diff --git a/src/commands/configure.gateway-auth.ollama.integration.test.ts b/src/commands/configure.gateway-auth.ollama.integration.test.ts new file mode 100644 index 00000000000..8d7275be6ea --- /dev/null +++ b/src/commands/configure.gateway-auth.ollama.integration.test.ts @@ -0,0 +1,92 @@ +import { mkdtempSync } from "node:fs"; +import { tmpdir } from "node:os"; +import { join } from "node:path"; +import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"; +import type { OpenClawConfig } from "../config/config.js"; +import type { WizardPrompter } from "../wizard/prompts.js"; +import { promptAuthConfig } from "./configure.gateway-auth.js"; +import { makePrompter, makeRuntime } from "./setup/__tests__/test-utils.js"; + +describe("promptAuthConfig Ollama setup", () => { + const originalFetch = globalThis.fetch; + + beforeEach(() => { + vi.clearAllMocks(); + vi.stubEnv("HOME", mkdtempSync(join(tmpdir(), "openclaw-ollama-config-"))); + vi.stubGlobal( + "fetch", + vi.fn(async (url: string | URL | Request) => { + const href = typeof url === "string" ? url : "url" in url ? url.url : String(url); + if (href.endsWith("/api/tags")) { + return new Response( + JSON.stringify({ + models: [{ name: "kimi-k2.5:cloud" }, { name: "gpt-oss:20b-cloud" }], + }), + { status: 200, headers: { "content-type": "application/json" } }, + ); + } + throw new Error(`unexpected fetch: ${href}`); + }), + ); + }); + + afterEach(() => { + vi.unstubAllEnvs(); + vi.stubGlobal("fetch", originalFetch); + }); + + it("shows the model picker after cloud-only setup when Ollama models were already configured", async () => { + const select = vi.fn(async (params) => { + if (params.message === "Model/auth provider") { + return "ollama"; + } + if (params.message === "Ollama mode") { + return "cloud-only"; + } + if (params.message === "How do you want to provide this API key?") { + return "plaintext"; + } + throw new Error(`unexpected select: ${params.message}`); + }) as WizardPrompter["select"]; + const text = vi.fn(async (params) => { + if (params.message === "Ollama API key") { + return "test-ollama-key"; + } + throw new Error(`unexpected text: ${params.message}`); + }); + const multiselect = vi.fn(async (params) => + params.options.map((option: { value: string }) => option.value), + ); + const progress = vi.fn(() => ({ update: vi.fn(), stop: vi.fn() })); + const prompter = makePrompter({ select, text, multiselect, progress }); + const config = { + models: { + providers: { + ollama: { + api: "ollama", + baseUrl: "https://ollama.com", + models: [ + { + id: "kimi-k2.5:cloud", + name: "Kimi K2.5", + reasoning: false, + input: ["text"], + cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 }, + contextWindow: 128_000, + maxTokens: 8192, + }, + ], + }, + }, + }, + } as OpenClawConfig; + + const result = await promptAuthConfig(config, makeRuntime(), prompter); + + expect(multiselect).toHaveBeenCalled(); + expect( + multiselect.mock.calls[0]?.[0]?.options.map((option: { value: string }) => option.value), + ).toContain("ollama/kimi-k2.5:cloud"); + expect(result.agents?.defaults?.models).toHaveProperty("ollama/kimi-k2.5:cloud"); + }); +}); diff --git a/src/commands/configure.gateway-auth.prompt-auth-config.test.ts b/src/commands/configure.gateway-auth.prompt-auth-config.test.ts index 466655958b4..f13a1e1502e 100644 --- a/src/commands/configure.gateway-auth.prompt-auth-config.test.ts +++ b/src/commands/configure.gateway-auth.prompt-auth-config.test.ts @@ -139,15 +139,21 @@ describe("promptAuthConfig", () => { mocks.applyAuthChoice.mockResolvedValue({ config: {} }); mocks.promptModelAllowlist.mockResolvedValue({ models: undefined }); mocks.resolveProviderPluginChoice.mockReturnValue({ - provider: { id: "anthropic", label: "Anthropic", auth: [] }, - method: { id: "setup-token", label: "setup-token", kind: "token" }, - wizard: { - modelAllowlist: { - allowedKeys: ["anthropic/claude-sonnet-4-6"], - initialSelections: ["anthropic/claude-sonnet-4-6"], - message: "Anthropic OAuth models", + provider: { + id: "anthropic", + label: "Anthropic", + auth: [], + wizard: { + setup: { + modelAllowlist: { + allowedKeys: ["anthropic/claude-sonnet-4-6"], + initialSelections: ["anthropic/claude-sonnet-4-6"], + message: "Anthropic OAuth models", + }, + }, }, }, + method: { id: "setup-token", label: "setup-token", kind: "token" }, }); await promptAuthConfig({}, makeRuntime(), noopPrompter); @@ -180,14 +186,20 @@ describe("promptAuthConfig", () => { scopeKeys: ["anthropic/claude-opus-4-6", "anthropic/claude-sonnet-4-6"], }); mocks.resolveProviderPluginChoice.mockReturnValue({ - provider: { id: "anthropic", label: "Anthropic", auth: [] }, - method: { id: "setup-token", label: "setup-token", kind: "token" }, - wizard: { - modelAllowlist: { - allowedKeys: ["anthropic/claude-opus-4-6", "anthropic/claude-sonnet-4-6"], - initialSelections: ["anthropic/claude-sonnet-4-6"], + provider: { + id: "anthropic", + label: "Anthropic", + auth: [], + wizard: { + setup: { + modelAllowlist: { + allowedKeys: ["anthropic/claude-opus-4-6", "anthropic/claude-sonnet-4-6"], + initialSelections: ["anthropic/claude-sonnet-4-6"], + }, + }, }, }, + method: { id: "setup-token", label: "setup-token", kind: "token" }, }); const result = await promptAuthConfig({}, makeRuntime(), noopPrompter); @@ -223,14 +235,20 @@ describe("promptAuthConfig", () => { scopeKeys: ["openai/gpt-5.5", "openai/gpt-5.4-mini"], }); mocks.resolveProviderPluginChoice.mockReturnValue({ - provider: { id: "openai", label: "OpenAI", auth: [] }, - method: { id: "setup-token", label: "setup-token", kind: "token" }, - wizard: { - modelAllowlist: { - allowedKeys: ["openai/gpt-5.5", "openai/gpt-5.4-mini"], - initialSelections: ["openai/gpt-5.5"], + provider: { + id: "openai", + label: "OpenAI", + auth: [], + wizard: { + setup: { + modelAllowlist: { + allowedKeys: ["openai/gpt-5.5", "openai/gpt-5.4-mini"], + initialSelections: ["openai/gpt-5.5"], + }, + }, }, }, + method: { id: "setup-token", label: "setup-token", kind: "token" }, }); const result = await promptAuthConfig({}, makeRuntime(), noopPrompter); @@ -245,6 +263,7 @@ describe("promptAuthConfig", () => { }); it("scopes the allowlist picker to the selected provider when available", async () => { + vi.clearAllMocks(); mocks.promptAuthChoiceGrouped.mockResolvedValue("openai-api-key"); mocks.resolvePreferredProviderForAuthChoice.mockResolvedValue("openai"); mocks.applyAuthChoice.mockResolvedValue({ config: {} }); @@ -259,6 +278,39 @@ describe("promptAuthConfig", () => { ); }); + it("loads configured provider models after Ollama Cloud + Local and Cloud only setup", async () => { + vi.clearAllMocks(); + mocks.promptAuthChoiceGrouped.mockResolvedValue("ollama"); + mocks.resolvePreferredProviderForAuthChoice.mockResolvedValue(undefined); + mocks.applyAuthChoice.mockResolvedValue({ + config: { + models: { + providers: { + ollama: { + baseUrl: "https://ollama.com", + api: "ollama", + models: [ + { id: "kimi-k2.5:cloud", name: "kimi-k2.5:cloud" }, + { id: "qwen3-coder:480b-cloud", name: "qwen3-coder:480b-cloud" }, + ], + }, + }, + }, + }, + }); + mocks.promptModelAllowlist.mockResolvedValue({ models: undefined }); + mocks.resolveProviderPluginChoice.mockReturnValue(null); + + await promptAuthConfig({}, makeRuntime(), noopPrompter); + + expect(mocks.promptModelAllowlist).toHaveBeenCalledWith( + expect.objectContaining({ + preferredProvider: "ollama", + loadCatalog: true, + }), + ); + }); + it("returns to auth selection when plugin install onboarding asks for a retry", async () => { vi.clearAllMocks(); mocks.promptAuthChoiceGrouped diff --git a/src/commands/configure.gateway-auth.ts b/src/commands/configure.gateway-auth.ts index 5047a0e4342..24692137ad0 100644 --- a/src/commands/configure.gateway-auth.ts +++ b/src/commands/configure.gateway-auth.ts @@ -30,16 +30,18 @@ function sanitizeTokenValue(value: unknown): string | undefined { return trimmed; } -async function resolveProviderChoiceModelAllowlist(params: { +async function resolveProviderChoiceModelPrompt(params: { authChoice: string; config: OpenClawConfig; workspaceDir?: string; env?: NodeJS.ProcessEnv; }): Promise< | { + provider?: string; allowedKeys?: string[]; initialSelections?: string[]; message?: string; + loadCatalog?: boolean; } | undefined > { @@ -51,10 +53,62 @@ async function resolveProviderChoiceModelAllowlist(params: { env: params.env, mode: "setup", }); - return resolveProviderPluginChoice({ + const resolved = resolveProviderPluginChoice({ providers, choice: params.authChoice, - })?.wizard?.modelAllowlist; + }); + const wizard = resolved?.provider.wizard?.setup; + const provider = resolved?.provider.id; + if (!wizard) { + return provider ? { provider } : undefined; + } + return { + provider, + ...wizard.modelAllowlist, + ...(wizard.modelSelection?.promptWhenAuthChoiceProvided === true ? { loadCatalog: true } : {}), + }; +} + +function hasConfiguredProviderModels(cfg: OpenClawConfig, provider: string | undefined): boolean { + if (!provider) { + return false; + } + return (cfg.models?.providers?.[provider]?.models?.length ?? 0) > 0; +} + +function listConfiguredModelProviders(cfg: OpenClawConfig): string[] { + return Object.entries(cfg.models?.providers ?? {}) + .filter(([, provider]) => (provider.models?.length ?? 0) > 0) + .map(([provider]) => provider); +} + +function resolveSingleConfiguredProvider(cfg: OpenClawConfig): string | undefined { + const configuredProviders = listConfiguredModelProviders(cfg); + return configuredProviders.length === 1 ? configuredProviders[0] : undefined; +} + +function resolveConfiguredProviderFromAuthChange(params: { + before: OpenClawConfig; + after: OpenClawConfig; + preferredProvider?: string; +}): string | undefined { + if (hasConfiguredProviderModels(params.after, params.preferredProvider)) { + return params.preferredProvider; + } + + const beforeProviders = params.before.models?.providers ?? {}; + const configuredProviders = listConfiguredModelProviders(params.after); + const changedProviders = configuredProviders.filter((provider) => { + const beforeCount = beforeProviders[provider]?.models?.length ?? 0; + const afterCount = params.after.models?.providers?.[provider]?.models?.length ?? 0; + return afterCount > beforeCount; + }); + + if (changedProviders.length === 1) { + return changedProviders[0]; + } + + return configuredProviders.length === 1 ? configuredProviders[0] : params.preferredProvider; } export function buildGatewayAuthConfig(params: { @@ -148,6 +202,7 @@ export async function promptAuthConfig( break; } + const beforeAuthConfig = next; const applied = await applyAuthChoice({ authChoice, config: next, @@ -157,6 +212,11 @@ export async function promptAuthConfig( preserveExistingDefaultModel: true, }); next = applied.config; + preferredProvider = resolveConfiguredProviderFromAuthChange({ + before: beforeAuthConfig, + after: next, + preferredProvider, + }); if (applied.retrySelection) { continue; } @@ -164,20 +224,23 @@ export async function promptAuthConfig( } if (authChoice !== "custom-api-key") { - const modelAllowlist = await resolveProviderChoiceModelAllowlist({ + const modelPrompt = await resolveProviderChoiceModelPrompt({ authChoice, config: next, workspaceDir: resolveDefaultAgentWorkspaceDir(), env: process.env, }); + const promptProvider = + modelPrompt?.provider ?? preferredProvider ?? resolveSingleConfiguredProvider(next); const allowlistSelection = await promptModelAllowlist({ config: next, prompter, - allowedKeys: modelAllowlist?.allowedKeys, - initialSelections: modelAllowlist?.initialSelections, - message: modelAllowlist?.message, - preferredProvider, - loadCatalog: false, + allowedKeys: modelPrompt?.allowedKeys, + initialSelections: modelPrompt?.initialSelections, + message: modelPrompt?.message, + preferredProvider: promptProvider, + loadCatalog: + modelPrompt?.loadCatalog ?? hasConfiguredProviderModels(next, promptProvider) ?? false, }); if (allowlistSelection.models) { next = applyModelFallbacksFromSelection(next, allowlistSelection.models, { diff --git a/src/commands/model-picker.test.ts b/src/commands/model-picker.test.ts index 07c6a922d57..f462ee505bd 100644 --- a/src/commands/model-picker.test.ts +++ b/src/commands/model-picker.test.ts @@ -638,6 +638,45 @@ describe("promptModelAllowlist", () => { ]); }); + it("shows configured preferred provider models when the catalog has no entries", async () => { + loadModelCatalog.mockResolvedValue([]); + + const multiselect = createSelectAllMultiselect(); + const text = vi.fn(async () => ""); + const prompter = makePrompter({ multiselect, text }); + const config = { + models: { + providers: { + ollama: { + api: "ollama", + baseUrl: "https://ollama.com/v1", + models: [ + configuredTextModel("kimi-k2.5:cloud", "Kimi K2.5"), + configuredTextModel("gpt-oss:20b-cloud", "GPT OSS 20B"), + ], + }, + }, + }, + agents: { defaults: {} }, + } as OpenClawConfig; + + const result = await promptModelAllowlist({ + config, + prompter, + preferredProvider: "ollama", + loadCatalog: true, + }); + + expect(text).not.toHaveBeenCalled(); + expect( + multiselect.mock.calls[0]?.[0]?.options.map((option: { value: string }) => option.value), + ).toEqual(["ollama/kimi-k2.5:cloud", "ollama/gpt-oss:20b-cloud"]); + expect(result).toEqual({ + models: ["ollama/kimi-k2.5:cloud", "ollama/gpt-oss:20b-cloud"], + scopeKeys: ["ollama/kimi-k2.5:cloud", "ollama/gpt-oss:20b-cloud"], + }); + }); + it("seeds existing model fallbacks into unscoped allowlist selections", async () => { loadModelCatalog.mockResolvedValue([ { diff --git a/src/flows/model-picker.ts b/src/flows/model-picker.ts index 444d3e9eb25..d184206a940 100644 --- a/src/flows/model-picker.ts +++ b/src/flows/model-picker.ts @@ -909,6 +909,18 @@ export async function promptModelAllowlist(params: { } finally { allowlistProgress.stop(); } + if (preferredProvider) { + const configuredCatalog = buildConfiguredModelCatalog({ cfg }).filter( + (entry) => matchesPreferredProvider?.(entry.provider) === true, + ); + const configuredKeys = new Set( + configuredCatalog.map((entry) => modelKey(entry.provider, entry.id)), + ); + catalog = [ + ...configuredCatalog, + ...catalog.filter((entry) => !configuredKeys.has(modelKey(entry.provider, entry.id))), + ]; + } if (catalog.length === 0 && allowedKeys.length === 0) { const noCatalogInitialKeys = existingKeys.length > 0 ? normalizeModelKeys([...existingKeys, ...fallbackKeys]) : [];