src / retrieval / engine.ts

/**
 * @file retrieval/engine.ts
 * Retrieval engine: combines TF-IDF similarity, memory decay, access
 * frequency, and confidence into a single composite score.
 *
 * Inspired by the SRLM paper's insight that multiple uncertainty signals
 * (self-consistency, trace length, confidence) outperform any single one.
 * We analogously blend multiple retrieval signals.
 */

import { TfIdfIndex } from "./tfidf";
import { MemoryDatabase } from "../storage/db";
import { srlmRerank } from "../processing/ai";
import {
  DECAY_HALF_LIFE_DAYS,
  DECAY_WEIGHT,
  FREQUENCY_WEIGHT,
  SIMILARITY_WEIGHT,
  CONFIDENCE_WEIGHT,
  MIN_RELEVANCE_THRESHOLD,
  MAX_SEARCH_RESULTS,
} from "../constants";
import type { MemoryRecord, ScoredMemory, RetrievalResult } from "../types";

/**
 * Compute exponential decay score based on time since last access.
 * Returns a value in [0, 1] where 1 = just accessed, 0 = very old.
 */
function computeDecay(lastAccessedAt: number, halfLifeDays: number): number {
  const now = Date.now();
  const daysSinceAccess = (now - lastAccessedAt) / (24 * 60 * 60 * 1000);
  // exp decay: score = 2^(-t/halfLife)
  return Math.pow(2, -daysSinceAccess / halfLifeDays);
}

/**
 * Normalize access count to [0, 1] using logarithmic scaling.
 * This prevents a single heavily-accessed memory from dominating.
 */
function normalizeFrequency(
  accessCount: number,
  maxAccessCount: number,
): number {
  if (maxAccessCount <= 0) return 0;
  return Math.log(1 + accessCount) / Math.log(1 + maxAccessCount);
}

export class RetrievalEngine {
  private tfIdf: TfIdfIndex;
  private db: MemoryDatabase;
  private maxAccessCount = 1;

  constructor(db: MemoryDatabase) {
    this.db = db;
    this.tfIdf = new TfIdfIndex();
  }

  /**
   * Build/rebuild the TF-IDF index from all stored memories.
   * Called once at startup, then incrementally maintained.
   */
  rebuildIndex(): void {
    this.tfIdf.clear();
    const allMemories = this.db.getAll(10_000);
    this.maxAccessCount = 1;
    for (const mem of allMemories) {
      this.tfIdf.addDocument(
        mem.id,
        `${mem.content} ${mem.tags.join(" ")} ${mem.category}`,
      );
      if (mem.accessCount > this.maxAccessCount) {
        this.maxAccessCount = mem.accessCount;
      }
    }
  }

  /** Add a single memory to the index (incremental). */
  indexMemory(
    id: string,
    content: string,
    tags: string[],
    category: string,
  ): void {
    this.tfIdf.addDocument(id, `${content} ${tags.join(" ")} ${category}`);
  }

  /** Remove a memory from the index. */
  removeFromIndex(id: string): void {
    this.tfIdf.removeDocument(id);
  }

  /**
   * Retrieve memories ranked by composite score.
   * Used by both the prompt preprocessor and the explicit tools.
   *
   * @param touchAccess If false, skip updating access counters.
   *   The preprocessor sets this to false to prevent auto-inject
   *   from artificially inflating access counts.
   */
  retrieve(
    query: string,
    limit: number = MAX_SEARCH_RESULTS,
    halfLifeDays: number = DECAY_HALF_LIFE_DAYS,
    touchAccess: boolean = true,
  ): RetrievalResult {
    const start = performance.now();

    const candidateLimit = Math.min(limit * 3, 100);
    const tfIdfResults = this.tfIdf.search(query, candidateLimit);

    if (tfIdfResults.length === 0) {
      const ftsResults = this.db.ftsSearch(query, limit);
      if (ftsResults.length === 0) {
        return {
          memories: [],
          totalMatched: 0,
          queryTerms: [],
          timeTakenMs: performance.now() - start,
        };
      }
      return this.scoreAndRank(
        ftsResults,
        0.5,
        limit,
        halfLifeDays,
        start,
        query,
        undefined,
        touchAccess,
      );
    }

    const ids = tfIdfResults.map(([id]) => id);
    const memories = this.db.getByIds(ids);
    const similarityMap = new Map<string, number>();
    for (const [docId, score] of tfIdfResults) {
      similarityMap.set(docId, score);
    }
    for (const mem of memories) {
      if (mem.accessCount > this.maxAccessCount) {
        this.maxAccessCount = mem.accessCount;
      }
    }

    return this.scoreAndRank(
      memories,
      null,
      limit,
      halfLifeDays,
      start,
      query,
      similarityMap,
      touchAccess,
    );
  }

  private scoreAndRank(
    memories: MemoryRecord[],
    flatSimilarity: number | null,
    limit: number,
    halfLifeDays: number,
    startTime: number,
    query: string,
    similarityMap?: Map<string, number>,
    touchAccess: boolean = true,
  ): RetrievalResult {
    const scored: ScoredMemory[] = [];

    for (const mem of memories) {
      const similarity = flatSimilarity ?? similarityMap?.get(mem.id) ?? 0;
      const decay = computeDecay(mem.lastAccessedAt, halfLifeDays);
      const frequency = normalizeFrequency(
        mem.accessCount,
        this.maxAccessCount,
      );
      const confidence = mem.confidence;

      const composite =
        SIMILARITY_WEIGHT * similarity +
        DECAY_WEIGHT * decay +
        FREQUENCY_WEIGHT * frequency +
        CONFIDENCE_WEIGHT * confidence;

      if (composite < MIN_RELEVANCE_THRESHOLD) continue;

      scored.push({
        ...mem,
        relevanceScore: similarity,
        decayScore: decay,
        compositeScore: composite,
      });
    }

    scored.sort((a, b) => b.compositeScore - a.compositeScore);
    const results = scored.slice(0, limit);

    if (touchAccess && results.length > 0) {
      try {
        this.db.touchAccessBatch(results.map((m) => m.id));
      } catch {
      }
    }

    const queryTerms = query
      .toLowerCase()
      .split(/\s+/)
      .filter((t) => t.length >= 2);

    return {
      memories: results,
      totalMatched: scored.length,
      queryTerms,
      timeTakenMs: performance.now() - startTime,
    };
  }

  get indexStats() {
    return this.tfIdf.stats;
  }

  /**
   * SRLM-enhanced retrieval: standard retrieval + AI re-ranking.
   *
   * Two-phase approach:
   * Phase 1: Fast TF-IDF + composite scoring (same as retrieve())
   * Phase 2: SRLM re-ranking via K-candidate self-consistency + VC + trace length
   *
   * The SRLM scores are blended with the composite scores:
   *   finalScore = 0.6 × compositeScore + 0.4 × srlmScore
   *
   * Falls back to plain retrieve() if AI is unavailable or times out.
   */
  async retrieveWithSRLM(
    query: string,
    limit: number = MAX_SEARCH_RESULTS,
    halfLifeDays: number = DECAY_HALF_LIFE_DAYS,
    K: number = 3,
  ): Promise<RetrievalResult> {
    const baseResult = this.retrieve(
      query,
      Math.min(limit * 2, 30),
      halfLifeDays,
      true,
    );
    if (baseResult.memories.length === 0) return baseResult;

    try {
      const candidates = baseResult.memories.slice(0, 12).map((m) => ({
        id: m.id,
        content: m.content,
      }));

      const srlmScores = await srlmRerank(query, candidates, K);

      if (srlmScores.size > 0) {
        for (const mem of baseResult.memories) {
          const srlmScore = srlmScores.get(mem.id);
          if (srlmScore !== undefined) {
            mem.compositeScore = 0.6 * mem.compositeScore + 0.4 * srlmScore;
          } else {
            mem.compositeScore *= 0.5;
          }
        }

        baseResult.memories.sort((a, b) => b.compositeScore - a.compositeScore);
      }
    } catch {
    }

    baseResult.memories = baseResult.memories.slice(0, limit);
    return baseResult;
  }
}