diff --git a/src/gateway/gateway-models.profiles.live.test.ts b/src/gateway/gateway-models.profiles.live.test.ts index d9d27512ed1..b5cc8ca0d05 100644 --- a/src/gateway/gateway-models.profiles.live.test.ts +++ b/src/gateway/gateway-models.profiles.live.test.ts @@ -1410,6 +1410,83 @@ type LiveModelRegistry = { getAll(): Array>; }; +function toGatewayLiveModel(params: { + provider: string; + providerConfig: ModelProviderConfig; + modelConfig: NonNullable[number]; +}): Model | null { + const id = params.modelConfig.id?.trim(); + const api = params.modelConfig.api ?? params.providerConfig.api; + const baseUrl = params.modelConfig.baseUrl ?? params.providerConfig.baseUrl; + if (!id || !api || !baseUrl) { + return null; + } + const input = params.modelConfig.input.filter( + (value): value is "text" | "image" => value === "text" || value === "image", + ); + return { + id, + name: params.modelConfig.name ?? id, + api: api as Api, + provider: params.provider, + baseUrl, + reasoning: params.modelConfig.reasoning ?? false, + thinkingLevelMap: params.modelConfig.thinkingLevelMap, + input: input.length > 0 ? input : ["text"], + cost: params.modelConfig.cost ?? { + input: 0, + output: 0, + cacheRead: 0, + cacheWrite: 0, + }, + contextWindow: params.modelConfig.contextWindow ?? 128_000, + maxTokens: params.modelConfig.maxTokens ?? 16_384, + compat: params.modelConfig.compat ?? params.providerConfig.compat, + }; +} + +async function loadProviderScopedConfiguredModels(params: { + agentDir: string; + providerList: readonly string[]; +}): Promise>> { + const modelsPath = path.join(params.agentDir, "models.json"); + let parsed: { providers?: Record }; + try { + parsed = JSON.parse(await fs.readFile(modelsPath, "utf8")) as { + providers?: Record; + }; + } catch { + return []; + } + + const providers = parsed.providers ?? {}; + const models: Array> = []; + const seen = new Set(); + for (const rawProvider of params.providerList) { + const normalizedProvider = normalizeProviderId(rawProvider); + const entry = Object.entries(providers).find( + ([provider]) => normalizeProviderId(provider) === normalizedProvider, + ); + if (!entry) { + continue; + } + const [provider, providerConfig] = entry; + for (const modelConfig of providerConfig.models ?? []) { + const model = toGatewayLiveModel({ provider, providerConfig, modelConfig }); + if (!model) { + continue; + } + const key = `${normalizeProviderId(model.provider)}/${model.id.toLowerCase()}`; + if (seen.has(key)) { + continue; + } + seen.add(key); + models.push(model); + } + } + return models; +} + function loadProviderScopedBuiltInModels(providerList: readonly string[]): Array> { const models: Array> = []; const seen = new Set(); @@ -1430,6 +1507,17 @@ function loadProviderScopedBuiltInModels(providerList: readonly string[]): Array return models; } +async function loadProviderScopedModels(params: { + agentDir: string; + providerList: readonly string[]; +}): Promise>> { + const configured = await loadProviderScopedConfiguredModels(params); + if (configured.length > 0) { + return configured; + } + return loadProviderScopedBuiltInModels(params.providerList); +} + function createStaticLiveModelRegistry(models: Array>): LiveModelRegistry { return { find(provider, modelId) { @@ -2459,7 +2547,10 @@ describeLive("gateway live (dev agent, profile keys)", () => { let all: Array>; if (useProviderScopedBuiltIns) { logProgress("[all-models] loading provider-scoped model refs"); - all = loadProviderScopedBuiltInModels(providerList); + all = await withGatewayLiveSetupTimeout( + loadProviderScopedModels({ agentDir, providerList }), + "[all-models] load provider-scoped model refs", + ); modelRegistry = createStaticLiveModelRegistry(all); } else { logProgress("[all-models] loading auth profiles");