src / toolsProvider.ts

/**
 * @file toolsProvider.ts
 * Registers all five memory tools with LM Studio.
 *
 * Tools:
 *   1. Remember      — store a new memory
 *   2. Recall        — retrieve memories by topic/query
 *   3. Search Memory — advanced search with filters
 *   4. Forget        — delete memories
 *   5. Memory Status — stats and diagnostics
 */

import { tool } from "@lmstudio/sdk";
import { z } from "zod";
import { configSchematics } from "./config";
import { MemoryDatabase } from "./storage/db";
import { RetrievalEngine } from "./retrieval/engine";
import { extractFacts, detectConflicts } from "./processing/ai";
import {
  VALID_CATEGORIES,
  VALID_SCOPES,
  MAX_MEMORY_CONTENT_LENGTH,
  MAX_TAGS_PER_MEMORY,
  MAX_SEARCH_RESULTS,
  DEFAULT_SEARCH_RESULTS,
  MAX_SESSION_MEMORIES,
  type MemoryCategory,
  type MemoryScope,
} from "./constants";
import type { PluginController } from "./pluginTypes";
import type { MemoryRecord, ScoredMemory } from "./types";

function readConfig(ctl: PluginController) {
  const c = ctl.getPluginConfig(configSchematics);
  return {
    autoInject: c.get("autoInjectMemories") === "on",
    contextCount: c.get("contextMemoryCount") || 5,
    enableAI: c.get("enableAIExtraction") === "on",
    enableConflict: c.get("enableConflictDetection") === "on",
    decayDays: c.get("decayHalfLifeDays") || 30,
    storagePath: c.get("memoryStoragePath") || "",
  };
}

/** Shared singleton instances (initialized once, reused across calls). */
let db: MemoryDatabase | null = null;
let engine: RetrievalEngine | null = null;
let currentPath: string = "";
let initPromise: Promise<{
  db: MemoryDatabase;
  engine: RetrievalEngine;
}> | null = null;

async function ensureInitialized(
  storagePath: string,
): Promise<{ db: MemoryDatabase; engine: RetrievalEngine }> {
  const resolved = storagePath || "";
  if (db && engine && currentPath === resolved) return { db, engine };

  if (initPromise && currentPath === resolved) return initPromise;

  initPromise = (async () => {
    if (db) {
      try {
        db.close();
      } catch {
      }
    }
    db = new MemoryDatabase(resolved || undefined);
    await db.init();
    engine = new RetrievalEngine(db);
    engine.rebuildIndex();
    currentPath = resolved;
    return { db: db!, engine: engine! };
  })();

  return initPromise;
}

/** Export for use by the prompt preprocessor. */
export async function getSharedInstances(storagePath: string) {
  return ensureInitialized(storagePath);
}

/**
 * Session memory store — in-memory only, never persisted to SQLite.
 * Cleared when LM Studio restarts or plugin reloads.
 */
const sessionMemories: Map<string, MemoryRecord> = new Map();

function storeSession(
  content: string,
  category: MemoryCategory,
  tags: string[],
): string {
  if (sessionMemories.size >= MAX_SESSION_MEMORIES) {
    const oldest = [...sessionMemories.entries()].sort(
      (a, b) => a[1].createdAt - b[1].createdAt,
    )[0];
    if (oldest) sessionMemories.delete(oldest[0]);
  }
  const now = Date.now();
  const id = `sess_${now}_${Math.random().toString(36).slice(2, 8)}`;
  sessionMemories.set(id, {
    id,
    content,
    category,
    tags,
    confidence: 1.0,
    source: "tool-call",
    scope: "session",
    project: null,
    createdAt: now,
    updatedAt: now,
    lastAccessedAt: now,
    accessCount: 0,
    supersedes: null,
  });
  return id;
}

/** Get session memories matching a query (simple substring match). */
function searchSession(query: string): MemoryRecord[] {
  const lower = query.toLowerCase();
  return [...sessionMemories.values()].filter(
    (m) =>
      m.content.toLowerCase().includes(lower) ||
      m.tags.some((t) => t.includes(lower)),
  );
}

export async function toolsProvider(ctl: PluginController) {
  const cfg = readConfig(ctl);
  const { db, engine } = await ensureInitialized(cfg.storagePath);

  const rememberTool = tool({
    name: "Remember",
    description:
      `Store a fact, preference, project detail, or note in memory.\n\n` +
      `SCOPES:\n` +
      `• "global" (default) — persists forever across ALL conversations\n` +
      `• "project" — persists but only surfaces when that project is active (requires project name)\n` +
      `• "session" — temporary, lost when LM Studio closes. Use for short-lived context.\n\n` +
      `USE THIS when:\n` +
      `• The user tells you something about themselves (name, job, preferences)\n` +
      `• The user mentions a project they're working on\n` +
      `• The user asks you to remember something\n` +
      `• You learn an important fact that would be useful to recall later\n\n` +
      `Categories: ${VALID_CATEGORIES.join(", ")}`,
    parameters: {
      content: z
        .string()
        .min(3)
        .max(MAX_MEMORY_CONTENT_LENGTH)
        .describe(
          "The fact/preference/note to remember. Be concise but complete.",
        ),
      category: z.enum(VALID_CATEGORIES).describe("Category of this memory."),
      tags: z
        .array(z.string().max(50))
        .max(MAX_TAGS_PER_MEMORY)
        .optional()
        .describe(
          "Optional tags for easier retrieval (e.g., ['typescript', 'coding', 'preference']).",
        ),
      confidence: z
        .number()
        .min(0)
        .max(1)
        .optional()
        .describe(
          "How confident you are in this fact (0.0–1.0). Default 1.0 for explicit statements.",
        ),
      scope: z
        .enum(VALID_SCOPES)
        .optional()
        .describe(
          "Memory scope: 'global' (default, all chats), 'project' (project-specific), 'session' (temporary).",
        ),
      project: z
        .string()
        .max(60)
        .optional()
        .describe(
          "Project name (required when scope='project'). E.g., 'lms-memory-plugin', 'my-website'.",
        ),
    },
    implementation: async (
      { content, category, tags, confidence, scope, project },
      { status, warn },
    ) => {
      const memScope: MemoryScope = scope ?? "global";

      if (memScope === "project" && !project) {
        return {
          stored: false,
          error: "scope='project' requires a project name",
        };
      }

      if (memScope === "session") {
        const id = storeSession(content, category, tags ?? []);
        status("Session memory stored (temporary)");
        return {
          stored: true,
          id,
          scope: "session",
          content,
          category,
          tags: tags ?? [],
        };
      }

      status("Storing memory…");

      try {
        const id = db.store(
          content,
          category,
          tags ?? [],
          confidence ?? 1.0,
          "tool-call",
          null,
          memScope,
          project ?? null,
        );
        engine.indexMemory(id, content, tags ?? [], category);

        let conflictInfo: {
          type: string;
          existingContent: string;
          resolution: string;
        } | null = null;

        if (cfg.enableConflict) {
          try {
            const existing = engine.retrieve(content, 5, cfg.decayDays);
            const others = existing.memories.filter((m) => m.id !== id);
            if (others.length > 0) {
              const conflicts = await detectConflicts(content, others);
              for (const conflict of conflicts) {
                if (
                  conflict.resolution === "skip" ||
                  conflict.conflictType === "duplicate"
                ) {
                  db.delete(id);
                  engine.removeFromIndex(id);
                  status("Duplicate detected — removed");
                  return {
                    stored: false,
                    reason: "duplicate",
                    existingMemory: conflict.existingContent,
                  };
                }
                if (conflict.resolution === "supersede") {
                  db.update(id, content, confidence ?? 1.0, tags ?? []);
                  conflictInfo = {
                    type: "supersede",
                    existingContent: conflict.existingContent,
                    resolution: "New memory stored, supersedes older version",
                  };
                  break;
                }
                if (conflict.conflictType === "contradiction") {
                  warn(
                    `Potential contradiction with: "${conflict.existingContent.slice(0, 100)}"`,
                  );
                  conflictInfo = {
                    type: "contradiction",
                    existingContent: conflict.existingContent,
                    resolution:
                      "Both memories kept — you may want to resolve this",
                  };
                }
              }
            }
          } catch {
          }
        }

        status("Memory stored successfully");
        return {
          stored: true,
          id,
          scope: memScope,
          content,
          category,
          tags: tags ?? [],
          ...(project ? { project } : {}),
          ...(conflictInfo ? { conflict: conflictInfo } : {}),
        };
      } catch (err) {
        const msg = err instanceof Error ? err.message : String(err);
        warn(`Failed to store memory: ${msg}`);
        return { stored: false, error: msg };
      }
    },
  });

  const recallTool = tool({
    name: "Recall",
    description:
      `Search memory for relevant facts, preferences, or notes. ` +
      `Searches across all scopes (global + project + session) by default.\n\n` +
      `USE THIS when:\n` +
      `• You need to check what you know about the user\n` +
      `• The user references something from a past conversation\n` +
      `• You want context before answering a question\n` +
      `• The user asks "do you remember…" or "what do you know about…"`,
    parameters: {
      query: z
        .string()
        .min(2)
        .describe(
          "What to search for — natural language topic, keyword, or question.",
        ),
      limit: z
        .number()
        .int()
        .min(1)
        .max(MAX_SEARCH_RESULTS)
        .optional()
        .describe(
          `Max results to return (default: ${DEFAULT_SEARCH_RESULTS}).`,
        ),
      category: z
        .enum(VALID_CATEGORIES)
        .optional()
        .describe("Filter by category."),
      scope: z
        .enum(VALID_SCOPES)
        .optional()
        .describe("Filter by scope. Omit to search all scopes."),
      project: z
        .string()
        .max(60)
        .optional()
        .describe(
          "Filter by project name. Only returns memories from this project.",
        ),
    },
    implementation: async (
      { query, limit, category, scope, project },
      { status },
    ) => {
      status(`Searching memories: "${query}"`);

      const maxResults = limit ?? DEFAULT_SEARCH_RESULTS;

      const result = cfg.enableAI
        ? await engine.retrieveWithSRLM(query, maxResults, cfg.decayDays, 3)
        : engine.retrieve(query, maxResults, cfg.decayDays);
      let memories: Array<ScoredMemory | MemoryRecord> = [...result.memories];

      const sessionHits = searchSession(query);
      if (sessionHits.length > 0) {
        memories.push(...sessionHits);
      }

      if (category) memories = memories.filter((m) => m.category === category);
      if (scope) memories = memories.filter((m) => m.scope === scope);
      if (project) memories = memories.filter((m) => m.project === project);

      memories.sort((a, b) => {
        const scoreA = "compositeScore" in a ? a.compositeScore : 0.5;
        const scoreB = "compositeScore" in b ? b.compositeScore : 0.5;
        return scoreB - scoreA;
      });
      memories = memories.slice(0, maxResults);

      if (memories.length === 0) {
        status("No relevant memories found");
        return {
          found: 0,
          memories: [],
          suggestion:
            "No memories match this query. The user may need to share this information first.",
        };
      }

      status(
        `Found ${memories.length} relevant memories (${result.timeTakenMs.toFixed(1)}ms)`,
      );

      return {
        found: memories.length,
        totalMatched: result.totalMatched + sessionHits.length,
        searchTimeMs: Math.round(result.timeTakenMs),
        memories: memories.map((m) => ({
          id: m.id,
          content: m.content,
          category: m.category,
          tags: m.tags,
          confidence: m.confidence,
          scope: m.scope,
          ...(m.project ? { project: m.project } : {}),
          relevance:
            "compositeScore" in m ? Math.round(m.compositeScore * 100) : 50,
          lastAccessed: new Date(m.lastAccessedAt).toISOString(),
          accessCount: m.accessCount,
        })),
      };
    },
  });

  const searchTool = tool({
    name: "Search Memory",
    description:
      `Advanced memory search with filters by category, tag, or recency. ` +
      `Use 'Recall' for simple topic-based retrieval. Use this for precise filtering.`,
    parameters: {
      query: z
        .string()
        .optional()
        .describe("Optional text query for semantic search."),
      category: z
        .enum(VALID_CATEGORIES)
        .optional()
        .describe("Filter by memory category."),
      tag: z.string().optional().describe("Filter by tag (exact match)."),
      recent: z
        .number()
        .int()
        .min(1)
        .max(50)
        .optional()
        .describe(
          "Get the N most recently created memories (ignores query/filters).",
        ),
      limit: z
        .number()
        .int()
        .min(1)
        .max(MAX_SEARCH_RESULTS)
        .optional()
        .describe(`Max results (default: ${DEFAULT_SEARCH_RESULTS}).`),
    },
    implementation: async (
      { query, category, tag, recent, limit },
      { status },
    ) => {
      const maxResults = limit ?? DEFAULT_SEARCH_RESULTS;

      if (recent) {
        status(`Getting ${recent} most recent memories`);
        const memories = db.getRecent(recent);
        return {
          mode: "recent",
          found: memories.length,
          memories: memories.map((m) => ({
            id: m.id,
            content: m.content,
            category: m.category,
            tags: m.tags,
            confidence: m.confidence,
            created: new Date(m.createdAt).toISOString(),
          })),
        };
      }

      if (tag && !query) {
        status(`Searching by tag: "${tag}"`);
        const memories = db.getByTag(tag, maxResults);
        return {
          mode: "tag-filter",
          tag,
          found: memories.length,
          memories: memories.map((m) => ({
            id: m.id,
            content: m.content,
            category: m.category,
            tags: m.tags,
          })),
        };
      }

      if (category && !query) {
        status(`Searching by category: "${category}"`);
        const memories = db.getByCategory(category, maxResults);
        return {
          mode: "category-filter",
          category,
          found: memories.length,
          memories: memories.map((m) => ({
            id: m.id,
            content: m.content,
            tags: m.tags,
            confidence: m.confidence,
          })),
        };
      }

      if (query) {
        status(`Searching: "${query}"`);
        const result = engine.retrieve(query, maxResults, cfg.decayDays);
        let memories = result.memories;
        if (category)
          memories = memories.filter((m) => m.category === category);
        if (tag)
          memories = memories.filter((m) => m.tags.includes(tag.toLowerCase()));

        return {
          mode: "semantic",
          found: memories.length,
          searchTimeMs: Math.round(result.timeTakenMs),
          memories: memories.map((m) => ({
            id: m.id,
            content: m.content,
            category: m.category,
            tags: m.tags,
            relevance: Math.round(m.compositeScore * 100),
          })),
        };
      }

      status("Returning recent memories");
      const memories = db.getRecent(maxResults);
      return {
        mode: "recent-fallback",
        found: memories.length,
        memories: memories.map((m) => ({
          id: m.id,
          content: m.content,
          category: m.category,
          tags: m.tags,
        })),
      };
    },
  });

  const forgetTool = tool({
    name: "Forget",
    description:
      `Delete memories by ID, content pattern, or clear all. ` +
      `Use when the user asks you to forget something or when information is outdated.\n\n` +
      `USE THIS when:\n` +
      `• User says "forget that" or "delete that memory"\n` +
      `• User corrects a fact (delete old, store new)\n` +
      `• User wants to clear their data`,
    parameters: {
      id: z
        .string()
        .optional()
        .describe("Exact memory ID to delete (from Recall/Search results)."),
      pattern: z
        .string()
        .optional()
        .describe("Delete all memories whose content contains this text."),
      deleteAll: z
        .boolean()
        .optional()
        .describe(
          "Set to true to delete ALL memories. Use with extreme caution.",
        ),
    },
    implementation: async ({ id, pattern, deleteAll }, { status, warn }) => {
      if (deleteAll === true) {
        const dbCount = db.deleteAll();
        const sessCount = sessionMemories.size;
        sessionMemories.clear();
        engine.rebuildIndex();
        status(
          `Deleted all ${dbCount + sessCount} memories (${dbCount} persistent + ${sessCount} session)`,
        );
        return { deleted: dbCount + sessCount, mode: "delete-all" };
      }

      if (id) {
        if (sessionMemories.has(id)) {
          sessionMemories.delete(id);
          status("Session memory deleted");
          return { deleted: 1, id, scope: "session" };
        }
        const existed = db.delete(id);
        if (existed) {
          engine.removeFromIndex(id);
          status("Memory deleted");
          return { deleted: 1, id };
        }
        return { deleted: 0, error: "Memory ID not found" };
      }

      if (pattern) {
        let sessDeleted = 0;
        for (const [sid, mem] of sessionMemories) {
          if (mem.content.toLowerCase().includes(pattern.toLowerCase())) {
            sessionMemories.delete(sid);
            sessDeleted++;
          }
        }
        const dbDeleted = db.deleteByPattern(pattern);
        if (dbDeleted > 0) engine.rebuildIndex();
        const total = dbDeleted + sessDeleted;
        status(`Deleted ${total} memories matching "${pattern}"`);
        return { deleted: total, pattern };
      }

      warn("No deletion target specified");
      return {
        deleted: 0,
        error: "Specify an id, pattern, or set deleteAll: true",
      };
    },
  });

  const statusTool = tool({
    name: "Memory Status",
    description:
      `Get statistics about the memory system: total count by scope, categories, ` +
      `session memory count, most accessed memory, database size, and index health.`,
    parameters: {},
    implementation: async (_, { status }) => {
      status("Gathering memory statistics…");
      const stats = db.getStats();
      const idxStats = engine.indexStats;

      return {
        ...stats,
        sessionMemories: sessionMemories.size,
        dbSizeKB: Math.round(stats.dbSizeBytes / 1024),
        indexVocabularySize: idxStats.vocabSize,
        indexedDocuments: idxStats.docCount,
        scopes: {
          global: "Persistent across all conversations",
          project: "Persistent, filtered by project name",
          session: `Temporary, in-memory only (${sessionMemories.size} active)`,
        },
      };
    },
  });

  return [rememberTool, recallTool, searchTool, forgetTool, statusTool];
}