fix(providers): unify request policy resolution (#59653)

* fix(providers): unify request policy resolution

* fix(providers): preserve request config SDK contract

* fix(providers): harden request header policy
This commit is contained in:
Vincent Koc
2026-04-02 21:42:11 +09:00
committed by GitHub
parent d4f69878da
commit cfbad0a4f9
8 changed files with 259 additions and 121 deletions

View File

@@ -15,7 +15,7 @@
import { EventEmitter } from "node:events";
import WebSocket, { type ClientOptions } from "ws";
import { resolveProviderRequestHeaders } from "./provider-request-config.js";
import { resolveProviderRequestPolicyConfig } from "./provider-request-config.js";
// ─────────────────────────────────────────────────────────────────────────────
// WebSocket Event Types (Server → Client)
@@ -403,18 +403,18 @@ export class OpenAIWebSocketManager extends EventEmitter<InternalEvents> {
}
const socket = this.socketFactory(this.wsUrl, {
headers: resolveProviderRequestHeaders({
headers: resolveProviderRequestPolicyConfig({
provider: "openai",
api: "openai-responses",
baseUrl: this.wsUrl,
capability: "llm",
transport: "websocket",
defaultHeaders: {
providerHeaders: {
Authorization: `Bearer ${this.apiKey}`,
"OpenAI-Beta": "responses-websocket=v1",
},
precedence: "defaults-win",
}),
}).headers,
});
this.ws = socket;

View File

@@ -42,7 +42,7 @@ import {
} from "./openai-ws-message-conversion.js";
import { log } from "./pi-embedded-runner/logger.js";
import { resolveOpenAITextVerbosity } from "./pi-embedded-runner/openai-stream-wrappers.js";
import { resolveProviderRequestCapabilities } from "./provider-attribution.js";
import { resolveProviderRequestPolicyConfig } from "./provider-request-config.js";
import {
buildAssistantMessageWithZeroUsage,
buildStreamErrorAssistantMessage,
@@ -486,14 +486,14 @@ export function createOpenAIWebSocketStreamFn(
// Respect compat.supportsStore — providers like Gemini reject unknown
// fields such as `store` with a 400 error. Fixes #39086.
const supportsResponsesStoreField = resolveProviderRequestCapabilities({
const supportsResponsesStoreField = resolveProviderRequestPolicyConfig({
provider: typeof model.provider === "string" ? model.provider : undefined,
api: typeof model.api === "string" ? model.api : undefined,
baseUrl: typeof model.baseUrl === "string" ? model.baseUrl : undefined,
compat: (model as { compat?: { supportsStore?: boolean } }).compat,
capability: "llm",
transport: "websocket",
}).supportsResponsesStoreField;
}).capabilities.supportsResponsesStoreField;
const payload: Record<string, unknown> = {
type: "response.create",

View File

@@ -6,8 +6,7 @@ import {
patchCodexNativeWebSearchPayload,
resolveCodexNativeSearchActivation,
} from "../codex-native-web-search.js";
import { resolveProviderRequestCapabilities } from "../provider-attribution.js";
import { resolveProviderRequestHeaders } from "../provider-request-config.js";
import { resolveProviderRequestPolicyConfig } from "../provider-request-config.js";
import { log } from "./logger.js";
import { streamWithPayloadPatch } from "./stream-payload-utils.js";
@@ -28,14 +27,14 @@ function resolveOpenAIRequestCapabilities(model: {
baseUrl?: unknown;
compat?: { supportsStore?: boolean };
}) {
return resolveProviderRequestCapabilities({
return resolveProviderRequestPolicyConfig({
provider: typeof model.provider === "string" ? model.provider : undefined,
api: typeof model.api === "string" ? model.api : undefined,
baseUrl: typeof model.baseUrl === "string" ? model.baseUrl : undefined,
compat: model.compat,
capability: "llm",
transport: "stream",
});
}).capabilities;
}
function shouldApplyOpenAIAttributionHeaders(model: {
@@ -502,7 +501,7 @@ export function createOpenAIAttributionHeadersWrapper(
}
return underlying(model, context, {
...options,
headers: resolveProviderRequestHeaders({
headers: resolveProviderRequestPolicyConfig({
provider: attributionProvider,
api: typeof model.api === "string" ? model.api : undefined,
baseUrl: typeof model.baseUrl === "string" ? model.baseUrl : undefined,
@@ -510,7 +509,7 @@ export function createOpenAIAttributionHeadersWrapper(
transport: "stream",
callerHeaders: options?.headers,
precedence: "defaults-win",
}),
}).headers,
});
};
}

View File

@@ -1,7 +1,7 @@
import type { StreamFn } from "@mariozechner/pi-agent-core";
import { streamSimple } from "@mariozechner/pi-ai";
import type { ThinkLevel } from "../../auto-reply/thinking.js";
import { resolveProviderRequestHeaders } from "../provider-request-config.js";
import { resolveProviderRequestPolicyConfig } from "../provider-request-config.js";
import { streamWithPayloadPatch } from "./stream-payload-utils.js";
const KILOCODE_FEATURE_HEADER = "X-KILOCODE-FEATURE";
const KILOCODE_FEATURE_DEFAULT = "openclaw";
@@ -111,7 +111,7 @@ export function createOpenRouterWrapper(
): StreamFn {
const underlying = baseStreamFn ?? streamSimple;
return (model, context, options) => {
const headers = resolveProviderRequestHeaders({
const headers = resolveProviderRequestPolicyConfig({
provider: typeof model.provider === "string" ? model.provider : "openrouter",
api: typeof model.api === "string" ? model.api : undefined,
baseUrl: typeof model.baseUrl === "string" ? model.baseUrl : undefined,
@@ -119,7 +119,7 @@ export function createOpenRouterWrapper(
transport: "stream",
callerHeaders: options?.headers,
precedence: "caller-wins",
});
}).headers;
return streamWithPayloadPatch(
underlying,
model,
@@ -145,16 +145,16 @@ export function createKilocodeWrapper(
): StreamFn {
const underlying = baseStreamFn ?? streamSimple;
return (model, context, options) => {
const headers = resolveProviderRequestHeaders({
const headers = resolveProviderRequestPolicyConfig({
provider: typeof model.provider === "string" ? model.provider : "kilocode",
api: typeof model.api === "string" ? model.api : undefined,
baseUrl: typeof model.baseUrl === "string" ? model.baseUrl : undefined,
capability: "llm",
transport: "stream",
callerHeaders: options?.headers,
defaultHeaders: resolveKilocodeAppHeaders(),
providerHeaders: resolveKilocodeAppHeaders(),
precedence: "defaults-win",
});
}).headers;
return streamWithPayloadPatch(
underlying,
model,

View File

@@ -1,5 +1,6 @@
import { describe, expect, it } from "vitest";
import {
resolveProviderRequestPolicyConfig,
resolveProviderRequestConfig,
resolveProviderRequestHeaders,
} from "./provider-request-config.js";
@@ -103,10 +104,82 @@ describe("provider request config", () => {
});
expect(resolved).toEqual({
"HTTP-Referer": "https://example.com",
"HTTP-Referer": "https://openclaw.ai",
"X-OpenRouter-Title": "OpenClaw",
"X-OpenRouter-Categories": "cli-agent",
"X-Custom": "1",
});
});
it("merges header names case-insensitively", () => {
const resolved = resolveProviderRequestHeaders({
provider: "openai",
api: "openai-responses",
baseUrl: "https://api.openai.com/v1",
capability: "llm",
transport: "stream",
callerHeaders: {
"user-agent": "custom-agent/1.0",
},
precedence: "caller-wins",
});
expect(
Object.keys(resolved ?? {}).filter((key) => key.toLowerCase() === "user-agent"),
).toHaveLength(1);
expect(resolved?.["User-Agent"]).toMatch(/^openclaw\//);
});
it("drops forbidden header keys while merging", () => {
const resolved = resolveProviderRequestHeaders({
provider: "custom-openai",
callerHeaders: {
__proto__: "polluted",
constructor: "polluted",
"X-Custom": "1",
} as Record<string, string>,
defaultHeaders: {
prototype: "polluted",
} as Record<string, string>,
});
expect(resolved).toEqual({
"X-Custom": "1",
});
expect(Object.getPrototypeOf(resolved ?? {})).toBeNull();
});
it("unifies policy, capabilities, headers, base URL, and private-network posture", () => {
const resolved = resolveProviderRequestPolicyConfig({
provider: "openai",
api: "openai-responses",
baseUrl: "https://api.openai.com/v1/",
defaultBaseUrl: "https://fallback.example/v1/",
callerHeaders: {
"User-Agent": "custom-agent/1.0",
"X-Custom": "1",
},
providerHeaders: {
authorization: "Bearer test-key",
},
compat: {
supportsStore: true,
},
capability: "llm",
transport: "stream",
precedence: "defaults-win",
});
expect(resolved.baseUrl).toBe("https://api.openai.com/v1");
expect(resolved.allowPrivateNetwork).toBe(true);
expect(resolved.policy.endpointClass).toBe("openai-public");
expect(resolved.capabilities.allowsResponsesStore).toBe(true);
expect(resolved.headers).toMatchObject({
authorization: "Bearer test-key",
originator: "openclaw",
version: expect.any(String),
"User-Agent": expect.stringMatching(/^openclaw\//),
"X-Custom": "1",
});
});
});

View File

@@ -1,11 +1,15 @@
import type { Api } from "@mariozechner/pi-ai";
import type { ModelDefinitionConfig } from "../config/types.js";
import type {
ProviderRequestCapabilities,
ProviderRequestCapability,
ProviderRequestPolicyResolution,
ProviderRequestTransport,
} from "./provider-attribution.js";
import { resolveProviderRequestPolicy } from "./provider-attribution.js";
import {
resolveProviderRequestCapabilities,
resolveProviderRequestPolicy,
type ProviderRequestPolicyResolution,
} from "./provider-attribution.js";
type RequestApi = Api | ModelDefinitionConfig["api"];
@@ -34,22 +38,133 @@ export type ResolvedProviderRequestConfig = {
export type ProviderRequestHeaderPrecedence = "caller-wins" | "defaults-win";
export type ResolvedProviderRequestPolicyConfig = ResolvedProviderRequestConfig & {
allowPrivateNetwork: boolean;
capabilities: ProviderRequestCapabilities;
};
const FORBIDDEN_HEADER_KEYS = new Set(["__proto__", "prototype", "constructor"]);
type ResolveProviderRequestPolicyConfigParams = {
provider?: string;
api?: RequestApi;
baseUrl?: string;
defaultBaseUrl?: string;
capability?: ProviderRequestCapability;
transport?: ProviderRequestTransport;
discoveredHeaders?: Record<string, string>;
providerHeaders?: Record<string, string>;
modelHeaders?: Record<string, string>;
callerHeaders?: Record<string, string>;
precedence?: ProviderRequestHeaderPrecedence;
authHeader?: boolean;
compat?: {
supportsStore?: boolean;
} | null;
modelId?: string | null;
allowPrivateNetwork?: boolean;
};
export function normalizeBaseUrl(baseUrl: string | undefined, fallback: string): string;
export function normalizeBaseUrl(
baseUrl: string | undefined,
fallback?: string,
): string | undefined;
export function normalizeBaseUrl(
baseUrl: string | undefined,
fallback?: string,
): string | undefined {
const raw = baseUrl?.trim() || fallback?.trim();
if (!raw) {
return undefined;
}
return raw.replace(/\/+$/, "");
}
export function mergeProviderRequestHeaders(
...headerSets: Array<Record<string, string> | undefined>
): Record<string, string> | undefined {
let merged: Record<string, string> | undefined;
const headerNamesByLowerKey = new Map<string, string>();
for (const headers of headerSets) {
if (!headers) {
continue;
}
merged = {
...merged,
...headers,
};
if (!merged) {
merged = Object.create(null) as Record<string, string>;
}
for (const [key, value] of Object.entries(headers)) {
const normalizedKey = key.toLowerCase();
if (FORBIDDEN_HEADER_KEYS.has(normalizedKey)) {
continue;
}
const previousKey = headerNamesByLowerKey.get(normalizedKey);
if (previousKey && previousKey !== key) {
delete merged[previousKey];
}
merged[key] = value;
headerNamesByLowerKey.set(normalizedKey, key);
}
}
return merged && Object.keys(merged).length > 0 ? merged : undefined;
}
export function resolveProviderRequestPolicyConfig(
params: ResolveProviderRequestPolicyConfigParams,
): ResolvedProviderRequestPolicyConfig {
const baseUrl = normalizeBaseUrl(params.baseUrl, params.defaultBaseUrl);
const capability = params.capability ?? "llm";
const transport = params.transport ?? "http";
const policyInput = {
provider: params.provider,
api: params.api,
baseUrl,
capability,
transport,
} satisfies Parameters<typeof resolveProviderRequestPolicy>[0];
const policy = resolveProviderRequestPolicy(policyInput);
const capabilities = resolveProviderRequestCapabilities({
...policyInput,
compat: params.compat,
modelId: params.modelId,
});
const defaultHeaders = mergeProviderRequestHeaders(
params.discoveredHeaders,
params.providerHeaders,
params.modelHeaders,
);
const protectedAttributionKeys = new Set(
Object.keys(policy.attributionHeaders ?? {}).map((key) => key.toLowerCase()),
);
const unprotectedCallerHeaders = params.callerHeaders
? Object.fromEntries(
Object.entries(params.callerHeaders).filter(
([key]) => !protectedAttributionKeys.has(key.toLowerCase()),
),
)
: undefined;
const mergedDefaults = mergeProviderRequestHeaders(defaultHeaders, policy.attributionHeaders);
const headers =
params.precedence === "caller-wins"
? mergeProviderRequestHeaders(mergedDefaults, unprotectedCallerHeaders)
: mergeProviderRequestHeaders(unprotectedCallerHeaders, mergedDefaults);
return {
api: params.api,
baseUrl,
headers,
auth: {
mode: params.authHeader ? "authorization-bearer" : "provider-default",
injectAuthorizationHeader: params.authHeader === true,
},
proxy: { configured: false },
tls: { configured: false },
policy,
capabilities,
allowPrivateNetwork: params.allowPrivateNetwork ?? Boolean(params.baseUrl?.trim()),
};
}
export function resolveProviderRequestConfig(params: {
provider: string;
api?: RequestApi;
@@ -61,31 +176,22 @@ export function resolveProviderRequestConfig(params: {
modelHeaders?: Record<string, string>;
authHeader?: boolean;
}): ResolvedProviderRequestConfig {
const policy = resolveProviderRequestPolicy({
provider: params.provider,
api: params.api,
baseUrl: params.baseUrl,
capability: params.capability ?? "llm",
transport: params.transport ?? "http",
});
const resolved = resolveProviderRequestPolicyConfig(params);
return {
api: params.api,
baseUrl: params.baseUrl,
api: resolved.api,
baseUrl: resolved.baseUrl,
// Model resolution intentionally excludes attribution headers. Those are
// applied later at transport/request time so native-host gating stays tied
// to the final resolved route instead of the catalog/config merge step.
headers: mergeProviderRequestHeaders(
params.discoveredHeaders,
params.providerHeaders,
params.modelHeaders,
),
auth: {
mode: params.authHeader ? "authorization-bearer" : "provider-default",
injectAuthorizationHeader: params.authHeader === true,
},
// These slots are intentionally internal-first. Future provider request
// policy work can populate them without reshaping existing callers again.
proxy: { configured: false },
tls: { configured: false },
policy,
auth: resolved.auth,
proxy: resolved.proxy,
tls: resolved.tls,
policy: resolved.policy,
};
}
@@ -99,21 +205,14 @@ export function resolveProviderRequestHeaders(params: {
defaultHeaders?: Record<string, string>;
precedence?: ProviderRequestHeaderPrecedence;
}): Record<string, string> | undefined {
const requestConfig = resolveProviderRequestConfig({
return resolveProviderRequestPolicyConfig({
provider: params.provider,
api: params.api,
baseUrl: params.baseUrl,
capability: params.capability,
transport: params.transport,
callerHeaders: params.callerHeaders,
providerHeaders: params.defaultHeaders,
});
const mergedDefaults = mergeProviderRequestHeaders(
requestConfig.headers,
requestConfig.policy.attributionHeaders,
);
// When precedence is omitted, defaults-win is the conservative choice:
// attribution/default headers cannot be silently overridden by callers.
return params.precedence === "caller-wins"
? mergeProviderRequestHeaders(mergedDefaults, params.callerHeaders)
: mergeProviderRequestHeaders(params.callerHeaders, mergedDefaults);
precedence: params.precedence,
}).headers;
}

View File

@@ -2,13 +2,14 @@ import { describe, expect, it } from "vitest";
import { resolveProviderHttpRequestConfig } from "./shared.js";
describe("resolveProviderHttpRequestConfig", () => {
it("preserves explicit caller headers over default and attribution headers", () => {
it("preserves explicit caller headers but protects attribution headers", () => {
const resolved = resolveProviderHttpRequestConfig({
baseUrl: "https://api.openai.com/v1/",
defaultBaseUrl: "https://api.openai.com/v1",
headers: {
authorization: "Bearer override",
"User-Agent": "custom-agent/1.0",
originator: "spoofed",
},
defaultHeaders: {
authorization: "Bearer default-token",
@@ -24,7 +25,7 @@ describe("resolveProviderHttpRequestConfig", () => {
expect(resolved.allowPrivateNetwork).toBe(true);
expect(resolved.headers.get("authorization")).toBe("Bearer override");
expect(resolved.headers.get("x-default")).toBe("1");
expect(resolved.headers.get("user-agent")).toBe("custom-agent/1.0");
expect(resolved.headers.get("user-agent")).toMatch(/^openclaw\//);
expect(resolved.headers.get("originator")).toBe("openclaw");
expect(resolved.headers.get("version")).toBeTruthy();
});
@@ -63,4 +64,13 @@ describe("resolveProviderHttpRequestConfig", () => {
expect(resolved.allowPrivateNetwork).toBe(false);
expect(resolved.headers.get("x-goog-api-key")).toBe("test-key");
});
it("fails fast when no base URL can be resolved", () => {
expect(() =>
resolveProviderHttpRequestConfig({
baseUrl: " ",
defaultBaseUrl: " ",
}),
).toThrow("Missing baseUrl");
});
});

View File

@@ -2,58 +2,19 @@ import type {
ProviderRequestCapability,
ProviderRequestTransport,
} from "../agents/provider-attribution.js";
import { resolveProviderRequestAttributionHeaders } from "../agents/provider-attribution.js";
import {
resolveProviderRequestConfig,
normalizeBaseUrl,
resolveProviderRequestPolicyConfig,
type ResolvedProviderRequestConfig,
} from "../agents/provider-request-config.js";
import type { GuardedFetchResult } from "../infra/net/fetch-guard.js";
import { fetchWithSsrFGuard } from "../infra/net/fetch-guard.js";
import type { LookupFn, SsrFPolicy } from "../infra/net/ssrf.js";
export { fetchWithTimeout } from "../utils/fetch-timeout.js";
export { normalizeBaseUrl } from "../agents/provider-request-config.js";
const MAX_ERROR_CHARS = 300;
export function normalizeBaseUrl(baseUrl: string | undefined, fallback: string): string {
const raw = baseUrl?.trim() || fallback;
return raw.replace(/\/+$/, "");
}
export function applyProviderRequestHeaders(params: {
headers?: HeadersInit;
defaultHeaders?: Record<string, string>;
provider?: string;
api?: string;
baseUrl?: string;
capability?: ProviderRequestCapability;
transport?: ProviderRequestTransport;
}): Headers {
const headers = new Headers(params.headers);
if (params.defaultHeaders) {
for (const [key, value] of Object.entries(params.defaultHeaders)) {
if (!headers.has(key)) {
headers.set(key, value);
}
}
}
const attributionHeaders = resolveProviderRequestAttributionHeaders({
provider: params.provider,
api: params.api,
baseUrl: params.baseUrl,
capability: params.capability ?? "other",
transport: params.transport ?? "http",
});
if (!attributionHeaders) {
return headers;
}
for (const [key, value] of Object.entries(attributionHeaders)) {
if (!headers.has(key)) {
headers.set(key, value);
}
}
return headers;
}
export function resolveProviderHttpRequestConfig(params: {
baseUrl?: string;
defaultBaseUrl: string;
@@ -70,33 +31,29 @@ export function resolveProviderHttpRequestConfig(params: {
headers: Headers;
requestConfig: ResolvedProviderRequestConfig;
} {
const baseUrl = normalizeBaseUrl(params.baseUrl, params.defaultBaseUrl);
const requestConfigParams: Parameters<typeof resolveProviderRequestConfig>[0] = {
const requestConfig = resolveProviderRequestPolicyConfig({
provider: params.provider ?? "",
baseUrl,
baseUrl: params.baseUrl,
defaultBaseUrl: params.defaultBaseUrl,
capability: params.capability ?? "other",
transport: params.transport ?? "http",
};
if (params.api !== undefined) {
requestConfigParams.api = params.api;
callerHeaders: params.headers
? Object.fromEntries(new Headers(params.headers).entries())
: undefined,
providerHeaders: params.defaultHeaders,
precedence: "caller-wins",
allowPrivateNetwork: params.allowPrivateNetwork,
api: params.api,
});
const headers = new Headers(requestConfig.headers);
if (!requestConfig.baseUrl) {
throw new Error("Missing baseUrl: provide baseUrl or defaultBaseUrl");
}
if (params.defaultHeaders !== undefined) {
requestConfigParams.providerHeaders = params.defaultHeaders;
}
const requestConfig = resolveProviderRequestConfig(requestConfigParams);
return {
baseUrl,
allowPrivateNetwork: params.allowPrivateNetwork ?? Boolean(params.baseUrl?.trim()),
headers: applyProviderRequestHeaders({
headers: params.headers,
defaultHeaders: requestConfig.headers,
provider: params.provider,
api: params.api,
baseUrl,
capability: params.capability,
transport: params.transport,
}),
baseUrl: requestConfig.baseUrl,
allowPrivateNetwork: requestConfig.allowPrivateNetwork,
headers,
requestConfig,
};
}