diff --git a/src/gateway/server.auth.browser-hardening.test.ts b/src/gateway/server.auth.browser-hardening.test.ts index 070addbdc53..e9550a8b1aa 100644 --- a/src/gateway/server.auth.browser-hardening.test.ts +++ b/src/gateway/server.auth.browser-hardening.test.ts @@ -152,4 +152,28 @@ describe("gateway auth browser hardening", () => { } }); }); + + test("rejects forged loopback origin for control-ui when proxy headers make client non-local", async () => { + testState.gatewayAuth = { mode: "token", token: "secret" }; + await withGatewayServer(async ({ port }) => { + const ws = await openWs(port, { + origin: originForPort(port), + "x-forwarded-for": "203.0.113.50", + }); + try { + const res = await connectReq(ws, { + token: "secret", + client: { + ...TEST_OPERATOR_CLIENT, + id: GATEWAY_CLIENT_NAMES.CONTROL_UI, + mode: GATEWAY_CLIENT_MODES.UI, + }, + }); + expect(res.ok).toBe(false); + expect(res.error?.message ?? "").toContain("origin not allowed"); + } finally { + ws.close(); + } + }); + }); }); diff --git a/src/gateway/server/ws-connection.ts b/src/gateway/server/ws-connection.ts index 3abc8d6e1b9..c2fad8059e8 100644 --- a/src/gateway/server/ws-connection.ts +++ b/src/gateway/server/ws-connection.ts @@ -15,7 +15,10 @@ import { formatError } from "../server-utils.js"; import { logWs } from "../ws-log.js"; import { getHealthVersion, incrementPresenceVersion } from "./health-state.js"; import { broadcastPresenceSnapshot } from "./presence-events.js"; -import { attachGatewayWsMessageHandler } from "./ws-connection/message-handler.js"; +import { + attachGatewayWsMessageHandler, + type WsOriginCheckMetrics, +} from "./ws-connection/message-handler.js"; import type { GatewayWsClient } from "./ws-types.js"; type SubsystemLogger = ReturnType; @@ -102,6 +105,7 @@ export function attachGatewayWsConnectionHandler(params: { broadcast, buildRequestContext, } = params; + const originCheckMetrics: WsOriginCheckMetrics = { hostHeaderFallbackAccepted: 0 }; wss.on("connection", (socket, upgradeReq) => { let client: GatewayWsClient | null = null; @@ -300,6 +304,7 @@ export function attachGatewayWsConnectionHandler(params: { }, setCloseCause, setLastFrameMeta, + originCheckMetrics, logGateway, logHealth, logWsControl, diff --git a/src/gateway/server/ws-connection/message-handler.ts b/src/gateway/server/ws-connection/message-handler.ts index 58b5c9c2ab4..1ecbb330c7c 100644 --- a/src/gateway/server/ws-connection/message-handler.ts +++ b/src/gateway/server/ws-connection/message-handler.ts @@ -90,7 +90,10 @@ type SubsystemLogger = ReturnType; const DEVICE_SIGNATURE_SKEW_MS = 2 * 60 * 1000; const BROWSER_ORIGIN_LOOPBACK_RATE_LIMIT_IP = "198.18.0.1"; -let hostHeaderFallbackAcceptedCount = 0; + +export type WsOriginCheckMetrics = { + hostHeaderFallbackAccepted: number; +}; type HandshakeBrowserSecurityContext = { hasBrowserOriginHeader: boolean; @@ -260,6 +263,7 @@ export function attachGatewayWsMessageHandler(params: { setHandshakeState: (state: "pending" | "connected" | "failed") => void; setCloseCause: (cause: string, meta?: Record) => void; setLastFrameMeta: (meta: { type?: string; method?: string; id?: string }) => void; + originCheckMetrics: WsOriginCheckMetrics; logGateway: SubsystemLogger; logHealth: SubsystemLogger; logWsControl: SubsystemLogger; @@ -292,6 +296,7 @@ export function attachGatewayWsMessageHandler(params: { setHandshakeState, setCloseCause, setLastFrameMeta, + originCheckMetrics, logGateway, logHealth, logWsControl, @@ -514,9 +519,9 @@ export function attachGatewayWsMessageHandler(params: { return; } if (originCheck.matchedBy === "host-header-fallback") { - hostHeaderFallbackAcceptedCount += 1; + originCheckMetrics.hostHeaderFallbackAccepted += 1; logWsControl.warn( - `security warning: websocket origin accepted via Host-header fallback conn=${connId} count=${hostHeaderFallbackAcceptedCount} host=${requestHost ?? "n/a"} origin=${requestOrigin ?? "n/a"}`, + `security warning: websocket origin accepted via Host-header fallback conn=${connId} count=${originCheckMetrics.hostHeaderFallbackAccepted} host=${requestHost ?? "n/a"} origin=${requestOrigin ?? "n/a"}`, ); if (hostHeaderOriginFallbackEnabled) { logGateway.warn(