src / toolsProvider.ts

import { text, tool, type Tool, type ToolsProviderController } from "@lmstudio/sdk";
import { configSchematics } from "./configSchematics";
import { z } from "zod";
import neo4j from "neo4j-driver";


export async function toolsProvider(ctl: ToolsProviderController) {
  const tools: Tool[] = [];

  const config = ctl.getPluginConfig(configSchematics);
  const connectionURIField = config.get("connection_uri");
  const databaseField = config.get("database");
  const usernameField = config.get("username");
  const passwordField = config.get("password");

  // Create Neo4j driver (reuse across tool calls for efficiency)
  const driver = neo4j.driver(connectionURIField, neo4j.auth.basic(usernameField, passwordField));

  const queryTool = tool({
    name: "query_neo4j",
    description: text`
      Execute a Cypher query on the Neo4j graph database and return the results.
      Provide the query as a string. Optionally, pass parameters as a JSON object.
      Results are returned as an array of records, each as a JSON object.
      Handle read and write queries; ensure the query is safe and efficient.
    `,
    parameters: {
      query: z.string().describe("The Cypher query to execute (e.g., 'MATCH (n) RETURN n')"),
      params: z.record(z.any()).optional().describe("Optional parameters for the query (e.g., { id: 123 })"),
    },
    implementation: async ({ query, params = {} }, { status, warn, signal }) => {
      const session = driver.session({
        database: databaseField
      });
      try {
        status("Connecting to Neo4j and executing query...");
        const result = await session.run(query, params, { signal });  // Supports abort via signal

        if (result.records.length > 1000) {
          warn("Query returned a large number of records; consider optimizing.");
        }

        // Format results as array of objects
        const formattedResults = result.records.map(record => {
          const obj: Record<string, any> = {};
          record.keys.forEach(key => {
            obj[key] = record.get(key);
          });
          return obj;
        });

        return formattedResults;
      } catch (error) {
        // Return error as string for model to retry if needed
        return `Error executing query: ${error.message}`;
      } finally {
        await session.close();
      }
    },
  });

  tools.push(queryTool);

  return tools;
}