src / toolsProvider.ts

import { tool, Tool, ToolsProviderController } from "@lmstudio/sdk";
import { writeFile, access } from "fs/promises";
import { join } from "path";
import { z } from "zod";
import { configSchematics } from "./config";

export async function toolsProvider(ctl:ToolsProviderController):Promise<Tool[]> {
	const tools: Tool[] = [];
	const COMPACT_INLINE_LIMIT = 2;
	const COMPACT_GALLERY_LIMIT = 6;

	type FetchHTMLResult = {
		html: string;
		head: string;
		body: string;
		source: "direct" | "jina";
		finalUrl: string;
		statusCode?: number;
		serverHeader?: string;
	};
	const fetchWithRetries = async (targetUrl:string, signal:AbortSignal, headers:Record<string, string>, attempts = 2) => {
		let lastError: unknown = null;
		for(let i = 0; i < attempts; i++) {
			try {
				return await fetch(targetUrl, {
					method: "GET",
					signal,
					headers,
					redirect: "follow",
				});
			} catch(error) {
				lastError = error;
				if(signal.aborted) throw error;
				await new Promise(resolve => setTimeout(resolve, 350 * (i + 1)));
			}
		}
		throw lastError;
	}
	const toParsedHtml = (html:string) => {
		const headStart = html.indexOf("<head>");
		const headEnd = html.indexOf("</head>") + 7;
		const head = headStart >= 0 && headEnd > headStart ? html.substring(headStart, headEnd) : "";
		const bodyStart = html.match(/<body[^>]*>/)?.index || 0;
		const bodyEnd = html.lastIndexOf("</body>");
		const body = html.substring(bodyStart, bodyEnd > bodyStart ? bodyEnd : html.length);
		return { head, body };
	}
	const looksBlocked = (statusCode:number, _serverHeader:string, html:string) =>
		statusCode === 401
		|| statusCode === 403
		|| statusCode === 429
		|| statusCode === 503
		|| /access denied|captcha|cf-chl|checking your browser|attention required|challenge-platform/i.test(html);
	const fetchViaJina = async (url:string, signal:AbortSignal, warn:(msg:string) => void):Promise<FetchHTMLResult | null> => {
		const jinaUrl = `https://r.jina.ai/http://${url.replace(/^https?:\/\//, "")}`;
		try {
			const response = await fetchWithRetries(jinaUrl, signal, {
				"User-Agent": "Mozilla/5.0",
				"Accept": "text/plain,text/html;q=0.9,*/*;q=0.8",
			});
			if(!response.ok) {
				warn(`Jina fallback failed: ${response.status} ${response.statusText}`);
				return null;
			}
			const text = await response.text();
			if(!text.trim()) {
				warn("Jina fallback returned empty content.");
				return null;
			}
			return {
				html: text,
				head: "",
				body: text,
				source: "jina",
				finalUrl: response.url || jinaUrl,
				statusCode: response.status,
				serverHeader: response.headers.get("server") || "",
			};
		} catch(error: any) {
			if (error instanceof DOMException && error.name === "AbortError") throw error;
			warn(`Jina fallback request failed: ${error?.message || String(error)}`);
			return null;
		}
	}
	const fetchHTML = async (url:string, signal:AbortSignal, warn:(msg:string) => void):Promise<FetchHTMLResult> => {
		const headers = spoofHeaders(url);
		let response: Response;
		try {
			response = await fetchWithRetries(url, signal, headers);
		} catch(error: any) {
			if (error instanceof DOMException && error.name === "AbortError") throw error;
			const fallback = await fetchViaJina(url, signal, warn);
			if(fallback) return fallback;
			throw new Error(`Failed to fetch website: network error (${error?.message || String(error)})`);
		}
		const html = await response.text();
		const statusCode = response.status;
		const serverHeader = response.headers.get("server") || "";
		if (!response.ok || looksBlocked(statusCode, serverHeader, html)) {
			const fallback = await fetchViaJina(url, signal, warn);
			if(fallback) return fallback;
			warn(`Failed to fetch website: ${statusCode} ${response.statusText}`);
			throw new Error(`Failed to fetch website: ${statusCode} ${response.statusText} (server: ${serverHeader || "unknown"})`);
		}
		const { head, body } = toParsedHtml(html);
		return {
			html,
			head,
			body,
			source: "direct",
			finalUrl: response.url || url,
			statusCode,
			serverHeader,
		};
	}

	const extractLinks = (body:string, url:string, maxLinks:number, searchTerms?:string[]) =>
		[...body.matchAll(/<a\s+[^>]*?href="([^"]+)"[^>]*>((?:\n|.)*?)<\/a>/g)]
			.map((match, index) => ({
				index,
				label: match[2]?.replace(/\\[ntr]|\s|<(?:[^>"]|"[^"]*")+>/g, " ").trim() || "",
				link: match[1]?.startsWith("/")
					? new URL(match[1], url).href
					: match[1],
			}))
			.filter(({ link }) => link?.startsWith("http"))
			.map((x, index, { length }) => {
				// Prioritize links fitting the search terms
				// Followed by short navigation links and content links with long labels
				// Fewer digits = more likely a navigation link than a content link
				const ratio = 1 / Math.min(1, /\d/g.exec(x.link)?.length || 1);
				const score
					= ratio * (100 - (x.label.length + x.link.length + (20 * index / length)))
					+ (1 - ratio) * x.label.split(/\s+/).length;
				return {
					...x,
					score: searchTerms?.length
						&& searchTerms.reduce((acc, term) => acc + (x.label.toLowerCase().includes(term.toLowerCase()) ? 1000 : 0), score)
						|| score,
				};
			})
			.sort((a, b) => b.score - a.score) // Sort by score in descending order
			.filter((x, i, arr) =>
				// Filter out duplicates based on link, keeping the first occurrence
				!arr.find((y, j) => j < i && y.link === x.link)
			)
			.slice(0, maxLinks) // Limit number of links
			// .sort((a, b) => a.index - b.index) // Sort by original order in the body
			.map(({ label, link }) => [label, link] as [string, string]);

	const extractImages = (body:string, url:string, maxImages:number, searchTerms?:string[]) => {
		const htmlMatches = [...body.matchAll(/<img(\s+[^>]*)/g)]
			.filter(x => x[1])
			.map(([, attributes], index) => {
				const alt = attributes.match(/\salt="([^"]+)"/)?.[1] || "";
				const src = attributes.match(/\ssrc="([^"]+)"/)?.[1]
					|| attributes.match(/\sdata-src="([^"]+)"/)?.[1]
					|| attributes.match(/\sdata-original="([^"]+)"/)?.[1]
					|| attributes.match(/\ssrcset="([^"]+)"/)?.[1]?.split(",")?.[0]?.trim()?.split(/\s+/)?.[0];
				return { index, alt, src };
			});
		const markdownMatches = [...body.matchAll(/!\[([^\]]*)\]\((https?:\/\/[^)\s]+)\)/g)]
			.map((match, index) => ({
				index: index + htmlMatches.length,
				alt: match[1] || "",
				src: match[2],
			}));
		const combined = [...htmlMatches, ...markdownMatches]
			.map(x => ({
				...x,
				src: x.src?.startsWith("/") ? new URL(x.src, url).href : x.src,
			}))
			.filter(({ src }) => src && src.startsWith("http"))
			.map(x => ({
				...x,
				score: searchTerms?.length
					? searchTerms.reduce((acc, term) => acc + (x.alt.toLowerCase().includes(term.toLowerCase()) ? 1000 : 0), x.alt.length)
					: x.alt.length,
			}))
			.sort((a, b) => b.score - a.score)
			.filter((x, i, arr) => !arr.find((y, j) => j < i && y.src === x.src))
			.slice(0, maxImages)
			.sort((a, b) => a.index - b.index);
		return combined.map(({ src, alt }) => [alt, src] as [string, string]);
	}

	const downloadImagesTool = tool({
		name: "Download Images",
		description: "Use with remote HTTP(S) URLs or websiteURL.",
		parameters: {
			imageURLs: z.array(z.string()).optional().describe("List of image references to view. Supports HTTP(S) URLs, local file paths, and markdown image links."),
			websiteURL: z.string().url().optional().describe("The URL of the website, whose images to view."),
			maxImages: z.number().int().min(1).max(200).optional().describe("Maximum number of images to view when websiteURL is provided."),
		},
		implementation: async ({ imageURLs, websiteURL, maxImages }, { status, warn, signal }) => {
			try {

				maxImages = undefinedIfAuto(ctl.getPluginConfig(configSchematics).get("maxImages"), -1)
					?? maxImages
					?? 10;

				const imageURLsToDownload = (imageURLs || [])
					.map(normalizeImageReference)
					.filter((x): x is string => !!x);
				if (imageURLsToDownload.length === 0 && !websiteURL) {
					return [];
				}

				if(websiteURL) {
					status("Fetching image URLs from website...");
					const { body } = await fetchHTML(websiteURL, signal, warn);
					const images = extractImages(body, websiteURL, maxImages).map(x => x[1]);
					imageURLsToDownload.push(...images);
				}

				status("Downloading images...");
				const workingDirectory = ctl.getWorkingDirectory();
				const timestamp = Date.now();
				const downloadPromises = imageURLsToDownload.map(async (url:string, i:number) => {
					if (isLocalImagePath(url)) {
						const safeLocal = await resolveLocalImagePathForCurrentWorkingDirectory(url, workingDirectory);
						if (!safeLocal) {
							return null;
						}
						return { fullPath: safeLocal, thumbPath: safeLocal }; // Already local, no resize needed here
					}
					
					const index = i + 1;
					try {
						const headers = spoofHeaders(url);
						const imageResponse = await fetch(url, {
							method: "GET",
							signal,
							headers,
						});
						if (!imageResponse.ok) {
							warn(`Failed to fetch image ${index}: ${imageResponse.statusText}`);
							return null; // Skip this image if download fails
						}
						const bytes = await imageResponse.bytes();
						if (bytes.length === 0) {
							warn(`Image ${index} is empty: ${url}`);
							return null; // Skip empty images
						}
						// save the image to a file in the working directory
						const fileExtension = /image\/([\w]+)/.exec(imageResponse.headers.get('content-type') || '')?.[1]
							|| /\.([\w]+)(?:\?.*)$/.exec(url)?.[1] // Extract extension from URL if content type is not available
							|| 'jpg'; // Default to jpg if no content type
						const fileName = `${timestamp}-${index}.${fileExtension}`;
						const filePath = join(workingDirectory, fileName);
						const localPath = fileName;
						await writeFile(filePath, bytes, 'binary');
						const thumbFileName = `${timestamp}-${index}-thumb.webp`;
						const thumbFilePath = join(workingDirectory, thumbFileName);
						const thumbCreated = await createThumbnailWebp(bytes, thumbFilePath);
						const thumbLocalPath = thumbCreated ? thumbFileName : localPath;
						return {
							fullPath: localPath,
							thumbPath: thumbLocalPath,
						};
					} catch (error: any) {
						if (error instanceof DOMException && error.name === "AbortError")
							return null; // Skip if download was aborted
						warn(`Error fetching image ${index}: ${error.message}`);
						return null; // Skip this image on error
					}
				});
				const downloadedImageEntries = (await Promise.all(downloadPromises))
					.filter((x): x is { fullPath: string; thumbPath: string } => !!x);
				if (downloadedImageEntries.length === 0) {
					// Avoid returning remote URLs as if they were downloaded.
					return {
						count: 0,
						images: [],
						compactGalleryMarkdown: "",
						hint: "No images were downloaded successfully. The listed source URLs may be blocked or unavailable.",
					};
				}

				status(`Downloaded ${downloadedImageEntries.length} images successfully.`);

				return buildImagePresentationPayload(downloadedImageEntries, COMPACT_INLINE_LIMIT, COMPACT_GALLERY_LIMIT);
			} catch (error: any) {
				if (error instanceof DOMException && error.name === "AbortError") {
					return "Image download aborted by user.";
				}
				console.error(error);
				warn(`Error during image download: ${error.message}`);
				return `Error: ${error.message}`;
			}
		}
	});

	const visitWebsiteTool = tool({
		name: "Visit Website",
		description: "Visit a website and return its title, headings, links, images, and text content. Images are automatically downloaded and viewable.",
		parameters: {
			url: z.string().url().describe("The URL of the website to visit"),
			findInPage: z.array(z.string()).optional().describe("Highly recommended! Optional search terms to prioritize which links, images, and content to return."),
			maxLinks: z.number().int().min(0).max(200).optional().describe("Maximum number of links to extract from the page."),
			maxImages: z.number().int().min(0).max(200).optional().describe("Maximum number of images to extract from the page."),
			contentLimit: z.number().int().min(0).max(10_000).optional().describe("Maximum text content length to extract from the page."),
		},
		implementation: async ({ url, maxLinks, maxImages, contentLimit, findInPage: searchTerms }, context) => {
			const { status, warn, signal } = context;
			status("Visiting website...");

			try {
				maxLinks = undefinedIfAuto(ctl.getPluginConfig(configSchematics).get("maxLinks"), -1)
					?? maxLinks
					?? 40;
				maxImages = undefinedIfAuto(ctl.getPluginConfig(configSchematics).get("maxImages"), -1)
					?? maxImages
					?? 10;
				contentLimit = undefinedIfAuto(ctl.getPluginConfig(configSchematics).get("contentLimit"), -1)
					?? contentLimit
					?? 2000;

				const { head, body, source, finalUrl, statusCode, serverHeader } = await fetchHTML(url, signal, warn);
				status("Website visited successfully.");
				
				const title = head.match(/<title>([^<]*)<\/title>/)?.[1] || ""
				const h1 = body.match(/<h1[^>]*>([^<]*)<\/h1>/)?.[1] || "";
				const h2 = body.match(/<h2[^>]*>([^<]*)<\/h2>/)?.[1] || "";
				const h3 = body.match(/<h3[^>]*>([^<]*)<\/h3>/)?.[1] || "";
				const links = maxLinks && extractLinks(body, url, maxLinks, searchTerms);
				const imagesToFetch = maxImages ? extractImages(body, url, maxImages, searchTerms) : [];
				const imagesPayload = maxImages &&
					await downloadImagesTool.implementation({ imageURLs: imagesToFetch.map(x => x[1]) }, context) as any;
				const imageMarkdowns: string[] = Array.isArray(imagesPayload)
					? imagesPayload
					: imagesPayload?.images || [];
				const images = maxImages &&
					imageMarkdowns.map((markdown, index) => [imagesToFetch[index]?.[0] || "", markdown] as [string, string]);

				// fetch the text content from the body using DOMParser
				const allContent = contentLimit && body
					.replace(/<script[^>]*>[\s\S]*?<\/script>/gi, '') // Remove script tags
					.replace(/<style[^>]*>[\s\S]*?<\/style>/gi, '') // Remove style tags
					.replace(/<[^>]+>/g, '') // Remove all HTML tags
					.replace(/\s+/g, ' ') // Normalize whitespace
					.trim() || '';
				
				let content = "";
				if(searchTerms?.length && contentLimit < allContent.length) {
					const padding = `.{0,${contentLimit / (searchTerms.length * 2)}}`;
					const matches = searchTerms
						.map(term => new RegExp(padding + term + padding, 'gi').exec(allContent))
						.filter(match => !!match)
						.sort((a, b) => a.index - b.index); // Sort by index in the content
					let nextMinIndex = 0;
					for(const match of matches) {
						// Ensure we don't return duplicates by merging overlapping matches
						content += match.index >= nextMinIndex
							// The Match does not overlap with the previous one
							? match[0]
							// The match overlaps so we just extend the content to include it
							: match[0].slice(nextMinIndex - match.index);
						nextMinIndex = match.index + match[0].length;
					}
				}
				else content = allContent.slice(0, contentLimit) // Limit text length
					
				return {
					url, title, h1, h2, h3,
					fetch: { source, finalUrl, statusCode, server: serverHeader || undefined },
					...(links ? { links } : {}),
					...(images ? { images } : {}),
					...(imagesPayload?.compactGalleryMarkdown ? { compactGalleryMarkdown: imagesPayload.compactGalleryMarkdown } : {}),
					...(imagesPayload?.hint ? { hint: imagesPayload.hint } : {}),
					...(content ? { content } : {}),
				};
			} catch (error: any) {
				if (error instanceof DOMException && error.name === "AbortError") {
					return "Website visit aborted by user.";
				}
				console.error(error);
				warn(`Error during website visit: ${error.message}`);
				return `Error: ${error.message}`;
			}
		},
	});


	tools.push(visitWebsiteTool);
	tools.push(downloadImagesTool);
	return tools;
}

const undefinedIfAuto = (value: unknown, autoValue: unknown) =>
	value === autoValue ? undefined : value as undefined;

const spoofedUserAgents = [
	// Random spoofed realistic user agents for DuckDuckGo
	"Mozilla/5.0 (Linux; Android 10; SM-M515F) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/87.0.4280.141 Mobile Safari/537.36",
	"Mozilla/5.0 (Linux; Android 6.0; E5533) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/83.0.4103.101 Mobile Safari/537.36",
	"Mozilla/5.0 (Linux; Android 8.1.0; AX1082) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/83.0.4103.83 Mobile Safari/537.36",
	"Mozilla/5.0 (Linux; Android 8.1.0; TM-MID1020A) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/83.0.4103.96 Safari/537.36",
	"Mozilla/5.0 (Linux; Android 9; POT-LX1) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/96.0.4664.45 Mobile Safari/537.36",
	"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/97.0.4692.71 Safari/537.36",
	"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/98.0.4758.80 Safari/537.36",
	"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/14.0.3 Safari/605.1.15",
	"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/18.3.1 Safari/605.1.15",
	"Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:97.0) Gecko/20100101 Firefox/97.0",
	"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/134.0.0.0 Safari/537.36 Edg/134.0.0.0",
	"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/97.0.4692.71 Safari/537.36 Edg/97.0.1072.71",
	"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/97.0.4692.71 Safari/537.36",
	"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/98.0.4758.80 Safari/537.36 Edg/98.0.1108.62",
	"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/98.0.4758.80 Safari/537.36",
	"Mozilla/5.0 (X11; CrOS x86_64 14541.0.0) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/134.0.0.0 Safari/537.36",
	"Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/97.0.4692.71 Safari/537.36",
	"Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:97.0) Gecko/20100101 Firefox/97.0",
	"Opera/9.80 (Android 7.0; Opera Mini/36.2.2254/119.132; U; id) Presto/2.12.423 Version/12.16",
]

function spoofHeaders(url:string) {
	const domain = new URL(url).hostname;
	return {
		'User-Agent': spoofedUserAgents[Math.floor(Math.random() * spoofedUserAgents.length)],
		'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8',
		'Accept-Language': 'en-US,en;q=0.9',
		'Accept-Encoding': 'gzip, deflate, br',
		'Referer': 'https://' + domain + '/',
		'Connection': 'keep-alive',
		'Upgrade-Insecure-Requests': '1',
		'Sec-Fetch-Dest': 'document',
		'Sec-Fetch-Mode': 'navigate',
		'Sec-Fetch-Site': 'same-origin',
		'Sec-Fetch-User': '?1',
		'Cache-Control': 'max-age=0',
	};
}

function buildImagePresentationPayload(
	localImages: Array<{ fullPath: string; thumbPath: string }>,
	inlineLimit = 2,
	galleryLimit = 6,
) {
	const thumbImages = localImages.map((x, i) => `![Image ${i + 1}](${x.thumbPath})`);
	const hint = "These images are already downloaded and renderable (stored locally). Use compactGalleryMarkdown when you want a compact table gallery; otherwise, please, embed images where they fit with ![](imagefile). The user will not see images unless you embed them. Only standard Markdown (no HTML <img>). Do NOT call Download Images as these are already downloaded local files";
	return {
		count: localImages.length,
		images: thumbImages,
		compactGalleryMarkdown: buildCompactGalleryMarkdown(thumbImages),
		hint,
		thumbnailsUsed: true,
	};
}

function normalizeImageReference(input: string): string | null {
	const trimmed = input.trim();
	if (!trimmed) return null;
	const markdownMatch = trimmed.match(/!\[[^\]]*\]\(([^)]+)\)/);
	const extracted = markdownMatch?.[1]?.trim() || trimmed;
	if (/^https?:\/\//i.test(extracted)) return extracted;
	if (extracted.startsWith("/")) return extracted;
	if (/^[a-zA-Z]:[\\/]/.test(extracted)) return extracted;
	return null;
}

function buildCompactGalleryMarkdown(imageMarkdowns: string[]): string {
	if (imageMarkdowns.length === 0) return "";
	const rows: string[] = ["| 1 | 2 |", "| --- | --- |"];
	for (let i = 0; i < imageMarkdowns.length; i += 2) {
		const left = imageMarkdowns[i] || "";
		const right = imageMarkdowns[i + 1] || "";
		rows.push(`| ${left} | ${right} |`);
	}
	return rows.join("\n");
}

const THUMB_MAX_WIDTH = 360;
const THUMB_WEBP_QUALITY = 60;

function isLocalImagePath(value: string): boolean {
	return value.startsWith("/") || value.startsWith("./") || value.startsWith("../") || /^[a-zA-Z]:[\\/]/.test(value);
}

async function createThumbnailWebp(
	inputBytes: Uint8Array,
	outputPath: string,
): Promise<boolean> {
	try {
		const sharpModule = await import("sharp");
		const sharp = (sharpModule as any).default || sharpModule;
		const outputBytes = await sharp(Buffer.from(inputBytes))
			.resize({ width: THUMB_MAX_WIDTH, fit: "inside", withoutEnlargement: true })
			.webp({ quality: THUMB_WEBP_QUALITY })
			.toBuffer();
		if (!outputBytes?.length) {
			return false;
		}
		await writeFile(outputPath, outputBytes, "binary");
		return true;
	} catch {
		return false;
	}
}

async function resolveLocalImagePathForCurrentWorkingDirectory(
	inputPath: string,
	workingDirectory: string,
): Promise<string | null> {
	if (!inputPath) return null;
	const normalized = inputPath.replace(/\\/g, "/");
	const filename = normalized.split("/").pop() || "";
	const candidates = [inputPath, normalized, filename].filter(Boolean);
	for (const candidate of candidates) {
		const absoluteCandidate = candidate === filename
			? join(workingDirectory, filename)
			: toAbsolutePath(candidate, workingDirectory);
		try {
			await access(absoluteCandidate);
			return candidate === filename ? filename : normalized;
		} catch {
			// Try next candidate
		}
	}
	return null;
}

function toAbsolutePath(pathLike: string, workingDirectory: string): string {
	if (/^[a-zA-Z]:[\\/]/.test(pathLike)) return pathLike;
	if (pathLike.startsWith("/")) {
		if (/^\/Users\//.test(pathLike)) {
			return pathLike.replace(/^\//, "").replace(/\//g, "\\");
		}
		return join(workingDirectory, pathLike);
	}
	return join(workingDirectory, pathLike);
}