From cfbad0a4f98529e864c40f0ceb8fd19b7dbdc880 Mon Sep 17 00:00:00 2001 From: Vincent Koc Date: Thu, 2 Apr 2026 21:42:11 +0900 Subject: [PATCH] fix(providers): unify request policy resolution (#59653) * fix(providers): unify request policy resolution * fix(providers): preserve request config SDK contract * fix(providers): harden request header policy --- src/agents/openai-ws-connection.ts | 8 +- src/agents/openai-ws-stream.ts | 6 +- .../openai-stream-wrappers.ts | 11 +- .../proxy-stream-wrappers.ts | 12 +- src/agents/provider-request-config.test.ts | 75 +++++++- src/agents/provider-request-config.ts | 171 ++++++++++++++---- src/media-understanding/shared.test.ts | 14 +- src/media-understanding/shared.ts | 83 ++------- 8 files changed, 259 insertions(+), 121 deletions(-) diff --git a/src/agents/openai-ws-connection.ts b/src/agents/openai-ws-connection.ts index 40305baa414..34d653af665 100644 --- a/src/agents/openai-ws-connection.ts +++ b/src/agents/openai-ws-connection.ts @@ -15,7 +15,7 @@ import { EventEmitter } from "node:events"; import WebSocket, { type ClientOptions } from "ws"; -import { resolveProviderRequestHeaders } from "./provider-request-config.js"; +import { resolveProviderRequestPolicyConfig } from "./provider-request-config.js"; // ───────────────────────────────────────────────────────────────────────────── // WebSocket Event Types (Server → Client) @@ -403,18 +403,18 @@ export class OpenAIWebSocketManager extends EventEmitter { } const socket = this.socketFactory(this.wsUrl, { - headers: resolveProviderRequestHeaders({ + headers: resolveProviderRequestPolicyConfig({ provider: "openai", api: "openai-responses", baseUrl: this.wsUrl, capability: "llm", transport: "websocket", - defaultHeaders: { + providerHeaders: { Authorization: `Bearer ${this.apiKey}`, "OpenAI-Beta": "responses-websocket=v1", }, precedence: "defaults-win", - }), + }).headers, }); this.ws = socket; diff --git a/src/agents/openai-ws-stream.ts b/src/agents/openai-ws-stream.ts index 98c2291e2ec..013d1e3e4a6 100644 --- a/src/agents/openai-ws-stream.ts +++ b/src/agents/openai-ws-stream.ts @@ -42,7 +42,7 @@ import { } from "./openai-ws-message-conversion.js"; import { log } from "./pi-embedded-runner/logger.js"; import { resolveOpenAITextVerbosity } from "./pi-embedded-runner/openai-stream-wrappers.js"; -import { resolveProviderRequestCapabilities } from "./provider-attribution.js"; +import { resolveProviderRequestPolicyConfig } from "./provider-request-config.js"; import { buildAssistantMessageWithZeroUsage, buildStreamErrorAssistantMessage, @@ -486,14 +486,14 @@ export function createOpenAIWebSocketStreamFn( // Respect compat.supportsStore — providers like Gemini reject unknown // fields such as `store` with a 400 error. Fixes #39086. - const supportsResponsesStoreField = resolveProviderRequestCapabilities({ + const supportsResponsesStoreField = resolveProviderRequestPolicyConfig({ provider: typeof model.provider === "string" ? model.provider : undefined, api: typeof model.api === "string" ? model.api : undefined, baseUrl: typeof model.baseUrl === "string" ? model.baseUrl : undefined, compat: (model as { compat?: { supportsStore?: boolean } }).compat, capability: "llm", transport: "websocket", - }).supportsResponsesStoreField; + }).capabilities.supportsResponsesStoreField; const payload: Record = { type: "response.create", diff --git a/src/agents/pi-embedded-runner/openai-stream-wrappers.ts b/src/agents/pi-embedded-runner/openai-stream-wrappers.ts index 6e092f2fdb2..89da1c5cb6f 100644 --- a/src/agents/pi-embedded-runner/openai-stream-wrappers.ts +++ b/src/agents/pi-embedded-runner/openai-stream-wrappers.ts @@ -6,8 +6,7 @@ import { patchCodexNativeWebSearchPayload, resolveCodexNativeSearchActivation, } from "../codex-native-web-search.js"; -import { resolveProviderRequestCapabilities } from "../provider-attribution.js"; -import { resolveProviderRequestHeaders } from "../provider-request-config.js"; +import { resolveProviderRequestPolicyConfig } from "../provider-request-config.js"; import { log } from "./logger.js"; import { streamWithPayloadPatch } from "./stream-payload-utils.js"; @@ -28,14 +27,14 @@ function resolveOpenAIRequestCapabilities(model: { baseUrl?: unknown; compat?: { supportsStore?: boolean }; }) { - return resolveProviderRequestCapabilities({ + return resolveProviderRequestPolicyConfig({ provider: typeof model.provider === "string" ? model.provider : undefined, api: typeof model.api === "string" ? model.api : undefined, baseUrl: typeof model.baseUrl === "string" ? model.baseUrl : undefined, compat: model.compat, capability: "llm", transport: "stream", - }); + }).capabilities; } function shouldApplyOpenAIAttributionHeaders(model: { @@ -502,7 +501,7 @@ export function createOpenAIAttributionHeadersWrapper( } return underlying(model, context, { ...options, - headers: resolveProviderRequestHeaders({ + headers: resolveProviderRequestPolicyConfig({ provider: attributionProvider, api: typeof model.api === "string" ? model.api : undefined, baseUrl: typeof model.baseUrl === "string" ? model.baseUrl : undefined, @@ -510,7 +509,7 @@ export function createOpenAIAttributionHeadersWrapper( transport: "stream", callerHeaders: options?.headers, precedence: "defaults-win", - }), + }).headers, }); }; } diff --git a/src/agents/pi-embedded-runner/proxy-stream-wrappers.ts b/src/agents/pi-embedded-runner/proxy-stream-wrappers.ts index 8f20ce3d26d..740d72de3d3 100644 --- a/src/agents/pi-embedded-runner/proxy-stream-wrappers.ts +++ b/src/agents/pi-embedded-runner/proxy-stream-wrappers.ts @@ -1,7 +1,7 @@ import type { StreamFn } from "@mariozechner/pi-agent-core"; import { streamSimple } from "@mariozechner/pi-ai"; import type { ThinkLevel } from "../../auto-reply/thinking.js"; -import { resolveProviderRequestHeaders } from "../provider-request-config.js"; +import { resolveProviderRequestPolicyConfig } from "../provider-request-config.js"; import { streamWithPayloadPatch } from "./stream-payload-utils.js"; const KILOCODE_FEATURE_HEADER = "X-KILOCODE-FEATURE"; const KILOCODE_FEATURE_DEFAULT = "openclaw"; @@ -111,7 +111,7 @@ export function createOpenRouterWrapper( ): StreamFn { const underlying = baseStreamFn ?? streamSimple; return (model, context, options) => { - const headers = resolveProviderRequestHeaders({ + const headers = resolveProviderRequestPolicyConfig({ provider: typeof model.provider === "string" ? model.provider : "openrouter", api: typeof model.api === "string" ? model.api : undefined, baseUrl: typeof model.baseUrl === "string" ? model.baseUrl : undefined, @@ -119,7 +119,7 @@ export function createOpenRouterWrapper( transport: "stream", callerHeaders: options?.headers, precedence: "caller-wins", - }); + }).headers; return streamWithPayloadPatch( underlying, model, @@ -145,16 +145,16 @@ export function createKilocodeWrapper( ): StreamFn { const underlying = baseStreamFn ?? streamSimple; return (model, context, options) => { - const headers = resolveProviderRequestHeaders({ + const headers = resolveProviderRequestPolicyConfig({ provider: typeof model.provider === "string" ? model.provider : "kilocode", api: typeof model.api === "string" ? model.api : undefined, baseUrl: typeof model.baseUrl === "string" ? model.baseUrl : undefined, capability: "llm", transport: "stream", callerHeaders: options?.headers, - defaultHeaders: resolveKilocodeAppHeaders(), + providerHeaders: resolveKilocodeAppHeaders(), precedence: "defaults-win", - }); + }).headers; return streamWithPayloadPatch( underlying, model, diff --git a/src/agents/provider-request-config.test.ts b/src/agents/provider-request-config.test.ts index f46cbc9ddb8..3eac0776cfe 100644 --- a/src/agents/provider-request-config.test.ts +++ b/src/agents/provider-request-config.test.ts @@ -1,5 +1,6 @@ import { describe, expect, it } from "vitest"; import { + resolveProviderRequestPolicyConfig, resolveProviderRequestConfig, resolveProviderRequestHeaders, } from "./provider-request-config.js"; @@ -103,10 +104,82 @@ describe("provider request config", () => { }); expect(resolved).toEqual({ - "HTTP-Referer": "https://example.com", + "HTTP-Referer": "https://openclaw.ai", "X-OpenRouter-Title": "OpenClaw", "X-OpenRouter-Categories": "cli-agent", "X-Custom": "1", }); }); + + it("merges header names case-insensitively", () => { + const resolved = resolveProviderRequestHeaders({ + provider: "openai", + api: "openai-responses", + baseUrl: "https://api.openai.com/v1", + capability: "llm", + transport: "stream", + callerHeaders: { + "user-agent": "custom-agent/1.0", + }, + precedence: "caller-wins", + }); + + expect( + Object.keys(resolved ?? {}).filter((key) => key.toLowerCase() === "user-agent"), + ).toHaveLength(1); + expect(resolved?.["User-Agent"]).toMatch(/^openclaw\//); + }); + + it("drops forbidden header keys while merging", () => { + const resolved = resolveProviderRequestHeaders({ + provider: "custom-openai", + callerHeaders: { + __proto__: "polluted", + constructor: "polluted", + "X-Custom": "1", + } as Record, + defaultHeaders: { + prototype: "polluted", + } as Record, + }); + + expect(resolved).toEqual({ + "X-Custom": "1", + }); + expect(Object.getPrototypeOf(resolved ?? {})).toBeNull(); + }); + + it("unifies policy, capabilities, headers, base URL, and private-network posture", () => { + const resolved = resolveProviderRequestPolicyConfig({ + provider: "openai", + api: "openai-responses", + baseUrl: "https://api.openai.com/v1/", + defaultBaseUrl: "https://fallback.example/v1/", + callerHeaders: { + "User-Agent": "custom-agent/1.0", + "X-Custom": "1", + }, + providerHeaders: { + authorization: "Bearer test-key", + }, + compat: { + supportsStore: true, + }, + capability: "llm", + transport: "stream", + precedence: "defaults-win", + }); + + expect(resolved.baseUrl).toBe("https://api.openai.com/v1"); + expect(resolved.allowPrivateNetwork).toBe(true); + expect(resolved.policy.endpointClass).toBe("openai-public"); + expect(resolved.capabilities.allowsResponsesStore).toBe(true); + expect(resolved.headers).toMatchObject({ + authorization: "Bearer test-key", + originator: "openclaw", + version: expect.any(String), + "User-Agent": expect.stringMatching(/^openclaw\//), + "X-Custom": "1", + }); + }); }); diff --git a/src/agents/provider-request-config.ts b/src/agents/provider-request-config.ts index cd79cb8d14d..8532b42b5b2 100644 --- a/src/agents/provider-request-config.ts +++ b/src/agents/provider-request-config.ts @@ -1,11 +1,15 @@ import type { Api } from "@mariozechner/pi-ai"; import type { ModelDefinitionConfig } from "../config/types.js"; import type { + ProviderRequestCapabilities, ProviderRequestCapability, - ProviderRequestPolicyResolution, ProviderRequestTransport, } from "./provider-attribution.js"; -import { resolveProviderRequestPolicy } from "./provider-attribution.js"; +import { + resolveProviderRequestCapabilities, + resolveProviderRequestPolicy, + type ProviderRequestPolicyResolution, +} from "./provider-attribution.js"; type RequestApi = Api | ModelDefinitionConfig["api"]; @@ -34,22 +38,133 @@ export type ResolvedProviderRequestConfig = { export type ProviderRequestHeaderPrecedence = "caller-wins" | "defaults-win"; +export type ResolvedProviderRequestPolicyConfig = ResolvedProviderRequestConfig & { + allowPrivateNetwork: boolean; + capabilities: ProviderRequestCapabilities; +}; + +const FORBIDDEN_HEADER_KEYS = new Set(["__proto__", "prototype", "constructor"]); + +type ResolveProviderRequestPolicyConfigParams = { + provider?: string; + api?: RequestApi; + baseUrl?: string; + defaultBaseUrl?: string; + capability?: ProviderRequestCapability; + transport?: ProviderRequestTransport; + discoveredHeaders?: Record; + providerHeaders?: Record; + modelHeaders?: Record; + callerHeaders?: Record; + precedence?: ProviderRequestHeaderPrecedence; + authHeader?: boolean; + compat?: { + supportsStore?: boolean; + } | null; + modelId?: string | null; + allowPrivateNetwork?: boolean; +}; + +export function normalizeBaseUrl(baseUrl: string | undefined, fallback: string): string; +export function normalizeBaseUrl( + baseUrl: string | undefined, + fallback?: string, +): string | undefined; +export function normalizeBaseUrl( + baseUrl: string | undefined, + fallback?: string, +): string | undefined { + const raw = baseUrl?.trim() || fallback?.trim(); + if (!raw) { + return undefined; + } + return raw.replace(/\/+$/, ""); +} + export function mergeProviderRequestHeaders( ...headerSets: Array | undefined> ): Record | undefined { let merged: Record | undefined; + const headerNamesByLowerKey = new Map(); for (const headers of headerSets) { if (!headers) { continue; } - merged = { - ...merged, - ...headers, - }; + if (!merged) { + merged = Object.create(null) as Record; + } + for (const [key, value] of Object.entries(headers)) { + const normalizedKey = key.toLowerCase(); + if (FORBIDDEN_HEADER_KEYS.has(normalizedKey)) { + continue; + } + const previousKey = headerNamesByLowerKey.get(normalizedKey); + if (previousKey && previousKey !== key) { + delete merged[previousKey]; + } + merged[key] = value; + headerNamesByLowerKey.set(normalizedKey, key); + } } return merged && Object.keys(merged).length > 0 ? merged : undefined; } +export function resolveProviderRequestPolicyConfig( + params: ResolveProviderRequestPolicyConfigParams, +): ResolvedProviderRequestPolicyConfig { + const baseUrl = normalizeBaseUrl(params.baseUrl, params.defaultBaseUrl); + const capability = params.capability ?? "llm"; + const transport = params.transport ?? "http"; + const policyInput = { + provider: params.provider, + api: params.api, + baseUrl, + capability, + transport, + } satisfies Parameters[0]; + const policy = resolveProviderRequestPolicy(policyInput); + const capabilities = resolveProviderRequestCapabilities({ + ...policyInput, + compat: params.compat, + modelId: params.modelId, + }); + const defaultHeaders = mergeProviderRequestHeaders( + params.discoveredHeaders, + params.providerHeaders, + params.modelHeaders, + ); + const protectedAttributionKeys = new Set( + Object.keys(policy.attributionHeaders ?? {}).map((key) => key.toLowerCase()), + ); + const unprotectedCallerHeaders = params.callerHeaders + ? Object.fromEntries( + Object.entries(params.callerHeaders).filter( + ([key]) => !protectedAttributionKeys.has(key.toLowerCase()), + ), + ) + : undefined; + const mergedDefaults = mergeProviderRequestHeaders(defaultHeaders, policy.attributionHeaders); + const headers = + params.precedence === "caller-wins" + ? mergeProviderRequestHeaders(mergedDefaults, unprotectedCallerHeaders) + : mergeProviderRequestHeaders(unprotectedCallerHeaders, mergedDefaults); + + return { + api: params.api, + baseUrl, + headers, + auth: { + mode: params.authHeader ? "authorization-bearer" : "provider-default", + injectAuthorizationHeader: params.authHeader === true, + }, + proxy: { configured: false }, + tls: { configured: false }, + policy, + capabilities, + allowPrivateNetwork: params.allowPrivateNetwork ?? Boolean(params.baseUrl?.trim()), + }; +} + export function resolveProviderRequestConfig(params: { provider: string; api?: RequestApi; @@ -61,31 +176,22 @@ export function resolveProviderRequestConfig(params: { modelHeaders?: Record; authHeader?: boolean; }): ResolvedProviderRequestConfig { - const policy = resolveProviderRequestPolicy({ - provider: params.provider, - api: params.api, - baseUrl: params.baseUrl, - capability: params.capability ?? "llm", - transport: params.transport ?? "http", - }); - + const resolved = resolveProviderRequestPolicyConfig(params); return { - api: params.api, - baseUrl: params.baseUrl, + api: resolved.api, + baseUrl: resolved.baseUrl, + // Model resolution intentionally excludes attribution headers. Those are + // applied later at transport/request time so native-host gating stays tied + // to the final resolved route instead of the catalog/config merge step. headers: mergeProviderRequestHeaders( params.discoveredHeaders, params.providerHeaders, params.modelHeaders, ), - auth: { - mode: params.authHeader ? "authorization-bearer" : "provider-default", - injectAuthorizationHeader: params.authHeader === true, - }, - // These slots are intentionally internal-first. Future provider request - // policy work can populate them without reshaping existing callers again. - proxy: { configured: false }, - tls: { configured: false }, - policy, + auth: resolved.auth, + proxy: resolved.proxy, + tls: resolved.tls, + policy: resolved.policy, }; } @@ -99,21 +205,14 @@ export function resolveProviderRequestHeaders(params: { defaultHeaders?: Record; precedence?: ProviderRequestHeaderPrecedence; }): Record | undefined { - const requestConfig = resolveProviderRequestConfig({ + return resolveProviderRequestPolicyConfig({ provider: params.provider, api: params.api, baseUrl: params.baseUrl, capability: params.capability, transport: params.transport, + callerHeaders: params.callerHeaders, providerHeaders: params.defaultHeaders, - }); - const mergedDefaults = mergeProviderRequestHeaders( - requestConfig.headers, - requestConfig.policy.attributionHeaders, - ); - // When precedence is omitted, defaults-win is the conservative choice: - // attribution/default headers cannot be silently overridden by callers. - return params.precedence === "caller-wins" - ? mergeProviderRequestHeaders(mergedDefaults, params.callerHeaders) - : mergeProviderRequestHeaders(params.callerHeaders, mergedDefaults); + precedence: params.precedence, + }).headers; } diff --git a/src/media-understanding/shared.test.ts b/src/media-understanding/shared.test.ts index b4242ee2c0c..1fb7ca391de 100644 --- a/src/media-understanding/shared.test.ts +++ b/src/media-understanding/shared.test.ts @@ -2,13 +2,14 @@ import { describe, expect, it } from "vitest"; import { resolveProviderHttpRequestConfig } from "./shared.js"; describe("resolveProviderHttpRequestConfig", () => { - it("preserves explicit caller headers over default and attribution headers", () => { + it("preserves explicit caller headers but protects attribution headers", () => { const resolved = resolveProviderHttpRequestConfig({ baseUrl: "https://api.openai.com/v1/", defaultBaseUrl: "https://api.openai.com/v1", headers: { authorization: "Bearer override", "User-Agent": "custom-agent/1.0", + originator: "spoofed", }, defaultHeaders: { authorization: "Bearer default-token", @@ -24,7 +25,7 @@ describe("resolveProviderHttpRequestConfig", () => { expect(resolved.allowPrivateNetwork).toBe(true); expect(resolved.headers.get("authorization")).toBe("Bearer override"); expect(resolved.headers.get("x-default")).toBe("1"); - expect(resolved.headers.get("user-agent")).toBe("custom-agent/1.0"); + expect(resolved.headers.get("user-agent")).toMatch(/^openclaw\//); expect(resolved.headers.get("originator")).toBe("openclaw"); expect(resolved.headers.get("version")).toBeTruthy(); }); @@ -63,4 +64,13 @@ describe("resolveProviderHttpRequestConfig", () => { expect(resolved.allowPrivateNetwork).toBe(false); expect(resolved.headers.get("x-goog-api-key")).toBe("test-key"); }); + + it("fails fast when no base URL can be resolved", () => { + expect(() => + resolveProviderHttpRequestConfig({ + baseUrl: " ", + defaultBaseUrl: " ", + }), + ).toThrow("Missing baseUrl"); + }); }); diff --git a/src/media-understanding/shared.ts b/src/media-understanding/shared.ts index 3c5578bbee3..f8d7e862196 100644 --- a/src/media-understanding/shared.ts +++ b/src/media-understanding/shared.ts @@ -2,58 +2,19 @@ import type { ProviderRequestCapability, ProviderRequestTransport, } from "../agents/provider-attribution.js"; -import { resolveProviderRequestAttributionHeaders } from "../agents/provider-attribution.js"; import { - resolveProviderRequestConfig, + normalizeBaseUrl, + resolveProviderRequestPolicyConfig, type ResolvedProviderRequestConfig, } from "../agents/provider-request-config.js"; import type { GuardedFetchResult } from "../infra/net/fetch-guard.js"; import { fetchWithSsrFGuard } from "../infra/net/fetch-guard.js"; import type { LookupFn, SsrFPolicy } from "../infra/net/ssrf.js"; export { fetchWithTimeout } from "../utils/fetch-timeout.js"; +export { normalizeBaseUrl } from "../agents/provider-request-config.js"; const MAX_ERROR_CHARS = 300; -export function normalizeBaseUrl(baseUrl: string | undefined, fallback: string): string { - const raw = baseUrl?.trim() || fallback; - return raw.replace(/\/+$/, ""); -} - -export function applyProviderRequestHeaders(params: { - headers?: HeadersInit; - defaultHeaders?: Record; - provider?: string; - api?: string; - baseUrl?: string; - capability?: ProviderRequestCapability; - transport?: ProviderRequestTransport; -}): Headers { - const headers = new Headers(params.headers); - if (params.defaultHeaders) { - for (const [key, value] of Object.entries(params.defaultHeaders)) { - if (!headers.has(key)) { - headers.set(key, value); - } - } - } - const attributionHeaders = resolveProviderRequestAttributionHeaders({ - provider: params.provider, - api: params.api, - baseUrl: params.baseUrl, - capability: params.capability ?? "other", - transport: params.transport ?? "http", - }); - if (!attributionHeaders) { - return headers; - } - for (const [key, value] of Object.entries(attributionHeaders)) { - if (!headers.has(key)) { - headers.set(key, value); - } - } - return headers; -} - export function resolveProviderHttpRequestConfig(params: { baseUrl?: string; defaultBaseUrl: string; @@ -70,33 +31,29 @@ export function resolveProviderHttpRequestConfig(params: { headers: Headers; requestConfig: ResolvedProviderRequestConfig; } { - const baseUrl = normalizeBaseUrl(params.baseUrl, params.defaultBaseUrl); - const requestConfigParams: Parameters[0] = { + const requestConfig = resolveProviderRequestPolicyConfig({ provider: params.provider ?? "", - baseUrl, + baseUrl: params.baseUrl, + defaultBaseUrl: params.defaultBaseUrl, capability: params.capability ?? "other", transport: params.transport ?? "http", - }; - if (params.api !== undefined) { - requestConfigParams.api = params.api; + callerHeaders: params.headers + ? Object.fromEntries(new Headers(params.headers).entries()) + : undefined, + providerHeaders: params.defaultHeaders, + precedence: "caller-wins", + allowPrivateNetwork: params.allowPrivateNetwork, + api: params.api, + }); + const headers = new Headers(requestConfig.headers); + if (!requestConfig.baseUrl) { + throw new Error("Missing baseUrl: provide baseUrl or defaultBaseUrl"); } - if (params.defaultHeaders !== undefined) { - requestConfigParams.providerHeaders = params.defaultHeaders; - } - const requestConfig = resolveProviderRequestConfig(requestConfigParams); return { - baseUrl, - allowPrivateNetwork: params.allowPrivateNetwork ?? Boolean(params.baseUrl?.trim()), - headers: applyProviderRequestHeaders({ - headers: params.headers, - defaultHeaders: requestConfig.headers, - provider: params.provider, - api: params.api, - baseUrl, - capability: params.capability, - transport: params.transport, - }), + baseUrl: requestConfig.baseUrl, + allowPrivateNetwork: requestConfig.allowPrivateNetwork, + headers, requestConfig, }; }