diff --git a/src/agents/models-config.providers.discovery.ts b/src/agents/models-config.providers.discovery.ts index dd0504d2a53..64e1a9abe61 100644 --- a/src/agents/models-config.providers.discovery.ts +++ b/src/agents/models-config.providers.discovery.ts @@ -10,6 +10,7 @@ import { } from "./huggingface-models.js"; import { discoverKilocodeModels } from "./kilocode-models.js"; import { + enrichOllamaModelsWithContext, OLLAMA_DEFAULT_CONTEXT_WINDOW, OLLAMA_DEFAULT_COST, OLLAMA_DEFAULT_MAX_TOKENS, @@ -46,38 +47,6 @@ type VllmModelsResponse = { }>; }; -async function queryOllamaContextWindow( - apiBase: string, - modelName: string, -): Promise { - try { - const response = await fetch(`${apiBase}/api/show`, { - method: "POST", - headers: { "Content-Type": "application/json" }, - body: JSON.stringify({ name: modelName }), - signal: AbortSignal.timeout(3000), - }); - if (!response.ok) { - return undefined; - } - const data = (await response.json()) as { model_info?: Record }; - if (!data.model_info) { - return undefined; - } - for (const [key, value] of Object.entries(data.model_info)) { - if (key.endsWith(".context_length") && typeof value === "number" && Number.isFinite(value)) { - const contextWindow = Math.floor(value); - if (contextWindow > 0) { - return contextWindow; - } - } - } - return undefined; - } catch { - return undefined; - } -} - async function discoverOllamaModels( baseUrl?: string, opts?: { quiet?: boolean }, @@ -107,27 +76,18 @@ async function discoverOllamaModels( `Capping Ollama /api/show inspection to ${OLLAMA_SHOW_MAX_MODELS} models (received ${data.models.length})`, ); } - const discovered: ModelDefinitionConfig[] = []; - for (let index = 0; index < modelsToInspect.length; index += OLLAMA_SHOW_CONCURRENCY) { - const batch = modelsToInspect.slice(index, index + OLLAMA_SHOW_CONCURRENCY); - const batchDiscovered = await Promise.all( - batch.map(async (model) => { - const modelId = model.name; - const contextWindow = await queryOllamaContextWindow(apiBase, modelId); - return { - id: modelId, - name: modelId, - reasoning: isReasoningModelHeuristic(modelId), - input: ["text"], - cost: OLLAMA_DEFAULT_COST, - contextWindow: contextWindow ?? OLLAMA_DEFAULT_CONTEXT_WINDOW, - maxTokens: OLLAMA_DEFAULT_MAX_TOKENS, - } satisfies ModelDefinitionConfig; - }), - ); - discovered.push(...batchDiscovered); - } - return discovered; + const discovered = await enrichOllamaModelsWithContext(apiBase, modelsToInspect, { + concurrency: OLLAMA_SHOW_CONCURRENCY, + }); + return discovered.map((model) => ({ + id: model.name, + name: model.name, + reasoning: isReasoningModelHeuristic(model.name), + input: ["text"], + cost: OLLAMA_DEFAULT_COST, + contextWindow: model.contextWindow ?? OLLAMA_DEFAULT_CONTEXT_WINDOW, + maxTokens: OLLAMA_DEFAULT_MAX_TOKENS, + })); } catch (error) { if (!opts?.quiet) { log.warn(`Failed to discover Ollama models: ${String(error)}`); diff --git a/src/agents/ollama-models.test.ts b/src/agents/ollama-models.test.ts new file mode 100644 index 00000000000..7877d40bdf9 --- /dev/null +++ b/src/agents/ollama-models.test.ts @@ -0,0 +1,61 @@ +import { afterEach, describe, expect, it, vi } from "vitest"; +import { + enrichOllamaModelsWithContext, + resolveOllamaApiBase, + type OllamaTagModel, +} from "./ollama-models.js"; + +function jsonResponse(body: unknown, status = 200): Response { + return new Response(JSON.stringify(body), { + status, + headers: { "Content-Type": "application/json" }, + }); +} + +function requestUrl(input: string | URL | Request): string { + if (typeof input === "string") { + return input; + } + if (input instanceof URL) { + return input.toString(); + } + return input.url; +} + +function requestBody(body: BodyInit | null | undefined): string { + return typeof body === "string" ? body : "{}"; +} + +describe("ollama-models", () => { + afterEach(() => { + vi.unstubAllGlobals(); + }); + + it("strips /v1 when resolving the Ollama API base", () => { + expect(resolveOllamaApiBase("http://127.0.0.1:11434/v1")).toBe("http://127.0.0.1:11434"); + expect(resolveOllamaApiBase("http://127.0.0.1:11434///")).toBe("http://127.0.0.1:11434"); + }); + + it("enriches discovered models with context windows from /api/show", async () => { + const models: OllamaTagModel[] = [{ name: "llama3:8b" }, { name: "deepseek-r1:14b" }]; + const fetchMock = vi.fn(async (input: string | URL | Request, init?: RequestInit) => { + const url = requestUrl(input); + if (!url.endsWith("/api/show")) { + throw new Error(`Unexpected fetch: ${url}`); + } + const body = JSON.parse(requestBody(init?.body)) as { name?: string }; + if (body.name === "llama3:8b") { + return jsonResponse({ model_info: { "llama.context_length": 65536 } }); + } + return jsonResponse({}); + }); + vi.stubGlobal("fetch", fetchMock); + + const enriched = await enrichOllamaModelsWithContext("http://127.0.0.1:11434", models); + + expect(enriched).toEqual([ + { name: "llama3:8b", contextWindow: 65536 }, + { name: "deepseek-r1:14b", contextWindow: undefined }, + ]); + }); +}); diff --git a/src/agents/ollama-models.ts b/src/agents/ollama-models.ts index 19d95605203..20406b3a80e 100644 --- a/src/agents/ollama-models.ts +++ b/src/agents/ollama-models.ts @@ -27,6 +27,12 @@ export type OllamaTagsResponse = { models?: OllamaTagModel[]; }; +export type OllamaModelWithContext = OllamaTagModel & { + contextWindow?: number; +}; + +const OLLAMA_SHOW_CONCURRENCY = 8; + /** * Derive the Ollama native API base URL from a configured base URL. * @@ -43,6 +49,58 @@ export function resolveOllamaApiBase(configuredBaseUrl?: string): string { return trimmed.replace(/\/v1$/i, ""); } +export async function queryOllamaContextWindow( + apiBase: string, + modelName: string, +): Promise { + try { + const response = await fetch(`${apiBase}/api/show`, { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ name: modelName }), + signal: AbortSignal.timeout(3000), + }); + if (!response.ok) { + return undefined; + } + const data = (await response.json()) as { model_info?: Record }; + if (!data.model_info) { + return undefined; + } + for (const [key, value] of Object.entries(data.model_info)) { + if (key.endsWith(".context_length") && typeof value === "number" && Number.isFinite(value)) { + const contextWindow = Math.floor(value); + if (contextWindow > 0) { + return contextWindow; + } + } + } + return undefined; + } catch { + return undefined; + } +} + +export async function enrichOllamaModelsWithContext( + apiBase: string, + models: OllamaTagModel[], + opts?: { concurrency?: number }, +): Promise { + const concurrency = Math.max(1, Math.floor(opts?.concurrency ?? OLLAMA_SHOW_CONCURRENCY)); + const enriched: OllamaModelWithContext[] = []; + for (let index = 0; index < models.length; index += concurrency) { + const batch = models.slice(index, index + concurrency); + const batchResults = await Promise.all( + batch.map(async (model) => ({ + ...model, + contextWindow: await queryOllamaContextWindow(apiBase, model.name), + })), + ); + enriched.push(...batchResults); + } + return enriched; +} + /** Heuristic: treat models with "r1", "reasoning", or "think" in the name as reasoning models. */ export function isReasoningModelHeuristic(modelId: string): boolean { return /r1|reasoning|think|reason/i.test(modelId); diff --git a/src/commands/ollama-setup.test.ts b/src/commands/ollama-setup.test.ts index 2313588f180..124254c53b2 100644 --- a/src/commands/ollama-setup.test.ts +++ b/src/commands/ollama-setup.test.ts @@ -30,6 +30,53 @@ function jsonResponse(body: unknown, status = 200): Response { }); } +function requestUrl(input: string | URL | Request): string { + if (typeof input === "string") { + return input; + } + if (input instanceof URL) { + return input.toString(); + } + return input.url; +} + +function requestBody(body: BodyInit | null | undefined): string { + return typeof body === "string" ? body : "{}"; +} + +function createOllamaFetchMock(params: { + tags?: string[]; + show?: Record; + meResponses?: Response[]; + pullResponse?: Response; + tagsError?: Error; +}) { + const meResponses = [...(params.meResponses ?? [])]; + return vi.fn(async (input: string | URL | Request, init?: RequestInit) => { + const url = requestUrl(input); + if (url.endsWith("/api/tags")) { + if (params.tagsError) { + throw params.tagsError; + } + return jsonResponse({ models: (params.tags ?? []).map((name) => ({ name })) }); + } + if (url.endsWith("/api/show")) { + const body = JSON.parse(requestBody(init?.body)) as { name?: string }; + const contextWindow = body.name ? params.show?.[body.name] : undefined; + return contextWindow + ? jsonResponse({ model_info: { "llama.context_length": contextWindow } }) + : jsonResponse({}); + } + if (url.endsWith("/api/me")) { + return meResponses.shift() ?? jsonResponse({ username: "testuser" }); + } + if (url.endsWith("/api/pull")) { + return params.pullResponse ?? new Response('{"status":"success"}\n', { status: 200 }); + } + throw new Error(`Unexpected fetch: ${url}`); + }); +} + describe("ollama setup", () => { afterEach(() => { vi.unstubAllGlobals(); @@ -45,9 +92,7 @@ describe("ollama setup", () => { note: vi.fn(async () => undefined), } as unknown as WizardPrompter; - const fetchMock = vi - .fn() - .mockResolvedValueOnce(jsonResponse({ models: [{ name: "llama3:8b" }] })); + const fetchMock = createOllamaFetchMock({ tags: ["llama3:8b"] }); vi.stubGlobal("fetch", fetchMock); const result = await promptAndConfigureOllama({ cfg: {}, prompter }); @@ -62,10 +107,7 @@ describe("ollama setup", () => { note: vi.fn(async () => undefined), } as unknown as WizardPrompter; - const fetchMock = vi - .fn() - .mockResolvedValueOnce(jsonResponse({ models: [{ name: "llama3:8b" }] })) - .mockResolvedValueOnce(jsonResponse({ username: "testuser" })); + const fetchMock = createOllamaFetchMock({ tags: ["llama3:8b"] }); vi.stubGlobal("fetch", fetchMock); const result = await promptAndConfigureOllama({ cfg: {}, prompter }); @@ -80,11 +122,7 @@ describe("ollama setup", () => { note: vi.fn(async () => undefined), } as unknown as WizardPrompter; - const fetchMock = vi - .fn() - .mockResolvedValueOnce( - jsonResponse({ models: [{ name: "llama3:8b" }, { name: "glm-4.7-flash" }] }), - ); + const fetchMock = createOllamaFetchMock({ tags: ["llama3:8b", "glm-4.7-flash"] }); vi.stubGlobal("fetch", fetchMock); const result = await promptAndConfigureOllama({ cfg: {}, prompter }); @@ -103,13 +141,13 @@ describe("ollama setup", () => { note: vi.fn(async () => undefined), } as unknown as WizardPrompter; - const fetchMock = vi - .fn() - .mockResolvedValueOnce(jsonResponse({ models: [{ name: "llama3:8b" }] })) - .mockResolvedValueOnce( + const fetchMock = createOllamaFetchMock({ + tags: ["llama3:8b"], + meResponses: [ jsonResponse({ error: "not signed in", signin_url: "https://ollama.com/signin" }, 401), - ) - .mockResolvedValueOnce(jsonResponse({ username: "testuser" })); + jsonResponse({ username: "testuser" }), + ], + }); vi.stubGlobal("fetch", fetchMock); await promptAndConfigureOllama({ cfg: {}, prompter }); @@ -127,13 +165,13 @@ describe("ollama setup", () => { note: vi.fn(async () => undefined), } as unknown as WizardPrompter; - const fetchMock = vi - .fn() - .mockResolvedValueOnce(jsonResponse({ models: [{ name: "llama3:8b" }] })) - .mockResolvedValueOnce( + const fetchMock = createOllamaFetchMock({ + tags: ["llama3:8b"], + meResponses: [ jsonResponse({ error: "not signed in", signin_url: "https://ollama.com/signin" }, 401), - ) - .mockResolvedValueOnce(jsonResponse({ username: "testuser" })); + jsonResponse({ username: "testuser" }), + ], + }); vi.stubGlobal("fetch", fetchMock); await promptAndConfigureOllama({ cfg: {}, prompter }); @@ -148,15 +186,16 @@ describe("ollama setup", () => { note: vi.fn(async () => undefined), } as unknown as WizardPrompter; - const fetchMock = vi - .fn() - .mockResolvedValueOnce(jsonResponse({ models: [{ name: "llama3:8b" }] })); + const fetchMock = createOllamaFetchMock({ tags: ["llama3:8b"] }); vi.stubGlobal("fetch", fetchMock); await promptAndConfigureOllama({ cfg: {}, prompter }); - expect(fetchMock).toHaveBeenCalledTimes(1); - expect(fetchMock.mock.calls[0][0]).toContain("/api/tags"); + expect(fetchMock).toHaveBeenCalledTimes(2); + expect(fetchMock.mock.calls[0]?.[0]).toContain("/api/tags"); + expect(fetchMock.mock.calls.some((call) => requestUrl(call[0]).includes("/api/me"))).toBe( + false, + ); }); it("suggested models appear first in model list (cloud+local)", async () => { @@ -166,14 +205,9 @@ describe("ollama setup", () => { note: vi.fn(async () => undefined), } as unknown as WizardPrompter; - const fetchMock = vi - .fn() - .mockResolvedValueOnce( - jsonResponse({ - models: [{ name: "llama3:8b" }, { name: "glm-4.7-flash" }, { name: "deepseek-r1:14b" }], - }), - ) - .mockResolvedValueOnce(jsonResponse({ username: "testuser" })); + const fetchMock = createOllamaFetchMock({ + tags: ["llama3:8b", "glm-4.7-flash", "deepseek-r1:14b"], + }); vi.stubGlobal("fetch", fetchMock); const result = await promptAndConfigureOllama({ cfg: {}, prompter }); @@ -189,6 +223,27 @@ describe("ollama setup", () => { ]); }); + it("uses /api/show context windows when building Ollama model configs", async () => { + const prompter = { + text: vi.fn().mockResolvedValueOnce("http://127.0.0.1:11434"), + select: vi.fn().mockResolvedValueOnce("local"), + note: vi.fn(async () => undefined), + } as unknown as WizardPrompter; + + const fetchMock = createOllamaFetchMock({ + tags: ["llama3:8b"], + show: { "llama3:8b": 65536 }, + }); + vi.stubGlobal("fetch", fetchMock); + + const result = await promptAndConfigureOllama({ cfg: {}, prompter }); + const model = result.config.models?.providers?.ollama?.models?.find( + (m) => m.id === "llama3:8b", + ); + + expect(model?.contextWindow).toBe(65536); + }); + describe("ensureOllamaModelPulled", () => { it("pulls model when not available locally", async () => { const progress = { update: vi.fn(), stop: vi.fn() }; @@ -196,12 +251,10 @@ describe("ollama setup", () => { progress: vi.fn(() => progress), } as unknown as WizardPrompter; - const fetchMock = vi - .fn() - // /api/tags — model not present - .mockResolvedValueOnce(jsonResponse({ models: [{ name: "llama3:8b" }] })) - // /api/pull - .mockResolvedValueOnce(new Response('{"status":"success"}\n', { status: 200 })); + const fetchMock = createOllamaFetchMock({ + tags: ["llama3:8b"], + pullResponse: new Response('{"status":"success"}\n', { status: 200 }), + }); vi.stubGlobal("fetch", fetchMock); await ensureOllamaModelPulled({ @@ -219,9 +272,7 @@ describe("ollama setup", () => { it("skips pull when model is already available", async () => { const prompter = {} as unknown as WizardPrompter; - const fetchMock = vi - .fn() - .mockResolvedValueOnce(jsonResponse({ models: [{ name: "glm-4.7-flash" }] })); + const fetchMock = createOllamaFetchMock({ tags: ["glm-4.7-flash"] }); vi.stubGlobal("fetch", fetchMock); await ensureOllamaModelPulled({ @@ -268,10 +319,10 @@ describe("ollama setup", () => { }); it("uses discovered model when requested non-interactive download fails", async () => { - const fetchMock = vi - .fn() - .mockResolvedValueOnce(jsonResponse({ models: [{ name: "qwen2.5-coder:7b" }] })) - .mockResolvedValueOnce(new Response('{"error":"disk full"}\n', { status: 200 })); + const fetchMock = createOllamaFetchMock({ + tags: ["qwen2.5-coder:7b"], + pullResponse: new Response('{"error":"disk full"}\n', { status: 200 }), + }); vi.stubGlobal("fetch", fetchMock); const runtime = { @@ -306,10 +357,10 @@ describe("ollama setup", () => { }); it("normalizes ollama/ prefix in non-interactive custom model download", async () => { - const fetchMock = vi - .fn() - .mockResolvedValueOnce(jsonResponse({ models: [] })) - .mockResolvedValueOnce(new Response('{"status":"success"}\n', { status: 200 })); + const fetchMock = createOllamaFetchMock({ + tags: [], + pullResponse: new Response('{"status":"success"}\n', { status: 200 }), + }); vi.stubGlobal("fetch", fetchMock); const runtime = { @@ -328,14 +379,14 @@ describe("ollama setup", () => { }); const pullRequest = fetchMock.mock.calls[1]?.[1]; - expect(JSON.parse(String(pullRequest?.body))).toEqual({ name: "llama3.2:latest" }); + expect(JSON.parse(requestBody(pullRequest?.body))).toEqual({ name: "llama3.2:latest" }); expect(result.agents?.defaults?.model).toEqual( expect.objectContaining({ primary: "ollama/llama3.2:latest" }), ); }); it("accepts cloud models in non-interactive mode without pulling", async () => { - const fetchMock = vi.fn().mockResolvedValueOnce(jsonResponse({ models: [] })); + const fetchMock = createOllamaFetchMock({ tags: [] }); vi.stubGlobal("fetch", fetchMock); const runtime = { @@ -363,7 +414,9 @@ describe("ollama setup", () => { }); it("exits when Ollama is unreachable", async () => { - const fetchMock = vi.fn().mockRejectedValueOnce(new Error("connect ECONNREFUSED")); + const fetchMock = createOllamaFetchMock({ + tagsError: new Error("connect ECONNREFUSED"), + }); vi.stubGlobal("fetch", fetchMock); const runtime = { diff --git a/src/commands/ollama-setup.ts b/src/commands/ollama-setup.ts index 7af3e18cff1..f6aec85dafc 100644 --- a/src/commands/ollama-setup.ts +++ b/src/commands/ollama-setup.ts @@ -2,8 +2,10 @@ import { upsertAuthProfileWithLock } from "../agents/auth-profiles.js"; import { OLLAMA_DEFAULT_BASE_URL, buildOllamaModelDefinition, + enrichOllamaModelsWithContext, fetchOllamaModels, resolveOllamaApiBase, + type OllamaModelWithContext, } from "../agents/ollama-models.js"; import type { OpenClawConfig } from "../config/config.js"; import type { RuntimeEnv } from "../runtime.js"; @@ -239,14 +241,20 @@ async function pullOllamaModelNonInteractive( return true; } -function buildOllamaModelsConfig(modelNames: string[]) { - return modelNames.map((name) => buildOllamaModelDefinition(name)); +function buildOllamaModelsConfig( + modelNames: string[], + discoveredModelsByName?: Map, +) { + return modelNames.map((name) => + buildOllamaModelDefinition(name, discoveredModelsByName?.get(name)?.contextWindow), + ); } function applyOllamaProviderConfig( cfg: OpenClawConfig, baseUrl: string, modelNames: string[], + discoveredModelsByName?: Map, ): OpenClawConfig { return { ...cfg, @@ -259,7 +267,7 @@ function applyOllamaProviderConfig( baseUrl, api: "ollama", apiKey: "OLLAMA_API_KEY", // pragma: allowlist secret - models: buildOllamaModelsConfig(modelNames), + models: buildOllamaModelsConfig(modelNames, discoveredModelsByName), }, }, }, @@ -299,7 +307,6 @@ export async function promptAndConfigureOllama(params: { // 2. Check reachability const { reachable, models } = await fetchOllamaModels(baseUrl); - const modelNames = models.map((m) => m.name); if (!reachable) { await prompter.note( @@ -314,6 +321,10 @@ export async function promptAndConfigureOllama(params: { throw new WizardCancelledError("Ollama not reachable"); } + const enrichedModels = await enrichOllamaModelsWithContext(baseUrl, models.slice(0, 50)); + const discoveredModelsByName = new Map(enrichedModels.map((model) => [model.name, model])); + const modelNames = models.map((m) => m.name); + // 3. Mode selection const mode = (await prompter.select({ message: "Ollama mode", @@ -387,7 +398,12 @@ export async function promptAndConfigureOllama(params: { await storeOllamaCredential(params.agentDir); const defaultModelId = suggestedModels[0] ?? OLLAMA_DEFAULT_MODEL; - const config = applyOllamaProviderConfig(params.cfg, baseUrl, orderedModelNames); + const config = applyOllamaProviderConfig( + params.cfg, + baseUrl, + orderedModelNames, + discoveredModelsByName, + ); return { config, defaultModelId }; } @@ -405,7 +421,6 @@ export async function configureOllamaNonInteractive(params: { const baseUrl = resolveOllamaApiBase(configuredBaseUrl); const { reachable, models } = await fetchOllamaModels(baseUrl); - const modelNames = models.map((m) => m.name); const explicitModel = normalizeOllamaModelName(opts.customModelId); if (!reachable) { @@ -421,6 +436,10 @@ export async function configureOllamaNonInteractive(params: { await storeOllamaCredential(); + const enrichedModels = await enrichOllamaModelsWithContext(baseUrl, models.slice(0, 50)); + const discoveredModelsByName = new Map(enrichedModels.map((model) => [model.name, model])); + const modelNames = models.map((m) => m.name); + // Apply local suggested model ordering. const suggestedModels = OLLAMA_SUGGESTED_MODELS_LOCAL; const orderedModelNames = [ @@ -478,7 +497,12 @@ export async function configureOllamaNonInteractive(params: { } } - const config = applyOllamaProviderConfig(params.nextConfig, baseUrl, allModelNames); + const config = applyOllamaProviderConfig( + params.nextConfig, + baseUrl, + allModelNames, + discoveredModelsByName, + ); const modelRef = `ollama/${defaultModelId}`; runtime.log(`Default Ollama model: ${defaultModelId}`); return applyAgentDefaultModelPrimary(config, modelRef);