From 5aac7939dbfdb4dd44d16f27bb4217372bcfe566 Mon Sep 17 00:00:00 2001 From: Vincent Koc Date: Sun, 17 May 2026 18:08:51 +0800 Subject: [PATCH] fix(gateway): drain replies during restart close --- CHANGELOG.md | 1 + src/cli/gateway-cli/run-loop.test.ts | 105 ++++---- src/cli/gateway-cli/run-loop.ts | 16 ++ src/gateway/server-close.test.ts | 225 +++++++++++++++++- src/gateway/server-close.ts | 222 ++++++++++++++++- .../server-startup-post-attach.test.ts | 7 +- src/gateway/server.impl.ts | 84 ++++--- 7 files changed, 578 insertions(+), 82 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 89c09ee152b..afe6ecea931 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,7 @@ Docs: https://docs.openclaw.ai ### Fixes - Gateway/secrets: split the lightweight secrets runtime state and auth-store cache from the full secrets runtime and take a startup fast path when the gateway startup config has no SecretRef values, speeding up secrets startup while preserving cleanup and refresh semantics. +- Gateway/restart: drain pending replies and active chat runs during restart shutdown before sockets and channels close, aborting timed-out chat runs through the normal cleanup path. (#69121) Thanks @alexlomt. - QA-Lab: wake qa-bus long polls that arrive with stale future cursors after a bus restart, preserving reconnect readiness for harness clients. (#67142) Thanks @hxy91819. - QA-Lab: stage Multipass transfer scripts under OpenClaw's preferred temp root instead of raw OS temp paths, keeping the VM runner inside temp-path guardrails. (#64098) Thanks @ImLukeF. - Agents/replies: keep surviving reply media and append a warning when other media references fail, so partial media normalization no longer drops failures silently. Thanks @Jerry-Xin. diff --git a/src/cli/gateway-cli/run-loop.test.ts b/src/cli/gateway-cli/run-loop.test.ts index 58e20d77744..b7b21f5e072 100644 --- a/src/cli/gateway-cli/run-loop.test.ts +++ b/src/cli/gateway-cli/run-loop.test.ts @@ -1,4 +1,5 @@ import { describe, expect, it, vi } from "vitest"; +import type { GatewayServer } from "../../gateway/server.impl.js"; import type { GatewayBonjourBeacon } from "../../infra/bonjour-discovery.js"; import { pickBeaconHost, pickGatewayPort } from "./discover.js"; @@ -249,13 +250,33 @@ function createRuntimeWithExitSignal(exitCallOrder?: string[]) { return { runtime, exited }; } -type GatewayCloseFn = (...args: unknown[]) => Promise; +type GatewayCloseFn = GatewayServer["close"]; type LoopRuntime = { log: (...args: unknown[]) => void; error: (...args: unknown[]) => void; exit: (code: number) => void; }; +function createCloseMock() { + return vi.fn(async (_opts) => {}); +} + +function expectRestartCloseCall( + close: ReturnType, + maxDrainTimeoutMs: number, +) { + expect(close).toHaveBeenCalledWith( + expect.objectContaining({ + reason: "gateway restarting", + restartExpectedMs: 1500, + drainTimeoutMs: expect.any(Number), + }), + ); + const closeArgs = close.mock.calls[0]?.[0]; + expect(closeArgs?.drainTimeoutMs).toBeLessThanOrEqual(maxDrainTimeoutMs); + expect(closeArgs?.drainTimeoutMs).toBeGreaterThanOrEqual(0); +} + function createSignaledStart(close: GatewayCloseFn) { let resolveStarted: (() => void) | null = null; const started = new Promise((resolve) => { @@ -304,7 +325,7 @@ async function waitForLoopCondition(predicate: () => boolean, message: string) { } async function createSignaledLoopHarness(exitCallOrder?: string[]) { - const close = vi.fn(async () => {}); + const close = createCloseMock(); const { start, started } = createSignaledStart(close); const { runtime, exited } = createRuntimeWithExitSignal(exitCallOrder); const { loopPromise } = await runLoopWithStart({ start, runtime }); @@ -361,8 +382,8 @@ describe("runGatewayLoop", () => { getActiveTaskCount.mockReturnValueOnce(1).mockReturnValue(0); await withIsolatedSignals(async ({ captureSignal }) => { - const closeFirst = vi.fn(async () => {}); - const closeSecond = vi.fn(async () => {}); + const closeFirst = createCloseMock(); + const closeSecond = createCloseMock(); const { runtime, exited } = createRuntimeWithExitSignal(); let resolveSecond: (() => void) | null = null; const startedSecond = new Promise((resolve) => { @@ -391,10 +412,7 @@ describe("runGatewayLoop", () => { expect(consumeGatewayRestartIntentPayloadSync).toHaveBeenCalledOnce(); expect(markGatewayDraining).toHaveBeenCalledOnce(); expect(waitForActiveTasks).toHaveBeenCalledWith(90_000); - expect(closeFirst).toHaveBeenCalledWith({ - reason: "gateway restarting", - restartExpectedMs: 1500, - }); + expectRestartCloseCall(closeFirst, 90_000); await startedSecond; expect(start).toHaveBeenCalledTimes(2); await new Promise((resolve) => setImmediate(resolve)); @@ -430,6 +448,27 @@ describe("runGatewayLoop", () => { }); }); + it("caps reply drain time for unbounded SIGTERM restarts", async () => { + vi.clearAllMocks(); + consumeGatewayRestartIntentPayloadSync.mockReturnValueOnce({ waitMs: 0 }); + + await withIsolatedSignals(async ({ captureSignal }) => { + const { close, start, exited } = await createSignaledLoopHarness(); + const sigterm = captureSignal("SIGTERM"); + const sigint = captureSignal("SIGINT"); + + sigterm(); + await new Promise((resolve) => setImmediate(resolve)); + await new Promise((resolve) => setImmediate(resolve)); + + expectRestartCloseCall(close, 15_000); + expect(start).toHaveBeenCalledTimes(2); + + sigint(); + await expect(exited).resolves.toBe(0); + }); + }); + it("aborts active embedded runs after a short restart drain grace", async () => { vi.clearAllMocks(); consumeGatewayRestartIntentPayloadSync.mockReturnValueOnce({}); @@ -473,10 +512,7 @@ describe("runGatewayLoop", () => { expect(gatewayLog.warn).toHaveBeenCalledWith( "failed to mark interrupted main sessions for restart recovery: Error: store read-only", ); - expect(close).toHaveBeenCalledWith({ - reason: "gateway restarting", - restartExpectedMs: 1500, - }); + expectRestartCloseCall(close, 90_000); expect(start).toHaveBeenCalledTimes(2); sigint(); @@ -567,12 +603,12 @@ describe("runGatewayLoop", () => { waitForActiveEmbeddedRuns.mockResolvedValueOnce({ drained: true }); type StartServer = () => Promise<{ - close: (opts: { reason: string; restartExpectedMs: number | null }) => Promise; + close: GatewayCloseFn; }>; - const closeFirst = vi.fn(async () => {}); - const closeSecond = vi.fn(async () => {}); - const closeThird = vi.fn(async () => {}); + const closeFirst = createCloseMock(); + const closeSecond = createCloseMock(); + const closeThird = createCloseMock(); const { runtime, exited } = createRuntimeWithExitSignal(); const start = vi.fn(); @@ -639,10 +675,7 @@ describe("runGatewayLoop", () => { }); expect(markGatewayDraining).toHaveBeenCalledTimes(1); expect(gatewayLog.warn).toHaveBeenCalledWith(DRAIN_TIMEOUT_LOG); - expect(closeFirst).toHaveBeenCalledWith({ - reason: "gateway restarting", - restartExpectedMs: 1500, - }); + expectRestartCloseCall(closeFirst, 1_234); expect(markGatewaySigusr1RestartHandled).toHaveBeenCalledTimes(1); expect(resetAllLanes).toHaveBeenCalledTimes(1); expect(resetGatewayRestartStateForInProcessRestart).toHaveBeenCalledTimes(1); @@ -652,10 +685,7 @@ describe("runGatewayLoop", () => { await startedThird; await new Promise((resolve) => setImmediate(resolve)); - expect(closeSecond).toHaveBeenCalledWith({ - reason: "gateway restarting", - restartExpectedMs: 1500, - }); + expectRestartCloseCall(closeSecond, 1_234); expect(markGatewaySigusr1RestartHandled).toHaveBeenCalledTimes(2); expect(markGatewayDraining).toHaveBeenCalledTimes(2); expect(resetAllLanes).toHaveBeenCalledTimes(2); @@ -681,8 +711,8 @@ describe("runGatewayLoop", () => { }); await withIsolatedSignals(async ({ captureSignal }) => { - const closeFirst = vi.fn(async () => {}); - const closeSecond = vi.fn(async () => {}); + const closeFirst = createCloseMock(); + const closeSecond = createCloseMock(); const { runtime, exited } = createRuntimeWithExitSignal(); let releaseFirstStart!: () => void; const firstStartMayReturn = new Promise((resolve) => { @@ -729,10 +759,7 @@ describe("runGatewayLoop", () => { "expected queued SIGUSR1 to trigger the second gateway start", ); await startedSecond; - expect(closeFirst).toHaveBeenCalledWith({ - reason: "gateway restarting", - restartExpectedMs: 1500, - }); + expectRestartCloseCall(closeFirst, 90_000); expect(markGatewaySigusr1RestartHandled).toHaveBeenCalledTimes(1); expect(markGatewayDraining).toHaveBeenCalledTimes(1); expect(resetAllLanes).toHaveBeenCalledTimes(1); @@ -869,8 +896,8 @@ describe("runGatewayLoop", () => { }); await withIsolatedSignals(async ({ captureSignal }) => { - const closeFirst = vi.fn(async () => {}); - const closeThird = vi.fn(async () => {}); + const closeFirst = createCloseMock(); + const closeThird = createCloseMock(); const { runtime, exited } = createRuntimeWithExitSignal(); let sigusr1: (() => void) | null = null; let resolveThirdStart: (() => void) | null = null; @@ -909,10 +936,7 @@ describe("runGatewayLoop", () => { "expected queued SIGUSR1 to advance past failed restart startup", ); await startedThird; - expect(closeFirst).toHaveBeenCalledWith({ - reason: "gateway restarting", - restartExpectedMs: 1500, - }); + expectRestartCloseCall(closeFirst, 90_000); expect(markGatewaySigusr1RestartHandled).toHaveBeenCalledTimes(2); expect(markGatewayDraining).toHaveBeenCalledTimes(2); expect(resetAllLanes).toHaveBeenCalledTimes(2); @@ -938,8 +962,8 @@ describe("runGatewayLoop", () => { }); await withIsolatedSignals(async ({ captureSignal }) => { - const closeFirst = vi.fn(async () => {}); - const closeThird = vi.fn(async () => {}); + const closeFirst = createCloseMock(); + const closeThird = createCloseMock(); const { runtime, exited } = createRuntimeWithExitSignal(); let resolveThirdStart: (() => void) | null = null; const startedThird = new Promise((resolve) => { @@ -980,10 +1004,7 @@ describe("runGatewayLoop", () => { "expected post-failure SIGUSR1 to retry gateway startup", ); await startedThird; - expect(closeFirst).toHaveBeenCalledWith({ - reason: "gateway restarting", - restartExpectedMs: 1500, - }); + expectRestartCloseCall(closeFirst, 90_000); expect(markGatewaySigusr1RestartHandled).toHaveBeenCalledTimes(2); expect(markGatewayDraining).toHaveBeenCalledTimes(2); expect(resetAllLanes).toHaveBeenCalledTimes(2); diff --git a/src/cli/gateway-cli/run-loop.ts b/src/cli/gateway-cli/run-loop.ts index ccdf17d4d34..f2cb7fccfd1 100644 --- a/src/cli/gateway-cli/run-loop.ts +++ b/src/cli/gateway-cli/run-loop.ts @@ -19,6 +19,7 @@ const LAUNCHD_SUPERVISED_RESTART_EXIT_DELAY_MS = 1500; const DEFAULT_RESTART_DRAIN_TIMEOUT_MS = 300_000; const RESTART_ACTIVE_EMBEDDED_RUN_ABORT_GRACE_MS = 30_000; const RESTART_DRAIN_STILL_PENDING_WARN_MS = 30_000; +const RESTART_CLOSE_REPLY_DRAIN_SHUTDOWN_RESERVE_MS = 10_000; const UPDATE_RESPAWN_HEALTH_TIMEOUT_MS = 10_000; const UPDATE_RESPAWN_HEALTH_POLL_MS = 200; @@ -404,6 +405,10 @@ export async function runGatewayLoop(params: { const restartDrainTimeoutMs = isRestart ? await resolveRestartDrainTimeoutMs(restartIntent) : 0; + const restartDrainDeadlineAt = + isRestart && restartDrainTimeoutMs !== undefined + ? Date.now() + restartDrainTimeoutMs + : undefined; if (!isRestart) { armForceExitTimer(SHUTDOWN_TIMEOUT_MS); } else if (restartDrainTimeoutMs !== undefined) { @@ -429,6 +434,15 @@ export async function runGatewayLoop(params: { armForceExitTimer(SHUTDOWN_TIMEOUT_MS); } }; + const resolveRestartCloseDrainTimeoutMs = () => { + if (!isRestart) { + return null; + } + if (restartDrainTimeoutMs === undefined) { + return Math.max(0, SHUTDOWN_TIMEOUT_MS - RESTART_CLOSE_REPLY_DRAIN_SHUTDOWN_RESERVE_MS); + } + return Math.max(0, (restartDrainDeadlineAt ?? Date.now()) - Date.now()); + }; try { // On restart, wait for in-flight agent turns to finish before @@ -595,9 +609,11 @@ export async function runGatewayLoop(params: { } armCloseForceExitTimerForIndefiniteRestart(); + const closeDrainTimeoutMs = resolveRestartCloseDrainTimeoutMs(); await server?.close({ reason: isRestart ? "gateway restarting" : "gateway stopping", restartExpectedMs: isRestart ? 1500 : null, + ...(closeDrainTimeoutMs !== null ? { drainTimeoutMs: closeDrainTimeoutMs } : {}), }); } catch (err) { gatewayLog.error(`shutdown error: ${String(err)}`); diff --git a/src/gateway/server-close.test.ts b/src/gateway/server-close.test.ts index 54c74c31453..04584208111 100644 --- a/src/gateway/server-close.test.ts +++ b/src/gateway/server-close.test.ts @@ -66,6 +66,7 @@ vi.mock("../logging/subsystem.js", () => ({ })); const { createGatewayCloseHandler } = await import("./server-close.js"); +const { createChatRunState } = await import("./server-chat-state.js"); const { finishGatewayRestartTrace, recordGatewayRestartTraceSpan, @@ -83,6 +84,13 @@ function firstMockCall(mock: { mock: { calls: read return mock.mock.calls[0]; } +function createTestChatRunState() { + const state = createChatRunState(); + const clear = state.clear; + state.clear = vi.fn(() => clear()); + return state; +} + function createGatewayCloseTestDeps( overrides: Partial = {}, ): GatewayCloseHandlerParams { @@ -105,7 +113,12 @@ function createGatewayCloseTestDeps( heartbeatUnsub: null, transcriptUnsub: null, lifecycleUnsub: null, - chatRunState: { clear: vi.fn() }, + chatRunState: createTestChatRunState(), + chatAbortControllers: new Map(), + removeChatRun: vi.fn(), + agentRunSeq: new Map(), + nodeSendToSession: vi.fn(), + getPendingReplyCount: vi.fn(() => 0), clients: new Set(), configReloader: { stop: vi.fn(async () => undefined) }, wss: { @@ -328,6 +341,30 @@ describe("createGatewayCloseHandler", () => { expect(firstMockCall(drainActiveSessionsForShutdown)?.[0]?.reason).toBe("restart"); }); + it("drains pending restart replies before emitting session-end hooks", async () => { + const order: string[] = []; + const drainActiveSessionsForShutdown = vi.fn(async () => { + order.push("session-end"); + return { + emittedSessionIds: ["session-A"], + timedOut: false, + }; + }); + const close = createGatewayCloseHandler( + createGatewayCloseTestDeps({ + drainActiveSessionsForShutdown, + getPendingReplyCount: () => { + order.push("reply-drain"); + return 0; + }, + }), + ); + + await close({ reason: "gateway restarting", restartExpectedMs: 123, drainTimeoutMs: 100 }); + + expect(order).toStrictEqual(["reply-drain", "session-end"]); + }); + it("records a warning and continues shutdown when the session-end drain reports a timeout", async () => { const drainActiveSessionsForShutdown = vi.fn(async () => ({ emittedSessionIds: ["session-A"], @@ -356,6 +393,192 @@ describe("createGatewayCloseHandler", () => { expect(result.warnings).not.toContain("session-end-drain"); }); + it("waits for pending replies to settle before restart shutdown", async () => { + vi.useFakeTimers(); + let pendingReplies = 1; + const close = createGatewayCloseHandler( + createGatewayCloseTestDeps({ + getPendingReplyCount: () => pendingReplies, + }), + ); + + const closePromise = close({ + reason: "gateway restarting", + restartExpectedMs: 123, + drainTimeoutMs: 200, + }); + await vi.advanceTimersByTimeAsync(100); + pendingReplies = 0; + await vi.advanceTimersByTimeAsync(100); + const result = await closePromise; + + expect(result.warnings).not.toContain("restart-reply-drain"); + expect( + mocks.logInfo.mock.calls.some(([message]) => + String(message).includes("waiting for 1 pending reply(ies) before restart shutdown"), + ), + ).toBe(true); + expect( + mocks.logInfo.mock.calls.some(([message]) => + String(message).includes("restart reply drain completed after"), + ), + ).toBe(true); + }); + + it("aborts active chat runs when restart reply drain times out", async () => { + vi.useFakeTimers(); + const controller = new AbortController(); + const agentController = new AbortController(); + const chatRunState = createChatRunState(); + chatRunState.buffers.set("run-1", "partial reply"); + chatRunState.deltaSentAt.set("run-1", Date.now()); + chatRunState.deltaLastBroadcastLen.set("run-1", 3); + chatRunState.deltaLastBroadcastText.set("run-1", "par"); + chatRunState.agentDeltaSentAt.set("run-1:assistant", Date.now()); + chatRunState.bufferedAgentEvents.set("run-1:assistant", { + sessionKey: "session-1", + payload: {} as never, + }); + const chatAbortControllers = new Map([ + [ + "run-1", + { + controller, + sessionId: "run-1", + sessionKey: "session-1", + startedAtMs: Date.now(), + expiresAtMs: Date.now() + 60_000, + }, + ], + [ + "agent-run-1", + { + controller: agentController, + sessionId: "agent-run-1", + sessionKey: "session-1", + startedAtMs: Date.now(), + expiresAtMs: Date.now() + 60_000, + kind: "agent" as const, + }, + ], + ]); + const broadcast = vi.fn(); + const nodeSendToSession = vi.fn(); + const close = createGatewayCloseHandler( + createGatewayCloseTestDeps({ + broadcast, + nodeSendToSession, + chatRunState, + chatAbortControllers, + removeChatRun: vi.fn(() => ({ sessionKey: "session-1", clientRunId: "run-1" })), + }), + ); + + const closePromise = close({ + reason: "gateway restarting", + restartExpectedMs: 123, + drainTimeoutMs: 100, + }); + await vi.advanceTimersByTimeAsync(100); + const result = await closePromise; + + expect(result.warnings).toContain("restart-reply-drain"); + expect(controller.signal.aborted).toBe(true); + expect(agentController.signal.aborted).toBe(false); + expect(chatAbortControllers.has("run-1")).toBe(false); + expect(chatAbortControllers.has("agent-run-1")).toBe(true); + expect(chatRunState.buffers.has("run-1")).toBe(false); + expect(chatRunState.deltaSentAt.has("run-1")).toBe(false); + expect(chatRunState.deltaLastBroadcastLen.has("run-1")).toBe(false); + expect(chatRunState.deltaLastBroadcastText.has("run-1")).toBe(false); + expect(chatRunState.agentDeltaSentAt.has("run-1:assistant")).toBe(false); + expect(chatRunState.bufferedAgentEvents.has("run-1:assistant")).toBe(false); + expect( + mocks.logWarn.mock.calls.some(([message]) => + String(message).includes( + "restart reply drain timed out after 100ms with 1 active chat run(s) still active", + ), + ), + ).toBe(true); + expect( + mocks.logWarn.mock.calls.some(([message]) => + String(message).includes("aborted 1 active chat run(s) during restart shutdown"), + ), + ).toBe(true); + expect(broadcast).toHaveBeenCalledWith( + "chat", + expect.objectContaining({ runId: "run-1", state: "aborted", stopReason: "restart" }), + ); + expect(nodeSendToSession).toHaveBeenCalledWith( + "session-1", + "chat", + expect.objectContaining({ runId: "run-1", state: "aborted", stopReason: "restart" }), + ); + }); + + it("does not drain or abort active chat runs for normal shutdown", async () => { + const controller = new AbortController(); + const chatAbortControllers = new Map([ + [ + "run-1", + { + controller, + sessionId: "run-1", + sessionKey: "session-1", + startedAtMs: Date.now(), + expiresAtMs: Date.now() + 60_000, + }, + ], + ]); + const close = createGatewayCloseHandler( + createGatewayCloseTestDeps({ + chatAbortControllers, + }), + ); + + const result = await close({ reason: "SIGTERM", drainTimeoutMs: 0 }); + + expect(result.warnings).not.toContain("restart-reply-drain"); + expect(controller.signal.aborted).toBe(false); + expect(chatAbortControllers.size).toBe(1); + }); + + it("aborts active chat runs immediately when restart drain budget is exhausted", async () => { + const controller = new AbortController(); + const chatAbortControllers = new Map([ + [ + "run-1", + { + controller, + sessionId: "run-1", + sessionKey: "session-1", + startedAtMs: Date.now(), + expiresAtMs: Date.now() + 60_000, + }, + ], + ]); + const close = createGatewayCloseHandler( + createGatewayCloseTestDeps({ + chatAbortControllers, + }), + ); + + const result = await close({ + reason: "gateway restarting", + restartExpectedMs: 123, + drainTimeoutMs: 0, + }); + + expect(result.warnings).toContain("restart-reply-drain"); + expect(controller.signal.aborted).toBe(true); + expect(chatAbortControllers.size).toBe(0); + expect( + mocks.logWarn.mock.calls.some(([message]) => + String(message).includes("restart reply drain timed out after 0ms"), + ), + ).toBe(true); + }); + it("continues restart shutdown and records a warning when gateway pre-restart hook stalls", async () => { vi.useFakeTimers(); mocks.triggerInternalHook.mockImplementation((event: InternalHookEvent) => { diff --git a/src/gateway/server-close.ts b/src/gateway/server-close.ts index d070fff250a..7efc635ba11 100644 --- a/src/gateway/server-close.ts +++ b/src/gateway/server-close.ts @@ -9,11 +9,13 @@ import { createSubsystemLogger } from "../logging/subsystem.js"; import { closePluginStateSqliteStore } from "../plugin-state/plugin-state-store.js"; import type { PluginServicesHandle } from "../plugins/services.js"; import { normalizeOptionalString } from "../shared/string-coerce.js"; +import { abortChatRunById, type ChatAbortControllerEntry } from "./chat-abort.js"; import { collectGatewayProcessMemoryUsageMb, measureGatewayRestartTrace, recordGatewayRestartTrace, } from "./restart-trace.js"; +import type { ChatRunState } from "./server-chat-state.js"; import type { GatewayPostReadySidecarHandle } from "./server-startup-post-attach.js"; const shutdownLog = createSubsystemLogger("gateway/shutdown"); @@ -26,6 +28,9 @@ const HTTP_CLOSE_GRACE_MS = 1_000; const HTTP_CLOSE_FORCE_WAIT_MS = 5_000; const MCP_RUNTIME_CLOSE_GRACE_MS = 5_000; const LSP_RUNTIME_CLOSE_GRACE_MS = 5_000; +const RESTART_REPLY_DRAIN_POLL_MS = 100; +const RESTART_REPLY_POST_ABORT_DRAIN_TIMEOUT_MS = 1_000; +const RESTART_REPLY_POST_ABORT_DRAIN_POLL_MS = 50; export type ShutdownResult = { durationMs: number; @@ -81,6 +86,187 @@ function recordShutdownWarning(warnings: string[], name: string): void { } } +function getRestartReplyDrainCounts(params: { + getPendingReplyCount: () => number; + chatAbortControllers: Map; +}) { + const pendingReplyCount = params.getPendingReplyCount(); + const activeChatRuns = listRestartDrainChatRuns(params.chatAbortControllers).length; + return { + pendingReplies: + Number.isFinite(pendingReplyCount) && pendingReplyCount > 0 + ? Math.floor(pendingReplyCount) + : 0, + activeChatRuns, + }; +} + +function listRestartDrainChatRuns( + chatAbortControllers: Map, +): Array<[string, ChatAbortControllerEntry]> { + return Array.from(chatAbortControllers.entries()).filter(([, entry]) => entry.kind !== "agent"); +} + +function formatRestartReplyDrainDetails(counts: { + pendingReplies: number; + activeChatRuns: number; +}): string { + const details: string[] = []; + if (counts.pendingReplies > 0) { + details.push(`${counts.pendingReplies} pending reply(ies)`); + } + if (counts.activeChatRuns > 0) { + details.push(`${counts.activeChatRuns} active chat run(s)`); + } + return details.length > 0 ? details.join(", ") : "no pending reply work"; +} + +async function sleepForRestartReplyDrain(delayMs: number): Promise { + await new Promise((resolve) => { + const timer = setTimeout(resolve, delayMs); + timer.unref?.(); + }); +} + +async function waitForRestartReplyDrain(params: { + getPendingReplyCount: () => number; + chatAbortControllers: Map; + timeoutMs: number; + pollMs?: number; +}): Promise<{ + drained: boolean; + elapsedMs: number; + counts: { pendingReplies: number; activeChatRuns: number }; +}> { + const timeoutMs = Math.max(0, Math.floor(params.timeoutMs)); + const pollMs = Math.max(25, Math.floor(params.pollMs ?? RESTART_REPLY_DRAIN_POLL_MS)); + let counts = getRestartReplyDrainCounts(params); + if (counts.pendingReplies <= 0 && counts.activeChatRuns <= 0) { + return { drained: true, elapsedMs: 0, counts }; + } + if (timeoutMs <= 0) { + return { drained: false, elapsedMs: 0, counts }; + } + + const startedAt = Date.now(); + for (;;) { + const elapsedMs = Date.now() - startedAt; + if (elapsedMs >= timeoutMs) { + return { drained: false, elapsedMs, counts }; + } + await sleepForRestartReplyDrain(Math.min(pollMs, timeoutMs - elapsedMs)); + counts = getRestartReplyDrainCounts(params); + if (counts.pendingReplies <= 0 && counts.activeChatRuns <= 0) { + return { drained: true, elapsedMs: Date.now() - startedAt, counts }; + } + } +} + +function abortActiveChatRunsForRestart(params: { + chatAbortControllers: Map; + chatRunState: ChatRunState; + removeChatRun: ( + sessionId: string, + clientRunId: string, + sessionKey?: string, + ) => { sessionKey: string; clientRunId: string } | undefined; + agentRunSeq: Map; + broadcast: (event: string, payload: unknown, opts?: { dropIfSlow?: boolean }) => void; + nodeSendToSession: (sessionKey: string, event: string, payload: unknown) => void; +}): number { + let aborted = 0; + for (const [runId, entry] of listRestartDrainChatRuns(params.chatAbortControllers)) { + const result = abortChatRunById( + { + chatAbortControllers: params.chatAbortControllers, + chatRunBuffers: params.chatRunState.buffers, + chatDeltaSentAt: params.chatRunState.deltaSentAt, + chatDeltaLastBroadcastLen: params.chatRunState.deltaLastBroadcastLen, + chatDeltaLastBroadcastText: params.chatRunState.deltaLastBroadcastText, + agentDeltaSentAt: params.chatRunState.agentDeltaSentAt, + bufferedAgentEvents: params.chatRunState.bufferedAgentEvents, + chatAbortedRuns: params.chatRunState.abortedRuns, + removeChatRun: params.removeChatRun, + agentRunSeq: params.agentRunSeq, + broadcast: params.broadcast, + nodeSendToSession: params.nodeSendToSession, + }, + { + runId, + sessionKey: entry.sessionKey, + stopReason: "restart", + }, + ); + if (result.aborted) { + aborted += 1; + } + } + return aborted; +} + +async function drainRestartPendingRepliesForShutdown(params: { + getPendingReplyCount: () => number; + chatAbortControllers: Map; + chatRunState: ChatRunState; + removeChatRun: ( + sessionId: string, + clientRunId: string, + sessionKey?: string, + ) => { sessionKey: string; clientRunId: string } | undefined; + agentRunSeq: Map; + broadcast: (event: string, payload: unknown, opts?: { dropIfSlow?: boolean }) => void; + nodeSendToSession: (sessionKey: string, event: string, payload: unknown) => void; + timeoutMs: number; + warnings: string[]; +}): Promise { + const initialCounts = getRestartReplyDrainCounts(params); + if (initialCounts.pendingReplies <= 0 && initialCounts.activeChatRuns <= 0) { + return; + } + + const timeoutMs = Math.max(0, Math.floor(params.timeoutMs)); + if (timeoutMs > 0) { + shutdownLog.info( + `waiting for ${formatRestartReplyDrainDetails(initialCounts)} before restart shutdown (timeout ${timeoutMs}ms)`, + ); + } + + const drainResult = await waitForRestartReplyDrain({ + getPendingReplyCount: params.getPendingReplyCount, + chatAbortControllers: params.chatAbortControllers, + timeoutMs, + }); + if (drainResult.drained) { + shutdownLog.info(`restart reply drain completed after ${drainResult.elapsedMs}ms`); + return; + } + + shutdownLog.warn( + `restart reply drain timed out after ${drainResult.elapsedMs}ms with ${formatRestartReplyDrainDetails(drainResult.counts)} still active; continuing shutdown`, + ); + recordShutdownWarning(params.warnings, "restart-reply-drain"); + + if (drainResult.counts.activeChatRuns <= 0) { + return; + } + + const abortedRuns = abortActiveChatRunsForRestart(params); + if (abortedRuns <= 0) { + return; + } + + shutdownLog.warn(`aborted ${abortedRuns} active chat run(s) during restart shutdown`); + const postAbortDrain = await waitForRestartReplyDrain({ + getPendingReplyCount: params.getPendingReplyCount, + chatAbortControllers: params.chatAbortControllers, + timeoutMs: RESTART_REPLY_POST_ABORT_DRAIN_TIMEOUT_MS, + pollMs: RESTART_REPLY_POST_ABORT_DRAIN_POLL_MS, + }); + if (postAbortDrain.drained) { + shutdownLog.info("restart reply drain completed after abort cleanup"); + } +} + async function triggerGatewayLifecycleHookWithTimeout(params: { event: ReturnType; hookName: "gateway:shutdown" | "gateway:pre-restart"; @@ -199,7 +385,16 @@ export function createGatewayCloseHandler(params: { heartbeatUnsub: (() => void) | null; transcriptUnsub: (() => void) | null; lifecycleUnsub: (() => void) | null; - chatRunState: { clear: () => void }; + chatRunState: ChatRunState; + chatAbortControllers: Map; + removeChatRun: ( + sessionId: string, + clientRunId: string, + sessionKey?: string, + ) => { sessionKey: string; clientRunId: string } | undefined; + agentRunSeq: Map; + nodeSendToSession: (sessionKey: string, event: string, payload: unknown) => void; + getPendingReplyCount?: () => number; clients: Set<{ socket: { close: (code: number, reason: string) => void } }>; configReloader: { stop: () => Promise }; wss: WebSocketServer; @@ -213,6 +408,7 @@ export function createGatewayCloseHandler(params: { return async (opts?: { reason?: string; restartExpectedMs?: number | null; + drainTimeoutMs?: number | null; }): Promise => { const start = Date.now(); const warnings: string[] = []; @@ -279,6 +475,30 @@ export function createGatewayCloseHandler(params: { ), ); } + if (restartExpectedMs !== null && params.getPendingReplyCount) { + const drainTimeoutMs = + typeof opts?.drainTimeoutMs === "number" && Number.isFinite(opts.drainTimeoutMs) + ? Math.max(0, Math.floor(opts.drainTimeoutMs)) + : 0; + await measureCloseStep("reply-drain", () => + shutdownStep( + "restart-reply-drain", + () => + drainRestartPendingRepliesForShutdown({ + getPendingReplyCount: params.getPendingReplyCount!, + chatAbortControllers: params.chatAbortControllers, + chatRunState: params.chatRunState, + removeChatRun: params.removeChatRun, + agentRunSeq: params.agentRunSeq, + broadcast: params.broadcast, + nodeSendToSession: params.nodeSendToSession, + timeoutMs: drainTimeoutMs, + warnings, + }), + warnings, + ), + ); + } if (params.drainActiveSessionsForShutdown) { await measureCloseStep("session-end-drain", () => shutdownStep( diff --git a/src/gateway/server-startup-post-attach.test.ts b/src/gateway/server-startup-post-attach.test.ts index 23a78b76524..ab66de2d039 100644 --- a/src/gateway/server-startup-post-attach.test.ts +++ b/src/gateway/server-startup-post-attach.test.ts @@ -1040,6 +1040,7 @@ describe("startGatewayPostAttachRuntime", () => { const stopChannel = vi.fn(async () => {}); const pluginServices = { stop: vi.fn(async () => {}) }; const { createGatewayCloseHandler } = await import("./server-close.js"); + const { createChatRunState } = await import("./server-chat-state.js"); const close = createGatewayCloseHandler({ bonjourStop: null, @@ -1060,7 +1061,11 @@ describe("startGatewayPostAttachRuntime", () => { heartbeatUnsub: null, transcriptUnsub: null, lifecycleUnsub: null, - chatRunState: { clear: vi.fn() }, + chatRunState: createChatRunState(), + chatAbortControllers: new Map(), + removeChatRun: vi.fn(), + agentRunSeq: new Map(), + nodeSendToSession: vi.fn(), clients: new Set(), configReloader: { stop: vi.fn(async () => {}) }, wss: { close: vi.fn((callback: () => void) => callback()) } as never, diff --git a/src/gateway/server.impl.ts b/src/gateway/server.impl.ts index d1f38536a3f..a9562f3d2e7 100644 --- a/src/gateway/server.impl.ts +++ b/src/gateway/server.impl.ts @@ -450,8 +450,14 @@ function createGatewayAuthRateLimiters(rateLimitConfig: AuthRateLimitConfig | un return { rateLimiter, browserRateLimiter }; } +export type GatewayCloseOptions = { + reason?: string; + restartExpectedMs?: number | null; + drainTimeoutMs?: number | null; +}; + export type GatewayServer = { - close: (opts?: { reason?: string; restartExpectedMs?: number | null }) => Promise; + close: (opts?: GatewayCloseOptions) => Promise; }; export type GatewayServerOptions = { @@ -973,42 +979,46 @@ export async function startGatewayServer( postReadySidecar.stop(); } }; - const createCloseHandler = - () => async (opts?: { reason?: string; restartExpectedMs?: number | null }) => { - const channelIds = listLoadedChannelPlugins().map((plugin) => plugin.id as ChannelId); - const { createGatewayCloseHandler, drainActiveSessionsForShutdown } = - await loadGatewayCloseModule(); - await createGatewayCloseHandler({ - bonjourStop: runtimeState.bonjourStop, - tailscaleCleanup: runtimeState.tailscaleCleanup, - releasePluginRouteRegistry, - channelIds, - stopChannel, - pluginServices: runtimeState.pluginServices, - postReadySidecars: runtimeState.postReadySidecars, - cron: runtimeState.cronState.cron, - heartbeatRunner: runtimeState.heartbeatRunner, - updateCheckStop: runtimeState.stopGatewayUpdateCheck, - stopTaskRegistryMaintenance: stopTaskRegistryMaintenanceOnDemand, - nodePresenceTimers, - broadcast, - tickInterval: runtimeState.tickInterval, - healthInterval: runtimeState.healthInterval, - dedupeCleanup: runtimeState.dedupeCleanup, - mediaCleanup: runtimeState.mediaCleanup, - agentUnsub: runtimeState.agentUnsub, - heartbeatUnsub: runtimeState.heartbeatUnsub, - transcriptUnsub: runtimeState.transcriptUnsub, - lifecycleUnsub: runtimeState.lifecycleUnsub, - chatRunState, - clients, - configReloader: runtimeState.configReloader, - wss, - httpServer, - httpServers, - drainActiveSessionsForShutdown, - })(opts); - }; + const createCloseHandler = () => async (opts?: GatewayCloseOptions) => { + const channelIds = listLoadedChannelPlugins().map((plugin) => plugin.id as ChannelId); + const { createGatewayCloseHandler, drainActiveSessionsForShutdown } = + await loadGatewayCloseModule(); + await createGatewayCloseHandler({ + bonjourStop: runtimeState.bonjourStop, + tailscaleCleanup: runtimeState.tailscaleCleanup, + releasePluginRouteRegistry, + channelIds, + stopChannel, + pluginServices: runtimeState.pluginServices, + postReadySidecars: runtimeState.postReadySidecars, + cron: runtimeState.cronState.cron, + heartbeatRunner: runtimeState.heartbeatRunner, + updateCheckStop: runtimeState.stopGatewayUpdateCheck, + stopTaskRegistryMaintenance: stopTaskRegistryMaintenanceOnDemand, + nodePresenceTimers, + broadcast, + tickInterval: runtimeState.tickInterval, + healthInterval: runtimeState.healthInterval, + dedupeCleanup: runtimeState.dedupeCleanup, + mediaCleanup: runtimeState.mediaCleanup, + agentUnsub: runtimeState.agentUnsub, + heartbeatUnsub: runtimeState.heartbeatUnsub, + transcriptUnsub: runtimeState.transcriptUnsub, + lifecycleUnsub: runtimeState.lifecycleUnsub, + chatRunState, + chatAbortControllers, + removeChatRun, + agentRunSeq, + nodeSendToSession, + getPendingReplyCount: getTotalPendingReplies, + clients, + configReloader: runtimeState.configReloader, + wss, + httpServer, + httpServers, + drainActiveSessionsForShutdown, + })(opts); + }; let clearFallbackGatewayContextForServer = () => {}; const closeOnStartupFailure = async () => { try {