diff --git a/src/commands/auth-choice.apply.plugin-provider.test.ts b/src/commands/auth-choice.apply.plugin-provider.test.ts index 2557fcd2f5c..c3709a6ce75 100644 --- a/src/commands/auth-choice.apply.plugin-provider.test.ts +++ b/src/commands/auth-choice.apply.plugin-provider.test.ts @@ -16,10 +16,12 @@ vi.mock("../plugins/providers.js", () => ({ const resolveProviderPluginChoice = vi.hoisted(() => vi.fn<() => { provider: ProviderPlugin; method: ProviderAuthMethod } | null>(), ); -const runProviderModelSelectedHook = vi.hoisted(() => vi.fn(async () => {})); +const runExtensionHostProviderModelSelectedHook = vi.hoisted(() => vi.fn(async () => {})); vi.mock("../plugins/provider-wizard.js", () => ({ resolveProviderPluginChoice, - runProviderModelSelectedHook, +})); +vi.mock("../extension-host/provider-model-selection.js", () => ({ + runExtensionHostProviderModelSelectedHook, })); const upsertAuthProfile = vi.hoisted(() => vi.fn()); @@ -130,7 +132,7 @@ describe("applyAuthChoiceLoadedPluginProvider", () => { config: {}, agentModelOverride: "ollama/qwen3:4b", }); - expect(runProviderModelSelectedHook).not.toHaveBeenCalled(); + expect(runExtensionHostProviderModelSelectedHook).not.toHaveBeenCalled(); }); it("applies the default model and runs provider post-setup hooks", async () => { @@ -155,7 +157,7 @@ describe("applyAuthChoiceLoadedPluginProvider", () => { }, agentDir: "/tmp/agent", }); - expect(runProviderModelSelectedHook).toHaveBeenCalledWith({ + expect(runExtensionHostProviderModelSelectedHook).toHaveBeenCalledWith({ config: result?.config, model: "ollama/qwen3:4b", prompter: expect.objectContaining({ note: expect.any(Function) }), @@ -279,7 +281,7 @@ describe("applyAuthChoiceLoadedPluginProvider", () => { }, }, }); - expect(runProviderModelSelectedHook).not.toHaveBeenCalled(); + expect(runExtensionHostProviderModelSelectedHook).not.toHaveBeenCalled(); expect(note).toHaveBeenCalledWith( 'Default model set to ollama/qwen3:4b for agent "worker".', "Model configured", diff --git a/src/commands/model-picker.ts b/src/commands/model-picker.ts index 64d9e533e1f..af7c0cee92b 100644 --- a/src/commands/model-picker.ts +++ b/src/commands/model-picker.ts @@ -401,7 +401,7 @@ export async function promptDefaultModel( workspaceDir: params.workspaceDir, }); if (applied.defaultModel) { - await runProviderModelSelectedHook({ + await runExtensionHostProviderModelSelectedHook({ config: applied.config, model: applied.defaultModel, prompter: params.prompter, diff --git a/src/extension-host/provider-auth-flow.ts b/src/extension-host/provider-auth-flow.ts index 606be794cbd..7cc80bc8854 100644 --- a/src/extension-host/provider-auth-flow.ts +++ b/src/extension-host/provider-auth-flow.ts @@ -15,10 +15,7 @@ import { createVpsAwareOAuthHandlers } from "../commands/oauth-flow.js"; import { applyAuthProfileConfig } from "../commands/onboard-auth.js"; import { openUrl } from "../commands/onboard-helpers.js"; import { enablePluginInConfig } from "../plugins/enable.js"; -import { - resolveProviderPluginChoice, - runProviderModelSelectedHook, -} from "../plugins/provider-wizard.js"; +import { resolveProviderPluginChoice } from "../plugins/provider-wizard.js"; import { resolvePluginProviders } from "../plugins/providers.js"; import type { ProviderAuthMethod } from "../plugins/types.js"; import { @@ -27,6 +24,7 @@ import { pickExtensionHostAuthMethod, resolveExtensionHostProviderMatch, } from "./provider-auth.js"; +import { runExtensionHostProviderModelSelectedHook } from "./provider-model-selection.js"; export type ExtensionHostPluginProviderAuthChoiceOptions = { authChoice: string; @@ -135,7 +133,7 @@ export async function applyExtensionHostLoadedPluginProvider( if (applied.defaultModel) { if (params.setDefaultModel) { const nextConfig = applyExtensionHostDefaultModel(applied.config, applied.defaultModel); - await runProviderModelSelectedHook({ + await runExtensionHostProviderModelSelectedHook({ config: nextConfig, model: applied.defaultModel, prompter: params.prompter, @@ -211,7 +209,7 @@ export async function applyExtensionHostPluginProvider( if (applied.defaultModel) { if (params.setDefaultModel) { nextConfig = applyExtensionHostDefaultModel(nextConfig, applied.defaultModel); - await runProviderModelSelectedHook({ + await runExtensionHostProviderModelSelectedHook({ config: nextConfig, model: applied.defaultModel, prompter: params.prompter, diff --git a/src/extension-host/provider-model-selection.ts b/src/extension-host/provider-model-selection.ts new file mode 100644 index 00000000000..0f970fe4988 --- /dev/null +++ b/src/extension-host/provider-model-selection.ts @@ -0,0 +1,40 @@ +import { DEFAULT_PROVIDER } from "../agents/defaults.js"; +import { parseModelRef } from "../agents/model-ref.js"; +import { normalizeProviderId } from "../agents/provider-id.js"; +import type { OpenClawConfig } from "../config/config.js"; +import { resolvePluginProviders } from "../plugins/providers.js"; +import type { WizardPrompter } from "../wizard/prompts.js"; + +export async function runExtensionHostProviderModelSelectedHook(params: { + config: OpenClawConfig; + model: string; + prompter: WizardPrompter; + agentDir?: string; + workspaceDir?: string; + env?: NodeJS.ProcessEnv; +}): Promise { + const parsed = parseModelRef(params.model, DEFAULT_PROVIDER); + if (!parsed) { + return; + } + + const providers = resolvePluginProviders({ + config: params.config, + workspaceDir: params.workspaceDir, + env: params.env, + }); + const provider = providers.find( + (entry) => normalizeProviderId(entry.id) === normalizeProviderId(parsed.provider), + ); + if (!provider?.onModelSelected) { + return; + } + + await provider.onModelSelected({ + config: params.config, + model: params.model, + prompter: params.prompter, + agentDir: params.agentDir, + workspaceDir: params.workspaceDir, + }); +} diff --git a/src/plugins/provider-wizard.ts b/src/plugins/provider-wizard.ts index ac5ab29e2f1..4f5efb64d95 100644 --- a/src/plugins/provider-wizard.ts +++ b/src/plugins/provider-wizard.ts @@ -1,7 +1,5 @@ -import { DEFAULT_PROVIDER } from "../agents/defaults.js"; -import { parseModelRef } from "../agents/model-ref.js"; -import { normalizeProviderId } from "../agents/provider-id.js"; import type { OpenClawConfig } from "../config/config.js"; +import { runExtensionHostProviderModelSelectedHook } from "../extension-host/provider-model-selection.js"; import { buildExtensionHostProviderMethodChoice, resolveExtensionHostProviderChoice, @@ -64,28 +62,5 @@ export async function runProviderModelSelectedHook(params: { workspaceDir?: string; env?: NodeJS.ProcessEnv; }): Promise { - const parsed = parseModelRef(params.model, DEFAULT_PROVIDER); - if (!parsed) { - return; - } - - const providers = resolvePluginProviders({ - config: params.config, - workspaceDir: params.workspaceDir, - env: params.env, - }); - const provider = providers.find( - (entry) => normalizeProviderId(entry.id) === normalizeProviderId(parsed.provider), - ); - if (!provider?.onModelSelected) { - return; - } - - await provider.onModelSelected({ - config: params.config, - model: params.model, - prompter: params.prompter, - agentDir: params.agentDir, - workspaceDir: params.workspaceDir, - }); + await runExtensionHostProviderModelSelectedHook(params); }