diff --git a/extensions/nvidia/index.test.ts b/extensions/nvidia/index.test.ts index 50821902b9f..47838170d0e 100644 --- a/extensions/nvidia/index.test.ts +++ b/extensions/nvidia/index.test.ts @@ -8,7 +8,7 @@ import { describe, expect, it } from "vitest"; import plugin from "./index.js"; type NvidiaManifest = { - providerAuthChoices?: Array<{ choiceId?: string; method?: string; provider?: string }>; + providerAuthChoices?: Array>; }; type RegisteredModelCatalogProvider = Parameters< ReturnType["registerModelCatalogProvider"] @@ -45,15 +45,21 @@ describe("nvidia provider hooks", () => { }); expect(choice?.provider.id).toBe("nvidia"); expect(choice?.method.id).toBe("api-key"); - expect(readManifest().providerAuthChoices).toEqual( - expect.arrayContaining([ - expect.objectContaining({ - provider: "nvidia", - method: "api-key", - choiceId: "nvidia-api-key", - }), - ]), - ); + expect(readManifest().providerAuthChoices).toStrictEqual([ + { + provider: "nvidia", + method: "api-key", + choiceId: "nvidia-api-key", + choiceLabel: "NVIDIA API key", + groupId: "nvidia", + groupLabel: "NVIDIA", + groupHint: "Direct API key", + optionKey: "nvidiaApiKey", + cliFlag: "--nvidia-api-key", + cliOption: "--nvidia-api-key ", + cliDescription: "NVIDIA API key", + }, + ]); }); it("keeps nvidia auth setup metadata aligned", async () => { @@ -85,20 +91,24 @@ describe("nvidia provider hooks", () => { it("keeps nvidia wizard setup metadata aligned", async () => { const provider = await registerNvidiaProvider(); - expect(provider.wizard?.setup).toMatchObject({ + expect(provider.wizard?.setup).toStrictEqual({ choiceId: "nvidia-api-key", choiceLabel: "NVIDIA API key", groupId: "nvidia", groupLabel: "NVIDIA", groupHint: "Direct API key", methodId: "api-key", + modelSelection: { + promptWhenAuthChoiceProvided: true, + allowKeepCurrent: false, + }, }); }); it("keeps nvidia model picker metadata aligned", async () => { const provider = await registerNvidiaProvider(); - expect(provider.wizard?.modelPicker).toMatchObject({ + expect(provider.wizard?.modelPicker).toStrictEqual({ label: "NVIDIA (custom)", hint: "Use NVIDIA-hosted open models", methodId: "api-key", @@ -162,9 +172,9 @@ describe("nvidia provider hooks", () => { }), ); - expect(registeredProviders).toContain("nvidia"); - expect(registeredModelCatalogProviders.map((provider) => provider.provider)).toContain( + expect(registeredProviders).toStrictEqual(["nvidia"]); + expect(registeredModelCatalogProviders.map((provider) => provider.provider)).toStrictEqual([ "nvidia", - ); + ]); }); });