src / services / rag / reranker.ts

// We cannot use static imports for ESM-only modules in a CJS environment.
// We must use dynamic imports and 'any' types for the library objects.

let transformers: any = null;
let tokenizer: any = null;
let model: any = null;

const MODEL_ID = 'Xenova/bge-reranker-base';

export interface RerankResult {
    index: number;
    score: number;
}

async function getTransformers() {
    if (!transformers) {
        // Dynamic import to bypass ERR_REQUIRE_ESM
        transformers = await import('@xenova/transformers');
        
        // Configure environment
        transformers.env.allowLocalModels = false; // Force download/cache usage
        transformers.env.useBrowserCache = true;
    }
    return transformers;
}

export async function loadReranker() {
    if (!tokenizer || !model) {
        console.log("Loading Reranker Model (this may take a while on first run)...");
        const { AutoTokenizer, AutoModelForSequenceClassification } = await getTransformers();
        
        tokenizer = await AutoTokenizer.from_pretrained(MODEL_ID);
        model = await AutoModelForSequenceClassification.from_pretrained(MODEL_ID);
        console.log("Reranker Model Loaded.");
    }
}

export async function rerank(query: string, documents: string[]): Promise<RerankResult[]> {
    await loadReranker();
    if (!tokenizer || !model) throw new Error("Model failed to load");

    // Construct pairs: [query, doc1], [query, doc2], ...
    const queries = new Array(documents.length).fill(query);
    
    // Batch processing
    const inputs = await tokenizer(queries, { 
        text_pair: documents, 
        padding: true, 
        truncation: true 
    });

    const output = await model(inputs);
    // Sigmoid to get 0-1 scores
    const scores = output.logits.sigmoid().data;

    // Map back to indices
    const results: RerankResult[] = [];
    for (let i = 0; i < scores.length; i++) {
        results.push({ index: i, score: scores[i] });
    }

    // Sort descending
    results.sort((a, b) => b.score - a.score);

    return results;
}