diff --git a/src/gateway/server-http.ts b/src/gateway/server-http.ts index d426e1223c4..1f3e8b9263d 100644 --- a/src/gateway/server-http.ts +++ b/src/gateway/server-http.ts @@ -72,6 +72,7 @@ import { type PluginHttpRequestHandler, type PluginRoutePathContext, } from "./server/plugins-http.js"; +import type { PreauthConnectionBudget } from "./server/preauth-connection-budget.js"; import type { ReadinessChecker } from "./server/readiness.js"; import type { GatewayWsClient } from "./server/ws-types.js"; import { handleSessionKillHttpRequest } from "./session-kill-http.js"; @@ -1008,13 +1009,25 @@ export function attachGatewayUpgradeHandler(opts: { wss: WebSocketServer; canvasHost: CanvasHostHandler | null; clients: Set; + preauthConnectionBudget: PreauthConnectionBudget; resolvedAuth: ResolvedGatewayAuth; /** Optional rate limiter for auth brute-force protection. */ rateLimiter?: AuthRateLimiter; }) { - const { httpServer, wss, canvasHost, clients, resolvedAuth, rateLimiter } = opts; + const { + httpServer, + wss, + canvasHost, + clients, + preauthConnectionBudget, + resolvedAuth, + rateLimiter, + } = opts; httpServer.on("upgrade", (req, socket, head) => { void (async () => { + const configSnapshot = loadConfig(); + const trustedProxies = configSnapshot.gateway?.trustedProxies ?? []; + const allowRealIpFallback = configSnapshot.gateway?.allowRealIpFallback === true; const scopedCanvas = normalizeCanvasScopedUrl(req.url ?? "/"); if (scopedCanvas.malformedScopedPath) { writeUpgradeAuthFailure(socket, { ok: false, reason: "unauthorized" }); @@ -1027,9 +1040,6 @@ export function attachGatewayUpgradeHandler(opts: { if (canvasHost) { const url = new URL(req.url ?? "/", "http://localhost"); if (url.pathname === CANVAS_WS_PATH) { - const configSnapshot = loadConfig(); - const trustedProxies = configSnapshot.gateway?.trustedProxies ?? []; - const allowRealIpFallback = configSnapshot.gateway?.allowRealIpFallback === true; const ok = await authorizeCanvasRequest({ req, auth: resolvedAuth, @@ -1050,9 +1060,68 @@ export function attachGatewayUpgradeHandler(opts: { return; } } - wss.handleUpgrade(req, socket, head, (ws) => { - wss.emit("connection", ws, req); - }); + const preauthBudgetKey = resolveRequestClientIp(req, trustedProxies, allowRealIpFallback); + if (wss.listenerCount("connection") === 0) { + const responseBody = "Gateway websocket handlers unavailable"; + socket.write( + "HTTP/1.1 503 Service Unavailable\r\n" + + "Connection: close\r\n" + + "Content-Type: text/plain; charset=utf-8\r\n" + + `Content-Length: ${Buffer.byteLength(responseBody, "utf8")}\r\n` + + "\r\n" + + responseBody, + ); + socket.destroy(); + return; + } + if (!preauthConnectionBudget.acquire(preauthBudgetKey)) { + const responseBody = "Too many unauthenticated sockets"; + socket.write( + "HTTP/1.1 503 Service Unavailable\r\n" + + "Connection: close\r\n" + + "Content-Type: text/plain; charset=utf-8\r\n" + + `Content-Length: ${Buffer.byteLength(responseBody, "utf8")}\r\n` + + "\r\n" + + responseBody, + ); + socket.destroy(); + return; + } + let budgetTransferred = false; + const releaseUpgradeBudget = () => { + if (budgetTransferred) { + return; + } + budgetTransferred = true; + preauthConnectionBudget.release(preauthBudgetKey); + }; + socket.once("close", releaseUpgradeBudget); + try { + wss.handleUpgrade(req, socket, head, (ws) => { + ( + ws as unknown as import("ws").WebSocket & { + __openclawPreauthBudgetClaimed?: boolean; + __openclawPreauthBudgetKey?: string; + } + ).__openclawPreauthBudgetKey = preauthBudgetKey; + wss.emit("connection", ws, req); + const budgetClaimed = Boolean( + ( + ws as unknown as import("ws").WebSocket & { + __openclawPreauthBudgetClaimed?: boolean; + } + ).__openclawPreauthBudgetClaimed, + ); + if (budgetClaimed) { + budgetTransferred = true; + socket.off("close", releaseUpgradeBudget); + } + }); + } catch { + socket.off("close", releaseUpgradeBudget); + releaseUpgradeBudget(); + throw new Error("gateway websocket upgrade failed"); + } })().catch(() => { socket.destroy(); }); diff --git a/src/gateway/server-runtime-state.ts b/src/gateway/server-runtime-state.ts index 173b45878ff..dc1a20c267f 100644 --- a/src/gateway/server-runtime-state.ts +++ b/src/gateway/server-runtime-state.ts @@ -43,6 +43,10 @@ import { shouldEnforceGatewayAuthForPluginPath, type PluginRoutePathContext, } from "./server/plugins-http.js"; +import { + createPreauthConnectionBudget, + type PreauthConnectionBudget, +} from "./server/preauth-connection-budget.js"; import type { ReadinessChecker } from "./server/readiness.js"; import type { GatewayTlsRuntime } from "./server/tls.js"; import type { GatewayWsClient } from "./server/ws-types.js"; @@ -83,6 +87,7 @@ export async function createGatewayRuntimeState(params: { httpServers: HttpServer[]; httpBindHosts: string[]; wss: WebSocketServer; + preauthConnectionBudget: PreauthConnectionBudget; clients: Set; broadcast: GatewayBroadcastFn; broadcastToConnIds: GatewayBroadcastToConnIdsFn; @@ -213,12 +218,14 @@ export async function createGatewayRuntimeState(params: { noServer: true, maxPayload: MAX_PREAUTH_PAYLOAD_BYTES, }); + const preauthConnectionBudget = createPreauthConnectionBudget(); for (const server of httpServers) { attachGatewayUpgradeHandler({ httpServer: server, wss, canvasHost, clients, + preauthConnectionBudget, resolvedAuth: params.resolvedAuth, rateLimiter: params.rateLimiter, }); @@ -251,6 +258,7 @@ export async function createGatewayRuntimeState(params: { httpServers, httpBindHosts, wss, + preauthConnectionBudget, clients, broadcast, broadcastToConnIds, diff --git a/src/gateway/server-ws-runtime.ts b/src/gateway/server-ws-runtime.ts index 795a162818f..6ffd2559642 100644 --- a/src/gateway/server-ws-runtime.ts +++ b/src/gateway/server-ws-runtime.ts @@ -25,6 +25,7 @@ export function attachGatewayWsHandlers(params: GatewayWsRuntimeParams) { attachGatewayWsConnectionHandler({ wss: params.wss, clients: params.clients, + preauthConnectionBudget: params.preauthConnectionBudget, port: params.port, gatewayHost: params.gatewayHost, canvasHostEnabled: params.canvasHostEnabled, diff --git a/src/gateway/server.canvas-auth.test.ts b/src/gateway/server.canvas-auth.test.ts index 8adfc0ccae5..c0822712ba8 100644 --- a/src/gateway/server.canvas-auth.test.ts +++ b/src/gateway/server.canvas-auth.test.ts @@ -6,6 +6,7 @@ import { createAuthRateLimiter } from "./auth-rate-limit.js"; import type { ResolvedGatewayAuth } from "./auth.js"; import { CANVAS_CAPABILITY_PATH_PREFIX } from "./canvas-capability.js"; import { attachGatewayUpgradeHandler, createGatewayHttpServer } from "./server-http.js"; +import { createPreauthConnectionBudget } from "./server/preauth-connection-budget.js"; import type { GatewayWsClient } from "./server/ws-types.js"; import { withTempConfig } from "./test-temp-config.js"; @@ -158,6 +159,7 @@ async function withCanvasGatewayHarness(params: { wss, canvasHost, clients, + preauthConnectionBudget: createPreauthConnectionBudget(8), resolvedAuth: params.resolvedAuth, rateLimiter: params.rateLimiter, }); diff --git a/src/gateway/server.impl.ts b/src/gateway/server.impl.ts index ac056bb69e7..e2b2e9fe117 100644 --- a/src/gateway/server.impl.ts +++ b/src/gateway/server.impl.ts @@ -686,6 +686,7 @@ export async function startGatewayServer( httpServers, httpBindHosts, wss, + preauthConnectionBudget, clients, broadcast, broadcastToConnIds, @@ -1223,6 +1224,7 @@ export async function startGatewayServer( attachGatewayWsHandlers({ wss, clients, + preauthConnectionBudget, port, gatewayHost: bindHost ?? undefined, canvasHostEnabled: Boolean(canvasHost), diff --git a/src/gateway/server.preauth-hardening.test.ts b/src/gateway/server.preauth-hardening.test.ts index df5c312286f..4004eb99a54 100644 --- a/src/gateway/server.preauth-hardening.test.ts +++ b/src/gateway/server.preauth-hardening.test.ts @@ -1,6 +1,14 @@ +import http from "node:http"; import { afterEach, describe, expect, it } from "vitest"; +import { WebSocketServer } from "ws"; +import type { ResolvedGatewayAuth } from "./auth.js"; import { MAX_PREAUTH_PAYLOAD_BYTES } from "./server-constants.js"; +import { attachGatewayUpgradeHandler, createGatewayHttpServer } from "./server-http.js"; +import { createPreauthConnectionBudget } from "./server/preauth-connection-budget.js"; +import type { GatewayWsClient } from "./server/ws-types.js"; +import { testState } from "./test-helpers.mocks.js"; import { createGatewaySuiteHarness, readConnectChallengeNonce } from "./test-helpers.server.js"; +import { withTempConfig } from "./test-temp-config.js"; let cleanupEnv: Array<() => void> = []; @@ -11,6 +19,80 @@ afterEach(async () => { }); describe("gateway pre-auth hardening", () => { + it("rejects upgrades before websocket handlers attach without consuming pre-auth budget", async () => { + const clients = new Set(); + const resolvedAuth: ResolvedGatewayAuth = { mode: "none", allowTailscale: false }; + const httpServer = createGatewayHttpServer({ + canvasHost: null, + clients, + controlUiEnabled: false, + controlUiBasePath: "/__control__", + openAiChatCompletionsEnabled: false, + openResponsesEnabled: false, + handleHooksRequest: async () => false, + resolvedAuth, + }); + const wss = new WebSocketServer({ noServer: true }); + attachGatewayUpgradeHandler({ + httpServer, + wss, + canvasHost: null, + clients, + preauthConnectionBudget: createPreauthConnectionBudget(1), + resolvedAuth, + }); + + await new Promise((resolve) => httpServer.listen(0, "127.0.0.1", resolve)); + const address = httpServer.address(); + const port = typeof address === "object" && address ? address.port : 0; + const requestUpgrade = async () => + await new Promise<{ status: number; body: string }>((resolve, reject) => { + const req = http.request({ + host: "127.0.0.1", + port, + path: "/", + headers: { + Connection: "Upgrade", + Upgrade: "websocket", + "Sec-WebSocket-Key": "dGVzdC1rZXktMDEyMzQ1Ng==", + "Sec-WebSocket-Version": "13", + }, + }); + req.once("upgrade", (_res, socket) => { + socket.destroy(); + reject(new Error("expected websocket upgrade to be rejected")); + }); + req.once("response", (res) => { + let body = ""; + res.setEncoding("utf8"); + res.on("data", (chunk) => { + body += chunk; + }); + res.once("end", () => { + resolve({ status: res.statusCode ?? 0, body }); + }); + }); + req.once("error", reject); + req.end(); + }); + + try { + await expect(requestUpgrade()).resolves.toEqual({ + status: 503, + body: "Gateway websocket handlers unavailable", + }); + await expect(requestUpgrade()).resolves.toEqual({ + status: 503, + body: "Gateway websocket handlers unavailable", + }); + } finally { + wss.close(); + await new Promise((resolve, reject) => + httpServer.close((err) => (err ? reject(err) : resolve())), + ); + } + }); + it("closes idle unauthenticated sockets after the handshake timeout", async () => { const previous = process.env.OPENCLAW_TEST_HANDSHAKE_TIMEOUT_MS; process.env.OPENCLAW_TEST_HANDSHAKE_TIMEOUT_MS = "200"; @@ -22,7 +104,9 @@ describe("gateway pre-auth hardening", () => { } }); - const harness = await createGatewaySuiteHarness(); + const harness = await createGatewaySuiteHarness({ + serverOptions: { auth: { mode: "none" } }, + }); try { const ws = await harness.openWs(); await readConnectChallengeNonce(ws); @@ -74,4 +158,129 @@ describe("gateway pre-auth hardening", () => { await harness.close(); } }); + + it("rejects excess simultaneous unauthenticated sockets from the same client ip", async () => { + const previous = process.env.OPENCLAW_TEST_MAX_PREAUTH_CONNECTIONS_PER_IP; + process.env.OPENCLAW_TEST_MAX_PREAUTH_CONNECTIONS_PER_IP = "1"; + cleanupEnv.push(() => { + if (previous === undefined) { + delete process.env.OPENCLAW_TEST_MAX_PREAUTH_CONNECTIONS_PER_IP; + } else { + process.env.OPENCLAW_TEST_MAX_PREAUTH_CONNECTIONS_PER_IP = previous; + } + }); + const previousAuth = testState.gatewayAuth; + testState.gatewayAuth = { mode: "none" }; + cleanupEnv.push(() => { + testState.gatewayAuth = previousAuth; + }); + + const harness = await createGatewaySuiteHarness(); + try { + const firstWs = await harness.openWs(); + await readConnectChallengeNonce(firstWs); + + const rejectedStatus = await new Promise((resolve, reject) => { + const req = http.request({ + host: "127.0.0.1", + port: harness.port, + path: "/", + headers: { + Connection: "Upgrade", + Upgrade: "websocket", + "Sec-WebSocket-Key": "dGVzdC1rZXktMDEyMzQ1Ng==", + "Sec-WebSocket-Version": "13", + }, + }); + req.once("upgrade", (_res, socket) => { + socket.destroy(); + reject(new Error("expected websocket upgrade to be rejected")); + }); + req.once("response", (res) => { + res.resume(); + resolve(res.statusCode ?? 0); + }); + req.once("error", reject); + req.end(); + }); + expect(rejectedStatus).toBe(503); + + firstWs.close(); + } finally { + await harness.close(); + } + }); + + it("rejects excess simultaneous unauthenticated sockets when trusted proxy headers are missing", async () => { + const previous = process.env.OPENCLAW_TEST_MAX_PREAUTH_CONNECTIONS_PER_IP; + process.env.OPENCLAW_TEST_MAX_PREAUTH_CONNECTIONS_PER_IP = "1"; + cleanupEnv.push(() => { + if (previous === undefined) { + delete process.env.OPENCLAW_TEST_MAX_PREAUTH_CONNECTIONS_PER_IP; + } else { + process.env.OPENCLAW_TEST_MAX_PREAUTH_CONNECTIONS_PER_IP = previous; + } + }); + const previousAuth = testState.gatewayAuth; + testState.gatewayAuth = { mode: "none" }; + cleanupEnv.push(() => { + testState.gatewayAuth = previousAuth; + }); + + await withTempConfig({ + cfg: { + gateway: { + trustedProxies: ["127.0.0.1"], + }, + }, + prefix: "openclaw-preauth-proxy-", + run: async () => { + const harness = await createGatewaySuiteHarness(); + try { + const firstWs = await harness.openWs(); + await readConnectChallengeNonce(firstWs); + + const rejected = await new Promise<{ status: number; body: string }>( + (resolve, reject) => { + const req = http.request({ + host: "127.0.0.1", + port: harness.port, + path: "/", + headers: { + Connection: "Upgrade", + Upgrade: "websocket", + "Sec-WebSocket-Key": "dGVzdC1rZXktMDEyMzQ1Ng==", + "Sec-WebSocket-Version": "13", + }, + }); + req.once("upgrade", (_res, socket) => { + socket.destroy(); + reject(new Error("expected websocket upgrade to be rejected")); + }); + req.once("response", (res) => { + let body = ""; + res.setEncoding("utf8"); + res.on("data", (chunk) => { + body += chunk; + }); + res.once("end", () => { + resolve({ status: res.statusCode ?? 0, body }); + }); + }); + req.once("error", reject); + req.end(); + }, + ); + expect(rejected).toEqual({ + status: 503, + body: "Too many unauthenticated sockets", + }); + + firstWs.close(); + } finally { + await harness.close(); + } + }, + }); + }); }); diff --git a/src/gateway/server/preauth-connection-budget.ts b/src/gateway/server/preauth-connection-budget.ts new file mode 100644 index 00000000000..e359338c8da --- /dev/null +++ b/src/gateway/server/preauth-connection-budget.ts @@ -0,0 +1,58 @@ +const DEFAULT_MAX_PREAUTH_CONNECTIONS_PER_IP = 32; +const UNKNOWN_CLIENT_IP_BUDGET_KEY = "__openclaw_unknown_client_ip__"; + +export function getMaxPreauthConnectionsPerIpFromEnv(env: NodeJS.ProcessEnv = process.env): number { + const configured = + env.OPENCLAW_MAX_PREAUTH_CONNECTIONS_PER_IP || + (env.VITEST && env.OPENCLAW_TEST_MAX_PREAUTH_CONNECTIONS_PER_IP); + if (!configured) { + return DEFAULT_MAX_PREAUTH_CONNECTIONS_PER_IP; + } + const parsed = Number(configured); + if (!Number.isFinite(parsed) || parsed < 1) { + return DEFAULT_MAX_PREAUTH_CONNECTIONS_PER_IP; + } + return Math.max(1, Math.floor(parsed)); +} + +export type PreauthConnectionBudget = { + acquire(clientIp: string | undefined): boolean; + release(clientIp: string | undefined): void; +}; + +export function createPreauthConnectionBudget( + limit = getMaxPreauthConnectionsPerIpFromEnv(), +): PreauthConnectionBudget { + const counts = new Map(); + const normalizeBudgetKey = (clientIp: string | undefined) => { + const ip = clientIp?.trim(); + // Trusted-proxy mode can intentionally leave client IP unresolved when + // forwarded headers are missing or invalid; keep those upgrades capped + // under a shared fallback bucket instead of failing open. + return ip || UNKNOWN_CLIENT_IP_BUDGET_KEY; + }; + + return { + acquire(clientIp) { + const ip = normalizeBudgetKey(clientIp); + const next = (counts.get(ip) ?? 0) + 1; + if (next > limit) { + return false; + } + counts.set(ip, next); + return true; + }, + release(clientIp) { + const ip = normalizeBudgetKey(clientIp); + const current = counts.get(ip); + if (current === undefined) { + return; + } + if (current <= 1) { + counts.delete(ip); + return; + } + counts.set(ip, current - 1); + }, + }; +} diff --git a/src/gateway/server/ws-connection.ts b/src/gateway/server/ws-connection.ts index 4add43d3cbd..74880180348 100644 --- a/src/gateway/server/ws-connection.ts +++ b/src/gateway/server/ws-connection.ts @@ -14,6 +14,7 @@ import type { GatewayRequestContext, GatewayRequestHandlers } from "../server-me import { formatError } from "../server-utils.js"; import { logWs } from "../ws-log.js"; import { getHealthVersion, incrementPresenceVersion } from "./health-state.js"; +import type { PreauthConnectionBudget } from "./preauth-connection-budget.js"; import { broadcastPresenceSnapshot } from "./presence-events.js"; import { attachGatewayWsMessageHandler, @@ -61,6 +62,7 @@ const sanitizeLogValue = (value: string | undefined): string | undefined => { export type GatewayWsSharedHandlerParams = { wss: WebSocketServer; clients: Set; + preauthConnectionBudget: PreauthConnectionBudget; port: number; gatewayHost?: string; canvasHostEnabled: boolean; @@ -94,6 +96,7 @@ export function attachGatewayWsConnectionHandler(params: AttachGatewayWsConnecti const { wss, clients, + preauthConnectionBudget, port, gatewayHost, canvasHostEnabled, @@ -119,6 +122,17 @@ export function attachGatewayWsConnectionHandler(params: AttachGatewayWsConnecti const connId = randomUUID(); const remoteAddr = (socket as WebSocket & { _socket?: { remoteAddress?: string } })._socket ?.remoteAddress; + const preauthBudgetKey = ( + socket as WebSocket & { + __openclawPreauthBudgetClaimed?: boolean; + __openclawPreauthBudgetKey?: string; + } + ).__openclawPreauthBudgetKey; + ( + socket as WebSocket & { + __openclawPreauthBudgetClaimed?: boolean; + } + ).__openclawPreauthBudgetClaimed = true; const headerValue = (value: string | string[] | undefined) => Array.isArray(value) ? value[0] : value; const requestHost = headerValue(upgradeReq.headers.host); @@ -140,6 +154,7 @@ export function attachGatewayWsConnectionHandler(params: AttachGatewayWsConnecti logWs("in", "open", { connId, remoteAddr }); let handshakeState: "pending" | "connected" | "failed" = "pending"; + let holdsPreauthBudget = true; let closeCause: string | undefined; let closeMeta: Record = {}; let lastFrameType: string | undefined; @@ -155,6 +170,14 @@ export function attachGatewayWsConnectionHandler(params: AttachGatewayWsConnecti } }; + const releasePreauthBudget = () => { + if (!holdsPreauthBudget) { + return; + } + holdsPreauthBudget = false; + preauthConnectionBudget.release(preauthBudgetKey); + }; + const setLastFrameMeta = (meta: { type?: string; method?: string; id?: string }) => { if (meta.type || meta.method || meta.id) { lastFrameType = meta.type ?? lastFrameType; @@ -184,6 +207,7 @@ export function attachGatewayWsConnectionHandler(params: AttachGatewayWsConnecti } closed = true; clearTimeout(handshakeTimer); + releasePreauthBudget(); if (client) { clients.delete(client); } @@ -302,6 +326,7 @@ export function attachGatewayWsConnectionHandler(params: AttachGatewayWsConnecti clearHandshakeTimer: () => clearTimeout(handshakeTimer), getClient: () => client, setClient: (next) => { + releasePreauthBudget(); client = next; clients.add(next); },