feat(agents): make openai responses websocket-first with fallback

This commit is contained in:
Peter Steinberger
2026-03-01 21:50:33 +00:00
parent 38da2d076c
commit 7ced38b5ef
7 changed files with 1260 additions and 6 deletions

View File

@@ -43,6 +43,8 @@ OpenClaw ships with the piai catalog. These providers require **no**
- Optional rotation: `OPENAI_API_KEYS`, `OPENAI_API_KEY_1`, `OPENAI_API_KEY_2`, plus `OPENCLAW_LIVE_OPENAI_KEY` (single override)
- Example model: `openai/gpt-5.1-codex`
- CLI: `openclaw onboard --auth-choice openai-api-key`
- Default transport is `auto` (WebSocket-first, SSE fallback)
- Override per model via `agents.defaults.models["openai/<model>"].params.transport` (`"sse"`, `"websocket"`, or `"auto"`)
```json5
{

View File

@@ -56,12 +56,14 @@ openclaw models auth login --provider openai-codex
}
```
### Codex transport default
### Transport default
OpenClaw uses `pi-ai` for model streaming. For `openai-codex/*` models you can set
`agents.defaults.models.<provider/model>.params.transport` to select transport:
OpenClaw uses `pi-ai` for model streaming. For both `openai/*` and
`openai-codex/*`, default transport is `"auto"` (WebSocket-first, then SSE
fallback).
You can set `agents.defaults.models.<provider/model>.params.transport`:
- Default is `"auto"` (WebSocket-first, then SSE fallback).
- `"sse"`: force SSE
- `"websocket"`: force WebSocket
- `"auto"`: try WebSocket, then fall back to SSE

View File

@@ -0,0 +1,528 @@
/**
* OpenAI WebSocket Connection Manager
*
* Manages a persistent WebSocket connection to the OpenAI Responses API
* (wss://api.openai.com/v1/responses) for multi-turn tool-call workflows.
*
* Features:
* - Auto-reconnect with exponential backoff (max 5 retries: 1s/2s/4s/8s/16s)
* - Tracks previous_response_id per connection for incremental turns
* - Warm-up support (generate: false) to pre-load the connection
* - Typed WebSocket event definitions matching the Responses API SSE spec
*
* @see https://developers.openai.com/api/docs/guides/websocket-mode
*/
import { EventEmitter } from "node:events";
import WebSocket from "ws";
// ─────────────────────────────────────────────────────────────────────────────
// WebSocket Event Types (Server → Client)
// ─────────────────────────────────────────────────────────────────────────────
export interface ResponseObject {
id: string;
object: "response";
created_at: number;
status: "in_progress" | "completed" | "failed" | "cancelled" | "incomplete";
model: string;
output: OutputItem[];
usage?: UsageInfo;
error?: { code: string; message: string };
}
export interface UsageInfo {
input_tokens: number;
output_tokens: number;
total_tokens: number;
}
export type OutputItem =
| {
type: "message";
id: string;
role: "assistant";
content: Array<{ type: "output_text"; text: string }>;
status?: "in_progress" | "completed";
}
| {
type: "function_call";
id: string;
call_id: string;
name: string;
arguments: string;
status?: "in_progress" | "completed";
}
| {
type: "reasoning";
id: string;
content?: string;
summary?: string;
};
export interface ResponseCreatedEvent {
type: "response.created";
response: ResponseObject;
}
export interface ResponseInProgressEvent {
type: "response.in_progress";
response: ResponseObject;
}
export interface ResponseCompletedEvent {
type: "response.completed";
response: ResponseObject;
}
export interface ResponseFailedEvent {
type: "response.failed";
response: ResponseObject;
}
export interface OutputItemAddedEvent {
type: "response.output_item.added";
output_index: number;
item: OutputItem;
}
export interface OutputItemDoneEvent {
type: "response.output_item.done";
output_index: number;
item: OutputItem;
}
export interface ContentPartAddedEvent {
type: "response.content_part.added";
item_id: string;
output_index: number;
content_index: number;
part: { type: "output_text"; text: string };
}
export interface ContentPartDoneEvent {
type: "response.content_part.done";
item_id: string;
output_index: number;
content_index: number;
part: { type: "output_text"; text: string };
}
export interface OutputTextDeltaEvent {
type: "response.output_text.delta";
item_id: string;
output_index: number;
content_index: number;
delta: string;
}
export interface OutputTextDoneEvent {
type: "response.output_text.done";
item_id: string;
output_index: number;
content_index: number;
text: string;
}
export interface FunctionCallArgumentsDeltaEvent {
type: "response.function_call_arguments.delta";
item_id: string;
output_index: number;
call_id: string;
delta: string;
}
export interface FunctionCallArgumentsDoneEvent {
type: "response.function_call_arguments.done";
item_id: string;
output_index: number;
call_id: string;
arguments: string;
}
export interface RateLimitUpdatedEvent {
type: "rate_limits.updated";
rate_limits: Array<{
name: string;
limit: number;
remaining: number;
reset_seconds: number;
}>;
}
export interface ErrorEvent {
type: "error";
code: string;
message: string;
param?: string;
}
export type OpenAIWebSocketEvent =
| ResponseCreatedEvent
| ResponseInProgressEvent
| ResponseCompletedEvent
| ResponseFailedEvent
| OutputItemAddedEvent
| OutputItemDoneEvent
| ContentPartAddedEvent
| ContentPartDoneEvent
| OutputTextDeltaEvent
| OutputTextDoneEvent
| FunctionCallArgumentsDeltaEvent
| FunctionCallArgumentsDoneEvent
| RateLimitUpdatedEvent
| ErrorEvent;
// ─────────────────────────────────────────────────────────────────────────────
// Client → Server Event Types
// ─────────────────────────────────────────────────────────────────────────────
export type ContentPart =
| { type: "input_text"; text: string }
| { type: "output_text"; text: string }
| {
type: "input_image";
source: { type: "url"; url: string } | { type: "base64"; media_type: string; data: string };
};
export type InputItem =
| {
type: "message";
role: "system" | "developer" | "user" | "assistant";
content: string | ContentPart[];
}
| { type: "function_call"; id?: string; call_id?: string; name: string; arguments: string }
| { type: "function_call_output"; call_id: string; output: string }
| { type: "reasoning"; content?: string; encrypted_content?: string; summary?: string }
| { type: "item_reference"; id: string };
export type ToolChoice =
| "auto"
| "none"
| "required"
| { type: "function"; function: { name: string } };
export interface FunctionToolDefinition {
type: "function";
function: {
name: string;
description?: string;
parameters?: Record<string, unknown>;
};
}
/** Standard response.create event payload (full turn) */
export interface ResponseCreateEvent {
type: "response.create";
model: string;
store?: boolean;
stream?: boolean;
input?: string | InputItem[];
instructions?: string;
tools?: FunctionToolDefinition[];
tool_choice?: ToolChoice;
context_management?: unknown;
previous_response_id?: string;
max_output_tokens?: number;
temperature?: number;
top_p?: number;
metadata?: Record<string, string>;
reasoning?: { effort?: "low" | "medium" | "high"; summary?: "auto" | "concise" | "detailed" };
truncation?: "auto" | "disabled";
[key: string]: unknown;
}
/** Warm-up payload: generate: false pre-loads connection without generating output */
export interface WarmUpEvent extends ResponseCreateEvent {
generate: false;
}
export type ClientEvent = ResponseCreateEvent | WarmUpEvent;
// ─────────────────────────────────────────────────────────────────────────────
// Connection Manager
// ─────────────────────────────────────────────────────────────────────────────
const OPENAI_WS_URL = "wss://api.openai.com/v1/responses";
const MAX_RETRIES = 5;
/** Backoff delays in ms: 1s, 2s, 4s, 8s, 16s */
const BACKOFF_DELAYS_MS = [1000, 2000, 4000, 8000, 16000] as const;
export interface OpenAIWebSocketManagerOptions {
/** Override the default WebSocket URL (useful for testing) */
url?: string;
/** Maximum number of reconnect attempts (default: 5) */
maxRetries?: number;
/** Custom backoff delays in ms (default: [1000, 2000, 4000, 8000, 16000]) */
backoffDelaysMs?: readonly number[];
}
type InternalEvents = {
message: [event: OpenAIWebSocketEvent];
open: [];
close: [code: number, reason: string];
error: [err: Error];
};
/**
* Manages a persistent WebSocket connection to the OpenAI Responses API.
*
* Usage:
* ```ts
* const manager = new OpenAIWebSocketManager();
* await manager.connect(apiKey);
*
* manager.onMessage((event) => {
* if (event.type === "response.completed") {
* console.log("Response ID:", event.response.id);
* }
* });
*
* manager.send({ type: "response.create", model: "gpt-5.2", input: [...] });
* ```
*/
export class OpenAIWebSocketManager extends EventEmitter<InternalEvents> {
private ws: WebSocket | null = null;
private apiKey: string | null = null;
private retryCount = 0;
private retryTimer: NodeJS.Timeout | null = null;
private closed = false;
/** The ID of the most recent completed response on this connection. */
private _previousResponseId: string | null = null;
private readonly wsUrl: string;
private readonly maxRetries: number;
private readonly backoffDelaysMs: readonly number[];
constructor(options: OpenAIWebSocketManagerOptions = {}) {
super();
this.wsUrl = options.url ?? OPENAI_WS_URL;
this.maxRetries = options.maxRetries ?? MAX_RETRIES;
this.backoffDelaysMs = options.backoffDelaysMs ?? BACKOFF_DELAYS_MS;
}
// ─── Public API ────────────────────────────────────────────────────────────
/**
* Returns the previous_response_id from the last completed response,
* for use in subsequent response.create events.
*/
get previousResponseId(): string | null {
return this._previousResponseId;
}
/**
* Opens a WebSocket connection to the OpenAI Responses API.
* Resolves when the connection is established (open event fires).
* Rejects if the initial connection fails after max retries.
*/
connect(apiKey: string): Promise<void> {
this.apiKey = apiKey;
this.closed = false;
this.retryCount = 0;
return this._openConnection();
}
/**
* Sends a typed event to the OpenAI Responses API over the WebSocket.
* Throws if the connection is not open.
*/
send(event: ClientEvent): void {
if (!this.ws || this.ws.readyState !== WebSocket.OPEN) {
throw new Error(
`OpenAIWebSocketManager: cannot send — connection is not open (readyState=${this.ws?.readyState ?? "no socket"})`,
);
}
this.ws.send(JSON.stringify(event));
}
/**
* Registers a handler for incoming server-sent WebSocket events.
* Returns an unsubscribe function.
*/
onMessage(handler: (event: OpenAIWebSocketEvent) => void): () => void {
this.on("message", handler);
return () => {
this.off("message", handler);
};
}
/**
* Returns true if the WebSocket is currently open and ready to send.
*/
isConnected(): boolean {
return this.ws !== null && this.ws.readyState === WebSocket.OPEN;
}
/**
* Permanently closes the WebSocket connection and disables auto-reconnect.
*/
close(): void {
this.closed = true;
this._cancelRetryTimer();
if (this.ws) {
this.ws.removeAllListeners();
if (this.ws.readyState === WebSocket.OPEN || this.ws.readyState === WebSocket.CONNECTING) {
this.ws.close(1000, "Client closed");
}
this.ws = null;
}
}
// ─── Internal: Connection Lifecycle ────────────────────────────────────────
private _openConnection(): Promise<void> {
return new Promise<void>((resolve, reject) => {
if (!this.apiKey) {
reject(new Error("OpenAIWebSocketManager: apiKey is required before connecting."));
return;
}
const socket = new WebSocket(this.wsUrl, {
headers: {
Authorization: `Bearer ${this.apiKey}`,
"OpenAI-Beta": "responses-websocket=v1",
},
});
this.ws = socket;
const onOpen = () => {
this.retryCount = 0;
resolve();
this.emit("open");
};
const onError = (err: Error) => {
// Remove open listener so we don't resolve after an error.
socket.off("open", onOpen);
// Emit "error" on the manager only when there are listeners; otherwise
// the promise rejection below is the primary error channel for this
// initial connection failure. (An uncaught "error" event in Node.js
// throws synchronously and would prevent the promise from rejecting.)
if (this.listenerCount("error") > 0) {
this.emit("error", err);
}
reject(err);
};
const onClose = (code: number, reason: Buffer) => {
const reasonStr = reason.toString();
this.emit("close", code, reasonStr);
if (!this.closed) {
this._scheduleReconnect();
}
};
const onMessage = (data: WebSocket.RawData) => {
this._handleMessage(data);
};
socket.once("open", onOpen);
socket.on("error", onError);
socket.on("close", onClose);
socket.on("message", onMessage);
});
}
private _scheduleReconnect(): void {
if (this.closed) {
return;
}
if (this.retryCount >= this.maxRetries) {
this._safeEmitError(
new Error(`OpenAIWebSocketManager: max reconnect retries (${this.maxRetries}) exceeded.`),
);
return;
}
const delayMs =
this.backoffDelaysMs[Math.min(this.retryCount, this.backoffDelaysMs.length - 1)] ?? 1000;
this.retryCount++;
this.retryTimer = setTimeout(() => {
if (this.closed) {
return;
}
this._openConnection().catch((err: unknown) => {
// onError handler already emitted error event; schedule next retry.
void err;
this._scheduleReconnect();
});
}, delayMs);
}
/** Emit an error only if there are listeners; prevents Node.js from crashing
* with "unhandled 'error' event" when no one is listening. */
private _safeEmitError(err: Error): void {
if (this.listenerCount("error") > 0) {
this.emit("error", err);
}
}
private _cancelRetryTimer(): void {
if (this.retryTimer !== null) {
clearTimeout(this.retryTimer);
this.retryTimer = null;
}
}
private _handleMessage(data: WebSocket.RawData): void {
let text: string;
if (typeof data === "string") {
text = data;
} else if (Buffer.isBuffer(data)) {
text = data.toString("utf8");
} else if (data instanceof ArrayBuffer) {
text = Buffer.from(data).toString("utf8");
} else {
// Blob or other — coerce to string
text = String(data);
}
let parsed: unknown;
try {
parsed = JSON.parse(text);
} catch {
this._safeEmitError(
new Error(`OpenAIWebSocketManager: failed to parse message: ${text.slice(0, 200)}`),
);
return;
}
if (!parsed || typeof parsed !== "object" || !("type" in parsed)) {
this._safeEmitError(
new Error(
`OpenAIWebSocketManager: unexpected message shape (no "type" field): ${text.slice(0, 200)}`,
),
);
return;
}
const event = parsed as OpenAIWebSocketEvent;
// Track previous_response_id on completion
if (event.type === "response.completed" && event.response?.id) {
this._previousResponseId = event.response.id;
}
this.emit("message", event);
}
/**
* Sends a warm-up event to pre-load the connection and model without generating output.
* Pass tools/instructions to prime the connection for the upcoming session.
*/
warmUp(params: { model: string; tools?: FunctionToolDefinition[]; instructions?: string }): void {
const event: WarmUpEvent = {
type: "response.create",
generate: false,
model: params.model,
...(params.tools ? { tools: params.tools } : {}),
...(params.instructions ? { instructions: params.instructions } : {}),
};
this.send(event);
}
}

View File

@@ -0,0 +1,680 @@
/**
* OpenAI WebSocket StreamFn Integration
*
* Wraps `OpenAIWebSocketManager` in a `StreamFn` that can be plugged into the
* pi-embedded-runner agent in place of the default `streamSimple` HTTP function.
*
* Key behaviours:
* - Per-session `OpenAIWebSocketManager` (keyed by sessionId)
* - Tracks `previous_response_id` to send only incremental tool-result inputs
* - Falls back to `streamSimple` (HTTP) if the WebSocket connection fails
* - Cleanup helpers for releasing sessions after the run completes
*
* Complexity budget & risk mitigation:
* - **Transport aware**: respects `transport` (`auto` | `websocket` | `sse`)
* - **Transparent fallback in `auto` mode**: connect/send failures fall back to
* the existing HTTP `streamSimple`; forced `websocket` mode surfaces WS errors
* - **Zero shared state**: per-session registry; session cleanup on dispose prevents leaks
* - **Full parity**: all generation options (temperature, top_p, max_output_tokens,
* tool_choice, reasoning) forwarded identically to the HTTP path
*
* @see src/agents/openai-ws-connection.ts for the connection manager
*/
import { randomUUID } from "node:crypto";
import type { StreamFn } from "@mariozechner/pi-agent-core";
import type {
AssistantMessage,
Context,
Message,
StopReason,
TextContent,
ToolCall,
Usage,
} from "@mariozechner/pi-ai";
import { createAssistantMessageEventStream, streamSimple } from "@mariozechner/pi-ai";
import {
OpenAIWebSocketManager,
type ContentPart,
type FunctionToolDefinition,
type InputItem,
type OpenAIWebSocketManagerOptions,
type ResponseObject,
} from "./openai-ws-connection.js";
import { log } from "./pi-embedded-runner/logger.js";
// ─────────────────────────────────────────────────────────────────────────────
// Per-session state
// ─────────────────────────────────────────────────────────────────────────────
interface WsSession {
manager: OpenAIWebSocketManager;
/** Number of messages that were in context.messages at the END of the last streamFn call. */
lastContextLength: number;
/** True if the connection has been established at least once. */
everConnected: boolean;
/** True if the session is permanently broken (no more reconnect). */
broken: boolean;
}
/** Module-level registry: sessionId → WsSession */
const wsRegistry = new Map<string, WsSession>();
// ─────────────────────────────────────────────────────────────────────────────
// Public registry helpers
// ─────────────────────────────────────────────────────────────────────────────
/**
* Release and close the WebSocket session for the given sessionId.
* Call this after the agent run completes to free the connection.
*/
export function releaseWsSession(sessionId: string): void {
const session = wsRegistry.get(sessionId);
if (session) {
try {
session.manager.close();
} catch {
// Ignore close errors — connection may already be gone.
}
wsRegistry.delete(sessionId);
}
}
/**
* Returns true if a live WebSocket session exists for the given sessionId.
*/
export function hasWsSession(sessionId: string): boolean {
const s = wsRegistry.get(sessionId);
return !!(s && !s.broken && s.manager.isConnected());
}
// ─────────────────────────────────────────────────────────────────────────────
// Message format converters
// ─────────────────────────────────────────────────────────────────────────────
type AnyMessage = Message & { role: string; content: unknown };
/** Convert pi-ai content (string | ContentPart[]) to plain text. */
function contentToText(content: unknown): string {
if (typeof content === "string") {
return content;
}
if (!Array.isArray(content)) {
return "";
}
return (content as Array<{ type?: string; text?: string }>)
.filter((p) => p.type === "text" && typeof p.text === "string")
.map((p) => p.text as string)
.join("");
}
/** Convert pi-ai content to OpenAI ContentPart[]. */
function contentToOpenAIParts(content: unknown): ContentPart[] {
if (typeof content === "string") {
return content ? [{ type: "input_text", text: content }] : [];
}
if (!Array.isArray(content)) {
return [];
}
const parts: ContentPart[] = [];
for (const part of content as Array<{
type?: string;
text?: string;
data?: string;
mimeType?: string;
}>) {
if (part.type === "text" && typeof part.text === "string") {
parts.push({ type: "input_text", text: part.text });
} else if (part.type === "image" && typeof part.data === "string") {
parts.push({
type: "input_image",
source: {
type: "base64",
media_type: part.mimeType ?? "image/jpeg",
data: part.data,
},
});
}
}
return parts;
}
/** Convert pi-ai tool array to OpenAI FunctionToolDefinition[]. */
export function convertTools(tools: Context["tools"]): FunctionToolDefinition[] {
if (!tools || tools.length === 0) {
return [];
}
return tools.map((tool) => ({
type: "function" as const,
function: {
name: tool.name,
description: typeof tool.description === "string" ? tool.description : undefined,
parameters: (tool.parameters ?? {}) as Record<string, unknown>,
},
}));
}
/**
* Convert the full pi-ai message history to an OpenAI `input` array.
* Handles user messages, assistant text+tool-call messages, and tool results.
*/
export function convertMessagesToInputItems(messages: Message[]): InputItem[] {
const items: InputItem[] = [];
for (const msg of messages) {
const m = msg as AnyMessage;
if (m.role === "user") {
const parts = contentToOpenAIParts(m.content);
items.push({
type: "message",
role: "user",
content:
parts.length === 1 && parts[0]?.type === "input_text"
? (parts[0] as { type: "input_text"; text: string }).text
: parts,
});
continue;
}
if (m.role === "assistant") {
const content = m.content;
if (Array.isArray(content)) {
// Collect text blocks and tool calls separately
const textParts: string[] = [];
for (const block of content as Array<{
type?: string;
text?: string;
id?: string;
name?: string;
arguments?: Record<string, unknown>;
thinking?: string;
}>) {
if (block.type === "text" && typeof block.text === "string") {
textParts.push(block.text);
} else if (block.type === "thinking" && typeof block.thinking === "string") {
// Skip thinking blocks — not sent back to the model
} else if (block.type === "toolCall") {
// Push accumulated text first
if (textParts.length > 0) {
items.push({
type: "message",
role: "assistant",
content: textParts.join(""),
});
textParts.length = 0;
}
// Push function_call item
items.push({
type: "function_call",
call_id: typeof block.id === "string" ? block.id : `call_${randomUUID()}`,
name: block.name ?? "",
arguments:
typeof block.arguments === "string"
? block.arguments
: JSON.stringify(block.arguments ?? {}),
});
}
}
if (textParts.length > 0) {
items.push({
type: "message",
role: "assistant",
content: textParts.join(""),
});
}
} else {
const text = contentToText(m.content);
if (text) {
items.push({
type: "message",
role: "assistant",
content: text,
});
}
}
continue;
}
if (m.role === "toolResult") {
const tr = m as unknown as {
toolCallId: string;
content: unknown;
isError: boolean;
};
const outputText = contentToText(tr.content);
items.push({
type: "function_call_output",
call_id: tr.toolCallId,
output: outputText,
});
continue;
}
}
return items;
}
// ─────────────────────────────────────────────────────────────────────────────
// Response object → AssistantMessage
// ─────────────────────────────────────────────────────────────────────────────
export function buildAssistantMessageFromResponse(
response: ResponseObject,
modelInfo: { api: string; provider: string; id: string },
): AssistantMessage {
const content: (TextContent | ToolCall)[] = [];
for (const item of response.output ?? []) {
if (item.type === "message") {
for (const part of item.content ?? []) {
if (part.type === "output_text" && part.text) {
content.push({ type: "text", text: part.text });
}
}
} else if (item.type === "function_call") {
content.push({
type: "toolCall",
id: item.call_id,
name: item.name,
arguments: (() => {
try {
return JSON.parse(item.arguments) as Record<string, unknown>;
} catch {
return {} as Record<string, unknown>;
}
})(),
});
}
// "reasoning" items are informational only; skip.
}
const hasToolCalls = content.some((c) => c.type === "toolCall");
const stopReason: StopReason = hasToolCalls ? "toolUse" : "stop";
const usage: Usage = {
input: response.usage?.input_tokens ?? 0,
output: response.usage?.output_tokens ?? 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: response.usage?.total_tokens ?? 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
};
return {
role: "assistant",
content,
stopReason,
api: modelInfo.api,
provider: modelInfo.provider,
model: modelInfo.id,
usage,
timestamp: Date.now(),
};
}
// ─────────────────────────────────────────────────────────────────────────────
// StreamFn factory
// ─────────────────────────────────────────────────────────────────────────────
export interface OpenAIWebSocketStreamOptions {
/** Manager options (url override, retry counts, etc.) */
managerOptions?: OpenAIWebSocketManagerOptions;
/** Abort signal forwarded from the run. */
signal?: AbortSignal;
}
type WsTransport = "sse" | "websocket" | "auto";
function resolveWsTransport(options: Parameters<StreamFn>[2]): WsTransport {
const transport = (options as { transport?: unknown } | undefined)?.transport;
return transport === "sse" || transport === "websocket" || transport === "auto"
? transport
: "auto";
}
/**
* Creates a `StreamFn` backed by a persistent WebSocket connection to the
* OpenAI Responses API. The first call for a given `sessionId` opens the
* connection; subsequent calls reuse it, sending only incremental tool-result
* inputs with `previous_response_id`.
*
* If the WebSocket connection is unavailable, the function falls back to the
* standard `streamSimple` HTTP path and logs a warning.
*
* @param apiKey OpenAI API key
* @param sessionId Agent session ID (used as the registry key)
* @param opts Optional manager + abort signal overrides
*/
export function createOpenAIWebSocketStreamFn(
apiKey: string,
sessionId: string,
opts: OpenAIWebSocketStreamOptions = {},
): StreamFn {
return (model, context, options) => {
const eventStream = createAssistantMessageEventStream();
const run = async () => {
const transport = resolveWsTransport(options);
if (transport === "sse") {
return fallbackToHttp(model, context, options, eventStream, opts.signal);
}
// ── 1. Get or create session state ──────────────────────────────────
let session = wsRegistry.get(sessionId);
if (!session) {
const manager = new OpenAIWebSocketManager(opts.managerOptions);
session = {
manager,
lastContextLength: 0,
everConnected: false,
broken: false,
};
wsRegistry.set(sessionId, session);
}
// ── 2. Ensure connection is open ─────────────────────────────────────
if (!session.manager.isConnected() && !session.broken) {
try {
await session.manager.connect(apiKey);
session.everConnected = true;
log.debug(`[ws-stream] connected for session=${sessionId}`);
} catch (connErr) {
// Cancel any background reconnect attempts before marking as broken.
try {
session.manager.close();
} catch {
/* ignore */
}
session.broken = true;
wsRegistry.delete(sessionId);
if (transport === "websocket") {
throw connErr instanceof Error ? connErr : new Error(String(connErr));
}
log.warn(
`[ws-stream] WebSocket connect failed for session=${sessionId}; falling back to HTTP. error=${String(connErr)}`,
);
// Fall back to HTTP immediately
return fallbackToHttp(model, context, options, eventStream, opts.signal);
}
}
if (session.broken || !session.manager.isConnected()) {
if (transport === "websocket") {
throw new Error("WebSocket session disconnected");
}
log.warn(`[ws-stream] session=${sessionId} broken/disconnected; falling back to HTTP`);
// Clean up stale session to prevent next turn from using stale
// previousResponseId / lastContextLength after a mid-request drop.
try {
session.manager.close();
} catch {
/* ignore */
}
wsRegistry.delete(sessionId);
return fallbackToHttp(model, context, options, eventStream, opts.signal);
}
// ── 3. Compute incremental vs full input ─────────────────────────────
const prevResponseId = session.manager.previousResponseId;
let inputItems: InputItem[];
if (prevResponseId && session.lastContextLength > 0) {
// Subsequent turn: only send new messages (tool results) since last call
const newMessages = context.messages.slice(session.lastContextLength);
// Filter to only tool results — the assistant message is already in server context
const toolResults = newMessages.filter((m) => (m as AnyMessage).role === "toolResult");
if (toolResults.length === 0) {
// Shouldn't happen in a well-formed turn, but fall back to full context
log.debug(
`[ws-stream] session=${sessionId}: no new tool results found; sending full context`,
);
inputItems = buildFullInput(context);
} else {
inputItems = convertMessagesToInputItems(toolResults);
}
log.debug(
`[ws-stream] session=${sessionId}: incremental send (${inputItems.length} tool results) previous_response_id=${prevResponseId}`,
);
} else {
// First turn: send full context
inputItems = buildFullInput(context);
log.debug(
`[ws-stream] session=${sessionId}: full context send (${inputItems.length} items)`,
);
}
// ── 4. Build & send response.create ──────────────────────────────────
const tools = convertTools(context.tools);
// Forward generation options that the HTTP path (openai-responses provider) also uses.
// Cast to record since SimpleStreamOptions carries openai-specific fields as unknown.
const streamOpts = options as
| (Record<string, unknown> & {
temperature?: number;
maxTokens?: number;
topP?: number;
toolChoice?: unknown;
})
| undefined;
const extraParams: Record<string, unknown> = {};
if (streamOpts?.temperature !== undefined) {
extraParams.temperature = streamOpts.temperature;
}
if (streamOpts?.maxTokens) {
extraParams.max_output_tokens = streamOpts.maxTokens;
}
if (streamOpts?.topP !== undefined) {
extraParams.top_p = streamOpts.topP;
}
if (streamOpts?.toolChoice !== undefined) {
extraParams.tool_choice = streamOpts.toolChoice;
}
if (streamOpts?.reasoningEffort || streamOpts?.reasoningSummary) {
const reasoning: { effort?: string; summary?: string } = {};
if (streamOpts.reasoningEffort !== undefined) {
reasoning.effort = streamOpts.reasoningEffort as string;
}
if (streamOpts.reasoningSummary !== undefined) {
reasoning.summary = streamOpts.reasoningSummary as string;
}
extraParams.reasoning = reasoning;
}
const payload: Record<string, unknown> = {
type: "response.create",
model: model.id,
store: false,
input: inputItems,
instructions: context.systemPrompt ?? undefined,
tools: tools.length > 0 ? tools : undefined,
...(prevResponseId ? { previous_response_id: prevResponseId } : {}),
...extraParams,
};
options?.onPayload?.(payload);
try {
session.manager.send(payload as Parameters<OpenAIWebSocketManager["send"]>[0]);
} catch (sendErr) {
if (transport === "websocket") {
throw sendErr instanceof Error ? sendErr : new Error(String(sendErr));
}
log.warn(
`[ws-stream] send failed for session=${sessionId}; falling back to HTTP. error=${String(sendErr)}`,
);
// Fully reset session state so the next WS turn doesn't use stale
// previous_response_id or lastContextLength from before the failure.
try {
session.manager.close();
} catch {
/* ignore */
}
wsRegistry.delete(sessionId);
return fallbackToHttp(model, context, options, eventStream, opts.signal);
}
eventStream.push({
type: "start",
partial: {
role: "assistant",
content: [],
stopReason: "stop",
api: model.api,
provider: model.provider,
model: model.id,
usage: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
},
timestamp: Date.now(),
},
});
// ── 5. Wait for response.completed ───────────────────────────────────
const capturedContextLength = context.messages.length;
await new Promise<void>((resolve, reject) => {
// Honour abort signal
const abortHandler = () => {
cleanup();
reject(new Error("aborted"));
};
const signal = opts.signal ?? (options as { signal?: AbortSignal } | undefined)?.signal;
if (signal?.aborted) {
reject(new Error("aborted"));
return;
}
signal?.addEventListener("abort", abortHandler, { once: true });
// If the WebSocket drops mid-request, reject so we don't hang forever.
const closeHandler = (code: number, reason: string) => {
cleanup();
reject(
new Error(`WebSocket closed mid-request (code=${code}, reason=${reason || "unknown"})`),
);
};
session.manager.on("close", closeHandler);
const cleanup = () => {
signal?.removeEventListener("abort", abortHandler);
session.manager.off("close", closeHandler);
unsubscribe();
};
const unsubscribe = session.manager.onMessage((event) => {
if (event.type === "response.completed") {
cleanup();
// Update session state
session.lastContextLength = capturedContextLength;
// Build and emit the assistant message
const assistantMsg = buildAssistantMessageFromResponse(event.response, {
api: model.api,
provider: model.provider,
id: model.id,
});
const reason: Extract<StopReason, "stop" | "length" | "toolUse"> =
assistantMsg.stopReason === "toolUse" ? "toolUse" : "stop";
eventStream.push({ type: "done", reason, message: assistantMsg });
resolve();
} else if (event.type === "response.failed") {
cleanup();
const errMsg = event.response?.error?.message ?? "Response failed";
reject(new Error(`OpenAI WebSocket response failed: ${errMsg}`));
} else if (event.type === "error") {
cleanup();
reject(new Error(`OpenAI WebSocket error: ${event.message} (code=${event.code})`));
} else if (event.type === "response.output_text.delta") {
// Stream partial text updates for responsive UI
const partialMsg: AssistantMessage = {
role: "assistant",
content: [{ type: "text", text: event.delta }],
stopReason: "stop",
api: model.api,
provider: model.provider,
model: model.id,
usage: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
},
timestamp: Date.now(),
};
eventStream.push({
type: "text_delta",
contentIndex: 0,
delta: event.delta,
partial: partialMsg,
});
}
});
});
};
queueMicrotask(() =>
run().catch((err) => {
const errorMessage = err instanceof Error ? err.message : String(err);
log.warn(`[ws-stream] session=${sessionId} run error: ${errorMessage}`);
eventStream.push({
type: "error",
reason: "error",
error: {
role: "assistant" as const,
content: [],
stopReason: "error" as StopReason,
errorMessage,
api: model.api,
provider: model.provider,
model: model.id,
usage: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
},
timestamp: Date.now(),
},
});
eventStream.end();
}),
);
return eventStream;
};
}
// ─────────────────────────────────────────────────────────────────────────────
// Helpers
// ─────────────────────────────────────────────────────────────────────────────
/** Build full input items from context (system prompt is passed via `instructions` field). */
function buildFullInput(context: Context): InputItem[] {
return convertMessagesToInputItems(context.messages);
}
/**
* Fall back to HTTP (`streamSimple`) and pipe events into the existing stream.
* This is called when the WebSocket is broken or unavailable.
*/
async function fallbackToHttp(
model: Parameters<StreamFn>[0],
context: Parameters<StreamFn>[1],
options: Parameters<StreamFn>[2],
eventStream: ReturnType<typeof createAssistantMessageEventStream>,
signal?: AbortSignal,
): Promise<void> {
const mergedOptions = signal ? { ...options, signal } : options;
const httpStream = streamSimple(model, context, mergedOptions);
for await (const event of httpStream) {
eventStream.push(event);
}
}

View File

@@ -544,7 +544,7 @@ describe("applyExtraParamsToAgent", () => {
expect(calls[0]?.transport).toBe("auto");
});
it("does not set transport defaults for non-Codex providers", () => {
it("defaults OpenAI transport to auto (WebSocket-first)", () => {
const { calls, agent } = createOptionsCaptureAgent();
applyExtraParamsToAgent(agent, undefined, "openai", "gpt-5");
@@ -558,7 +558,24 @@ describe("applyExtraParamsToAgent", () => {
void agent.streamFn?.(model, context, {});
expect(calls).toHaveLength(1);
expect(calls[0]?.transport).toBeUndefined();
expect(calls[0]?.transport).toBe("auto");
});
it("lets runtime options override OpenAI default transport", () => {
const { calls, agent } = createOptionsCaptureAgent();
applyExtraParamsToAgent(agent, undefined, "openai", "gpt-5");
const model = {
api: "openai-responses",
provider: "openai",
id: "gpt-5",
} as Model<"openai-responses">;
const context: Context = { messages: [] };
void agent.streamFn?.(model, context, { transport: "sse" });
expect(calls).toHaveLength(1);
expect(calls[0]?.transport).toBe("sse");
});
it("allows forcing Codex transport to SSE", () => {
@@ -878,6 +895,7 @@ describe("applyExtraParamsToAgent", () => {
contextWindow: 128_000,
maxTokens: 16_384,
compat: { supportsStore: false },
<<<<<<< HEAD
} as Model<"openai-responses"> & { compat?: { supportsStore?: boolean } },
});
expect(payload.store).toBe(false);

View File

@@ -319,6 +319,15 @@ function createCodexDefaultTransportWrapper(baseStreamFn: StreamFn | undefined):
});
}
function createOpenAIDefaultTransportWrapper(baseStreamFn: StreamFn | undefined): StreamFn {
const underlying = baseStreamFn ?? streamSimple;
return (model, context, options) =>
underlying(model, context, {
...options,
transport: options?.transport ?? "auto",
});
}
function isAnthropic1MModel(modelId: string): boolean {
const normalized = modelId.trim().toLowerCase();
return ANTHROPIC_1M_MODEL_PREFIXES.some((prefix) => normalized.startsWith(prefix));
@@ -740,6 +749,9 @@ export function applyExtraParamsToAgent(
if (provider === "openai-codex") {
// Default Codex to WebSocket-first when nothing else specifies transport.
agent.streamFn = createCodexDefaultTransportWrapper(agent.streamFn);
} else if (provider === "openai") {
// Default OpenAI Responses to WebSocket-first with transparent SSE fallback.
agent.streamFn = createOpenAIDefaultTransportWrapper(agent.streamFn);
}
const override =
extraParamsOverride && Object.keys(extraParamsOverride).length > 0

View File

@@ -42,6 +42,7 @@ import { resolveImageSanitizationLimits } from "../../image-sanitization.js";
import { resolveModelAuthMode } from "../../model-auth.js";
import { normalizeProviderId, resolveDefaultModelForAgent } from "../../model-selection.js";
import { createOllamaStreamFn, OLLAMA_NATIVE_BASE_URL } from "../../ollama-stream.js";
import { createOpenAIWebSocketStreamFn, releaseWsSession } from "../../openai-ws-stream.js";
import { resolveOwnerDisplaySetting } from "../../owner-display.js";
import {
isCloudCodeAssistFormatError,
@@ -866,6 +867,16 @@ export async function runEmbeddedAttempt(
typeof providerConfig?.baseUrl === "string" ? providerConfig.baseUrl.trim() : "";
const ollamaBaseUrl = modelBaseUrl || providerBaseUrl || OLLAMA_NATIVE_BASE_URL;
activeSession.agent.streamFn = createOllamaStreamFn(ollamaBaseUrl);
} else if (params.model.api === "openai-responses" && params.provider === "openai") {
const wsApiKey = await params.authStorage.getApiKey(params.provider);
if (wsApiKey) {
activeSession.agent.streamFn = createOpenAIWebSocketStreamFn(wsApiKey, params.sessionId, {
signal: runAbortController.signal,
});
} else {
log.warn(`[ws-stream] no API key for provider=${params.provider}; using HTTP transport`);
activeSession.agent.streamFn = streamSimple;
}
} else {
// Force a stable streamFn reference so vitest can reliably mock @mariozechner/pi-ai.
activeSession.agent.streamFn = streamSimple;
@@ -1548,6 +1559,7 @@ export async function runEmbeddedAttempt(
sessionManager,
});
session?.dispose();
releaseWsSession(params.sessionId);
await sessionLock.release();
}
} finally {