src / memory / retrieval / embedding.ts

/**
 * @file retrieval/embedding.ts
 * Embedding-based semantic retrieval via local LM Studio API.
 *
 * Uses the loaded embedding model (e.g., nomic-embed-text) to compute
 * dense vector similarity. Falls back gracefully if no embedding model
 * is available.
 */

const EMBEDDING_ENDPOINT = "http://localhost:1234/v1/embeddings";
const EMBEDDING_TIMEOUT_MS = 5_000;

/** Cosine similarity between two vectors. */
function cosineSimilarity(a: number[], b: number[]): number {
  let dot = 0, normA = 0, normB = 0;
  for (let i = 0; i < a.length; i++) {
    dot += a[i] * b[i];
    normA += a[i] * a[i];
    normB += b[i] * b[i];
  }
  const denom = Math.sqrt(normA) * Math.sqrt(normB);
  return denom === 0 ? 0 : dot / denom;
}

/** Cache of document embeddings keyed by memory ID. */
const embeddingCache = new Map<string, number[]>();

/** Whether embedding is available (checked once). */
let embeddingAvailable: boolean | null = null;

/**
 * Call the local embedding API. Returns null if unavailable.
 */
async function getEmbeddings(texts: string[]): Promise<number[][] | null> {
  try {
    const controller = new AbortController();
    const timeout = setTimeout(() => controller.abort(), EMBEDDING_TIMEOUT_MS);

    const response = await fetch(EMBEDDING_ENDPOINT, {
      method: "POST",
      headers: { "Content-Type": "application/json" },
      body: JSON.stringify({ input: texts }),
      signal: controller.signal,
    });
    clearTimeout(timeout);

    if (!response.ok) {
      embeddingAvailable = false;
      return null;
    }

    const data = await response.json();
    embeddingAvailable = true;
    return data.data.map((d: { embedding: number[] }) => d.embedding);
  } catch {
    embeddingAvailable = false;
    return null;
  }
}

/**
 * Check if embedding API is reachable (cached after first check).
 */
export async function isEmbeddingAvailable(): Promise<boolean> {
  if (embeddingAvailable !== null) return embeddingAvailable;
  const result = await getEmbeddings(["test"]);
  return result !== null;
}

/**
 * Remove a memory from the embedding cache.
 */
export function removeEmbedding(id: string): void {
  embeddingCache.delete(id);
}

/**
 * Clear the entire embedding cache (e.g., on index rebuild).
 */
export function clearEmbeddingCache(): void {
  embeddingCache.clear();
}

/**
 * Semantic search: compute query embedding and rank candidate memories.
 *
 * @param query User query string
 * @param candidates Array of { id, content } to rank
 * @param limit Max results to return
 * @returns Sorted array of [id, similarity] or null if embedding unavailable
 */
export async function semanticSearch(
  query: string,
  candidates: Array<{ id: string; content: string }>,
  limit: number,
): Promise<Array<[string, number]> | null> {
  if (candidates.length === 0) return [];

  // Gather texts to embed: query + any uncached candidates
  const uncachedCandidates = candidates.filter(c => !embeddingCache.has(c.id));
  const textsToEmbed = [query, ...uncachedCandidates.map(c => c.content)];

  const embeddings = await getEmbeddings(textsToEmbed);
  if (!embeddings) return null;

  // Store query embedding
  const queryEmbedding = embeddings[0];

  // Cache new candidate embeddings
  for (let i = 0; i < uncachedCandidates.length; i++) {
    embeddingCache.set(uncachedCandidates[i].id, embeddings[i + 1]);
  }

  // Score all candidates
  const scored: Array<[string, number]> = [];
  for (const candidate of candidates) {
    const candidateEmbedding = embeddingCache.get(candidate.id);
    if (!candidateEmbedding) continue;
    const sim = cosineSimilarity(queryEmbedding, candidateEmbedding);
    scored.push([candidate.id, sim]);
  }

  scored.sort((a, b) => b[1] - a[1]);
  return scored.slice(0, limit);
}