From c2035fea1c2e564154017fdec953dc2ac91fd283 Mon Sep 17 00:00:00 2001 From: Peter Steinberger Date: Wed, 27 May 2026 09:46:47 +0100 Subject: [PATCH] fix: honor explicit codex websocket transport --- .../providers/openai-codex-responses.test.ts | 64 ++++++++++++++++++- src/llm/providers/openai-codex-responses.ts | 22 +++++-- 2 files changed, 77 insertions(+), 9 deletions(-) diff --git a/src/llm/providers/openai-codex-responses.test.ts b/src/llm/providers/openai-codex-responses.test.ts index 57ba4fa4329..6cd45aa2684 100644 --- a/src/llm/providers/openai-codex-responses.test.ts +++ b/src/llm/providers/openai-codex-responses.test.ts @@ -1,5 +1,10 @@ -import { describe, expect, it } from "vitest"; -import { extractOpenAICodexAccountId } from "./openai-codex-responses.js"; +import { afterEach, describe, expect, it, vi } from "vitest"; +import type { Context, Model } from "../types.js"; +import { + extractOpenAICodexAccountId, + resetOpenAICodexWebSocketDebugStats, + streamOpenAICodexResponses, +} from "./openai-codex-responses.js"; function createJwt(payload: Record): string { const header = Buffer.from(JSON.stringify({ alg: "none", typ: "JWT" })).toString("base64url"); @@ -25,3 +30,58 @@ describe("extractOpenAICodexAccountId", () => { ); }); }); + +describe("streamOpenAICodexResponses transport", () => { + afterEach(() => { + vi.unstubAllGlobals(); + resetOpenAICodexWebSocketDebugStats(); + }); + + const model = { + id: "gpt-5.5", + name: "GPT-5.5", + api: "openai-codex-responses", + provider: "openai-codex", + baseUrl: "https://chatgpt.test/backend-api", + reasoning: true, + input: ["text"], + cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 }, + contextWindow: 128_000, + maxTokens: 16_000, + } satisfies Model<"openai-codex-responses">; + + const context = { + messages: [{ role: "user", content: "hi", timestamp: 1 }], + } satisfies Context; + + it("does not fall back to SSE when websocket transport is explicit", async () => { + const fetchMock = vi.fn(async () => { + throw new Error("fetch should not run"); + }); + vi.stubGlobal("fetch", fetchMock); + vi.stubGlobal( + "WebSocket", + class { + constructor() { + throw new Error("websocket connect failed"); + } + }, + ); + + const stream = streamOpenAICodexResponses(model, context, { + apiKey: createJwt({ + "https://api.openai.com/auth": { + chatgpt_account_id: "acct-1", + }, + }), + sessionId: "session-explicit-websocket", + transport: "websocket", + }); + + const result = await stream.result(); + + expect(fetchMock).not.toHaveBeenCalled(); + expect(result.stopReason).toBe("error"); + expect(result.errorMessage).toContain("websocket connect failed"); + }); +}); diff --git a/src/llm/providers/openai-codex-responses.ts b/src/llm/providers/openai-codex-responses.ts index 1da474b4ad9..05f5f0c84c6 100644 --- a/src/llm/providers/openai-codex-responses.ts +++ b/src/llm/providers/openai-codex-responses.ts @@ -196,7 +196,7 @@ export const streamOpenAICodexResponses: StreamFunction< const bodyJson = JSON.stringify(body); const transport = options?.transport || "auto"; const websocketDisabledForSession = - transport !== "sse" && isWebSocketSseFallbackActive(options?.sessionId); + transport === "auto" && isWebSocketSseFallbackActive(options?.sessionId); if (websocketDisabledForSession) { recordWebSocketSseFallback(options?.sessionId); } @@ -236,7 +236,7 @@ export const streamOpenAICodexResponses: StreamFunction< output, createAssistantMessageDiagnostic("provider_transport_failure", error, { configuredTransport: transport, - fallbackTransport: websocketStarted ? undefined : "sse", + fallbackTransport: transport === "auto" && !websocketStarted ? "sse" : undefined, eventsEmitted: websocketStarted, phase: websocketStarted ? "after_message_stream_start" @@ -244,8 +244,10 @@ export const streamOpenAICodexResponses: StreamFunction< requestBytes: new TextEncoder().encode(bodyJson).byteLength, }), ); - recordWebSocketFailure(options?.sessionId, error); - if (websocketStarted) { + recordWebSocketFailure(options?.sessionId, error, { + activateSseFallback: transport === "auto", + }); + if (websocketStarted || transport !== "auto") { throw error; } recordWebSocketSseFallback(options?.sessionId); @@ -795,16 +797,22 @@ function recordWebSocketSseFallback(sessionId: string | undefined): void { stats.websocketFallbackActive = isWebSocketSseFallbackActive(sessionId); } -function recordWebSocketFailure(sessionId: string | undefined, error: unknown): void { +function recordWebSocketFailure( + sessionId: string | undefined, + error: unknown, + options: { activateSseFallback: boolean }, +): void { if (!sessionId) { return; } - websocketSseFallbackSessions.add(sessionId); + if (options.activateSseFallback) { + websocketSseFallbackSessions.add(sessionId); + } const stats = getOrCreateWebSocketDebugStats(sessionId); stats.websocketFailures++; stats.lastWebSocketError = formatThrownValue(error); - stats.websocketFallbackActive = true; + stats.websocketFallbackActive = isWebSocketSseFallbackActive(sessionId); } type WebSocketConstructor = new (