Plugins: extract provider model selection hook

This commit is contained in:
Gustavo Madeira Santana
2026-03-15 17:53:09 +00:00
parent e7e59a862d
commit bb4681fca6
5 changed files with 54 additions and 39 deletions

View File

@@ -16,10 +16,12 @@ vi.mock("../plugins/providers.js", () => ({
const resolveProviderPluginChoice = vi.hoisted(() =>
vi.fn<() => { provider: ProviderPlugin; method: ProviderAuthMethod } | null>(),
);
const runProviderModelSelectedHook = vi.hoisted(() => vi.fn(async () => {}));
const runExtensionHostProviderModelSelectedHook = vi.hoisted(() => vi.fn(async () => {}));
vi.mock("../plugins/provider-wizard.js", () => ({
resolveProviderPluginChoice,
runProviderModelSelectedHook,
}));
vi.mock("../extension-host/provider-model-selection.js", () => ({
runExtensionHostProviderModelSelectedHook,
}));
const upsertAuthProfile = vi.hoisted(() => vi.fn());
@@ -130,7 +132,7 @@ describe("applyAuthChoiceLoadedPluginProvider", () => {
config: {},
agentModelOverride: "ollama/qwen3:4b",
});
expect(runProviderModelSelectedHook).not.toHaveBeenCalled();
expect(runExtensionHostProviderModelSelectedHook).not.toHaveBeenCalled();
});
it("applies the default model and runs provider post-setup hooks", async () => {
@@ -155,7 +157,7 @@ describe("applyAuthChoiceLoadedPluginProvider", () => {
},
agentDir: "/tmp/agent",
});
expect(runProviderModelSelectedHook).toHaveBeenCalledWith({
expect(runExtensionHostProviderModelSelectedHook).toHaveBeenCalledWith({
config: result?.config,
model: "ollama/qwen3:4b",
prompter: expect.objectContaining({ note: expect.any(Function) }),
@@ -279,7 +281,7 @@ describe("applyAuthChoiceLoadedPluginProvider", () => {
},
},
});
expect(runProviderModelSelectedHook).not.toHaveBeenCalled();
expect(runExtensionHostProviderModelSelectedHook).not.toHaveBeenCalled();
expect(note).toHaveBeenCalledWith(
'Default model set to ollama/qwen3:4b for agent "worker".',
"Model configured",

View File

@@ -401,7 +401,7 @@ export async function promptDefaultModel(
workspaceDir: params.workspaceDir,
});
if (applied.defaultModel) {
await runProviderModelSelectedHook({
await runExtensionHostProviderModelSelectedHook({
config: applied.config,
model: applied.defaultModel,
prompter: params.prompter,

View File

@@ -15,10 +15,7 @@ import { createVpsAwareOAuthHandlers } from "../commands/oauth-flow.js";
import { applyAuthProfileConfig } from "../commands/onboard-auth.js";
import { openUrl } from "../commands/onboard-helpers.js";
import { enablePluginInConfig } from "../plugins/enable.js";
import {
resolveProviderPluginChoice,
runProviderModelSelectedHook,
} from "../plugins/provider-wizard.js";
import { resolveProviderPluginChoice } from "../plugins/provider-wizard.js";
import { resolvePluginProviders } from "../plugins/providers.js";
import type { ProviderAuthMethod } from "../plugins/types.js";
import {
@@ -27,6 +24,7 @@ import {
pickExtensionHostAuthMethod,
resolveExtensionHostProviderMatch,
} from "./provider-auth.js";
import { runExtensionHostProviderModelSelectedHook } from "./provider-model-selection.js";
export type ExtensionHostPluginProviderAuthChoiceOptions = {
authChoice: string;
@@ -135,7 +133,7 @@ export async function applyExtensionHostLoadedPluginProvider(
if (applied.defaultModel) {
if (params.setDefaultModel) {
const nextConfig = applyExtensionHostDefaultModel(applied.config, applied.defaultModel);
await runProviderModelSelectedHook({
await runExtensionHostProviderModelSelectedHook({
config: nextConfig,
model: applied.defaultModel,
prompter: params.prompter,
@@ -211,7 +209,7 @@ export async function applyExtensionHostPluginProvider(
if (applied.defaultModel) {
if (params.setDefaultModel) {
nextConfig = applyExtensionHostDefaultModel(nextConfig, applied.defaultModel);
await runProviderModelSelectedHook({
await runExtensionHostProviderModelSelectedHook({
config: nextConfig,
model: applied.defaultModel,
prompter: params.prompter,

View File

@@ -0,0 +1,40 @@
import { DEFAULT_PROVIDER } from "../agents/defaults.js";
import { parseModelRef } from "../agents/model-ref.js";
import { normalizeProviderId } from "../agents/provider-id.js";
import type { OpenClawConfig } from "../config/config.js";
import { resolvePluginProviders } from "../plugins/providers.js";
import type { WizardPrompter } from "../wizard/prompts.js";
export async function runExtensionHostProviderModelSelectedHook(params: {
config: OpenClawConfig;
model: string;
prompter: WizardPrompter;
agentDir?: string;
workspaceDir?: string;
env?: NodeJS.ProcessEnv;
}): Promise<void> {
const parsed = parseModelRef(params.model, DEFAULT_PROVIDER);
if (!parsed) {
return;
}
const providers = resolvePluginProviders({
config: params.config,
workspaceDir: params.workspaceDir,
env: params.env,
});
const provider = providers.find(
(entry) => normalizeProviderId(entry.id) === normalizeProviderId(parsed.provider),
);
if (!provider?.onModelSelected) {
return;
}
await provider.onModelSelected({
config: params.config,
model: params.model,
prompter: params.prompter,
agentDir: params.agentDir,
workspaceDir: params.workspaceDir,
});
}

View File

@@ -1,7 +1,5 @@
import { DEFAULT_PROVIDER } from "../agents/defaults.js";
import { parseModelRef } from "../agents/model-ref.js";
import { normalizeProviderId } from "../agents/provider-id.js";
import type { OpenClawConfig } from "../config/config.js";
import { runExtensionHostProviderModelSelectedHook } from "../extension-host/provider-model-selection.js";
import {
buildExtensionHostProviderMethodChoice,
resolveExtensionHostProviderChoice,
@@ -64,28 +62,5 @@ export async function runProviderModelSelectedHook(params: {
workspaceDir?: string;
env?: NodeJS.ProcessEnv;
}): Promise<void> {
const parsed = parseModelRef(params.model, DEFAULT_PROVIDER);
if (!parsed) {
return;
}
const providers = resolvePluginProviders({
config: params.config,
workspaceDir: params.workspaceDir,
env: params.env,
});
const provider = providers.find(
(entry) => normalizeProviderId(entry.id) === normalizeProviderId(parsed.provider),
);
if (!provider?.onModelSelected) {
return;
}
await provider.onModelSelected({
config: params.config,
model: params.model,
prompter: params.prompter,
agentDir: params.agentDir,
workspaceDir: params.workspaceDir,
});
await runExtensionHostProviderModelSelectedHook(params);
}