diff --git a/src/agents/huggingface-models.ts b/src/agents/huggingface-models.ts index 7d3755adefb..0e7ae4270f7 100644 --- a/src/agents/huggingface-models.ts +++ b/src/agents/huggingface-models.ts @@ -1,5 +1,6 @@ import type { ModelDefinitionConfig } from "../config/types.models.js"; import { createSubsystemLogger } from "../logging/subsystem.js"; +import { isReasoningModelHeuristic } from "./ollama-models.js"; const log = createSubsystemLogger("huggingface-models"); @@ -125,7 +126,7 @@ export function buildHuggingfaceModelDefinition( */ function inferredMetaFromModelId(id: string): { name: string; reasoning: boolean } { const base = id.split("/").pop() ?? id; - const reasoning = /r1|reasoning|thinking|reason/i.test(id) || /-\d+[tb]?-thinking/i.test(base); + const reasoning = isReasoningModelHeuristic(id); const name = base.replace(/-/g, " ").replace(/\b(\w)/g, (c) => c.toUpperCase()); return { name, reasoning }; } diff --git a/src/agents/models-config.providers.discovery.ts b/src/agents/models-config.providers.discovery.ts index caab5cafb4e..dd0504d2a53 100644 --- a/src/agents/models-config.providers.discovery.ts +++ b/src/agents/models-config.providers.discovery.ts @@ -9,27 +9,26 @@ import { buildHuggingfaceModelDefinition, } from "./huggingface-models.js"; import { discoverKilocodeModels } from "./kilocode-models.js"; -import { OLLAMA_NATIVE_BASE_URL } from "./ollama-stream.js"; +import { + OLLAMA_DEFAULT_CONTEXT_WINDOW, + OLLAMA_DEFAULT_COST, + OLLAMA_DEFAULT_MAX_TOKENS, + isReasoningModelHeuristic, + resolveOllamaApiBase, + type OllamaTagsResponse, +} from "./ollama-models.js"; import { discoverVeniceModels, VENICE_BASE_URL } from "./venice-models.js"; import { discoverVercelAiGatewayModels, VERCEL_AI_GATEWAY_BASE_URL } from "./vercel-ai-gateway.js"; +export { resolveOllamaApiBase } from "./ollama-models.js"; + type ModelsConfig = NonNullable; type ProviderConfig = NonNullable[string]; const log = createSubsystemLogger("agents/model-providers"); -const OLLAMA_BASE_URL = OLLAMA_NATIVE_BASE_URL; -const OLLAMA_API_BASE_URL = OLLAMA_BASE_URL; const OLLAMA_SHOW_CONCURRENCY = 8; const OLLAMA_SHOW_MAX_MODELS = 200; -const OLLAMA_DEFAULT_CONTEXT_WINDOW = 128000; -const OLLAMA_DEFAULT_MAX_TOKENS = 8192; -const OLLAMA_DEFAULT_COST = { - input: 0, - output: 0, - cacheRead: 0, - cacheWrite: 0, -}; const VLLM_BASE_URL = "http://127.0.0.1:8000/v1"; const VLLM_DEFAULT_CONTEXT_WINDOW = 128000; @@ -41,44 +40,12 @@ const VLLM_DEFAULT_COST = { cacheWrite: 0, }; -interface OllamaModel { - name: string; - modified_at: string; - size: number; - digest: string; - details?: { - family?: string; - parameter_size?: string; - }; -} - -interface OllamaTagsResponse { - models: OllamaModel[]; -} - type VllmModelsResponse = { data?: Array<{ id?: string; }>; }; -/** - * Derive the Ollama native API base URL from a configured base URL. - * - * Users typically configure `baseUrl` with a `/v1` suffix (e.g. - * `http://192.168.20.14:11434/v1`) for the OpenAI-compatible endpoint. - * The native Ollama API lives at the root (e.g. `/api/tags`), so we - * strip the `/v1` suffix when present. - */ -export function resolveOllamaApiBase(configuredBaseUrl?: string): string { - if (!configuredBaseUrl) { - return OLLAMA_API_BASE_URL; - } - // Strip trailing slash, then strip /v1 suffix if present - const trimmed = configuredBaseUrl.replace(/\/+$/, ""); - return trimmed.replace(/\/v1$/i, ""); -} - async function queryOllamaContextWindow( apiBase: string, modelName: string, @@ -147,12 +114,10 @@ async function discoverOllamaModels( batch.map(async (model) => { const modelId = model.name; const contextWindow = await queryOllamaContextWindow(apiBase, modelId); - const isReasoning = - modelId.toLowerCase().includes("r1") || modelId.toLowerCase().includes("reasoning"); return { id: modelId, name: modelId, - reasoning: isReasoning, + reasoning: isReasoningModelHeuristic(modelId), input: ["text"], cost: OLLAMA_DEFAULT_COST, contextWindow: contextWindow ?? OLLAMA_DEFAULT_CONTEXT_WINDOW, @@ -204,13 +169,10 @@ async function discoverVllmModels( .filter((model) => Boolean(model.id)) .map((model) => { const modelId = model.id; - const lower = modelId.toLowerCase(); - const isReasoning = - lower.includes("r1") || lower.includes("reasoning") || lower.includes("think"); return { id: modelId, name: modelId, - reasoning: isReasoning, + reasoning: isReasoningModelHeuristic(modelId), input: ["text"], cost: VLLM_DEFAULT_COST, contextWindow: VLLM_DEFAULT_CONTEXT_WINDOW, diff --git a/src/agents/ollama-models.ts b/src/agents/ollama-models.ts new file mode 100644 index 00000000000..19d95605203 --- /dev/null +++ b/src/agents/ollama-models.ts @@ -0,0 +1,85 @@ +import type { ModelDefinitionConfig } from "../config/types.models.js"; +import { OLLAMA_NATIVE_BASE_URL } from "./ollama-stream.js"; + +export const OLLAMA_DEFAULT_BASE_URL = OLLAMA_NATIVE_BASE_URL; +export const OLLAMA_DEFAULT_CONTEXT_WINDOW = 128000; +export const OLLAMA_DEFAULT_MAX_TOKENS = 8192; +export const OLLAMA_DEFAULT_COST = { + input: 0, + output: 0, + cacheRead: 0, + cacheWrite: 0, +}; + +export type OllamaTagModel = { + name: string; + modified_at?: string; + size?: number; + digest?: string; + remote_host?: string; + details?: { + family?: string; + parameter_size?: string; + }; +}; + +export type OllamaTagsResponse = { + models?: OllamaTagModel[]; +}; + +/** + * Derive the Ollama native API base URL from a configured base URL. + * + * Users typically configure `baseUrl` with a `/v1` suffix (e.g. + * `http://192.168.20.14:11434/v1`) for the OpenAI-compatible endpoint. + * The native Ollama API lives at the root (e.g. `/api/tags`), so we + * strip the `/v1` suffix when present. + */ +export function resolveOllamaApiBase(configuredBaseUrl?: string): string { + if (!configuredBaseUrl) { + return OLLAMA_DEFAULT_BASE_URL; + } + const trimmed = configuredBaseUrl.replace(/\/+$/, ""); + return trimmed.replace(/\/v1$/i, ""); +} + +/** Heuristic: treat models with "r1", "reasoning", or "think" in the name as reasoning models. */ +export function isReasoningModelHeuristic(modelId: string): boolean { + return /r1|reasoning|think|reason/i.test(modelId); +} + +/** Build a ModelDefinitionConfig for an Ollama model with default values. */ +export function buildOllamaModelDefinition( + modelId: string, + contextWindow?: number, +): ModelDefinitionConfig { + return { + id: modelId, + name: modelId, + reasoning: isReasoningModelHeuristic(modelId), + input: ["text"], + cost: OLLAMA_DEFAULT_COST, + contextWindow: contextWindow ?? OLLAMA_DEFAULT_CONTEXT_WINDOW, + maxTokens: OLLAMA_DEFAULT_MAX_TOKENS, + }; +} + +/** Fetch the model list from a running Ollama instance. */ +export async function fetchOllamaModels( + baseUrl: string, +): Promise<{ reachable: boolean; models: OllamaTagModel[] }> { + try { + const apiBase = resolveOllamaApiBase(baseUrl); + const response = await fetch(`${apiBase}/api/tags`, { + signal: AbortSignal.timeout(5000), + }); + if (!response.ok) { + return { reachable: true, models: [] }; + } + const data = (await response.json()) as OllamaTagsResponse; + const models = (data.models ?? []).filter((m) => m.name); + return { reachable: true, models }; + } catch { + return { reachable: false, models: [] }; + } +} diff --git a/src/commands/auth-choice-options.test.ts b/src/commands/auth-choice-options.test.ts index e86f5d5c361..462dbb32d11 100644 --- a/src/commands/auth-choice-options.test.ts +++ b/src/commands/auth-choice-options.test.ts @@ -42,6 +42,7 @@ describe("buildAuthChoiceOptions", () => { "byteplus-api-key", "vllm", "opencode-go", + "ollama", ]) { expect(options.some((opt) => opt.value === value)).toBe(true); } @@ -93,4 +94,15 @@ describe("buildAuthChoiceOptions", () => { expect(openCodeGroup?.options.some((opt) => opt.value === "opencode-zen")).toBe(true); expect(openCodeGroup?.options.some((opt) => opt.value === "opencode-go")).toBe(true); }); + + it("shows Ollama in grouped provider selection", () => { + const { groups } = buildAuthChoiceGroups({ + store: EMPTY_STORE, + includeSkip: false, + }); + const ollamaGroup = groups.find((group) => group.value === "ollama"); + + expect(ollamaGroup).toBeDefined(); + expect(ollamaGroup?.options.some((opt) => opt.value === "ollama")).toBe(true); + }); }); diff --git a/src/commands/auth-choice-options.ts b/src/commands/auth-choice-options.ts index 33b3752e585..077fee024b9 100644 --- a/src/commands/auth-choice-options.ts +++ b/src/commands/auth-choice-options.ts @@ -47,6 +47,12 @@ const AUTH_CHOICE_GROUP_DEFS: { hint: "Local/self-hosted OpenAI-compatible", choices: ["vllm"], }, + { + value: "ollama", + label: "Ollama", + hint: "Cloud and local open models", + choices: ["ollama"], + }, { value: "minimax", label: "MiniMax", @@ -238,6 +244,11 @@ const BASE_AUTH_CHOICE_OPTIONS: ReadonlyArray = [ label: "vLLM (custom URL + model)", hint: "Local/self-hosted OpenAI-compatible server", }, + { + value: "ollama", + label: "Ollama", + hint: "Cloud and local open models", + }, ...buildProviderAuthChoiceOptions(), { value: "moonshot-api-key-cn", diff --git a/src/commands/auth-choice.apply.ollama.test.ts b/src/commands/auth-choice.apply.ollama.test.ts new file mode 100644 index 00000000000..f6739a88ad1 --- /dev/null +++ b/src/commands/auth-choice.apply.ollama.test.ts @@ -0,0 +1,83 @@ +import { describe, expect, it, vi } from "vitest"; +import type { ApplyAuthChoiceParams } from "./auth-choice.apply.js"; +import { applyAuthChoiceOllama } from "./auth-choice.apply.ollama.js"; + +type PromptAndConfigureOllama = typeof import("./ollama-setup.js").promptAndConfigureOllama; + +const promptAndConfigureOllama = vi.hoisted(() => + vi.fn(async ({ cfg }) => ({ + config: cfg, + defaultModelId: "qwen3.5:35b", + })), +); +const ensureOllamaModelPulled = vi.hoisted(() => vi.fn(async () => {})); +vi.mock("./ollama-setup.js", () => ({ + promptAndConfigureOllama, + ensureOllamaModelPulled, +})); + +function buildParams(overrides: Partial = {}): ApplyAuthChoiceParams { + return { + authChoice: "ollama", + config: {}, + prompter: {} as ApplyAuthChoiceParams["prompter"], + runtime: {} as ApplyAuthChoiceParams["runtime"], + setDefaultModel: false, + ...overrides, + }; +} + +describe("applyAuthChoiceOllama", () => { + it("returns agentModelOverride when setDefaultModel is false", async () => { + const config = { agents: { defaults: { model: { primary: "openai/gpt-4o-mini" } } } }; + promptAndConfigureOllama.mockResolvedValueOnce({ + config, + defaultModelId: "qwen2.5-coder:7b", + }); + + const result = await applyAuthChoiceOllama( + buildParams({ + config, + setDefaultModel: false, + }), + ); + + expect(result).toEqual({ + config, + agentModelOverride: "ollama/qwen2.5-coder:7b", + }); + // Pull is deferred — the wizard model picker handles it. + expect(ensureOllamaModelPulled).not.toHaveBeenCalled(); + }); + + it("sets global default model and preserves fallbacks when setDefaultModel is true", async () => { + const config = { + agents: { + defaults: { + model: { + primary: "openai/gpt-4o-mini", + fallbacks: ["anthropic/claude-sonnet-4-5"], + }, + }, + }, + }; + promptAndConfigureOllama.mockResolvedValueOnce({ + config, + defaultModelId: "qwen2.5-coder:7b", + }); + + const result = await applyAuthChoiceOllama( + buildParams({ + config, + setDefaultModel: true, + }), + ); + + expect(result?.agentModelOverride).toBeUndefined(); + expect(result?.config.agents?.defaults?.model).toEqual({ + primary: "ollama/qwen2.5-coder:7b", + fallbacks: ["anthropic/claude-sonnet-4-5"], + }); + expect(ensureOllamaModelPulled).toHaveBeenCalledOnce(); + }); +}); diff --git a/src/commands/auth-choice.apply.ollama.ts b/src/commands/auth-choice.apply.ollama.ts new file mode 100644 index 00000000000..640b57431cf --- /dev/null +++ b/src/commands/auth-choice.apply.ollama.ts @@ -0,0 +1,31 @@ +import type { ApplyAuthChoiceParams, ApplyAuthChoiceResult } from "./auth-choice.apply.js"; +import { ensureOllamaModelPulled, promptAndConfigureOllama } from "./ollama-setup.js"; +import { applyAgentDefaultModelPrimary } from "./onboard-auth.config-shared.js"; + +export async function applyAuthChoiceOllama( + params: ApplyAuthChoiceParams, +): Promise { + if (params.authChoice !== "ollama") { + return null; + } + + const { config, defaultModelId } = await promptAndConfigureOllama({ + cfg: params.config, + prompter: params.prompter, + agentDir: params.agentDir, + }); + + // Set an Ollama default so the model picker pre-selects an Ollama model. + const defaultModel = `ollama/${defaultModelId}`; + const configWithDefault = applyAgentDefaultModelPrimary(config, defaultModel); + + if (!params.setDefaultModel) { + // Defer pulling: the interactive wizard will show a model picker next, + // so avoid downloading a model the user may not choose. + return { config, agentModelOverride: defaultModel }; + } + + await ensureOllamaModelPulled({ config: configWithDefault, prompter: params.prompter }); + + return { config: configWithDefault }; +} diff --git a/src/commands/auth-choice.apply.ts b/src/commands/auth-choice.apply.ts index e6dfa9ed52a..36591304da0 100644 --- a/src/commands/auth-choice.apply.ts +++ b/src/commands/auth-choice.apply.ts @@ -9,6 +9,7 @@ import { applyAuthChoiceGitHubCopilot } from "./auth-choice.apply.github-copilot import { applyAuthChoiceGoogleGeminiCli } from "./auth-choice.apply.google-gemini-cli.js"; import { applyAuthChoiceMiniMax } from "./auth-choice.apply.minimax.js"; import { applyAuthChoiceOAuth } from "./auth-choice.apply.oauth.js"; +import { applyAuthChoiceOllama } from "./auth-choice.apply.ollama.js"; import { applyAuthChoiceOpenAI } from "./auth-choice.apply.openai.js"; import { applyAuthChoiceQwenPortal } from "./auth-choice.apply.qwen-portal.js"; import { applyAuthChoiceVllm } from "./auth-choice.apply.vllm.js"; @@ -38,6 +39,7 @@ export async function applyAuthChoice( const handlers: Array<(p: ApplyAuthChoiceParams) => Promise> = [ applyAuthChoiceAnthropic, applyAuthChoiceVllm, + applyAuthChoiceOllama, applyAuthChoiceOpenAI, applyAuthChoiceOAuth, applyAuthChoiceApiProviders, diff --git a/src/commands/auth-choice.preferred-provider.ts b/src/commands/auth-choice.preferred-provider.ts index 4f94e0e4d6f..7ebc0b24ea1 100644 --- a/src/commands/auth-choice.preferred-provider.ts +++ b/src/commands/auth-choice.preferred-provider.ts @@ -7,6 +7,7 @@ const PREFERRED_PROVIDER_BY_AUTH_CHOICE: Partial> = { token: "anthropic", apiKey: "anthropic", vllm: "vllm", + ollama: "ollama", "openai-codex": "openai-codex", "codex-cli": "openai-codex", chutes: "chutes", diff --git a/src/commands/auth-choice.test.ts b/src/commands/auth-choice.test.ts index 200471971a2..6cdf32fa1d2 100644 --- a/src/commands/auth-choice.test.ts +++ b/src/commands/auth-choice.test.ts @@ -22,6 +22,7 @@ import { } from "./test-wizard-helpers.js"; type DetectZaiEndpoint = typeof import("./zai-endpoint-detect.js").detectZaiEndpoint; +type PromptAndConfigureOllama = typeof import("./ollama-setup.js").promptAndConfigureOllama; vi.mock("../providers/github-copilot-auth.js", () => ({ githubCopilotLoginCommand: vi.fn(async () => {}), @@ -44,6 +45,16 @@ vi.mock("./zai-endpoint-detect.js", () => ({ detectZaiEndpoint, })); +const promptAndConfigureOllama = vi.hoisted(() => + vi.fn(async ({ cfg }) => ({ + config: cfg, + defaultModelId: "qwen3.5:35b", + })), +); +vi.mock("./ollama-setup.js", () => ({ + promptAndConfigureOllama, +})); + type StoredAuthProfile = { key?: string; keyRef?: { source: string; provider: string; id: string }; @@ -131,6 +142,11 @@ describe("applyAuthChoice", () => { detectZaiEndpoint.mockResolvedValue(null); loginOpenAICodexOAuth.mockReset(); loginOpenAICodexOAuth.mockResolvedValue(null); + promptAndConfigureOllama.mockReset(); + promptAndConfigureOllama.mockImplementation(async ({ cfg }) => ({ + config: cfg, + defaultModelId: "qwen3.5:35b", + })); await lifecycle.cleanup(); activeStateDir = null; }); @@ -1350,6 +1366,7 @@ describe("resolvePreferredProviderForAuthChoice", () => { { authChoice: "github-copilot" as const, expectedProvider: "github-copilot" }, { authChoice: "qwen-portal" as const, expectedProvider: "qwen-portal" }, { authChoice: "mistral-api-key" as const, expectedProvider: "mistral" }, + { authChoice: "ollama" as const, expectedProvider: "ollama" }, { authChoice: "unknown" as AuthChoice, expectedProvider: undefined }, ] as const; for (const scenario of scenarios) { diff --git a/src/commands/ollama-setup.test.ts b/src/commands/ollama-setup.test.ts new file mode 100644 index 00000000000..2313588f180 --- /dev/null +++ b/src/commands/ollama-setup.test.ts @@ -0,0 +1,391 @@ +import { afterEach, describe, expect, it, vi } from "vitest"; +import type { RuntimeEnv } from "../runtime.js"; +import type { WizardPrompter } from "../wizard/prompts.js"; +import { + configureOllamaNonInteractive, + ensureOllamaModelPulled, + promptAndConfigureOllama, +} from "./ollama-setup.js"; + +const upsertAuthProfileWithLock = vi.hoisted(() => vi.fn(async () => {})); +vi.mock("../agents/auth-profiles.js", () => ({ + upsertAuthProfileWithLock, +})); + +const openUrlMock = vi.hoisted(() => vi.fn(async () => false)); +vi.mock("./onboard-helpers.js", async (importOriginal) => { + const original = await importOriginal(); + return { ...original, openUrl: openUrlMock }; +}); + +const isRemoteEnvironmentMock = vi.hoisted(() => vi.fn(() => false)); +vi.mock("./oauth-env.js", () => ({ + isRemoteEnvironment: isRemoteEnvironmentMock, +})); + +function jsonResponse(body: unknown, status = 200): Response { + return new Response(JSON.stringify(body), { + status, + headers: { "Content-Type": "application/json" }, + }); +} + +describe("ollama setup", () => { + afterEach(() => { + vi.unstubAllGlobals(); + upsertAuthProfileWithLock.mockClear(); + openUrlMock.mockClear(); + isRemoteEnvironmentMock.mockReset().mockReturnValue(false); + }); + + it("returns suggested default model for local mode", async () => { + const prompter = { + text: vi.fn().mockResolvedValueOnce("http://127.0.0.1:11434"), + select: vi.fn().mockResolvedValueOnce("local"), + note: vi.fn(async () => undefined), + } as unknown as WizardPrompter; + + const fetchMock = vi + .fn() + .mockResolvedValueOnce(jsonResponse({ models: [{ name: "llama3:8b" }] })); + vi.stubGlobal("fetch", fetchMock); + + const result = await promptAndConfigureOllama({ cfg: {}, prompter }); + + expect(result.defaultModelId).toBe("glm-4.7-flash"); + }); + + it("returns suggested default model for remote mode", async () => { + const prompter = { + text: vi.fn().mockResolvedValueOnce("http://127.0.0.1:11434"), + select: vi.fn().mockResolvedValueOnce("remote"), + note: vi.fn(async () => undefined), + } as unknown as WizardPrompter; + + const fetchMock = vi + .fn() + .mockResolvedValueOnce(jsonResponse({ models: [{ name: "llama3:8b" }] })) + .mockResolvedValueOnce(jsonResponse({ username: "testuser" })); + vi.stubGlobal("fetch", fetchMock); + + const result = await promptAndConfigureOllama({ cfg: {}, prompter }); + + expect(result.defaultModelId).toBe("kimi-k2.5:cloud"); + }); + + it("mode selection affects model ordering (local)", async () => { + const prompter = { + text: vi.fn().mockResolvedValueOnce("http://127.0.0.1:11434"), + select: vi.fn().mockResolvedValueOnce("local"), + note: vi.fn(async () => undefined), + } as unknown as WizardPrompter; + + const fetchMock = vi + .fn() + .mockResolvedValueOnce( + jsonResponse({ models: [{ name: "llama3:8b" }, { name: "glm-4.7-flash" }] }), + ); + vi.stubGlobal("fetch", fetchMock); + + const result = await promptAndConfigureOllama({ cfg: {}, prompter }); + + expect(result.defaultModelId).toBe("glm-4.7-flash"); + const modelIds = result.config.models?.providers?.ollama?.models?.map((m) => m.id); + expect(modelIds?.[0]).toBe("glm-4.7-flash"); + expect(modelIds).toContain("llama3:8b"); + }); + + it("cloud+local mode triggers /api/me check and opens sign-in URL", async () => { + const prompter = { + text: vi.fn().mockResolvedValueOnce("http://127.0.0.1:11434"), + select: vi.fn().mockResolvedValueOnce("remote"), + confirm: vi.fn().mockResolvedValueOnce(true), + note: vi.fn(async () => undefined), + } as unknown as WizardPrompter; + + const fetchMock = vi + .fn() + .mockResolvedValueOnce(jsonResponse({ models: [{ name: "llama3:8b" }] })) + .mockResolvedValueOnce( + jsonResponse({ error: "not signed in", signin_url: "https://ollama.com/signin" }, 401), + ) + .mockResolvedValueOnce(jsonResponse({ username: "testuser" })); + vi.stubGlobal("fetch", fetchMock); + + await promptAndConfigureOllama({ cfg: {}, prompter }); + + expect(openUrlMock).toHaveBeenCalledWith("https://ollama.com/signin"); + expect(prompter.confirm).toHaveBeenCalled(); + }); + + it("cloud+local mode does not open browser in remote environment", async () => { + isRemoteEnvironmentMock.mockReturnValue(true); + const prompter = { + text: vi.fn().mockResolvedValueOnce("http://127.0.0.1:11434"), + select: vi.fn().mockResolvedValueOnce("remote"), + confirm: vi.fn().mockResolvedValueOnce(true), + note: vi.fn(async () => undefined), + } as unknown as WizardPrompter; + + const fetchMock = vi + .fn() + .mockResolvedValueOnce(jsonResponse({ models: [{ name: "llama3:8b" }] })) + .mockResolvedValueOnce( + jsonResponse({ error: "not signed in", signin_url: "https://ollama.com/signin" }, 401), + ) + .mockResolvedValueOnce(jsonResponse({ username: "testuser" })); + vi.stubGlobal("fetch", fetchMock); + + await promptAndConfigureOllama({ cfg: {}, prompter }); + + expect(openUrlMock).not.toHaveBeenCalled(); + }); + + it("local mode does not trigger cloud auth", async () => { + const prompter = { + text: vi.fn().mockResolvedValueOnce("http://127.0.0.1:11434"), + select: vi.fn().mockResolvedValueOnce("local"), + note: vi.fn(async () => undefined), + } as unknown as WizardPrompter; + + const fetchMock = vi + .fn() + .mockResolvedValueOnce(jsonResponse({ models: [{ name: "llama3:8b" }] })); + vi.stubGlobal("fetch", fetchMock); + + await promptAndConfigureOllama({ cfg: {}, prompter }); + + expect(fetchMock).toHaveBeenCalledTimes(1); + expect(fetchMock.mock.calls[0][0]).toContain("/api/tags"); + }); + + it("suggested models appear first in model list (cloud+local)", async () => { + const prompter = { + text: vi.fn().mockResolvedValueOnce("http://127.0.0.1:11434"), + select: vi.fn().mockResolvedValueOnce("remote"), + note: vi.fn(async () => undefined), + } as unknown as WizardPrompter; + + const fetchMock = vi + .fn() + .mockResolvedValueOnce( + jsonResponse({ + models: [{ name: "llama3:8b" }, { name: "glm-4.7-flash" }, { name: "deepseek-r1:14b" }], + }), + ) + .mockResolvedValueOnce(jsonResponse({ username: "testuser" })); + vi.stubGlobal("fetch", fetchMock); + + const result = await promptAndConfigureOllama({ cfg: {}, prompter }); + const modelIds = result.config.models?.providers?.ollama?.models?.map((m) => m.id); + + expect(modelIds).toEqual([ + "kimi-k2.5:cloud", + "minimax-m2.5:cloud", + "glm-5:cloud", + "llama3:8b", + "glm-4.7-flash", + "deepseek-r1:14b", + ]); + }); + + describe("ensureOllamaModelPulled", () => { + it("pulls model when not available locally", async () => { + const progress = { update: vi.fn(), stop: vi.fn() }; + const prompter = { + progress: vi.fn(() => progress), + } as unknown as WizardPrompter; + + const fetchMock = vi + .fn() + // /api/tags — model not present + .mockResolvedValueOnce(jsonResponse({ models: [{ name: "llama3:8b" }] })) + // /api/pull + .mockResolvedValueOnce(new Response('{"status":"success"}\n', { status: 200 })); + vi.stubGlobal("fetch", fetchMock); + + await ensureOllamaModelPulled({ + config: { + agents: { defaults: { model: { primary: "ollama/glm-4.7-flash" } } }, + models: { providers: { ollama: { baseUrl: "http://127.0.0.1:11434", models: [] } } }, + }, + prompter, + }); + + expect(fetchMock).toHaveBeenCalledTimes(2); + expect(fetchMock.mock.calls[1][0]).toContain("/api/pull"); + }); + + it("skips pull when model is already available", async () => { + const prompter = {} as unknown as WizardPrompter; + + const fetchMock = vi + .fn() + .mockResolvedValueOnce(jsonResponse({ models: [{ name: "glm-4.7-flash" }] })); + vi.stubGlobal("fetch", fetchMock); + + await ensureOllamaModelPulled({ + config: { + agents: { defaults: { model: { primary: "ollama/glm-4.7-flash" } } }, + models: { providers: { ollama: { baseUrl: "http://127.0.0.1:11434", models: [] } } }, + }, + prompter, + }); + + expect(fetchMock).toHaveBeenCalledTimes(1); + }); + + it("skips pull for cloud models", async () => { + const prompter = {} as unknown as WizardPrompter; + const fetchMock = vi.fn(); + vi.stubGlobal("fetch", fetchMock); + + await ensureOllamaModelPulled({ + config: { + agents: { defaults: { model: { primary: "ollama/kimi-k2.5:cloud" } } }, + models: { providers: { ollama: { baseUrl: "http://127.0.0.1:11434", models: [] } } }, + }, + prompter, + }); + + expect(fetchMock).not.toHaveBeenCalled(); + }); + + it("skips when model is not an ollama model", async () => { + const prompter = {} as unknown as WizardPrompter; + const fetchMock = vi.fn(); + vi.stubGlobal("fetch", fetchMock); + + await ensureOllamaModelPulled({ + config: { + agents: { defaults: { model: { primary: "openai/gpt-4o" } } }, + }, + prompter, + }); + + expect(fetchMock).not.toHaveBeenCalled(); + }); + }); + + it("uses discovered model when requested non-interactive download fails", async () => { + const fetchMock = vi + .fn() + .mockResolvedValueOnce(jsonResponse({ models: [{ name: "qwen2.5-coder:7b" }] })) + .mockResolvedValueOnce(new Response('{"error":"disk full"}\n', { status: 200 })); + vi.stubGlobal("fetch", fetchMock); + + const runtime = { + log: vi.fn(), + error: vi.fn(), + exit: vi.fn(), + } as unknown as RuntimeEnv; + + const result = await configureOllamaNonInteractive({ + nextConfig: { + agents: { + defaults: { + model: { + primary: "openai/gpt-4o-mini", + fallbacks: ["anthropic/claude-sonnet-4-5"], + }, + }, + }, + }, + opts: { + customBaseUrl: "http://127.0.0.1:11434", + customModelId: "missing-model", + }, + runtime, + }); + + expect(runtime.error).toHaveBeenCalledWith("Download failed: disk full"); + expect(result.agents?.defaults?.model).toEqual({ + primary: "ollama/qwen2.5-coder:7b", + fallbacks: ["anthropic/claude-sonnet-4-5"], + }); + }); + + it("normalizes ollama/ prefix in non-interactive custom model download", async () => { + const fetchMock = vi + .fn() + .mockResolvedValueOnce(jsonResponse({ models: [] })) + .mockResolvedValueOnce(new Response('{"status":"success"}\n', { status: 200 })); + vi.stubGlobal("fetch", fetchMock); + + const runtime = { + log: vi.fn(), + error: vi.fn(), + exit: vi.fn(), + } as unknown as RuntimeEnv; + + const result = await configureOllamaNonInteractive({ + nextConfig: {}, + opts: { + customBaseUrl: "http://127.0.0.1:11434", + customModelId: "ollama/llama3.2:latest", + }, + runtime, + }); + + const pullRequest = fetchMock.mock.calls[1]?.[1]; + expect(JSON.parse(String(pullRequest?.body))).toEqual({ name: "llama3.2:latest" }); + expect(result.agents?.defaults?.model).toEqual( + expect.objectContaining({ primary: "ollama/llama3.2:latest" }), + ); + }); + + it("accepts cloud models in non-interactive mode without pulling", async () => { + const fetchMock = vi.fn().mockResolvedValueOnce(jsonResponse({ models: [] })); + vi.stubGlobal("fetch", fetchMock); + + const runtime = { + log: vi.fn(), + error: vi.fn(), + exit: vi.fn(), + } as unknown as RuntimeEnv; + + const result = await configureOllamaNonInteractive({ + nextConfig: {}, + opts: { + customBaseUrl: "http://127.0.0.1:11434", + customModelId: "kimi-k2.5:cloud", + }, + runtime, + }); + + expect(fetchMock).toHaveBeenCalledTimes(1); + expect(result.models?.providers?.ollama?.models?.map((model) => model.id)).toContain( + "kimi-k2.5:cloud", + ); + expect(result.agents?.defaults?.model).toEqual( + expect.objectContaining({ primary: "ollama/kimi-k2.5:cloud" }), + ); + }); + + it("exits when Ollama is unreachable", async () => { + const fetchMock = vi.fn().mockRejectedValueOnce(new Error("connect ECONNREFUSED")); + vi.stubGlobal("fetch", fetchMock); + + const runtime = { + log: vi.fn(), + error: vi.fn(), + exit: vi.fn(), + } as unknown as RuntimeEnv; + const nextConfig = {}; + + const result = await configureOllamaNonInteractive({ + nextConfig, + opts: { + customBaseUrl: "http://127.0.0.1:11435", + customModelId: "llama3.2:latest", + }, + runtime, + }); + + expect(runtime.error).toHaveBeenCalledWith( + expect.stringContaining("Ollama could not be reached at http://127.0.0.1:11435."), + ); + expect(runtime.exit).toHaveBeenCalledWith(1); + expect(result).toBe(nextConfig); + }); +}); diff --git a/src/commands/ollama-setup.ts b/src/commands/ollama-setup.ts new file mode 100644 index 00000000000..7bffaf729e5 --- /dev/null +++ b/src/commands/ollama-setup.ts @@ -0,0 +1,511 @@ +import { upsertAuthProfileWithLock } from "../agents/auth-profiles.js"; +import { + OLLAMA_DEFAULT_BASE_URL, + buildOllamaModelDefinition, + fetchOllamaModels, + resolveOllamaApiBase, +} from "../agents/ollama-models.js"; +import type { OpenClawConfig } from "../config/config.js"; +import type { RuntimeEnv } from "../runtime.js"; +import { WizardCancelledError, type WizardPrompter } from "../wizard/prompts.js"; +import { isRemoteEnvironment } from "./oauth-env.js"; +import { applyAgentDefaultModelPrimary } from "./onboard-auth.config-shared.js"; +import { openUrl } from "./onboard-helpers.js"; +import type { OnboardMode, OnboardOptions } from "./onboard-types.js"; + +export { OLLAMA_DEFAULT_BASE_URL } from "../agents/ollama-models.js"; +export const OLLAMA_DEFAULT_MODEL = "glm-4.7-flash"; + +const OLLAMA_SUGGESTED_MODELS_LOCAL = ["glm-4.7-flash"]; +const OLLAMA_SUGGESTED_MODELS_CLOUD = [ + "kimi-k2.5:cloud", + "minimax-m2.5:cloud", + "glm-5:cloud", +]; + +function normalizeOllamaModelName(value: string | undefined): string | undefined { + const trimmed = value?.trim(); + if (!trimmed) { + return undefined; + } + if (trimmed.toLowerCase().startsWith("ollama/")) { + const withoutPrefix = trimmed.slice("ollama/".length).trim(); + return withoutPrefix || undefined; + } + return trimmed; +} + +function isOllamaCloudModel(modelName: string | undefined): boolean { + return Boolean(modelName?.trim().toLowerCase().endsWith(":cloud")); +} + +function formatOllamaPullStatus(status: string): { text: string; hidePercent: boolean } { + const trimmed = status.trim(); + const partStatusMatch = trimmed.match(/^([a-z-]+)\s+(?:sha256:)?[a-f0-9]{8,}$/i); + if (partStatusMatch) { + return { text: `${partStatusMatch[1]} part`, hidePercent: false }; + } + if (/^verifying\b.*\bdigest\b/i.test(trimmed)) { + return { text: "verifying digest", hidePercent: true }; + } + return { text: trimmed, hidePercent: false }; +} + +type OllamaCloudAuthResult = { + signedIn: boolean; + signinUrl?: string; +}; + +/** Check if the user is signed in to Ollama cloud via /api/me. */ +async function checkOllamaCloudAuth(baseUrl: string): Promise { + try { + const apiBase = resolveOllamaApiBase(baseUrl); + const response = await fetch(`${apiBase}/api/me`, { + method: "POST", + signal: AbortSignal.timeout(5000), + }); + if (response.status === 401) { + // 401 body contains { error, signin_url } + const data = (await response.json()) as { signin_url?: string }; + return { signedIn: false, signinUrl: data.signin_url }; + } + if (!response.ok) { + return { signedIn: false }; + } + return { signedIn: true }; + } catch { + // /api/me not supported or unreachable — fail closed so cloud mode + // doesn't silently skip auth; the caller handles the fallback. + return { signedIn: false }; + } +} + +type OllamaPullChunk = { + status?: string; + total?: number; + completed?: number; + error?: string; +}; + +type OllamaPullFailureKind = "http" | "no-body" | "chunk-error" | "network"; +type OllamaPullResult = + | { ok: true } + | { + ok: false; + kind: OllamaPullFailureKind; + message: string; + }; + +async function pullOllamaModelCore(params: { + baseUrl: string; + modelName: string; + onStatus?: (status: string, percent: number | null) => void; +}): Promise { + const { onStatus } = params; + const baseUrl = resolveOllamaApiBase(params.baseUrl); + const modelName = normalizeOllamaModelName(params.modelName) ?? params.modelName.trim(); + try { + const response = await fetch(`${baseUrl}/api/pull`, { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ name: modelName }), + }); + if (!response.ok) { + return { + ok: false, + kind: "http", + message: `Failed to download ${modelName} (HTTP ${response.status})`, + }; + } + if (!response.body) { + return { + ok: false, + kind: "no-body", + message: `Failed to download ${modelName} (no response body)`, + }; + } + + const reader = response.body.getReader(); + const decoder = new TextDecoder(); + let buffer = ""; + const layers = new Map(); + + const parseLine = (line: string): OllamaPullResult => { + const trimmed = line.trim(); + if (!trimmed) { + return { ok: true }; + } + try { + const chunk = JSON.parse(trimmed) as OllamaPullChunk; + if (chunk.error) { + return { + ok: false, + kind: "chunk-error", + message: `Download failed: ${chunk.error}`, + }; + } + if (!chunk.status) { + return { ok: true }; + } + if (chunk.total && chunk.completed !== undefined) { + layers.set(chunk.status, { total: chunk.total, completed: chunk.completed }); + let totalSum = 0; + let completedSum = 0; + for (const layer of layers.values()) { + totalSum += layer.total; + completedSum += layer.completed; + } + const percent = totalSum > 0 ? Math.round((completedSum / totalSum) * 100) : null; + onStatus?.(chunk.status, percent); + } else { + onStatus?.(chunk.status, null); + } + } catch { + // Ignore malformed lines from streaming output. + } + return { ok: true }; + }; + + for (;;) { + const { done, value } = await reader.read(); + if (done) { + break; + } + buffer += decoder.decode(value, { stream: true }); + const lines = buffer.split("\n"); + buffer = lines.pop() ?? ""; + for (const line of lines) { + const parsed = parseLine(line); + if (!parsed.ok) { + return parsed; + } + } + } + + const trailing = buffer.trim(); + if (trailing) { + const parsed = parseLine(trailing); + if (!parsed.ok) { + return parsed; + } + } + + return { ok: true }; + } catch (err) { + const reason = err instanceof Error ? err.message : String(err); + return { + ok: false, + kind: "network", + message: `Failed to download ${modelName}: ${reason}`, + }; + } +} + +/** Pull a model from Ollama, streaming progress updates. */ +async function pullOllamaModel( + baseUrl: string, + modelName: string, + prompter: WizardPrompter, +): Promise { + const spinner = prompter.progress(`Downloading ${modelName}...`); + const result = await pullOllamaModelCore({ + baseUrl, + modelName, + onStatus: (status, percent) => { + const displayStatus = formatOllamaPullStatus(status); + if (displayStatus.hidePercent) { + spinner.update(`Downloading ${modelName} - ${displayStatus.text}`); + } else { + spinner.update(`Downloading ${modelName} - ${displayStatus.text} - ${percent ?? 0}%`); + } + }, + }); + if (!result.ok) { + spinner.stop(result.message); + return false; + } + spinner.stop(`Downloaded ${modelName}`); + return true; +} + +async function pullOllamaModelNonInteractive( + baseUrl: string, + modelName: string, + runtime: RuntimeEnv, +): Promise { + runtime.log(`Downloading ${modelName}...`); + const result = await pullOllamaModelCore({ baseUrl, modelName }); + if (!result.ok) { + runtime.error(result.message); + return false; + } + runtime.log(`Downloaded ${modelName}`); + return true; +} + +function buildOllamaModelsConfig(modelNames: string[]) { + return modelNames.map((name) => buildOllamaModelDefinition(name)); +} + +function applyOllamaProviderConfig( + cfg: OpenClawConfig, + baseUrl: string, + modelNames: string[], +): OpenClawConfig { + return { + ...cfg, + models: { + ...cfg.models, + mode: cfg.models?.mode ?? "merge", + providers: { + ...cfg.models?.providers, + ollama: { + baseUrl, + api: "ollama", + apiKey: "OLLAMA_API_KEY", // pragma: allowlist secret + models: buildOllamaModelsConfig(modelNames), + }, + }, + }, + }; +} + +async function storeOllamaCredential(agentDir?: string): Promise { + await upsertAuthProfileWithLock({ + profileId: "ollama:default", + credential: { type: "api_key", provider: "ollama", key: "ollama-local" }, + agentDir, + }); +} + +/** + * Interactive: prompt for base URL, discover models, configure provider. + * Model selection is handled by the standard model picker downstream. + */ +export async function promptAndConfigureOllama(params: { + cfg: OpenClawConfig; + prompter: WizardPrompter; + agentDir?: string; +}): Promise<{ config: OpenClawConfig; defaultModelId: string }> { + const { prompter } = params; + + // 1. Prompt base URL + const baseUrlRaw = await prompter.text({ + message: "Ollama base URL", + initialValue: OLLAMA_DEFAULT_BASE_URL, + placeholder: OLLAMA_DEFAULT_BASE_URL, + validate: (value) => (value?.trim() ? undefined : "Required"), + }); + const configuredBaseUrl = String(baseUrlRaw ?? "") + .trim() + .replace(/\/+$/, ""); + const baseUrl = resolveOllamaApiBase(configuredBaseUrl); + + // 2. Check reachability + const { reachable, models } = await fetchOllamaModels(baseUrl); + const modelNames = models.map((m) => m.name); + + if (!reachable) { + await prompter.note( + [ + `Ollama could not be reached at ${baseUrl}.`, + "Download it at https://ollama.com/download", + "", + "Start Ollama and re-run onboarding.", + ].join("\n"), + "Ollama", + ); + throw new WizardCancelledError("Ollama not reachable"); + } + + // 3. Mode selection + const mode = (await prompter.select({ + message: "Ollama mode", + options: [ + { value: "remote", label: "Cloud + Local", hint: "Ollama cloud models + local models" }, + { value: "local", label: "Local", hint: "Local models only" }, + ], + })) as OnboardMode; + + // 4. Cloud auth — check /api/me upfront for remote (cloud+local) mode + let cloudAuthVerified = false; + if (mode === "remote") { + const authResult = await checkOllamaCloudAuth(baseUrl); + if (!authResult.signedIn) { + if (authResult.signinUrl) { + if (!isRemoteEnvironment()) { + await openUrl(authResult.signinUrl); + } + await prompter.note( + ["Sign in to Ollama Cloud:", authResult.signinUrl].join("\n"), + "Ollama Cloud", + ); + const confirmed = await prompter.confirm({ + message: "Have you signed in?", + }); + if (!confirmed) { + throw new WizardCancelledError("Ollama cloud sign-in cancelled"); + } + // Re-check after user claims sign-in + const recheck = await checkOllamaCloudAuth(baseUrl); + if (!recheck.signedIn) { + throw new WizardCancelledError("Ollama cloud sign-in required"); + } + cloudAuthVerified = true; + } else { + // No signin URL available (older server, unreachable /api/me, or custom gateway). + await prompter.note( + [ + "Could not verify Ollama Cloud authentication.", + "Cloud models may not work until you sign in at https://ollama.com.", + ].join("\n"), + "Ollama Cloud", + ); + const continueAnyway = await prompter.confirm({ + message: "Continue without cloud auth?", + }); + if (!continueAnyway) { + throw new WizardCancelledError("Ollama cloud auth could not be verified"); + } + // Cloud auth unverified — fall back to local defaults so the model + // picker doesn't steer toward cloud models that may fail. + } + } else { + cloudAuthVerified = true; + } + } + + // 5. Model ordering — suggested models first. + // Use cloud defaults only when auth was actually verified; otherwise fall + // back to local defaults so the user isn't steered toward cloud models + // that may fail at runtime. + const suggestedModels = + mode === "local" || !cloudAuthVerified + ? OLLAMA_SUGGESTED_MODELS_LOCAL + : OLLAMA_SUGGESTED_MODELS_CLOUD; + const orderedModelNames = [ + ...suggestedModels, + ...modelNames.filter((name) => !suggestedModels.includes(name)), + ]; + + await storeOllamaCredential(params.agentDir); + + const defaultModelId = suggestedModels[0] ?? OLLAMA_DEFAULT_MODEL; + const config = applyOllamaProviderConfig(params.cfg, baseUrl, orderedModelNames); + return { config, defaultModelId }; +} + +/** Non-interactive: auto-discover models and configure provider. */ +export async function configureOllamaNonInteractive(params: { + nextConfig: OpenClawConfig; + opts: OnboardOptions; + runtime: RuntimeEnv; +}): Promise { + const { opts, runtime } = params; + const configuredBaseUrl = (opts.customBaseUrl?.trim() || OLLAMA_DEFAULT_BASE_URL).replace( + /\/+$/, + "", + ); + const baseUrl = resolveOllamaApiBase(configuredBaseUrl); + + const { reachable, models } = await fetchOllamaModels(baseUrl); + const modelNames = models.map((m) => m.name); + const explicitModel = normalizeOllamaModelName(opts.customModelId); + + if (!reachable) { + runtime.error( + [ + `Ollama could not be reached at ${baseUrl}.`, + "Download it at https://ollama.com/download", + ].join("\n"), + ); + runtime.exit(1); + return params.nextConfig; + } + + await storeOllamaCredential(); + + // Apply local suggested model ordering. + const suggestedModels = OLLAMA_SUGGESTED_MODELS_LOCAL; + const orderedModelNames = [ + ...suggestedModels, + ...modelNames.filter((name) => !suggestedModels.includes(name)), + ]; + + const requestedDefaultModelId = explicitModel ?? suggestedModels[0]; + let pulledRequestedModel = false; + const availableModelNames = new Set(modelNames); + const requestedCloudModel = isOllamaCloudModel(requestedDefaultModelId); + + if (requestedCloudModel) { + availableModelNames.add(requestedDefaultModelId); + } + + // Pull if model not in discovered list and Ollama is reachable + if (!requestedCloudModel && !modelNames.includes(requestedDefaultModelId)) { + pulledRequestedModel = await pullOllamaModelNonInteractive( + baseUrl, + requestedDefaultModelId, + runtime, + ); + if (pulledRequestedModel) { + availableModelNames.add(requestedDefaultModelId); + } + } + + let allModelNames = orderedModelNames; + let defaultModelId = requestedDefaultModelId; + if ((pulledRequestedModel || requestedCloudModel) && !allModelNames.includes(requestedDefaultModelId)) { + allModelNames = [...allModelNames, requestedDefaultModelId]; + } + if (!availableModelNames.has(requestedDefaultModelId)) { + if (availableModelNames.size > 0) { + const firstAvailableModel = + allModelNames.find((name) => availableModelNames.has(name)) ?? + Array.from(availableModelNames)[0]; + defaultModelId = firstAvailableModel; + runtime.log( + `Ollama model ${requestedDefaultModelId} was not available; using ${defaultModelId} instead.`, + ); + } else { + runtime.error( + [ + `No Ollama models are available at ${baseUrl}.`, + "Pull a model first, then re-run onboarding.", + ].join("\n"), + ); + runtime.exit(1); + return params.nextConfig; + } + } + + const config = applyOllamaProviderConfig(params.nextConfig, baseUrl, allModelNames); + const modelRef = `ollama/${defaultModelId}`; + runtime.log(`Default Ollama model: ${defaultModelId}`); + return applyAgentDefaultModelPrimary(config, modelRef); +} + +/** Pull the configured default Ollama model if it isn't already available locally. */ +export async function ensureOllamaModelPulled(params: { + config: OpenClawConfig; + prompter: WizardPrompter; +}): Promise { + const modelCfg = params.config.agents?.defaults?.model; + const modelId = typeof modelCfg === "string" ? modelCfg : modelCfg?.primary; + if (!modelId?.startsWith("ollama/")) { + return; + } + const baseUrl = params.config.models?.providers?.ollama?.baseUrl ?? OLLAMA_DEFAULT_BASE_URL; + const modelName = modelId.slice("ollama/".length); + if (isOllamaCloudModel(modelName)) { + return; + } + const { models } = await fetchOllamaModels(baseUrl); + if (models.some((m) => m.name === modelName)) { + return; + } + const pulled = await pullOllamaModel(baseUrl, modelName, params.prompter); + if (!pulled) { + throw new WizardCancelledError("Failed to download selected Ollama model"); + } +} diff --git a/src/commands/onboard-non-interactive/local/auth-choice.ts b/src/commands/onboard-non-interactive/local/auth-choice.ts index 7636e64d6d6..af119c12efe 100644 --- a/src/commands/onboard-non-interactive/local/auth-choice.ts +++ b/src/commands/onboard-non-interactive/local/auth-choice.ts @@ -10,6 +10,7 @@ import { normalizeSecretInputModeInput } from "../../auth-choice.apply-helpers.j import { buildTokenProfileId, validateAnthropicSetupToken } from "../../auth-token.js"; import { applyGoogleGeminiModelDefault } from "../../google-gemini-model-default.js"; import { applyPrimaryModel } from "../../model-picker.js"; +import { configureOllamaNonInteractive } from "../../ollama-setup.js"; import { applyAuthProfileConfig, applyCloudflareAiGatewayConfig, @@ -174,6 +175,10 @@ export async function applyNonInteractiveAuthChoice(params: { return null; } + if (authChoice === "ollama") { + return configureOllamaNonInteractive({ nextConfig, opts, runtime }); + } + if (authChoice === "apiKey") { const resolved = await resolveApiKey({ provider: "anthropic", diff --git a/src/commands/onboard-types.ts b/src/commands/onboard-types.ts index bb8bf150a0b..40a02e85c15 100644 --- a/src/commands/onboard-types.ts +++ b/src/commands/onboard-types.ts @@ -10,6 +10,7 @@ export type AuthChoice = | "token" | "chutes" | "vllm" + | "ollama" | "openai-codex" | "openai-api-key" | "openrouter-api-key" @@ -59,6 +60,7 @@ export type AuthChoiceGroupId = | "anthropic" | "chutes" | "vllm" + | "ollama" | "google" | "copilot" | "openrouter" diff --git a/src/memory/embeddings-ollama.ts b/src/memory/embeddings-ollama.ts index 4c9326df874..7ccdff6560d 100644 --- a/src/memory/embeddings-ollama.ts +++ b/src/memory/embeddings-ollama.ts @@ -1,4 +1,5 @@ import { resolveEnvApiKey } from "../agents/model-auth.js"; +import { resolveOllamaApiBase } from "../agents/ollama-models.js"; import { formatErrorMessage } from "../infra/errors.js"; import type { SsrFPolicy } from "../infra/net/ssrf.js"; import { normalizeOptionalSecretInput } from "../utils/normalize-secret-input.js"; @@ -17,7 +18,6 @@ export type OllamaEmbeddingClient = { type OllamaEmbeddingClientConfig = Omit; export const DEFAULT_OLLAMA_EMBEDDING_MODEL = "nomic-embed-text"; -const DEFAULT_OLLAMA_BASE_URL = "http://127.0.0.1:11434"; function sanitizeAndNormalizeEmbedding(vec: number[]): number[] { const sanitized = vec.map((value) => (Number.isFinite(value) ? value : 0)); @@ -36,14 +36,6 @@ function normalizeOllamaModel(model: string): string { }); } -function resolveOllamaApiBase(configuredBaseUrl?: string): string { - if (!configuredBaseUrl) { - return DEFAULT_OLLAMA_BASE_URL; - } - const trimmed = configuredBaseUrl.replace(/\/+$/, ""); - return trimmed.replace(/\/v1$/i, ""); -} - function resolveOllamaApiKey(options: EmbeddingProviderOptions): string | undefined { const remoteApiKey = resolveMemorySecretInputString({ value: options.remote?.apiKey, diff --git a/src/wizard/onboarding.ts b/src/wizard/onboarding.ts index 47825eeae52..554c8046b60 100644 --- a/src/wizard/onboarding.ts +++ b/src/wizard/onboarding.ts @@ -442,13 +442,17 @@ export async function runOnboardingWizard( config: nextConfig, prompter, runtime, - setDefaultModel: true, + setDefaultModel: !(authChoiceFromPrompt && authChoice === "ollama"), opts: { tokenProvider: opts.tokenProvider, token: opts.authChoice === "apiKey" && opts.token ? opts.token : undefined, }, }); nextConfig = authResult.config; + + if (authResult.agentModelOverride) { + nextConfig = applyPrimaryModel(nextConfig, authResult.agentModelOverride); + } } if (authChoiceFromPrompt && authChoice !== "custom-api-key") { @@ -468,6 +472,11 @@ export async function runOnboardingWizard( } } + if (authChoice === "ollama") { + const { ensureOllamaModelPulled } = await import("../commands/ollama-setup.js"); + await ensureOllamaModelPulled({ config: nextConfig, prompter }); + } + await warnIfModelConfigLooksOff(nextConfig, prompter); const { configureGatewayForOnboarding } = await import("./onboarding.gateway-config.js");