From 3dda75894b80a6521757489bef44a620b97b0a2e Mon Sep 17 00:00:00 2001 From: Peter Steinberger Date: Sat, 4 Apr 2026 20:22:03 +0900 Subject: [PATCH] refactor(agents): centralize run wait helpers --- .../agent-step.test.ts => run-wait.test.ts} | 46 +++- src/agents/run-wait.ts | 226 ++++++++++++++++++ src/agents/subagent-control.ts | 8 +- src/agents/subagent-registry-run-manager.ts | 2 +- src/agents/tools/agent-step.ts | 201 +--------------- src/agents/tools/sessions-send-tool.a2a.ts | 3 +- src/agents/tools/sessions-send-tool.ts | 2 +- .../isolated-agent/subagent-followup.test.ts | 13 +- src/cron/isolated-agent/subagent-followup.ts | 13 +- 9 files changed, 290 insertions(+), 224 deletions(-) rename src/agents/{tools/agent-step.test.ts => run-wait.test.ts} (84%) create mode 100644 src/agents/run-wait.ts diff --git a/src/agents/tools/agent-step.test.ts b/src/agents/run-wait.test.ts similarity index 84% rename from src/agents/tools/agent-step.test.ts rename to src/agents/run-wait.test.ts index 74d098325f9..98ba73534ba 100644 --- a/src/agents/tools/agent-step.test.ts +++ b/src/agents/run-wait.test.ts @@ -1,7 +1,7 @@ import { beforeEach, describe, expect, it, vi } from "vitest"; const callGatewayMock = vi.fn(); -vi.mock("../../gateway/call.js", () => ({ +vi.mock("../gateway/call.js", () => ({ callGateway: (opts: unknown) => callGatewayMock(opts), })); @@ -10,9 +10,9 @@ import { readLatestAssistantReply, readLatestAssistantReplySnapshot, waitForAgentRun, - waitForAgentRunsUntilQuiescent, + waitForAgentRunsToDrain, waitForAgentRunAndReadUpdatedAssistantReply, -} from "./agent-step.js"; +} from "./run-wait.js"; describe("readLatestAssistantReply", () => { beforeEach(() => { @@ -181,7 +181,7 @@ describe("waitForAgentRunAndReadUpdatedAssistantReply", () => { }); }); -describe("waitForAgentRunsUntilQuiescent", () => { +describe("waitForAgentRunsToDrain", () => { beforeEach(() => { callGatewayMock.mockClear(); __testing.setDepsForTest({ @@ -204,7 +204,7 @@ describe("waitForAgentRunsUntilQuiescent", () => { return { status: "ok" }; }); - const result = await waitForAgentRunsUntilQuiescent({ + const result = await waitForAgentRunsToDrain({ timeoutMs: 1_000, getPendingRunIds: () => activeRunIds, }); @@ -236,13 +236,45 @@ describe("waitForAgentRunsUntilQuiescent", () => { it("deduplicates and trims pending run ids", async () => { callGatewayMock.mockResolvedValue({ status: "ok" }); + let activeRunIds = [" run-1 ", "run-1", "", "run-2"]; - const result = await waitForAgentRunsUntilQuiescent({ + const result = await waitForAgentRunsToDrain({ timeoutMs: 1_000, - getPendingRunIds: () => [" run-1 ", "run-1", "", "run-2"], + getPendingRunIds: () => { + const current = activeRunIds; + activeRunIds = []; + return current; + }, }); expect(result.timedOut).toBe(false); expect(callGatewayMock.mock.calls).toHaveLength(2); }); + + it("keeps the initial pending run ids before refreshing", async () => { + callGatewayMock.mockResolvedValue({ status: "ok" }); + let activeRunIds = ["run-2"]; + + const result = await waitForAgentRunsToDrain({ + timeoutMs: 1_000, + initialPendingRunIds: ["run-1"], + getPendingRunIds: () => { + const current = activeRunIds; + activeRunIds = []; + return current; + }, + }); + + expect(result.timedOut).toBe(false); + expect(callGatewayMock.mock.calls.map((call) => call[0])).toEqual([ + expect.objectContaining({ + method: "agent.wait", + params: expect.objectContaining({ runId: "run-1" }), + }), + expect.objectContaining({ + method: "agent.wait", + params: expect.objectContaining({ runId: "run-2" }), + }), + ]); + }); }); diff --git a/src/agents/run-wait.ts b/src/agents/run-wait.ts new file mode 100644 index 00000000000..49fe6762170 --- /dev/null +++ b/src/agents/run-wait.ts @@ -0,0 +1,226 @@ +import { callGateway } from "../gateway/call.js"; +import { extractAssistantText, stripToolMessages } from "./tools/chat-history-text.js"; + +type GatewayCaller = typeof callGateway; + +const defaultRunWaitDeps = { + callGateway, +}; + +let runWaitDeps: { + callGateway: GatewayCaller; +} = defaultRunWaitDeps; + +export type AssistantReplySnapshot = { + text?: string; + fingerprint?: string; +}; + +export type AgentWaitResult = { + status: "ok" | "timeout" | "error"; + error?: string; + startedAt?: number; + endedAt?: number; +}; + +export type AgentRunsDrainResult = { + timedOut: boolean; + pendingRunIds: string[]; + deadlineAtMs: number; +}; + +type RawAgentWaitResponse = { + status?: string; + error?: string; + startedAt?: unknown; + endedAt?: unknown; +}; + +function normalizeAgentWaitResult( + status: AgentWaitResult["status"], + wait?: RawAgentWaitResponse, +): AgentWaitResult { + return { + status, + error: typeof wait?.error === "string" ? wait.error : undefined, + startedAt: typeof wait?.startedAt === "number" ? wait.startedAt : undefined, + endedAt: typeof wait?.endedAt === "number" ? wait.endedAt : undefined, + }; +} + +function normalizePendingRunIds(runIds: Iterable): string[] { + const seen = new Set(); + for (const runId of runIds) { + const normalized = runId.trim(); + if (!normalized || seen.has(normalized)) { + continue; + } + seen.add(normalized); + } + return [...seen]; +} + +function resolveLatestAssistantReplySnapshot(messages: unknown[]): AssistantReplySnapshot { + for (let i = messages.length - 1; i >= 0; i -= 1) { + const candidate = messages[i]; + if (!candidate || typeof candidate !== "object") { + continue; + } + if ((candidate as { role?: unknown }).role !== "assistant") { + continue; + } + const text = extractAssistantText(candidate); + if (!text?.trim()) { + continue; + } + let fingerprint: string | undefined; + try { + fingerprint = JSON.stringify(candidate); + } catch { + fingerprint = text; + } + return { text, fingerprint }; + } + return {}; +} + +export async function readLatestAssistantReplySnapshot(params: { + sessionKey: string; + limit?: number; + callGateway?: GatewayCaller; +}): Promise { + const history = await (params.callGateway ?? runWaitDeps.callGateway)<{ + messages: Array; + }>({ + method: "chat.history", + params: { sessionKey: params.sessionKey, limit: params.limit ?? 50 }, + }); + return resolveLatestAssistantReplySnapshot( + stripToolMessages(Array.isArray(history?.messages) ? history.messages : []), + ); +} + +export async function readLatestAssistantReply(params: { + sessionKey: string; + limit?: number; + callGateway?: GatewayCaller; +}): Promise { + return ( + await readLatestAssistantReplySnapshot({ + sessionKey: params.sessionKey, + limit: params.limit, + callGateway: params.callGateway, + }) + ).text; +} + +export async function waitForAgentRun(params: { + runId: string; + timeoutMs: number; + callGateway?: GatewayCaller; +}): Promise { + const timeoutMs = Math.max(1, Math.floor(params.timeoutMs)); + try { + const wait = await (params.callGateway ?? runWaitDeps.callGateway)({ + method: "agent.wait", + params: { + runId: params.runId, + timeoutMs, + }, + timeoutMs: timeoutMs + 2000, + }); + if (wait?.status === "timeout") { + return normalizeAgentWaitResult("timeout", wait); + } + if (wait?.status === "error") { + return normalizeAgentWaitResult("error", wait); + } + return normalizeAgentWaitResult("ok", wait); + } catch (err) { + const error = err instanceof Error ? err.message : String(err); + return { + status: error.includes("gateway timeout") ? "timeout" : "error", + error, + }; + } +} + +export async function waitForAgentRunAndReadUpdatedAssistantReply(params: { + runId: string; + sessionKey: string; + timeoutMs: number; + limit?: number; + baseline?: AssistantReplySnapshot; + callGateway?: GatewayCaller; +}): Promise { + const wait = await waitForAgentRun({ + runId: params.runId, + timeoutMs: params.timeoutMs, + callGateway: params.callGateway, + }); + if (wait.status !== "ok") { + return wait; + } + + const latestReply = await readLatestAssistantReplySnapshot({ + sessionKey: params.sessionKey, + limit: params.limit, + callGateway: params.callGateway, + }); + const baselineFingerprint = params.baseline?.fingerprint; + const replyText = + latestReply.text && (!baselineFingerprint || latestReply.fingerprint !== baselineFingerprint) + ? latestReply.text + : undefined; + return { + status: "ok", + replyText, + }; +} + +export async function waitForAgentRunsToDrain(params: { + getPendingRunIds: () => Iterable; + initialPendingRunIds?: Iterable; + timeoutMs?: number; + deadlineAtMs?: number; + callGateway?: GatewayCaller; +}): Promise { + const deadlineAtMs = + params.deadlineAtMs ?? Date.now() + Math.max(1, Math.floor(params.timeoutMs ?? 0)); + + // Runs may finish and spawn more runs, so refresh until no pending IDs remain. + let pendingRunIds = new Set( + normalizePendingRunIds(params.initialPendingRunIds ?? params.getPendingRunIds()), + ); + + while (pendingRunIds.size > 0 && Date.now() < deadlineAtMs) { + const remainingMs = Math.max(1, deadlineAtMs - Date.now()); + await Promise.allSettled( + [...pendingRunIds].map((runId) => + waitForAgentRun({ + runId, + timeoutMs: remainingMs, + callGateway: params.callGateway, + }), + ), + ); + pendingRunIds = new Set(normalizePendingRunIds(params.getPendingRunIds())); + } + + return { + timedOut: pendingRunIds.size > 0, + pendingRunIds: [...pendingRunIds], + deadlineAtMs, + }; +} + +export const __testing = { + setDepsForTest(overrides?: Partial<{ callGateway: GatewayCaller }>) { + runWaitDeps = overrides + ? { + ...defaultRunWaitDeps, + ...overrides, + } + : defaultRunWaitDeps; + }, +}; diff --git a/src/agents/subagent-control.ts b/src/agents/subagent-control.ts index 3b3e02daded..eded9cf467b 100644 --- a/src/agents/subagent-control.ts +++ b/src/agents/subagent-control.ts @@ -26,6 +26,10 @@ import { INTERNAL_MESSAGE_CHANNEL } from "../utils/message-channel.js"; import { AGENT_LANE_SUBAGENT } from "./lanes.js"; import { resolveModelDisplayName, resolveModelDisplayRef } from "./model-selection-display.js"; import { abortEmbeddedPiRun } from "./pi-embedded.js"; +import { + readLatestAssistantReplySnapshot, + waitForAgentRunAndReadUpdatedAssistantReply, +} from "./run-wait.js"; import { resolveStoredSubagentCapabilities } from "./subagent-capabilities.js"; import { clearSubagentRunSteerRestart, @@ -39,10 +43,6 @@ import { replaceSubagentRunAfterSteer, type SubagentRunRecord, } from "./subagent-registry.js"; -import { - readLatestAssistantReplySnapshot, - waitForAgentRunAndReadUpdatedAssistantReply, -} from "./tools/agent-step.js"; import { resolveInternalSessionKey, resolveMainSessionAlias } from "./tools/sessions-helpers.js"; export const DEFAULT_RECENT_MINUTES = 30; diff --git a/src/agents/subagent-registry-run-manager.ts b/src/agents/subagent-registry-run-manager.ts index 96930ef8b10..daafa01a6da 100644 --- a/src/agents/subagent-registry-run-manager.ts +++ b/src/agents/subagent-registry-run-manager.ts @@ -3,6 +3,7 @@ import { callGateway } from "../gateway/call.js"; import { createSubsystemLogger } from "../logging/subsystem.js"; import { createRunningTaskRun } from "../tasks/task-executor.js"; import { type DeliveryContext, normalizeDeliveryContext } from "../utils/delivery-context.js"; +import { waitForAgentRun } from "./run-wait.js"; import type { ensureRuntimePluginsLoaded as ensureRuntimePluginsLoadedFn } from "./runtime-plugins.js"; import type { SubagentRunOutcome } from "./subagent-announce.js"; import { @@ -21,7 +22,6 @@ import { safeRemoveAttachmentsDir, } from "./subagent-registry-helpers.js"; import type { SubagentRunRecord } from "./subagent-registry.types.js"; -import { waitForAgentRun } from "./tools/agent-step.js"; const log = createSubsystemLogger("agents/subagent-registry"); diff --git a/src/agents/tools/agent-step.ts b/src/agents/tools/agent-step.ts index 4cf0cdebdca..c570d2a4d78 100644 --- a/src/agents/tools/agent-step.ts +++ b/src/agents/tools/agent-step.ts @@ -2,7 +2,9 @@ import crypto from "node:crypto"; import { callGateway } from "../../gateway/call.js"; import { INTERNAL_MESSAGE_CHANNEL } from "../../utils/message-channel.js"; import { AGENT_LANE_NESTED } from "../lanes.js"; -import { extractAssistantText, stripToolMessages } from "./chat-history-text.js"; +import { waitForAgentRunAndReadUpdatedAssistantReply } from "../run-wait.js"; + +export { readLatestAssistantReply } from "../run-wait.js"; type GatewayCaller = typeof callGateway; @@ -14,203 +16,6 @@ let agentStepDeps: { callGateway: GatewayCaller; } = defaultAgentStepDeps; -export type AssistantReplySnapshot = { - text?: string; - fingerprint?: string; -}; - -export type AgentWaitResult = { - status: "ok" | "timeout" | "error"; - error?: string; - startedAt?: number; - endedAt?: number; -}; - -export type AgentRunsQuiescentResult = { - timedOut: boolean; - pendingRunIds: string[]; - deadlineAtMs: number; -}; - -type RawAgentWaitResponse = { - status?: string; - error?: string; - startedAt?: unknown; - endedAt?: unknown; -}; - -function normalizeAgentWaitResult( - status: AgentWaitResult["status"], - wait?: RawAgentWaitResponse, -): AgentWaitResult { - return { - status, - error: typeof wait?.error === "string" ? wait.error : undefined, - startedAt: typeof wait?.startedAt === "number" ? wait.startedAt : undefined, - endedAt: typeof wait?.endedAt === "number" ? wait.endedAt : undefined, - }; -} - -function normalizePendingRunIds(runIds: Iterable): string[] { - const seen = new Set(); - for (const runId of runIds) { - const normalized = runId.trim(); - if (!normalized || seen.has(normalized)) { - continue; - } - seen.add(normalized); - } - return [...seen]; -} - -function resolveLatestAssistantReplySnapshot(messages: unknown[]): AssistantReplySnapshot { - for (let i = messages.length - 1; i >= 0; i -= 1) { - const candidate = messages[i]; - if (!candidate || typeof candidate !== "object") { - continue; - } - if ((candidate as { role?: unknown }).role !== "assistant") { - continue; - } - const text = extractAssistantText(candidate); - if (!text?.trim()) { - continue; - } - let fingerprint: string | undefined; - try { - fingerprint = JSON.stringify(candidate); - } catch { - fingerprint = text; - } - return { text, fingerprint }; - } - return {}; -} - -export async function readLatestAssistantReplySnapshot(params: { - sessionKey: string; - limit?: number; - callGateway?: GatewayCaller; -}): Promise { - const history = await (params.callGateway ?? agentStepDeps.callGateway)<{ - messages: Array; - }>({ - method: "chat.history", - params: { sessionKey: params.sessionKey, limit: params.limit ?? 50 }, - }); - return resolveLatestAssistantReplySnapshot( - stripToolMessages(Array.isArray(history?.messages) ? history.messages : []), - ); -} - -export async function readLatestAssistantReply(params: { - sessionKey: string; - limit?: number; -}): Promise { - return ( - await readLatestAssistantReplySnapshot({ - sessionKey: params.sessionKey, - limit: params.limit, - }) - ).text; -} - -export async function waitForAgentRun(params: { - runId: string; - timeoutMs: number; - callGateway?: GatewayCaller; -}): Promise { - const timeoutMs = Math.max(1, Math.floor(params.timeoutMs)); - try { - const wait = await (params.callGateway ?? agentStepDeps.callGateway)({ - method: "agent.wait", - params: { - runId: params.runId, - timeoutMs, - }, - timeoutMs: timeoutMs + 2000, - }); - if (wait?.status === "timeout") { - return normalizeAgentWaitResult("timeout", wait); - } - if (wait?.status === "error") { - return normalizeAgentWaitResult("error", wait); - } - return normalizeAgentWaitResult("ok", wait); - } catch (err) { - const error = err instanceof Error ? err.message : String(err); - return { - status: error.includes("gateway timeout") ? "timeout" : "error", - error, - }; - } -} - -export async function waitForAgentRunAndReadUpdatedAssistantReply(params: { - runId: string; - sessionKey: string; - timeoutMs: number; - limit?: number; - baseline?: AssistantReplySnapshot; - callGateway?: GatewayCaller; -}): Promise { - const wait = await waitForAgentRun({ - runId: params.runId, - timeoutMs: params.timeoutMs, - callGateway: params.callGateway, - }); - if (wait.status !== "ok") { - return wait; - } - - const latestReply = await readLatestAssistantReplySnapshot({ - sessionKey: params.sessionKey, - limit: params.limit, - callGateway: params.callGateway, - }); - const baselineFingerprint = params.baseline?.fingerprint; - const replyText = - latestReply.text && (!baselineFingerprint || latestReply.fingerprint !== baselineFingerprint) - ? latestReply.text - : undefined; - return { - status: "ok", - replyText, - }; -} - -export async function waitForAgentRunsUntilQuiescent(params: { - getPendingRunIds: () => Iterable; - timeoutMs?: number; - deadlineAtMs?: number; - callGateway?: GatewayCaller; -}): Promise { - const deadlineAtMs = - params.deadlineAtMs ?? Date.now() + Math.max(1, Math.floor(params.timeoutMs ?? 0)); - - let pendingRunIds = new Set(normalizePendingRunIds(params.getPendingRunIds())); - - while (pendingRunIds.size > 0 && Date.now() < deadlineAtMs) { - const remainingMs = Math.max(1, deadlineAtMs - Date.now()); - await Promise.allSettled( - [...pendingRunIds].map((runId) => - waitForAgentRun({ - runId, - timeoutMs: remainingMs, - callGateway: params.callGateway, - }), - ), - ); - pendingRunIds = new Set(normalizePendingRunIds(params.getPendingRunIds())); - } - - return { - timedOut: pendingRunIds.size > 0, - pendingRunIds: [...pendingRunIds], - deadlineAtMs, - }; -} - export async function runAgentStep(params: { sessionKey: string; message: string; diff --git a/src/agents/tools/sessions-send-tool.a2a.ts b/src/agents/tools/sessions-send-tool.a2a.ts index a4a7b1fc539..3371cc4287c 100644 --- a/src/agents/tools/sessions-send-tool.a2a.ts +++ b/src/agents/tools/sessions-send-tool.a2a.ts @@ -4,7 +4,8 @@ import { formatErrorMessage } from "../../infra/errors.js"; import { createSubsystemLogger } from "../../logging/subsystem.js"; import type { GatewayMessageChannel } from "../../utils/message-channel.js"; import { AGENT_LANE_NESTED } from "../lanes.js"; -import { readLatestAssistantReply, runAgentStep, waitForAgentRun } from "./agent-step.js"; +import { readLatestAssistantReply, waitForAgentRun } from "../run-wait.js"; +import { runAgentStep } from "./agent-step.js"; import { resolveAnnounceTarget } from "./sessions-announce-target.js"; import { buildAgentToAgentAnnounceContext, diff --git a/src/agents/tools/sessions-send-tool.ts b/src/agents/tools/sessions-send-tool.ts index 69ee8695eb5..46e9265b42f 100644 --- a/src/agents/tools/sessions-send-tool.ts +++ b/src/agents/tools/sessions-send-tool.ts @@ -12,7 +12,7 @@ import { AGENT_LANE_NESTED } from "../lanes.js"; import { readLatestAssistantReplySnapshot, waitForAgentRunAndReadUpdatedAssistantReply, -} from "./agent-step.js"; +} from "../run-wait.js"; import type { AnyAgentTool } from "./common.js"; import { jsonResult, readStringParam } from "./common.js"; import { diff --git a/src/cron/isolated-agent/subagent-followup.test.ts b/src/cron/isolated-agent/subagent-followup.test.ts index c4038c944de..af13b48d7a2 100644 --- a/src/cron/isolated-agent/subagent-followup.test.ts +++ b/src/cron/isolated-agent/subagent-followup.test.ts @@ -15,9 +15,9 @@ vi.mock("../../agents/subagent-registry-read.js", () => ({ listDescendantRunsForRequester: vi.fn().mockReturnValue([]), })); -vi.mock("../../agents/tools/agent-step.js", async () => { - const actual = await vi.importActual( - "../../agents/tools/agent-step.js", +vi.mock("../../agents/run-wait.js", async () => { + const actual = await vi.importActual( + "../../agents/run-wait.js", ); return { ...actual, @@ -30,7 +30,8 @@ vi.mock("../../gateway/call.js", () => ({ })); const { listDescendantRunsForRequester } = await import("../../agents/subagent-registry-read.js"); -const { readLatestAssistantReply } = await import("../../agents/tools/agent-step.js"); +const { __testing: runWaitTesting, readLatestAssistantReply } = + await import("../../agents/run-wait.js"); const { callGateway } = await import("../../gateway/call.js"); async function resolveAfterAdvancingTimers(promise: Promise, advanceMs = 100): Promise { @@ -237,10 +238,14 @@ describe("waitForDescendantSubagentSummary", () => { vi.mocked(listDescendantRunsForRequester).mockReturnValue([]); vi.mocked(readLatestAssistantReply).mockResolvedValue(undefined); vi.mocked(callGateway).mockResolvedValue({ status: "ok" }); + runWaitTesting.setDepsForTest({ + callGateway: ((opts) => vi.mocked(callGateway)(opts as never)) as typeof callGateway, + }); }); afterEach(() => { vi.useRealTimers(); + runWaitTesting.setDepsForTest(); }); it("returns initialReply immediately when no active descendants and observedActiveDescendants=false", async () => { diff --git a/src/cron/isolated-agent/subagent-followup.ts b/src/cron/isolated-agent/subagent-followup.ts index 5aa5617e13c..ca981bc7511 100644 --- a/src/cron/isolated-agent/subagent-followup.ts +++ b/src/cron/isolated-agent/subagent-followup.ts @@ -1,8 +1,5 @@ +import { readLatestAssistantReply, waitForAgentRunsToDrain } from "../../agents/run-wait.js"; import { listDescendantRunsForRequester } from "../../agents/subagent-registry-read.js"; -import { - readLatestAssistantReply, - waitForAgentRunsUntilQuiescent, -} from "../../agents/tools/agent-step.js"; import { SILENT_REPLY_TOKEN } from "../../auto-reply/tokens.js"; import { expectsSubagentFollowup, isLikelyInterimCronMessage } from "./subagent-followup-hints.js"; export { expectsSubagentFollowup, isLikelyInterimCronMessage } from "./subagent-followup-hints.js"; @@ -100,11 +97,11 @@ export async function waitForDescendantSubagentSummary(params: { return initialReply; } - // --- Push-based wait for all active descendants --- - // We iterate in case first-level descendants spawn their own subagents while - // we wait, so new active runs can appear between rounds. - await waitForAgentRunsUntilQuiescent({ + // Wait until no descendant runs remain active. Descendants can finish and + // spawn more descendants, so the helper refreshes the run set until it drains. + await waitForAgentRunsToDrain({ deadlineAtMs: deadline, + initialPendingRunIds: initialActiveRuns.map((entry) => entry.runId), getPendingRunIds: () => getActiveRuns().map((entry) => entry.runId), });