From 44881b0222e2c6e017a1543ac054008d3067d11b Mon Sep 17 00:00:00 2001 From: Peter Steinberger Date: Sat, 7 Mar 2026 16:45:55 +0000 Subject: [PATCH] fix(diffs): harden proxied local viewer detection --- extensions/diffs/src/http.test.ts | 61 +++++++++++++++++++++++++++++- extensions/diffs/src/http.ts | 62 ++++++++++++++++++++++--------- 2 files changed, 104 insertions(+), 19 deletions(-) diff --git a/extensions/diffs/src/http.test.ts b/extensions/diffs/src/http.test.ts index b9a0fee6e59..5e8c2927691 100644 --- a/extensions/diffs/src/http.test.ts +++ b/extensions/diffs/src/http.test.ts @@ -135,6 +135,29 @@ describe("createDiffsHttpHandler", () => { expect(res.statusCode).toBe(404); }); + it("blocks loopback requests that carry proxy forwarding headers by default", async () => { + const artifact = await store.createArtifact({ + html: "viewer", + title: "Demo", + inputKind: "before_after", + fileCount: 1, + }); + + const handler = createDiffsHttpHandler({ store }); + const res = createMockServerResponse(); + const handled = await handler( + localReq({ + method: "GET", + url: artifact.viewerPath, + headers: { "x-forwarded-for": "203.0.113.10" }, + }), + res, + ); + + expect(handled).toBe(true); + expect(res.statusCode).toBe(404); + }); + it("allows remote access when allowRemoteViewer is enabled", async () => { const artifact = await store.createArtifact({ html: "viewer", @@ -158,6 +181,30 @@ describe("createDiffsHttpHandler", () => { expect(res.body).toBe("viewer"); }); + it("allows proxied loopback requests when allowRemoteViewer is enabled", async () => { + const artifact = await store.createArtifact({ + html: "viewer", + title: "Demo", + inputKind: "before_after", + fileCount: 1, + }); + + const handler = createDiffsHttpHandler({ store, allowRemoteViewer: true }); + const res = createMockServerResponse(); + const handled = await handler( + localReq({ + method: "GET", + url: artifact.viewerPath, + headers: { "x-forwarded-for": "203.0.113.10" }, + }), + res, + ); + + expect(handled).toBe(true); + expect(res.statusCode).toBe(200); + expect(res.body).toBe("viewer"); + }); + it("rate-limits repeated remote misses", async () => { const handler = createDiffsHttpHandler({ store, allowRemoteViewer: true }); @@ -185,16 +232,26 @@ describe("createDiffsHttpHandler", () => { }); }); -function localReq(input: { method: string; url: string }): IncomingMessage { +function localReq(input: { + method: string; + url: string; + headers?: Record; +}): IncomingMessage { return { ...input, + headers: input.headers ?? {}, socket: { remoteAddress: "127.0.0.1" }, } as unknown as IncomingMessage; } -function remoteReq(input: { method: string; url: string }): IncomingMessage { +function remoteReq(input: { + method: string; + url: string; + headers?: Record; +}): IncomingMessage { return { ...input, + headers: input.headers ?? {}, socket: { remoteAddress: "203.0.113.10" }, } as unknown as IncomingMessage; } diff --git a/extensions/diffs/src/http.ts b/extensions/diffs/src/http.ts index 0f17e77fd9e..df141ca0d38 100644 --- a/extensions/diffs/src/http.ts +++ b/extensions/diffs/src/http.ts @@ -42,9 +42,8 @@ export function createDiffsHttpHandler(params: { return false; } - const remoteKey = normalizeRemoteClientKey(req.socket?.remoteAddress); - const localRequest = isLoopbackClientIp(remoteKey); - if (!localRequest && params.allowRemoteViewer !== true) { + const access = resolveViewerAccess(req); + if (!access.localRequest && params.allowRemoteViewer !== true) { respondText(res, 404, "Diff not found"); return true; } @@ -54,8 +53,8 @@ export function createDiffsHttpHandler(params: { return true; } - if (!localRequest) { - const throttled = viewerFailureLimiter.check(remoteKey); + if (!access.localRequest) { + const throttled = viewerFailureLimiter.check(access.remoteKey); if (!throttled.allowed) { res.statusCode = 429; setSharedHeaders(res, "text/plain; charset=utf-8"); @@ -74,27 +73,21 @@ export function createDiffsHttpHandler(params: { !DIFF_ARTIFACT_ID_PATTERN.test(id) || !DIFF_ARTIFACT_TOKEN_PATTERN.test(token) ) { - if (!localRequest) { - viewerFailureLimiter.recordFailure(remoteKey); - } + recordRemoteFailure(viewerFailureLimiter, access); respondText(res, 404, "Diff not found"); return true; } const artifact = await params.store.getArtifact(id, token); if (!artifact) { - if (!localRequest) { - viewerFailureLimiter.recordFailure(remoteKey); - } + recordRemoteFailure(viewerFailureLimiter, access); respondText(res, 404, "Diff not found or expired"); return true; } try { const html = await params.store.readHtml(id); - if (!localRequest) { - viewerFailureLimiter.reset(remoteKey); - } + resetRemoteFailures(viewerFailureLimiter, access); res.statusCode = 200; setSharedHeaders(res, "text/html; charset=utf-8"); res.setHeader("content-security-policy", VIEWER_CONTENT_SECURITY_POLICY); @@ -105,9 +98,7 @@ export function createDiffsHttpHandler(params: { } return true; } catch (error) { - if (!localRequest) { - viewerFailureLimiter.recordFailure(remoteKey); - } + recordRemoteFailure(viewerFailureLimiter, access); params.logger?.warn(`Failed to serve diff artifact ${id}: ${String(error)}`); respondText(res, 500, "Failed to load diff"); return true; @@ -184,6 +175,43 @@ function isLoopbackClientIp(clientIp: string): boolean { return clientIp === "127.0.0.1" || clientIp === "::1"; } +function hasProxyForwardingHints(req: IncomingMessage): boolean { + return Boolean( + req.headers["x-forwarded-for"] || + req.headers["x-real-ip"] || + req.headers.forwarded || + req.headers["x-forwarded-host"] || + req.headers["x-forwarded-proto"], + ); +} + +function resolveViewerAccess(req: IncomingMessage): { + remoteKey: string; + localRequest: boolean; +} { + const remoteKey = normalizeRemoteClientKey(req.socket?.remoteAddress); + const localRequest = isLoopbackClientIp(remoteKey) && !hasProxyForwardingHints(req); + return { remoteKey, localRequest }; +} + +function recordRemoteFailure( + limiter: ViewerFailureLimiter, + access: { remoteKey: string; localRequest: boolean }, +): void { + if (!access.localRequest) { + limiter.recordFailure(access.remoteKey); + } +} + +function resetRemoteFailures( + limiter: ViewerFailureLimiter, + access: { remoteKey: string; localRequest: boolean }, +): void { + if (!access.localRequest) { + limiter.reset(access.remoteKey); + } +} + type RateLimitCheckResult = { allowed: boolean; retryAfterMs: number;