diff --git a/extensions/openai/realtime-transcription-provider.test.ts b/extensions/openai/realtime-transcription-provider.test.ts index 5fb6829df3a..6ee64399cd1 100644 --- a/extensions/openai/realtime-transcription-provider.test.ts +++ b/extensions/openai/realtime-transcription-provider.test.ts @@ -1,7 +1,88 @@ -import { describe, expect, it } from "vitest"; +import { beforeEach, describe, expect, it, vi } from "vitest"; import { buildOpenAIRealtimeTranscriptionProvider } from "./realtime-transcription-provider.js"; +const { FakeWebSocket } = vi.hoisted(() => { + type Listener = (...args: unknown[]) => void; + + class MockWebSocket { + static readonly OPEN = 1; + static readonly CLOSED = 3; + static instances: MockWebSocket[] = []; + + readonly listeners = new Map(); + readyState = 0; + sent: string[] = []; + closed = false; + + constructor() { + MockWebSocket.instances.push(this); + } + + on(event: string, listener: Listener): this { + const listeners = this.listeners.get(event) ?? []; + listeners.push(listener); + this.listeners.set(event, listeners); + return this; + } + + emit(event: string, ...args: unknown[]): void { + for (const listener of this.listeners.get(event) ?? []) { + listener(...args); + } + } + + send(payload: string): void { + this.sent.push(payload); + } + + close(code?: number, reason?: string): void { + this.closed = true; + this.readyState = MockWebSocket.CLOSED; + this.emit("close", code ?? 1000, Buffer.from(reason ?? "")); + } + } + + return { FakeWebSocket: MockWebSocket }; +}); + +vi.mock("ws", () => ({ + default: FakeWebSocket, +})); + +type FakeWebSocketInstance = InstanceType; +type SentRealtimeEvent = { + type: string; + audio?: string; + session?: { + type?: string; + audio?: { + input?: { + format?: { type?: string }; + transcription?: { + model?: string; + language?: string; + prompt?: string; + }; + turn_detection?: { + type?: string; + threshold?: number; + prefix_padding_ms?: number; + silence_duration_ms?: number; + }; + }; + }; + }; +}; + +function parseSent(socket: FakeWebSocketInstance): SentRealtimeEvent[] { + return socket.sent.map((payload) => JSON.parse(payload) as SentRealtimeEvent); +} + describe("buildOpenAIRealtimeTranscriptionProvider", () => { + beforeEach(() => { + FakeWebSocket.instances = []; + }); + it("normalizes OpenAI config defaults", () => { const provider = buildOpenAIRealtimeTranscriptionProvider(); const resolved = provider.resolveConfig?.({ @@ -70,4 +151,88 @@ describe("buildOpenAIRealtimeTranscriptionProvider", () => { const provider = buildOpenAIRealtimeTranscriptionProvider(); expect(provider.aliases).toContain("openai-realtime"); }); + + it("waits for the OpenAI session update before draining audio", async () => { + const provider = buildOpenAIRealtimeTranscriptionProvider(); + const session = provider.createSession({ + providerConfig: { + apiKey: "sk-test", // pragma: allowlist secret + language: "en", + model: "gpt-4o-transcribe", + prompt: "expect OpenClaw product names", + silenceDurationMs: 900, + vadThreshold: 0.45, + }, + }); + + const connecting = session.connect(); + const socket = FakeWebSocket.instances[0]; + if (!socket) { + throw new Error("expected session to create a websocket"); + } + + socket.readyState = FakeWebSocket.OPEN; + socket.emit("open"); + session.sendAudio(Buffer.from("before-ready")); + + expect(session.isConnected()).toBe(false); + expect(parseSent(socket)).toEqual([ + { + type: "transcription_session.update", + session: { + type: "transcription", + audio: { + input: { + format: { type: "audio/pcmu" }, + transcription: { + model: "gpt-4o-transcribe", + language: "en", + prompt: "expect OpenClaw product names", + }, + turn_detection: { + type: "server_vad", + threshold: 0.45, + prefix_padding_ms: 300, + silence_duration_ms: 900, + }, + }, + }, + }, + }, + ]); + + socket.emit("message", Buffer.from(JSON.stringify({ type: "session.updated" }))); + await connecting; + + expect(session.isConnected()).toBe(true); + expect(parseSent(socket)).toEqual([ + { + type: "transcription_session.update", + session: { + type: "transcription", + audio: { + input: { + format: { type: "audio/pcmu" }, + transcription: { + model: "gpt-4o-transcribe", + language: "en", + prompt: "expect OpenClaw product names", + }, + turn_detection: { + type: "server_vad", + threshold: 0.45, + prefix_padding_ms: 300, + silence_duration_ms: 900, + }, + }, + }, + }, + }, + { + type: "input_audio_buffer.append", + audio: Buffer.from("before-ready").toString("base64"), + }, + ]); + session.close(); + }); }); diff --git a/extensions/openai/realtime-transcription-provider.ts b/extensions/openai/realtime-transcription-provider.ts index 38b2993024b..852148f80b3 100644 --- a/extensions/openai/realtime-transcription-provider.ts +++ b/extensions/openai/realtime-transcription-provider.ts @@ -72,8 +72,16 @@ function createOpenAIRealtimeTranscriptionSession( ): RealtimeTranscriptionSession { let pendingTranscript = ""; - const handleEvent = (event: RealtimeEvent) => { + const handleEvent = ( + event: RealtimeEvent, + transport: RealtimeTranscriptionWebSocketTransport, + ) => { switch (event.type) { + case "session.updated": + case "transcription_session.updated": + transport.markReady(); + return; + case "conversation.item.input_audio_transcription.delta": if (event.delta) { pendingTranscript += event.delta; @@ -95,7 +103,11 @@ function createOpenAIRealtimeTranscriptionSession( case "error": { const detail = readRealtimeErrorDetail(event.error); - config.onError?.(new Error(detail)); + const error = new Error(detail); + config.onError?.(error); + if (!transport.isReady()) { + transport.failConnect(error); + } return; } @@ -121,7 +133,6 @@ function createOpenAIRealtimeTranscriptionSession( Authorization: `Bearer ${config.apiKey}`, "OpenAI-Beta": "realtime=v1", }, - readyOnOpen: true, connectTimeoutMs: OPENAI_REALTIME_TRANSCRIPTION_CONNECT_TIMEOUT_MS, maxReconnectAttempts: OPENAI_REALTIME_TRANSCRIPTION_MAX_RECONNECT_ATTEMPTS, reconnectDelayMs: OPENAI_REALTIME_TRANSCRIPTION_RECONNECT_DELAY_MS, @@ -137,17 +148,24 @@ function createOpenAIRealtimeTranscriptionSession( transport.sendJson({ type: "transcription_session.update", session: { - input_audio_format: "g711_ulaw", - input_audio_transcription: { - model: config.model, - ...(config.language ? { language: config.language } : {}), - ...(config.prompt ? { prompt: config.prompt } : {}), - }, - turn_detection: { - type: "server_vad", - threshold: config.vadThreshold, - prefix_padding_ms: 300, - silence_duration_ms: config.silenceDurationMs, + type: "transcription", + audio: { + input: { + format: { + type: "audio/pcmu", + }, + transcription: { + model: config.model, + ...(config.language ? { language: config.language } : {}), + ...(config.prompt ? { prompt: config.prompt } : {}), + }, + turn_detection: { + type: "server_vad", + threshold: config.vadThreshold, + prefix_padding_ms: 300, + silence_duration_ms: config.silenceDurationMs, + }, + }, }, }, });