mirror of
https://github.com/openclaw/openclaw.git
synced 2026-05-01 15:30:22 +00:00
refactor(openai): move native transport policy into extension
This commit is contained in:
@@ -501,6 +501,68 @@ describe("openai transport stream", () => {
|
||||
expect(params.tools?.[0]).not.toHaveProperty("strict");
|
||||
});
|
||||
|
||||
it("adds native OpenAI turn metadata on direct Responses routes", () => {
|
||||
const params = buildOpenAIResponsesParams(
|
||||
{
|
||||
id: "gpt-5.4",
|
||||
name: "GPT-5.4",
|
||||
api: "openai-responses",
|
||||
provider: "openai",
|
||||
baseUrl: "https://api.openai.com/v1",
|
||||
reasoning: true,
|
||||
input: ["text"],
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
|
||||
contextWindow: 200000,
|
||||
maxTokens: 8192,
|
||||
} satisfies Model<"openai-responses">,
|
||||
{
|
||||
systemPrompt: "system",
|
||||
messages: [],
|
||||
tools: [],
|
||||
} as never,
|
||||
{ sessionId: "session-123" } as never,
|
||||
{
|
||||
openclaw_session_id: "session-123",
|
||||
openclaw_turn_id: "turn-123",
|
||||
openclaw_turn_attempt: "1",
|
||||
openclaw_transport: "stream",
|
||||
},
|
||||
) as { metadata?: Record<string, string> };
|
||||
|
||||
expect(params.metadata).toMatchObject({
|
||||
openclaw_session_id: "session-123",
|
||||
openclaw_turn_id: "turn-123",
|
||||
openclaw_turn_attempt: "1",
|
||||
openclaw_transport: "stream",
|
||||
});
|
||||
});
|
||||
|
||||
it("leaves proxy-like OpenAI Responses routes without native turn metadata by default", () => {
|
||||
const params = buildOpenAIResponsesParams(
|
||||
{
|
||||
id: "custom-model",
|
||||
name: "Custom Model",
|
||||
api: "openai-responses",
|
||||
provider: "openai",
|
||||
baseUrl: "https://proxy.example.com/v1",
|
||||
reasoning: true,
|
||||
input: ["text"],
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
|
||||
contextWindow: 200000,
|
||||
maxTokens: 8192,
|
||||
} satisfies Model<"openai-responses">,
|
||||
{
|
||||
systemPrompt: "system",
|
||||
messages: [],
|
||||
tools: [],
|
||||
} as never,
|
||||
{ sessionId: "session-123" } as never,
|
||||
undefined,
|
||||
) as { metadata?: Record<string, string> };
|
||||
|
||||
expect(params).not.toHaveProperty("metadata");
|
||||
});
|
||||
|
||||
it("gates responses service_tier to native OpenAI endpoints", () => {
|
||||
const nativeParams = buildOpenAIResponsesParams(
|
||||
{
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import { randomUUID } from "node:crypto";
|
||||
import type { StreamFn } from "@mariozechner/pi-agent-core";
|
||||
import {
|
||||
calculateCost,
|
||||
@@ -18,6 +19,8 @@ import type {
|
||||
ResponseInput,
|
||||
ResponseInputMessageContentList,
|
||||
} from "openai/resources/responses/responses.js";
|
||||
import { resolveProviderTransportTurnStateWithPlugin } from "../plugins/provider-runtime.js";
|
||||
import type { ProviderRuntimeModel } from "../plugins/types.js";
|
||||
import { buildCopilotDynamicHeaders, hasCopilotVisionInput } from "./copilot-dynamic-headers.js";
|
||||
import { resolveOpenAICompletionsCompatDefaultsFromCapabilities } from "./openai-completions-compat.js";
|
||||
import {
|
||||
@@ -27,7 +30,7 @@ import {
|
||||
import { resolveProviderRequestCapabilities } from "./provider-attribution.js";
|
||||
import { buildGuardedModelFetch } from "./provider-transport-fetch.js";
|
||||
import { transformTransportMessages } from "./transport-message-transform.js";
|
||||
import { sanitizeTransportPayloadText } from "./transport-stream-shared.js";
|
||||
import { mergeTransportMetadata, sanitizeTransportPayloadText } from "./transport-stream-shared.js";
|
||||
|
||||
const DEFAULT_AZURE_OPENAI_API_VERSION = "2024-12-01-preview";
|
||||
|
||||
@@ -561,6 +564,7 @@ function buildOpenAIClientHeaders(
|
||||
model: Model<Api>,
|
||||
context: Context,
|
||||
optionHeaders?: Record<string, string>,
|
||||
turnHeaders?: Record<string, string>,
|
||||
): Record<string, string> {
|
||||
const headers = { ...model.headers };
|
||||
if (model.provider === "github-copilot") {
|
||||
@@ -575,20 +579,47 @@ function buildOpenAIClientHeaders(
|
||||
if (optionHeaders) {
|
||||
Object.assign(headers, optionHeaders);
|
||||
}
|
||||
if (turnHeaders) {
|
||||
Object.assign(headers, turnHeaders);
|
||||
}
|
||||
return headers;
|
||||
}
|
||||
|
||||
function resolveProviderTransportTurnState(
|
||||
model: Model<Api>,
|
||||
params: {
|
||||
sessionId?: string;
|
||||
turnId: string;
|
||||
attempt: number;
|
||||
transport: "stream" | "websocket";
|
||||
},
|
||||
) {
|
||||
return resolveProviderTransportTurnStateWithPlugin({
|
||||
provider: model.provider,
|
||||
context: {
|
||||
provider: model.provider,
|
||||
modelId: model.id,
|
||||
model: model as ProviderRuntimeModel,
|
||||
sessionId: params.sessionId,
|
||||
turnId: params.turnId,
|
||||
attempt: params.attempt,
|
||||
transport: params.transport,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
function createOpenAIResponsesClient(
|
||||
model: Model<Api>,
|
||||
context: Context,
|
||||
apiKey: string,
|
||||
optionHeaders?: Record<string, string>,
|
||||
turnHeaders?: Record<string, string>,
|
||||
) {
|
||||
return new OpenAI({
|
||||
apiKey,
|
||||
baseURL: model.baseUrl,
|
||||
dangerouslyAllowBrowser: true,
|
||||
defaultHeaders: buildOpenAIClientHeaders(model, context, optionHeaders),
|
||||
defaultHeaders: buildOpenAIClientHeaders(model, context, optionHeaders, turnHeaders),
|
||||
fetch: buildGuardedModelFetch(model),
|
||||
});
|
||||
}
|
||||
@@ -617,12 +648,30 @@ export function createOpenAIResponsesTransportStreamFn(): StreamFn {
|
||||
};
|
||||
try {
|
||||
const apiKey = options?.apiKey || getEnvApiKey(model.provider) || "";
|
||||
const client = createOpenAIResponsesClient(model, context, apiKey, options?.headers);
|
||||
let params = buildOpenAIResponsesParams(model, context, options as OpenAIResponsesOptions);
|
||||
const turnState = resolveProviderTransportTurnState(model, {
|
||||
sessionId: options?.sessionId,
|
||||
turnId: randomUUID(),
|
||||
attempt: 1,
|
||||
transport: "stream",
|
||||
});
|
||||
const client = createOpenAIResponsesClient(
|
||||
model,
|
||||
context,
|
||||
apiKey,
|
||||
options?.headers,
|
||||
turnState?.headers,
|
||||
);
|
||||
let params = buildOpenAIResponsesParams(
|
||||
model,
|
||||
context,
|
||||
options as OpenAIResponsesOptions,
|
||||
turnState?.metadata,
|
||||
);
|
||||
const nextParams = await options?.onPayload?.(params, model);
|
||||
if (nextParams !== undefined) {
|
||||
params = nextParams as typeof params;
|
||||
}
|
||||
params = mergeTransportMetadata(params, turnState?.metadata);
|
||||
const responseStream = (await client.responses.create(
|
||||
params as never,
|
||||
options?.signal ? { signal: options.signal } : undefined,
|
||||
@@ -675,6 +724,7 @@ export function buildOpenAIResponsesParams(
|
||||
model: Model<Api>,
|
||||
context: Context,
|
||||
options: OpenAIResponsesOptions | undefined,
|
||||
metadata?: Record<string, string>,
|
||||
) {
|
||||
const compat = getCompat(model as OpenAIModeModel);
|
||||
const supportsDeveloperRole =
|
||||
@@ -695,6 +745,7 @@ export function buildOpenAIResponsesParams(
|
||||
stream: true,
|
||||
prompt_cache_key: cacheRetention === "none" ? undefined : options?.sessionId,
|
||||
prompt_cache_retention: getPromptCacheRetention(model.baseUrl, cacheRetention),
|
||||
...(metadata ? { metadata } : {}),
|
||||
};
|
||||
if (options?.maxTokens) {
|
||||
params.max_output_tokens = options.maxTokens;
|
||||
@@ -749,18 +800,32 @@ export function createAzureOpenAIResponsesTransportStreamFn(): StreamFn {
|
||||
};
|
||||
try {
|
||||
const apiKey = options?.apiKey || getEnvApiKey(model.provider) || "";
|
||||
const client = createAzureOpenAIClient(model, context, apiKey, options?.headers);
|
||||
const turnState = resolveProviderTransportTurnState(model, {
|
||||
sessionId: options?.sessionId,
|
||||
turnId: randomUUID(),
|
||||
attempt: 1,
|
||||
transport: "stream",
|
||||
});
|
||||
const client = createAzureOpenAIClient(
|
||||
model,
|
||||
context,
|
||||
apiKey,
|
||||
options?.headers,
|
||||
turnState?.headers,
|
||||
);
|
||||
const deploymentName = resolveAzureDeploymentName(model);
|
||||
let params = buildAzureOpenAIResponsesParams(
|
||||
model,
|
||||
context,
|
||||
options as OpenAIResponsesOptions | undefined,
|
||||
deploymentName,
|
||||
turnState?.metadata,
|
||||
);
|
||||
const nextParams = await options?.onPayload?.(params, model);
|
||||
if (nextParams !== undefined) {
|
||||
params = nextParams as typeof params;
|
||||
}
|
||||
params = mergeTransportMetadata(params, turnState?.metadata);
|
||||
const responseStream = (await client.responses.create(
|
||||
params as never,
|
||||
options?.signal ? { signal: options.signal } : undefined,
|
||||
@@ -808,12 +873,13 @@ function createAzureOpenAIClient(
|
||||
context: Context,
|
||||
apiKey: string,
|
||||
optionHeaders?: Record<string, string>,
|
||||
turnHeaders?: Record<string, string>,
|
||||
) {
|
||||
return new AzureOpenAI({
|
||||
apiKey,
|
||||
apiVersion: resolveAzureOpenAIApiVersion(),
|
||||
dangerouslyAllowBrowser: true,
|
||||
defaultHeaders: buildOpenAIClientHeaders(model, context, optionHeaders),
|
||||
defaultHeaders: buildOpenAIClientHeaders(model, context, optionHeaders, turnHeaders),
|
||||
baseURL: normalizeAzureBaseUrl(model.baseUrl),
|
||||
fetch: buildGuardedModelFetch(model),
|
||||
});
|
||||
@@ -824,8 +890,9 @@ function buildAzureOpenAIResponsesParams(
|
||||
context: Context,
|
||||
options: OpenAIResponsesOptions | undefined,
|
||||
deploymentName: string,
|
||||
metadata?: Record<string, string>,
|
||||
) {
|
||||
const params = buildOpenAIResponsesParams(model, context, options);
|
||||
const params = buildOpenAIResponsesParams(model, context, options, metadata);
|
||||
params.model = deploymentName;
|
||||
delete params.store;
|
||||
return params;
|
||||
@@ -1148,6 +1215,7 @@ type OpenAIResponsesRequestParams = {
|
||||
stream: true;
|
||||
prompt_cache_key?: string;
|
||||
prompt_cache_retention?: "24h";
|
||||
metadata?: Record<string, string>;
|
||||
store?: boolean;
|
||||
max_output_tokens?: number;
|
||||
temperature?: number;
|
||||
|
||||
@@ -252,6 +252,27 @@ describe("OpenAIWebSocketManager", () => {
|
||||
await connectPromise;
|
||||
});
|
||||
|
||||
it("merges native session headers into the websocket handshake", async () => {
|
||||
const manager = buildManager({
|
||||
headers: {
|
||||
"x-client-request-id": "session-123",
|
||||
"x-openclaw-session-id": "session-123",
|
||||
},
|
||||
});
|
||||
const connectPromise = manager.connect("sk-test-key");
|
||||
|
||||
const sock = lastSocket();
|
||||
expect(sock.options).toMatchObject({
|
||||
headers: expect.objectContaining({
|
||||
"x-client-request-id": "session-123",
|
||||
"x-openclaw-session-id": "session-123",
|
||||
}),
|
||||
});
|
||||
|
||||
sock.simulateOpen();
|
||||
await connectPromise;
|
||||
});
|
||||
|
||||
it("does not add hidden attribution headers on custom websocket endpoints", async () => {
|
||||
const manager = buildManager({
|
||||
url: "wss://proxy.example.com/v1/responses",
|
||||
|
||||
@@ -281,6 +281,8 @@ export interface OpenAIWebSocketManagerOptions {
|
||||
backoffDelaysMs?: readonly number[];
|
||||
/** Custom socket factory for tests. */
|
||||
socketFactory?: (url: string, options: ClientOptions) => WebSocket;
|
||||
/** Extra headers merged into the initial WebSocket handshake request. */
|
||||
headers?: Record<string, string>;
|
||||
/** Optional transport overrides for provider-owned auth or TLS wiring. */
|
||||
request?: ProviderRequestTransportOverrides;
|
||||
}
|
||||
@@ -338,6 +340,7 @@ export class OpenAIWebSocketManager extends EventEmitter<InternalEvents> {
|
||||
private readonly maxRetries: number;
|
||||
private readonly backoffDelaysMs: readonly number[];
|
||||
private readonly socketFactory: (url: string, options: ClientOptions) => WebSocket;
|
||||
private readonly headers?: Record<string, string>;
|
||||
private readonly request?: ProviderRequestTransportOverrides;
|
||||
|
||||
constructor(options: OpenAIWebSocketManagerOptions = {}) {
|
||||
@@ -347,6 +350,7 @@ export class OpenAIWebSocketManager extends EventEmitter<InternalEvents> {
|
||||
this.backoffDelaysMs = options.backoffDelaysMs ?? BACKOFF_DELAYS_MS;
|
||||
this.socketFactory =
|
||||
options.socketFactory ?? ((url, socketOptions) => new WebSocket(url, socketOptions));
|
||||
this.headers = options.headers;
|
||||
this.request = options.request;
|
||||
}
|
||||
|
||||
@@ -454,6 +458,7 @@ export class OpenAIWebSocketManager extends EventEmitter<InternalEvents> {
|
||||
providerHeaders: {
|
||||
Authorization: `Bearer ${this.apiKey}`,
|
||||
"OpenAI-Beta": "responses-websocket=v1",
|
||||
...this.headers,
|
||||
},
|
||||
precedence: "defaults-win",
|
||||
request: this.request,
|
||||
@@ -607,7 +612,12 @@ export class OpenAIWebSocketManager extends EventEmitter<InternalEvents> {
|
||||
* Sends a warm-up event to pre-load the connection and model without generating output.
|
||||
* Pass tools/instructions to prime the connection for the upcoming session.
|
||||
*/
|
||||
warmUp(params: { model: string; tools?: FunctionToolDefinition[]; instructions?: string }): void {
|
||||
warmUp(params: {
|
||||
model: string;
|
||||
tools?: FunctionToolDefinition[];
|
||||
instructions?: string;
|
||||
metadata?: Record<string, string>;
|
||||
}): void {
|
||||
const event = buildOpenAIWebSocketWarmUpPayload(params);
|
||||
this.send(event);
|
||||
}
|
||||
|
||||
@@ -30,6 +30,7 @@ export function buildOpenAIWebSocketWarmUpPayload(params: {
|
||||
model: string;
|
||||
tools?: FunctionToolDefinition[];
|
||||
instructions?: string;
|
||||
metadata?: Record<string, string>;
|
||||
}): WarmUpEvent {
|
||||
return {
|
||||
type: "response.create",
|
||||
@@ -38,6 +39,7 @@ export function buildOpenAIWebSocketWarmUpPayload(params: {
|
||||
input: [],
|
||||
...(params.tools?.length ? { tools: params.tools } : {}),
|
||||
...(params.instructions ? { instructions: params.instructions } : {}),
|
||||
...(params.metadata ? { metadata: params.metadata } : {}),
|
||||
};
|
||||
}
|
||||
|
||||
@@ -47,6 +49,7 @@ export function buildOpenAIWebSocketResponseCreatePayload(params: {
|
||||
options?: WsOptions;
|
||||
turnInput: PlannedWsTurnInput;
|
||||
tools: FunctionToolDefinition[];
|
||||
metadata?: Record<string, string>;
|
||||
}): ResponseCreateEvent {
|
||||
const extraParams: Record<string, unknown> = {};
|
||||
const streamOpts = params.options;
|
||||
@@ -108,6 +111,7 @@ export function buildOpenAIWebSocketResponseCreatePayload(params: {
|
||||
...(params.turnInput.previousResponseId
|
||||
? { previous_response_id: params.turnInput.previousResponseId }
|
||||
: {}),
|
||||
...(params.metadata ? { metadata: params.metadata } : {}),
|
||||
...extraParams,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -47,11 +47,17 @@ const { MockManager } = vi.hoisted(() => {
|
||||
sentEvents: unknown[] = [];
|
||||
connectCallCount = 0;
|
||||
closeCallCount = 0;
|
||||
options: unknown;
|
||||
|
||||
// Allow tests to override connect/send behaviour
|
||||
connectShouldFail = false;
|
||||
sendShouldFail = false;
|
||||
|
||||
constructor(options?: unknown) {
|
||||
super();
|
||||
this.options = options;
|
||||
}
|
||||
|
||||
get previousResponseId(): string | null {
|
||||
return this._previousResponseId;
|
||||
}
|
||||
@@ -201,9 +207,9 @@ const { MockManager } = vi.hoisted(() => {
|
||||
});
|
||||
|
||||
// Track if streamSimple (HTTP fallback) was called
|
||||
const streamSimpleCalls: Array<{ model: unknown; context: unknown }> = [];
|
||||
const mockStreamSimple = vi.fn((model: unknown, context: unknown) => {
|
||||
streamSimpleCalls.push({ model, context });
|
||||
const streamSimpleCalls: Array<{ model: unknown; context: unknown; options?: unknown }> = [];
|
||||
const mockStreamSimple = vi.fn((model: unknown, context: unknown, options?: unknown) => {
|
||||
streamSimpleCalls.push({ model, context, options });
|
||||
const stream = createAssistantMessageEventStream();
|
||||
queueMicrotask(() => {
|
||||
const msg = makeFakeAssistantMessage("http fallback response");
|
||||
@@ -1174,7 +1180,7 @@ describe("createOpenAIWebSocketStreamFn", () => {
|
||||
MockManager.reset();
|
||||
streamSimpleCalls.length = 0;
|
||||
openAIWsStreamTesting.setDepsForTest({
|
||||
createManager: (() => new MockManager()) as never,
|
||||
createManager: ((options?: unknown) => new MockManager(options)) as never,
|
||||
streamSimple: mockStreamSimple,
|
||||
});
|
||||
});
|
||||
@@ -1196,7 +1202,10 @@ describe("createOpenAIWebSocketStreamFn", () => {
|
||||
releaseWsSession("sess-store-compat");
|
||||
releaseWsSession("sess-max-tokens-zero");
|
||||
releaseWsSession("sess-runtime-fallback");
|
||||
releaseWsSession("sess-turn-metadata-retry");
|
||||
releaseWsSession("sess-degraded-cooldown");
|
||||
releaseWsSession("sess-drop");
|
||||
openAIWsStreamTesting.setWsDegradeCooldownMsForTest();
|
||||
openAIWsStreamTesting.setDepsForTest();
|
||||
});
|
||||
|
||||
@@ -1447,8 +1456,8 @@ describe("createOpenAIWebSocketStreamFn", () => {
|
||||
// streamSimple was called as part of HTTP fallback
|
||||
expect(streamSimpleCalls.length).toBeGreaterThanOrEqual(1);
|
||||
|
||||
// manager.close() must be called to cancel background reconnect attempts
|
||||
expect(MockManager.lastInstance!.closeCallCount).toBeGreaterThanOrEqual(1);
|
||||
// The failed manager is closed before the replacement session manager is installed.
|
||||
expect(MockManager.instances.some((instance) => instance.closeCallCount >= 1)).toBe(true);
|
||||
} finally {
|
||||
MockManager.globalConnectShouldFail = false;
|
||||
}
|
||||
@@ -1550,6 +1559,103 @@ describe("createOpenAIWebSocketStreamFn", () => {
|
||||
const doneEvent = events.find((event) => event.type === "done");
|
||||
expect(doneEvent?.message?.content?.[0]?.text).toBe("retry succeeded");
|
||||
});
|
||||
|
||||
it("keeps native turn metadata stable across websocket retries and increments attempt", async () => {
|
||||
const streamFn = createOpenAIWebSocketStreamFn("sk-test", "sess-turn-metadata-retry");
|
||||
const stream = streamFn(
|
||||
modelStub as Parameters<typeof streamFn>[0],
|
||||
contextStub as Parameters<typeof streamFn>[1],
|
||||
{ transport: "auto" } as Parameters<typeof streamFn>[2],
|
||||
);
|
||||
|
||||
await new Promise((r) => setImmediate(r));
|
||||
const firstManager = MockManager.lastInstance!;
|
||||
firstManager.simulateClose(1006, "connection lost");
|
||||
|
||||
await new Promise((r) => setImmediate(r));
|
||||
const secondManager = MockManager.lastInstance!;
|
||||
secondManager.simulateEvent({
|
||||
type: "response.completed",
|
||||
response: makeResponseObject("resp-retried-meta", "retry succeeded"),
|
||||
});
|
||||
|
||||
for await (const _ of await resolveStream(stream)) {
|
||||
// consume
|
||||
}
|
||||
|
||||
const firstPayload = firstManager.sentEvents[0] as { metadata?: Record<string, string> };
|
||||
const secondPayload = secondManager.sentEvents[0] as { metadata?: Record<string, string> };
|
||||
expect(firstPayload.metadata?.openclaw_session_id).toBe("sess-turn-metadata-retry");
|
||||
expect(firstPayload.metadata?.openclaw_transport).toBe("websocket");
|
||||
expect(firstPayload.metadata?.openclaw_turn_id).toBeTruthy();
|
||||
expect(secondPayload.metadata?.openclaw_turn_id).toBe(firstPayload.metadata?.openclaw_turn_id);
|
||||
expect(firstPayload.metadata?.openclaw_turn_attempt).toBe("1");
|
||||
expect(secondPayload.metadata?.openclaw_turn_attempt).toBe("2");
|
||||
});
|
||||
|
||||
it("keeps websocket degraded for the session until the cool-down expires", async () => {
|
||||
openAIWsStreamTesting.setWsDegradeCooldownMsForTest(50);
|
||||
MockManager.globalConnectShouldFail = true;
|
||||
|
||||
try {
|
||||
const sessionId = "sess-degraded-cooldown";
|
||||
const streamFn = createOpenAIWebSocketStreamFn("sk-test", sessionId);
|
||||
|
||||
const firstStream = streamFn(
|
||||
modelStub as Parameters<typeof streamFn>[0],
|
||||
contextStub as Parameters<typeof streamFn>[1],
|
||||
{ transport: "auto" } as Parameters<typeof streamFn>[2],
|
||||
);
|
||||
void firstStream;
|
||||
await new Promise((resolve) => setImmediate(resolve));
|
||||
await new Promise((resolve) => setImmediate(resolve));
|
||||
|
||||
expect(streamSimpleCalls.length).toBe(1);
|
||||
expect(MockManager.instances).toHaveLength(2);
|
||||
const cooledManager = MockManager.lastInstance!;
|
||||
expect(cooledManager.connectCallCount).toBe(0);
|
||||
|
||||
MockManager.globalConnectShouldFail = false;
|
||||
|
||||
const secondStream = streamFn(
|
||||
modelStub as Parameters<typeof streamFn>[0],
|
||||
contextStub as Parameters<typeof streamFn>[1],
|
||||
{ transport: "auto" } as Parameters<typeof streamFn>[2],
|
||||
);
|
||||
void secondStream;
|
||||
await new Promise((resolve) => setImmediate(resolve));
|
||||
await new Promise((resolve) => setImmediate(resolve));
|
||||
|
||||
expect(streamSimpleCalls.length).toBe(2);
|
||||
expect(MockManager.instances).toHaveLength(2);
|
||||
expect(cooledManager.connectCallCount).toBe(0);
|
||||
|
||||
await new Promise((resolve) => setTimeout(resolve, 60));
|
||||
|
||||
const thirdStream = streamFn(
|
||||
modelStub as Parameters<typeof streamFn>[0],
|
||||
contextStub as Parameters<typeof streamFn>[1],
|
||||
{ transport: "auto" } as Parameters<typeof streamFn>[2],
|
||||
);
|
||||
|
||||
void thirdStream;
|
||||
await new Promise((resolve) => setImmediate(resolve));
|
||||
await new Promise((resolve) => setImmediate(resolve));
|
||||
expect(cooledManager.connectCallCount).toBe(1);
|
||||
expect(streamSimpleCalls.length).toBe(2);
|
||||
cooledManager.simulateEvent({
|
||||
type: "response.completed",
|
||||
response: makeResponseObject("resp-after-cooldown", "ws recovered"),
|
||||
});
|
||||
await new Promise((resolve) => setImmediate(resolve));
|
||||
} finally {
|
||||
MockManager.globalConnectShouldFail = false;
|
||||
openAIWsStreamTesting.setWsDegradeCooldownMsForTest();
|
||||
releaseWsSession("sess-degraded-cooldown");
|
||||
releaseWsSession("sess-turn-metadata-retry");
|
||||
}
|
||||
});
|
||||
|
||||
it("tracks previous_response_id across turns (incremental send)", async () => {
|
||||
const sessionId = "sess-incremental";
|
||||
const streamFn = createOpenAIWebSocketStreamFn("sk-test", sessionId);
|
||||
|
||||
@@ -21,6 +21,7 @@
|
||||
* @see src/agents/openai-ws-connection.ts for the connection manager
|
||||
*/
|
||||
|
||||
import { randomUUID } from "node:crypto";
|
||||
import type { StreamFn } from "@mariozechner/pi-agent-core";
|
||||
import type {
|
||||
AssistantMessage,
|
||||
@@ -29,6 +30,11 @@ import type {
|
||||
StopReason,
|
||||
} from "@mariozechner/pi-ai";
|
||||
import * as piAi from "@mariozechner/pi-ai";
|
||||
import {
|
||||
resolveProviderTransportTurnStateWithPlugin,
|
||||
resolveProviderWebSocketSessionPolicyWithPlugin,
|
||||
} from "../plugins/provider-runtime.js";
|
||||
import type { ProviderRuntimeModel, ProviderTransportTurnState } from "../plugins/types.js";
|
||||
import {
|
||||
getOpenAIWebSocketErrorDetails,
|
||||
OpenAIWebSocketManager,
|
||||
@@ -47,6 +53,7 @@ import {
|
||||
buildAssistantMessageWithZeroUsage,
|
||||
buildStreamErrorAssistantMessage,
|
||||
} from "./stream-message-shared.js";
|
||||
import { mergeTransportMetadata } from "./transport-stream-shared.js";
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// Per-session state
|
||||
@@ -62,6 +69,9 @@ interface WsSession {
|
||||
warmUpAttempted: boolean;
|
||||
/** True if the session is permanently broken (no more reconnect). */
|
||||
broken: boolean;
|
||||
/** Session-scoped cool-down after repeated websocket failures. */
|
||||
degradedUntil: number | null;
|
||||
degradeCooldownMs: number;
|
||||
}
|
||||
|
||||
/** Module-level registry: sessionId → WsSession */
|
||||
@@ -208,6 +218,8 @@ export interface OpenAIWebSocketStreamOptions {
|
||||
type WsTransport = "sse" | "websocket" | "auto";
|
||||
const WARM_UP_TIMEOUT_MS = 8_000;
|
||||
const MAX_AUTO_WS_RUNTIME_RETRIES = 1;
|
||||
const DEFAULT_WS_DEGRADE_COOLDOWN_MS = 60_000;
|
||||
let wsDegradeCooldownMsOverride: number | undefined;
|
||||
|
||||
class OpenAIWebSocketRuntimeError extends Error {
|
||||
readonly kind: "disconnect" | "send" | "server";
|
||||
@@ -247,13 +259,100 @@ function resolveWsWarmup(options: Parameters<StreamFn>[2]): boolean {
|
||||
return warmup === true;
|
||||
}
|
||||
|
||||
function resetWsSession(params: { sessionId: string; session: WsSession }): void {
|
||||
function resetWsSession(params: {
|
||||
session: WsSession;
|
||||
createManager: () => OpenAIWebSocketManager;
|
||||
preserveDegradeUntil?: boolean;
|
||||
}): void {
|
||||
try {
|
||||
params.session.manager.close();
|
||||
} catch {
|
||||
/* ignore */
|
||||
}
|
||||
wsRegistry.delete(params.sessionId);
|
||||
params.session.manager = params.createManager();
|
||||
params.session.everConnected = false;
|
||||
params.session.warmUpAttempted = false;
|
||||
params.session.broken = false;
|
||||
if (!params.preserveDegradeUntil) {
|
||||
params.session.degradedUntil = null;
|
||||
}
|
||||
}
|
||||
|
||||
function markWsSessionDegraded(session: WsSession): void {
|
||||
session.degradedUntil = Date.now() + session.degradeCooldownMs;
|
||||
}
|
||||
|
||||
function isWsSessionDegraded(session: WsSession): boolean {
|
||||
if (!session.degradedUntil) {
|
||||
return false;
|
||||
}
|
||||
if (session.degradedUntil <= Date.now()) {
|
||||
session.degradedUntil = null;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
function createWsManager(
|
||||
managerOptions: OpenAIWebSocketManagerOptions | undefined,
|
||||
sessionHeaders?: Record<string, string>,
|
||||
): OpenAIWebSocketManager {
|
||||
return openAIWsStreamDeps.createManager({
|
||||
...managerOptions,
|
||||
...(sessionHeaders
|
||||
? {
|
||||
headers: {
|
||||
...managerOptions?.headers,
|
||||
...sessionHeaders,
|
||||
},
|
||||
}
|
||||
: {}),
|
||||
});
|
||||
}
|
||||
|
||||
function resolveProviderTransportTurnState(
|
||||
model: Parameters<StreamFn>[0],
|
||||
params: {
|
||||
sessionId?: string;
|
||||
turnId: string;
|
||||
attempt: number;
|
||||
transport: "stream" | "websocket";
|
||||
},
|
||||
): ProviderTransportTurnState | undefined {
|
||||
return resolveProviderTransportTurnStateWithPlugin({
|
||||
provider: model.provider,
|
||||
context: {
|
||||
provider: model.provider,
|
||||
modelId: model.id,
|
||||
model: model as ProviderRuntimeModel,
|
||||
sessionId: params.sessionId,
|
||||
turnId: params.turnId,
|
||||
attempt: params.attempt,
|
||||
transport: params.transport,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
function resolveWebSocketSessionPolicy(
|
||||
model: Parameters<StreamFn>[0],
|
||||
sessionId: string,
|
||||
): { headers?: Record<string, string>; degradeCooldownMs: number } {
|
||||
const policy = resolveProviderWebSocketSessionPolicyWithPlugin({
|
||||
provider: model.provider,
|
||||
context: {
|
||||
provider: model.provider,
|
||||
modelId: model.id,
|
||||
model: model as ProviderRuntimeModel,
|
||||
sessionId,
|
||||
},
|
||||
});
|
||||
return {
|
||||
headers: policy?.headers,
|
||||
degradeCooldownMs: Math.max(
|
||||
0,
|
||||
wsDegradeCooldownMsOverride ?? policy?.degradeCooldownMs ?? DEFAULT_WS_DEGRADE_COOLDOWN_MS,
|
||||
),
|
||||
};
|
||||
}
|
||||
|
||||
function formatOpenAIWebSocketError(
|
||||
@@ -311,6 +410,7 @@ async function runWarmUp(params: {
|
||||
modelId: string;
|
||||
tools: FunctionToolDefinition[];
|
||||
instructions?: string;
|
||||
metadata?: Record<string, string>;
|
||||
signal?: AbortSignal;
|
||||
}): Promise<void> {
|
||||
if (params.signal?.aborted) {
|
||||
@@ -358,6 +458,7 @@ async function runWarmUp(params: {
|
||||
model: params.modelId,
|
||||
tools: params.tools.length > 0 ? params.tools : undefined,
|
||||
instructions: params.instructions,
|
||||
...(params.metadata ? { metadata: params.metadata } : {}),
|
||||
});
|
||||
});
|
||||
}
|
||||
@@ -392,34 +493,55 @@ export function createOpenAIWebSocketStreamFn(
|
||||
const signal = opts.signal ?? (options as WsOptions | undefined)?.signal;
|
||||
let emittedStart = false;
|
||||
let runtimeRetries = 0;
|
||||
const turnId = randomUUID();
|
||||
let turnAttempt = 0;
|
||||
const wsSessionPolicy = resolveWebSocketSessionPolicy(model, sessionId);
|
||||
const sessionHeaders = wsSessionPolicy.headers;
|
||||
|
||||
while (true) {
|
||||
let session = wsRegistry.get(sessionId);
|
||||
if (!session) {
|
||||
const manager = openAIWsStreamDeps.createManager(opts.managerOptions);
|
||||
const manager = createWsManager(opts.managerOptions, sessionHeaders);
|
||||
session = {
|
||||
manager,
|
||||
lastContextLength: 0,
|
||||
everConnected: false,
|
||||
warmUpAttempted: false,
|
||||
broken: false,
|
||||
degradedUntil: null,
|
||||
degradeCooldownMs: wsSessionPolicy.degradeCooldownMs,
|
||||
};
|
||||
wsRegistry.set(sessionId, session);
|
||||
}
|
||||
|
||||
if (transport !== "websocket" && isWsSessionDegraded(session)) {
|
||||
log.debug(
|
||||
`[ws-stream] session=${sessionId} in websocket cool-down; using HTTP fallback until ${new Date(session.degradedUntil!).toISOString()}`,
|
||||
);
|
||||
return fallbackToHttp(model, context, options, apiKey, eventStream, opts.signal, {
|
||||
suppressStart: emittedStart,
|
||||
turnState: resolveProviderTransportTurnState(model, {
|
||||
sessionId,
|
||||
turnId,
|
||||
attempt: Math.max(1, turnAttempt),
|
||||
transport: "stream",
|
||||
}),
|
||||
});
|
||||
}
|
||||
|
||||
if (!session.manager.isConnected() && !session.broken) {
|
||||
try {
|
||||
await session.manager.connect(apiKey);
|
||||
session.everConnected = true;
|
||||
session.degradedUntil = null;
|
||||
log.debug(`[ws-stream] connected for session=${sessionId}`);
|
||||
} catch (connErr) {
|
||||
try {
|
||||
session.manager.close();
|
||||
} catch {
|
||||
/* ignore */
|
||||
}
|
||||
session.broken = true;
|
||||
wsRegistry.delete(sessionId);
|
||||
markWsSessionDegraded(session);
|
||||
resetWsSession({
|
||||
session,
|
||||
createManager: () => createWsManager(opts.managerOptions, sessionHeaders),
|
||||
preserveDegradeUntil: true,
|
||||
});
|
||||
if (transport === "websocket") {
|
||||
throw connErr instanceof Error ? connErr : new Error(String(connErr));
|
||||
}
|
||||
@@ -428,6 +550,12 @@ export function createOpenAIWebSocketStreamFn(
|
||||
);
|
||||
return fallbackToHttp(model, context, options, apiKey, eventStream, opts.signal, {
|
||||
suppressStart: emittedStart,
|
||||
turnState: resolveProviderTransportTurnState(model, {
|
||||
sessionId,
|
||||
turnId,
|
||||
attempt: Math.max(1, turnAttempt),
|
||||
transport: "stream",
|
||||
}),
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -437,9 +565,20 @@ export function createOpenAIWebSocketStreamFn(
|
||||
throw new Error("WebSocket session disconnected");
|
||||
}
|
||||
log.warn(`[ws-stream] session=${sessionId} broken/disconnected; falling back to HTTP`);
|
||||
resetWsSession({ sessionId, session });
|
||||
markWsSessionDegraded(session);
|
||||
resetWsSession({
|
||||
session,
|
||||
createManager: () => createWsManager(opts.managerOptions, sessionHeaders),
|
||||
preserveDegradeUntil: true,
|
||||
});
|
||||
return fallbackToHttp(model, context, options, apiKey, eventStream, opts.signal, {
|
||||
suppressStart: emittedStart,
|
||||
turnState: resolveProviderTransportTurnState(model, {
|
||||
sessionId,
|
||||
turnId,
|
||||
attempt: Math.max(1, turnAttempt),
|
||||
transport: "stream",
|
||||
}),
|
||||
});
|
||||
}
|
||||
|
||||
@@ -452,6 +591,12 @@ export function createOpenAIWebSocketStreamFn(
|
||||
modelId: model.id,
|
||||
tools: convertTools(context.tools),
|
||||
instructions: context.systemPrompt ?? undefined,
|
||||
metadata: resolveProviderTransportTurnState(model, {
|
||||
sessionId,
|
||||
turnId,
|
||||
attempt: Math.max(1, turnAttempt),
|
||||
transport: "websocket",
|
||||
})?.metadata,
|
||||
signal,
|
||||
});
|
||||
log.debug(`[ws-stream] warm-up completed for session=${sessionId}`);
|
||||
@@ -471,12 +616,18 @@ export function createOpenAIWebSocketStreamFn(
|
||||
/* ignore */
|
||||
}
|
||||
try {
|
||||
session.manager = createWsManager(opts.managerOptions, sessionHeaders);
|
||||
await session.manager.connect(apiKey);
|
||||
session.everConnected = true;
|
||||
session.degradedUntil = null;
|
||||
log.debug(`[ws-stream] reconnected after warm-up failure for session=${sessionId}`);
|
||||
} catch (reconnectErr) {
|
||||
session.broken = true;
|
||||
wsRegistry.delete(sessionId);
|
||||
markWsSessionDegraded(session);
|
||||
resetWsSession({
|
||||
session,
|
||||
createManager: () => createWsManager(opts.managerOptions, sessionHeaders),
|
||||
preserveDegradeUntil: true,
|
||||
});
|
||||
if (transport === "websocket") {
|
||||
throw reconnectErr instanceof Error
|
||||
? reconnectErr
|
||||
@@ -487,6 +638,12 @@ export function createOpenAIWebSocketStreamFn(
|
||||
);
|
||||
return fallbackToHttp(model, context, options, apiKey, eventStream, opts.signal, {
|
||||
suppressStart: emittedStart,
|
||||
turnState: resolveProviderTransportTurnState(model, {
|
||||
sessionId,
|
||||
turnId,
|
||||
attempt: Math.max(1, turnAttempt),
|
||||
transport: "stream",
|
||||
}),
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -513,17 +670,27 @@ export function createOpenAIWebSocketStreamFn(
|
||||
);
|
||||
}
|
||||
|
||||
const payload = buildOpenAIWebSocketResponseCreatePayload({
|
||||
turnAttempt++;
|
||||
const turnState = resolveProviderTransportTurnState(model, {
|
||||
sessionId,
|
||||
turnId,
|
||||
attempt: turnAttempt,
|
||||
transport: "websocket",
|
||||
});
|
||||
let payload = buildOpenAIWebSocketResponseCreatePayload({
|
||||
model,
|
||||
context,
|
||||
options: options as WsOptions | undefined,
|
||||
turnInput,
|
||||
tools: convertTools(context.tools),
|
||||
metadata: turnState?.metadata,
|
||||
}) as Record<string, unknown>;
|
||||
const nextPayload = options?.onPayload?.(payload, model);
|
||||
const requestPayload = (nextPayload ?? payload) as Parameters<
|
||||
OpenAIWebSocketManager["send"]
|
||||
>[0];
|
||||
payload = mergeTransportMetadata(
|
||||
(nextPayload ?? payload) as Record<string, unknown>,
|
||||
turnState?.metadata,
|
||||
);
|
||||
const requestPayload = payload as Parameters<OpenAIWebSocketManager["send"]>[0];
|
||||
|
||||
try {
|
||||
session.manager.send(requestPayload);
|
||||
@@ -538,16 +705,30 @@ export function createOpenAIWebSocketStreamFn(
|
||||
log.warn(
|
||||
`[ws-stream] retrying websocket turn after send failure for session=${sessionId} (${runtimeRetries}/${MAX_AUTO_WS_RUNTIME_RETRIES}). error=${normalizedErr.message}`,
|
||||
);
|
||||
resetWsSession({ sessionId, session });
|
||||
resetWsSession({
|
||||
session,
|
||||
createManager: () => createWsManager(opts.managerOptions, sessionHeaders),
|
||||
});
|
||||
continue;
|
||||
}
|
||||
if (transport !== "websocket") {
|
||||
log.warn(
|
||||
`[ws-stream] send failed for session=${sessionId}; falling back to HTTP. error=${normalizedErr.message}`,
|
||||
);
|
||||
resetWsSession({ sessionId, session });
|
||||
markWsSessionDegraded(session);
|
||||
resetWsSession({
|
||||
session,
|
||||
createManager: () => createWsManager(opts.managerOptions, sessionHeaders),
|
||||
preserveDegradeUntil: true,
|
||||
});
|
||||
return fallbackToHttp(model, context, options, apiKey, eventStream, opts.signal, {
|
||||
suppressStart: emittedStart,
|
||||
turnState: resolveProviderTransportTurnState(model, {
|
||||
sessionId,
|
||||
turnId,
|
||||
attempt: turnAttempt,
|
||||
transport: "stream",
|
||||
}),
|
||||
});
|
||||
}
|
||||
throw normalizedErr;
|
||||
@@ -680,16 +861,30 @@ export function createOpenAIWebSocketStreamFn(
|
||||
log.warn(
|
||||
`[ws-stream] retrying websocket turn after retryable runtime failure for session=${sessionId} (${runtimeRetries}/${MAX_AUTO_WS_RUNTIME_RETRIES}). error=${normalizedErr.message}`,
|
||||
);
|
||||
resetWsSession({ sessionId, session });
|
||||
resetWsSession({
|
||||
session,
|
||||
createManager: () => createWsManager(opts.managerOptions, sessionHeaders),
|
||||
});
|
||||
continue;
|
||||
}
|
||||
if (transport !== "websocket" && !signal?.aborted && !sawWsOutput) {
|
||||
log.warn(
|
||||
`[ws-stream] session=${sessionId} runtime failure before output; falling back to HTTP. error=${normalizedErr.message}`,
|
||||
);
|
||||
resetWsSession({ sessionId, session });
|
||||
markWsSessionDegraded(session);
|
||||
resetWsSession({
|
||||
session,
|
||||
createManager: () => createWsManager(opts.managerOptions, sessionHeaders),
|
||||
preserveDegradeUntil: true,
|
||||
});
|
||||
return fallbackToHttp(model, context, options, apiKey, eventStream, opts.signal, {
|
||||
suppressStart: true,
|
||||
turnState: resolveProviderTransportTurnState(model, {
|
||||
sessionId,
|
||||
turnId,
|
||||
attempt: turnAttempt,
|
||||
transport: "stream",
|
||||
}),
|
||||
});
|
||||
}
|
||||
throw normalizedErr;
|
||||
@@ -728,11 +923,35 @@ async function fallbackToHttp(
|
||||
apiKey: string,
|
||||
eventStream: AssistantMessageEventStreamLike,
|
||||
signal?: AbortSignal,
|
||||
fallbackOptions?: { suppressStart?: boolean },
|
||||
fallbackOptions?: {
|
||||
suppressStart?: boolean;
|
||||
turnState?: ProviderTransportTurnState;
|
||||
},
|
||||
): Promise<void> {
|
||||
const baseOnPayload = streamOptions?.onPayload;
|
||||
const mergedOptions = {
|
||||
...streamOptions,
|
||||
apiKey,
|
||||
...(fallbackOptions?.turnState?.headers
|
||||
? {
|
||||
headers: {
|
||||
...streamOptions?.headers,
|
||||
...fallbackOptions.turnState.headers,
|
||||
},
|
||||
}
|
||||
: {}),
|
||||
...(fallbackOptions?.turnState?.metadata
|
||||
? {
|
||||
onPayload: async (
|
||||
payload: unknown,
|
||||
payloadModel: Parameters<NonNullable<typeof baseOnPayload>>[1],
|
||||
) => {
|
||||
const nextPayload = await baseOnPayload?.(payload, payloadModel);
|
||||
const resolvedPayload = (nextPayload ?? payload) as Record<string, unknown>;
|
||||
return mergeTransportMetadata(resolvedPayload, fallbackOptions.turnState?.metadata);
|
||||
},
|
||||
}
|
||||
: {}),
|
||||
...(signal ? { signal } : {}),
|
||||
};
|
||||
const httpStream = openAIWsStreamDeps.streamSimple(model, context, mergedOptions);
|
||||
@@ -753,4 +972,7 @@ export const __testing = {
|
||||
}
|
||||
: defaultOpenAIWsStreamDeps;
|
||||
},
|
||||
setWsDegradeCooldownMsForTest(nextMs?: number) {
|
||||
wsDegradeCooldownMsOverride = nextMs;
|
||||
},
|
||||
};
|
||||
|
||||
@@ -38,6 +38,26 @@ export function mergeTransportHeaders(
|
||||
return Object.keys(merged).length > 0 ? merged : undefined;
|
||||
}
|
||||
|
||||
export function mergeTransportMetadata<T extends Record<string, unknown>>(
|
||||
payload: T,
|
||||
metadata?: Record<string, string>,
|
||||
): T {
|
||||
if (!metadata || Object.keys(metadata).length === 0) {
|
||||
return payload;
|
||||
}
|
||||
const existingMetadata =
|
||||
payload.metadata && typeof payload.metadata === "object" && !Array.isArray(payload.metadata)
|
||||
? (payload.metadata as Record<string, string>)
|
||||
: undefined;
|
||||
return {
|
||||
...payload,
|
||||
metadata: {
|
||||
...existingMetadata,
|
||||
...metadata,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
export function createEmptyTransportUsage(): TransportUsage {
|
||||
return {
|
||||
input: 0,
|
||||
|
||||
@@ -65,14 +65,18 @@ export type {
|
||||
ProviderReplaySessionEntry,
|
||||
ProviderReplaySessionState,
|
||||
ProviderResolveDynamicModelContext,
|
||||
ProviderResolveTransportTurnStateContext,
|
||||
ProviderResolveWebSocketSessionPolicyContext,
|
||||
ProviderResolvedUsageAuth,
|
||||
RealtimeTranscriptionProviderPlugin,
|
||||
ProviderSanitizeReplayHistoryContext,
|
||||
ProviderTransportTurnState,
|
||||
ProviderToolSchemaDiagnostic,
|
||||
ProviderResolveUsageAuthContext,
|
||||
ProviderRuntimeModel,
|
||||
ProviderThinkingPolicyContext,
|
||||
ProviderValidateReplayTurnsContext,
|
||||
ProviderWebSocketSessionPolicy,
|
||||
ProviderWrapStreamFnContext,
|
||||
SpeechProviderPlugin,
|
||||
} from "./plugin-entry.js";
|
||||
|
||||
@@ -49,12 +49,16 @@ import type {
|
||||
RealtimeTranscriptionProviderPlugin,
|
||||
ProviderResolvedUsageAuth,
|
||||
ProviderResolveDynamicModelContext,
|
||||
ProviderResolveTransportTurnStateContext,
|
||||
ProviderResolveWebSocketSessionPolicyContext,
|
||||
ProviderSanitizeReplayHistoryContext,
|
||||
ProviderTransportTurnState,
|
||||
ProviderToolSchemaDiagnostic,
|
||||
ProviderResolveUsageAuthContext,
|
||||
ProviderRuntimeModel,
|
||||
ProviderThinkingPolicyContext,
|
||||
ProviderValidateReplayTurnsContext,
|
||||
ProviderWebSocketSessionPolicy,
|
||||
ProviderWrapStreamFnContext,
|
||||
SpeechProviderPlugin,
|
||||
PluginCommandContext,
|
||||
@@ -101,12 +105,16 @@ export type {
|
||||
ProviderSanitizeReplayHistoryContext,
|
||||
ProviderResolveUsageAuthContext,
|
||||
ProviderResolveDynamicModelContext,
|
||||
ProviderResolveTransportTurnStateContext,
|
||||
ProviderResolveWebSocketSessionPolicyContext,
|
||||
ProviderNormalizeResolvedModelContext,
|
||||
ProviderRuntimeModel,
|
||||
RealtimeTranscriptionProviderPlugin,
|
||||
ProviderTransportTurnState,
|
||||
SpeechProviderPlugin,
|
||||
ProviderThinkingPolicyContext,
|
||||
ProviderValidateReplayTurnsContext,
|
||||
ProviderWebSocketSessionPolicy,
|
||||
ProviderWrapStreamFnContext,
|
||||
OpenClawPluginService,
|
||||
OpenClawPluginServiceContext,
|
||||
|
||||
@@ -39,9 +39,13 @@ import type {
|
||||
ProviderResolveUsageAuthContext,
|
||||
ProviderPlugin,
|
||||
ProviderResolveDynamicModelContext,
|
||||
ProviderResolveTransportTurnStateContext,
|
||||
ProviderResolveWebSocketSessionPolicyContext,
|
||||
ProviderRuntimeModel,
|
||||
ProviderThinkingPolicyContext,
|
||||
ProviderTransportTurnState,
|
||||
ProviderValidateReplayTurnsContext,
|
||||
ProviderWebSocketSessionPolicy,
|
||||
ProviderWrapStreamFnContext,
|
||||
} from "./types.js";
|
||||
|
||||
@@ -525,6 +529,30 @@ export function wrapProviderStreamFn(params: {
|
||||
return resolveProviderHookPlugin(params)?.wrapStreamFn?.(params.context) ?? undefined;
|
||||
}
|
||||
|
||||
export function resolveProviderTransportTurnStateWithPlugin(params: {
|
||||
provider: string;
|
||||
config?: OpenClawConfig;
|
||||
workspaceDir?: string;
|
||||
env?: NodeJS.ProcessEnv;
|
||||
context: ProviderResolveTransportTurnStateContext;
|
||||
}): ProviderTransportTurnState | undefined {
|
||||
return (
|
||||
resolveProviderHookPlugin(params)?.resolveTransportTurnState?.(params.context) ?? undefined
|
||||
);
|
||||
}
|
||||
|
||||
export function resolveProviderWebSocketSessionPolicyWithPlugin(params: {
|
||||
provider: string;
|
||||
config?: OpenClawConfig;
|
||||
workspaceDir?: string;
|
||||
env?: NodeJS.ProcessEnv;
|
||||
context: ProviderResolveWebSocketSessionPolicyContext;
|
||||
}): ProviderWebSocketSessionPolicy | undefined {
|
||||
return (
|
||||
resolveProviderHookPlugin(params)?.resolveWebSocketSessionPolicy?.(params.context) ?? undefined
|
||||
);
|
||||
}
|
||||
|
||||
export async function createProviderEmbeddingProvider(params: {
|
||||
provider: string;
|
||||
config?: OpenClawConfig;
|
||||
|
||||
@@ -697,6 +697,57 @@ export type ProviderWrapStreamFnContext = ProviderPrepareExtraParamsContext & {
|
||||
streamFn?: StreamFn;
|
||||
};
|
||||
|
||||
/**
|
||||
* Provider-owned transport turn state.
|
||||
*
|
||||
* Use this for provider-native request headers or metadata that should stay
|
||||
* stable across retries while still being attached by generic core transports.
|
||||
*/
|
||||
export type ProviderTransportTurnState = {
|
||||
headers?: Record<string, string>;
|
||||
metadata?: Record<string, string>;
|
||||
};
|
||||
|
||||
/**
|
||||
* Provider-owned request identity for transport turns.
|
||||
*
|
||||
* Use this when the provider exposes native request/session metadata that must
|
||||
* be attached by both HTTP and WebSocket transports.
|
||||
*/
|
||||
export type ProviderResolveTransportTurnStateContext = {
|
||||
provider: string;
|
||||
modelId: string;
|
||||
model?: ProviderRuntimeModel;
|
||||
sessionId?: string;
|
||||
turnId: string;
|
||||
attempt: number;
|
||||
transport: "stream" | "websocket";
|
||||
};
|
||||
|
||||
/**
|
||||
* Provider-owned WebSocket session policy.
|
||||
*
|
||||
* Use this for session-scoped headers or cool-down behavior that should apply
|
||||
* before a generic WebSocket transport decides to retry or fall back.
|
||||
*/
|
||||
export type ProviderWebSocketSessionPolicy = {
|
||||
headers?: Record<string, string>;
|
||||
degradeCooldownMs?: number;
|
||||
};
|
||||
|
||||
/**
|
||||
* Provider-owned WebSocket session policy input.
|
||||
*
|
||||
* Use this when the provider wants to control native session handshake headers
|
||||
* or the post-failure cool-down window for a generic WebSocket transport.
|
||||
*/
|
||||
export type ProviderResolveWebSocketSessionPolicyContext = {
|
||||
provider: string;
|
||||
modelId: string;
|
||||
model?: ProviderRuntimeModel;
|
||||
sessionId?: string;
|
||||
};
|
||||
|
||||
/**
|
||||
* Generic embedding provider shape returned by provider plugins.
|
||||
*
|
||||
@@ -1166,6 +1217,26 @@ export type ProviderPlugin = {
|
||||
* transport implementation.
|
||||
*/
|
||||
wrapStreamFn?: (ctx: ProviderWrapStreamFnContext) => StreamFn | null | undefined;
|
||||
/**
|
||||
* Provider-owned native transport turn identity.
|
||||
*
|
||||
* Use this when a provider wants generic transports to attach provider-native
|
||||
* request headers or metadata on each turn without hardcoding vendor logic in
|
||||
* core.
|
||||
*/
|
||||
resolveTransportTurnState?: (
|
||||
ctx: ProviderResolveTransportTurnStateContext,
|
||||
) => ProviderTransportTurnState | null | undefined;
|
||||
/**
|
||||
* Provider-owned WebSocket session policy.
|
||||
*
|
||||
* Use this when a provider wants generic WebSocket transports to attach
|
||||
* native session headers or tune the session-scoped cool-down before HTTP
|
||||
* fallback.
|
||||
*/
|
||||
resolveWebSocketSessionPolicy?: (
|
||||
ctx: ProviderResolveWebSocketSessionPolicyContext,
|
||||
) => ProviderWebSocketSessionPolicy | null | undefined;
|
||||
/**
|
||||
* Provider-owned embedding provider factory.
|
||||
*
|
||||
|
||||
Reference in New Issue
Block a user