fix(tools): normalize media model lookups (#66422)

* fix(tools): normalize media model lookups

* Update CHANGELOG.md
This commit is contained in:
Vincent Koc
2026-04-14 09:23:52 +01:00
committed by GitHub
parent aa0dc118f1
commit e59f5ecac3
3 changed files with 70 additions and 2 deletions

View File

@@ -0,0 +1,62 @@
import { beforeAll, beforeEach, describe, expect, it, vi } from "vitest";
const state = vi.hoisted(() => ({
normalizeModelRefMock: vi.fn(),
}));
vi.mock("../model-selection.js", async () => {
const actual =
await vi.importActual<typeof import("../model-selection.js")>("../model-selection.js");
return {
...actual,
normalizeModelRef: (...args: Parameters<typeof actual.normalizeModelRef>) =>
state.normalizeModelRefMock(...args),
};
});
let resolveModelFromRegistry: typeof import("./media-tool-shared.js").resolveModelFromRegistry;
describe("resolveModelFromRegistry", () => {
beforeAll(async () => {
({ resolveModelFromRegistry } = await import("./media-tool-shared.js"));
});
beforeEach(() => {
state.normalizeModelRefMock
.mockReset()
.mockImplementation((provider: string, model: string) => ({
provider: provider.trim().toLowerCase(),
model: model.trim().replace(/^ollama\//, ""),
}));
});
it("normalizes provider and model refs before registry lookup", () => {
const foundModel = { provider: "ollama", id: "qwen3.5:397b-cloud" };
const find = vi.fn(() => foundModel);
const result = resolveModelFromRegistry({
modelRegistry: { find },
provider: " OLLAMA ",
modelId: "ollama/qwen3.5:397b-cloud",
});
expect(state.normalizeModelRefMock).toHaveBeenCalledWith(
" OLLAMA ",
"ollama/qwen3.5:397b-cloud",
);
expect(find).toHaveBeenCalledWith("ollama", "qwen3.5:397b-cloud");
expect(result).toBe(foundModel);
});
it("reports the normalized ref when the registry lookup misses", () => {
const find = vi.fn(() => null);
expect(() =>
resolveModelFromRegistry({
modelRegistry: { find },
provider: " OLLAMA ",
modelId: "ollama/qwen3.5:397b-cloud",
}),
).toThrow("Unknown model: ollama/qwen3.5:397b-cloud");
});
});

View File

@@ -7,6 +7,7 @@ import {
normalizeOptionalLowercaseString,
normalizeOptionalString,
} from "../../shared/string-coerce.js";
import { normalizeModelRef } from "../model-selection.js";
import { normalizeProviderId } from "../provider-id.js";
import { ToolInputError, readStringArrayParam, readStringParam } from "./common.js";
import type { ImageModelConfig } from "./image-tool.helpers.js";
@@ -400,9 +401,13 @@ export function resolveModelFromRegistry(params: {
provider: string;
modelId: string;
}): Model<Api> {
const model = params.modelRegistry.find(params.provider, params.modelId) as Model<Api> | null;
const resolvedRef = normalizeModelRef(params.provider, params.modelId);
const model = params.modelRegistry.find(
resolvedRef.provider,
resolvedRef.model,
) as Model<Api> | null;
if (!model) {
throw new Error(`Unknown model: ${params.provider}/${params.modelId}`);
throw new Error(`Unknown model: ${resolvedRef.provider}/${resolvedRef.model}`);
}
return model;
}