perf: scope models list discovery by provider

This commit is contained in:
Shakker
2026-04-23 06:35:59 +01:00
committed by Shakker
parent 3ec5558f53
commit db1e4f811d
6 changed files with 83 additions and 17 deletions

View File

@@ -19,6 +19,7 @@ import { ensureAuthProfileStore } from "./auth-profiles/store.js";
import { resolveProviderEnvApiKeyCandidates } from "./model-auth-env-vars.js";
import { resolveEnvApiKey } from "./model-auth-env.js";
import { resolvePiCredentialMapFromStore, type PiCredentialMap } from "./pi-auth-credentials.js";
import { normalizeProviderId } from "./provider-id.js";
const PiAuthStorageClass = PiCodingAgent.AuthStorage;
const PiModelRegistryClass = PiCodingAgent.ModelRegistry;
@@ -33,6 +34,10 @@ type DiscoveredProviderRuntimeModelLike = Omit<ProviderRuntimeModelLike, "api">
api?: string | null;
};
type DiscoverModelsOptions = {
providerFilter?: string;
};
type InMemoryAuthStorageBackendLike = {
withLock<T>(
update: (current: string) => {
@@ -136,16 +141,24 @@ function createOpenClawModelRegistry(
authStorage: PiAuthStorage,
modelsJsonPath: string,
agentDir: string,
options?: DiscoverModelsOptions,
): PiModelRegistry {
const registry = instantiatePiModelRegistry(authStorage, modelsJsonPath);
const getAll = registry.getAll.bind(registry);
const getAvailable = registry.getAvailable.bind(registry);
const find = registry.find.bind(registry);
const providerFilter = options?.providerFilter ? normalizeProviderId(options.providerFilter) : "";
const matchesProviderFilter = (entry: Model<Api>) =>
!providerFilter || normalizeProviderId(entry.provider) === providerFilter;
registry.getAll = () =>
getAll().map((entry: Model<Api>) => normalizeDiscoveredPiModel(entry, agentDir));
getAll()
.filter((entry: Model<Api>) => matchesProviderFilter(entry))
.map((entry: Model<Api>) => normalizeDiscoveredPiModel(entry, agentDir));
registry.getAvailable = () =>
getAvailable().map((entry: Model<Api>) => normalizeDiscoveredPiModel(entry, agentDir));
getAvailable()
.filter((entry: Model<Api>) => matchesProviderFilter(entry))
.map((entry: Model<Api>) => normalizeDiscoveredPiModel(entry, agentDir));
registry.find = (provider: string, modelId: string) =>
normalizeDiscoveredPiModel(find(provider, modelId), agentDir);
@@ -299,6 +312,15 @@ export function discoverAuthStorage(agentDir: string): PiAuthStorage {
return createAuthStorage(PiAuthStorageClass, authPath, credentials);
}
export function discoverModels(authStorage: PiAuthStorage, agentDir: string): PiModelRegistry {
return createOpenClawModelRegistry(authStorage, path.join(agentDir, "models.json"), agentDir);
export function discoverModels(
authStorage: PiAuthStorage,
agentDir: string,
options?: DiscoverModelsOptions,
): PiModelRegistry {
return createOpenClawModelRegistry(
authStorage,
path.join(agentDir, "models.json"),
agentDir,
options,
);
}

View File

@@ -17,6 +17,7 @@ const listProfilesForProvider = vi.fn().mockReturnValue([]);
const resolveEnvApiKey = vi.fn().mockReturnValue(undefined);
const resolveAwsSdkEnvVarName = vi.fn().mockReturnValue(undefined);
const hasUsableCustomProviderApiKey = vi.fn().mockReturnValue(false);
const loadModelCatalog = vi.fn(async () => []);
const loadProviderCatalogModelsForList = vi.fn<() => Promise<Array<Record<string, unknown>>>>(
async () => [],
);
@@ -74,7 +75,7 @@ vi.mock("./models/list.runtime.js", () => {
resolveEnvApiKey,
resolveAwsSdkEnvVarName,
hasUsableCustomProviderApiKey,
loadModelCatalog: vi.fn(async () => []),
loadModelCatalog,
loadProviderCatalogModelsForList,
discoverAuthStorage: () => ({}) as unknown,
discoverModels: () => new MockModelRegistry() as unknown,
@@ -136,6 +137,8 @@ beforeEach(() => {
getRuntimeConfig.mockReturnValue({});
listProfilesForProvider.mockReturnValue([]);
ensureOpenClawModelsJson.mockClear();
loadModelCatalog.mockClear();
loadModelCatalog.mockResolvedValue([]);
loadProviderCatalogModelsForList.mockReset();
loadProviderCatalogModelsForList.mockResolvedValue([]);
shouldSuppressBuiltInModel.mockReset();
@@ -359,6 +362,7 @@ describe("models list/status", () => {
await modelsListCommand({ all: true, provider: "moonshot", json: true }, runtime);
const payload = parseJsonLog(runtime);
expect(loadModelCatalog).toHaveBeenCalledTimes(1);
expect(payload.models).toEqual([
expect.objectContaining({
key: "moonshot/kimi-k2.6",
@@ -369,6 +373,21 @@ describe("models list/status", () => {
]);
});
it("models list rejects provider display labels", async () => {
setDefaultZaiRegistry({ available: false });
const runtime = makeRuntime();
await modelsListCommand({ all: true, provider: "Moonshot AI", json: true }, runtime);
expect(runtime.error).toHaveBeenCalledWith(
'Invalid provider filter "Moonshot AI". Use a provider id such as "moonshot", not a display label.',
);
expect(runtime.log).not.toHaveBeenCalled();
expect(loadModelCatalog).not.toHaveBeenCalled();
expect(loadProviderCatalogModelsForList).not.toHaveBeenCalled();
expect(process.exitCode).toBe(1);
});
it("models list all local skips unauthenticated provider catalog rows", async () => {
setDefaultZaiRegistry({ available: false });
loadProviderCatalogModelsForList.mockResolvedValueOnce([MOONSHOT_MODEL]);

View File

@@ -206,6 +206,19 @@ beforeEach(() => {
describe("modelsListCommand forward-compat", () => {
describe("configured rows", () => {
it("passes provider filters into registry loading before row assembly", async () => {
const runtime = createRuntime();
await modelsListCommand({ json: true, provider: "moonshot" }, runtime as never);
expect(mocks.loadModelRegistry).toHaveBeenCalledWith(
mocks.resolvedConfig,
expect.objectContaining({
providerFilter: "moonshot",
}),
);
});
it("does not mark configured codex model as missing when forward-compat can build a fallback", async () => {
const runtime = createRuntime();

View File

@@ -29,6 +29,24 @@ export async function modelsListCommand(
runtime: RuntimeEnv,
) {
ensureFlagCompatibility(opts);
const providerFilter = (() => {
const raw = opts.provider?.trim();
if (!raw) {
return undefined;
}
if (/\s/u.test(raw)) {
runtime.error(
`Invalid provider filter "${raw}". Use a provider id such as "moonshot", not a display label.`,
);
process.exitCode = 1;
return null;
}
const parsed = parseModelRef(`${raw}/_`, DEFAULT_PROVIDER, DISPLAY_MODEL_PARSE_OPTIONS);
return parsed?.provider ?? normalizeLowercaseStringOrEmpty(raw);
})();
if (providerFilter === null) {
return;
}
const { ensureAuthProfileStore, ensureOpenClawModelsJson, resolveOpenClawAgentDir } =
await import("./list.runtime.js");
const { sourceConfig, resolvedConfig: cfg } = await loadModelsConfigWithSource({
@@ -37,14 +55,6 @@ export async function modelsListCommand(
});
const authStore = ensureAuthProfileStore();
const agentDir = resolveOpenClawAgentDir();
const providerFilter = (() => {
const raw = opts.provider?.trim();
if (!raw) {
return undefined;
}
const parsed = parseModelRef(`${raw}/_`, DEFAULT_PROVIDER, DISPLAY_MODEL_PARSE_OPTIONS);
return parsed?.provider ?? normalizeLowercaseStringOrEmpty(raw);
})();
let modelRegistry: ModelRegistry | undefined;
let discoveredKeys = new Set<string>();
@@ -56,7 +66,7 @@ export async function modelsListCommand(
// before building the read-only model registry view.
if (!useProviderCatalogFastPath) {
await ensureOpenClawModelsJson(sourceConfig ?? cfg);
const loaded = await loadListModelRegistry(cfg, { sourceConfig });
const loaded = await loadListModelRegistry(cfg, { sourceConfig, providerFilter });
modelRegistry = loaded.registry;
discoveredKeys = loaded.discoveredKeys;
availableKeys = loaded.availableKeys;

View File

@@ -110,11 +110,13 @@ function loadAvailableModels(registry: ModelRegistry, cfg: OpenClawConfig): Mode
export async function loadModelRegistry(
cfg: OpenClawConfig,
_opts?: { sourceConfig?: OpenClawConfig },
opts?: { sourceConfig?: OpenClawConfig; providerFilter?: string },
) {
const agentDir = resolveOpenClawAgentDir();
const authStorage = discoverAuthStorage(agentDir);
const registry = discoverModels(authStorage, agentDir);
const registry = discoverModels(authStorage, agentDir, {
providerFilter: opts?.providerFilter,
});
const models = registry.getAll().filter(
(model) =>
!shouldSuppressBuiltInModel({

View File

@@ -77,7 +77,7 @@ function shouldSuppressListModel(params: {
export async function loadListModelRegistry(
cfg: OpenClawConfig,
opts?: { sourceConfig?: OpenClawConfig },
opts?: { sourceConfig?: OpenClawConfig; providerFilter?: string },
) {
const loaded = await loadModelRegistry(cfg, opts);
return {