From 16807824cc2b9567519bd2f14dc4769bcaee7d5a Mon Sep 17 00:00:00 2001 From: Vincent Koc Date: Mon, 1 Jun 2026 19:12:20 +0200 Subject: [PATCH] refactor: share node invoke approval test helpers --- ...server.node-invoke-approval-bypass.test.ts | 166 +++++++++--------- 1 file changed, 80 insertions(+), 86 deletions(-) diff --git a/src/gateway/server.node-invoke-approval-bypass.test.ts b/src/gateway/server.node-invoke-approval-bypass.test.ts index 3da255f8240..fd966319da5 100644 --- a/src/gateway/server.node-invoke-approval-bypass.test.ts +++ b/src/gateway/server.node-invoke-approval-bypass.test.ts @@ -51,6 +51,38 @@ async function expectNoForwardedInvoke(hasInvoke: () => boolean): Promise expect(hasInvoke()).toBe(false); } +function parseInvokeParamsJSON(payload: unknown): Record | null { + const obj = payload as { paramsJSON?: unknown }; + const raw = typeof obj?.paramsJSON === "string" ? obj.paramsJSON : ""; + return raw ? (JSON.parse(raw) as Record) : null; +} + +function createInvokeParamCapture() { + let invokeCount = 0; + let lastInvokeParams: Record | null = null; + return { + count: () => invokeCount, + onInvoke: (payload: unknown) => { + invokeCount += 1; + lastInvokeParams = parseInvokeParamsJSON(payload); + }, + waitForParams: async () => { + await vi.waitFor( + () => { + if (!lastInvokeParams) { + throw new Error("expected forwarded invoke params"); + } + }, + { + timeout: 5_000, + interval: 50, + }, + ); + return requireRecord(lastInvokeParams, "forwarded invoke params"); + }, + }; +} + function requireNonEmptyString(value: string | null | undefined, label: string): string { if (!value) { throw new Error(`expected ${label}`); @@ -126,6 +158,38 @@ async function requestAllowOnceApproval( return approvalId; } +function approvedSystemRunParams( + command: string[], + rawCommand: string, + runId: string, + extra: Record = {}, +): Record { + return { + command, + rawCommand, + runId, + approved: true, + approvalDecision: "allow-once", + ...extra, + }; +} + +function approvedChatSystemRunParams( + context: ChatApprovalContext, + runId: string, + extra: Record = {}, +): Record { + return approvedSystemRunParams(["echo", "chat"], "echo chat", runId, { + agentId: context.agentId, + sessionKey: context.sessionKey, + turnSourceChannel: context.turnSourceChannel, + turnSourceTo: context.turnSourceTo, + turnSourceAccountId: context.turnSourceAccountId, + turnSourceThreadId: context.turnSourceThreadId, + ...extra, + }); +} + type ChatApprovalContext = { agentId: string; sessionKey: string; @@ -509,18 +573,8 @@ describe("node.invoke approval bypass", () => { }); test("binds approvals to decision/device and blocks cross-device replay", async () => { - let invokeCount = 0; - let lastInvokeParams: Record | null = null; - const node = await connectLinuxNode((payload) => { - invokeCount += 1; - const obj = payload as { paramsJSON?: unknown }; - const raw = typeof obj?.paramsJSON === "string" ? obj.paramsJSON : ""; - if (!raw) { - lastInvokeParams = null; - return; - } - lastInvokeParams = JSON.parse(raw) as Record; - }); + const invokeCapture = createInvokeParamCapture(); + const node = await connectLinuxNode(invokeCapture.onInvoke); const wsApprover = await connectOperator(["operator.write", "operator.approvals"]); const wsCaller = await connectOperator(["operator.write"]); @@ -534,50 +588,29 @@ describe("node.invoke approval bypass", () => { const invoke = await rpcReq(wsCaller, "node.invoke", { nodeId, command: "system.run", - params: { - command: ["echo", "hi"], - rawCommand: "echo hi", - runId: approvalId, - approved: true, + params: approvedSystemRunParams(["echo", "hi"], "echo hi", approvalId, { approvalDecision: "allow-always", injected: "nope", - }, + }), idempotencyKey: crypto.randomUUID(), }); expect(invoke.ok).toBe(true); - await vi.waitFor( - () => { - if (!lastInvokeParams) { - throw new Error("expected forwarded invoke params"); - } - }, - { - timeout: 5_000, - interval: 50, - }, - ); - const forwardedParams = requireRecord(lastInvokeParams, "forwarded invoke params"); + const forwardedParams = await invokeCapture.waitForParams(); expect(forwardedParams["approved"]).toBe(true); expect(forwardedParams["approvalDecision"]).toBe("allow-once"); expect(forwardedParams["injected"]).toBeUndefined(); const replayApprovalId = await requestAllowOnceApproval(wsApprover, "echo hi", nodeId); - const invokeCountBeforeReplay = invokeCount; + const invokeCountBeforeReplay = invokeCapture.count(); const replay = await rpcReq(wsOtherDevice, "node.invoke", { nodeId, command: "system.run", - params: { - command: ["echo", "hi"], - rawCommand: "echo hi", - runId: replayApprovalId, - approved: true, - approvalDecision: "allow-once", - }, + params: approvedSystemRunParams(["echo", "hi"], "echo hi", replayApprovalId), idempotencyKey: crypto.randomUUID(), }); expect(replay.ok).toBe(false); expect(replay.error?.message ?? "").toContain("not valid for this device"); - await expectNoForwardedInvoke(() => invokeCount > invokeCountBeforeReplay); + await expectNoForwardedInvoke(() => invokeCapture.count() > invokeCountBeforeReplay); } finally { wsApprover.close(); wsCaller.close(); @@ -587,14 +620,8 @@ describe("node.invoke approval bypass", () => { }); test("bridges no-device chat approvals across backend reconnects only for the same turn source", async () => { - let invokeCount = 0; - let lastInvokeParams: Record | null = null; - const node = await connectLinuxNode((payload) => { - invokeCount += 1; - const obj = payload as { paramsJSON?: unknown }; - const raw = typeof obj?.paramsJSON === "string" ? obj.paramsJSON : ""; - lastInvokeParams = raw ? (JSON.parse(raw) as Record) : null; - }); + const invokeCapture = createInvokeParamCapture(); + const node = await connectLinuxNode(invokeCapture.onInvoke); const wsRequest = await connectTrustedBackend(["operator.write", "operator.approvals"]); const wsReplay = await connectTrustedBackend(["operator.write", "operator.approvals"]); @@ -619,34 +646,11 @@ describe("node.invoke approval bypass", () => { const invoke = await rpcReq(wsReplay, "node.invoke", { nodeId, command: "system.run", - params: { - command: ["echo", "chat"], - rawCommand: "echo chat", - agentId: context.agentId, - sessionKey: context.sessionKey, - turnSourceChannel: context.turnSourceChannel, - turnSourceTo: context.turnSourceTo, - turnSourceAccountId: context.turnSourceAccountId, - turnSourceThreadId: context.turnSourceThreadId, - runId: approvalId, - approved: true, - approvalDecision: "allow-once", - }, + params: approvedChatSystemRunParams(context, approvalId), idempotencyKey: crypto.randomUUID(), }); expect(invoke.ok).toBe(true); - await vi.waitFor( - () => { - if (!lastInvokeParams) { - throw new Error("expected forwarded invoke params"); - } - }, - { - timeout: 5_000, - interval: 50, - }, - ); - const forwardedParams = requireRecord(lastInvokeParams, "forwarded invoke params"); + const forwardedParams = await invokeCapture.waitForParams(); expect(forwardedParams["approved"]).toBe(true); expect(forwardedParams["approvalDecision"]).toBe("allow-once"); expect(forwardedParams["turnSourceTo"]).toBeUndefined(); @@ -657,28 +661,18 @@ describe("node.invoke approval bypass", () => { nodeId, context, }); - const invokeCountBeforeMismatch = invokeCount; + const invokeCountBeforeMismatch = invokeCapture.count(); const mismatch = await rpcReq(wsReplay, "node.invoke", { nodeId, command: "system.run", - params: { - command: ["echo", "chat"], - rawCommand: "echo chat", - agentId: context.agentId, - sessionKey: context.sessionKey, - turnSourceChannel: context.turnSourceChannel, + params: approvedChatSystemRunParams(context, mismatchApprovalId, { turnSourceTo: "telegram:67890", - turnSourceAccountId: context.turnSourceAccountId, - turnSourceThreadId: context.turnSourceThreadId, - runId: mismatchApprovalId, - approved: true, - approvalDecision: "allow-once", - }, + }), idempotencyKey: crypto.randomUUID(), }); expect(mismatch.ok).toBe(false); expect(mismatch.error?.message ?? "").toContain("not valid for this client"); - await expectNoForwardedInvoke(() => invokeCount > invokeCountBeforeMismatch); + await expectNoForwardedInvoke(() => invokeCapture.count() > invokeCountBeforeMismatch); } finally { wsRequest.close(); wsReplay.close();