refactor(openai): move native transport policy into extension

This commit is contained in:
Peter Steinberger
2026-04-04 04:25:36 +01:00
parent 585b1c9413
commit eb9051cc7c
21 changed files with 1310 additions and 305 deletions

View File

@@ -501,6 +501,68 @@ describe("openai transport stream", () => {
expect(params.tools?.[0]).not.toHaveProperty("strict");
});
it("adds native OpenAI turn metadata on direct Responses routes", () => {
const params = buildOpenAIResponsesParams(
{
id: "gpt-5.4",
name: "GPT-5.4",
api: "openai-responses",
provider: "openai",
baseUrl: "https://api.openai.com/v1",
reasoning: true,
input: ["text"],
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
contextWindow: 200000,
maxTokens: 8192,
} satisfies Model<"openai-responses">,
{
systemPrompt: "system",
messages: [],
tools: [],
} as never,
{ sessionId: "session-123" } as never,
{
openclaw_session_id: "session-123",
openclaw_turn_id: "turn-123",
openclaw_turn_attempt: "1",
openclaw_transport: "stream",
},
) as { metadata?: Record<string, string> };
expect(params.metadata).toMatchObject({
openclaw_session_id: "session-123",
openclaw_turn_id: "turn-123",
openclaw_turn_attempt: "1",
openclaw_transport: "stream",
});
});
it("leaves proxy-like OpenAI Responses routes without native turn metadata by default", () => {
const params = buildOpenAIResponsesParams(
{
id: "custom-model",
name: "Custom Model",
api: "openai-responses",
provider: "openai",
baseUrl: "https://proxy.example.com/v1",
reasoning: true,
input: ["text"],
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
contextWindow: 200000,
maxTokens: 8192,
} satisfies Model<"openai-responses">,
{
systemPrompt: "system",
messages: [],
tools: [],
} as never,
{ sessionId: "session-123" } as never,
undefined,
) as { metadata?: Record<string, string> };
expect(params).not.toHaveProperty("metadata");
});
it("gates responses service_tier to native OpenAI endpoints", () => {
const nativeParams = buildOpenAIResponsesParams(
{

View File

@@ -1,3 +1,4 @@
import { randomUUID } from "node:crypto";
import type { StreamFn } from "@mariozechner/pi-agent-core";
import {
calculateCost,
@@ -18,6 +19,8 @@ import type {
ResponseInput,
ResponseInputMessageContentList,
} from "openai/resources/responses/responses.js";
import { resolveProviderTransportTurnStateWithPlugin } from "../plugins/provider-runtime.js";
import type { ProviderRuntimeModel } from "../plugins/types.js";
import { buildCopilotDynamicHeaders, hasCopilotVisionInput } from "./copilot-dynamic-headers.js";
import { resolveOpenAICompletionsCompatDefaultsFromCapabilities } from "./openai-completions-compat.js";
import {
@@ -27,7 +30,7 @@ import {
import { resolveProviderRequestCapabilities } from "./provider-attribution.js";
import { buildGuardedModelFetch } from "./provider-transport-fetch.js";
import { transformTransportMessages } from "./transport-message-transform.js";
import { sanitizeTransportPayloadText } from "./transport-stream-shared.js";
import { mergeTransportMetadata, sanitizeTransportPayloadText } from "./transport-stream-shared.js";
const DEFAULT_AZURE_OPENAI_API_VERSION = "2024-12-01-preview";
@@ -561,6 +564,7 @@ function buildOpenAIClientHeaders(
model: Model<Api>,
context: Context,
optionHeaders?: Record<string, string>,
turnHeaders?: Record<string, string>,
): Record<string, string> {
const headers = { ...model.headers };
if (model.provider === "github-copilot") {
@@ -575,20 +579,47 @@ function buildOpenAIClientHeaders(
if (optionHeaders) {
Object.assign(headers, optionHeaders);
}
if (turnHeaders) {
Object.assign(headers, turnHeaders);
}
return headers;
}
function resolveProviderTransportTurnState(
model: Model<Api>,
params: {
sessionId?: string;
turnId: string;
attempt: number;
transport: "stream" | "websocket";
},
) {
return resolveProviderTransportTurnStateWithPlugin({
provider: model.provider,
context: {
provider: model.provider,
modelId: model.id,
model: model as ProviderRuntimeModel,
sessionId: params.sessionId,
turnId: params.turnId,
attempt: params.attempt,
transport: params.transport,
},
});
}
function createOpenAIResponsesClient(
model: Model<Api>,
context: Context,
apiKey: string,
optionHeaders?: Record<string, string>,
turnHeaders?: Record<string, string>,
) {
return new OpenAI({
apiKey,
baseURL: model.baseUrl,
dangerouslyAllowBrowser: true,
defaultHeaders: buildOpenAIClientHeaders(model, context, optionHeaders),
defaultHeaders: buildOpenAIClientHeaders(model, context, optionHeaders, turnHeaders),
fetch: buildGuardedModelFetch(model),
});
}
@@ -617,12 +648,30 @@ export function createOpenAIResponsesTransportStreamFn(): StreamFn {
};
try {
const apiKey = options?.apiKey || getEnvApiKey(model.provider) || "";
const client = createOpenAIResponsesClient(model, context, apiKey, options?.headers);
let params = buildOpenAIResponsesParams(model, context, options as OpenAIResponsesOptions);
const turnState = resolveProviderTransportTurnState(model, {
sessionId: options?.sessionId,
turnId: randomUUID(),
attempt: 1,
transport: "stream",
});
const client = createOpenAIResponsesClient(
model,
context,
apiKey,
options?.headers,
turnState?.headers,
);
let params = buildOpenAIResponsesParams(
model,
context,
options as OpenAIResponsesOptions,
turnState?.metadata,
);
const nextParams = await options?.onPayload?.(params, model);
if (nextParams !== undefined) {
params = nextParams as typeof params;
}
params = mergeTransportMetadata(params, turnState?.metadata);
const responseStream = (await client.responses.create(
params as never,
options?.signal ? { signal: options.signal } : undefined,
@@ -675,6 +724,7 @@ export function buildOpenAIResponsesParams(
model: Model<Api>,
context: Context,
options: OpenAIResponsesOptions | undefined,
metadata?: Record<string, string>,
) {
const compat = getCompat(model as OpenAIModeModel);
const supportsDeveloperRole =
@@ -695,6 +745,7 @@ export function buildOpenAIResponsesParams(
stream: true,
prompt_cache_key: cacheRetention === "none" ? undefined : options?.sessionId,
prompt_cache_retention: getPromptCacheRetention(model.baseUrl, cacheRetention),
...(metadata ? { metadata } : {}),
};
if (options?.maxTokens) {
params.max_output_tokens = options.maxTokens;
@@ -749,18 +800,32 @@ export function createAzureOpenAIResponsesTransportStreamFn(): StreamFn {
};
try {
const apiKey = options?.apiKey || getEnvApiKey(model.provider) || "";
const client = createAzureOpenAIClient(model, context, apiKey, options?.headers);
const turnState = resolveProviderTransportTurnState(model, {
sessionId: options?.sessionId,
turnId: randomUUID(),
attempt: 1,
transport: "stream",
});
const client = createAzureOpenAIClient(
model,
context,
apiKey,
options?.headers,
turnState?.headers,
);
const deploymentName = resolveAzureDeploymentName(model);
let params = buildAzureOpenAIResponsesParams(
model,
context,
options as OpenAIResponsesOptions | undefined,
deploymentName,
turnState?.metadata,
);
const nextParams = await options?.onPayload?.(params, model);
if (nextParams !== undefined) {
params = nextParams as typeof params;
}
params = mergeTransportMetadata(params, turnState?.metadata);
const responseStream = (await client.responses.create(
params as never,
options?.signal ? { signal: options.signal } : undefined,
@@ -808,12 +873,13 @@ function createAzureOpenAIClient(
context: Context,
apiKey: string,
optionHeaders?: Record<string, string>,
turnHeaders?: Record<string, string>,
) {
return new AzureOpenAI({
apiKey,
apiVersion: resolveAzureOpenAIApiVersion(),
dangerouslyAllowBrowser: true,
defaultHeaders: buildOpenAIClientHeaders(model, context, optionHeaders),
defaultHeaders: buildOpenAIClientHeaders(model, context, optionHeaders, turnHeaders),
baseURL: normalizeAzureBaseUrl(model.baseUrl),
fetch: buildGuardedModelFetch(model),
});
@@ -824,8 +890,9 @@ function buildAzureOpenAIResponsesParams(
context: Context,
options: OpenAIResponsesOptions | undefined,
deploymentName: string,
metadata?: Record<string, string>,
) {
const params = buildOpenAIResponsesParams(model, context, options);
const params = buildOpenAIResponsesParams(model, context, options, metadata);
params.model = deploymentName;
delete params.store;
return params;
@@ -1148,6 +1215,7 @@ type OpenAIResponsesRequestParams = {
stream: true;
prompt_cache_key?: string;
prompt_cache_retention?: "24h";
metadata?: Record<string, string>;
store?: boolean;
max_output_tokens?: number;
temperature?: number;

View File

@@ -252,6 +252,27 @@ describe("OpenAIWebSocketManager", () => {
await connectPromise;
});
it("merges native session headers into the websocket handshake", async () => {
const manager = buildManager({
headers: {
"x-client-request-id": "session-123",
"x-openclaw-session-id": "session-123",
},
});
const connectPromise = manager.connect("sk-test-key");
const sock = lastSocket();
expect(sock.options).toMatchObject({
headers: expect.objectContaining({
"x-client-request-id": "session-123",
"x-openclaw-session-id": "session-123",
}),
});
sock.simulateOpen();
await connectPromise;
});
it("does not add hidden attribution headers on custom websocket endpoints", async () => {
const manager = buildManager({
url: "wss://proxy.example.com/v1/responses",

View File

@@ -281,6 +281,8 @@ export interface OpenAIWebSocketManagerOptions {
backoffDelaysMs?: readonly number[];
/** Custom socket factory for tests. */
socketFactory?: (url: string, options: ClientOptions) => WebSocket;
/** Extra headers merged into the initial WebSocket handshake request. */
headers?: Record<string, string>;
/** Optional transport overrides for provider-owned auth or TLS wiring. */
request?: ProviderRequestTransportOverrides;
}
@@ -338,6 +340,7 @@ export class OpenAIWebSocketManager extends EventEmitter<InternalEvents> {
private readonly maxRetries: number;
private readonly backoffDelaysMs: readonly number[];
private readonly socketFactory: (url: string, options: ClientOptions) => WebSocket;
private readonly headers?: Record<string, string>;
private readonly request?: ProviderRequestTransportOverrides;
constructor(options: OpenAIWebSocketManagerOptions = {}) {
@@ -347,6 +350,7 @@ export class OpenAIWebSocketManager extends EventEmitter<InternalEvents> {
this.backoffDelaysMs = options.backoffDelaysMs ?? BACKOFF_DELAYS_MS;
this.socketFactory =
options.socketFactory ?? ((url, socketOptions) => new WebSocket(url, socketOptions));
this.headers = options.headers;
this.request = options.request;
}
@@ -454,6 +458,7 @@ export class OpenAIWebSocketManager extends EventEmitter<InternalEvents> {
providerHeaders: {
Authorization: `Bearer ${this.apiKey}`,
"OpenAI-Beta": "responses-websocket=v1",
...this.headers,
},
precedence: "defaults-win",
request: this.request,
@@ -607,7 +612,12 @@ export class OpenAIWebSocketManager extends EventEmitter<InternalEvents> {
* Sends a warm-up event to pre-load the connection and model without generating output.
* Pass tools/instructions to prime the connection for the upcoming session.
*/
warmUp(params: { model: string; tools?: FunctionToolDefinition[]; instructions?: string }): void {
warmUp(params: {
model: string;
tools?: FunctionToolDefinition[];
instructions?: string;
metadata?: Record<string, string>;
}): void {
const event = buildOpenAIWebSocketWarmUpPayload(params);
this.send(event);
}

View File

@@ -30,6 +30,7 @@ export function buildOpenAIWebSocketWarmUpPayload(params: {
model: string;
tools?: FunctionToolDefinition[];
instructions?: string;
metadata?: Record<string, string>;
}): WarmUpEvent {
return {
type: "response.create",
@@ -38,6 +39,7 @@ export function buildOpenAIWebSocketWarmUpPayload(params: {
input: [],
...(params.tools?.length ? { tools: params.tools } : {}),
...(params.instructions ? { instructions: params.instructions } : {}),
...(params.metadata ? { metadata: params.metadata } : {}),
};
}
@@ -47,6 +49,7 @@ export function buildOpenAIWebSocketResponseCreatePayload(params: {
options?: WsOptions;
turnInput: PlannedWsTurnInput;
tools: FunctionToolDefinition[];
metadata?: Record<string, string>;
}): ResponseCreateEvent {
const extraParams: Record<string, unknown> = {};
const streamOpts = params.options;
@@ -108,6 +111,7 @@ export function buildOpenAIWebSocketResponseCreatePayload(params: {
...(params.turnInput.previousResponseId
? { previous_response_id: params.turnInput.previousResponseId }
: {}),
...(params.metadata ? { metadata: params.metadata } : {}),
...extraParams,
};
}

View File

@@ -47,11 +47,17 @@ const { MockManager } = vi.hoisted(() => {
sentEvents: unknown[] = [];
connectCallCount = 0;
closeCallCount = 0;
options: unknown;
// Allow tests to override connect/send behaviour
connectShouldFail = false;
sendShouldFail = false;
constructor(options?: unknown) {
super();
this.options = options;
}
get previousResponseId(): string | null {
return this._previousResponseId;
}
@@ -201,9 +207,9 @@ const { MockManager } = vi.hoisted(() => {
});
// Track if streamSimple (HTTP fallback) was called
const streamSimpleCalls: Array<{ model: unknown; context: unknown }> = [];
const mockStreamSimple = vi.fn((model: unknown, context: unknown) => {
streamSimpleCalls.push({ model, context });
const streamSimpleCalls: Array<{ model: unknown; context: unknown; options?: unknown }> = [];
const mockStreamSimple = vi.fn((model: unknown, context: unknown, options?: unknown) => {
streamSimpleCalls.push({ model, context, options });
const stream = createAssistantMessageEventStream();
queueMicrotask(() => {
const msg = makeFakeAssistantMessage("http fallback response");
@@ -1174,7 +1180,7 @@ describe("createOpenAIWebSocketStreamFn", () => {
MockManager.reset();
streamSimpleCalls.length = 0;
openAIWsStreamTesting.setDepsForTest({
createManager: (() => new MockManager()) as never,
createManager: ((options?: unknown) => new MockManager(options)) as never,
streamSimple: mockStreamSimple,
});
});
@@ -1196,7 +1202,10 @@ describe("createOpenAIWebSocketStreamFn", () => {
releaseWsSession("sess-store-compat");
releaseWsSession("sess-max-tokens-zero");
releaseWsSession("sess-runtime-fallback");
releaseWsSession("sess-turn-metadata-retry");
releaseWsSession("sess-degraded-cooldown");
releaseWsSession("sess-drop");
openAIWsStreamTesting.setWsDegradeCooldownMsForTest();
openAIWsStreamTesting.setDepsForTest();
});
@@ -1447,8 +1456,8 @@ describe("createOpenAIWebSocketStreamFn", () => {
// streamSimple was called as part of HTTP fallback
expect(streamSimpleCalls.length).toBeGreaterThanOrEqual(1);
// manager.close() must be called to cancel background reconnect attempts
expect(MockManager.lastInstance!.closeCallCount).toBeGreaterThanOrEqual(1);
// The failed manager is closed before the replacement session manager is installed.
expect(MockManager.instances.some((instance) => instance.closeCallCount >= 1)).toBe(true);
} finally {
MockManager.globalConnectShouldFail = false;
}
@@ -1550,6 +1559,103 @@ describe("createOpenAIWebSocketStreamFn", () => {
const doneEvent = events.find((event) => event.type === "done");
expect(doneEvent?.message?.content?.[0]?.text).toBe("retry succeeded");
});
it("keeps native turn metadata stable across websocket retries and increments attempt", async () => {
const streamFn = createOpenAIWebSocketStreamFn("sk-test", "sess-turn-metadata-retry");
const stream = streamFn(
modelStub as Parameters<typeof streamFn>[0],
contextStub as Parameters<typeof streamFn>[1],
{ transport: "auto" } as Parameters<typeof streamFn>[2],
);
await new Promise((r) => setImmediate(r));
const firstManager = MockManager.lastInstance!;
firstManager.simulateClose(1006, "connection lost");
await new Promise((r) => setImmediate(r));
const secondManager = MockManager.lastInstance!;
secondManager.simulateEvent({
type: "response.completed",
response: makeResponseObject("resp-retried-meta", "retry succeeded"),
});
for await (const _ of await resolveStream(stream)) {
// consume
}
const firstPayload = firstManager.sentEvents[0] as { metadata?: Record<string, string> };
const secondPayload = secondManager.sentEvents[0] as { metadata?: Record<string, string> };
expect(firstPayload.metadata?.openclaw_session_id).toBe("sess-turn-metadata-retry");
expect(firstPayload.metadata?.openclaw_transport).toBe("websocket");
expect(firstPayload.metadata?.openclaw_turn_id).toBeTruthy();
expect(secondPayload.metadata?.openclaw_turn_id).toBe(firstPayload.metadata?.openclaw_turn_id);
expect(firstPayload.metadata?.openclaw_turn_attempt).toBe("1");
expect(secondPayload.metadata?.openclaw_turn_attempt).toBe("2");
});
it("keeps websocket degraded for the session until the cool-down expires", async () => {
openAIWsStreamTesting.setWsDegradeCooldownMsForTest(50);
MockManager.globalConnectShouldFail = true;
try {
const sessionId = "sess-degraded-cooldown";
const streamFn = createOpenAIWebSocketStreamFn("sk-test", sessionId);
const firstStream = streamFn(
modelStub as Parameters<typeof streamFn>[0],
contextStub as Parameters<typeof streamFn>[1],
{ transport: "auto" } as Parameters<typeof streamFn>[2],
);
void firstStream;
await new Promise((resolve) => setImmediate(resolve));
await new Promise((resolve) => setImmediate(resolve));
expect(streamSimpleCalls.length).toBe(1);
expect(MockManager.instances).toHaveLength(2);
const cooledManager = MockManager.lastInstance!;
expect(cooledManager.connectCallCount).toBe(0);
MockManager.globalConnectShouldFail = false;
const secondStream = streamFn(
modelStub as Parameters<typeof streamFn>[0],
contextStub as Parameters<typeof streamFn>[1],
{ transport: "auto" } as Parameters<typeof streamFn>[2],
);
void secondStream;
await new Promise((resolve) => setImmediate(resolve));
await new Promise((resolve) => setImmediate(resolve));
expect(streamSimpleCalls.length).toBe(2);
expect(MockManager.instances).toHaveLength(2);
expect(cooledManager.connectCallCount).toBe(0);
await new Promise((resolve) => setTimeout(resolve, 60));
const thirdStream = streamFn(
modelStub as Parameters<typeof streamFn>[0],
contextStub as Parameters<typeof streamFn>[1],
{ transport: "auto" } as Parameters<typeof streamFn>[2],
);
void thirdStream;
await new Promise((resolve) => setImmediate(resolve));
await new Promise((resolve) => setImmediate(resolve));
expect(cooledManager.connectCallCount).toBe(1);
expect(streamSimpleCalls.length).toBe(2);
cooledManager.simulateEvent({
type: "response.completed",
response: makeResponseObject("resp-after-cooldown", "ws recovered"),
});
await new Promise((resolve) => setImmediate(resolve));
} finally {
MockManager.globalConnectShouldFail = false;
openAIWsStreamTesting.setWsDegradeCooldownMsForTest();
releaseWsSession("sess-degraded-cooldown");
releaseWsSession("sess-turn-metadata-retry");
}
});
it("tracks previous_response_id across turns (incremental send)", async () => {
const sessionId = "sess-incremental";
const streamFn = createOpenAIWebSocketStreamFn("sk-test", sessionId);

View File

@@ -21,6 +21,7 @@
* @see src/agents/openai-ws-connection.ts for the connection manager
*/
import { randomUUID } from "node:crypto";
import type { StreamFn } from "@mariozechner/pi-agent-core";
import type {
AssistantMessage,
@@ -29,6 +30,11 @@ import type {
StopReason,
} from "@mariozechner/pi-ai";
import * as piAi from "@mariozechner/pi-ai";
import {
resolveProviderTransportTurnStateWithPlugin,
resolveProviderWebSocketSessionPolicyWithPlugin,
} from "../plugins/provider-runtime.js";
import type { ProviderRuntimeModel, ProviderTransportTurnState } from "../plugins/types.js";
import {
getOpenAIWebSocketErrorDetails,
OpenAIWebSocketManager,
@@ -47,6 +53,7 @@ import {
buildAssistantMessageWithZeroUsage,
buildStreamErrorAssistantMessage,
} from "./stream-message-shared.js";
import { mergeTransportMetadata } from "./transport-stream-shared.js";
// ─────────────────────────────────────────────────────────────────────────────
// Per-session state
@@ -62,6 +69,9 @@ interface WsSession {
warmUpAttempted: boolean;
/** True if the session is permanently broken (no more reconnect). */
broken: boolean;
/** Session-scoped cool-down after repeated websocket failures. */
degradedUntil: number | null;
degradeCooldownMs: number;
}
/** Module-level registry: sessionId → WsSession */
@@ -208,6 +218,8 @@ export interface OpenAIWebSocketStreamOptions {
type WsTransport = "sse" | "websocket" | "auto";
const WARM_UP_TIMEOUT_MS = 8_000;
const MAX_AUTO_WS_RUNTIME_RETRIES = 1;
const DEFAULT_WS_DEGRADE_COOLDOWN_MS = 60_000;
let wsDegradeCooldownMsOverride: number | undefined;
class OpenAIWebSocketRuntimeError extends Error {
readonly kind: "disconnect" | "send" | "server";
@@ -247,13 +259,100 @@ function resolveWsWarmup(options: Parameters<StreamFn>[2]): boolean {
return warmup === true;
}
function resetWsSession(params: { sessionId: string; session: WsSession }): void {
function resetWsSession(params: {
session: WsSession;
createManager: () => OpenAIWebSocketManager;
preserveDegradeUntil?: boolean;
}): void {
try {
params.session.manager.close();
} catch {
/* ignore */
}
wsRegistry.delete(params.sessionId);
params.session.manager = params.createManager();
params.session.everConnected = false;
params.session.warmUpAttempted = false;
params.session.broken = false;
if (!params.preserveDegradeUntil) {
params.session.degradedUntil = null;
}
}
function markWsSessionDegraded(session: WsSession): void {
session.degradedUntil = Date.now() + session.degradeCooldownMs;
}
function isWsSessionDegraded(session: WsSession): boolean {
if (!session.degradedUntil) {
return false;
}
if (session.degradedUntil <= Date.now()) {
session.degradedUntil = null;
return false;
}
return true;
}
function createWsManager(
managerOptions: OpenAIWebSocketManagerOptions | undefined,
sessionHeaders?: Record<string, string>,
): OpenAIWebSocketManager {
return openAIWsStreamDeps.createManager({
...managerOptions,
...(sessionHeaders
? {
headers: {
...managerOptions?.headers,
...sessionHeaders,
},
}
: {}),
});
}
function resolveProviderTransportTurnState(
model: Parameters<StreamFn>[0],
params: {
sessionId?: string;
turnId: string;
attempt: number;
transport: "stream" | "websocket";
},
): ProviderTransportTurnState | undefined {
return resolveProviderTransportTurnStateWithPlugin({
provider: model.provider,
context: {
provider: model.provider,
modelId: model.id,
model: model as ProviderRuntimeModel,
sessionId: params.sessionId,
turnId: params.turnId,
attempt: params.attempt,
transport: params.transport,
},
});
}
function resolveWebSocketSessionPolicy(
model: Parameters<StreamFn>[0],
sessionId: string,
): { headers?: Record<string, string>; degradeCooldownMs: number } {
const policy = resolveProviderWebSocketSessionPolicyWithPlugin({
provider: model.provider,
context: {
provider: model.provider,
modelId: model.id,
model: model as ProviderRuntimeModel,
sessionId,
},
});
return {
headers: policy?.headers,
degradeCooldownMs: Math.max(
0,
wsDegradeCooldownMsOverride ?? policy?.degradeCooldownMs ?? DEFAULT_WS_DEGRADE_COOLDOWN_MS,
),
};
}
function formatOpenAIWebSocketError(
@@ -311,6 +410,7 @@ async function runWarmUp(params: {
modelId: string;
tools: FunctionToolDefinition[];
instructions?: string;
metadata?: Record<string, string>;
signal?: AbortSignal;
}): Promise<void> {
if (params.signal?.aborted) {
@@ -358,6 +458,7 @@ async function runWarmUp(params: {
model: params.modelId,
tools: params.tools.length > 0 ? params.tools : undefined,
instructions: params.instructions,
...(params.metadata ? { metadata: params.metadata } : {}),
});
});
}
@@ -392,34 +493,55 @@ export function createOpenAIWebSocketStreamFn(
const signal = opts.signal ?? (options as WsOptions | undefined)?.signal;
let emittedStart = false;
let runtimeRetries = 0;
const turnId = randomUUID();
let turnAttempt = 0;
const wsSessionPolicy = resolveWebSocketSessionPolicy(model, sessionId);
const sessionHeaders = wsSessionPolicy.headers;
while (true) {
let session = wsRegistry.get(sessionId);
if (!session) {
const manager = openAIWsStreamDeps.createManager(opts.managerOptions);
const manager = createWsManager(opts.managerOptions, sessionHeaders);
session = {
manager,
lastContextLength: 0,
everConnected: false,
warmUpAttempted: false,
broken: false,
degradedUntil: null,
degradeCooldownMs: wsSessionPolicy.degradeCooldownMs,
};
wsRegistry.set(sessionId, session);
}
if (transport !== "websocket" && isWsSessionDegraded(session)) {
log.debug(
`[ws-stream] session=${sessionId} in websocket cool-down; using HTTP fallback until ${new Date(session.degradedUntil!).toISOString()}`,
);
return fallbackToHttp(model, context, options, apiKey, eventStream, opts.signal, {
suppressStart: emittedStart,
turnState: resolveProviderTransportTurnState(model, {
sessionId,
turnId,
attempt: Math.max(1, turnAttempt),
transport: "stream",
}),
});
}
if (!session.manager.isConnected() && !session.broken) {
try {
await session.manager.connect(apiKey);
session.everConnected = true;
session.degradedUntil = null;
log.debug(`[ws-stream] connected for session=${sessionId}`);
} catch (connErr) {
try {
session.manager.close();
} catch {
/* ignore */
}
session.broken = true;
wsRegistry.delete(sessionId);
markWsSessionDegraded(session);
resetWsSession({
session,
createManager: () => createWsManager(opts.managerOptions, sessionHeaders),
preserveDegradeUntil: true,
});
if (transport === "websocket") {
throw connErr instanceof Error ? connErr : new Error(String(connErr));
}
@@ -428,6 +550,12 @@ export function createOpenAIWebSocketStreamFn(
);
return fallbackToHttp(model, context, options, apiKey, eventStream, opts.signal, {
suppressStart: emittedStart,
turnState: resolveProviderTransportTurnState(model, {
sessionId,
turnId,
attempt: Math.max(1, turnAttempt),
transport: "stream",
}),
});
}
}
@@ -437,9 +565,20 @@ export function createOpenAIWebSocketStreamFn(
throw new Error("WebSocket session disconnected");
}
log.warn(`[ws-stream] session=${sessionId} broken/disconnected; falling back to HTTP`);
resetWsSession({ sessionId, session });
markWsSessionDegraded(session);
resetWsSession({
session,
createManager: () => createWsManager(opts.managerOptions, sessionHeaders),
preserveDegradeUntil: true,
});
return fallbackToHttp(model, context, options, apiKey, eventStream, opts.signal, {
suppressStart: emittedStart,
turnState: resolveProviderTransportTurnState(model, {
sessionId,
turnId,
attempt: Math.max(1, turnAttempt),
transport: "stream",
}),
});
}
@@ -452,6 +591,12 @@ export function createOpenAIWebSocketStreamFn(
modelId: model.id,
tools: convertTools(context.tools),
instructions: context.systemPrompt ?? undefined,
metadata: resolveProviderTransportTurnState(model, {
sessionId,
turnId,
attempt: Math.max(1, turnAttempt),
transport: "websocket",
})?.metadata,
signal,
});
log.debug(`[ws-stream] warm-up completed for session=${sessionId}`);
@@ -471,12 +616,18 @@ export function createOpenAIWebSocketStreamFn(
/* ignore */
}
try {
session.manager = createWsManager(opts.managerOptions, sessionHeaders);
await session.manager.connect(apiKey);
session.everConnected = true;
session.degradedUntil = null;
log.debug(`[ws-stream] reconnected after warm-up failure for session=${sessionId}`);
} catch (reconnectErr) {
session.broken = true;
wsRegistry.delete(sessionId);
markWsSessionDegraded(session);
resetWsSession({
session,
createManager: () => createWsManager(opts.managerOptions, sessionHeaders),
preserveDegradeUntil: true,
});
if (transport === "websocket") {
throw reconnectErr instanceof Error
? reconnectErr
@@ -487,6 +638,12 @@ export function createOpenAIWebSocketStreamFn(
);
return fallbackToHttp(model, context, options, apiKey, eventStream, opts.signal, {
suppressStart: emittedStart,
turnState: resolveProviderTransportTurnState(model, {
sessionId,
turnId,
attempt: Math.max(1, turnAttempt),
transport: "stream",
}),
});
}
}
@@ -513,17 +670,27 @@ export function createOpenAIWebSocketStreamFn(
);
}
const payload = buildOpenAIWebSocketResponseCreatePayload({
turnAttempt++;
const turnState = resolveProviderTransportTurnState(model, {
sessionId,
turnId,
attempt: turnAttempt,
transport: "websocket",
});
let payload = buildOpenAIWebSocketResponseCreatePayload({
model,
context,
options: options as WsOptions | undefined,
turnInput,
tools: convertTools(context.tools),
metadata: turnState?.metadata,
}) as Record<string, unknown>;
const nextPayload = options?.onPayload?.(payload, model);
const requestPayload = (nextPayload ?? payload) as Parameters<
OpenAIWebSocketManager["send"]
>[0];
payload = mergeTransportMetadata(
(nextPayload ?? payload) as Record<string, unknown>,
turnState?.metadata,
);
const requestPayload = payload as Parameters<OpenAIWebSocketManager["send"]>[0];
try {
session.manager.send(requestPayload);
@@ -538,16 +705,30 @@ export function createOpenAIWebSocketStreamFn(
log.warn(
`[ws-stream] retrying websocket turn after send failure for session=${sessionId} (${runtimeRetries}/${MAX_AUTO_WS_RUNTIME_RETRIES}). error=${normalizedErr.message}`,
);
resetWsSession({ sessionId, session });
resetWsSession({
session,
createManager: () => createWsManager(opts.managerOptions, sessionHeaders),
});
continue;
}
if (transport !== "websocket") {
log.warn(
`[ws-stream] send failed for session=${sessionId}; falling back to HTTP. error=${normalizedErr.message}`,
);
resetWsSession({ sessionId, session });
markWsSessionDegraded(session);
resetWsSession({
session,
createManager: () => createWsManager(opts.managerOptions, sessionHeaders),
preserveDegradeUntil: true,
});
return fallbackToHttp(model, context, options, apiKey, eventStream, opts.signal, {
suppressStart: emittedStart,
turnState: resolveProviderTransportTurnState(model, {
sessionId,
turnId,
attempt: turnAttempt,
transport: "stream",
}),
});
}
throw normalizedErr;
@@ -680,16 +861,30 @@ export function createOpenAIWebSocketStreamFn(
log.warn(
`[ws-stream] retrying websocket turn after retryable runtime failure for session=${sessionId} (${runtimeRetries}/${MAX_AUTO_WS_RUNTIME_RETRIES}). error=${normalizedErr.message}`,
);
resetWsSession({ sessionId, session });
resetWsSession({
session,
createManager: () => createWsManager(opts.managerOptions, sessionHeaders),
});
continue;
}
if (transport !== "websocket" && !signal?.aborted && !sawWsOutput) {
log.warn(
`[ws-stream] session=${sessionId} runtime failure before output; falling back to HTTP. error=${normalizedErr.message}`,
);
resetWsSession({ sessionId, session });
markWsSessionDegraded(session);
resetWsSession({
session,
createManager: () => createWsManager(opts.managerOptions, sessionHeaders),
preserveDegradeUntil: true,
});
return fallbackToHttp(model, context, options, apiKey, eventStream, opts.signal, {
suppressStart: true,
turnState: resolveProviderTransportTurnState(model, {
sessionId,
turnId,
attempt: turnAttempt,
transport: "stream",
}),
});
}
throw normalizedErr;
@@ -728,11 +923,35 @@ async function fallbackToHttp(
apiKey: string,
eventStream: AssistantMessageEventStreamLike,
signal?: AbortSignal,
fallbackOptions?: { suppressStart?: boolean },
fallbackOptions?: {
suppressStart?: boolean;
turnState?: ProviderTransportTurnState;
},
): Promise<void> {
const baseOnPayload = streamOptions?.onPayload;
const mergedOptions = {
...streamOptions,
apiKey,
...(fallbackOptions?.turnState?.headers
? {
headers: {
...streamOptions?.headers,
...fallbackOptions.turnState.headers,
},
}
: {}),
...(fallbackOptions?.turnState?.metadata
? {
onPayload: async (
payload: unknown,
payloadModel: Parameters<NonNullable<typeof baseOnPayload>>[1],
) => {
const nextPayload = await baseOnPayload?.(payload, payloadModel);
const resolvedPayload = (nextPayload ?? payload) as Record<string, unknown>;
return mergeTransportMetadata(resolvedPayload, fallbackOptions.turnState?.metadata);
},
}
: {}),
...(signal ? { signal } : {}),
};
const httpStream = openAIWsStreamDeps.streamSimple(model, context, mergedOptions);
@@ -753,4 +972,7 @@ export const __testing = {
}
: defaultOpenAIWsStreamDeps;
},
setWsDegradeCooldownMsForTest(nextMs?: number) {
wsDegradeCooldownMsOverride = nextMs;
},
};

View File

@@ -38,6 +38,26 @@ export function mergeTransportHeaders(
return Object.keys(merged).length > 0 ? merged : undefined;
}
export function mergeTransportMetadata<T extends Record<string, unknown>>(
payload: T,
metadata?: Record<string, string>,
): T {
if (!metadata || Object.keys(metadata).length === 0) {
return payload;
}
const existingMetadata =
payload.metadata && typeof payload.metadata === "object" && !Array.isArray(payload.metadata)
? (payload.metadata as Record<string, string>)
: undefined;
return {
...payload,
metadata: {
...existingMetadata,
...metadata,
},
};
}
export function createEmptyTransportUsage(): TransportUsage {
return {
input: 0,

View File

@@ -65,14 +65,18 @@ export type {
ProviderReplaySessionEntry,
ProviderReplaySessionState,
ProviderResolveDynamicModelContext,
ProviderResolveTransportTurnStateContext,
ProviderResolveWebSocketSessionPolicyContext,
ProviderResolvedUsageAuth,
RealtimeTranscriptionProviderPlugin,
ProviderSanitizeReplayHistoryContext,
ProviderTransportTurnState,
ProviderToolSchemaDiagnostic,
ProviderResolveUsageAuthContext,
ProviderRuntimeModel,
ProviderThinkingPolicyContext,
ProviderValidateReplayTurnsContext,
ProviderWebSocketSessionPolicy,
ProviderWrapStreamFnContext,
SpeechProviderPlugin,
} from "./plugin-entry.js";

View File

@@ -49,12 +49,16 @@ import type {
RealtimeTranscriptionProviderPlugin,
ProviderResolvedUsageAuth,
ProviderResolveDynamicModelContext,
ProviderResolveTransportTurnStateContext,
ProviderResolveWebSocketSessionPolicyContext,
ProviderSanitizeReplayHistoryContext,
ProviderTransportTurnState,
ProviderToolSchemaDiagnostic,
ProviderResolveUsageAuthContext,
ProviderRuntimeModel,
ProviderThinkingPolicyContext,
ProviderValidateReplayTurnsContext,
ProviderWebSocketSessionPolicy,
ProviderWrapStreamFnContext,
SpeechProviderPlugin,
PluginCommandContext,
@@ -101,12 +105,16 @@ export type {
ProviderSanitizeReplayHistoryContext,
ProviderResolveUsageAuthContext,
ProviderResolveDynamicModelContext,
ProviderResolveTransportTurnStateContext,
ProviderResolveWebSocketSessionPolicyContext,
ProviderNormalizeResolvedModelContext,
ProviderRuntimeModel,
RealtimeTranscriptionProviderPlugin,
ProviderTransportTurnState,
SpeechProviderPlugin,
ProviderThinkingPolicyContext,
ProviderValidateReplayTurnsContext,
ProviderWebSocketSessionPolicy,
ProviderWrapStreamFnContext,
OpenClawPluginService,
OpenClawPluginServiceContext,

View File

@@ -39,9 +39,13 @@ import type {
ProviderResolveUsageAuthContext,
ProviderPlugin,
ProviderResolveDynamicModelContext,
ProviderResolveTransportTurnStateContext,
ProviderResolveWebSocketSessionPolicyContext,
ProviderRuntimeModel,
ProviderThinkingPolicyContext,
ProviderTransportTurnState,
ProviderValidateReplayTurnsContext,
ProviderWebSocketSessionPolicy,
ProviderWrapStreamFnContext,
} from "./types.js";
@@ -525,6 +529,30 @@ export function wrapProviderStreamFn(params: {
return resolveProviderHookPlugin(params)?.wrapStreamFn?.(params.context) ?? undefined;
}
export function resolveProviderTransportTurnStateWithPlugin(params: {
provider: string;
config?: OpenClawConfig;
workspaceDir?: string;
env?: NodeJS.ProcessEnv;
context: ProviderResolveTransportTurnStateContext;
}): ProviderTransportTurnState | undefined {
return (
resolveProviderHookPlugin(params)?.resolveTransportTurnState?.(params.context) ?? undefined
);
}
export function resolveProviderWebSocketSessionPolicyWithPlugin(params: {
provider: string;
config?: OpenClawConfig;
workspaceDir?: string;
env?: NodeJS.ProcessEnv;
context: ProviderResolveWebSocketSessionPolicyContext;
}): ProviderWebSocketSessionPolicy | undefined {
return (
resolveProviderHookPlugin(params)?.resolveWebSocketSessionPolicy?.(params.context) ?? undefined
);
}
export async function createProviderEmbeddingProvider(params: {
provider: string;
config?: OpenClawConfig;

View File

@@ -697,6 +697,57 @@ export type ProviderWrapStreamFnContext = ProviderPrepareExtraParamsContext & {
streamFn?: StreamFn;
};
/**
* Provider-owned transport turn state.
*
* Use this for provider-native request headers or metadata that should stay
* stable across retries while still being attached by generic core transports.
*/
export type ProviderTransportTurnState = {
headers?: Record<string, string>;
metadata?: Record<string, string>;
};
/**
* Provider-owned request identity for transport turns.
*
* Use this when the provider exposes native request/session metadata that must
* be attached by both HTTP and WebSocket transports.
*/
export type ProviderResolveTransportTurnStateContext = {
provider: string;
modelId: string;
model?: ProviderRuntimeModel;
sessionId?: string;
turnId: string;
attempt: number;
transport: "stream" | "websocket";
};
/**
* Provider-owned WebSocket session policy.
*
* Use this for session-scoped headers or cool-down behavior that should apply
* before a generic WebSocket transport decides to retry or fall back.
*/
export type ProviderWebSocketSessionPolicy = {
headers?: Record<string, string>;
degradeCooldownMs?: number;
};
/**
* Provider-owned WebSocket session policy input.
*
* Use this when the provider wants to control native session handshake headers
* or the post-failure cool-down window for a generic WebSocket transport.
*/
export type ProviderResolveWebSocketSessionPolicyContext = {
provider: string;
modelId: string;
model?: ProviderRuntimeModel;
sessionId?: string;
};
/**
* Generic embedding provider shape returned by provider plugins.
*
@@ -1166,6 +1217,26 @@ export type ProviderPlugin = {
* transport implementation.
*/
wrapStreamFn?: (ctx: ProviderWrapStreamFnContext) => StreamFn | null | undefined;
/**
* Provider-owned native transport turn identity.
*
* Use this when a provider wants generic transports to attach provider-native
* request headers or metadata on each turn without hardcoding vendor logic in
* core.
*/
resolveTransportTurnState?: (
ctx: ProviderResolveTransportTurnStateContext,
) => ProviderTransportTurnState | null | undefined;
/**
* Provider-owned WebSocket session policy.
*
* Use this when a provider wants generic WebSocket transports to attach
* native session headers or tune the session-scoped cool-down before HTTP
* fallback.
*/
resolveWebSocketSessionPolicy?: (
ctx: ProviderResolveWebSocketSessionPolicyContext,
) => ProviderWebSocketSessionPolicy | null | undefined;
/**
* Provider-owned embedding provider factory.
*