From 0e7bcf7588d279e68e866664551d5a2da7c58864 Mon Sep 17 00:00:00 2001 From: Peter Steinberger Date: Thu, 23 Apr 2026 03:35:19 +0100 Subject: [PATCH] feat(plugin-sdk): share realtime transcription websocket sessions --- CHANGELOG.md | 1 + .../.generated/plugin-sdk-api-baseline.sha256 | 4 +- docs/plugins/sdk-migration.md | 2 +- docs/plugins/sdk-overview.md | 2 +- docs/plugins/sdk-provider-plugins.md | 34 +- .../realtime-transcription-provider.ts | 321 +++----------- .../realtime-transcription-provider.ts | 326 +++----------- .../realtime-transcription-provider.ts | 328 +++----------- src/plugin-sdk/realtime-transcription.ts | 5 + .../websocket-session.test.ts | 151 +++++++ .../websocket-session.ts | 402 ++++++++++++++++++ 11 files changed, 756 insertions(+), 820 deletions(-) create mode 100644 src/realtime-transcription/websocket-session.test.ts create mode 100644 src/realtime-transcription/websocket-session.ts diff --git a/CHANGELOG.md b/CHANGELOG.md index b2388430861..414cf1c90ba 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ Docs: https://docs.openclaw.ai - Providers/GPT-5: move the GPT-5 prompt overlay into the shared provider runtime so compatible GPT-5 models receive the same behavior and heartbeat guidance through OpenAI, OpenRouter, OpenCode, Codex, and other GPT providers; add `agents.defaults.promptOverlays.gpt5.personality` as the global friendly-style toggle while keeping the OpenAI plugin setting as a fallback. - Providers/xAI: add image generation, text-to-speech, and speech-to-text support, including `grok-imagine-image` / `grok-imagine-image-pro`, reference-image edits, six live xAI voices, MP3/WAV/PCM/G.711 TTS formats, `grok-stt` audio transcription, and xAI realtime transcription for Voice Call streaming. (#68694) Thanks @KateWilkins. - Providers/STT: add Voice Call streaming transcription for Deepgram, ElevenLabs, and Mistral, and add ElevenLabs Scribe v2 batch audio transcription for inbound media. +- Plugin SDK/realtime transcription: add a shared WebSocket session helper for streaming STT providers, covering queueing, ready handshakes, proxy capture, reconnects, and close flushing without each plugin reimplementing the transport loop. - Models/commands: add `/models add ` so you can register a model from chat and use it without restarting the gateway; keep `/models` as a simple provider browser while adding clearer add guidance and copy-friendly command examples. (#70211) Thanks @Takhoffman. - Pi/models: update the bundled pi packages to `0.68.1` and let the OpenCode Go catalog come from pi instead of plugin-maintained model aliases, adding the refreshed `opencode-go/kimi-k2.6`, Qwen, GLM, MiMo, and MiniMax entries. - CLI/doctor plugins: lazy-load doctor plugin paths and prefer installed plugin `dist/*` runtime entries over source-adjacent JavaScript fallbacks, reducing the measured `doctor --non-interactive` runtime by about 74% while keeping cold doctor startup on built plugin artifacts. (#69840) Thanks @gumadeiras. diff --git a/docs/.generated/plugin-sdk-api-baseline.sha256 b/docs/.generated/plugin-sdk-api-baseline.sha256 index 81e57d732fc..01df532b278 100644 --- a/docs/.generated/plugin-sdk-api-baseline.sha256 +++ b/docs/.generated/plugin-sdk-api-baseline.sha256 @@ -1,2 +1,2 @@ -2b7093a57992029cc70126d33544e02eed6c3076a3a6b4ffa6aef7664da0f33d plugin-sdk-api-baseline.json -ea6a2f2326565517b6c42a4d334f615163fb434dbad5e0b8d134c92767714256 plugin-sdk-api-baseline.jsonl +e10f01ce10a381ecb098b805cee95b7278d16de42e02c7873f54448eb2b6c5cc plugin-sdk-api-baseline.json +918b646ff2e0849c4feba5ef930a08187a7bdad3a2d35ba4e1dd456fe3ea2cea plugin-sdk-api-baseline.jsonl diff --git a/docs/plugins/sdk-migration.md b/docs/plugins/sdk-migration.md index aae3e2f4b0e..ced0f07f5e9 100644 --- a/docs/plugins/sdk-migration.md +++ b/docs/plugins/sdk-migration.md @@ -296,7 +296,7 @@ Current bundled provider examples: | `plugin-sdk/text-chunking` | Text chunking helpers | Outbound text chunking helper | | `plugin-sdk/speech` | Speech helpers | Speech provider types plus provider-facing directive, registry, and validation helpers | | `plugin-sdk/speech-core` | Shared speech core | Speech provider types, registry, directives, normalization | - | `plugin-sdk/realtime-transcription` | Realtime transcription helpers | Provider types and registry helpers | + | `plugin-sdk/realtime-transcription` | Realtime transcription helpers | Provider types, registry helpers, and shared WebSocket session helper | | `plugin-sdk/realtime-voice` | Realtime voice helpers | Provider types and registry helpers | | `plugin-sdk/image-generation-core` | Shared image-generation core | Image-generation types, failover, auth, and registry helpers | | `plugin-sdk/music-generation` | Music-generation helpers | Music-generation provider/request/result types | diff --git a/docs/plugins/sdk-overview.md b/docs/plugins/sdk-overview.md index 60c815a94b9..5ce0e772c75 100644 --- a/docs/plugins/sdk-overview.md +++ b/docs/plugins/sdk-overview.md @@ -258,7 +258,7 @@ explicitly promotes one as public. | `plugin-sdk/text-chunking` | Outbound text chunking helper | | `plugin-sdk/speech` | Speech provider types plus provider-facing directive, registry, and validation helpers | | `plugin-sdk/speech-core` | Shared speech provider types, registry, directive, and normalization helpers | - | `plugin-sdk/realtime-transcription` | Realtime transcription provider types and registry helpers | + | `plugin-sdk/realtime-transcription` | Realtime transcription provider types, registry helpers, and shared WebSocket session helper | | `plugin-sdk/realtime-voice` | Realtime voice provider types and registry helpers | | `plugin-sdk/image-generation` | Image generation provider types | | `plugin-sdk/image-generation-core` | Shared image-generation types, failover, auth, and registry helpers | diff --git a/docs/plugins/sdk-provider-plugins.md b/docs/plugins/sdk-provider-plugins.md index 03c68fd10b0..07d8d24f541 100644 --- a/docs/plugins/sdk-provider-plugins.md +++ b/docs/plugins/sdk-provider-plugins.md @@ -599,12 +599,34 @@ API key auth, and dynamic model resolution. id: "acme-ai", label: "Acme Realtime Transcription", isConfigured: () => true, - createSession: (req) => ({ - connect: async () => {}, - sendAudio: () => {}, - close: () => {}, - isConnected: () => true, - }), + createSession: (req) => { + const apiKey = String(req.providerConfig.apiKey ?? ""); + return createRealtimeTranscriptionWebSocketSession({ + providerId: "acme-ai", + callbacks: req, + url: "wss://api.example.com/v1/realtime-transcription", + headers: { Authorization: `Bearer ${apiKey}` }, + onMessage: (event, transport) => { + if (event.type === "session.created") { + transport.sendJson({ type: "session.update" }); + transport.markReady(); + return; + } + if (event.type === "transcript.final") { + req.onTranscript?.(event.text); + } + }, + sendAudio: (audio, transport) => { + transport.sendJson({ + type: "audio.append", + audio: audio.toString("base64"), + }); + }, + onClose: (transport) => { + transport.sendJson({ type: "audio.end" }); + }, + }); + }, }); api.registerRealtimeVoiceProvider({ diff --git a/extensions/deepgram/realtime-transcription-provider.ts b/extensions/deepgram/realtime-transcription-provider.ts index 28ded1d6a43..9401d538020 100644 --- a/extensions/deepgram/realtime-transcription-provider.ts +++ b/extensions/deepgram/realtime-transcription-provider.ts @@ -1,18 +1,12 @@ -import { randomUUID } from "node:crypto"; import { - captureWsEvent, - createDebugProxyWebSocketAgent, - resolveDebugProxySettings, -} from "openclaw/plugin-sdk/proxy-capture"; -import type { - RealtimeTranscriptionProviderConfig, - RealtimeTranscriptionProviderPlugin, - RealtimeTranscriptionSession, - RealtimeTranscriptionSessionCreateRequest, + createRealtimeTranscriptionWebSocketSession, + type RealtimeTranscriptionProviderConfig, + type RealtimeTranscriptionProviderPlugin, + type RealtimeTranscriptionSession, + type RealtimeTranscriptionSessionCreateRequest, } from "openclaw/plugin-sdk/realtime-transcription"; import { normalizeResolvedSecretInputString } from "openclaw/plugin-sdk/secret-input"; import { normalizeOptionalString } from "openclaw/plugin-sdk/text-runtime"; -import WebSocket from "ws"; import { DEFAULT_DEEPGRAM_AUDIO_BASE_URL, DEFAULT_DEEPGRAM_AUDIO_MODEL } from "./audio.js"; type DeepgramRealtimeTranscriptionEncoding = "linear16" | "mulaw" | "alaw"; @@ -164,16 +158,6 @@ function normalizeProviderConfig( }; } -function rawWsDataToBuffer(data: WebSocket.RawData): Buffer { - if (Buffer.isBuffer(data)) { - return data; - } - if (Array.isArray(data)) { - return Buffer.concat(data); - } - return Buffer.from(data); -} - function readErrorDetail(value: unknown): string { if (typeof value === "string") { return value; @@ -188,284 +172,75 @@ function readTranscriptText(event: DeepgramRealtimeTranscriptionEvent): string | return normalizeOptionalString(event.channel?.alternatives?.[0]?.transcript); } -class DeepgramRealtimeTranscriptionSession implements RealtimeTranscriptionSession { - private ws: WebSocket | null = null; - private connected = false; - private closed = false; - private reconnectAttempts = 0; - private queuedAudio: Buffer[] = []; - private queuedBytes = 0; - private closeTimer: ReturnType | undefined; - private lastTranscript: string | undefined; - private speechStarted = false; - private reconnecting = false; - private readonly flowId = randomUUID(); +function createDeepgramRealtimeTranscriptionSession( + config: DeepgramRealtimeTranscriptionSessionConfig, +): RealtimeTranscriptionSession { + let lastTranscript: string | undefined; + let speechStarted = false; - constructor(private readonly config: DeepgramRealtimeTranscriptionSessionConfig) {} - - async connect(): Promise { - this.closed = false; - this.reconnectAttempts = 0; - await this.doConnect(); - } - - sendAudio(audio: Buffer): void { - if (this.closed || audio.byteLength === 0) { + const emitTranscript = (text: string) => { + if (text === lastTranscript) { return; } - if (this.ws?.readyState === WebSocket.OPEN) { - this.sendAudioFrame(audio); - return; - } - this.queueAudio(audio); - } + lastTranscript = text; + config.onTranscript?.(text); + }; - close(): void { - this.closed = true; - this.connected = false; - this.queuedAudio = []; - this.queuedBytes = 0; - if (!this.ws || this.ws.readyState !== WebSocket.OPEN) { - this.forceClose(); - return; - } - this.sendEvent({ type: "Finalize" }); - this.closeTimer = setTimeout(() => this.forceClose(), DEEPGRAM_REALTIME_CLOSE_TIMEOUT_MS); - } - - isConnected(): boolean { - return this.connected; - } - - private async doConnect(): Promise { - await new Promise((resolve, reject) => { - const url = toDeepgramRealtimeWsUrl(this.config); - const debugProxy = resolveDebugProxySettings(); - const proxyAgent = createDebugProxyWebSocketAgent(debugProxy); - let settled = false; - let opened = false; - const finishConnect = () => { - if (settled) { - return; - } - settled = true; - clearTimeout(connectTimeout); - this.flushQueuedAudio(); - resolve(); - }; - const failConnect = (error: Error) => { - if (settled) { - return; - } - settled = true; - clearTimeout(connectTimeout); - this.config.onError?.(error); - this.closed = true; - this.forceClose(); - reject(error); - }; - this.ws = new WebSocket(url, { - headers: { - Authorization: `Token ${this.config.apiKey}`, - }, - ...(proxyAgent ? { agent: proxyAgent } : {}), - }); - - const connectTimeout = setTimeout(() => { - failConnect(new Error("Deepgram realtime transcription connection timeout")); - }, DEEPGRAM_REALTIME_CONNECT_TIMEOUT_MS); - - this.ws.on("open", () => { - opened = true; - this.connected = true; - this.reconnectAttempts = 0; - captureWsEvent({ - url, - direction: "local", - kind: "ws-open", - flowId: this.flowId, - meta: { provider: "deepgram", capability: "realtime-transcription" }, - }); - finishConnect(); - }); - - this.ws.on("message", (data) => { - const payload = rawWsDataToBuffer(data); - captureWsEvent({ - url, - direction: "inbound", - kind: "ws-frame", - flowId: this.flowId, - payload, - meta: { provider: "deepgram", capability: "realtime-transcription" }, - }); - try { - this.handleEvent(JSON.parse(payload.toString()) as DeepgramRealtimeTranscriptionEvent); - } catch (error) { - this.config.onError?.(error instanceof Error ? error : new Error(String(error))); - } - }); - - this.ws.on("error", (error) => { - captureWsEvent({ - url, - direction: "local", - kind: "error", - flowId: this.flowId, - errorText: error instanceof Error ? error.message : String(error), - meta: { provider: "deepgram", capability: "realtime-transcription" }, - }); - if (!opened) { - failConnect(error instanceof Error ? error : new Error(String(error))); - return; - } - this.config.onError?.(error instanceof Error ? error : new Error(String(error))); - }); - - this.ws.on("close", () => { - clearTimeout(connectTimeout); - this.connected = false; - if (this.closeTimer) { - clearTimeout(this.closeTimer); - this.closeTimer = undefined; - } - if (this.closed || !opened || !settled) { - return; - } - void this.attemptReconnect(); - }); - }); - } - - private async attemptReconnect(): Promise { - if (this.closed || this.reconnecting) { - return; - } - if (this.reconnectAttempts >= DEEPGRAM_REALTIME_MAX_RECONNECT_ATTEMPTS) { - this.config.onError?.(new Error("Deepgram realtime transcription reconnect limit reached")); - return; - } - this.reconnectAttempts += 1; - const delay = DEEPGRAM_REALTIME_RECONNECT_DELAY_MS * 2 ** (this.reconnectAttempts - 1); - this.reconnecting = true; - try { - await new Promise((resolve) => setTimeout(resolve, delay)); - if (!this.closed) { - await this.doConnect(); - } - } catch { - if (!this.closed) { - this.reconnecting = false; - await this.attemptReconnect(); - return; - } - } finally { - this.reconnecting = false; - } - } - - private handleEvent(event: DeepgramRealtimeTranscriptionEvent): void { + const handleEvent = (event: DeepgramRealtimeTranscriptionEvent) => { switch (event.type) { case "Results": { const text = readTranscriptText(event); if (!text) { return; } - if (!this.speechStarted) { - this.speechStarted = true; - this.config.onSpeechStart?.(); + if (!speechStarted) { + speechStarted = true; + config.onSpeechStart?.(); } if (event.is_final || event.speech_final) { - this.emitTranscript(text); + emitTranscript(text); if (event.speech_final) { - this.speechStarted = false; + speechStarted = false; } return; } - this.config.onPartial?.(text); + config.onPartial?.(text); return; } case "SpeechStarted": - this.speechStarted = true; - this.config.onSpeechStart?.(); + speechStarted = true; + config.onSpeechStart?.(); return; case "Error": case "error": - this.config.onError?.(new Error(readErrorDetail(event.error ?? event.message))); + config.onError?.(new Error(readErrorDetail(event.error ?? event.message))); return; default: return; } - } + }; - private emitTranscript(text: string): void { - if (text === this.lastTranscript) { - return; - } - this.lastTranscript = text; - this.config.onTranscript?.(text); - } - - private queueAudio(audio: Buffer): void { - this.queuedAudio.push(Buffer.from(audio)); - this.queuedBytes += audio.byteLength; - while (this.queuedBytes > DEEPGRAM_REALTIME_MAX_QUEUED_BYTES && this.queuedAudio.length > 0) { - const dropped = this.queuedAudio.shift(); - this.queuedBytes -= dropped?.byteLength ?? 0; - } - } - - private flushQueuedAudio(): void { - for (const audio of this.queuedAudio) { - this.sendAudioFrame(audio); - } - this.queuedAudio = []; - this.queuedBytes = 0; - } - - private sendAudioFrame(audio: Buffer): void { - if (this.ws?.readyState !== WebSocket.OPEN) { - this.queueAudio(audio); - return; - } - captureWsEvent({ - url: toDeepgramRealtimeWsUrl(this.config), - direction: "outbound", - kind: "ws-frame", - flowId: this.flowId, - payload: audio, - meta: { provider: "deepgram", capability: "realtime-transcription" }, - }); - this.ws.send(audio); - } - - private sendEvent(event: unknown): void { - if (this.ws?.readyState !== WebSocket.OPEN) { - return; - } - const payload = JSON.stringify(event); - captureWsEvent({ - url: toDeepgramRealtimeWsUrl(this.config), - direction: "outbound", - kind: "ws-frame", - flowId: this.flowId, - payload, - meta: { provider: "deepgram", capability: "realtime-transcription" }, - }); - this.ws.send(payload); - } - - private forceClose(): void { - if (this.closeTimer) { - clearTimeout(this.closeTimer); - this.closeTimer = undefined; - } - this.connected = false; - if (this.ws) { - this.ws.close(1000, "Transcription session closed"); - this.ws = null; - } - } + return createRealtimeTranscriptionWebSocketSession({ + providerId: "deepgram", + callbacks: config, + url: () => toDeepgramRealtimeWsUrl(config), + headers: { Authorization: `Token ${config.apiKey}` }, + readyOnOpen: true, + connectTimeoutMs: DEEPGRAM_REALTIME_CONNECT_TIMEOUT_MS, + closeTimeoutMs: DEEPGRAM_REALTIME_CLOSE_TIMEOUT_MS, + maxReconnectAttempts: DEEPGRAM_REALTIME_MAX_RECONNECT_ATTEMPTS, + reconnectDelayMs: DEEPGRAM_REALTIME_RECONNECT_DELAY_MS, + maxQueuedBytes: DEEPGRAM_REALTIME_MAX_QUEUED_BYTES, + connectTimeoutMessage: "Deepgram realtime transcription connection timeout", + reconnectLimitMessage: "Deepgram realtime transcription reconnect limit reached", + sendAudio: (audio, transport) => { + transport.sendBinary(audio); + }, + onClose: (transport) => { + transport.sendJson({ type: "Finalize" }); + }, + onMessage: handleEvent, + }); } export function buildDeepgramRealtimeTranscriptionProvider(): RealtimeTranscriptionProviderPlugin { @@ -483,7 +258,7 @@ export function buildDeepgramRealtimeTranscriptionProvider(): RealtimeTranscript if (!apiKey) { throw new Error("Deepgram API key missing"); } - return new DeepgramRealtimeTranscriptionSession({ + return createDeepgramRealtimeTranscriptionSession({ ...req, apiKey, baseUrl: normalizeDeepgramRealtimeBaseUrl(config.baseUrl), diff --git a/extensions/elevenlabs/realtime-transcription-provider.ts b/extensions/elevenlabs/realtime-transcription-provider.ts index 83751388148..3215c0120c1 100644 --- a/extensions/elevenlabs/realtime-transcription-provider.ts +++ b/extensions/elevenlabs/realtime-transcription-provider.ts @@ -1,18 +1,13 @@ -import { randomUUID } from "node:crypto"; import { - captureWsEvent, - createDebugProxyWebSocketAgent, - resolveDebugProxySettings, -} from "openclaw/plugin-sdk/proxy-capture"; -import type { - RealtimeTranscriptionProviderConfig, - RealtimeTranscriptionProviderPlugin, - RealtimeTranscriptionSession, - RealtimeTranscriptionSessionCreateRequest, + createRealtimeTranscriptionWebSocketSession, + type RealtimeTranscriptionProviderConfig, + type RealtimeTranscriptionProviderPlugin, + type RealtimeTranscriptionSession, + type RealtimeTranscriptionSessionCreateRequest, + type RealtimeTranscriptionWebSocketTransport, } from "openclaw/plugin-sdk/realtime-transcription"; import { normalizeResolvedSecretInputString } from "openclaw/plugin-sdk/secret-input"; import { normalizeOptionalString } from "openclaw/plugin-sdk/text-runtime"; -import WebSocket from "ws"; import { resolveElevenLabsApiKeyWithProfileFallback } from "./config-api.js"; import { normalizeElevenLabsBaseUrl } from "./shared.js"; @@ -152,16 +147,6 @@ function toElevenLabsRealtimeWsUrl(config: ElevenLabsRealtimeTranscriptionSessio return url.toString(); } -function rawWsDataToBuffer(data: WebSocket.RawData): Buffer { - if (Buffer.isBuffer(data)) { - return data; - } - if (Array.isArray(data)) { - return Buffer.concat(data); - } - return Buffer.from(data); -} - function readErrorDetail(event: ElevenLabsRealtimeTranscriptionEvent): string { return ( normalizeOptionalString(event.error) ?? @@ -171,277 +156,86 @@ function readErrorDetail(event: ElevenLabsRealtimeTranscriptionEvent): string { ); } -class ElevenLabsRealtimeTranscriptionSession implements RealtimeTranscriptionSession { - private ws: WebSocket | null = null; - private connected = false; - private ready = false; - private closed = false; - private reconnectAttempts = 0; - private queuedAudio: Buffer[] = []; - private queuedBytes = 0; - private closeTimer: ReturnType | undefined; - private lastTranscript: string | undefined; - private reconnecting = false; - private readonly flowId = randomUUID(); +function createElevenLabsRealtimeTranscriptionSession( + config: ElevenLabsRealtimeTranscriptionSessionConfig, +): RealtimeTranscriptionSession { + let lastTranscript: string | undefined; - constructor(private readonly config: ElevenLabsRealtimeTranscriptionSessionConfig) {} - - async connect(): Promise { - this.closed = false; - this.reconnectAttempts = 0; - await this.doConnect(); - } - - sendAudio(audio: Buffer): void { - if (this.closed || audio.byteLength === 0) { + const emitTranscript = (text: string) => { + if (text === lastTranscript) { return; } - if (this.ws?.readyState === WebSocket.OPEN && this.ready) { - this.sendAudioChunk(audio); - return; - } - this.queueAudio(audio); - } + lastTranscript = text; + config.onTranscript?.(text); + }; - close(): void { - this.closed = true; - this.connected = false; - this.queuedAudio = []; - this.queuedBytes = 0; - if (!this.ws || this.ws.readyState !== WebSocket.OPEN) { - this.forceClose(); - return; - } - this.sendJson({ + const sendAudioChunk = ( + audio: Buffer, + transport: RealtimeTranscriptionWebSocketTransport, + ): void => { + transport.sendJson({ message_type: "input_audio_chunk", - audio_base_64: "", - sample_rate: this.config.sampleRate, - commit: true, + audio_base_64: audio.toString("base64"), + sample_rate: config.sampleRate, + ...(config.commitStrategy === "manual" ? { commit: true } : {}), }); - this.closeTimer = setTimeout(() => this.forceClose(), ELEVENLABS_REALTIME_CLOSE_TIMEOUT_MS); - } + }; - isConnected(): boolean { - return this.connected; - } - - private async doConnect(): Promise { - await new Promise((resolve, reject) => { - const url = toElevenLabsRealtimeWsUrl(this.config); - const debugProxy = resolveDebugProxySettings(); - const proxyAgent = createDebugProxyWebSocketAgent(debugProxy); - let settled = false; - let opened = false; - const finishConnect = () => { - if (settled) { - return; - } - settled = true; - clearTimeout(connectTimeout); - this.ready = true; - this.flushQueuedAudio(); - resolve(); - }; - const failConnect = (error: Error) => { - if (settled) { - return; - } - settled = true; - clearTimeout(connectTimeout); - this.config.onError?.(error); - this.closed = true; - this.forceClose(); - reject(error); - }; - this.ready = false; - this.ws = new WebSocket(url, { - headers: { - "xi-api-key": this.config.apiKey, - }, - ...(proxyAgent ? { agent: proxyAgent } : {}), - }); - - const connectTimeout = setTimeout(() => { - failConnect(new Error("ElevenLabs realtime transcription connection timeout")); - }, ELEVENLABS_REALTIME_CONNECT_TIMEOUT_MS); - - this.ws.on("open", () => { - opened = true; - this.connected = true; - this.reconnectAttempts = 0; - captureWsEvent({ - url, - direction: "local", - kind: "ws-open", - flowId: this.flowId, - meta: { provider: "elevenlabs", capability: "realtime-transcription" }, - }); - }); - - this.ws.on("message", (data) => { - const payload = rawWsDataToBuffer(data); - captureWsEvent({ - url, - direction: "inbound", - kind: "ws-frame", - flowId: this.flowId, - payload, - meta: { provider: "elevenlabs", capability: "realtime-transcription" }, - }); - try { - const event = JSON.parse(payload.toString()) as ElevenLabsRealtimeTranscriptionEvent; - if (event.message_type === "session_started") { - finishConnect(); - return; - } - if (!this.ready && event.message_type?.includes("error")) { - failConnect(new Error(readErrorDetail(event))); - return; - } - this.handleEvent(event); - } catch (error) { - this.config.onError?.(error instanceof Error ? error : new Error(String(error))); - } - }); - - this.ws.on("error", (error) => { - captureWsEvent({ - url, - direction: "local", - kind: "error", - flowId: this.flowId, - errorText: error instanceof Error ? error.message : String(error), - meta: { provider: "elevenlabs", capability: "realtime-transcription" }, - }); - if (!opened) { - failConnect(error instanceof Error ? error : new Error(String(error))); - return; - } - this.config.onError?.(error instanceof Error ? error : new Error(String(error))); - }); - - this.ws.on("close", () => { - clearTimeout(connectTimeout); - this.connected = false; - this.ready = false; - if (this.closed || !opened || !settled) { - return; - } - void this.attemptReconnect(); - }); - }); - } - - private async attemptReconnect(): Promise { - if (this.closed || this.reconnecting) { + const handleEvent = ( + event: ElevenLabsRealtimeTranscriptionEvent, + transport: RealtimeTranscriptionWebSocketTransport, + ) => { + if (event.message_type === "session_started") { + transport.markReady(); return; } - if (this.reconnectAttempts >= ELEVENLABS_REALTIME_MAX_RECONNECT_ATTEMPTS) { - this.config.onError?.(new Error("ElevenLabs realtime transcription reconnect limit reached")); + if (!transport.isReady() && event.message_type?.includes("error")) { + transport.failConnect(new Error(readErrorDetail(event))); return; } - this.reconnectAttempts += 1; - const delay = ELEVENLABS_REALTIME_RECONNECT_DELAY_MS * 2 ** (this.reconnectAttempts - 1); - this.reconnecting = true; - try { - await new Promise((resolve) => setTimeout(resolve, delay)); - if (!this.closed) { - await this.doConnect(); - } - } catch { - if (!this.closed) { - this.reconnecting = false; - await this.attemptReconnect(); - return; - } - } finally { - this.reconnecting = false; - } - } - - private handleEvent(event: ElevenLabsRealtimeTranscriptionEvent): void { switch (event.message_type) { case "partial_transcript": if (event.text) { - this.config.onPartial?.(event.text); + config.onPartial?.(event.text); } return; case "committed_transcript": case "committed_transcript_with_timestamps": if (event.text) { - this.emitTranscript(event.text); + emitTranscript(event.text); } return; default: if (event.message_type?.includes("error")) { - this.config.onError?.(new Error(readErrorDetail(event))); + config.onError?.(new Error(readErrorDetail(event))); } return; } - } + }; - private emitTranscript(text: string): void { - if (text === this.lastTranscript) { - return; - } - this.lastTranscript = text; - this.config.onTranscript?.(text); - } - - private queueAudio(audio: Buffer): void { - this.queuedAudio.push(Buffer.from(audio)); - this.queuedBytes += audio.byteLength; - while (this.queuedBytes > ELEVENLABS_REALTIME_MAX_QUEUED_BYTES && this.queuedAudio.length > 0) { - const dropped = this.queuedAudio.shift(); - this.queuedBytes -= dropped?.byteLength ?? 0; - } - } - - private flushQueuedAudio(): void { - for (const audio of this.queuedAudio) { - this.sendAudioChunk(audio); - } - this.queuedAudio = []; - this.queuedBytes = 0; - } - - private sendAudioChunk(audio: Buffer): void { - this.sendJson({ - message_type: "input_audio_chunk", - audio_base_64: audio.toString("base64"), - sample_rate: this.config.sampleRate, - ...(this.config.commitStrategy === "manual" ? { commit: true } : {}), - }); - } - - private sendJson(event: unknown): void { - if (this.ws?.readyState !== WebSocket.OPEN) { - return; - } - const payload = JSON.stringify(event); - captureWsEvent({ - url: toElevenLabsRealtimeWsUrl(this.config), - direction: "outbound", - kind: "ws-frame", - flowId: this.flowId, - payload, - meta: { provider: "elevenlabs", capability: "realtime-transcription" }, - }); - this.ws.send(payload); - } - - private forceClose(): void { - if (this.closeTimer) { - clearTimeout(this.closeTimer); - this.closeTimer = undefined; - } - this.connected = false; - this.ready = false; - if (this.ws) { - this.ws.close(1000, "Transcription session closed"); - this.ws = null; - } - } + return createRealtimeTranscriptionWebSocketSession({ + providerId: "elevenlabs", + callbacks: config, + url: () => toElevenLabsRealtimeWsUrl(config), + headers: { "xi-api-key": config.apiKey }, + connectTimeoutMs: ELEVENLABS_REALTIME_CONNECT_TIMEOUT_MS, + closeTimeoutMs: ELEVENLABS_REALTIME_CLOSE_TIMEOUT_MS, + maxReconnectAttempts: ELEVENLABS_REALTIME_MAX_RECONNECT_ATTEMPTS, + reconnectDelayMs: ELEVENLABS_REALTIME_RECONNECT_DELAY_MS, + maxQueuedBytes: ELEVENLABS_REALTIME_MAX_QUEUED_BYTES, + connectTimeoutMessage: "ElevenLabs realtime transcription connection timeout", + reconnectLimitMessage: "ElevenLabs realtime transcription reconnect limit reached", + sendAudio: sendAudioChunk, + onClose: (transport) => { + transport.sendJson({ + message_type: "input_audio_chunk", + audio_base_64: "", + sample_rate: config.sampleRate, + commit: true, + }); + }, + onMessage: handleEvent, + }); } export function buildElevenLabsRealtimeTranscriptionProvider(): RealtimeTranscriptionProviderPlugin { @@ -464,7 +258,7 @@ export function buildElevenLabsRealtimeTranscriptionProvider(): RealtimeTranscri if (!apiKey) { throw new Error("ElevenLabs API key missing"); } - return new ElevenLabsRealtimeTranscriptionSession({ + return createElevenLabsRealtimeTranscriptionSession({ ...req, apiKey, baseUrl: normalizeElevenLabsBaseUrl(config.baseUrl), diff --git a/extensions/mistral/realtime-transcription-provider.ts b/extensions/mistral/realtime-transcription-provider.ts index 290d485c714..c46efd4c984 100644 --- a/extensions/mistral/realtime-transcription-provider.ts +++ b/extensions/mistral/realtime-transcription-provider.ts @@ -1,18 +1,13 @@ -import { randomUUID } from "node:crypto"; import { - captureWsEvent, - createDebugProxyWebSocketAgent, - resolveDebugProxySettings, -} from "openclaw/plugin-sdk/proxy-capture"; -import type { - RealtimeTranscriptionProviderConfig, - RealtimeTranscriptionProviderPlugin, - RealtimeTranscriptionSession, - RealtimeTranscriptionSessionCreateRequest, + createRealtimeTranscriptionWebSocketSession, + type RealtimeTranscriptionProviderConfig, + type RealtimeTranscriptionProviderPlugin, + type RealtimeTranscriptionSession, + type RealtimeTranscriptionSessionCreateRequest, + type RealtimeTranscriptionWebSocketTransport, } from "openclaw/plugin-sdk/realtime-transcription"; import { normalizeResolvedSecretInputString } from "openclaw/plugin-sdk/secret-input"; import { normalizeOptionalString } from "openclaw/plugin-sdk/text-runtime"; -import WebSocket from "ws"; type MistralRealtimeTranscriptionEncoding = | "pcm_s16le" @@ -155,16 +150,6 @@ function normalizeProviderConfig( }; } -function rawWsDataToBuffer(data: WebSocket.RawData): Buffer { - if (Buffer.isBuffer(data)) { - return data; - } - if (Array.isArray(data)) { - return Buffer.concat(data); - } - return Buffer.from(data); -} - function readErrorDetail(event: MistralRealtimeTranscriptionEvent): string { const message = event.error?.message; if (typeof message === "string") { @@ -179,283 +164,84 @@ function readErrorDetail(event: MistralRealtimeTranscriptionEvent): string { return "Mistral realtime transcription error"; } -class MistralRealtimeTranscriptionSession implements RealtimeTranscriptionSession { - private ws: WebSocket | null = null; - private connected = false; - private ready = false; - private closed = false; - private reconnectAttempts = 0; - private queuedAudio: Buffer[] = []; - private queuedBytes = 0; - private closeTimer: ReturnType | undefined; - private partialText = ""; - private reconnecting = false; - private readonly flowId = randomUUID(); +function createMistralRealtimeTranscriptionSession( + config: MistralRealtimeTranscriptionSessionConfig, +): RealtimeTranscriptionSession { + let partialText = ""; - constructor(private readonly config: MistralRealtimeTranscriptionSessionConfig) {} - - async connect(): Promise { - this.closed = false; - this.reconnectAttempts = 0; - await this.doConnect(); - } - - sendAudio(audio: Buffer): void { - if (this.closed || audio.byteLength === 0) { - return; - } - if (this.ws?.readyState === WebSocket.OPEN && this.ready) { - this.sendJson({ - type: "input_audio.append", - audio: audio.toString("base64"), - }); - return; - } - this.queueAudio(audio); - } - - close(): void { - this.closed = true; - this.connected = false; - this.queuedAudio = []; - this.queuedBytes = 0; - if (!this.ws || this.ws.readyState !== WebSocket.OPEN) { - this.forceClose(); - return; - } - this.sendJson({ type: "input_audio.flush" }); - this.sendJson({ type: "input_audio.end" }); - this.closeTimer = setTimeout(() => this.forceClose(), MISTRAL_REALTIME_CLOSE_TIMEOUT_MS); - } - - isConnected(): boolean { - return this.connected; - } - - private async doConnect(): Promise { - await new Promise((resolve, reject) => { - const url = toMistralRealtimeWsUrl(this.config); - const debugProxy = resolveDebugProxySettings(); - const proxyAgent = createDebugProxyWebSocketAgent(debugProxy); - let settled = false; - let opened = false; - const finishConnect = () => { - if (settled) { - return; - } - settled = true; - clearTimeout(connectTimeout); - this.ready = true; - this.flushQueuedAudio(); - resolve(); - }; - const failConnect = (error: Error) => { - if (settled) { - return; - } - settled = true; - clearTimeout(connectTimeout); - this.config.onError?.(error); - this.closed = true; - this.forceClose(); - reject(error); - }; - this.ready = false; - this.ws = new WebSocket(url, { - headers: { - Authorization: `Bearer ${this.config.apiKey}`, + const handleEvent = ( + event: MistralRealtimeTranscriptionEvent, + transport: RealtimeTranscriptionWebSocketTransport, + ) => { + if (event.type === "session.created") { + transport.sendJson({ + type: "session.update", + session: { + audio_format: { + encoding: config.encoding, + sample_rate: config.sampleRate, + }, }, - ...(proxyAgent ? { agent: proxyAgent } : {}), }); - - const connectTimeout = setTimeout(() => { - failConnect(new Error("Mistral realtime transcription connection timeout")); - }, MISTRAL_REALTIME_CONNECT_TIMEOUT_MS); - - this.ws.on("open", () => { - opened = true; - this.connected = true; - this.reconnectAttempts = 0; - captureWsEvent({ - url, - direction: "local", - kind: "ws-open", - flowId: this.flowId, - meta: { provider: "mistral", capability: "realtime-transcription" }, - }); - }); - - this.ws.on("message", (data) => { - const payload = rawWsDataToBuffer(data); - captureWsEvent({ - url, - direction: "inbound", - kind: "ws-frame", - flowId: this.flowId, - payload, - meta: { provider: "mistral", capability: "realtime-transcription" }, - }); - try { - const event = JSON.parse(payload.toString()) as MistralRealtimeTranscriptionEvent; - if (event.type === "session.created") { - this.sendJson({ - type: "session.update", - session: { - audio_format: { - encoding: this.config.encoding, - sample_rate: this.config.sampleRate, - }, - }, - }); - finishConnect(); - return; - } - if (!this.ready && event.type === "error") { - failConnect(new Error(readErrorDetail(event))); - return; - } - this.handleEvent(event); - } catch (error) { - this.config.onError?.(error instanceof Error ? error : new Error(String(error))); - } - }); - - this.ws.on("error", (error) => { - captureWsEvent({ - url, - direction: "local", - kind: "error", - flowId: this.flowId, - errorText: error instanceof Error ? error.message : String(error), - meta: { provider: "mistral", capability: "realtime-transcription" }, - }); - if (!opened) { - failConnect(error instanceof Error ? error : new Error(String(error))); - return; - } - this.config.onError?.(error instanceof Error ? error : new Error(String(error))); - }); - - this.ws.on("close", () => { - clearTimeout(connectTimeout); - this.connected = false; - this.ready = false; - if (this.closeTimer) { - clearTimeout(this.closeTimer); - this.closeTimer = undefined; - } - if (this.closed || !opened || !settled) { - return; - } - void this.attemptReconnect(); - }); - }); - } - - private async attemptReconnect(): Promise { - if (this.closed || this.reconnecting) { + transport.markReady(); return; } - if (this.reconnectAttempts >= MISTRAL_REALTIME_MAX_RECONNECT_ATTEMPTS) { - this.config.onError?.(new Error("Mistral realtime transcription reconnect limit reached")); + if (!transport.isReady() && event.type === "error") { + transport.failConnect(new Error(readErrorDetail(event))); return; } - this.reconnectAttempts += 1; - const delay = MISTRAL_REALTIME_RECONNECT_DELAY_MS * 2 ** (this.reconnectAttempts - 1); - this.reconnecting = true; - try { - await new Promise((resolve) => setTimeout(resolve, delay)); - if (!this.closed) { - await this.doConnect(); - } - } catch { - if (!this.closed) { - this.reconnecting = false; - await this.attemptReconnect(); - return; - } - } finally { - this.reconnecting = false; - } - } - - private handleEvent(event: MistralRealtimeTranscriptionEvent): void { switch (event.type) { case "transcription.text.delta": if (event.text) { - this.partialText += event.text; - this.config.onPartial?.(this.partialText); + partialText += event.text; + config.onPartial?.(partialText); } return; case "transcription.segment": if (event.text) { - this.config.onTranscript?.(event.text); - this.partialText = ""; + config.onTranscript?.(event.text); + partialText = ""; } return; case "transcription.done": - if (this.partialText.trim()) { - this.config.onTranscript?.(this.partialText); - this.partialText = ""; + if (partialText.trim()) { + config.onTranscript?.(partialText); + partialText = ""; } - this.forceClose(); + transport.closeNow(); return; case "error": - this.config.onError?.(new Error(readErrorDetail(event))); + config.onError?.(new Error(readErrorDetail(event))); return; default: return; } - } + }; - private queueAudio(audio: Buffer): void { - this.queuedAudio.push(Buffer.from(audio)); - this.queuedBytes += audio.byteLength; - while (this.queuedBytes > MISTRAL_REALTIME_MAX_QUEUED_BYTES && this.queuedAudio.length > 0) { - const dropped = this.queuedAudio.shift(); - this.queuedBytes -= dropped?.byteLength ?? 0; - } - } - - private flushQueuedAudio(): void { - for (const audio of this.queuedAudio) { - this.sendJson({ + return createRealtimeTranscriptionWebSocketSession({ + providerId: "mistral", + callbacks: config, + url: () => toMistralRealtimeWsUrl(config), + headers: { Authorization: `Bearer ${config.apiKey}` }, + connectTimeoutMs: MISTRAL_REALTIME_CONNECT_TIMEOUT_MS, + closeTimeoutMs: MISTRAL_REALTIME_CLOSE_TIMEOUT_MS, + maxReconnectAttempts: MISTRAL_REALTIME_MAX_RECONNECT_ATTEMPTS, + reconnectDelayMs: MISTRAL_REALTIME_RECONNECT_DELAY_MS, + maxQueuedBytes: MISTRAL_REALTIME_MAX_QUEUED_BYTES, + connectTimeoutMessage: "Mistral realtime transcription connection timeout", + reconnectLimitMessage: "Mistral realtime transcription reconnect limit reached", + sendAudio: (audio, transport) => { + transport.sendJson({ type: "input_audio.append", audio: audio.toString("base64"), }); - } - this.queuedAudio = []; - this.queuedBytes = 0; - } - - private sendJson(event: unknown): void { - if (this.ws?.readyState !== WebSocket.OPEN) { - return; - } - const payload = JSON.stringify(event); - captureWsEvent({ - url: toMistralRealtimeWsUrl(this.config), - direction: "outbound", - kind: "ws-frame", - flowId: this.flowId, - payload, - meta: { provider: "mistral", capability: "realtime-transcription" }, - }); - this.ws.send(payload); - } - - private forceClose(): void { - if (this.closeTimer) { - clearTimeout(this.closeTimer); - this.closeTimer = undefined; - } - this.connected = false; - this.ready = false; - if (this.ws) { - this.ws.close(1000, "Transcription session closed"); - this.ws = null; - } - } + }, + onClose: (transport) => { + transport.sendJson({ type: "input_audio.flush" }); + transport.sendJson({ type: "input_audio.end" }); + }, + onMessage: handleEvent, + }); } export function buildMistralRealtimeTranscriptionProvider(): RealtimeTranscriptionProviderPlugin { @@ -473,7 +259,7 @@ export function buildMistralRealtimeTranscriptionProvider(): RealtimeTranscripti if (!apiKey) { throw new Error("Mistral API key missing"); } - return new MistralRealtimeTranscriptionSession({ + return createMistralRealtimeTranscriptionSession({ ...req, apiKey, baseUrl: normalizeMistralRealtimeBaseUrl(config.baseUrl), diff --git a/src/plugin-sdk/realtime-transcription.ts b/src/plugin-sdk/realtime-transcription.ts index e0f68005b07..4526276f81e 100644 --- a/src/plugin-sdk/realtime-transcription.ts +++ b/src/plugin-sdk/realtime-transcription.ts @@ -14,3 +14,8 @@ export { listRealtimeTranscriptionProviders, normalizeRealtimeTranscriptionProviderId, } from "../realtime-transcription/provider-registry.js"; +export { + createRealtimeTranscriptionWebSocketSession, + type RealtimeTranscriptionWebSocketSessionOptions, + type RealtimeTranscriptionWebSocketTransport, +} from "../realtime-transcription/websocket-session.js"; diff --git a/src/realtime-transcription/websocket-session.test.ts b/src/realtime-transcription/websocket-session.test.ts new file mode 100644 index 00000000000..95160e38e2c --- /dev/null +++ b/src/realtime-transcription/websocket-session.test.ts @@ -0,0 +1,151 @@ +import { createServer } from "node:http"; +import type { AddressInfo } from "node:net"; +import { afterEach, describe, expect, it, vi } from "vitest"; +import type WebSocket from "ws"; +import { WebSocketServer } from "ws"; +import { createRealtimeTranscriptionWebSocketSession } from "./websocket-session.js"; + +let cleanup: (() => Promise) | undefined; + +afterEach(async () => { + await cleanup?.(); + cleanup = undefined; +}); + +async function createRealtimeServer(params?: { + initialEvent?: unknown; + onBinary?: (payload: Buffer) => void; + onText?: (payload: unknown) => void; +}) { + const server = createServer(); + const wss = new WebSocketServer({ noServer: true }); + const clients = new Set(); + + server.on("upgrade", (request, socket, head) => { + wss.handleUpgrade(request, socket, head, (ws) => { + clients.add(ws); + ws.on("close", () => clients.delete(ws)); + if (params?.initialEvent) { + ws.send(JSON.stringify(params.initialEvent)); + } + ws.on("message", (data, isBinary) => { + const buffer = Buffer.isBuffer(data) + ? data + : Array.isArray(data) + ? Buffer.concat(data) + : Buffer.from(data); + if (isBinary) { + params?.onBinary?.(buffer); + return; + } + params?.onText?.(JSON.parse(buffer.toString())); + }); + }); + }); + + await new Promise((resolve) => server.listen(0, "127.0.0.1", resolve)); + cleanup = async () => { + for (const ws of clients) { + ws.terminate(); + } + await new Promise((resolve) => wss.close(() => resolve())); + await new Promise((resolve) => server.close(() => resolve())); + }; + const port = (server.address() as AddressInfo).port; + return { url: `ws://127.0.0.1:${port}` }; +} + +async function waitFor(expectation: () => void) { + const started = Date.now(); + let lastError: unknown; + while (Date.now() - started < 3000) { + try { + expectation(); + return; + } catch (error) { + lastError = error; + await new Promise((resolve) => setTimeout(resolve, 25)); + } + } + throw lastError; +} + +describe("createRealtimeTranscriptionWebSocketSession", () => { + it("flushes queued binary audio after an open-ready connection", async () => { + const frames: Buffer[] = []; + const server = await createRealtimeServer({ onBinary: (payload) => frames.push(payload) }); + const session = createRealtimeTranscriptionWebSocketSession({ + providerId: "test", + callbacks: {}, + url: server.url, + readyOnOpen: true, + sendAudio: (audio, transport) => { + transport.sendBinary(audio); + }, + }); + + session.sendAudio(Buffer.from("queued")); + await session.connect(); + session.sendAudio(Buffer.from("after")); + await waitFor(() => expect(Buffer.concat(frames).toString()).toBe("queuedafter")); + expect(session.isConnected()).toBe(true); + session.close(); + }); + + it("lets providers mark ready after a JSON handshake", async () => { + const frames: unknown[] = []; + const server = await createRealtimeServer({ + initialEvent: { type: "session.created" }, + onText: (payload) => frames.push(payload), + }); + const session = createRealtimeTranscriptionWebSocketSession<{ type?: string }>({ + providerId: "test", + callbacks: {}, + url: server.url, + onMessage: (event, transport) => { + if (event.type === "session.created") { + transport.sendJson({ type: "session.update" }); + transport.markReady(); + } + }, + sendAudio: (audio, transport) => { + transport.sendJson({ type: "input_audio.append", audio: audio.toString("base64") }); + }, + }); + + session.sendAudio(Buffer.from("queued")); + await session.connect(); + await waitFor(() => + expect(frames).toEqual([ + { type: "session.update" }, + { type: "input_audio.append", audio: Buffer.from("queued").toString("base64") }, + ]), + ); + session.close(); + }); + + it("rejects provider setup errors before ready", async () => { + const server = await createRealtimeServer({ initialEvent: { type: "error", message: "nope" } }); + const onError = vi.fn(); + const session = createRealtimeTranscriptionWebSocketSession<{ + type?: string; + message?: string; + }>({ + providerId: "test", + callbacks: { onError }, + url: server.url, + onMessage: (event, transport) => { + if (!transport.isReady() && event.type === "error") { + transport.failConnect(new Error(event.message)); + } + }, + sendAudio: (audio, transport) => { + transport.sendBinary(audio); + }, + }); + + await expect(session.connect()).rejects.toThrow("nope"); + expect(session.isConnected()).toBe(false); + expect(onError).toHaveBeenCalledWith(expect.any(Error)); + }); +}); diff --git a/src/realtime-transcription/websocket-session.ts b/src/realtime-transcription/websocket-session.ts new file mode 100644 index 00000000000..f6efc8662b0 --- /dev/null +++ b/src/realtime-transcription/websocket-session.ts @@ -0,0 +1,402 @@ +import { randomUUID } from "node:crypto"; +import WebSocket, { type RawData } from "ws"; +import { createDebugProxyWebSocketAgent, resolveDebugProxySettings } from "../proxy-capture/env.js"; +import { captureWsEvent } from "../proxy-capture/runtime.js"; +import type { + RealtimeTranscriptionSession, + RealtimeTranscriptionSessionCallbacks, +} from "./provider-types.js"; + +export type RealtimeTranscriptionWebSocketTransport = { + readonly callbacks: RealtimeTranscriptionSessionCallbacks; + closeNow(): void; + failConnect(error: Error): void; + isOpen(): boolean; + isReady(): boolean; + markReady(): void; + sendBinary(payload: Buffer): boolean; + sendJson(payload: unknown): boolean; +}; + +export type RealtimeTranscriptionWebSocketSessionOptions = { + callbacks: RealtimeTranscriptionSessionCallbacks; + connectTimeoutMessage?: string; + connectTimeoutMs?: number; + closeTimeoutMs?: number; + headers?: Record; + maxQueuedBytes?: number; + maxReconnectAttempts?: number; + onClose?: (transport: RealtimeTranscriptionWebSocketTransport) => void; + onMessage?: (event: Event, transport: RealtimeTranscriptionWebSocketTransport) => void; + onOpen?: (transport: RealtimeTranscriptionWebSocketTransport) => void; + parseMessage?: (payload: Buffer) => Event; + providerId: string; + readyOnOpen?: boolean; + reconnectDelayMs?: number; + reconnectLimitMessage?: string; + sendAudio: (audio: Buffer, transport: RealtimeTranscriptionWebSocketTransport) => void; + url: string | (() => string); +}; + +const DEFAULT_CONNECT_TIMEOUT_MS = 10_000; +const DEFAULT_CLOSE_TIMEOUT_MS = 5_000; +const DEFAULT_MAX_RECONNECT_ATTEMPTS = 5; +const DEFAULT_RECONNECT_DELAY_MS = 1000; +const DEFAULT_MAX_QUEUED_BYTES = 2 * 1024 * 1024; + +function rawWsDataToBuffer(data: RawData): Buffer { + if (Buffer.isBuffer(data)) { + return data; + } + if (Array.isArray(data)) { + return Buffer.concat(data); + } + return Buffer.from(data); +} + +function defaultParseMessage(payload: Buffer): unknown { + return JSON.parse(payload.toString()); +} + +class WebSocketRealtimeTranscriptionSession implements RealtimeTranscriptionSession { + private closeTimer: ReturnType | undefined; + private closed = false; + private connected = false; + private currentUrl = ""; + private queuedAudio: Buffer[] = []; + private queuedBytes = 0; + private ready = false; + private reconnectAttempts = 0; + private reconnecting = false; + private suppressReconnect = false; + private ws: WebSocket | null = null; + private readonly flowId = randomUUID(); + private readonly options: RealtimeTranscriptionWebSocketSessionOptions; + private readonly transport: RealtimeTranscriptionWebSocketTransport; + private failConnect: ((error: Error) => void) | undefined; + private markReady: (() => void) | undefined; + + constructor(options: RealtimeTranscriptionWebSocketSessionOptions) { + this.options = options; + this.transport = { + callbacks: options.callbacks, + closeNow: () => { + this.closed = true; + this.forceClose(); + }, + failConnect: (error) => this.failConnect?.(error), + isOpen: () => this.ws?.readyState === WebSocket.OPEN, + isReady: () => this.ready, + markReady: () => this.markReady?.(), + sendBinary: (payload) => this.sendBinary(payload), + sendJson: (payload) => this.sendJson(payload), + }; + } + + async connect(): Promise { + this.closed = false; + this.suppressReconnect = false; + this.reconnectAttempts = 0; + await this.doConnect(); + } + + sendAudio(audio: Buffer): void { + if (this.closed || audio.byteLength === 0) { + return; + } + if (this.ws?.readyState === WebSocket.OPEN && this.ready) { + this.options.sendAudio(audio, this.transport); + return; + } + this.queueAudio(audio); + } + + close(): void { + this.closed = true; + this.connected = false; + this.ready = false; + this.queuedAudio = []; + this.queuedBytes = 0; + if (!this.ws || this.ws.readyState !== WebSocket.OPEN) { + this.forceClose(); + return; + } + try { + this.options.onClose?.(this.transport); + } catch (error) { + this.emitError(error); + } + this.closeTimer = setTimeout(() => this.forceClose(), this.closeTimeoutMs); + } + + isConnected(): boolean { + return this.connected && this.ready; + } + + private get closeTimeoutMs(): number { + return this.options.closeTimeoutMs ?? DEFAULT_CLOSE_TIMEOUT_MS; + } + + private get connectTimeoutMs(): number { + return this.options.connectTimeoutMs ?? DEFAULT_CONNECT_TIMEOUT_MS; + } + + private get maxQueuedBytes(): number { + return this.options.maxQueuedBytes ?? DEFAULT_MAX_QUEUED_BYTES; + } + + private get maxReconnectAttempts(): number { + return this.options.maxReconnectAttempts ?? DEFAULT_MAX_RECONNECT_ATTEMPTS; + } + + private get reconnectDelayMs(): number { + return this.options.reconnectDelayMs ?? DEFAULT_RECONNECT_DELAY_MS; + } + + private async doConnect(): Promise { + await new Promise((resolve, reject) => { + this.ready = false; + this.currentUrl = + typeof this.options.url === "function" ? this.options.url() : this.options.url; + const debugProxy = resolveDebugProxySettings(); + const proxyAgent = createDebugProxyWebSocketAgent(debugProxy); + let settled = false; + let opened = false; + let connectTimeout: ReturnType | undefined; + + const finishConnect = () => { + if (settled) { + return; + } + settled = true; + if (connectTimeout) { + clearTimeout(connectTimeout); + } + this.ready = true; + this.flushQueuedAudio(); + resolve(); + }; + + const failConnect = (error: Error) => { + if (settled) { + return; + } + settled = true; + if (connectTimeout) { + clearTimeout(connectTimeout); + } + this.emitError(error); + this.suppressReconnect = true; + this.forceClose(); + reject(error); + }; + + this.markReady = finishConnect; + this.failConnect = failConnect; + this.ws = new WebSocket(this.currentUrl, { + headers: this.options.headers, + ...(proxyAgent ? { agent: proxyAgent } : {}), + }); + + connectTimeout = setTimeout(() => { + failConnect( + new Error( + this.options.connectTimeoutMessage ?? + `${this.options.providerId} realtime transcription connection timeout`, + ), + ); + }, this.connectTimeoutMs); + + this.ws.on("open", () => { + opened = true; + this.connected = true; + this.reconnectAttempts = 0; + this.captureLocalOpen(); + try { + this.options.onOpen?.(this.transport); + if (this.options.readyOnOpen) { + finishConnect(); + } + } catch (error) { + failConnect(error instanceof Error ? error : new Error(String(error))); + } + }); + + this.ws.on("message", (data) => { + const payload = rawWsDataToBuffer(data); + this.captureFrame("inbound", payload); + try { + if (!this.options.onMessage) { + return; + } + const parseMessage = this.options.parseMessage ?? defaultParseMessage; + this.options.onMessage(parseMessage(payload) as Event, this.transport); + } catch (error) { + this.emitError(error); + } + }); + + this.ws.on("error", (error) => { + const normalized = error instanceof Error ? error : new Error(String(error)); + this.captureError(normalized); + if (!opened || !settled) { + failConnect(normalized); + return; + } + this.emitError(normalized); + }); + + this.ws.on("close", () => { + if (connectTimeout) { + clearTimeout(connectTimeout); + } + this.connected = false; + this.ready = false; + if (this.closeTimer) { + clearTimeout(this.closeTimer); + this.closeTimer = undefined; + } + if (this.closed) { + return; + } + if (this.suppressReconnect) { + this.suppressReconnect = false; + return; + } + if (!opened || !settled) { + failConnect( + new Error( + this.options.connectTimeoutMessage ?? + `${this.options.providerId} realtime transcription connection closed before ready`, + ), + ); + return; + } + void this.attemptReconnect(); + }); + }); + } + + private async attemptReconnect(): Promise { + if (this.closed || this.reconnecting) { + return; + } + if (this.reconnectAttempts >= this.maxReconnectAttempts) { + this.emitError( + new Error( + this.options.reconnectLimitMessage ?? + `${this.options.providerId} realtime transcription reconnect limit reached`, + ), + ); + return; + } + this.reconnectAttempts += 1; + const delay = this.reconnectDelayMs * 2 ** (this.reconnectAttempts - 1); + this.reconnecting = true; + try { + await new Promise((resolve) => setTimeout(resolve, delay)); + if (!this.closed) { + await this.doConnect(); + } + } catch { + if (!this.closed) { + this.reconnecting = false; + await this.attemptReconnect(); + return; + } + } finally { + this.reconnecting = false; + } + } + + private queueAudio(audio: Buffer): void { + this.queuedAudio.push(Buffer.from(audio)); + this.queuedBytes += audio.byteLength; + while (this.queuedBytes > this.maxQueuedBytes && this.queuedAudio.length > 0) { + const dropped = this.queuedAudio.shift(); + this.queuedBytes -= dropped?.byteLength ?? 0; + } + } + + private flushQueuedAudio(): void { + for (const audio of this.queuedAudio) { + this.options.sendAudio(audio, this.transport); + } + this.queuedAudio = []; + this.queuedBytes = 0; + } + + private sendBinary(payload: Buffer): boolean { + if (this.ws?.readyState !== WebSocket.OPEN) { + return false; + } + this.captureFrame("outbound", payload); + this.ws.send(payload); + return true; + } + + private sendJson(payload: unknown): boolean { + if (this.ws?.readyState !== WebSocket.OPEN) { + return false; + } + const serialized = JSON.stringify(payload); + this.captureFrame("outbound", serialized); + this.ws.send(serialized); + return true; + } + + private forceClose(): void { + if (this.closeTimer) { + clearTimeout(this.closeTimer); + this.closeTimer = undefined; + } + this.connected = false; + this.ready = false; + if (this.ws) { + this.ws.close(1000, "Transcription session closed"); + this.ws = null; + } + } + + private emitError(error: unknown): void { + this.options.callbacks.onError?.(error instanceof Error ? error : new Error(String(error))); + } + + private captureFrame(direction: "inbound" | "outbound", payload: Buffer | string): void { + captureWsEvent({ + url: this.currentUrl, + direction, + kind: "ws-frame", + flowId: this.flowId, + payload, + meta: { provider: this.options.providerId, capability: "realtime-transcription" }, + }); + } + + private captureLocalOpen(): void { + captureWsEvent({ + url: this.currentUrl, + direction: "local", + kind: "ws-open", + flowId: this.flowId, + meta: { provider: this.options.providerId, capability: "realtime-transcription" }, + }); + } + + private captureError(error: Error): void { + captureWsEvent({ + url: this.currentUrl, + direction: "local", + kind: "error", + flowId: this.flowId, + errorText: error.message, + meta: { provider: this.options.providerId, capability: "realtime-transcription" }, + }); + } +} + +export function createRealtimeTranscriptionWebSocketSession( + options: RealtimeTranscriptionWebSocketSessionOptions, +): RealtimeTranscriptionSession { + return new WebSocketRealtimeTranscriptionSession(options); +}