src / memory / processing / ai.ts

/**
 * @file processing/ai.ts
 * AI-powered memory processing using the loaded LM Studio model.
 *
 * Provides: fact extraction, conflict detection.
 * All calls are optional (graceful fallback if no model loaded).
 * All calls have strict timeouts to never block the chat flow.
 */

import { LMStudioClient } from "@lmstudio/sdk";
import {
  AI_EXTRACT_MAX_TOKENS,
  AI_EXTRACT_TEMPERATURE,
  AI_CALL_TIMEOUT_MS,
  AI_CONFLICT_MAX_TOKENS,
  AI_CONFLICT_TEMPERATURE,
  VALID_CATEGORIES,
  type MemoryCategory,
} from "../constants";
import type { MemoryConflict, MemoryRecord } from "../types";

let cachedClient: LMStudioClient | null = null;
function getClient(): LMStudioClient {
  if (!cachedClient) cachedClient = new LMStudioClient();
  return cachedClient;
}

async function callModelWithMeta(
  prompt: string,
  maxTokens: number,
  temperature: number,
  timeoutMs: number = AI_CALL_TIMEOUT_MS,
): Promise<{ text: string; charCount: number } | null> {
  try {
    const client = getClient();
    const models = await Promise.race([
      client.llm.listLoaded(),
      new Promise<never>((_, rej) =>
        setTimeout(() => rej(new Error("timeout")), timeoutMs),
      ),
    ]);
    if (!Array.isArray(models) || models.length === 0) return null;

    const model = await client.llm.model(models[0].identifier);
    const stream = model.respond([{ role: "user", content: prompt }], {
      maxTokens,
      temperature,
    });
    let text = "";
    for await (const chunk of stream) text += chunk.content ?? "";
    const trimmed = text.trim();
    if (!trimmed) return null;
    return { text: trimmed, charCount: trimmed.length };
  } catch {
    cachedClient = null;
    return null;
  }
}

/** Simple model call (backward compat). */
async function callModel(
  prompt: string,
  maxTokens: number,
  temperature: number,
  timeoutMs?: number,
): Promise<string | null> {
  const result = await callModelWithMeta(
    prompt,
    maxTokens,
    temperature,
    timeoutMs,
  );
  return result?.text ?? null;
}

export interface ExtractedFact {
  content: string;
  category: MemoryCategory;
  tags: string[];
  confidence: number;
}

/**
 * Extract structured facts from conversation text.
 * Called by the preprocessor when AI extraction is enabled.
 */
export async function extractFacts(
  conversationText: string,
  existingSummary: string = "",
): Promise<ExtractedFact[]> {
  const prompt = `You are a memory extraction system. Extract key facts, preferences, and information from this conversation that would be useful to remember for future conversations.

CONVERSATION:
${conversationText.slice(0, 2000)}

${existingSummary ? `ALREADY KNOWN:\n${existingSummary}\n\nOnly extract NEW information not already known.` : ""}

For each fact, output ONE line in this exact format:
FACT: <the information> | CATEGORY: <${VALID_CATEGORIES.join("/")}> | TAGS: <comma-separated tags> | CONFIDENCE: <0.0-1.0>

Rules:
- Extract only clearly stated facts, not guesses or ambiguous statements
- CONFIDENCE: 1.0 = explicitly stated, 0.7 = strongly implied, 0.5 = inferred
- Keep each fact concise (one sentence)
- Maximum 5 facts per extraction
- If no useful facts to extract, output: NONE

OUTPUT:`;

  const raw = await callModel(
    prompt,
    AI_EXTRACT_MAX_TOKENS,
    AI_EXTRACT_TEMPERATURE,
  );
  if (!raw || /^NONE$/im.test(raw.trim())) return [];

  const facts: ExtractedFact[] = [];
  for (const line of raw.split("\n")) {
    const trimmed = line.trim();
    if (!trimmed.startsWith("FACT:")) continue;

    try {
      const factMatch = /FACT:\s*(.+?)\s*\|/.exec(trimmed);
      const catMatch = /CATEGORY:\s*(\w+)/i.exec(trimmed);
      const tagMatch = /TAGS:\s*(.+?)\s*\|/.exec(trimmed);
      const confMatch = /CONFIDENCE:\s*([\d.]+)/i.exec(trimmed);

      if (!factMatch) continue;
      const content = factMatch[1].trim();
      if (content.length < 5 || content.length > 500) continue;

      const rawCat = catMatch?.[1]?.toLowerCase() ?? "note";
      const category = VALID_CATEGORIES.includes(rawCat as MemoryCategory)
        ? (rawCat as MemoryCategory)
        : "note";

      const tags = (tagMatch?.[1] ?? "")
        .split(",")
        .map((t) => t.trim().toLowerCase())
        .filter((t) => t.length >= 2 && t.length <= 50)
        .slice(0, 5);

      const confidence = confMatch
        ? Math.max(0, Math.min(1, parseFloat(confMatch[1])))
        : 0.7;

      facts.push({ content, category, tags, confidence });
    } catch {
      continue;
    }
  }

  return facts;
}

export async function detectConflicts(
  newContent: string,
  existingMemories: MemoryRecord[],
): Promise<MemoryConflict[]> {
  if (existingMemories.length === 0) return [];

  const existing = existingMemories
    .slice(0, 8)
    .map((m, i) => `[${i + 1}] (id=${m.id}) ${m.content}`)
    .join("\n");

  const prompt = `You are a memory conflict detector. Check if this NEW memory conflicts with any EXISTING memories.

NEW MEMORY: "${newContent}"

EXISTING MEMORIES:
${existing}

For each conflict found, output ONE line:
CONFLICT: <existing_index> | TYPE: <contradiction/update/duplicate> | ACTION: <keep_both/supersede/skip>

Rules:
- "contradiction": memories disagree on a fact → keep_both (let user decide)
- "update": new memory is a newer version of the same fact → supersede
- "duplicate": essentially the same information → skip
- If no conflicts: output NONE
- Max 3 conflicts

OUTPUT:`;

  const raw = await callModel(
    prompt,
    AI_CONFLICT_MAX_TOKENS,
    AI_CONFLICT_TEMPERATURE,
  );
  if (!raw || /^NONE$/im.test(raw.trim())) return [];

  const conflicts: MemoryConflict[] = [];
  for (const line of raw.split("\n")) {
    const trimmed = line.trim();
    if (!trimmed.startsWith("CONFLICT:")) continue;

    try {
      const idxMatch = /CONFLICT:\s*(\d+)/i.exec(trimmed);
      const typeMatch = /TYPE:\s*(\w+)/i.exec(trimmed);
      const actionMatch = /ACTION:\s*([\w_]+)/i.exec(trimmed);

      if (!idxMatch) continue;
      const idx = parseInt(idxMatch[1], 10) - 1;
      if (idx < 0 || idx >= existingMemories.length) continue;

      const mem = existingMemories[idx];
      const conflictType =
        (["contradiction", "update", "duplicate"] as const).find(
          (t) => typeMatch?.[1]?.toLowerCase() === t,
        ) ?? "contradiction";
      const resolution =
        (["keep_both", "supersede", "skip"] as const).find(
          (r) => actionMatch?.[1]?.toLowerCase() === r,
        ) ?? "keep_both";

      conflicts.push({
        existingId: mem.id,
        existingContent: mem.content,
        newContent,
        conflictType,
        resolution,
      });
    } catch {
      continue;
    }
  }

  return conflicts;
}