From e250ea3668f64f61cdf7a03f9f4e835a3b283a1d Mon Sep 17 00:00:00 2001 From: Vincent Koc Date: Wed, 22 Apr 2026 13:37:41 -0700 Subject: [PATCH] fix(agents): centralize native websocket endpoint checks --- src/agents/openai-ws-stream.test.ts | 33 +++++++++++++++++++ src/agents/openai-ws-stream.ts | 50 ++++++----------------------- 2 files changed, 42 insertions(+), 41 deletions(-) diff --git a/src/agents/openai-ws-stream.test.ts b/src/agents/openai-ws-stream.test.ts index 4c9bdf5c5e3..59d8494efd0 100644 --- a/src/agents/openai-ws-stream.test.ts +++ b/src/agents/openai-ws-stream.test.ts @@ -2528,6 +2528,39 @@ describe("createOpenAIWebSocketStreamFn", () => { expect(secondPayload.metadata?.openclaw_turn_attempt).toBe("2"); }); + it("does not attach native OpenAI session headers or metadata for custom responses endpoints", async () => { + const sessionId = "sess-custom-openai-endpoint"; + const streamFn = createOpenAIWebSocketStreamFn("sk-test", sessionId); + const customEndpointModel = { + ...modelStub, + baseUrl: "http://127.0.0.1:4100/v1", + }; + const stream = streamFn( + customEndpointModel as Parameters[0], + contextStub as Parameters[1], + { transport: "websocket" } as Parameters[2], + ); + + await new Promise((r) => setImmediate(r)); + const manager = MockManager.lastInstance!; + manager.simulateEvent({ + type: "response.completed", + response: makeResponseObject("resp-custom-endpoint", "custom endpoint"), + }); + + for await (const _ of await resolveStream(stream)) { + // consume + } + + expect((manager.options as { headers?: Record } | undefined)?.headers).toBe( + undefined, + ); + const payload = manager.sentEvents[0] as { metadata?: Record }; + expect(payload.metadata?.openclaw_session_id).toBeUndefined(); + expect(payload.metadata?.openclaw_transport).toBeUndefined(); + releaseWsSession(sessionId); + }); + it("keeps websocket degraded for the session until the cool-down expires", async () => { openAIWsStreamTesting.setWsDegradeCooldownMsForTest(50); MockManager.globalConnectShouldFail = true; diff --git a/src/agents/openai-ws-stream.ts b/src/agents/openai-ws-stream.ts index bc98120c81c..0a7c5cd66eb 100644 --- a/src/agents/openai-ws-stream.ts +++ b/src/agents/openai-ws-stream.ts @@ -40,7 +40,6 @@ import { encodeAssistantTextSignature, normalizeAssistantPhase, } from "../shared/chat-message-content.js"; -import { normalizeLowercaseStringOrEmpty } from "../shared/string-coerce.js"; import { resolveOpenAIStrictToolSetting } from "./openai-strict-tool-setting.js"; import { getOpenAIWebSocketErrorDetails, @@ -57,6 +56,7 @@ import { } from "./openai-ws-message-conversion.js"; import { buildOpenAIWebSocketResponseCreatePayload } from "./openai-ws-request.js"; import { log } from "./pi-embedded-runner/logger.js"; +import { resolveProviderEndpoint } from "./provider-attribution.js"; import { normalizeProviderId } from "./provider-id.js"; import { createBoundaryAwareStreamFnForModel } from "./provider-transport-stream.js"; import { @@ -365,61 +365,29 @@ function resolveWsManagerConfigSignature( const AZURE_OPENAI_PROVIDER_IDS = new Set(["azure-openai", "azure-openai-responses"]); const OPENAI_CODEX_PROVIDER_ID = "openai-codex"; -function isOpenAIApiBaseUrl(baseUrl?: string): boolean { - const trimmed = baseUrl?.trim(); - if (!trimmed) { - return false; - } - try { - const url = new URL(trimmed); - return ( - url.protocol === "https:" && - normalizeLowercaseStringOrEmpty(url.hostname) === "api.openai.com" && - /^\/v1\/?$/u.test(url.pathname) - ); - } catch { - return false; - } -} - -function isOpenAICodexBaseUrl(baseUrl?: string): boolean { - const trimmed = baseUrl?.trim(); - if (!trimmed) { - return false; - } - return /^https?:\/\/chatgpt\.com\/backend-api\/?$/iu.test(trimmed); -} - -function isAzureOpenAIBaseUrl(baseUrl?: string): boolean { - const trimmed = baseUrl?.trim(); - if (!trimmed) { - return false; - } - try { - return normalizeLowercaseStringOrEmpty(new URL(trimmed).hostname).endsWith(".openai.azure.com"); - } catch { - return false; - } -} - function normalizeTransportIdentityValue(value: string, maxLength = 160): string { const trimmed = value.trim().replace(/[\r\n]+/gu, " "); return trimmed.length > maxLength ? trimmed.slice(0, maxLength) : trimmed; } function usesNativeOpenAIRoute(provider: string, baseUrl?: string): boolean { + const endpointClass = resolveProviderEndpoint(baseUrl).endpointClass; const normalizedProvider = normalizeProviderId(provider); if (!normalizedProvider) { return false; } if (normalizedProvider === "openai") { - return !baseUrl || isOpenAIApiBaseUrl(baseUrl); + return endpointClass === "default" || endpointClass === "openai-public"; } if (AZURE_OPENAI_PROVIDER_IDS.has(normalizedProvider)) { - return !baseUrl || isAzureOpenAIBaseUrl(baseUrl); + return endpointClass === "default" || endpointClass === "azure-openai"; } if (normalizedProvider === OPENAI_CODEX_PROVIDER_ID) { - return !baseUrl || isOpenAIApiBaseUrl(baseUrl) || isOpenAICodexBaseUrl(baseUrl); + return ( + endpointClass === "default" || + endpointClass === "openai-public" || + endpointClass === "openai-codex" + ); } return false; }