fix: honor explicit codex websocket transport

This commit is contained in:
Peter Steinberger
2026-05-27 09:46:47 +01:00
parent d8d79ba24e
commit c2035fea1c
2 changed files with 77 additions and 9 deletions

View File

@@ -1,5 +1,10 @@
import { describe, expect, it } from "vitest";
import { extractOpenAICodexAccountId } from "./openai-codex-responses.js";
import { afterEach, describe, expect, it, vi } from "vitest";
import type { Context, Model } from "../types.js";
import {
extractOpenAICodexAccountId,
resetOpenAICodexWebSocketDebugStats,
streamOpenAICodexResponses,
} from "./openai-codex-responses.js";
function createJwt(payload: Record<string, unknown>): string {
const header = Buffer.from(JSON.stringify({ alg: "none", typ: "JWT" })).toString("base64url");
@@ -25,3 +30,58 @@ describe("extractOpenAICodexAccountId", () => {
);
});
});
describe("streamOpenAICodexResponses transport", () => {
afterEach(() => {
vi.unstubAllGlobals();
resetOpenAICodexWebSocketDebugStats();
});
const model = {
id: "gpt-5.5",
name: "GPT-5.5",
api: "openai-codex-responses",
provider: "openai-codex",
baseUrl: "https://chatgpt.test/backend-api",
reasoning: true,
input: ["text"],
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
contextWindow: 128_000,
maxTokens: 16_000,
} satisfies Model<"openai-codex-responses">;
const context = {
messages: [{ role: "user", content: "hi", timestamp: 1 }],
} satisfies Context;
it("does not fall back to SSE when websocket transport is explicit", async () => {
const fetchMock = vi.fn(async () => {
throw new Error("fetch should not run");
});
vi.stubGlobal("fetch", fetchMock);
vi.stubGlobal(
"WebSocket",
class {
constructor() {
throw new Error("websocket connect failed");
}
},
);
const stream = streamOpenAICodexResponses(model, context, {
apiKey: createJwt({
"https://api.openai.com/auth": {
chatgpt_account_id: "acct-1",
},
}),
sessionId: "session-explicit-websocket",
transport: "websocket",
});
const result = await stream.result();
expect(fetchMock).not.toHaveBeenCalled();
expect(result.stopReason).toBe("error");
expect(result.errorMessage).toContain("websocket connect failed");
});
});

View File

@@ -196,7 +196,7 @@ export const streamOpenAICodexResponses: StreamFunction<
const bodyJson = JSON.stringify(body);
const transport = options?.transport || "auto";
const websocketDisabledForSession =
transport !== "sse" && isWebSocketSseFallbackActive(options?.sessionId);
transport === "auto" && isWebSocketSseFallbackActive(options?.sessionId);
if (websocketDisabledForSession) {
recordWebSocketSseFallback(options?.sessionId);
}
@@ -236,7 +236,7 @@ export const streamOpenAICodexResponses: StreamFunction<
output,
createAssistantMessageDiagnostic("provider_transport_failure", error, {
configuredTransport: transport,
fallbackTransport: websocketStarted ? undefined : "sse",
fallbackTransport: transport === "auto" && !websocketStarted ? "sse" : undefined,
eventsEmitted: websocketStarted,
phase: websocketStarted
? "after_message_stream_start"
@@ -244,8 +244,10 @@ export const streamOpenAICodexResponses: StreamFunction<
requestBytes: new TextEncoder().encode(bodyJson).byteLength,
}),
);
recordWebSocketFailure(options?.sessionId, error);
if (websocketStarted) {
recordWebSocketFailure(options?.sessionId, error, {
activateSseFallback: transport === "auto",
});
if (websocketStarted || transport !== "auto") {
throw error;
}
recordWebSocketSseFallback(options?.sessionId);
@@ -795,16 +797,22 @@ function recordWebSocketSseFallback(sessionId: string | undefined): void {
stats.websocketFallbackActive = isWebSocketSseFallbackActive(sessionId);
}
function recordWebSocketFailure(sessionId: string | undefined, error: unknown): void {
function recordWebSocketFailure(
sessionId: string | undefined,
error: unknown,
options: { activateSseFallback: boolean },
): void {
if (!sessionId) {
return;
}
websocketSseFallbackSessions.add(sessionId);
if (options.activateSseFallback) {
websocketSseFallbackSessions.add(sessionId);
}
const stats = getOrCreateWebSocketDebugStats(sessionId);
stats.websocketFailures++;
stats.lastWebSocketError = formatThrownValue(error);
stats.websocketFallbackActive = true;
stats.websocketFallbackActive = isWebSocketSseFallbackActive(sessionId);
}
type WebSocketConstructor = new (