fix(agents): centralize native websocket endpoint checks

This commit is contained in:
Vincent Koc
2026-04-22 13:37:41 -07:00
parent 4c675216f1
commit e250ea3668
2 changed files with 42 additions and 41 deletions

View File

@@ -2528,6 +2528,39 @@ describe("createOpenAIWebSocketStreamFn", () => {
expect(secondPayload.metadata?.openclaw_turn_attempt).toBe("2");
});
it("does not attach native OpenAI session headers or metadata for custom responses endpoints", async () => {
const sessionId = "sess-custom-openai-endpoint";
const streamFn = createOpenAIWebSocketStreamFn("sk-test", sessionId);
const customEndpointModel = {
...modelStub,
baseUrl: "http://127.0.0.1:4100/v1",
};
const stream = streamFn(
customEndpointModel as Parameters<typeof streamFn>[0],
contextStub as Parameters<typeof streamFn>[1],
{ transport: "websocket" } as Parameters<typeof streamFn>[2],
);
await new Promise((r) => setImmediate(r));
const manager = MockManager.lastInstance!;
manager.simulateEvent({
type: "response.completed",
response: makeResponseObject("resp-custom-endpoint", "custom endpoint"),
});
for await (const _ of await resolveStream(stream)) {
// consume
}
expect((manager.options as { headers?: Record<string, string> } | undefined)?.headers).toBe(
undefined,
);
const payload = manager.sentEvents[0] as { metadata?: Record<string, string> };
expect(payload.metadata?.openclaw_session_id).toBeUndefined();
expect(payload.metadata?.openclaw_transport).toBeUndefined();
releaseWsSession(sessionId);
});
it("keeps websocket degraded for the session until the cool-down expires", async () => {
openAIWsStreamTesting.setWsDegradeCooldownMsForTest(50);
MockManager.globalConnectShouldFail = true;

View File

@@ -40,7 +40,6 @@ import {
encodeAssistantTextSignature,
normalizeAssistantPhase,
} from "../shared/chat-message-content.js";
import { normalizeLowercaseStringOrEmpty } from "../shared/string-coerce.js";
import { resolveOpenAIStrictToolSetting } from "./openai-strict-tool-setting.js";
import {
getOpenAIWebSocketErrorDetails,
@@ -57,6 +56,7 @@ import {
} from "./openai-ws-message-conversion.js";
import { buildOpenAIWebSocketResponseCreatePayload } from "./openai-ws-request.js";
import { log } from "./pi-embedded-runner/logger.js";
import { resolveProviderEndpoint } from "./provider-attribution.js";
import { normalizeProviderId } from "./provider-id.js";
import { createBoundaryAwareStreamFnForModel } from "./provider-transport-stream.js";
import {
@@ -365,61 +365,29 @@ function resolveWsManagerConfigSignature(
const AZURE_OPENAI_PROVIDER_IDS = new Set(["azure-openai", "azure-openai-responses"]);
const OPENAI_CODEX_PROVIDER_ID = "openai-codex";
function isOpenAIApiBaseUrl(baseUrl?: string): boolean {
const trimmed = baseUrl?.trim();
if (!trimmed) {
return false;
}
try {
const url = new URL(trimmed);
return (
url.protocol === "https:" &&
normalizeLowercaseStringOrEmpty(url.hostname) === "api.openai.com" &&
/^\/v1\/?$/u.test(url.pathname)
);
} catch {
return false;
}
}
function isOpenAICodexBaseUrl(baseUrl?: string): boolean {
const trimmed = baseUrl?.trim();
if (!trimmed) {
return false;
}
return /^https?:\/\/chatgpt\.com\/backend-api\/?$/iu.test(trimmed);
}
function isAzureOpenAIBaseUrl(baseUrl?: string): boolean {
const trimmed = baseUrl?.trim();
if (!trimmed) {
return false;
}
try {
return normalizeLowercaseStringOrEmpty(new URL(trimmed).hostname).endsWith(".openai.azure.com");
} catch {
return false;
}
}
function normalizeTransportIdentityValue(value: string, maxLength = 160): string {
const trimmed = value.trim().replace(/[\r\n]+/gu, " ");
return trimmed.length > maxLength ? trimmed.slice(0, maxLength) : trimmed;
}
function usesNativeOpenAIRoute(provider: string, baseUrl?: string): boolean {
const endpointClass = resolveProviderEndpoint(baseUrl).endpointClass;
const normalizedProvider = normalizeProviderId(provider);
if (!normalizedProvider) {
return false;
}
if (normalizedProvider === "openai") {
return !baseUrl || isOpenAIApiBaseUrl(baseUrl);
return endpointClass === "default" || endpointClass === "openai-public";
}
if (AZURE_OPENAI_PROVIDER_IDS.has(normalizedProvider)) {
return !baseUrl || isAzureOpenAIBaseUrl(baseUrl);
return endpointClass === "default" || endpointClass === "azure-openai";
}
if (normalizedProvider === OPENAI_CODEX_PROVIDER_ID) {
return !baseUrl || isOpenAIApiBaseUrl(baseUrl) || isOpenAICodexBaseUrl(baseUrl);
return (
endpointClass === "default" ||
endpointClass === "openai-public" ||
endpointClass === "openai-codex"
);
}
return false;
}