diff --git a/src/agents/pi-model-discovery.ts b/src/agents/pi-model-discovery.ts index 85768999c02..c50ec909563 100644 --- a/src/agents/pi-model-discovery.ts +++ b/src/agents/pi-model-discovery.ts @@ -19,6 +19,7 @@ import { ensureAuthProfileStore } from "./auth-profiles/store.js"; import { resolveProviderEnvApiKeyCandidates } from "./model-auth-env-vars.js"; import { resolveEnvApiKey } from "./model-auth-env.js"; import { resolvePiCredentialMapFromStore, type PiCredentialMap } from "./pi-auth-credentials.js"; +import { normalizeProviderId } from "./provider-id.js"; const PiAuthStorageClass = PiCodingAgent.AuthStorage; const PiModelRegistryClass = PiCodingAgent.ModelRegistry; @@ -33,6 +34,10 @@ type DiscoveredProviderRuntimeModelLike = Omit api?: string | null; }; +type DiscoverModelsOptions = { + providerFilter?: string; +}; + type InMemoryAuthStorageBackendLike = { withLock( update: (current: string) => { @@ -136,16 +141,24 @@ function createOpenClawModelRegistry( authStorage: PiAuthStorage, modelsJsonPath: string, agentDir: string, + options?: DiscoverModelsOptions, ): PiModelRegistry { const registry = instantiatePiModelRegistry(authStorage, modelsJsonPath); const getAll = registry.getAll.bind(registry); const getAvailable = registry.getAvailable.bind(registry); const find = registry.find.bind(registry); + const providerFilter = options?.providerFilter ? normalizeProviderId(options.providerFilter) : ""; + const matchesProviderFilter = (entry: Model) => + !providerFilter || normalizeProviderId(entry.provider) === providerFilter; registry.getAll = () => - getAll().map((entry: Model) => normalizeDiscoveredPiModel(entry, agentDir)); + getAll() + .filter((entry: Model) => matchesProviderFilter(entry)) + .map((entry: Model) => normalizeDiscoveredPiModel(entry, agentDir)); registry.getAvailable = () => - getAvailable().map((entry: Model) => normalizeDiscoveredPiModel(entry, agentDir)); + getAvailable() + .filter((entry: Model) => matchesProviderFilter(entry)) + .map((entry: Model) => normalizeDiscoveredPiModel(entry, agentDir)); registry.find = (provider: string, modelId: string) => normalizeDiscoveredPiModel(find(provider, modelId), agentDir); @@ -299,6 +312,15 @@ export function discoverAuthStorage(agentDir: string): PiAuthStorage { return createAuthStorage(PiAuthStorageClass, authPath, credentials); } -export function discoverModels(authStorage: PiAuthStorage, agentDir: string): PiModelRegistry { - return createOpenClawModelRegistry(authStorage, path.join(agentDir, "models.json"), agentDir); +export function discoverModels( + authStorage: PiAuthStorage, + agentDir: string, + options?: DiscoverModelsOptions, +): PiModelRegistry { + return createOpenClawModelRegistry( + authStorage, + path.join(agentDir, "models.json"), + agentDir, + options, + ); } diff --git a/src/commands/models.list.e2e.test.ts b/src/commands/models.list.e2e.test.ts index 8f83d87d5f3..fa1e43b71fd 100644 --- a/src/commands/models.list.e2e.test.ts +++ b/src/commands/models.list.e2e.test.ts @@ -17,6 +17,7 @@ const listProfilesForProvider = vi.fn().mockReturnValue([]); const resolveEnvApiKey = vi.fn().mockReturnValue(undefined); const resolveAwsSdkEnvVarName = vi.fn().mockReturnValue(undefined); const hasUsableCustomProviderApiKey = vi.fn().mockReturnValue(false); +const loadModelCatalog = vi.fn(async () => []); const loadProviderCatalogModelsForList = vi.fn<() => Promise>>>( async () => [], ); @@ -74,7 +75,7 @@ vi.mock("./models/list.runtime.js", () => { resolveEnvApiKey, resolveAwsSdkEnvVarName, hasUsableCustomProviderApiKey, - loadModelCatalog: vi.fn(async () => []), + loadModelCatalog, loadProviderCatalogModelsForList, discoverAuthStorage: () => ({}) as unknown, discoverModels: () => new MockModelRegistry() as unknown, @@ -136,6 +137,8 @@ beforeEach(() => { getRuntimeConfig.mockReturnValue({}); listProfilesForProvider.mockReturnValue([]); ensureOpenClawModelsJson.mockClear(); + loadModelCatalog.mockClear(); + loadModelCatalog.mockResolvedValue([]); loadProviderCatalogModelsForList.mockReset(); loadProviderCatalogModelsForList.mockResolvedValue([]); shouldSuppressBuiltInModel.mockReset(); @@ -359,6 +362,7 @@ describe("models list/status", () => { await modelsListCommand({ all: true, provider: "moonshot", json: true }, runtime); const payload = parseJsonLog(runtime); + expect(loadModelCatalog).toHaveBeenCalledTimes(1); expect(payload.models).toEqual([ expect.objectContaining({ key: "moonshot/kimi-k2.6", @@ -369,6 +373,21 @@ describe("models list/status", () => { ]); }); + it("models list rejects provider display labels", async () => { + setDefaultZaiRegistry({ available: false }); + const runtime = makeRuntime(); + + await modelsListCommand({ all: true, provider: "Moonshot AI", json: true }, runtime); + + expect(runtime.error).toHaveBeenCalledWith( + 'Invalid provider filter "Moonshot AI". Use a provider id such as "moonshot", not a display label.', + ); + expect(runtime.log).not.toHaveBeenCalled(); + expect(loadModelCatalog).not.toHaveBeenCalled(); + expect(loadProviderCatalogModelsForList).not.toHaveBeenCalled(); + expect(process.exitCode).toBe(1); + }); + it("models list all local skips unauthenticated provider catalog rows", async () => { setDefaultZaiRegistry({ available: false }); loadProviderCatalogModelsForList.mockResolvedValueOnce([MOONSHOT_MODEL]); diff --git a/src/commands/models/list.list-command.forward-compat.test.ts b/src/commands/models/list.list-command.forward-compat.test.ts index e8cb02333d0..a486477b3dd 100644 --- a/src/commands/models/list.list-command.forward-compat.test.ts +++ b/src/commands/models/list.list-command.forward-compat.test.ts @@ -206,6 +206,19 @@ beforeEach(() => { describe("modelsListCommand forward-compat", () => { describe("configured rows", () => { + it("passes provider filters into registry loading before row assembly", async () => { + const runtime = createRuntime(); + + await modelsListCommand({ json: true, provider: "moonshot" }, runtime as never); + + expect(mocks.loadModelRegistry).toHaveBeenCalledWith( + mocks.resolvedConfig, + expect.objectContaining({ + providerFilter: "moonshot", + }), + ); + }); + it("does not mark configured codex model as missing when forward-compat can build a fallback", async () => { const runtime = createRuntime(); diff --git a/src/commands/models/list.list-command.ts b/src/commands/models/list.list-command.ts index ab5768c8d15..b6afc796360 100644 --- a/src/commands/models/list.list-command.ts +++ b/src/commands/models/list.list-command.ts @@ -29,6 +29,24 @@ export async function modelsListCommand( runtime: RuntimeEnv, ) { ensureFlagCompatibility(opts); + const providerFilter = (() => { + const raw = opts.provider?.trim(); + if (!raw) { + return undefined; + } + if (/\s/u.test(raw)) { + runtime.error( + `Invalid provider filter "${raw}". Use a provider id such as "moonshot", not a display label.`, + ); + process.exitCode = 1; + return null; + } + const parsed = parseModelRef(`${raw}/_`, DEFAULT_PROVIDER, DISPLAY_MODEL_PARSE_OPTIONS); + return parsed?.provider ?? normalizeLowercaseStringOrEmpty(raw); + })(); + if (providerFilter === null) { + return; + } const { ensureAuthProfileStore, ensureOpenClawModelsJson, resolveOpenClawAgentDir } = await import("./list.runtime.js"); const { sourceConfig, resolvedConfig: cfg } = await loadModelsConfigWithSource({ @@ -37,14 +55,6 @@ export async function modelsListCommand( }); const authStore = ensureAuthProfileStore(); const agentDir = resolveOpenClawAgentDir(); - const providerFilter = (() => { - const raw = opts.provider?.trim(); - if (!raw) { - return undefined; - } - const parsed = parseModelRef(`${raw}/_`, DEFAULT_PROVIDER, DISPLAY_MODEL_PARSE_OPTIONS); - return parsed?.provider ?? normalizeLowercaseStringOrEmpty(raw); - })(); let modelRegistry: ModelRegistry | undefined; let discoveredKeys = new Set(); @@ -56,7 +66,7 @@ export async function modelsListCommand( // before building the read-only model registry view. if (!useProviderCatalogFastPath) { await ensureOpenClawModelsJson(sourceConfig ?? cfg); - const loaded = await loadListModelRegistry(cfg, { sourceConfig }); + const loaded = await loadListModelRegistry(cfg, { sourceConfig, providerFilter }); modelRegistry = loaded.registry; discoveredKeys = loaded.discoveredKeys; availableKeys = loaded.availableKeys; diff --git a/src/commands/models/list.registry.ts b/src/commands/models/list.registry.ts index 3ae30979ab4..40b64cf46bd 100644 --- a/src/commands/models/list.registry.ts +++ b/src/commands/models/list.registry.ts @@ -110,11 +110,13 @@ function loadAvailableModels(registry: ModelRegistry, cfg: OpenClawConfig): Mode export async function loadModelRegistry( cfg: OpenClawConfig, - _opts?: { sourceConfig?: OpenClawConfig }, + opts?: { sourceConfig?: OpenClawConfig; providerFilter?: string }, ) { const agentDir = resolveOpenClawAgentDir(); const authStorage = discoverAuthStorage(agentDir); - const registry = discoverModels(authStorage, agentDir); + const registry = discoverModels(authStorage, agentDir, { + providerFilter: opts?.providerFilter, + }); const models = registry.getAll().filter( (model) => !shouldSuppressBuiltInModel({ diff --git a/src/commands/models/list.rows.ts b/src/commands/models/list.rows.ts index f98b7db891e..89aaa4fbd79 100644 --- a/src/commands/models/list.rows.ts +++ b/src/commands/models/list.rows.ts @@ -77,7 +77,7 @@ function shouldSuppressListModel(params: { export async function loadListModelRegistry( cfg: OpenClawConfig, - opts?: { sourceConfig?: OpenClawConfig }, + opts?: { sourceConfig?: OpenClawConfig; providerFilter?: string }, ) { const loaded = await loadModelRegistry(cfg, opts); return {