fix(ollama): share model context discovery

This commit is contained in:
Peter Steinberger
2026-03-11 20:08:35 +00:00
parent 9329a0ab24
commit 620bae4ec7
5 changed files with 273 additions and 117 deletions

View File

@@ -10,6 +10,7 @@ import {
} from "./huggingface-models.js";
import { discoverKilocodeModels } from "./kilocode-models.js";
import {
enrichOllamaModelsWithContext,
OLLAMA_DEFAULT_CONTEXT_WINDOW,
OLLAMA_DEFAULT_COST,
OLLAMA_DEFAULT_MAX_TOKENS,
@@ -46,38 +47,6 @@ type VllmModelsResponse = {
}>;
};
async function queryOllamaContextWindow(
apiBase: string,
modelName: string,
): Promise<number | undefined> {
try {
const response = await fetch(`${apiBase}/api/show`, {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ name: modelName }),
signal: AbortSignal.timeout(3000),
});
if (!response.ok) {
return undefined;
}
const data = (await response.json()) as { model_info?: Record<string, unknown> };
if (!data.model_info) {
return undefined;
}
for (const [key, value] of Object.entries(data.model_info)) {
if (key.endsWith(".context_length") && typeof value === "number" && Number.isFinite(value)) {
const contextWindow = Math.floor(value);
if (contextWindow > 0) {
return contextWindow;
}
}
}
return undefined;
} catch {
return undefined;
}
}
async function discoverOllamaModels(
baseUrl?: string,
opts?: { quiet?: boolean },
@@ -107,27 +76,18 @@ async function discoverOllamaModels(
`Capping Ollama /api/show inspection to ${OLLAMA_SHOW_MAX_MODELS} models (received ${data.models.length})`,
);
}
const discovered: ModelDefinitionConfig[] = [];
for (let index = 0; index < modelsToInspect.length; index += OLLAMA_SHOW_CONCURRENCY) {
const batch = modelsToInspect.slice(index, index + OLLAMA_SHOW_CONCURRENCY);
const batchDiscovered = await Promise.all(
batch.map(async (model) => {
const modelId = model.name;
const contextWindow = await queryOllamaContextWindow(apiBase, modelId);
return {
id: modelId,
name: modelId,
reasoning: isReasoningModelHeuristic(modelId),
input: ["text"],
cost: OLLAMA_DEFAULT_COST,
contextWindow: contextWindow ?? OLLAMA_DEFAULT_CONTEXT_WINDOW,
maxTokens: OLLAMA_DEFAULT_MAX_TOKENS,
} satisfies ModelDefinitionConfig;
}),
);
discovered.push(...batchDiscovered);
}
return discovered;
const discovered = await enrichOllamaModelsWithContext(apiBase, modelsToInspect, {
concurrency: OLLAMA_SHOW_CONCURRENCY,
});
return discovered.map((model) => ({
id: model.name,
name: model.name,
reasoning: isReasoningModelHeuristic(model.name),
input: ["text"],
cost: OLLAMA_DEFAULT_COST,
contextWindow: model.contextWindow ?? OLLAMA_DEFAULT_CONTEXT_WINDOW,
maxTokens: OLLAMA_DEFAULT_MAX_TOKENS,
}));
} catch (error) {
if (!opts?.quiet) {
log.warn(`Failed to discover Ollama models: ${String(error)}`);

View File

@@ -0,0 +1,61 @@
import { afterEach, describe, expect, it, vi } from "vitest";
import {
enrichOllamaModelsWithContext,
resolveOllamaApiBase,
type OllamaTagModel,
} from "./ollama-models.js";
function jsonResponse(body: unknown, status = 200): Response {
return new Response(JSON.stringify(body), {
status,
headers: { "Content-Type": "application/json" },
});
}
function requestUrl(input: string | URL | Request): string {
if (typeof input === "string") {
return input;
}
if (input instanceof URL) {
return input.toString();
}
return input.url;
}
function requestBody(body: BodyInit | null | undefined): string {
return typeof body === "string" ? body : "{}";
}
describe("ollama-models", () => {
afterEach(() => {
vi.unstubAllGlobals();
});
it("strips /v1 when resolving the Ollama API base", () => {
expect(resolveOllamaApiBase("http://127.0.0.1:11434/v1")).toBe("http://127.0.0.1:11434");
expect(resolveOllamaApiBase("http://127.0.0.1:11434///")).toBe("http://127.0.0.1:11434");
});
it("enriches discovered models with context windows from /api/show", async () => {
const models: OllamaTagModel[] = [{ name: "llama3:8b" }, { name: "deepseek-r1:14b" }];
const fetchMock = vi.fn(async (input: string | URL | Request, init?: RequestInit) => {
const url = requestUrl(input);
if (!url.endsWith("/api/show")) {
throw new Error(`Unexpected fetch: ${url}`);
}
const body = JSON.parse(requestBody(init?.body)) as { name?: string };
if (body.name === "llama3:8b") {
return jsonResponse({ model_info: { "llama.context_length": 65536 } });
}
return jsonResponse({});
});
vi.stubGlobal("fetch", fetchMock);
const enriched = await enrichOllamaModelsWithContext("http://127.0.0.1:11434", models);
expect(enriched).toEqual([
{ name: "llama3:8b", contextWindow: 65536 },
{ name: "deepseek-r1:14b", contextWindow: undefined },
]);
});
});

View File

@@ -27,6 +27,12 @@ export type OllamaTagsResponse = {
models?: OllamaTagModel[];
};
export type OllamaModelWithContext = OllamaTagModel & {
contextWindow?: number;
};
const OLLAMA_SHOW_CONCURRENCY = 8;
/**
* Derive the Ollama native API base URL from a configured base URL.
*
@@ -43,6 +49,58 @@ export function resolveOllamaApiBase(configuredBaseUrl?: string): string {
return trimmed.replace(/\/v1$/i, "");
}
export async function queryOllamaContextWindow(
apiBase: string,
modelName: string,
): Promise<number | undefined> {
try {
const response = await fetch(`${apiBase}/api/show`, {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ name: modelName }),
signal: AbortSignal.timeout(3000),
});
if (!response.ok) {
return undefined;
}
const data = (await response.json()) as { model_info?: Record<string, unknown> };
if (!data.model_info) {
return undefined;
}
for (const [key, value] of Object.entries(data.model_info)) {
if (key.endsWith(".context_length") && typeof value === "number" && Number.isFinite(value)) {
const contextWindow = Math.floor(value);
if (contextWindow > 0) {
return contextWindow;
}
}
}
return undefined;
} catch {
return undefined;
}
}
export async function enrichOllamaModelsWithContext(
apiBase: string,
models: OllamaTagModel[],
opts?: { concurrency?: number },
): Promise<OllamaModelWithContext[]> {
const concurrency = Math.max(1, Math.floor(opts?.concurrency ?? OLLAMA_SHOW_CONCURRENCY));
const enriched: OllamaModelWithContext[] = [];
for (let index = 0; index < models.length; index += concurrency) {
const batch = models.slice(index, index + concurrency);
const batchResults = await Promise.all(
batch.map(async (model) => ({
...model,
contextWindow: await queryOllamaContextWindow(apiBase, model.name),
})),
);
enriched.push(...batchResults);
}
return enriched;
}
/** Heuristic: treat models with "r1", "reasoning", or "think" in the name as reasoning models. */
export function isReasoningModelHeuristic(modelId: string): boolean {
return /r1|reasoning|think|reason/i.test(modelId);

View File

@@ -30,6 +30,53 @@ function jsonResponse(body: unknown, status = 200): Response {
});
}
function requestUrl(input: string | URL | Request): string {
if (typeof input === "string") {
return input;
}
if (input instanceof URL) {
return input.toString();
}
return input.url;
}
function requestBody(body: BodyInit | null | undefined): string {
return typeof body === "string" ? body : "{}";
}
function createOllamaFetchMock(params: {
tags?: string[];
show?: Record<string, number | undefined>;
meResponses?: Response[];
pullResponse?: Response;
tagsError?: Error;
}) {
const meResponses = [...(params.meResponses ?? [])];
return vi.fn(async (input: string | URL | Request, init?: RequestInit) => {
const url = requestUrl(input);
if (url.endsWith("/api/tags")) {
if (params.tagsError) {
throw params.tagsError;
}
return jsonResponse({ models: (params.tags ?? []).map((name) => ({ name })) });
}
if (url.endsWith("/api/show")) {
const body = JSON.parse(requestBody(init?.body)) as { name?: string };
const contextWindow = body.name ? params.show?.[body.name] : undefined;
return contextWindow
? jsonResponse({ model_info: { "llama.context_length": contextWindow } })
: jsonResponse({});
}
if (url.endsWith("/api/me")) {
return meResponses.shift() ?? jsonResponse({ username: "testuser" });
}
if (url.endsWith("/api/pull")) {
return params.pullResponse ?? new Response('{"status":"success"}\n', { status: 200 });
}
throw new Error(`Unexpected fetch: ${url}`);
});
}
describe("ollama setup", () => {
afterEach(() => {
vi.unstubAllGlobals();
@@ -45,9 +92,7 @@ describe("ollama setup", () => {
note: vi.fn(async () => undefined),
} as unknown as WizardPrompter;
const fetchMock = vi
.fn()
.mockResolvedValueOnce(jsonResponse({ models: [{ name: "llama3:8b" }] }));
const fetchMock = createOllamaFetchMock({ tags: ["llama3:8b"] });
vi.stubGlobal("fetch", fetchMock);
const result = await promptAndConfigureOllama({ cfg: {}, prompter });
@@ -62,10 +107,7 @@ describe("ollama setup", () => {
note: vi.fn(async () => undefined),
} as unknown as WizardPrompter;
const fetchMock = vi
.fn()
.mockResolvedValueOnce(jsonResponse({ models: [{ name: "llama3:8b" }] }))
.mockResolvedValueOnce(jsonResponse({ username: "testuser" }));
const fetchMock = createOllamaFetchMock({ tags: ["llama3:8b"] });
vi.stubGlobal("fetch", fetchMock);
const result = await promptAndConfigureOllama({ cfg: {}, prompter });
@@ -80,11 +122,7 @@ describe("ollama setup", () => {
note: vi.fn(async () => undefined),
} as unknown as WizardPrompter;
const fetchMock = vi
.fn()
.mockResolvedValueOnce(
jsonResponse({ models: [{ name: "llama3:8b" }, { name: "glm-4.7-flash" }] }),
);
const fetchMock = createOllamaFetchMock({ tags: ["llama3:8b", "glm-4.7-flash"] });
vi.stubGlobal("fetch", fetchMock);
const result = await promptAndConfigureOllama({ cfg: {}, prompter });
@@ -103,13 +141,13 @@ describe("ollama setup", () => {
note: vi.fn(async () => undefined),
} as unknown as WizardPrompter;
const fetchMock = vi
.fn()
.mockResolvedValueOnce(jsonResponse({ models: [{ name: "llama3:8b" }] }))
.mockResolvedValueOnce(
const fetchMock = createOllamaFetchMock({
tags: ["llama3:8b"],
meResponses: [
jsonResponse({ error: "not signed in", signin_url: "https://ollama.com/signin" }, 401),
)
.mockResolvedValueOnce(jsonResponse({ username: "testuser" }));
jsonResponse({ username: "testuser" }),
],
});
vi.stubGlobal("fetch", fetchMock);
await promptAndConfigureOllama({ cfg: {}, prompter });
@@ -127,13 +165,13 @@ describe("ollama setup", () => {
note: vi.fn(async () => undefined),
} as unknown as WizardPrompter;
const fetchMock = vi
.fn()
.mockResolvedValueOnce(jsonResponse({ models: [{ name: "llama3:8b" }] }))
.mockResolvedValueOnce(
const fetchMock = createOllamaFetchMock({
tags: ["llama3:8b"],
meResponses: [
jsonResponse({ error: "not signed in", signin_url: "https://ollama.com/signin" }, 401),
)
.mockResolvedValueOnce(jsonResponse({ username: "testuser" }));
jsonResponse({ username: "testuser" }),
],
});
vi.stubGlobal("fetch", fetchMock);
await promptAndConfigureOllama({ cfg: {}, prompter });
@@ -148,15 +186,16 @@ describe("ollama setup", () => {
note: vi.fn(async () => undefined),
} as unknown as WizardPrompter;
const fetchMock = vi
.fn()
.mockResolvedValueOnce(jsonResponse({ models: [{ name: "llama3:8b" }] }));
const fetchMock = createOllamaFetchMock({ tags: ["llama3:8b"] });
vi.stubGlobal("fetch", fetchMock);
await promptAndConfigureOllama({ cfg: {}, prompter });
expect(fetchMock).toHaveBeenCalledTimes(1);
expect(fetchMock.mock.calls[0][0]).toContain("/api/tags");
expect(fetchMock).toHaveBeenCalledTimes(2);
expect(fetchMock.mock.calls[0]?.[0]).toContain("/api/tags");
expect(fetchMock.mock.calls.some((call) => requestUrl(call[0]).includes("/api/me"))).toBe(
false,
);
});
it("suggested models appear first in model list (cloud+local)", async () => {
@@ -166,14 +205,9 @@ describe("ollama setup", () => {
note: vi.fn(async () => undefined),
} as unknown as WizardPrompter;
const fetchMock = vi
.fn()
.mockResolvedValueOnce(
jsonResponse({
models: [{ name: "llama3:8b" }, { name: "glm-4.7-flash" }, { name: "deepseek-r1:14b" }],
}),
)
.mockResolvedValueOnce(jsonResponse({ username: "testuser" }));
const fetchMock = createOllamaFetchMock({
tags: ["llama3:8b", "glm-4.7-flash", "deepseek-r1:14b"],
});
vi.stubGlobal("fetch", fetchMock);
const result = await promptAndConfigureOllama({ cfg: {}, prompter });
@@ -189,6 +223,27 @@ describe("ollama setup", () => {
]);
});
it("uses /api/show context windows when building Ollama model configs", async () => {
const prompter = {
text: vi.fn().mockResolvedValueOnce("http://127.0.0.1:11434"),
select: vi.fn().mockResolvedValueOnce("local"),
note: vi.fn(async () => undefined),
} as unknown as WizardPrompter;
const fetchMock = createOllamaFetchMock({
tags: ["llama3:8b"],
show: { "llama3:8b": 65536 },
});
vi.stubGlobal("fetch", fetchMock);
const result = await promptAndConfigureOllama({ cfg: {}, prompter });
const model = result.config.models?.providers?.ollama?.models?.find(
(m) => m.id === "llama3:8b",
);
expect(model?.contextWindow).toBe(65536);
});
describe("ensureOllamaModelPulled", () => {
it("pulls model when not available locally", async () => {
const progress = { update: vi.fn(), stop: vi.fn() };
@@ -196,12 +251,10 @@ describe("ollama setup", () => {
progress: vi.fn(() => progress),
} as unknown as WizardPrompter;
const fetchMock = vi
.fn()
// /api/tags — model not present
.mockResolvedValueOnce(jsonResponse({ models: [{ name: "llama3:8b" }] }))
// /api/pull
.mockResolvedValueOnce(new Response('{"status":"success"}\n', { status: 200 }));
const fetchMock = createOllamaFetchMock({
tags: ["llama3:8b"],
pullResponse: new Response('{"status":"success"}\n', { status: 200 }),
});
vi.stubGlobal("fetch", fetchMock);
await ensureOllamaModelPulled({
@@ -219,9 +272,7 @@ describe("ollama setup", () => {
it("skips pull when model is already available", async () => {
const prompter = {} as unknown as WizardPrompter;
const fetchMock = vi
.fn()
.mockResolvedValueOnce(jsonResponse({ models: [{ name: "glm-4.7-flash" }] }));
const fetchMock = createOllamaFetchMock({ tags: ["glm-4.7-flash"] });
vi.stubGlobal("fetch", fetchMock);
await ensureOllamaModelPulled({
@@ -268,10 +319,10 @@ describe("ollama setup", () => {
});
it("uses discovered model when requested non-interactive download fails", async () => {
const fetchMock = vi
.fn()
.mockResolvedValueOnce(jsonResponse({ models: [{ name: "qwen2.5-coder:7b" }] }))
.mockResolvedValueOnce(new Response('{"error":"disk full"}\n', { status: 200 }));
const fetchMock = createOllamaFetchMock({
tags: ["qwen2.5-coder:7b"],
pullResponse: new Response('{"error":"disk full"}\n', { status: 200 }),
});
vi.stubGlobal("fetch", fetchMock);
const runtime = {
@@ -306,10 +357,10 @@ describe("ollama setup", () => {
});
it("normalizes ollama/ prefix in non-interactive custom model download", async () => {
const fetchMock = vi
.fn()
.mockResolvedValueOnce(jsonResponse({ models: [] }))
.mockResolvedValueOnce(new Response('{"status":"success"}\n', { status: 200 }));
const fetchMock = createOllamaFetchMock({
tags: [],
pullResponse: new Response('{"status":"success"}\n', { status: 200 }),
});
vi.stubGlobal("fetch", fetchMock);
const runtime = {
@@ -328,14 +379,14 @@ describe("ollama setup", () => {
});
const pullRequest = fetchMock.mock.calls[1]?.[1];
expect(JSON.parse(String(pullRequest?.body))).toEqual({ name: "llama3.2:latest" });
expect(JSON.parse(requestBody(pullRequest?.body))).toEqual({ name: "llama3.2:latest" });
expect(result.agents?.defaults?.model).toEqual(
expect.objectContaining({ primary: "ollama/llama3.2:latest" }),
);
});
it("accepts cloud models in non-interactive mode without pulling", async () => {
const fetchMock = vi.fn().mockResolvedValueOnce(jsonResponse({ models: [] }));
const fetchMock = createOllamaFetchMock({ tags: [] });
vi.stubGlobal("fetch", fetchMock);
const runtime = {
@@ -363,7 +414,9 @@ describe("ollama setup", () => {
});
it("exits when Ollama is unreachable", async () => {
const fetchMock = vi.fn().mockRejectedValueOnce(new Error("connect ECONNREFUSED"));
const fetchMock = createOllamaFetchMock({
tagsError: new Error("connect ECONNREFUSED"),
});
vi.stubGlobal("fetch", fetchMock);
const runtime = {

View File

@@ -2,8 +2,10 @@ import { upsertAuthProfileWithLock } from "../agents/auth-profiles.js";
import {
OLLAMA_DEFAULT_BASE_URL,
buildOllamaModelDefinition,
enrichOllamaModelsWithContext,
fetchOllamaModels,
resolveOllamaApiBase,
type OllamaModelWithContext,
} from "../agents/ollama-models.js";
import type { OpenClawConfig } from "../config/config.js";
import type { RuntimeEnv } from "../runtime.js";
@@ -239,14 +241,20 @@ async function pullOllamaModelNonInteractive(
return true;
}
function buildOllamaModelsConfig(modelNames: string[]) {
return modelNames.map((name) => buildOllamaModelDefinition(name));
function buildOllamaModelsConfig(
modelNames: string[],
discoveredModelsByName?: Map<string, OllamaModelWithContext>,
) {
return modelNames.map((name) =>
buildOllamaModelDefinition(name, discoveredModelsByName?.get(name)?.contextWindow),
);
}
function applyOllamaProviderConfig(
cfg: OpenClawConfig,
baseUrl: string,
modelNames: string[],
discoveredModelsByName?: Map<string, OllamaModelWithContext>,
): OpenClawConfig {
return {
...cfg,
@@ -259,7 +267,7 @@ function applyOllamaProviderConfig(
baseUrl,
api: "ollama",
apiKey: "OLLAMA_API_KEY", // pragma: allowlist secret
models: buildOllamaModelsConfig(modelNames),
models: buildOllamaModelsConfig(modelNames, discoveredModelsByName),
},
},
},
@@ -299,7 +307,6 @@ export async function promptAndConfigureOllama(params: {
// 2. Check reachability
const { reachable, models } = await fetchOllamaModels(baseUrl);
const modelNames = models.map((m) => m.name);
if (!reachable) {
await prompter.note(
@@ -314,6 +321,10 @@ export async function promptAndConfigureOllama(params: {
throw new WizardCancelledError("Ollama not reachable");
}
const enrichedModels = await enrichOllamaModelsWithContext(baseUrl, models.slice(0, 50));
const discoveredModelsByName = new Map(enrichedModels.map((model) => [model.name, model]));
const modelNames = models.map((m) => m.name);
// 3. Mode selection
const mode = (await prompter.select({
message: "Ollama mode",
@@ -387,7 +398,12 @@ export async function promptAndConfigureOllama(params: {
await storeOllamaCredential(params.agentDir);
const defaultModelId = suggestedModels[0] ?? OLLAMA_DEFAULT_MODEL;
const config = applyOllamaProviderConfig(params.cfg, baseUrl, orderedModelNames);
const config = applyOllamaProviderConfig(
params.cfg,
baseUrl,
orderedModelNames,
discoveredModelsByName,
);
return { config, defaultModelId };
}
@@ -405,7 +421,6 @@ export async function configureOllamaNonInteractive(params: {
const baseUrl = resolveOllamaApiBase(configuredBaseUrl);
const { reachable, models } = await fetchOllamaModels(baseUrl);
const modelNames = models.map((m) => m.name);
const explicitModel = normalizeOllamaModelName(opts.customModelId);
if (!reachable) {
@@ -421,6 +436,10 @@ export async function configureOllamaNonInteractive(params: {
await storeOllamaCredential();
const enrichedModels = await enrichOllamaModelsWithContext(baseUrl, models.slice(0, 50));
const discoveredModelsByName = new Map(enrichedModels.map((model) => [model.name, model]));
const modelNames = models.map((m) => m.name);
// Apply local suggested model ordering.
const suggestedModels = OLLAMA_SUGGESTED_MODELS_LOCAL;
const orderedModelNames = [
@@ -478,7 +497,12 @@ export async function configureOllamaNonInteractive(params: {
}
}
const config = applyOllamaProviderConfig(params.nextConfig, baseUrl, allModelNames);
const config = applyOllamaProviderConfig(
params.nextConfig,
baseUrl,
allModelNames,
discoveredModelsByName,
);
const modelRef = `ollama/${defaultModelId}`;
runtime.log(`Default Ollama model: ${defaultModelId}`);
return applyAgentDefaultModelPrimary(config, modelRef);