diff --git a/src/image-generation/provider-registry.test.ts b/src/image-generation/provider-registry.test.ts index 6c4ed231fc1..7cfa36343b3 100644 --- a/src/image-generation/provider-registry.test.ts +++ b/src/image-generation/provider-registry.test.ts @@ -34,6 +34,15 @@ async function loadProviderRegistry() { return await import("./provider-registry.js"); } +function requireImageProvider(id: string): ImageGenerationProviderPlugin { + const provider = getImageGenerationProvider(id); + expect(provider).toBeDefined(); + if (!provider) { + throw new Error(`expected image generation provider ${id}`); + } + return provider; +} + describe("image-generation provider registry", () => { beforeEach(async () => { resolvePluginCapabilityProvidersMock.mockReset(); @@ -56,7 +65,7 @@ describe("image-generation provider registry", () => { const provider = getImageGenerationProvider("custom-image"); - expect(provider?.id).toBe("custom-image"); + expect(provider).toMatchObject({ id: "custom-image" }); expect(resolvePluginCapabilityProvidersMock).toHaveBeenCalledWith({ key: "imageGenerationProviders", cfg: undefined, @@ -72,6 +81,6 @@ describe("image-generation provider registry", () => { expect(listImageGenerationProviders().map((provider) => provider.id)).toEqual(["safe-image"]); expect(getImageGenerationProvider("__proto__")).toBeUndefined(); expect(getImageGenerationProvider("constructor")).toBeUndefined(); - expect(getImageGenerationProvider("safe-alias")?.id).toBe("safe-image"); + expect(requireImageProvider("safe-alias")).toMatchObject({ id: "safe-image" }); }); });