diff --git a/src/gateway/server.talk-provider-contracts.test.ts b/src/gateway/server.talk-provider-contracts.test.ts index 6f2cf35da78..df2fcc65452 100644 --- a/src/gateway/server.talk-provider-contracts.test.ts +++ b/src/gateway/server.talk-provider-contracts.test.ts @@ -2,9 +2,23 @@ import { describe, expect, it } from "vitest"; import { vi } from "vitest"; import { createEmptyPluginRegistry } from "../plugins/registry-empty.js"; import { getActivePluginRegistry, setActivePluginRegistry } from "../plugins/runtime.js"; -import { withFetchPreconnect } from "../test-utils/fetch-mock.js"; import { talkHandlers } from "./server-methods/talk.js"; +const synthesizeSpeechMock = vi.hoisted(() => + vi.fn(async () => ({ + success: true, + audioBuffer: Buffer.from([4, 5, 6]), + provider: "elevenlabs", + outputFormat: "pcm_44100", + fileExtension: ".pcm", + voiceCompatible: false, + })), +); + +vi.mock("../tts/tts.js", () => ({ + synthesizeSpeech: synthesizeSpeechMock, +})); + type TalkSpeakPayload = { audioBase64?: string; provider?: string; @@ -71,104 +85,62 @@ describe("gateway talk provider contracts", () => { }, }); - const originalFetch = globalThis.fetch; - let fetchUrl: string | undefined; - const requestInits: RequestInit[] = []; - const fetchMock = vi.fn(async (input: RequestInfo | URL, init?: RequestInit) => { - fetchUrl = typeof input === "string" ? input : input instanceof URL ? input.href : input.url; - if (init) { - requestInits.push(init); - } - return new Response(new Uint8Array([4, 5, 6]), { status: 200 }); - }); - globalThis.fetch = withFetchPreconnect(fetchMock); - - try { - const res = await withSpeechProviders( - [ - { - pluginId: "elevenlabs-test", - source: "test", - provider: { - id: "elevenlabs", - label: "ElevenLabs", - isConfigured: () => true, - resolveTalkOverrides: ({ params }) => ({ - ...(typeof params.voiceId === "string" && params.voiceId.trim().length > 0 - ? { voiceId: params.voiceId.trim() } - : {}), - ...(typeof params.modelId === "string" && params.modelId.trim().length > 0 - ? { modelId: params.modelId.trim() } - : {}), - ...(typeof params.outputFormat === "string" && params.outputFormat.trim().length > 0 - ? { outputFormat: params.outputFormat.trim() } - : {}), - ...(typeof params.latencyTier === "number" - ? { latencyTier: params.latencyTier } - : {}), - }), - synthesize: async (req) => { - const config = req.providerConfig as Record; - const overrides = (req.providerOverrides ?? {}) as Record; - const voiceId = - (typeof overrides.voiceId === "string" && overrides.voiceId.trim().length > 0 - ? overrides.voiceId.trim() - : undefined) ?? - (typeof config.voiceId === "string" && config.voiceId.trim().length > 0 - ? config.voiceId.trim() - : undefined) ?? - DEFAULT_STUB_VOICE_ID; - const outputFormat = - typeof overrides.outputFormat === "string" && - overrides.outputFormat.trim().length > 0 - ? overrides.outputFormat.trim() - : "mp3"; - const url = new URL(`https://api.elevenlabs.io/v1/text-to-speech/${voiceId}`); - url.searchParams.set("output_format", outputFormat); - const response = await globalThis.fetch(url.href, { - method: "POST", - headers: { "content-type": "application/json" }, - body: JSON.stringify({ - text: req.text, - ...(typeof overrides.latencyTier === "number" - ? { latency_optimization_level: overrides.latencyTier } - : {}), - }), - }); - return { - audioBuffer: Buffer.from(await response.arrayBuffer()), - outputFormat, - fileExtension: outputFormat.startsWith("pcm") ? ".pcm" : ".mp3", - voiceCompatible: false, - }; - }, + const res = await withSpeechProviders( + [ + { + pluginId: "elevenlabs-test", + source: "test", + provider: { + id: "elevenlabs", + label: "ElevenLabs", + isConfigured: () => true, + resolveTalkOverrides: ({ params }) => ({ + ...(typeof params.voiceId === "string" && params.voiceId.trim().length > 0 + ? { voiceId: params.voiceId.trim() } + : {}), + ...(typeof params.outputFormat === "string" && params.outputFormat.trim().length > 0 + ? { outputFormat: params.outputFormat.trim() } + : {}), + ...(typeof params.latencyTier === "number" + ? { latencyTier: params.latencyTier } + : {}), + }), + synthesize: async () => { + throw new Error("synthesize should be mocked at the handler boundary"); }, }, - ], - async () => - await invokeTalkSpeakDirect({ - text: "Hello from talk mode.", - voiceId: "clawd", - outputFormat: "pcm_44100", - latencyTier: 3, - }), - ); - expect(res?.ok, JSON.stringify(res?.error)).toBe(true); - expect((res?.payload as TalkSpeakPayload | undefined)?.provider).toBe("elevenlabs"); - expect((res?.payload as TalkSpeakPayload | undefined)?.outputFormat).toBe("pcm_44100"); - expect((res?.payload as TalkSpeakPayload | undefined)?.audioBase64).toBe( - Buffer.from([4, 5, 6]).toString("base64"), - ); + }, + ], + async () => + await invokeTalkSpeakDirect({ + text: "Hello from talk mode.", + voiceId: "clawd", + outputFormat: "pcm_44100", + latencyTier: 3, + }), + ); + expect(res?.ok, JSON.stringify(res?.error)).toBe(true); + expect((res?.payload as TalkSpeakPayload | undefined)?.provider).toBe("elevenlabs"); + expect((res?.payload as TalkSpeakPayload | undefined)?.outputFormat).toBe("pcm_44100"); + expect((res?.payload as TalkSpeakPayload | undefined)?.audioBase64).toBe( + Buffer.from([4, 5, 6]).toString("base64"), + ); - expect(fetchMock).toHaveBeenCalled(); - expect(fetchUrl).toContain(`/v1/text-to-speech/${ALIAS_STUB_VOICE_ID}`); - expect(fetchUrl).toContain("output_format=pcm_44100"); - const init = requestInits[0]; - const bodyText = typeof init?.body === "string" ? init.body : "{}"; - const body = JSON.parse(bodyText) as Record; - expect(body.latency_optimization_level).toBe(3); - } finally { - globalThis.fetch = originalFetch; - } + expect(synthesizeSpeechMock).toHaveBeenCalledWith( + expect.objectContaining({ + text: "Hello from talk mode.", + overrides: { + provider: "elevenlabs", + providerOverrides: { + elevenlabs: { + voiceId: ALIAS_STUB_VOICE_ID, + outputFormat: "pcm_44100", + latencyTier: 3, + }, + }, + }, + disableFallback: true, + }), + ); }); }); diff --git a/src/gateway/server.talk-runtime.test.ts b/src/gateway/server.talk-runtime.test.ts index 5d0133c6fdc..4112a353c2c 100644 --- a/src/gateway/server.talk-runtime.test.ts +++ b/src/gateway/server.talk-runtime.test.ts @@ -1,8 +1,23 @@ -import { describe, expect, it } from "vitest"; +import { beforeEach, describe, expect, it, vi } from "vitest"; import { createEmptyPluginRegistry } from "../plugins/registry-empty.js"; import { getActivePluginRegistry, setActivePluginRegistry } from "../plugins/runtime.js"; import { talkHandlers } from "./server-methods/talk.js"; +const synthesizeSpeechMock = vi.hoisted(() => + vi.fn(async () => ({ + success: true, + audioBuffer: Buffer.from([7, 8, 9]), + provider: "acme", + outputFormat: "mp3", + fileExtension: ".mp3", + voiceCompatible: false, + })), +); + +vi.mock("../tts/tts.js", () => ({ + synthesizeSpeech: synthesizeSpeechMock, +})); + type TalkSpeakPayload = { audioBase64?: string; provider?: string; @@ -46,6 +61,79 @@ async function withSpeechProviders( } describe("gateway talk runtime", () => { + beforeEach(() => { + synthesizeSpeechMock.mockReset(); + synthesizeSpeechMock.mockResolvedValue({ + success: true, + audioBuffer: Buffer.from([7, 8, 9]), + provider: "acme", + outputFormat: "mp3", + fileExtension: ".mp3", + voiceCompatible: false, + }); + }); + + it("allows extension speech providers through the talk setup", async () => { + const { writeConfigFile } = await import("../config/config.js"); + await writeConfigFile({ + talk: { + provider: "acme", + providers: { + acme: { + voiceId: "plugin-voice", + }, + }, + }, + }); + + await withSpeechProviders( + [ + { + pluginId: "acme-plugin", + source: "test", + provider: { + id: "acme", + label: "Acme Speech", + isConfigured: () => true, + resolveTalkConfig: ({ talkProviderConfig }) => ({ + ...talkProviderConfig, + resolvedBy: "acme-test-provider", + }), + synthesize: async () => { + throw new Error("synthesize should be mocked at the handler boundary"); + }, + }, + }, + ], + async () => { + const res = await invokeTalkSpeakDirect({ + text: "Hello from talk mode.", + }); + expect(res?.ok, JSON.stringify(res?.error)).toBe(true); + expect(synthesizeSpeechMock).toHaveBeenCalledWith( + expect.objectContaining({ + text: "Hello from talk mode.", + overrides: { provider: "acme" }, + disableFallback: true, + cfg: expect.objectContaining({ + messages: expect.objectContaining({ + tts: expect.objectContaining({ + provider: "acme", + providers: expect.objectContaining({ + acme: expect.objectContaining({ + resolvedBy: "acme-test-provider", + voiceId: "plugin-voice", + }), + }), + }), + }), + }), + }), + ); + }, + ); + }); + it("allows extension speech providers through talk.speak", async () => { const { writeConfigFile } = await import("../config/config.js"); await writeConfigFile({ @@ -125,13 +213,15 @@ describe("gateway talk runtime", () => { id: "acme", label: "Acme Speech", isConfigured: () => true, - synthesize: async () => { - throw new Error("provider failed"); - }, + synthesize: async () => ({}) as never, }, }, ], async () => { + synthesizeSpeechMock.mockResolvedValue({ + success: false, + error: "provider failed", + }); const res = await invokeTalkSpeakDirect({ text: "Hello from talk mode." }); expect(res?.ok).toBe(false); expect(res?.error?.details).toEqual({ @@ -164,16 +254,19 @@ describe("gateway talk runtime", () => { id: "acme", label: "Acme Speech", isConfigured: () => true, - synthesize: async () => ({ - audioBuffer: Buffer.alloc(0), - outputFormat: "mp3", - fileExtension: ".mp3", - voiceCompatible: false, - }), + synthesize: async () => ({}) as never, }, }, ], async () => { + synthesizeSpeechMock.mockResolvedValue({ + success: true, + audioBuffer: Buffer.alloc(0), + provider: "acme", + outputFormat: "mp3", + fileExtension: ".mp3", + voiceCompatible: false, + }); const res = await invokeTalkSpeakDirect({ text: "Hello from talk mode." }); expect(res?.ok).toBe(false); expect(res?.error?.details).toEqual({