refactor: unify channel/plugin ssrf fetch policy and auth fallback

This commit is contained in:
Peter Steinberger
2026-02-26 16:43:44 +01:00
parent 2e97d0dd95
commit 57334cd7d8
13 changed files with 749 additions and 595 deletions

View File

@@ -0,0 +1,94 @@
import { describe, expect, it, vi } from "vitest";
import { fetchWithBearerAuthScopeFallback } from "./fetch-auth.js";
describe("fetchWithBearerAuthScopeFallback", () => {
it("rejects non-https urls when https is required", async () => {
await expect(
fetchWithBearerAuthScopeFallback({
url: "http://example.com/file",
scopes: [],
requireHttps: true,
}),
).rejects.toThrow("URL must use HTTPS");
});
it("returns immediately when the first attempt succeeds", async () => {
const fetchFn = vi.fn(async () => new Response("ok", { status: 200 }));
const tokenProvider = { getAccessToken: vi.fn(async () => "unused") };
const response = await fetchWithBearerAuthScopeFallback({
url: "https://example.com/file",
scopes: ["https://graph.microsoft.com"],
fetchFn,
tokenProvider,
});
expect(response.status).toBe(200);
expect(fetchFn).toHaveBeenCalledTimes(1);
expect(tokenProvider.getAccessToken).not.toHaveBeenCalled();
});
it("retries with auth scopes after a 401 response", async () => {
const fetchFn = vi
.fn()
.mockResolvedValueOnce(new Response("unauthorized", { status: 401 }))
.mockResolvedValueOnce(new Response("ok", { status: 200 }));
const tokenProvider = { getAccessToken: vi.fn(async () => "token-1") };
const response = await fetchWithBearerAuthScopeFallback({
url: "https://graph.microsoft.com/v1.0/me",
scopes: ["https://graph.microsoft.com", "https://api.botframework.com"],
fetchFn,
tokenProvider,
});
expect(response.status).toBe(200);
expect(fetchFn).toHaveBeenCalledTimes(2);
expect(tokenProvider.getAccessToken).toHaveBeenCalledWith("https://graph.microsoft.com");
const secondCall = fetchFn.mock.calls[1] as [string, RequestInit | undefined];
const secondHeaders = new Headers(secondCall[1]?.headers);
expect(secondHeaders.get("authorization")).toBe("Bearer token-1");
});
it("does not attach auth when host predicate rejects url", async () => {
const fetchFn = vi.fn(async () => new Response("unauthorized", { status: 401 }));
const tokenProvider = { getAccessToken: vi.fn(async () => "token-1") };
const response = await fetchWithBearerAuthScopeFallback({
url: "https://example.com/file",
scopes: ["https://graph.microsoft.com"],
fetchFn,
tokenProvider,
shouldAttachAuth: () => false,
});
expect(response.status).toBe(401);
expect(fetchFn).toHaveBeenCalledTimes(1);
expect(tokenProvider.getAccessToken).not.toHaveBeenCalled();
});
it("continues across scopes when token retrieval fails", async () => {
const fetchFn = vi
.fn()
.mockResolvedValueOnce(new Response("unauthorized", { status: 401 }))
.mockResolvedValueOnce(new Response("ok", { status: 200 }));
const tokenProvider = {
getAccessToken: vi
.fn()
.mockRejectedValueOnce(new Error("first scope failed"))
.mockResolvedValueOnce("token-2"),
};
const response = await fetchWithBearerAuthScopeFallback({
url: "https://graph.microsoft.com/v1.0/me",
scopes: ["https://first.example", "https://second.example"],
fetchFn,
tokenProvider,
});
expect(response.status).toBe(200);
expect(tokenProvider.getAccessToken).toHaveBeenCalledTimes(2);
expect(tokenProvider.getAccessToken).toHaveBeenNthCalledWith(1, "https://first.example");
expect(tokenProvider.getAccessToken).toHaveBeenNthCalledWith(2, "https://second.example");
});
});

View File

@@ -0,0 +1,71 @@
export type ScopeTokenProvider = {
getAccessToken: (scope: string) => Promise<string>;
};
function isAuthFailureStatus(status: number): boolean {
return status === 401 || status === 403;
}
export async function fetchWithBearerAuthScopeFallback(params: {
url: string;
scopes: readonly string[];
tokenProvider?: ScopeTokenProvider;
fetchFn?: typeof fetch;
requestInit?: RequestInit;
requireHttps?: boolean;
shouldAttachAuth?: (url: string) => boolean;
shouldRetry?: (response: Response) => boolean;
}): Promise<Response> {
const fetchFn = params.fetchFn ?? fetch;
let parsedUrl: URL;
try {
parsedUrl = new URL(params.url);
} catch {
throw new Error(`Invalid URL: ${params.url}`);
}
if (params.requireHttps === true && parsedUrl.protocol !== "https:") {
throw new Error(`URL must use HTTPS: ${params.url}`);
}
const fetchOnce = (headers?: Headers): Promise<Response> =>
fetchFn(params.url, {
...params.requestInit,
...(headers ? { headers } : {}),
});
const firstAttempt = await fetchOnce();
if (firstAttempt.ok) {
return firstAttempt;
}
if (!params.tokenProvider) {
return firstAttempt;
}
const shouldRetry =
params.shouldRetry ?? ((response: Response) => isAuthFailureStatus(response.status));
if (!shouldRetry(firstAttempt)) {
return firstAttempt;
}
if (params.shouldAttachAuth && !params.shouldAttachAuth(params.url)) {
return firstAttempt;
}
for (const scope of params.scopes) {
try {
const token = await params.tokenProvider.getAccessToken(scope);
const authHeaders = new Headers(params.requestInit?.headers);
authHeaders.set("Authorization", `Bearer ${token}`);
const authAttempt = await fetchOnce(authHeaders);
if (authAttempt.ok) {
return authAttempt;
}
if (!shouldRetry(authAttempt)) {
continue;
}
} catch {
// Ignore token/fetch errors and continue trying remaining scopes.
}
}
return firstAttempt;
}

View File

@@ -292,6 +292,13 @@ export {
isPrivateIpAddress,
} from "../infra/net/ssrf.js";
export type { LookupFn, SsrFPolicy } from "../infra/net/ssrf.js";
export {
buildHostnameAllowlistPolicyFromSuffixAllowlist,
isHttpsUrlAllowedByHostnameSuffixAllowlist,
normalizeHostnameSuffixAllowlist,
} from "./ssrf-policy.js";
export { fetchWithBearerAuthScopeFallback } from "./fetch-auth.js";
export type { ScopeTokenProvider } from "./fetch-auth.js";
export { rawDataToString } from "../infra/ws.js";
export { isWSLSync, isWSL2Sync, isWSLEnv } from "../infra/wsl.js";
export { isTruthyEnvValue } from "../infra/env.js";

View File

@@ -0,0 +1,84 @@
import { describe, expect, it } from "vitest";
import {
buildHostnameAllowlistPolicyFromSuffixAllowlist,
isHttpsUrlAllowedByHostnameSuffixAllowlist,
normalizeHostnameSuffixAllowlist,
} from "./ssrf-policy.js";
describe("normalizeHostnameSuffixAllowlist", () => {
it("uses defaults when input is missing", () => {
expect(normalizeHostnameSuffixAllowlist(undefined, ["GRAPH.MICROSOFT.COM"])).toEqual([
"graph.microsoft.com",
]);
});
it("normalizes wildcard prefixes and deduplicates", () => {
expect(
normalizeHostnameSuffixAllowlist([
"*.TrafficManager.NET",
".trafficmanager.net.",
" * ",
"x",
]),
).toEqual(["*"]);
});
});
describe("isHttpsUrlAllowedByHostnameSuffixAllowlist", () => {
it("requires https", () => {
expect(
isHttpsUrlAllowedByHostnameSuffixAllowlist("http://a.example.com/x", ["example.com"]),
).toBe(false);
});
it("supports exact and suffix match", () => {
expect(
isHttpsUrlAllowedByHostnameSuffixAllowlist("https://example.com/x", ["example.com"]),
).toBe(true);
expect(
isHttpsUrlAllowedByHostnameSuffixAllowlist("https://a.example.com/x", ["example.com"]),
).toBe(true);
expect(isHttpsUrlAllowedByHostnameSuffixAllowlist("https://evil.com/x", ["example.com"])).toBe(
false,
);
});
it("supports wildcard allowlist", () => {
expect(isHttpsUrlAllowedByHostnameSuffixAllowlist("https://evil.com/x", ["*"])).toBe(true);
});
});
describe("buildHostnameAllowlistPolicyFromSuffixAllowlist", () => {
it("returns undefined when allowHosts is empty", () => {
expect(buildHostnameAllowlistPolicyFromSuffixAllowlist()).toBeUndefined();
expect(buildHostnameAllowlistPolicyFromSuffixAllowlist([])).toBeUndefined();
});
it("returns undefined when wildcard host is present", () => {
expect(buildHostnameAllowlistPolicyFromSuffixAllowlist(["*"])).toBeUndefined();
expect(buildHostnameAllowlistPolicyFromSuffixAllowlist(["example.com", "*"])).toBeUndefined();
});
it("expands a suffix entry to exact + wildcard hostname allowlist patterns", () => {
expect(buildHostnameAllowlistPolicyFromSuffixAllowlist(["sharepoint.com"])).toEqual({
hostnameAllowlist: ["sharepoint.com", "*.sharepoint.com"],
});
});
it("normalizes wildcard prefixes, leading/trailing dots, and deduplicates patterns", () => {
expect(
buildHostnameAllowlistPolicyFromSuffixAllowlist([
"*.TrafficManager.NET",
".trafficmanager.net.",
" blob.core.windows.net ",
]),
).toEqual({
hostnameAllowlist: [
"trafficmanager.net",
"*.trafficmanager.net",
"blob.core.windows.net",
"*.blob.core.windows.net",
],
});
});
});

View File

@@ -0,0 +1,85 @@
import type { SsrFPolicy } from "../infra/net/ssrf.js";
function normalizeHostnameSuffix(value: string): string {
const trimmed = value.trim().toLowerCase();
if (!trimmed) {
return "";
}
if (trimmed === "*" || trimmed === "*.") {
return "*";
}
const withoutWildcard = trimmed.replace(/^\*\.?/, "");
const withoutLeadingDot = withoutWildcard.replace(/^\.+/, "");
return withoutLeadingDot.replace(/\.+$/, "");
}
function isHostnameAllowedBySuffixAllowlist(
hostname: string,
allowlist: readonly string[],
): boolean {
if (allowlist.includes("*")) {
return true;
}
const normalized = hostname.toLowerCase();
return allowlist.some((entry) => normalized === entry || normalized.endsWith(`.${entry}`));
}
export function normalizeHostnameSuffixAllowlist(
input?: readonly string[],
defaults?: readonly string[],
): string[] {
const source = input && input.length > 0 ? input : defaults;
if (!source || source.length === 0) {
return [];
}
const normalized = source.map(normalizeHostnameSuffix).filter(Boolean);
if (normalized.includes("*")) {
return ["*"];
}
return Array.from(new Set(normalized));
}
export function isHttpsUrlAllowedByHostnameSuffixAllowlist(
url: string,
allowlist: readonly string[],
): boolean {
try {
const parsed = new URL(url);
if (parsed.protocol !== "https:") {
return false;
}
return isHostnameAllowedBySuffixAllowlist(parsed.hostname, allowlist);
} catch {
return false;
}
}
/**
* Converts suffix-style host allowlists (for example "example.com") into SSRF
* hostname allowlist patterns used by the shared fetch guard.
*
* Suffix semantics:
* - "example.com" allows "example.com" and "*.example.com"
* - "*" disables hostname allowlist restrictions
*/
export function buildHostnameAllowlistPolicyFromSuffixAllowlist(
allowHosts?: readonly string[],
): SsrFPolicy | undefined {
const normalizedAllowHosts = normalizeHostnameSuffixAllowlist(allowHosts);
if (normalizedAllowHosts.length === 0) {
return undefined;
}
const patterns = new Set<string>();
for (const normalized of normalizedAllowHosts) {
if (normalized === "*") {
return undefined;
}
patterns.add(normalized);
patterns.add(`*.${normalized}`);
}
if (patterns.size === 0) {
return undefined;
}
return { hostnameAllowlist: Array.from(patterns) };
}