src / memory / retrieval / engine.ts

/**
 * @file retrieval/engine.ts
 * Retrieval engine: combines TF-IDF similarity, memory decay, access
 * frequency, and confidence into a single composite score.
 *
 * Inspired by the SRLM paper's insight that multiple uncertainty signals
 * (self-consistency, trace length, confidence) outperform any single one.
 * We analogously blend multiple retrieval signals.
 */

import { TfIdfIndex } from "./tfidf";
import { MemoryDatabase } from "../storage/db";
import {
  semanticSearch,
  isEmbeddingAvailable,
  removeEmbedding,
  clearEmbeddingCache,
} from "./embedding";
import {
  DECAY_HALF_LIFE_DAYS,
  DECAY_WEIGHT,
  FREQUENCY_WEIGHT,
  SIMILARITY_WEIGHT,
  CONFIDENCE_WEIGHT,
  KEYWORD_BOOST_WEIGHT,
  MIN_RELEVANCE_THRESHOLD,
  MAX_SEARCH_RESULTS,
} from "../constants";
import { tokenize } from "./tfidf";
import type { MemoryRecord, ScoredMemory, RetrievalResult } from "../types";

/**
 * Compute exponential decay score based on time since last access.
 * Returns a value in [0, 1] where 1 = just accessed, 0 = very old.
 */
function computeDecay(lastAccessedAt: number, halfLifeDays: number): number {
  const now = Date.now();
  const daysSinceAccess = (now - lastAccessedAt) / (24 * 60 * 60 * 1000);
  // exp decay: score = 2^(-t/halfLife)
  return Math.pow(2, -daysSinceAccess / halfLifeDays);
}

/**
 * Normalize access count to [0, 1] using logarithmic scaling.
 * This prevents a single heavily-accessed memory from dominating.
 */
function normalizeFrequency(
  accessCount: number,
  maxAccessCount: number,
): number {
  if (maxAccessCount <= 0) return 0;
  return Math.log(1 + accessCount) / Math.log(1 + maxAccessCount);
}

export class RetrievalEngine {
  private tfIdf: TfIdfIndex;
  private db: MemoryDatabase;
  private maxAccessCount = 1;

  constructor(db: MemoryDatabase) {
    this.db = db;
    this.tfIdf = new TfIdfIndex();
  }

  /**
   * Build/rebuild the TF-IDF index from all stored memories.
   * Called once at startup, then incrementally maintained.
   */
  rebuildIndex(): void {
    this.tfIdf.clear();
    clearEmbeddingCache();
    const allMemories = this.db.getAll(10_000);
    this.maxAccessCount = 1;
    for (const mem of allMemories) {
      this.tfIdf.addDocument(
        mem.id,
        `${mem.content} ${mem.tags.join(" ")} ${mem.category}`,
      );
      if (mem.accessCount > this.maxAccessCount) {
        this.maxAccessCount = mem.accessCount;
      }
    }
  }

  /** Add a single memory to the index (incremental). */
  indexMemory(
    id: string,
    content: string,
    tags: string[],
    category: string,
  ): void {
    this.tfIdf.addDocument(id, `${content} ${tags.join(" ")} ${category}`);
  }

  /** Remove a memory from the index. */
  removeFromIndex(id: string): void {
    this.tfIdf.removeDocument(id);
    removeEmbedding(id);
  }

  /**
   * Retrieve memories ranked by composite score.
   * Used by both the prompt preprocessor and the explicit tools.
   *
   * @param touchAccess If false, skip updating access counters.
   *   The preprocessor sets this to false to prevent auto-inject
   *   from artificially inflating access counts.
   */
  /**
   * Retrieve memories — tries embedding similarity first, falls back to TF-IDF.
   */
  async retrieve(
    query: string,
    limit: number = MAX_SEARCH_RESULTS,
    halfLifeDays: number = DECAY_HALF_LIFE_DAYS,
    touchAccess: boolean = true,
  ): Promise<RetrievalResult> {
    const start = performance.now();

    // Phase 1: Try embedding-based retrieval
    if (await isEmbeddingAvailable()) {
      const allMemories = this.db.getValid(500);
      if (allMemories.length > 0) {
        const candidates = allMemories.map(m => ({
          id: m.id,
          content: `${m.content} ${m.tags.join(" ")} ${m.category}`,
        }));
        const embeddingResults = await semanticSearch(query, candidates, Math.min(limit * 3, 100));
        if (embeddingResults && embeddingResults.length > 0) {
          const ids = embeddingResults.map(([id]) => id);
          const memories = this.db.getByIds(ids);
          const similarityMap = new Map<string, number>();
          for (const [docId, score] of embeddingResults) {
            similarityMap.set(docId, score);
          }
          for (const mem of memories) {
            if (mem.accessCount > this.maxAccessCount) {
              this.maxAccessCount = mem.accessCount;
            }
          }
          return this.scoreAndRank(memories, null, limit, halfLifeDays, start, query, similarityMap, touchAccess);
        }
      }
    }

    // Phase 2: Fall back to TF-IDF
    const candidateLimit = Math.min(limit * 3, 100);
    const tfIdfResults = this.tfIdf.search(query, candidateLimit);

    if (tfIdfResults.length === 0) {
      const ftsResults = this.db.ftsSearch(query, limit);
      if (ftsResults.length === 0) {
        return {
          memories: [],
          totalMatched: 0,
          queryTerms: [],
          timeTakenMs: performance.now() - start,
        };
      }
      return this.scoreAndRank(ftsResults, 0.5, limit, halfLifeDays, start, query, undefined, touchAccess);
    }

    const ids = tfIdfResults.map(([id]) => id);
    const memories = this.db.getByIds(ids);
    const similarityMap = new Map<string, number>();
    for (const [docId, score] of tfIdfResults) {
      similarityMap.set(docId, score);
    }
    for (const mem of memories) {
      if (mem.accessCount > this.maxAccessCount) {
        this.maxAccessCount = mem.accessCount;
      }
    }

    return this.scoreAndRank(memories, null, limit, halfLifeDays, start, query, similarityMap, touchAccess);
  }

  private scoreAndRank(
    memories: MemoryRecord[],
    flatSimilarity: number | null,
    limit: number,
    halfLifeDays: number,
    startTime: number,
    query: string,
    similarityMap?: Map<string, number>,
    touchAccess: boolean = true,
  ): RetrievalResult {
    const scored: ScoredMemory[] = [];
    const now = Date.now();

    // Tokenize query for keyword overlap boost
    const queryTokens = tokenize(query);
    const queryTokenSet = new Set(queryTokens);

    for (const mem of memories) {
      // Skip expired memories
      if (mem.validTo && mem.validTo < now) continue;

      const similarity = flatSimilarity ?? similarityMap?.get(mem.id) ?? 0;
      const decay = computeDecay(mem.lastAccessedAt, halfLifeDays);
      const frequency = normalizeFrequency(
        mem.accessCount,
        this.maxAccessCount,
      );
      const confidence = mem.confidence;

      // Keyword overlap boost: fraction of query terms found in memory content
      let keywordBoost = 0;
      if (queryTokenSet.size > 0) {
        const contentTokens = tokenize(mem.content);
        const contentTokenSet = new Set(contentTokens);
        let overlap = 0;
        for (const qt of queryTokenSet) {
          if (contentTokenSet.has(qt)) overlap++;
        }
        keywordBoost = overlap / queryTokenSet.size;
      }

      const composite =
        SIMILARITY_WEIGHT * similarity +
        DECAY_WEIGHT * decay +
        FREQUENCY_WEIGHT * frequency +
        CONFIDENCE_WEIGHT * confidence +
        KEYWORD_BOOST_WEIGHT * keywordBoost;

      if (composite < MIN_RELEVANCE_THRESHOLD) continue;

      scored.push({
        ...mem,
        relevanceScore: similarity,
        decayScore: decay,
        compositeScore: composite,
      });
    }

    scored.sort((a, b) => b.compositeScore - a.compositeScore);
    const results = scored.slice(0, limit);

    if (touchAccess && results.length > 0) {
      try {
        this.db.touchAccessBatch(results.map((m) => m.id));
      } catch {
      }
    }

    const queryTerms = query
      .toLowerCase()
      .split(/\s+/)
      .filter((t) => t.length >= 2);

    return {
      memories: results,
      totalMatched: scored.length,
      queryTerms,
      timeTakenMs: performance.now() - startTime,
    };
  }

  get indexStats() {
    return this.tfIdf.stats;
  }
}