fix(openai): wait for realtime transcription session update

This commit is contained in:
Vincent Koc
2026-05-03 16:07:28 -07:00
parent 3546a54003
commit f66af6a5f5
2 changed files with 198 additions and 15 deletions

View File

@@ -1,7 +1,88 @@
import { describe, expect, it } from "vitest";
import { beforeEach, describe, expect, it, vi } from "vitest";
import { buildOpenAIRealtimeTranscriptionProvider } from "./realtime-transcription-provider.js";
const { FakeWebSocket } = vi.hoisted(() => {
type Listener = (...args: unknown[]) => void;
class MockWebSocket {
static readonly OPEN = 1;
static readonly CLOSED = 3;
static instances: MockWebSocket[] = [];
readonly listeners = new Map<string, Listener[]>();
readyState = 0;
sent: string[] = [];
closed = false;
constructor() {
MockWebSocket.instances.push(this);
}
on(event: string, listener: Listener): this {
const listeners = this.listeners.get(event) ?? [];
listeners.push(listener);
this.listeners.set(event, listeners);
return this;
}
emit(event: string, ...args: unknown[]): void {
for (const listener of this.listeners.get(event) ?? []) {
listener(...args);
}
}
send(payload: string): void {
this.sent.push(payload);
}
close(code?: number, reason?: string): void {
this.closed = true;
this.readyState = MockWebSocket.CLOSED;
this.emit("close", code ?? 1000, Buffer.from(reason ?? ""));
}
}
return { FakeWebSocket: MockWebSocket };
});
vi.mock("ws", () => ({
default: FakeWebSocket,
}));
type FakeWebSocketInstance = InstanceType<typeof FakeWebSocket>;
type SentRealtimeEvent = {
type: string;
audio?: string;
session?: {
type?: string;
audio?: {
input?: {
format?: { type?: string };
transcription?: {
model?: string;
language?: string;
prompt?: string;
};
turn_detection?: {
type?: string;
threshold?: number;
prefix_padding_ms?: number;
silence_duration_ms?: number;
};
};
};
};
};
function parseSent(socket: FakeWebSocketInstance): SentRealtimeEvent[] {
return socket.sent.map((payload) => JSON.parse(payload) as SentRealtimeEvent);
}
describe("buildOpenAIRealtimeTranscriptionProvider", () => {
beforeEach(() => {
FakeWebSocket.instances = [];
});
it("normalizes OpenAI config defaults", () => {
const provider = buildOpenAIRealtimeTranscriptionProvider();
const resolved = provider.resolveConfig?.({
@@ -70,4 +151,88 @@ describe("buildOpenAIRealtimeTranscriptionProvider", () => {
const provider = buildOpenAIRealtimeTranscriptionProvider();
expect(provider.aliases).toContain("openai-realtime");
});
it("waits for the OpenAI session update before draining audio", async () => {
const provider = buildOpenAIRealtimeTranscriptionProvider();
const session = provider.createSession({
providerConfig: {
apiKey: "sk-test", // pragma: allowlist secret
language: "en",
model: "gpt-4o-transcribe",
prompt: "expect OpenClaw product names",
silenceDurationMs: 900,
vadThreshold: 0.45,
},
});
const connecting = session.connect();
const socket = FakeWebSocket.instances[0];
if (!socket) {
throw new Error("expected session to create a websocket");
}
socket.readyState = FakeWebSocket.OPEN;
socket.emit("open");
session.sendAudio(Buffer.from("before-ready"));
expect(session.isConnected()).toBe(false);
expect(parseSent(socket)).toEqual([
{
type: "transcription_session.update",
session: {
type: "transcription",
audio: {
input: {
format: { type: "audio/pcmu" },
transcription: {
model: "gpt-4o-transcribe",
language: "en",
prompt: "expect OpenClaw product names",
},
turn_detection: {
type: "server_vad",
threshold: 0.45,
prefix_padding_ms: 300,
silence_duration_ms: 900,
},
},
},
},
},
]);
socket.emit("message", Buffer.from(JSON.stringify({ type: "session.updated" })));
await connecting;
expect(session.isConnected()).toBe(true);
expect(parseSent(socket)).toEqual([
{
type: "transcription_session.update",
session: {
type: "transcription",
audio: {
input: {
format: { type: "audio/pcmu" },
transcription: {
model: "gpt-4o-transcribe",
language: "en",
prompt: "expect OpenClaw product names",
},
turn_detection: {
type: "server_vad",
threshold: 0.45,
prefix_padding_ms: 300,
silence_duration_ms: 900,
},
},
},
},
},
{
type: "input_audio_buffer.append",
audio: Buffer.from("before-ready").toString("base64"),
},
]);
session.close();
});
});

View File

@@ -72,8 +72,16 @@ function createOpenAIRealtimeTranscriptionSession(
): RealtimeTranscriptionSession {
let pendingTranscript = "";
const handleEvent = (event: RealtimeEvent) => {
const handleEvent = (
event: RealtimeEvent,
transport: RealtimeTranscriptionWebSocketTransport,
) => {
switch (event.type) {
case "session.updated":
case "transcription_session.updated":
transport.markReady();
return;
case "conversation.item.input_audio_transcription.delta":
if (event.delta) {
pendingTranscript += event.delta;
@@ -95,7 +103,11 @@ function createOpenAIRealtimeTranscriptionSession(
case "error": {
const detail = readRealtimeErrorDetail(event.error);
config.onError?.(new Error(detail));
const error = new Error(detail);
config.onError?.(error);
if (!transport.isReady()) {
transport.failConnect(error);
}
return;
}
@@ -121,7 +133,6 @@ function createOpenAIRealtimeTranscriptionSession(
Authorization: `Bearer ${config.apiKey}`,
"OpenAI-Beta": "realtime=v1",
},
readyOnOpen: true,
connectTimeoutMs: OPENAI_REALTIME_TRANSCRIPTION_CONNECT_TIMEOUT_MS,
maxReconnectAttempts: OPENAI_REALTIME_TRANSCRIPTION_MAX_RECONNECT_ATTEMPTS,
reconnectDelayMs: OPENAI_REALTIME_TRANSCRIPTION_RECONNECT_DELAY_MS,
@@ -137,17 +148,24 @@ function createOpenAIRealtimeTranscriptionSession(
transport.sendJson({
type: "transcription_session.update",
session: {
input_audio_format: "g711_ulaw",
input_audio_transcription: {
model: config.model,
...(config.language ? { language: config.language } : {}),
...(config.prompt ? { prompt: config.prompt } : {}),
},
turn_detection: {
type: "server_vad",
threshold: config.vadThreshold,
prefix_padding_ms: 300,
silence_duration_ms: config.silenceDurationMs,
type: "transcription",
audio: {
input: {
format: {
type: "audio/pcmu",
},
transcription: {
model: config.model,
...(config.language ? { language: config.language } : {}),
...(config.prompt ? { prompt: config.prompt } : {}),
},
turn_detection: {
type: "server_vad",
threshold: config.vadThreshold,
prefix_padding_ms: 300,
silence_duration_ms: config.silenceDurationMs,
},
},
},
},
});