mirror of
https://github.com/openclaw/openclaw.git
synced 2026-03-12 07:20:45 +00:00
feat(agents): make openai responses websocket-first with fallback
This commit is contained in:
@@ -43,6 +43,8 @@ OpenClaw ships with the pi‑ai catalog. These providers require **no**
|
||||
- Optional rotation: `OPENAI_API_KEYS`, `OPENAI_API_KEY_1`, `OPENAI_API_KEY_2`, plus `OPENCLAW_LIVE_OPENAI_KEY` (single override)
|
||||
- Example model: `openai/gpt-5.1-codex`
|
||||
- CLI: `openclaw onboard --auth-choice openai-api-key`
|
||||
- Default transport is `auto` (WebSocket-first, SSE fallback)
|
||||
- Override per model via `agents.defaults.models["openai/<model>"].params.transport` (`"sse"`, `"websocket"`, or `"auto"`)
|
||||
|
||||
```json5
|
||||
{
|
||||
|
||||
@@ -56,12 +56,14 @@ openclaw models auth login --provider openai-codex
|
||||
}
|
||||
```
|
||||
|
||||
### Codex transport default
|
||||
### Transport default
|
||||
|
||||
OpenClaw uses `pi-ai` for model streaming. For `openai-codex/*` models you can set
|
||||
`agents.defaults.models.<provider/model>.params.transport` to select transport:
|
||||
OpenClaw uses `pi-ai` for model streaming. For both `openai/*` and
|
||||
`openai-codex/*`, default transport is `"auto"` (WebSocket-first, then SSE
|
||||
fallback).
|
||||
|
||||
You can set `agents.defaults.models.<provider/model>.params.transport`:
|
||||
|
||||
- Default is `"auto"` (WebSocket-first, then SSE fallback).
|
||||
- `"sse"`: force SSE
|
||||
- `"websocket"`: force WebSocket
|
||||
- `"auto"`: try WebSocket, then fall back to SSE
|
||||
|
||||
528
src/agents/openai-ws-connection.ts
Normal file
528
src/agents/openai-ws-connection.ts
Normal file
@@ -0,0 +1,528 @@
|
||||
/**
|
||||
* OpenAI WebSocket Connection Manager
|
||||
*
|
||||
* Manages a persistent WebSocket connection to the OpenAI Responses API
|
||||
* (wss://api.openai.com/v1/responses) for multi-turn tool-call workflows.
|
||||
*
|
||||
* Features:
|
||||
* - Auto-reconnect with exponential backoff (max 5 retries: 1s/2s/4s/8s/16s)
|
||||
* - Tracks previous_response_id per connection for incremental turns
|
||||
* - Warm-up support (generate: false) to pre-load the connection
|
||||
* - Typed WebSocket event definitions matching the Responses API SSE spec
|
||||
*
|
||||
* @see https://developers.openai.com/api/docs/guides/websocket-mode
|
||||
*/
|
||||
|
||||
import { EventEmitter } from "node:events";
|
||||
import WebSocket from "ws";
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// WebSocket Event Types (Server → Client)
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
export interface ResponseObject {
|
||||
id: string;
|
||||
object: "response";
|
||||
created_at: number;
|
||||
status: "in_progress" | "completed" | "failed" | "cancelled" | "incomplete";
|
||||
model: string;
|
||||
output: OutputItem[];
|
||||
usage?: UsageInfo;
|
||||
error?: { code: string; message: string };
|
||||
}
|
||||
|
||||
export interface UsageInfo {
|
||||
input_tokens: number;
|
||||
output_tokens: number;
|
||||
total_tokens: number;
|
||||
}
|
||||
|
||||
export type OutputItem =
|
||||
| {
|
||||
type: "message";
|
||||
id: string;
|
||||
role: "assistant";
|
||||
content: Array<{ type: "output_text"; text: string }>;
|
||||
status?: "in_progress" | "completed";
|
||||
}
|
||||
| {
|
||||
type: "function_call";
|
||||
id: string;
|
||||
call_id: string;
|
||||
name: string;
|
||||
arguments: string;
|
||||
status?: "in_progress" | "completed";
|
||||
}
|
||||
| {
|
||||
type: "reasoning";
|
||||
id: string;
|
||||
content?: string;
|
||||
summary?: string;
|
||||
};
|
||||
|
||||
export interface ResponseCreatedEvent {
|
||||
type: "response.created";
|
||||
response: ResponseObject;
|
||||
}
|
||||
|
||||
export interface ResponseInProgressEvent {
|
||||
type: "response.in_progress";
|
||||
response: ResponseObject;
|
||||
}
|
||||
|
||||
export interface ResponseCompletedEvent {
|
||||
type: "response.completed";
|
||||
response: ResponseObject;
|
||||
}
|
||||
|
||||
export interface ResponseFailedEvent {
|
||||
type: "response.failed";
|
||||
response: ResponseObject;
|
||||
}
|
||||
|
||||
export interface OutputItemAddedEvent {
|
||||
type: "response.output_item.added";
|
||||
output_index: number;
|
||||
item: OutputItem;
|
||||
}
|
||||
|
||||
export interface OutputItemDoneEvent {
|
||||
type: "response.output_item.done";
|
||||
output_index: number;
|
||||
item: OutputItem;
|
||||
}
|
||||
|
||||
export interface ContentPartAddedEvent {
|
||||
type: "response.content_part.added";
|
||||
item_id: string;
|
||||
output_index: number;
|
||||
content_index: number;
|
||||
part: { type: "output_text"; text: string };
|
||||
}
|
||||
|
||||
export interface ContentPartDoneEvent {
|
||||
type: "response.content_part.done";
|
||||
item_id: string;
|
||||
output_index: number;
|
||||
content_index: number;
|
||||
part: { type: "output_text"; text: string };
|
||||
}
|
||||
|
||||
export interface OutputTextDeltaEvent {
|
||||
type: "response.output_text.delta";
|
||||
item_id: string;
|
||||
output_index: number;
|
||||
content_index: number;
|
||||
delta: string;
|
||||
}
|
||||
|
||||
export interface OutputTextDoneEvent {
|
||||
type: "response.output_text.done";
|
||||
item_id: string;
|
||||
output_index: number;
|
||||
content_index: number;
|
||||
text: string;
|
||||
}
|
||||
|
||||
export interface FunctionCallArgumentsDeltaEvent {
|
||||
type: "response.function_call_arguments.delta";
|
||||
item_id: string;
|
||||
output_index: number;
|
||||
call_id: string;
|
||||
delta: string;
|
||||
}
|
||||
|
||||
export interface FunctionCallArgumentsDoneEvent {
|
||||
type: "response.function_call_arguments.done";
|
||||
item_id: string;
|
||||
output_index: number;
|
||||
call_id: string;
|
||||
arguments: string;
|
||||
}
|
||||
|
||||
export interface RateLimitUpdatedEvent {
|
||||
type: "rate_limits.updated";
|
||||
rate_limits: Array<{
|
||||
name: string;
|
||||
limit: number;
|
||||
remaining: number;
|
||||
reset_seconds: number;
|
||||
}>;
|
||||
}
|
||||
|
||||
export interface ErrorEvent {
|
||||
type: "error";
|
||||
code: string;
|
||||
message: string;
|
||||
param?: string;
|
||||
}
|
||||
|
||||
export type OpenAIWebSocketEvent =
|
||||
| ResponseCreatedEvent
|
||||
| ResponseInProgressEvent
|
||||
| ResponseCompletedEvent
|
||||
| ResponseFailedEvent
|
||||
| OutputItemAddedEvent
|
||||
| OutputItemDoneEvent
|
||||
| ContentPartAddedEvent
|
||||
| ContentPartDoneEvent
|
||||
| OutputTextDeltaEvent
|
||||
| OutputTextDoneEvent
|
||||
| FunctionCallArgumentsDeltaEvent
|
||||
| FunctionCallArgumentsDoneEvent
|
||||
| RateLimitUpdatedEvent
|
||||
| ErrorEvent;
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// Client → Server Event Types
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
export type ContentPart =
|
||||
| { type: "input_text"; text: string }
|
||||
| { type: "output_text"; text: string }
|
||||
| {
|
||||
type: "input_image";
|
||||
source: { type: "url"; url: string } | { type: "base64"; media_type: string; data: string };
|
||||
};
|
||||
|
||||
export type InputItem =
|
||||
| {
|
||||
type: "message";
|
||||
role: "system" | "developer" | "user" | "assistant";
|
||||
content: string | ContentPart[];
|
||||
}
|
||||
| { type: "function_call"; id?: string; call_id?: string; name: string; arguments: string }
|
||||
| { type: "function_call_output"; call_id: string; output: string }
|
||||
| { type: "reasoning"; content?: string; encrypted_content?: string; summary?: string }
|
||||
| { type: "item_reference"; id: string };
|
||||
|
||||
export type ToolChoice =
|
||||
| "auto"
|
||||
| "none"
|
||||
| "required"
|
||||
| { type: "function"; function: { name: string } };
|
||||
|
||||
export interface FunctionToolDefinition {
|
||||
type: "function";
|
||||
function: {
|
||||
name: string;
|
||||
description?: string;
|
||||
parameters?: Record<string, unknown>;
|
||||
};
|
||||
}
|
||||
|
||||
/** Standard response.create event payload (full turn) */
|
||||
export interface ResponseCreateEvent {
|
||||
type: "response.create";
|
||||
model: string;
|
||||
store?: boolean;
|
||||
stream?: boolean;
|
||||
input?: string | InputItem[];
|
||||
instructions?: string;
|
||||
tools?: FunctionToolDefinition[];
|
||||
tool_choice?: ToolChoice;
|
||||
context_management?: unknown;
|
||||
previous_response_id?: string;
|
||||
max_output_tokens?: number;
|
||||
temperature?: number;
|
||||
top_p?: number;
|
||||
metadata?: Record<string, string>;
|
||||
reasoning?: { effort?: "low" | "medium" | "high"; summary?: "auto" | "concise" | "detailed" };
|
||||
truncation?: "auto" | "disabled";
|
||||
[key: string]: unknown;
|
||||
}
|
||||
|
||||
/** Warm-up payload: generate: false pre-loads connection without generating output */
|
||||
export interface WarmUpEvent extends ResponseCreateEvent {
|
||||
generate: false;
|
||||
}
|
||||
|
||||
export type ClientEvent = ResponseCreateEvent | WarmUpEvent;
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// Connection Manager
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
const OPENAI_WS_URL = "wss://api.openai.com/v1/responses";
|
||||
const MAX_RETRIES = 5;
|
||||
/** Backoff delays in ms: 1s, 2s, 4s, 8s, 16s */
|
||||
const BACKOFF_DELAYS_MS = [1000, 2000, 4000, 8000, 16000] as const;
|
||||
|
||||
export interface OpenAIWebSocketManagerOptions {
|
||||
/** Override the default WebSocket URL (useful for testing) */
|
||||
url?: string;
|
||||
/** Maximum number of reconnect attempts (default: 5) */
|
||||
maxRetries?: number;
|
||||
/** Custom backoff delays in ms (default: [1000, 2000, 4000, 8000, 16000]) */
|
||||
backoffDelaysMs?: readonly number[];
|
||||
}
|
||||
|
||||
type InternalEvents = {
|
||||
message: [event: OpenAIWebSocketEvent];
|
||||
open: [];
|
||||
close: [code: number, reason: string];
|
||||
error: [err: Error];
|
||||
};
|
||||
|
||||
/**
|
||||
* Manages a persistent WebSocket connection to the OpenAI Responses API.
|
||||
*
|
||||
* Usage:
|
||||
* ```ts
|
||||
* const manager = new OpenAIWebSocketManager();
|
||||
* await manager.connect(apiKey);
|
||||
*
|
||||
* manager.onMessage((event) => {
|
||||
* if (event.type === "response.completed") {
|
||||
* console.log("Response ID:", event.response.id);
|
||||
* }
|
||||
* });
|
||||
*
|
||||
* manager.send({ type: "response.create", model: "gpt-5.2", input: [...] });
|
||||
* ```
|
||||
*/
|
||||
export class OpenAIWebSocketManager extends EventEmitter<InternalEvents> {
|
||||
private ws: WebSocket | null = null;
|
||||
private apiKey: string | null = null;
|
||||
private retryCount = 0;
|
||||
private retryTimer: NodeJS.Timeout | null = null;
|
||||
private closed = false;
|
||||
|
||||
/** The ID of the most recent completed response on this connection. */
|
||||
private _previousResponseId: string | null = null;
|
||||
|
||||
private readonly wsUrl: string;
|
||||
private readonly maxRetries: number;
|
||||
private readonly backoffDelaysMs: readonly number[];
|
||||
|
||||
constructor(options: OpenAIWebSocketManagerOptions = {}) {
|
||||
super();
|
||||
this.wsUrl = options.url ?? OPENAI_WS_URL;
|
||||
this.maxRetries = options.maxRetries ?? MAX_RETRIES;
|
||||
this.backoffDelaysMs = options.backoffDelaysMs ?? BACKOFF_DELAYS_MS;
|
||||
}
|
||||
|
||||
// ─── Public API ────────────────────────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Returns the previous_response_id from the last completed response,
|
||||
* for use in subsequent response.create events.
|
||||
*/
|
||||
get previousResponseId(): string | null {
|
||||
return this._previousResponseId;
|
||||
}
|
||||
|
||||
/**
|
||||
* Opens a WebSocket connection to the OpenAI Responses API.
|
||||
* Resolves when the connection is established (open event fires).
|
||||
* Rejects if the initial connection fails after max retries.
|
||||
*/
|
||||
connect(apiKey: string): Promise<void> {
|
||||
this.apiKey = apiKey;
|
||||
this.closed = false;
|
||||
this.retryCount = 0;
|
||||
return this._openConnection();
|
||||
}
|
||||
|
||||
/**
|
||||
* Sends a typed event to the OpenAI Responses API over the WebSocket.
|
||||
* Throws if the connection is not open.
|
||||
*/
|
||||
send(event: ClientEvent): void {
|
||||
if (!this.ws || this.ws.readyState !== WebSocket.OPEN) {
|
||||
throw new Error(
|
||||
`OpenAIWebSocketManager: cannot send — connection is not open (readyState=${this.ws?.readyState ?? "no socket"})`,
|
||||
);
|
||||
}
|
||||
this.ws.send(JSON.stringify(event));
|
||||
}
|
||||
|
||||
/**
|
||||
* Registers a handler for incoming server-sent WebSocket events.
|
||||
* Returns an unsubscribe function.
|
||||
*/
|
||||
onMessage(handler: (event: OpenAIWebSocketEvent) => void): () => void {
|
||||
this.on("message", handler);
|
||||
return () => {
|
||||
this.off("message", handler);
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns true if the WebSocket is currently open and ready to send.
|
||||
*/
|
||||
isConnected(): boolean {
|
||||
return this.ws !== null && this.ws.readyState === WebSocket.OPEN;
|
||||
}
|
||||
|
||||
/**
|
||||
* Permanently closes the WebSocket connection and disables auto-reconnect.
|
||||
*/
|
||||
close(): void {
|
||||
this.closed = true;
|
||||
this._cancelRetryTimer();
|
||||
if (this.ws) {
|
||||
this.ws.removeAllListeners();
|
||||
if (this.ws.readyState === WebSocket.OPEN || this.ws.readyState === WebSocket.CONNECTING) {
|
||||
this.ws.close(1000, "Client closed");
|
||||
}
|
||||
this.ws = null;
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Internal: Connection Lifecycle ────────────────────────────────────────
|
||||
|
||||
private _openConnection(): Promise<void> {
|
||||
return new Promise<void>((resolve, reject) => {
|
||||
if (!this.apiKey) {
|
||||
reject(new Error("OpenAIWebSocketManager: apiKey is required before connecting."));
|
||||
return;
|
||||
}
|
||||
|
||||
const socket = new WebSocket(this.wsUrl, {
|
||||
headers: {
|
||||
Authorization: `Bearer ${this.apiKey}`,
|
||||
"OpenAI-Beta": "responses-websocket=v1",
|
||||
},
|
||||
});
|
||||
|
||||
this.ws = socket;
|
||||
|
||||
const onOpen = () => {
|
||||
this.retryCount = 0;
|
||||
resolve();
|
||||
this.emit("open");
|
||||
};
|
||||
|
||||
const onError = (err: Error) => {
|
||||
// Remove open listener so we don't resolve after an error.
|
||||
socket.off("open", onOpen);
|
||||
// Emit "error" on the manager only when there are listeners; otherwise
|
||||
// the promise rejection below is the primary error channel for this
|
||||
// initial connection failure. (An uncaught "error" event in Node.js
|
||||
// throws synchronously and would prevent the promise from rejecting.)
|
||||
if (this.listenerCount("error") > 0) {
|
||||
this.emit("error", err);
|
||||
}
|
||||
reject(err);
|
||||
};
|
||||
|
||||
const onClose = (code: number, reason: Buffer) => {
|
||||
const reasonStr = reason.toString();
|
||||
this.emit("close", code, reasonStr);
|
||||
|
||||
if (!this.closed) {
|
||||
this._scheduleReconnect();
|
||||
}
|
||||
};
|
||||
|
||||
const onMessage = (data: WebSocket.RawData) => {
|
||||
this._handleMessage(data);
|
||||
};
|
||||
|
||||
socket.once("open", onOpen);
|
||||
socket.on("error", onError);
|
||||
socket.on("close", onClose);
|
||||
socket.on("message", onMessage);
|
||||
});
|
||||
}
|
||||
|
||||
private _scheduleReconnect(): void {
|
||||
if (this.closed) {
|
||||
return;
|
||||
}
|
||||
if (this.retryCount >= this.maxRetries) {
|
||||
this._safeEmitError(
|
||||
new Error(`OpenAIWebSocketManager: max reconnect retries (${this.maxRetries}) exceeded.`),
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
const delayMs =
|
||||
this.backoffDelaysMs[Math.min(this.retryCount, this.backoffDelaysMs.length - 1)] ?? 1000;
|
||||
this.retryCount++;
|
||||
|
||||
this.retryTimer = setTimeout(() => {
|
||||
if (this.closed) {
|
||||
return;
|
||||
}
|
||||
this._openConnection().catch((err: unknown) => {
|
||||
// onError handler already emitted error event; schedule next retry.
|
||||
void err;
|
||||
this._scheduleReconnect();
|
||||
});
|
||||
}, delayMs);
|
||||
}
|
||||
|
||||
/** Emit an error only if there are listeners; prevents Node.js from crashing
|
||||
* with "unhandled 'error' event" when no one is listening. */
|
||||
private _safeEmitError(err: Error): void {
|
||||
if (this.listenerCount("error") > 0) {
|
||||
this.emit("error", err);
|
||||
}
|
||||
}
|
||||
|
||||
private _cancelRetryTimer(): void {
|
||||
if (this.retryTimer !== null) {
|
||||
clearTimeout(this.retryTimer);
|
||||
this.retryTimer = null;
|
||||
}
|
||||
}
|
||||
|
||||
private _handleMessage(data: WebSocket.RawData): void {
|
||||
let text: string;
|
||||
if (typeof data === "string") {
|
||||
text = data;
|
||||
} else if (Buffer.isBuffer(data)) {
|
||||
text = data.toString("utf8");
|
||||
} else if (data instanceof ArrayBuffer) {
|
||||
text = Buffer.from(data).toString("utf8");
|
||||
} else {
|
||||
// Blob or other — coerce to string
|
||||
text = String(data);
|
||||
}
|
||||
|
||||
let parsed: unknown;
|
||||
try {
|
||||
parsed = JSON.parse(text);
|
||||
} catch {
|
||||
this._safeEmitError(
|
||||
new Error(`OpenAIWebSocketManager: failed to parse message: ${text.slice(0, 200)}`),
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
if (!parsed || typeof parsed !== "object" || !("type" in parsed)) {
|
||||
this._safeEmitError(
|
||||
new Error(
|
||||
`OpenAIWebSocketManager: unexpected message shape (no "type" field): ${text.slice(0, 200)}`,
|
||||
),
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
const event = parsed as OpenAIWebSocketEvent;
|
||||
|
||||
// Track previous_response_id on completion
|
||||
if (event.type === "response.completed" && event.response?.id) {
|
||||
this._previousResponseId = event.response.id;
|
||||
}
|
||||
|
||||
this.emit("message", event);
|
||||
}
|
||||
|
||||
/**
|
||||
* Sends a warm-up event to pre-load the connection and model without generating output.
|
||||
* Pass tools/instructions to prime the connection for the upcoming session.
|
||||
*/
|
||||
warmUp(params: { model: string; tools?: FunctionToolDefinition[]; instructions?: string }): void {
|
||||
const event: WarmUpEvent = {
|
||||
type: "response.create",
|
||||
generate: false,
|
||||
model: params.model,
|
||||
...(params.tools ? { tools: params.tools } : {}),
|
||||
...(params.instructions ? { instructions: params.instructions } : {}),
|
||||
};
|
||||
this.send(event);
|
||||
}
|
||||
}
|
||||
680
src/agents/openai-ws-stream.ts
Normal file
680
src/agents/openai-ws-stream.ts
Normal file
@@ -0,0 +1,680 @@
|
||||
/**
|
||||
* OpenAI WebSocket StreamFn Integration
|
||||
*
|
||||
* Wraps `OpenAIWebSocketManager` in a `StreamFn` that can be plugged into the
|
||||
* pi-embedded-runner agent in place of the default `streamSimple` HTTP function.
|
||||
*
|
||||
* Key behaviours:
|
||||
* - Per-session `OpenAIWebSocketManager` (keyed by sessionId)
|
||||
* - Tracks `previous_response_id` to send only incremental tool-result inputs
|
||||
* - Falls back to `streamSimple` (HTTP) if the WebSocket connection fails
|
||||
* - Cleanup helpers for releasing sessions after the run completes
|
||||
*
|
||||
* Complexity budget & risk mitigation:
|
||||
* - **Transport aware**: respects `transport` (`auto` | `websocket` | `sse`)
|
||||
* - **Transparent fallback in `auto` mode**: connect/send failures fall back to
|
||||
* the existing HTTP `streamSimple`; forced `websocket` mode surfaces WS errors
|
||||
* - **Zero shared state**: per-session registry; session cleanup on dispose prevents leaks
|
||||
* - **Full parity**: all generation options (temperature, top_p, max_output_tokens,
|
||||
* tool_choice, reasoning) forwarded identically to the HTTP path
|
||||
*
|
||||
* @see src/agents/openai-ws-connection.ts for the connection manager
|
||||
*/
|
||||
|
||||
import { randomUUID } from "node:crypto";
|
||||
import type { StreamFn } from "@mariozechner/pi-agent-core";
|
||||
import type {
|
||||
AssistantMessage,
|
||||
Context,
|
||||
Message,
|
||||
StopReason,
|
||||
TextContent,
|
||||
ToolCall,
|
||||
Usage,
|
||||
} from "@mariozechner/pi-ai";
|
||||
import { createAssistantMessageEventStream, streamSimple } from "@mariozechner/pi-ai";
|
||||
import {
|
||||
OpenAIWebSocketManager,
|
||||
type ContentPart,
|
||||
type FunctionToolDefinition,
|
||||
type InputItem,
|
||||
type OpenAIWebSocketManagerOptions,
|
||||
type ResponseObject,
|
||||
} from "./openai-ws-connection.js";
|
||||
import { log } from "./pi-embedded-runner/logger.js";
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// Per-session state
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
interface WsSession {
|
||||
manager: OpenAIWebSocketManager;
|
||||
/** Number of messages that were in context.messages at the END of the last streamFn call. */
|
||||
lastContextLength: number;
|
||||
/** True if the connection has been established at least once. */
|
||||
everConnected: boolean;
|
||||
/** True if the session is permanently broken (no more reconnect). */
|
||||
broken: boolean;
|
||||
}
|
||||
|
||||
/** Module-level registry: sessionId → WsSession */
|
||||
const wsRegistry = new Map<string, WsSession>();
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// Public registry helpers
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Release and close the WebSocket session for the given sessionId.
|
||||
* Call this after the agent run completes to free the connection.
|
||||
*/
|
||||
export function releaseWsSession(sessionId: string): void {
|
||||
const session = wsRegistry.get(sessionId);
|
||||
if (session) {
|
||||
try {
|
||||
session.manager.close();
|
||||
} catch {
|
||||
// Ignore close errors — connection may already be gone.
|
||||
}
|
||||
wsRegistry.delete(sessionId);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns true if a live WebSocket session exists for the given sessionId.
|
||||
*/
|
||||
export function hasWsSession(sessionId: string): boolean {
|
||||
const s = wsRegistry.get(sessionId);
|
||||
return !!(s && !s.broken && s.manager.isConnected());
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// Message format converters
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
type AnyMessage = Message & { role: string; content: unknown };
|
||||
|
||||
/** Convert pi-ai content (string | ContentPart[]) to plain text. */
|
||||
function contentToText(content: unknown): string {
|
||||
if (typeof content === "string") {
|
||||
return content;
|
||||
}
|
||||
if (!Array.isArray(content)) {
|
||||
return "";
|
||||
}
|
||||
return (content as Array<{ type?: string; text?: string }>)
|
||||
.filter((p) => p.type === "text" && typeof p.text === "string")
|
||||
.map((p) => p.text as string)
|
||||
.join("");
|
||||
}
|
||||
|
||||
/** Convert pi-ai content to OpenAI ContentPart[]. */
|
||||
function contentToOpenAIParts(content: unknown): ContentPart[] {
|
||||
if (typeof content === "string") {
|
||||
return content ? [{ type: "input_text", text: content }] : [];
|
||||
}
|
||||
if (!Array.isArray(content)) {
|
||||
return [];
|
||||
}
|
||||
const parts: ContentPart[] = [];
|
||||
for (const part of content as Array<{
|
||||
type?: string;
|
||||
text?: string;
|
||||
data?: string;
|
||||
mimeType?: string;
|
||||
}>) {
|
||||
if (part.type === "text" && typeof part.text === "string") {
|
||||
parts.push({ type: "input_text", text: part.text });
|
||||
} else if (part.type === "image" && typeof part.data === "string") {
|
||||
parts.push({
|
||||
type: "input_image",
|
||||
source: {
|
||||
type: "base64",
|
||||
media_type: part.mimeType ?? "image/jpeg",
|
||||
data: part.data,
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
return parts;
|
||||
}
|
||||
|
||||
/** Convert pi-ai tool array to OpenAI FunctionToolDefinition[]. */
|
||||
export function convertTools(tools: Context["tools"]): FunctionToolDefinition[] {
|
||||
if (!tools || tools.length === 0) {
|
||||
return [];
|
||||
}
|
||||
return tools.map((tool) => ({
|
||||
type: "function" as const,
|
||||
function: {
|
||||
name: tool.name,
|
||||
description: typeof tool.description === "string" ? tool.description : undefined,
|
||||
parameters: (tool.parameters ?? {}) as Record<string, unknown>,
|
||||
},
|
||||
}));
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert the full pi-ai message history to an OpenAI `input` array.
|
||||
* Handles user messages, assistant text+tool-call messages, and tool results.
|
||||
*/
|
||||
export function convertMessagesToInputItems(messages: Message[]): InputItem[] {
|
||||
const items: InputItem[] = [];
|
||||
|
||||
for (const msg of messages) {
|
||||
const m = msg as AnyMessage;
|
||||
|
||||
if (m.role === "user") {
|
||||
const parts = contentToOpenAIParts(m.content);
|
||||
items.push({
|
||||
type: "message",
|
||||
role: "user",
|
||||
content:
|
||||
parts.length === 1 && parts[0]?.type === "input_text"
|
||||
? (parts[0] as { type: "input_text"; text: string }).text
|
||||
: parts,
|
||||
});
|
||||
continue;
|
||||
}
|
||||
|
||||
if (m.role === "assistant") {
|
||||
const content = m.content;
|
||||
if (Array.isArray(content)) {
|
||||
// Collect text blocks and tool calls separately
|
||||
const textParts: string[] = [];
|
||||
for (const block of content as Array<{
|
||||
type?: string;
|
||||
text?: string;
|
||||
id?: string;
|
||||
name?: string;
|
||||
arguments?: Record<string, unknown>;
|
||||
thinking?: string;
|
||||
}>) {
|
||||
if (block.type === "text" && typeof block.text === "string") {
|
||||
textParts.push(block.text);
|
||||
} else if (block.type === "thinking" && typeof block.thinking === "string") {
|
||||
// Skip thinking blocks — not sent back to the model
|
||||
} else if (block.type === "toolCall") {
|
||||
// Push accumulated text first
|
||||
if (textParts.length > 0) {
|
||||
items.push({
|
||||
type: "message",
|
||||
role: "assistant",
|
||||
content: textParts.join(""),
|
||||
});
|
||||
textParts.length = 0;
|
||||
}
|
||||
// Push function_call item
|
||||
items.push({
|
||||
type: "function_call",
|
||||
call_id: typeof block.id === "string" ? block.id : `call_${randomUUID()}`,
|
||||
name: block.name ?? "",
|
||||
arguments:
|
||||
typeof block.arguments === "string"
|
||||
? block.arguments
|
||||
: JSON.stringify(block.arguments ?? {}),
|
||||
});
|
||||
}
|
||||
}
|
||||
if (textParts.length > 0) {
|
||||
items.push({
|
||||
type: "message",
|
||||
role: "assistant",
|
||||
content: textParts.join(""),
|
||||
});
|
||||
}
|
||||
} else {
|
||||
const text = contentToText(m.content);
|
||||
if (text) {
|
||||
items.push({
|
||||
type: "message",
|
||||
role: "assistant",
|
||||
content: text,
|
||||
});
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (m.role === "toolResult") {
|
||||
const tr = m as unknown as {
|
||||
toolCallId: string;
|
||||
content: unknown;
|
||||
isError: boolean;
|
||||
};
|
||||
const outputText = contentToText(tr.content);
|
||||
items.push({
|
||||
type: "function_call_output",
|
||||
call_id: tr.toolCallId,
|
||||
output: outputText,
|
||||
});
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
return items;
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// Response object → AssistantMessage
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
export function buildAssistantMessageFromResponse(
|
||||
response: ResponseObject,
|
||||
modelInfo: { api: string; provider: string; id: string },
|
||||
): AssistantMessage {
|
||||
const content: (TextContent | ToolCall)[] = [];
|
||||
|
||||
for (const item of response.output ?? []) {
|
||||
if (item.type === "message") {
|
||||
for (const part of item.content ?? []) {
|
||||
if (part.type === "output_text" && part.text) {
|
||||
content.push({ type: "text", text: part.text });
|
||||
}
|
||||
}
|
||||
} else if (item.type === "function_call") {
|
||||
content.push({
|
||||
type: "toolCall",
|
||||
id: item.call_id,
|
||||
name: item.name,
|
||||
arguments: (() => {
|
||||
try {
|
||||
return JSON.parse(item.arguments) as Record<string, unknown>;
|
||||
} catch {
|
||||
return {} as Record<string, unknown>;
|
||||
}
|
||||
})(),
|
||||
});
|
||||
}
|
||||
// "reasoning" items are informational only; skip.
|
||||
}
|
||||
|
||||
const hasToolCalls = content.some((c) => c.type === "toolCall");
|
||||
const stopReason: StopReason = hasToolCalls ? "toolUse" : "stop";
|
||||
|
||||
const usage: Usage = {
|
||||
input: response.usage?.input_tokens ?? 0,
|
||||
output: response.usage?.output_tokens ?? 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: response.usage?.total_tokens ?? 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
};
|
||||
|
||||
return {
|
||||
role: "assistant",
|
||||
content,
|
||||
stopReason,
|
||||
api: modelInfo.api,
|
||||
provider: modelInfo.provider,
|
||||
model: modelInfo.id,
|
||||
usage,
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// StreamFn factory
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
export interface OpenAIWebSocketStreamOptions {
|
||||
/** Manager options (url override, retry counts, etc.) */
|
||||
managerOptions?: OpenAIWebSocketManagerOptions;
|
||||
/** Abort signal forwarded from the run. */
|
||||
signal?: AbortSignal;
|
||||
}
|
||||
|
||||
type WsTransport = "sse" | "websocket" | "auto";
|
||||
|
||||
function resolveWsTransport(options: Parameters<StreamFn>[2]): WsTransport {
|
||||
const transport = (options as { transport?: unknown } | undefined)?.transport;
|
||||
return transport === "sse" || transport === "websocket" || transport === "auto"
|
||||
? transport
|
||||
: "auto";
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a `StreamFn` backed by a persistent WebSocket connection to the
|
||||
* OpenAI Responses API. The first call for a given `sessionId` opens the
|
||||
* connection; subsequent calls reuse it, sending only incremental tool-result
|
||||
* inputs with `previous_response_id`.
|
||||
*
|
||||
* If the WebSocket connection is unavailable, the function falls back to the
|
||||
* standard `streamSimple` HTTP path and logs a warning.
|
||||
*
|
||||
* @param apiKey OpenAI API key
|
||||
* @param sessionId Agent session ID (used as the registry key)
|
||||
* @param opts Optional manager + abort signal overrides
|
||||
*/
|
||||
export function createOpenAIWebSocketStreamFn(
|
||||
apiKey: string,
|
||||
sessionId: string,
|
||||
opts: OpenAIWebSocketStreamOptions = {},
|
||||
): StreamFn {
|
||||
return (model, context, options) => {
|
||||
const eventStream = createAssistantMessageEventStream();
|
||||
|
||||
const run = async () => {
|
||||
const transport = resolveWsTransport(options);
|
||||
if (transport === "sse") {
|
||||
return fallbackToHttp(model, context, options, eventStream, opts.signal);
|
||||
}
|
||||
|
||||
// ── 1. Get or create session state ──────────────────────────────────
|
||||
let session = wsRegistry.get(sessionId);
|
||||
|
||||
if (!session) {
|
||||
const manager = new OpenAIWebSocketManager(opts.managerOptions);
|
||||
session = {
|
||||
manager,
|
||||
lastContextLength: 0,
|
||||
everConnected: false,
|
||||
broken: false,
|
||||
};
|
||||
wsRegistry.set(sessionId, session);
|
||||
}
|
||||
|
||||
// ── 2. Ensure connection is open ─────────────────────────────────────
|
||||
if (!session.manager.isConnected() && !session.broken) {
|
||||
try {
|
||||
await session.manager.connect(apiKey);
|
||||
session.everConnected = true;
|
||||
log.debug(`[ws-stream] connected for session=${sessionId}`);
|
||||
} catch (connErr) {
|
||||
// Cancel any background reconnect attempts before marking as broken.
|
||||
try {
|
||||
session.manager.close();
|
||||
} catch {
|
||||
/* ignore */
|
||||
}
|
||||
session.broken = true;
|
||||
wsRegistry.delete(sessionId);
|
||||
if (transport === "websocket") {
|
||||
throw connErr instanceof Error ? connErr : new Error(String(connErr));
|
||||
}
|
||||
log.warn(
|
||||
`[ws-stream] WebSocket connect failed for session=${sessionId}; falling back to HTTP. error=${String(connErr)}`,
|
||||
);
|
||||
// Fall back to HTTP immediately
|
||||
return fallbackToHttp(model, context, options, eventStream, opts.signal);
|
||||
}
|
||||
}
|
||||
|
||||
if (session.broken || !session.manager.isConnected()) {
|
||||
if (transport === "websocket") {
|
||||
throw new Error("WebSocket session disconnected");
|
||||
}
|
||||
log.warn(`[ws-stream] session=${sessionId} broken/disconnected; falling back to HTTP`);
|
||||
// Clean up stale session to prevent next turn from using stale
|
||||
// previousResponseId / lastContextLength after a mid-request drop.
|
||||
try {
|
||||
session.manager.close();
|
||||
} catch {
|
||||
/* ignore */
|
||||
}
|
||||
wsRegistry.delete(sessionId);
|
||||
return fallbackToHttp(model, context, options, eventStream, opts.signal);
|
||||
}
|
||||
|
||||
// ── 3. Compute incremental vs full input ─────────────────────────────
|
||||
const prevResponseId = session.manager.previousResponseId;
|
||||
let inputItems: InputItem[];
|
||||
|
||||
if (prevResponseId && session.lastContextLength > 0) {
|
||||
// Subsequent turn: only send new messages (tool results) since last call
|
||||
const newMessages = context.messages.slice(session.lastContextLength);
|
||||
// Filter to only tool results — the assistant message is already in server context
|
||||
const toolResults = newMessages.filter((m) => (m as AnyMessage).role === "toolResult");
|
||||
if (toolResults.length === 0) {
|
||||
// Shouldn't happen in a well-formed turn, but fall back to full context
|
||||
log.debug(
|
||||
`[ws-stream] session=${sessionId}: no new tool results found; sending full context`,
|
||||
);
|
||||
inputItems = buildFullInput(context);
|
||||
} else {
|
||||
inputItems = convertMessagesToInputItems(toolResults);
|
||||
}
|
||||
log.debug(
|
||||
`[ws-stream] session=${sessionId}: incremental send (${inputItems.length} tool results) previous_response_id=${prevResponseId}`,
|
||||
);
|
||||
} else {
|
||||
// First turn: send full context
|
||||
inputItems = buildFullInput(context);
|
||||
log.debug(
|
||||
`[ws-stream] session=${sessionId}: full context send (${inputItems.length} items)`,
|
||||
);
|
||||
}
|
||||
|
||||
// ── 4. Build & send response.create ──────────────────────────────────
|
||||
const tools = convertTools(context.tools);
|
||||
|
||||
// Forward generation options that the HTTP path (openai-responses provider) also uses.
|
||||
// Cast to record since SimpleStreamOptions carries openai-specific fields as unknown.
|
||||
const streamOpts = options as
|
||||
| (Record<string, unknown> & {
|
||||
temperature?: number;
|
||||
maxTokens?: number;
|
||||
topP?: number;
|
||||
toolChoice?: unknown;
|
||||
})
|
||||
| undefined;
|
||||
const extraParams: Record<string, unknown> = {};
|
||||
if (streamOpts?.temperature !== undefined) {
|
||||
extraParams.temperature = streamOpts.temperature;
|
||||
}
|
||||
if (streamOpts?.maxTokens) {
|
||||
extraParams.max_output_tokens = streamOpts.maxTokens;
|
||||
}
|
||||
if (streamOpts?.topP !== undefined) {
|
||||
extraParams.top_p = streamOpts.topP;
|
||||
}
|
||||
if (streamOpts?.toolChoice !== undefined) {
|
||||
extraParams.tool_choice = streamOpts.toolChoice;
|
||||
}
|
||||
if (streamOpts?.reasoningEffort || streamOpts?.reasoningSummary) {
|
||||
const reasoning: { effort?: string; summary?: string } = {};
|
||||
if (streamOpts.reasoningEffort !== undefined) {
|
||||
reasoning.effort = streamOpts.reasoningEffort as string;
|
||||
}
|
||||
if (streamOpts.reasoningSummary !== undefined) {
|
||||
reasoning.summary = streamOpts.reasoningSummary as string;
|
||||
}
|
||||
extraParams.reasoning = reasoning;
|
||||
}
|
||||
|
||||
const payload: Record<string, unknown> = {
|
||||
type: "response.create",
|
||||
model: model.id,
|
||||
store: false,
|
||||
input: inputItems,
|
||||
instructions: context.systemPrompt ?? undefined,
|
||||
tools: tools.length > 0 ? tools : undefined,
|
||||
...(prevResponseId ? { previous_response_id: prevResponseId } : {}),
|
||||
...extraParams,
|
||||
};
|
||||
options?.onPayload?.(payload);
|
||||
|
||||
try {
|
||||
session.manager.send(payload as Parameters<OpenAIWebSocketManager["send"]>[0]);
|
||||
} catch (sendErr) {
|
||||
if (transport === "websocket") {
|
||||
throw sendErr instanceof Error ? sendErr : new Error(String(sendErr));
|
||||
}
|
||||
log.warn(
|
||||
`[ws-stream] send failed for session=${sessionId}; falling back to HTTP. error=${String(sendErr)}`,
|
||||
);
|
||||
// Fully reset session state so the next WS turn doesn't use stale
|
||||
// previous_response_id or lastContextLength from before the failure.
|
||||
try {
|
||||
session.manager.close();
|
||||
} catch {
|
||||
/* ignore */
|
||||
}
|
||||
wsRegistry.delete(sessionId);
|
||||
return fallbackToHttp(model, context, options, eventStream, opts.signal);
|
||||
}
|
||||
|
||||
eventStream.push({
|
||||
type: "start",
|
||||
partial: {
|
||||
role: "assistant",
|
||||
content: [],
|
||||
stopReason: "stop",
|
||||
api: model.api,
|
||||
provider: model.provider,
|
||||
model: model.id,
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
timestamp: Date.now(),
|
||||
},
|
||||
});
|
||||
|
||||
// ── 5. Wait for response.completed ───────────────────────────────────
|
||||
const capturedContextLength = context.messages.length;
|
||||
|
||||
await new Promise<void>((resolve, reject) => {
|
||||
// Honour abort signal
|
||||
const abortHandler = () => {
|
||||
cleanup();
|
||||
reject(new Error("aborted"));
|
||||
};
|
||||
const signal = opts.signal ?? (options as { signal?: AbortSignal } | undefined)?.signal;
|
||||
if (signal?.aborted) {
|
||||
reject(new Error("aborted"));
|
||||
return;
|
||||
}
|
||||
signal?.addEventListener("abort", abortHandler, { once: true });
|
||||
|
||||
// If the WebSocket drops mid-request, reject so we don't hang forever.
|
||||
const closeHandler = (code: number, reason: string) => {
|
||||
cleanup();
|
||||
reject(
|
||||
new Error(`WebSocket closed mid-request (code=${code}, reason=${reason || "unknown"})`),
|
||||
);
|
||||
};
|
||||
session.manager.on("close", closeHandler);
|
||||
|
||||
const cleanup = () => {
|
||||
signal?.removeEventListener("abort", abortHandler);
|
||||
session.manager.off("close", closeHandler);
|
||||
unsubscribe();
|
||||
};
|
||||
|
||||
const unsubscribe = session.manager.onMessage((event) => {
|
||||
if (event.type === "response.completed") {
|
||||
cleanup();
|
||||
// Update session state
|
||||
session.lastContextLength = capturedContextLength;
|
||||
// Build and emit the assistant message
|
||||
const assistantMsg = buildAssistantMessageFromResponse(event.response, {
|
||||
api: model.api,
|
||||
provider: model.provider,
|
||||
id: model.id,
|
||||
});
|
||||
const reason: Extract<StopReason, "stop" | "length" | "toolUse"> =
|
||||
assistantMsg.stopReason === "toolUse" ? "toolUse" : "stop";
|
||||
eventStream.push({ type: "done", reason, message: assistantMsg });
|
||||
resolve();
|
||||
} else if (event.type === "response.failed") {
|
||||
cleanup();
|
||||
const errMsg = event.response?.error?.message ?? "Response failed";
|
||||
reject(new Error(`OpenAI WebSocket response failed: ${errMsg}`));
|
||||
} else if (event.type === "error") {
|
||||
cleanup();
|
||||
reject(new Error(`OpenAI WebSocket error: ${event.message} (code=${event.code})`));
|
||||
} else if (event.type === "response.output_text.delta") {
|
||||
// Stream partial text updates for responsive UI
|
||||
const partialMsg: AssistantMessage = {
|
||||
role: "assistant",
|
||||
content: [{ type: "text", text: event.delta }],
|
||||
stopReason: "stop",
|
||||
api: model.api,
|
||||
provider: model.provider,
|
||||
model: model.id,
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
eventStream.push({
|
||||
type: "text_delta",
|
||||
contentIndex: 0,
|
||||
delta: event.delta,
|
||||
partial: partialMsg,
|
||||
});
|
||||
}
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
queueMicrotask(() =>
|
||||
run().catch((err) => {
|
||||
const errorMessage = err instanceof Error ? err.message : String(err);
|
||||
log.warn(`[ws-stream] session=${sessionId} run error: ${errorMessage}`);
|
||||
eventStream.push({
|
||||
type: "error",
|
||||
reason: "error",
|
||||
error: {
|
||||
role: "assistant" as const,
|
||||
content: [],
|
||||
stopReason: "error" as StopReason,
|
||||
errorMessage,
|
||||
api: model.api,
|
||||
provider: model.provider,
|
||||
model: model.id,
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
timestamp: Date.now(),
|
||||
},
|
||||
});
|
||||
eventStream.end();
|
||||
}),
|
||||
);
|
||||
|
||||
return eventStream;
|
||||
};
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// Helpers
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
/** Build full input items from context (system prompt is passed via `instructions` field). */
|
||||
function buildFullInput(context: Context): InputItem[] {
|
||||
return convertMessagesToInputItems(context.messages);
|
||||
}
|
||||
|
||||
/**
|
||||
* Fall back to HTTP (`streamSimple`) and pipe events into the existing stream.
|
||||
* This is called when the WebSocket is broken or unavailable.
|
||||
*/
|
||||
async function fallbackToHttp(
|
||||
model: Parameters<StreamFn>[0],
|
||||
context: Parameters<StreamFn>[1],
|
||||
options: Parameters<StreamFn>[2],
|
||||
eventStream: ReturnType<typeof createAssistantMessageEventStream>,
|
||||
signal?: AbortSignal,
|
||||
): Promise<void> {
|
||||
const mergedOptions = signal ? { ...options, signal } : options;
|
||||
const httpStream = streamSimple(model, context, mergedOptions);
|
||||
for await (const event of httpStream) {
|
||||
eventStream.push(event);
|
||||
}
|
||||
}
|
||||
@@ -544,7 +544,7 @@ describe("applyExtraParamsToAgent", () => {
|
||||
expect(calls[0]?.transport).toBe("auto");
|
||||
});
|
||||
|
||||
it("does not set transport defaults for non-Codex providers", () => {
|
||||
it("defaults OpenAI transport to auto (WebSocket-first)", () => {
|
||||
const { calls, agent } = createOptionsCaptureAgent();
|
||||
|
||||
applyExtraParamsToAgent(agent, undefined, "openai", "gpt-5");
|
||||
@@ -558,7 +558,24 @@ describe("applyExtraParamsToAgent", () => {
|
||||
void agent.streamFn?.(model, context, {});
|
||||
|
||||
expect(calls).toHaveLength(1);
|
||||
expect(calls[0]?.transport).toBeUndefined();
|
||||
expect(calls[0]?.transport).toBe("auto");
|
||||
});
|
||||
|
||||
it("lets runtime options override OpenAI default transport", () => {
|
||||
const { calls, agent } = createOptionsCaptureAgent();
|
||||
|
||||
applyExtraParamsToAgent(agent, undefined, "openai", "gpt-5");
|
||||
|
||||
const model = {
|
||||
api: "openai-responses",
|
||||
provider: "openai",
|
||||
id: "gpt-5",
|
||||
} as Model<"openai-responses">;
|
||||
const context: Context = { messages: [] };
|
||||
void agent.streamFn?.(model, context, { transport: "sse" });
|
||||
|
||||
expect(calls).toHaveLength(1);
|
||||
expect(calls[0]?.transport).toBe("sse");
|
||||
});
|
||||
|
||||
it("allows forcing Codex transport to SSE", () => {
|
||||
@@ -878,6 +895,7 @@ describe("applyExtraParamsToAgent", () => {
|
||||
contextWindow: 128_000,
|
||||
maxTokens: 16_384,
|
||||
compat: { supportsStore: false },
|
||||
<<<<<<< HEAD
|
||||
} as Model<"openai-responses"> & { compat?: { supportsStore?: boolean } },
|
||||
});
|
||||
expect(payload.store).toBe(false);
|
||||
|
||||
@@ -319,6 +319,15 @@ function createCodexDefaultTransportWrapper(baseStreamFn: StreamFn | undefined):
|
||||
});
|
||||
}
|
||||
|
||||
function createOpenAIDefaultTransportWrapper(baseStreamFn: StreamFn | undefined): StreamFn {
|
||||
const underlying = baseStreamFn ?? streamSimple;
|
||||
return (model, context, options) =>
|
||||
underlying(model, context, {
|
||||
...options,
|
||||
transport: options?.transport ?? "auto",
|
||||
});
|
||||
}
|
||||
|
||||
function isAnthropic1MModel(modelId: string): boolean {
|
||||
const normalized = modelId.trim().toLowerCase();
|
||||
return ANTHROPIC_1M_MODEL_PREFIXES.some((prefix) => normalized.startsWith(prefix));
|
||||
@@ -740,6 +749,9 @@ export function applyExtraParamsToAgent(
|
||||
if (provider === "openai-codex") {
|
||||
// Default Codex to WebSocket-first when nothing else specifies transport.
|
||||
agent.streamFn = createCodexDefaultTransportWrapper(agent.streamFn);
|
||||
} else if (provider === "openai") {
|
||||
// Default OpenAI Responses to WebSocket-first with transparent SSE fallback.
|
||||
agent.streamFn = createOpenAIDefaultTransportWrapper(agent.streamFn);
|
||||
}
|
||||
const override =
|
||||
extraParamsOverride && Object.keys(extraParamsOverride).length > 0
|
||||
|
||||
@@ -42,6 +42,7 @@ import { resolveImageSanitizationLimits } from "../../image-sanitization.js";
|
||||
import { resolveModelAuthMode } from "../../model-auth.js";
|
||||
import { normalizeProviderId, resolveDefaultModelForAgent } from "../../model-selection.js";
|
||||
import { createOllamaStreamFn, OLLAMA_NATIVE_BASE_URL } from "../../ollama-stream.js";
|
||||
import { createOpenAIWebSocketStreamFn, releaseWsSession } from "../../openai-ws-stream.js";
|
||||
import { resolveOwnerDisplaySetting } from "../../owner-display.js";
|
||||
import {
|
||||
isCloudCodeAssistFormatError,
|
||||
@@ -866,6 +867,16 @@ export async function runEmbeddedAttempt(
|
||||
typeof providerConfig?.baseUrl === "string" ? providerConfig.baseUrl.trim() : "";
|
||||
const ollamaBaseUrl = modelBaseUrl || providerBaseUrl || OLLAMA_NATIVE_BASE_URL;
|
||||
activeSession.agent.streamFn = createOllamaStreamFn(ollamaBaseUrl);
|
||||
} else if (params.model.api === "openai-responses" && params.provider === "openai") {
|
||||
const wsApiKey = await params.authStorage.getApiKey(params.provider);
|
||||
if (wsApiKey) {
|
||||
activeSession.agent.streamFn = createOpenAIWebSocketStreamFn(wsApiKey, params.sessionId, {
|
||||
signal: runAbortController.signal,
|
||||
});
|
||||
} else {
|
||||
log.warn(`[ws-stream] no API key for provider=${params.provider}; using HTTP transport`);
|
||||
activeSession.agent.streamFn = streamSimple;
|
||||
}
|
||||
} else {
|
||||
// Force a stable streamFn reference so vitest can reliably mock @mariozechner/pi-ai.
|
||||
activeSession.agent.streamFn = streamSimple;
|
||||
@@ -1548,6 +1559,7 @@ export async function runEmbeddedAttempt(
|
||||
sessionManager,
|
||||
});
|
||||
session?.dispose();
|
||||
releaseWsSession(params.sessionId);
|
||||
await sessionLock.release();
|
||||
}
|
||||
} finally {
|
||||
|
||||
Reference in New Issue
Block a user