test(live): read gateway provider models

This commit is contained in:
Vincent Koc
2026-05-06 12:00:03 -07:00
parent d47497c99f
commit 6587832f25

View File

@@ -1410,6 +1410,83 @@ type LiveModelRegistry = {
getAll(): Array<Model<Api>>;
};
function toGatewayLiveModel(params: {
provider: string;
providerConfig: ModelProviderConfig;
modelConfig: NonNullable<ModelProviderConfig["models"]>[number];
}): Model<Api> | null {
const id = params.modelConfig.id?.trim();
const api = params.modelConfig.api ?? params.providerConfig.api;
const baseUrl = params.modelConfig.baseUrl ?? params.providerConfig.baseUrl;
if (!id || !api || !baseUrl) {
return null;
}
const input = params.modelConfig.input.filter(
(value): value is "text" | "image" => value === "text" || value === "image",
);
return {
id,
name: params.modelConfig.name ?? id,
api: api as Api,
provider: params.provider,
baseUrl,
reasoning: params.modelConfig.reasoning ?? false,
thinkingLevelMap: params.modelConfig.thinkingLevelMap,
input: input.length > 0 ? input : ["text"],
cost: params.modelConfig.cost ?? {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
},
contextWindow: params.modelConfig.contextWindow ?? 128_000,
maxTokens: params.modelConfig.maxTokens ?? 16_384,
compat: params.modelConfig.compat ?? params.providerConfig.compat,
};
}
async function loadProviderScopedConfiguredModels(params: {
agentDir: string;
providerList: readonly string[];
}): Promise<Array<Model<Api>>> {
const modelsPath = path.join(params.agentDir, "models.json");
let parsed: { providers?: Record<string, ModelProviderConfig> };
try {
parsed = JSON.parse(await fs.readFile(modelsPath, "utf8")) as {
providers?: Record<string, ModelProviderConfig>;
};
} catch {
return [];
}
const providers = parsed.providers ?? {};
const models: Array<Model<Api>> = [];
const seen = new Set<string>();
for (const rawProvider of params.providerList) {
const normalizedProvider = normalizeProviderId(rawProvider);
const entry = Object.entries(providers).find(
([provider]) => normalizeProviderId(provider) === normalizedProvider,
);
if (!entry) {
continue;
}
const [provider, providerConfig] = entry;
for (const modelConfig of providerConfig.models ?? []) {
const model = toGatewayLiveModel({ provider, providerConfig, modelConfig });
if (!model) {
continue;
}
const key = `${normalizeProviderId(model.provider)}/${model.id.toLowerCase()}`;
if (seen.has(key)) {
continue;
}
seen.add(key);
models.push(model);
}
}
return models;
}
function loadProviderScopedBuiltInModels(providerList: readonly string[]): Array<Model<Api>> {
const models: Array<Model<Api>> = [];
const seen = new Set<string>();
@@ -1430,6 +1507,17 @@ function loadProviderScopedBuiltInModels(providerList: readonly string[]): Array
return models;
}
async function loadProviderScopedModels(params: {
agentDir: string;
providerList: readonly string[];
}): Promise<Array<Model<Api>>> {
const configured = await loadProviderScopedConfiguredModels(params);
if (configured.length > 0) {
return configured;
}
return loadProviderScopedBuiltInModels(params.providerList);
}
function createStaticLiveModelRegistry(models: Array<Model<Api>>): LiveModelRegistry {
return {
find(provider, modelId) {
@@ -2459,7 +2547,10 @@ describeLive("gateway live (dev agent, profile keys)", () => {
let all: Array<Model<Api>>;
if (useProviderScopedBuiltIns) {
logProgress("[all-models] loading provider-scoped model refs");
all = loadProviderScopedBuiltInModels(providerList);
all = await withGatewayLiveSetupTimeout(
loadProviderScopedModels({ agentDir, providerList }),
"[all-models] load provider-scoped model refs",
);
modelRegistry = createStaticLiveModelRegistry(all);
} else {
logProgress("[all-models] loading auth profiles");