diff --git a/src/commands/model-picker.test.ts b/src/commands/model-picker.test.ts index f4bbd52b9c7..704fc5d793a 100644 --- a/src/commands/model-picker.test.ts +++ b/src/commands/model-picker.test.ts @@ -34,6 +34,21 @@ vi.mock("../agents/model-auth.js", () => ({ hasUsableCustomProviderApiKey, })); +const resolveOwningPluginIdsForProvider = vi.hoisted(() => + vi.fn(({ provider }: { provider: string }) => { + if (provider === "byteplus" || provider === "byteplus-plan") { + return ["byteplus"]; + } + if (provider === "volcengine" || provider === "volcengine-plan") { + return ["volcengine"]; + } + return undefined; + }), +); +vi.mock("../plugins/providers.js", () => ({ + resolveOwningPluginIdsForProvider, +})); + const providerModelPickerContributionRuntime = vi.hoisted(() => ({ enabled: false, resolve: vi.fn(() => []), @@ -85,6 +100,15 @@ function createSelectAllMultiselect() { beforeEach(() => { vi.clearAllMocks(); providerModelPickerContributionRuntime.enabled = false; + resolveOwningPluginIdsForProvider.mockImplementation(({ provider }: { provider: string }) => { + if (provider === "byteplus" || provider === "byteplus-plan") { + return ["byteplus"]; + } + if (provider === "volcengine" || provider === "volcengine-plan") { + return ["volcengine"]; + } + return undefined; + }); }); describe("promptDefaultModel", () => { @@ -167,6 +191,12 @@ describe("promptDefaultModel", () => { expect(optionValues[1]).toBe("byteplus-plan/ark-code-latest"); expect(select.mock.calls[0]?.[0]?.initialValue).toBe("byteplus-plan/ark-code-latest"); expect(result.model).toBe("byteplus-plan/ark-code-latest"); + expect(resolveOwningPluginIdsForProvider).toHaveBeenCalledWith( + expect.objectContaining({ provider: "byteplus" }), + ); + expect(resolveOwningPluginIdsForProvider).toHaveBeenCalledWith( + expect.objectContaining({ provider: "byteplus-plan" }), + ); }); it("supports configuring vLLM during setup", async () => { diff --git a/src/commands/openai-model-default.test.ts b/src/commands/openai-model-default.test.ts index 612bf52f81e..970efd954ef 100644 --- a/src/commands/openai-model-default.test.ts +++ b/src/commands/openai-model-default.test.ts @@ -1,9 +1,9 @@ import { describe, expect, it, vi } from "vitest"; -import type { OpenClawConfig } from "../config/config.js"; import { applyOpencodeZenModelDefault, OPENCODE_ZEN_DEFAULT_MODEL, -} from "../plugins/provider-model-defaults.js"; +} from "../../extensions/opencode/api.js"; +import type { OpenClawConfig } from "../config/config.js"; import type { WizardPrompter } from "../wizard/prompts.js"; import { applyDefaultModelChoice } from "./auth-choice.default-model.js"; diff --git a/src/flows/model-picker.ts b/src/flows/model-picker.ts index 03c55d93249..cd0ab1f85d3 100644 --- a/src/flows/model-picker.ts +++ b/src/flows/model-picker.ts @@ -13,6 +13,7 @@ import { formatTokenK } from "../commands/models/shared.js"; import type { OpenClawConfig } from "../config/config.js"; import { resolveAgentModelPrimaryValue } from "../config/model-input.js"; import { applyPrimaryModel } from "../plugins/provider-model-primary.js"; +import { resolveOwningPluginIdsForProvider } from "../plugins/providers.js"; import type { ProviderPlugin } from "../plugins/types.js"; import type { RuntimeEnv } from "../runtime.js"; import { createLazyRuntimeSurface } from "../shared/lazy-runtime.js"; @@ -172,14 +173,43 @@ function addModelSelectOption(params: { params.seen.add(key); } -function matchesPreferredProvider(entryProvider: string, preferredProvider: string): boolean { - if (preferredProvider === "volcengine") { - return entryProvider === "volcengine" || entryProvider === "volcengine-plan"; - } - if (preferredProvider === "byteplus") { - return entryProvider === "byteplus" || entryProvider === "byteplus-plan"; - } - return entryProvider === preferredProvider; +function createPreferredProviderMatcher(params: { + preferredProvider: string; + cfg: OpenClawConfig; + workspaceDir?: string; + env?: NodeJS.ProcessEnv; +}): (entryProvider: string) => boolean { + const normalizedPreferredProvider = normalizeProviderId(params.preferredProvider); + const preferredOwnerPluginIds = resolveOwningPluginIdsForProvider({ + provider: normalizedPreferredProvider, + config: params.cfg, + workspaceDir: params.workspaceDir, + env: params.env, + }); + const preferredOwnerPluginIdSet = preferredOwnerPluginIds + ? new Set(preferredOwnerPluginIds) + : undefined; + const entryProviderCache = new Map(); + return (entryProvider: string) => { + const normalizedEntryProvider = normalizeProviderId(entryProvider); + if (normalizedEntryProvider === normalizedPreferredProvider) { + return true; + } + const cached = entryProviderCache.get(normalizedEntryProvider); + if (cached !== undefined) { + return cached; + } + const value = + !!preferredOwnerPluginIdSet && + !!resolveOwningPluginIdsForProvider({ + provider: normalizedEntryProvider, + config: params.cfg, + workspaceDir: params.workspaceDir, + env: params.env, + })?.some((pluginId) => preferredOwnerPluginIdSet.has(pluginId)); + entryProviderCache.set(normalizedEntryProvider, value); + return value; + }; } async function promptManualModel(params: { @@ -226,6 +256,9 @@ async function maybeFilterModelsByProvider(params: { }>; preferredProvider?: string; prompter: WizardPrompter; + cfg: OpenClawConfig; + workspaceDir?: string; + env?: NodeJS.ProcessEnv; }): Promise { const providerIds = Array.from(new Set(params.models.map((entry) => entry.provider))).toSorted( (a, b) => a.localeCompare(b), @@ -236,6 +269,14 @@ async function maybeFilterModelsByProvider(params: { providerIds.length > 1 && params.models.length > PROVIDER_FILTER_THRESHOLD; let next = params.models; + const matchesPreferredProvider = params.preferredProvider + ? createPreferredProviderMatcher({ + preferredProvider: params.preferredProvider, + cfg: params.cfg, + workspaceDir: params.workspaceDir, + env: params.env, + }) + : undefined; if (shouldPromptProvider) { const selection = await params.prompter.select({ message: "Filter models by provider", @@ -246,9 +287,7 @@ async function maybeFilterModelsByProvider(params: { } } if (hasPreferredProvider && params.preferredProvider) { - const filtered = next.filter((entry) => - matchesPreferredProvider(entry.provider, params.preferredProvider!), - ); + const filtered = next.filter((entry) => matchesPreferredProvider?.(entry.provider)); if (filtered.length > 0) { next = filtered; } @@ -418,9 +457,20 @@ export async function promptDefaultModel( models, preferredProvider, prompter: params.prompter, + cfg, + workspaceDir: params.workspaceDir, + env: params.env, }); + const matchesPreferredProvider = preferredProvider + ? createPreferredProviderMatcher({ + preferredProvider, + cfg, + workspaceDir: params.workspaceDir, + env: params.env, + }) + : undefined; const hasPreferredProvider = preferredProvider - ? filteredModels.some((entry) => matchesPreferredProvider(entry.provider, preferredProvider)) + ? filteredModels.some((entry) => matchesPreferredProvider?.(entry.provider)) : false; const hasAuth = createProviderAuthChecker({ cfg, agentDir: params.agentDir }); @@ -465,7 +515,7 @@ export async function promptDefaultModel( allowKeep && hasPreferredProvider && preferredProvider && - !matchesPreferredProvider(resolved.provider, preferredProvider) + !matchesPreferredProvider?.(resolved.provider) ) { const firstModel = filteredModels[0]; if (firstModel) { @@ -570,6 +620,12 @@ export async function promptModelAllowlist(params: { defaultProvider: DEFAULT_PROVIDER, }); const hasAuth = createProviderAuthChecker({ cfg, agentDir: params.agentDir }); + const matchesPreferredProvider = preferredProvider + ? createPreferredProviderMatcher({ + preferredProvider, + cfg, + }) + : undefined; const options: WizardSelectOption[] = []; const seen = new Set(); @@ -577,11 +633,8 @@ export async function promptModelAllowlist(params: { ? catalog.filter((entry) => allowedKeySet.has(modelKey(entry.provider, entry.id))) : catalog; const filteredCatalog = - preferredProvider && - allowedCatalog.some((entry) => matchesPreferredProvider(entry.provider, preferredProvider)) - ? allowedCatalog.filter((entry) => - matchesPreferredProvider(entry.provider, preferredProvider), - ) + preferredProvider && allowedCatalog.some((entry) => matchesPreferredProvider?.(entry.provider)) + ? allowedCatalog.filter((entry) => matchesPreferredProvider?.(entry.provider)) : allowedCatalog; for (const entry of filteredCatalog) { diff --git a/src/plugins/provider-model-defaults.ts b/src/plugins/provider-model-defaults.ts index 9ce577fce10..d4b8464ae78 100644 --- a/src/plugins/provider-model-defaults.ts +++ b/src/plugins/provider-model-defaults.ts @@ -1,6 +1,10 @@ import type { OpenClawConfig } from "../config/config.js"; import { ensureModelAllowlistEntry } from "./provider-model-allowlist.js"; import { applyAgentDefaultPrimaryModel } from "./provider-model-primary.js"; +export { + applyOpencodeZenModelDefault, + OPENCODE_ZEN_DEFAULT_MODEL, +} from "../../extensions/opencode/api.js"; export const OPENAI_DEFAULT_MODEL = "openai/gpt-5.4"; export const OPENAI_CODEX_DEFAULT_MODEL = "openai-codex/gpt-5.4"; @@ -12,12 +16,6 @@ export const OPENAI_DEFAULT_EMBEDDING_MODEL = "text-embedding-3-small"; export const GOOGLE_GEMINI_DEFAULT_MODEL = "google/gemini-3.1-pro-preview"; export const OLLAMA_DEFAULT_BASE_URL = "http://127.0.0.1:11434"; export const OPENCODE_GO_DEFAULT_MODEL_REF = "opencode-go/kimi-k2.5"; -export const OPENCODE_ZEN_DEFAULT_MODEL = "opencode/claude-opus-4-6"; - -const LEGACY_OPENCODE_ZEN_DEFAULT_MODELS = new Set([ - "opencode/claude-opus-4-5", - "opencode-zen/claude-opus-4-5", -]); export function applyGoogleGeminiModelDefault(cfg: OpenClawConfig): { next: OpenClawConfig; @@ -75,14 +73,3 @@ export function applyOpencodeGoModelDefault(cfg: OpenClawConfig): { } { return applyAgentDefaultPrimaryModel({ cfg, model: OPENCODE_GO_DEFAULT_MODEL_REF }); } - -export function applyOpencodeZenModelDefault(cfg: OpenClawConfig): { - next: OpenClawConfig; - changed: boolean; -} { - return applyAgentDefaultPrimaryModel({ - cfg, - model: OPENCODE_ZEN_DEFAULT_MODEL, - legacyModels: LEGACY_OPENCODE_ZEN_DEFAULT_MODELS, - }); -}