From 7ced38b5ef2a41b0b2357d34be2de1234d78ecc8 Mon Sep 17 00:00:00 2001 From: Peter Steinberger Date: Sun, 1 Mar 2026 21:50:33 +0000 Subject: [PATCH] feat(agents): make openai responses websocket-first with fallback --- docs/concepts/model-providers.md | 2 + docs/providers/openai.md | 10 +- src/agents/openai-ws-connection.ts | 528 ++++++++++++++ src/agents/openai-ws-stream.ts | 680 ++++++++++++++++++ .../pi-embedded-runner-extraparams.test.ts | 22 +- src/agents/pi-embedded-runner/extra-params.ts | 12 + src/agents/pi-embedded-runner/run/attempt.ts | 12 + 7 files changed, 1260 insertions(+), 6 deletions(-) create mode 100644 src/agents/openai-ws-connection.ts create mode 100644 src/agents/openai-ws-stream.ts diff --git a/docs/concepts/model-providers.md b/docs/concepts/model-providers.md index fccd0b84249..a64b92cecb9 100644 --- a/docs/concepts/model-providers.md +++ b/docs/concepts/model-providers.md @@ -43,6 +43,8 @@ OpenClaw ships with the pi‑ai catalog. These providers require **no** - Optional rotation: `OPENAI_API_KEYS`, `OPENAI_API_KEY_1`, `OPENAI_API_KEY_2`, plus `OPENCLAW_LIVE_OPENAI_KEY` (single override) - Example model: `openai/gpt-5.1-codex` - CLI: `openclaw onboard --auth-choice openai-api-key` +- Default transport is `auto` (WebSocket-first, SSE fallback) +- Override per model via `agents.defaults.models["openai/"].params.transport` (`"sse"`, `"websocket"`, or `"auto"`) ```json5 { diff --git a/docs/providers/openai.md b/docs/providers/openai.md index 8b26072f1a6..35018c9937a 100644 --- a/docs/providers/openai.md +++ b/docs/providers/openai.md @@ -56,12 +56,14 @@ openclaw models auth login --provider openai-codex } ``` -### Codex transport default +### Transport default -OpenClaw uses `pi-ai` for model streaming. For `openai-codex/*` models you can set -`agents.defaults.models..params.transport` to select transport: +OpenClaw uses `pi-ai` for model streaming. For both `openai/*` and +`openai-codex/*`, default transport is `"auto"` (WebSocket-first, then SSE +fallback). + +You can set `agents.defaults.models..params.transport`: -- Default is `"auto"` (WebSocket-first, then SSE fallback). - `"sse"`: force SSE - `"websocket"`: force WebSocket - `"auto"`: try WebSocket, then fall back to SSE diff --git a/src/agents/openai-ws-connection.ts b/src/agents/openai-ws-connection.ts new file mode 100644 index 00000000000..b3214c3e291 --- /dev/null +++ b/src/agents/openai-ws-connection.ts @@ -0,0 +1,528 @@ +/** + * OpenAI WebSocket Connection Manager + * + * Manages a persistent WebSocket connection to the OpenAI Responses API + * (wss://api.openai.com/v1/responses) for multi-turn tool-call workflows. + * + * Features: + * - Auto-reconnect with exponential backoff (max 5 retries: 1s/2s/4s/8s/16s) + * - Tracks previous_response_id per connection for incremental turns + * - Warm-up support (generate: false) to pre-load the connection + * - Typed WebSocket event definitions matching the Responses API SSE spec + * + * @see https://developers.openai.com/api/docs/guides/websocket-mode + */ + +import { EventEmitter } from "node:events"; +import WebSocket from "ws"; + +// ───────────────────────────────────────────────────────────────────────────── +// WebSocket Event Types (Server → Client) +// ───────────────────────────────────────────────────────────────────────────── + +export interface ResponseObject { + id: string; + object: "response"; + created_at: number; + status: "in_progress" | "completed" | "failed" | "cancelled" | "incomplete"; + model: string; + output: OutputItem[]; + usage?: UsageInfo; + error?: { code: string; message: string }; +} + +export interface UsageInfo { + input_tokens: number; + output_tokens: number; + total_tokens: number; +} + +export type OutputItem = + | { + type: "message"; + id: string; + role: "assistant"; + content: Array<{ type: "output_text"; text: string }>; + status?: "in_progress" | "completed"; + } + | { + type: "function_call"; + id: string; + call_id: string; + name: string; + arguments: string; + status?: "in_progress" | "completed"; + } + | { + type: "reasoning"; + id: string; + content?: string; + summary?: string; + }; + +export interface ResponseCreatedEvent { + type: "response.created"; + response: ResponseObject; +} + +export interface ResponseInProgressEvent { + type: "response.in_progress"; + response: ResponseObject; +} + +export interface ResponseCompletedEvent { + type: "response.completed"; + response: ResponseObject; +} + +export interface ResponseFailedEvent { + type: "response.failed"; + response: ResponseObject; +} + +export interface OutputItemAddedEvent { + type: "response.output_item.added"; + output_index: number; + item: OutputItem; +} + +export interface OutputItemDoneEvent { + type: "response.output_item.done"; + output_index: number; + item: OutputItem; +} + +export interface ContentPartAddedEvent { + type: "response.content_part.added"; + item_id: string; + output_index: number; + content_index: number; + part: { type: "output_text"; text: string }; +} + +export interface ContentPartDoneEvent { + type: "response.content_part.done"; + item_id: string; + output_index: number; + content_index: number; + part: { type: "output_text"; text: string }; +} + +export interface OutputTextDeltaEvent { + type: "response.output_text.delta"; + item_id: string; + output_index: number; + content_index: number; + delta: string; +} + +export interface OutputTextDoneEvent { + type: "response.output_text.done"; + item_id: string; + output_index: number; + content_index: number; + text: string; +} + +export interface FunctionCallArgumentsDeltaEvent { + type: "response.function_call_arguments.delta"; + item_id: string; + output_index: number; + call_id: string; + delta: string; +} + +export interface FunctionCallArgumentsDoneEvent { + type: "response.function_call_arguments.done"; + item_id: string; + output_index: number; + call_id: string; + arguments: string; +} + +export interface RateLimitUpdatedEvent { + type: "rate_limits.updated"; + rate_limits: Array<{ + name: string; + limit: number; + remaining: number; + reset_seconds: number; + }>; +} + +export interface ErrorEvent { + type: "error"; + code: string; + message: string; + param?: string; +} + +export type OpenAIWebSocketEvent = + | ResponseCreatedEvent + | ResponseInProgressEvent + | ResponseCompletedEvent + | ResponseFailedEvent + | OutputItemAddedEvent + | OutputItemDoneEvent + | ContentPartAddedEvent + | ContentPartDoneEvent + | OutputTextDeltaEvent + | OutputTextDoneEvent + | FunctionCallArgumentsDeltaEvent + | FunctionCallArgumentsDoneEvent + | RateLimitUpdatedEvent + | ErrorEvent; + +// ───────────────────────────────────────────────────────────────────────────── +// Client → Server Event Types +// ───────────────────────────────────────────────────────────────────────────── + +export type ContentPart = + | { type: "input_text"; text: string } + | { type: "output_text"; text: string } + | { + type: "input_image"; + source: { type: "url"; url: string } | { type: "base64"; media_type: string; data: string }; + }; + +export type InputItem = + | { + type: "message"; + role: "system" | "developer" | "user" | "assistant"; + content: string | ContentPart[]; + } + | { type: "function_call"; id?: string; call_id?: string; name: string; arguments: string } + | { type: "function_call_output"; call_id: string; output: string } + | { type: "reasoning"; content?: string; encrypted_content?: string; summary?: string } + | { type: "item_reference"; id: string }; + +export type ToolChoice = + | "auto" + | "none" + | "required" + | { type: "function"; function: { name: string } }; + +export interface FunctionToolDefinition { + type: "function"; + function: { + name: string; + description?: string; + parameters?: Record; + }; +} + +/** Standard response.create event payload (full turn) */ +export interface ResponseCreateEvent { + type: "response.create"; + model: string; + store?: boolean; + stream?: boolean; + input?: string | InputItem[]; + instructions?: string; + tools?: FunctionToolDefinition[]; + tool_choice?: ToolChoice; + context_management?: unknown; + previous_response_id?: string; + max_output_tokens?: number; + temperature?: number; + top_p?: number; + metadata?: Record; + reasoning?: { effort?: "low" | "medium" | "high"; summary?: "auto" | "concise" | "detailed" }; + truncation?: "auto" | "disabled"; + [key: string]: unknown; +} + +/** Warm-up payload: generate: false pre-loads connection without generating output */ +export interface WarmUpEvent extends ResponseCreateEvent { + generate: false; +} + +export type ClientEvent = ResponseCreateEvent | WarmUpEvent; + +// ───────────────────────────────────────────────────────────────────────────── +// Connection Manager +// ───────────────────────────────────────────────────────────────────────────── + +const OPENAI_WS_URL = "wss://api.openai.com/v1/responses"; +const MAX_RETRIES = 5; +/** Backoff delays in ms: 1s, 2s, 4s, 8s, 16s */ +const BACKOFF_DELAYS_MS = [1000, 2000, 4000, 8000, 16000] as const; + +export interface OpenAIWebSocketManagerOptions { + /** Override the default WebSocket URL (useful for testing) */ + url?: string; + /** Maximum number of reconnect attempts (default: 5) */ + maxRetries?: number; + /** Custom backoff delays in ms (default: [1000, 2000, 4000, 8000, 16000]) */ + backoffDelaysMs?: readonly number[]; +} + +type InternalEvents = { + message: [event: OpenAIWebSocketEvent]; + open: []; + close: [code: number, reason: string]; + error: [err: Error]; +}; + +/** + * Manages a persistent WebSocket connection to the OpenAI Responses API. + * + * Usage: + * ```ts + * const manager = new OpenAIWebSocketManager(); + * await manager.connect(apiKey); + * + * manager.onMessage((event) => { + * if (event.type === "response.completed") { + * console.log("Response ID:", event.response.id); + * } + * }); + * + * manager.send({ type: "response.create", model: "gpt-5.2", input: [...] }); + * ``` + */ +export class OpenAIWebSocketManager extends EventEmitter { + private ws: WebSocket | null = null; + private apiKey: string | null = null; + private retryCount = 0; + private retryTimer: NodeJS.Timeout | null = null; + private closed = false; + + /** The ID of the most recent completed response on this connection. */ + private _previousResponseId: string | null = null; + + private readonly wsUrl: string; + private readonly maxRetries: number; + private readonly backoffDelaysMs: readonly number[]; + + constructor(options: OpenAIWebSocketManagerOptions = {}) { + super(); + this.wsUrl = options.url ?? OPENAI_WS_URL; + this.maxRetries = options.maxRetries ?? MAX_RETRIES; + this.backoffDelaysMs = options.backoffDelaysMs ?? BACKOFF_DELAYS_MS; + } + + // ─── Public API ──────────────────────────────────────────────────────────── + + /** + * Returns the previous_response_id from the last completed response, + * for use in subsequent response.create events. + */ + get previousResponseId(): string | null { + return this._previousResponseId; + } + + /** + * Opens a WebSocket connection to the OpenAI Responses API. + * Resolves when the connection is established (open event fires). + * Rejects if the initial connection fails after max retries. + */ + connect(apiKey: string): Promise { + this.apiKey = apiKey; + this.closed = false; + this.retryCount = 0; + return this._openConnection(); + } + + /** + * Sends a typed event to the OpenAI Responses API over the WebSocket. + * Throws if the connection is not open. + */ + send(event: ClientEvent): void { + if (!this.ws || this.ws.readyState !== WebSocket.OPEN) { + throw new Error( + `OpenAIWebSocketManager: cannot send — connection is not open (readyState=${this.ws?.readyState ?? "no socket"})`, + ); + } + this.ws.send(JSON.stringify(event)); + } + + /** + * Registers a handler for incoming server-sent WebSocket events. + * Returns an unsubscribe function. + */ + onMessage(handler: (event: OpenAIWebSocketEvent) => void): () => void { + this.on("message", handler); + return () => { + this.off("message", handler); + }; + } + + /** + * Returns true if the WebSocket is currently open and ready to send. + */ + isConnected(): boolean { + return this.ws !== null && this.ws.readyState === WebSocket.OPEN; + } + + /** + * Permanently closes the WebSocket connection and disables auto-reconnect. + */ + close(): void { + this.closed = true; + this._cancelRetryTimer(); + if (this.ws) { + this.ws.removeAllListeners(); + if (this.ws.readyState === WebSocket.OPEN || this.ws.readyState === WebSocket.CONNECTING) { + this.ws.close(1000, "Client closed"); + } + this.ws = null; + } + } + + // ─── Internal: Connection Lifecycle ──────────────────────────────────────── + + private _openConnection(): Promise { + return new Promise((resolve, reject) => { + if (!this.apiKey) { + reject(new Error("OpenAIWebSocketManager: apiKey is required before connecting.")); + return; + } + + const socket = new WebSocket(this.wsUrl, { + headers: { + Authorization: `Bearer ${this.apiKey}`, + "OpenAI-Beta": "responses-websocket=v1", + }, + }); + + this.ws = socket; + + const onOpen = () => { + this.retryCount = 0; + resolve(); + this.emit("open"); + }; + + const onError = (err: Error) => { + // Remove open listener so we don't resolve after an error. + socket.off("open", onOpen); + // Emit "error" on the manager only when there are listeners; otherwise + // the promise rejection below is the primary error channel for this + // initial connection failure. (An uncaught "error" event in Node.js + // throws synchronously and would prevent the promise from rejecting.) + if (this.listenerCount("error") > 0) { + this.emit("error", err); + } + reject(err); + }; + + const onClose = (code: number, reason: Buffer) => { + const reasonStr = reason.toString(); + this.emit("close", code, reasonStr); + + if (!this.closed) { + this._scheduleReconnect(); + } + }; + + const onMessage = (data: WebSocket.RawData) => { + this._handleMessage(data); + }; + + socket.once("open", onOpen); + socket.on("error", onError); + socket.on("close", onClose); + socket.on("message", onMessage); + }); + } + + private _scheduleReconnect(): void { + if (this.closed) { + return; + } + if (this.retryCount >= this.maxRetries) { + this._safeEmitError( + new Error(`OpenAIWebSocketManager: max reconnect retries (${this.maxRetries}) exceeded.`), + ); + return; + } + + const delayMs = + this.backoffDelaysMs[Math.min(this.retryCount, this.backoffDelaysMs.length - 1)] ?? 1000; + this.retryCount++; + + this.retryTimer = setTimeout(() => { + if (this.closed) { + return; + } + this._openConnection().catch((err: unknown) => { + // onError handler already emitted error event; schedule next retry. + void err; + this._scheduleReconnect(); + }); + }, delayMs); + } + + /** Emit an error only if there are listeners; prevents Node.js from crashing + * with "unhandled 'error' event" when no one is listening. */ + private _safeEmitError(err: Error): void { + if (this.listenerCount("error") > 0) { + this.emit("error", err); + } + } + + private _cancelRetryTimer(): void { + if (this.retryTimer !== null) { + clearTimeout(this.retryTimer); + this.retryTimer = null; + } + } + + private _handleMessage(data: WebSocket.RawData): void { + let text: string; + if (typeof data === "string") { + text = data; + } else if (Buffer.isBuffer(data)) { + text = data.toString("utf8"); + } else if (data instanceof ArrayBuffer) { + text = Buffer.from(data).toString("utf8"); + } else { + // Blob or other — coerce to string + text = String(data); + } + + let parsed: unknown; + try { + parsed = JSON.parse(text); + } catch { + this._safeEmitError( + new Error(`OpenAIWebSocketManager: failed to parse message: ${text.slice(0, 200)}`), + ); + return; + } + + if (!parsed || typeof parsed !== "object" || !("type" in parsed)) { + this._safeEmitError( + new Error( + `OpenAIWebSocketManager: unexpected message shape (no "type" field): ${text.slice(0, 200)}`, + ), + ); + return; + } + + const event = parsed as OpenAIWebSocketEvent; + + // Track previous_response_id on completion + if (event.type === "response.completed" && event.response?.id) { + this._previousResponseId = event.response.id; + } + + this.emit("message", event); + } + + /** + * Sends a warm-up event to pre-load the connection and model without generating output. + * Pass tools/instructions to prime the connection for the upcoming session. + */ + warmUp(params: { model: string; tools?: FunctionToolDefinition[]; instructions?: string }): void { + const event: WarmUpEvent = { + type: "response.create", + generate: false, + model: params.model, + ...(params.tools ? { tools: params.tools } : {}), + ...(params.instructions ? { instructions: params.instructions } : {}), + }; + this.send(event); + } +} diff --git a/src/agents/openai-ws-stream.ts b/src/agents/openai-ws-stream.ts new file mode 100644 index 00000000000..865ad775840 --- /dev/null +++ b/src/agents/openai-ws-stream.ts @@ -0,0 +1,680 @@ +/** + * OpenAI WebSocket StreamFn Integration + * + * Wraps `OpenAIWebSocketManager` in a `StreamFn` that can be plugged into the + * pi-embedded-runner agent in place of the default `streamSimple` HTTP function. + * + * Key behaviours: + * - Per-session `OpenAIWebSocketManager` (keyed by sessionId) + * - Tracks `previous_response_id` to send only incremental tool-result inputs + * - Falls back to `streamSimple` (HTTP) if the WebSocket connection fails + * - Cleanup helpers for releasing sessions after the run completes + * + * Complexity budget & risk mitigation: + * - **Transport aware**: respects `transport` (`auto` | `websocket` | `sse`) + * - **Transparent fallback in `auto` mode**: connect/send failures fall back to + * the existing HTTP `streamSimple`; forced `websocket` mode surfaces WS errors + * - **Zero shared state**: per-session registry; session cleanup on dispose prevents leaks + * - **Full parity**: all generation options (temperature, top_p, max_output_tokens, + * tool_choice, reasoning) forwarded identically to the HTTP path + * + * @see src/agents/openai-ws-connection.ts for the connection manager + */ + +import { randomUUID } from "node:crypto"; +import type { StreamFn } from "@mariozechner/pi-agent-core"; +import type { + AssistantMessage, + Context, + Message, + StopReason, + TextContent, + ToolCall, + Usage, +} from "@mariozechner/pi-ai"; +import { createAssistantMessageEventStream, streamSimple } from "@mariozechner/pi-ai"; +import { + OpenAIWebSocketManager, + type ContentPart, + type FunctionToolDefinition, + type InputItem, + type OpenAIWebSocketManagerOptions, + type ResponseObject, +} from "./openai-ws-connection.js"; +import { log } from "./pi-embedded-runner/logger.js"; + +// ───────────────────────────────────────────────────────────────────────────── +// Per-session state +// ───────────────────────────────────────────────────────────────────────────── + +interface WsSession { + manager: OpenAIWebSocketManager; + /** Number of messages that were in context.messages at the END of the last streamFn call. */ + lastContextLength: number; + /** True if the connection has been established at least once. */ + everConnected: boolean; + /** True if the session is permanently broken (no more reconnect). */ + broken: boolean; +} + +/** Module-level registry: sessionId → WsSession */ +const wsRegistry = new Map(); + +// ───────────────────────────────────────────────────────────────────────────── +// Public registry helpers +// ───────────────────────────────────────────────────────────────────────────── + +/** + * Release and close the WebSocket session for the given sessionId. + * Call this after the agent run completes to free the connection. + */ +export function releaseWsSession(sessionId: string): void { + const session = wsRegistry.get(sessionId); + if (session) { + try { + session.manager.close(); + } catch { + // Ignore close errors — connection may already be gone. + } + wsRegistry.delete(sessionId); + } +} + +/** + * Returns true if a live WebSocket session exists for the given sessionId. + */ +export function hasWsSession(sessionId: string): boolean { + const s = wsRegistry.get(sessionId); + return !!(s && !s.broken && s.manager.isConnected()); +} + +// ───────────────────────────────────────────────────────────────────────────── +// Message format converters +// ───────────────────────────────────────────────────────────────────────────── + +type AnyMessage = Message & { role: string; content: unknown }; + +/** Convert pi-ai content (string | ContentPart[]) to plain text. */ +function contentToText(content: unknown): string { + if (typeof content === "string") { + return content; + } + if (!Array.isArray(content)) { + return ""; + } + return (content as Array<{ type?: string; text?: string }>) + .filter((p) => p.type === "text" && typeof p.text === "string") + .map((p) => p.text as string) + .join(""); +} + +/** Convert pi-ai content to OpenAI ContentPart[]. */ +function contentToOpenAIParts(content: unknown): ContentPart[] { + if (typeof content === "string") { + return content ? [{ type: "input_text", text: content }] : []; + } + if (!Array.isArray(content)) { + return []; + } + const parts: ContentPart[] = []; + for (const part of content as Array<{ + type?: string; + text?: string; + data?: string; + mimeType?: string; + }>) { + if (part.type === "text" && typeof part.text === "string") { + parts.push({ type: "input_text", text: part.text }); + } else if (part.type === "image" && typeof part.data === "string") { + parts.push({ + type: "input_image", + source: { + type: "base64", + media_type: part.mimeType ?? "image/jpeg", + data: part.data, + }, + }); + } + } + return parts; +} + +/** Convert pi-ai tool array to OpenAI FunctionToolDefinition[]. */ +export function convertTools(tools: Context["tools"]): FunctionToolDefinition[] { + if (!tools || tools.length === 0) { + return []; + } + return tools.map((tool) => ({ + type: "function" as const, + function: { + name: tool.name, + description: typeof tool.description === "string" ? tool.description : undefined, + parameters: (tool.parameters ?? {}) as Record, + }, + })); +} + +/** + * Convert the full pi-ai message history to an OpenAI `input` array. + * Handles user messages, assistant text+tool-call messages, and tool results. + */ +export function convertMessagesToInputItems(messages: Message[]): InputItem[] { + const items: InputItem[] = []; + + for (const msg of messages) { + const m = msg as AnyMessage; + + if (m.role === "user") { + const parts = contentToOpenAIParts(m.content); + items.push({ + type: "message", + role: "user", + content: + parts.length === 1 && parts[0]?.type === "input_text" + ? (parts[0] as { type: "input_text"; text: string }).text + : parts, + }); + continue; + } + + if (m.role === "assistant") { + const content = m.content; + if (Array.isArray(content)) { + // Collect text blocks and tool calls separately + const textParts: string[] = []; + for (const block of content as Array<{ + type?: string; + text?: string; + id?: string; + name?: string; + arguments?: Record; + thinking?: string; + }>) { + if (block.type === "text" && typeof block.text === "string") { + textParts.push(block.text); + } else if (block.type === "thinking" && typeof block.thinking === "string") { + // Skip thinking blocks — not sent back to the model + } else if (block.type === "toolCall") { + // Push accumulated text first + if (textParts.length > 0) { + items.push({ + type: "message", + role: "assistant", + content: textParts.join(""), + }); + textParts.length = 0; + } + // Push function_call item + items.push({ + type: "function_call", + call_id: typeof block.id === "string" ? block.id : `call_${randomUUID()}`, + name: block.name ?? "", + arguments: + typeof block.arguments === "string" + ? block.arguments + : JSON.stringify(block.arguments ?? {}), + }); + } + } + if (textParts.length > 0) { + items.push({ + type: "message", + role: "assistant", + content: textParts.join(""), + }); + } + } else { + const text = contentToText(m.content); + if (text) { + items.push({ + type: "message", + role: "assistant", + content: text, + }); + } + } + continue; + } + + if (m.role === "toolResult") { + const tr = m as unknown as { + toolCallId: string; + content: unknown; + isError: boolean; + }; + const outputText = contentToText(tr.content); + items.push({ + type: "function_call_output", + call_id: tr.toolCallId, + output: outputText, + }); + continue; + } + } + + return items; +} + +// ───────────────────────────────────────────────────────────────────────────── +// Response object → AssistantMessage +// ───────────────────────────────────────────────────────────────────────────── + +export function buildAssistantMessageFromResponse( + response: ResponseObject, + modelInfo: { api: string; provider: string; id: string }, +): AssistantMessage { + const content: (TextContent | ToolCall)[] = []; + + for (const item of response.output ?? []) { + if (item.type === "message") { + for (const part of item.content ?? []) { + if (part.type === "output_text" && part.text) { + content.push({ type: "text", text: part.text }); + } + } + } else if (item.type === "function_call") { + content.push({ + type: "toolCall", + id: item.call_id, + name: item.name, + arguments: (() => { + try { + return JSON.parse(item.arguments) as Record; + } catch { + return {} as Record; + } + })(), + }); + } + // "reasoning" items are informational only; skip. + } + + const hasToolCalls = content.some((c) => c.type === "toolCall"); + const stopReason: StopReason = hasToolCalls ? "toolUse" : "stop"; + + const usage: Usage = { + input: response.usage?.input_tokens ?? 0, + output: response.usage?.output_tokens ?? 0, + cacheRead: 0, + cacheWrite: 0, + totalTokens: response.usage?.total_tokens ?? 0, + cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 }, + }; + + return { + role: "assistant", + content, + stopReason, + api: modelInfo.api, + provider: modelInfo.provider, + model: modelInfo.id, + usage, + timestamp: Date.now(), + }; +} + +// ───────────────────────────────────────────────────────────────────────────── +// StreamFn factory +// ───────────────────────────────────────────────────────────────────────────── + +export interface OpenAIWebSocketStreamOptions { + /** Manager options (url override, retry counts, etc.) */ + managerOptions?: OpenAIWebSocketManagerOptions; + /** Abort signal forwarded from the run. */ + signal?: AbortSignal; +} + +type WsTransport = "sse" | "websocket" | "auto"; + +function resolveWsTransport(options: Parameters[2]): WsTransport { + const transport = (options as { transport?: unknown } | undefined)?.transport; + return transport === "sse" || transport === "websocket" || transport === "auto" + ? transport + : "auto"; +} + +/** + * Creates a `StreamFn` backed by a persistent WebSocket connection to the + * OpenAI Responses API. The first call for a given `sessionId` opens the + * connection; subsequent calls reuse it, sending only incremental tool-result + * inputs with `previous_response_id`. + * + * If the WebSocket connection is unavailable, the function falls back to the + * standard `streamSimple` HTTP path and logs a warning. + * + * @param apiKey OpenAI API key + * @param sessionId Agent session ID (used as the registry key) + * @param opts Optional manager + abort signal overrides + */ +export function createOpenAIWebSocketStreamFn( + apiKey: string, + sessionId: string, + opts: OpenAIWebSocketStreamOptions = {}, +): StreamFn { + return (model, context, options) => { + const eventStream = createAssistantMessageEventStream(); + + const run = async () => { + const transport = resolveWsTransport(options); + if (transport === "sse") { + return fallbackToHttp(model, context, options, eventStream, opts.signal); + } + + // ── 1. Get or create session state ────────────────────────────────── + let session = wsRegistry.get(sessionId); + + if (!session) { + const manager = new OpenAIWebSocketManager(opts.managerOptions); + session = { + manager, + lastContextLength: 0, + everConnected: false, + broken: false, + }; + wsRegistry.set(sessionId, session); + } + + // ── 2. Ensure connection is open ───────────────────────────────────── + if (!session.manager.isConnected() && !session.broken) { + try { + await session.manager.connect(apiKey); + session.everConnected = true; + log.debug(`[ws-stream] connected for session=${sessionId}`); + } catch (connErr) { + // Cancel any background reconnect attempts before marking as broken. + try { + session.manager.close(); + } catch { + /* ignore */ + } + session.broken = true; + wsRegistry.delete(sessionId); + if (transport === "websocket") { + throw connErr instanceof Error ? connErr : new Error(String(connErr)); + } + log.warn( + `[ws-stream] WebSocket connect failed for session=${sessionId}; falling back to HTTP. error=${String(connErr)}`, + ); + // Fall back to HTTP immediately + return fallbackToHttp(model, context, options, eventStream, opts.signal); + } + } + + if (session.broken || !session.manager.isConnected()) { + if (transport === "websocket") { + throw new Error("WebSocket session disconnected"); + } + log.warn(`[ws-stream] session=${sessionId} broken/disconnected; falling back to HTTP`); + // Clean up stale session to prevent next turn from using stale + // previousResponseId / lastContextLength after a mid-request drop. + try { + session.manager.close(); + } catch { + /* ignore */ + } + wsRegistry.delete(sessionId); + return fallbackToHttp(model, context, options, eventStream, opts.signal); + } + + // ── 3. Compute incremental vs full input ───────────────────────────── + const prevResponseId = session.manager.previousResponseId; + let inputItems: InputItem[]; + + if (prevResponseId && session.lastContextLength > 0) { + // Subsequent turn: only send new messages (tool results) since last call + const newMessages = context.messages.slice(session.lastContextLength); + // Filter to only tool results — the assistant message is already in server context + const toolResults = newMessages.filter((m) => (m as AnyMessage).role === "toolResult"); + if (toolResults.length === 0) { + // Shouldn't happen in a well-formed turn, but fall back to full context + log.debug( + `[ws-stream] session=${sessionId}: no new tool results found; sending full context`, + ); + inputItems = buildFullInput(context); + } else { + inputItems = convertMessagesToInputItems(toolResults); + } + log.debug( + `[ws-stream] session=${sessionId}: incremental send (${inputItems.length} tool results) previous_response_id=${prevResponseId}`, + ); + } else { + // First turn: send full context + inputItems = buildFullInput(context); + log.debug( + `[ws-stream] session=${sessionId}: full context send (${inputItems.length} items)`, + ); + } + + // ── 4. Build & send response.create ────────────────────────────────── + const tools = convertTools(context.tools); + + // Forward generation options that the HTTP path (openai-responses provider) also uses. + // Cast to record since SimpleStreamOptions carries openai-specific fields as unknown. + const streamOpts = options as + | (Record & { + temperature?: number; + maxTokens?: number; + topP?: number; + toolChoice?: unknown; + }) + | undefined; + const extraParams: Record = {}; + if (streamOpts?.temperature !== undefined) { + extraParams.temperature = streamOpts.temperature; + } + if (streamOpts?.maxTokens) { + extraParams.max_output_tokens = streamOpts.maxTokens; + } + if (streamOpts?.topP !== undefined) { + extraParams.top_p = streamOpts.topP; + } + if (streamOpts?.toolChoice !== undefined) { + extraParams.tool_choice = streamOpts.toolChoice; + } + if (streamOpts?.reasoningEffort || streamOpts?.reasoningSummary) { + const reasoning: { effort?: string; summary?: string } = {}; + if (streamOpts.reasoningEffort !== undefined) { + reasoning.effort = streamOpts.reasoningEffort as string; + } + if (streamOpts.reasoningSummary !== undefined) { + reasoning.summary = streamOpts.reasoningSummary as string; + } + extraParams.reasoning = reasoning; + } + + const payload: Record = { + type: "response.create", + model: model.id, + store: false, + input: inputItems, + instructions: context.systemPrompt ?? undefined, + tools: tools.length > 0 ? tools : undefined, + ...(prevResponseId ? { previous_response_id: prevResponseId } : {}), + ...extraParams, + }; + options?.onPayload?.(payload); + + try { + session.manager.send(payload as Parameters[0]); + } catch (sendErr) { + if (transport === "websocket") { + throw sendErr instanceof Error ? sendErr : new Error(String(sendErr)); + } + log.warn( + `[ws-stream] send failed for session=${sessionId}; falling back to HTTP. error=${String(sendErr)}`, + ); + // Fully reset session state so the next WS turn doesn't use stale + // previous_response_id or lastContextLength from before the failure. + try { + session.manager.close(); + } catch { + /* ignore */ + } + wsRegistry.delete(sessionId); + return fallbackToHttp(model, context, options, eventStream, opts.signal); + } + + eventStream.push({ + type: "start", + partial: { + role: "assistant", + content: [], + stopReason: "stop", + api: model.api, + provider: model.provider, + model: model.id, + usage: { + input: 0, + output: 0, + cacheRead: 0, + cacheWrite: 0, + totalTokens: 0, + cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 }, + }, + timestamp: Date.now(), + }, + }); + + // ── 5. Wait for response.completed ─────────────────────────────────── + const capturedContextLength = context.messages.length; + + await new Promise((resolve, reject) => { + // Honour abort signal + const abortHandler = () => { + cleanup(); + reject(new Error("aborted")); + }; + const signal = opts.signal ?? (options as { signal?: AbortSignal } | undefined)?.signal; + if (signal?.aborted) { + reject(new Error("aborted")); + return; + } + signal?.addEventListener("abort", abortHandler, { once: true }); + + // If the WebSocket drops mid-request, reject so we don't hang forever. + const closeHandler = (code: number, reason: string) => { + cleanup(); + reject( + new Error(`WebSocket closed mid-request (code=${code}, reason=${reason || "unknown"})`), + ); + }; + session.manager.on("close", closeHandler); + + const cleanup = () => { + signal?.removeEventListener("abort", abortHandler); + session.manager.off("close", closeHandler); + unsubscribe(); + }; + + const unsubscribe = session.manager.onMessage((event) => { + if (event.type === "response.completed") { + cleanup(); + // Update session state + session.lastContextLength = capturedContextLength; + // Build and emit the assistant message + const assistantMsg = buildAssistantMessageFromResponse(event.response, { + api: model.api, + provider: model.provider, + id: model.id, + }); + const reason: Extract = + assistantMsg.stopReason === "toolUse" ? "toolUse" : "stop"; + eventStream.push({ type: "done", reason, message: assistantMsg }); + resolve(); + } else if (event.type === "response.failed") { + cleanup(); + const errMsg = event.response?.error?.message ?? "Response failed"; + reject(new Error(`OpenAI WebSocket response failed: ${errMsg}`)); + } else if (event.type === "error") { + cleanup(); + reject(new Error(`OpenAI WebSocket error: ${event.message} (code=${event.code})`)); + } else if (event.type === "response.output_text.delta") { + // Stream partial text updates for responsive UI + const partialMsg: AssistantMessage = { + role: "assistant", + content: [{ type: "text", text: event.delta }], + stopReason: "stop", + api: model.api, + provider: model.provider, + model: model.id, + usage: { + input: 0, + output: 0, + cacheRead: 0, + cacheWrite: 0, + totalTokens: 0, + cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 }, + }, + timestamp: Date.now(), + }; + eventStream.push({ + type: "text_delta", + contentIndex: 0, + delta: event.delta, + partial: partialMsg, + }); + } + }); + }); + }; + + queueMicrotask(() => + run().catch((err) => { + const errorMessage = err instanceof Error ? err.message : String(err); + log.warn(`[ws-stream] session=${sessionId} run error: ${errorMessage}`); + eventStream.push({ + type: "error", + reason: "error", + error: { + role: "assistant" as const, + content: [], + stopReason: "error" as StopReason, + errorMessage, + api: model.api, + provider: model.provider, + model: model.id, + usage: { + input: 0, + output: 0, + cacheRead: 0, + cacheWrite: 0, + totalTokens: 0, + cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 }, + }, + timestamp: Date.now(), + }, + }); + eventStream.end(); + }), + ); + + return eventStream; + }; +} + +// ───────────────────────────────────────────────────────────────────────────── +// Helpers +// ───────────────────────────────────────────────────────────────────────────── + +/** Build full input items from context (system prompt is passed via `instructions` field). */ +function buildFullInput(context: Context): InputItem[] { + return convertMessagesToInputItems(context.messages); +} + +/** + * Fall back to HTTP (`streamSimple`) and pipe events into the existing stream. + * This is called when the WebSocket is broken or unavailable. + */ +async function fallbackToHttp( + model: Parameters[0], + context: Parameters[1], + options: Parameters[2], + eventStream: ReturnType, + signal?: AbortSignal, +): Promise { + const mergedOptions = signal ? { ...options, signal } : options; + const httpStream = streamSimple(model, context, mergedOptions); + for await (const event of httpStream) { + eventStream.push(event); + } +} diff --git a/src/agents/pi-embedded-runner-extraparams.test.ts b/src/agents/pi-embedded-runner-extraparams.test.ts index 766604542e2..3f9e19909a3 100644 --- a/src/agents/pi-embedded-runner-extraparams.test.ts +++ b/src/agents/pi-embedded-runner-extraparams.test.ts @@ -544,7 +544,7 @@ describe("applyExtraParamsToAgent", () => { expect(calls[0]?.transport).toBe("auto"); }); - it("does not set transport defaults for non-Codex providers", () => { + it("defaults OpenAI transport to auto (WebSocket-first)", () => { const { calls, agent } = createOptionsCaptureAgent(); applyExtraParamsToAgent(agent, undefined, "openai", "gpt-5"); @@ -558,7 +558,24 @@ describe("applyExtraParamsToAgent", () => { void agent.streamFn?.(model, context, {}); expect(calls).toHaveLength(1); - expect(calls[0]?.transport).toBeUndefined(); + expect(calls[0]?.transport).toBe("auto"); + }); + + it("lets runtime options override OpenAI default transport", () => { + const { calls, agent } = createOptionsCaptureAgent(); + + applyExtraParamsToAgent(agent, undefined, "openai", "gpt-5"); + + const model = { + api: "openai-responses", + provider: "openai", + id: "gpt-5", + } as Model<"openai-responses">; + const context: Context = { messages: [] }; + void agent.streamFn?.(model, context, { transport: "sse" }); + + expect(calls).toHaveLength(1); + expect(calls[0]?.transport).toBe("sse"); }); it("allows forcing Codex transport to SSE", () => { @@ -878,6 +895,7 @@ describe("applyExtraParamsToAgent", () => { contextWindow: 128_000, maxTokens: 16_384, compat: { supportsStore: false }, +<<<<<<< HEAD } as Model<"openai-responses"> & { compat?: { supportsStore?: boolean } }, }); expect(payload.store).toBe(false); diff --git a/src/agents/pi-embedded-runner/extra-params.ts b/src/agents/pi-embedded-runner/extra-params.ts index 70678f08bb4..de1d552957b 100644 --- a/src/agents/pi-embedded-runner/extra-params.ts +++ b/src/agents/pi-embedded-runner/extra-params.ts @@ -319,6 +319,15 @@ function createCodexDefaultTransportWrapper(baseStreamFn: StreamFn | undefined): }); } +function createOpenAIDefaultTransportWrapper(baseStreamFn: StreamFn | undefined): StreamFn { + const underlying = baseStreamFn ?? streamSimple; + return (model, context, options) => + underlying(model, context, { + ...options, + transport: options?.transport ?? "auto", + }); +} + function isAnthropic1MModel(modelId: string): boolean { const normalized = modelId.trim().toLowerCase(); return ANTHROPIC_1M_MODEL_PREFIXES.some((prefix) => normalized.startsWith(prefix)); @@ -740,6 +749,9 @@ export function applyExtraParamsToAgent( if (provider === "openai-codex") { // Default Codex to WebSocket-first when nothing else specifies transport. agent.streamFn = createCodexDefaultTransportWrapper(agent.streamFn); + } else if (provider === "openai") { + // Default OpenAI Responses to WebSocket-first with transparent SSE fallback. + agent.streamFn = createOpenAIDefaultTransportWrapper(agent.streamFn); } const override = extraParamsOverride && Object.keys(extraParamsOverride).length > 0 diff --git a/src/agents/pi-embedded-runner/run/attempt.ts b/src/agents/pi-embedded-runner/run/attempt.ts index 035a84ba015..7d7f473825f 100644 --- a/src/agents/pi-embedded-runner/run/attempt.ts +++ b/src/agents/pi-embedded-runner/run/attempt.ts @@ -42,6 +42,7 @@ import { resolveImageSanitizationLimits } from "../../image-sanitization.js"; import { resolveModelAuthMode } from "../../model-auth.js"; import { normalizeProviderId, resolveDefaultModelForAgent } from "../../model-selection.js"; import { createOllamaStreamFn, OLLAMA_NATIVE_BASE_URL } from "../../ollama-stream.js"; +import { createOpenAIWebSocketStreamFn, releaseWsSession } from "../../openai-ws-stream.js"; import { resolveOwnerDisplaySetting } from "../../owner-display.js"; import { isCloudCodeAssistFormatError, @@ -866,6 +867,16 @@ export async function runEmbeddedAttempt( typeof providerConfig?.baseUrl === "string" ? providerConfig.baseUrl.trim() : ""; const ollamaBaseUrl = modelBaseUrl || providerBaseUrl || OLLAMA_NATIVE_BASE_URL; activeSession.agent.streamFn = createOllamaStreamFn(ollamaBaseUrl); + } else if (params.model.api === "openai-responses" && params.provider === "openai") { + const wsApiKey = await params.authStorage.getApiKey(params.provider); + if (wsApiKey) { + activeSession.agent.streamFn = createOpenAIWebSocketStreamFn(wsApiKey, params.sessionId, { + signal: runAbortController.signal, + }); + } else { + log.warn(`[ws-stream] no API key for provider=${params.provider}; using HTTP transport`); + activeSession.agent.streamFn = streamSimple; + } } else { // Force a stable streamFn reference so vitest can reliably mock @mariozechner/pi-ai. activeSession.agent.streamFn = streamSimple; @@ -1548,6 +1559,7 @@ export async function runEmbeddedAttempt( sessionManager, }); session?.dispose(); + releaseWsSession(params.sessionId); await sessionLock.release(); } } finally {