src / promptPreprocessor.ts

import {type Chat, type ChatMessage, type FileHandle, type LLMDynamicHandle, type PredictionProcessStatusController, type PromptPreprocessorController, text,} from "@lmstudio/sdk";
import {dynamicConfig} from "./index";
import {CONFIG_KEYS, parseLanguageFromDisplay} from "./config";
import {setLanguage, t} from "./i18n";
import fs from "fs";

type DocumentContextInjectionStrategy = "none" | "inject-full-content" | "retrieval";


// Debug 函數
const debug = (msg: string, data?: any) => {
    const log = `[${new Date().toISOString()}] ${msg}${data ? ' ' + JSON.stringify(data, null, 2) : ''}\n`;
    fs.appendFileSync('D:\\dev\\rag-flex\\logs\\lmstudio-debug.log', log);  // Windows 路徑
};

export async function preprocess(ctl: PromptPreprocessorController, userMessage: ChatMessage) {
    // Read user's language preference from config and set it
    const pluginConfig = ctl.getPluginConfig(dynamicConfig);
    const languageDisplayValue = pluginConfig.get(CONFIG_KEYS.LANGUAGE) as string;
    const userLanguage = parseLanguageFromDisplay(languageDisplayValue);
    setLanguage(userLanguage);

    const userPrompt = userMessage.getText();
    const history = await ctl.pullHistory();
    history.append(userMessage);
    const newFiles = userMessage.getFiles(ctl.client).filter(f => f.type !== "image");
    const files = history.getAllFiles(ctl.client).filter(f => f.type !== "image");

    if (newFiles.length > 0) {
        const strategy = await chooseContextInjectionStrategy(ctl, userPrompt, newFiles);
        if (strategy === "inject-full-content") {
            return await prepareDocumentContextInjection(ctl, userMessage);
        } else if (strategy === "retrieval") {
            return await prepareRetrievalResultsContextInjection(ctl, userPrompt, files);
        }
    } else if (files.length > 0) {
        return await prepareRetrievalResultsContextInjection(ctl, userPrompt, files);
    }

    return userMessage;
}

async function prepareRetrievalResultsContextInjection(
    ctl: PromptPreprocessorController,
    originalUserPrompt: string,
    files: Array<FileHandle>,
): Promise<string> {
    const translations = t();
    const pluginConfig = ctl.getPluginConfig(dynamicConfig);
    const retrievalLimit = pluginConfig.get(CONFIG_KEYS.LIMIT);
    const retrievalAffinityThreshold = pluginConfig.get(CONFIG_KEYS.THRESHOLD);

    // 讀取模型路徑:自訂路徑優先,否則使用下拉選單
    const customModelPath = (pluginConfig.get(CONFIG_KEYS.CUSTOM_MODEL_PATH) as string || "").trim();
    let modelPath = customModelPath || (pluginConfig.get(CONFIG_KEYS.MODEL_PATH) as string);
    // 移除括號內的未下載標註,支援: (not downloaded), (未下載), (未ダウンロード)
    modelPath = modelPath.replace(/ \([^)]+\)$/, ""); // 取得實際模型路徑

    // process files if necessary
    const statusSteps = new Map<FileHandle, PredictionProcessStatusController>();

    const retrievingStatus = ctl.createStatus({
        status: "loading",
        text: translations.status.loadingEmbeddingModel(modelPath),
    });

    try {
        // 嘗試載入模型
        const model = await ctl.client.embedding.model(modelPath, {
            signal: ctl.abortSignal,
        });

        retrievingStatus.setState({
            status: "loading",
            text: translations.status.retrievingCitations,
        });

        debug('model path:', model.path);
        debug('model identifer:', model.identifier);
        const result = await ctl.client.files.retrieve(originalUserPrompt, files, {
            embeddingModel: model,
            // Affinity threshold: 0.6 not implemented
            limit: retrievalLimit,
            signal: ctl.abortSignal,
            onFileProcessList(filesToProcess) {
                for (const file of filesToProcess) {
                    statusSteps.set(
                        file,
                        retrievingStatus.addSubStatus({
                            status: "waiting",
                            text: translations.status.processFileForRetrieval(file.name),
                        }),
                    );
                }
            },
            onFileProcessingStart(file) {
                statusSteps
                    .get(file)!
                    .setState({status: "loading", text: translations.status.processingFileForRetrieval(file.name)});
            },
            onFileProcessingEnd(file) {
                statusSteps
                    .get(file)!
                    .setState({status: "done", text: translations.status.processedFileForRetrieval(file.name)});
            },
            onFileProcessingStepProgress(file, step, progressInStep) {
                const verb = step === "loading" ? translations.verbs.loading :
                            step === "chunking" ? translations.verbs.chunking :
                            translations.verbs.embedding;
                statusSteps.get(file)!.setState({
                    status: "loading",
                    text: translations.status.fileProcessProgress(verb, file.name, `${(progressInStep * 100).toFixed(1)}%`),
                });
            },
        });

        result.entries = result.entries.filter(entry => entry.score > retrievalAffinityThreshold);
        result.entries = result.entries.filter(entry => entry.score > retrievalAffinityThreshold);

        if (!result) {
            return translations.errors.retrievalFailed;
        }

        // inject a retrieval result into the "processed" content
        let processedContent = "";
        const numRetrievals = result.entries.length;
        if (numRetrievals > 0) {
            // retrieval occurred and got results
            // show status
            retrievingStatus.setState({
                status: "done",
                text: translations.status.retrievalSuccess(numRetrievals, retrievalAffinityThreshold),
            });
            ctl.debug("Retrieval results", result);
            // add results to prompt
            processedContent += translations.llmPrompts.citationsPrefix;
            result.entries.forEach((entry, index) => {
                processedContent += `${translations.llmPrompts.citationLabel(index + 1)}: "${entry.content}"\n\n`;
            });
            await ctl.addCitations(result);
            processedContent += translations.llmPrompts.citationsSuffix(originalUserPrompt);
        } else {
            // retrieval occured but no relevant citations found
            retrievingStatus.setState({
                status: "canceled",
                text: translations.status.noRelevantContent(retrievalAffinityThreshold),
            });
            ctl.debug("No relevant citations found for user query");
            processedContent = translations.llmPrompts.noRetrievalNote(originalUserPrompt);
        }
        ctl.debug("Processed content", processedContent);

        ctl.debug("最終回傳內容大小", {
            contentLength: processedContent.length,
            entriesCount: result?.entries?.length || 0,
            firstEntry: result?.entries?.[0]?.content?.substring(0, 100)
        });

        return processedContent;
    } catch (error) {
        // 模型載入失敗
        retrievingStatus.setState({
            status: "error",
            text: translations.errors.modelNotFound(modelPath),
        });

        return translations.errors.modelNotFoundDetail(modelPath, error);
    }
}

async function prepareDocumentContextInjection(
    ctl: PromptPreprocessorController,
    input: ChatMessage,
): Promise<ChatMessage> {
    const translations = t();
    const documentInjectionSnippets: Map<FileHandle, string> = new Map();
    const files = input.consumeFiles(ctl.client, file => file.type !== "image");
    for (const file of files) {
        // This should take no time as the result is already in the cache
        const {content} = await ctl.client.files.parseDocument(file, {
            signal: ctl.abortSignal,
        });

        ctl.debug(text`
      Strategy: inject-full-content. Injecting full content of file '${file}' into the
      context. Length: ${content.length}.
    `);
        documentInjectionSnippets.set(file, content);
    }

    // Format the final user prompt
    // TODO:
    //    Make this templatable and configurable
    //      https://github.com/lmstudio-ai/llmster/issues/1017
    let formattedFinalUserPrompt = "";

    if (documentInjectionSnippets.size > 0) {
        formattedFinalUserPrompt += translations.llmPrompts.enrichedContextPrefix;

        for (const [fileHandle, snippet] of documentInjectionSnippets) {
            formattedFinalUserPrompt += `\n\n${translations.llmPrompts.fileContentStart(fileHandle.name)}\n\n${snippet}\n\n${translations.llmPrompts.fileContentEnd(fileHandle.name)}\n\n`;
        }

        formattedFinalUserPrompt += translations.llmPrompts.enrichedContextSuffix(input.getText());
    }

    input.replaceText(formattedFinalUserPrompt);
    return input;
}

async function measureContextWindow(ctx: Chat, model: LLMDynamicHandle) {
    const currentContextFormatted = await model.applyPromptTemplate(ctx);
    const totalTokensInContext = await model.countTokens(currentContextFormatted);
    const modelContextLength = await model.getContextLength();
    const modelRemainingContextLength = modelContextLength - totalTokensInContext;
    const contextOccupiedPercent = (totalTokensInContext / modelContextLength) * 100;
    return {
        totalTokensInContext,
        modelContextLength,
        modelRemainingContextLength,
        contextOccupiedPercent,
    };
}

async function chooseContextInjectionStrategy(
    ctl: PromptPreprocessorController,
    originalUserPrompt: string,
    files: Array<FileHandle>,
): Promise<DocumentContextInjectionStrategy> {
    const translations = t();

    // 1. 取得設定值
    const pluginConfig = ctl.getPluginConfig(dynamicConfig);
    const targetContextUsePercent = pluginConfig.get(CONFIG_KEYS.CONTEXT_THRESHOLD);

    const status = ctl.createStatus({
        status: "loading",
        text: translations.status.decidingStrategy,
    });

    const model = await ctl.client.llm.model();
    const ctx = await ctl.pullHistory();

    // Measure the context window
    const {
        totalTokensInContext,
        modelContextLength,
        modelRemainingContextLength,
        contextOccupiedPercent,
    } = await measureContextWindow(ctx, model);

    ctl.debug(
        `Context measurement result:\n\n` +
        `\tTotal tokens in context: ${totalTokensInContext}\n` +
        `\tModel context length: ${modelContextLength}\n` +
        `\tModel remaining context length: ${modelRemainingContextLength}\n` +
        `\tContext occupied percent: ${contextOccupiedPercent.toFixed(2)}%\n`,
    );

    // Get token count of provided files
    let totalFileTokenCount = 0;
    let totalReadTime = 0;
    let totalTokenizeTime = 0;
    for (const file of files) {
        const startTime = performance.now();

        const loadingStatus = status.addSubStatus({
            status: "loading",
            text: translations.status.loadingParser(file.name),
        });
        let actionProgressing = translations.verbs.reading;
        let parserIndicator = "";

        const {content} = await ctl.client.files.parseDocument(file, {
            signal: ctl.abortSignal,
            onParserLoaded: parser => {
                loadingStatus.setState({
                    status: "loading",
                    text: translations.status.parserLoaded(parser.library, file.name),
                });
                // Update action names if we're using a parsing framework
                if (parser.library !== "builtIn") {
                    actionProgressing = translations.verbs.parsing;
                    parserIndicator = ` with ${parser.library}`;
                }
            },
            onProgress: progress => {
                loadingStatus.setState({
                    status: "loading",
                    text: translations.status.fileProcessing(
                        actionProgressing,
                        file.name,
                        parserIndicator,
                        `${(progress * 100).toFixed(2)}%`
                    ),
                });
            },
        });
        loadingStatus.remove();

        totalReadTime += performance.now() - startTime;

        // tokenize file content
        const startTokenizeTime = performance.now();
        totalFileTokenCount += await model.countTokens(content);
        totalTokenizeTime += performance.now() - startTokenizeTime;
        if (totalFileTokenCount > modelRemainingContextLength) {
            // Early exit if we already have too many tokens. Helps with performance when there are a lot of files.
            break;
        }
    }
    ctl.debug(`Total file read time: ${totalReadTime.toFixed(2)} ms`);
    ctl.debug(`Total tokenize time: ${totalTokenizeTime.toFixed(2)} ms`);

    // Calculate total token count of files + user prompt
    ctl.debug(`Original User Prompt: ${originalUserPrompt}`);
    const userPromptTokenCount = (await model.tokenize(originalUserPrompt)).length;
    const totalFilePlusPromptTokenCount = totalFileTokenCount + userPromptTokenCount;

    // Calculate the available context tokens
    const contextOccupiedFraction = contextOccupiedPercent / 100;
    // const targetContextUsePercent = 0.7;
    const targetContextUsage = targetContextUsePercent * (1 - contextOccupiedFraction);
    const availableContextTokens = Math.floor(modelRemainingContextLength * targetContextUsage);

    // Debug log
    ctl.debug("Strategy Calculation:");
    ctl.debug(`\tTotal Tokens in All Files: ${totalFileTokenCount}`);
    ctl.debug(`\tTotal Tokens in User Prompt: ${userPromptTokenCount}`);
    ctl.debug(`\tModel Context Remaining: ${modelRemainingContextLength} tokens`);
    ctl.debug(`\tContext Occupied: ${contextOccupiedPercent.toFixed(2)}%`);
    ctl.debug(`\tAvailable Tokens: ${availableContextTokens}\n`);

    if (totalFilePlusPromptTokenCount > availableContextTokens) {
        const chosenStrategy = "retrieval";
        ctl.debug(
            `Chosen context injection strategy: '${chosenStrategy}'. Total file + prompt token count: ` +
            `${totalFilePlusPromptTokenCount} > ${
                targetContextUsage * 100
            }% * available context tokens: ${availableContextTokens}`,
        );
        status.setState({
            status: "done",
            text: translations.status.strategyRetrieval(Math.round(targetContextUsePercent * 100)),
        });
        return chosenStrategy;
    }

    // TODO:
    //
    //   Consider a more sophisticated strategy where we inject some header or summary content
    //   and then perform retrieval on the rest of the content.
    //
    //

    const chosenStrategy = "inject-full-content";
    status.setState({
        status: "done",
        text: translations.status.strategyInjectFull(Math.round(targetContextUsePercent * 100)),
    });
    return chosenStrategy;
}