diff --git a/src/agents/pi-embedded-runner/run/attempt.test.ts b/src/agents/pi-embedded-runner/run/attempt.test.ts index a2cc9b60293..8ec81f1c6d6 100644 --- a/src/agents/pi-embedded-runner/run/attempt.test.ts +++ b/src/agents/pi-embedded-runner/run/attempt.test.ts @@ -835,6 +835,49 @@ describe("wrapStreamFnSanitizeMalformedToolCalls", () => { const seenContext = baseFn.mock.calls[0]?.[1] as { messages: unknown[] }; expect(seenContext.messages).toBe(messages); }); + + it("preserves sessions_spawn attachment payloads on replay", async () => { + const attachmentContent = "INLINE_ATTACHMENT_PAYLOAD"; + const messages = [ + { + role: "assistant", + content: [ + { + type: "toolUse", + id: "call_1", + name: " SESSIONS_SPAWN ", + input: { + task: "inspect attachment", + attachments: [{ name: "snapshot.txt", content: attachmentContent }], + }, + }, + ], + }, + ]; + const baseFn = vi.fn((_model, _context) => + createFakeStream({ events: [], resultMessage: { role: "assistant", content: [] } }), + ); + + const wrapped = wrapStreamFnSanitizeMalformedToolCalls( + baseFn as never, + new Set(["sessions_spawn"]), + ); + const stream = wrapped({} as never, { messages } as never, {} as never) as + | FakeWrappedStream + | Promise; + await Promise.resolve(stream); + + expect(baseFn).toHaveBeenCalledTimes(1); + const seenContext = baseFn.mock.calls[0]?.[1] as { + messages: Array<{ content?: Array> }>; + }; + const toolCall = seenContext.messages[0]?.content?.[0] as { + name?: string; + input?: { attachments?: Array<{ content?: string }> }; + }; + expect(toolCall.name).toBe("SESSIONS_SPAWN"); + expect(toolCall.input?.attachments?.[0]?.content).toBe(attachmentContent); + }); }); describe("wrapStreamFnRepairMalformedToolCallArguments", () => { diff --git a/src/agents/pi-embedded-runner/run/attempt.ts b/src/agents/pi-embedded-runner/run/attempt.ts index b9509cc8853..a3acb863b78 100644 --- a/src/agents/pi-embedded-runner/run/attempt.ts +++ b/src/agents/pi-embedded-runner/run/attempt.ts @@ -81,10 +81,7 @@ import { resolveSandboxContext } from "../../sandbox.js"; import { resolveSandboxRuntimeStatus } from "../../sandbox/runtime-status.js"; import { repairSessionFileIfNeeded } from "../../session-file-repair.js"; import { guardSessionManager } from "../../session-tool-result-guard-wrapper.js"; -import { - sanitizeToolCallInputs, - sanitizeToolUseResultPairing, -} from "../../session-transcript-repair.js"; +import { sanitizeToolUseResultPairing } from "../../session-transcript-repair.js"; import { acquireSessionWriteLock, resolveSessionLockMaxHoldFromTimeout, @@ -651,6 +648,112 @@ function isToolCallBlockType(type: unknown): boolean { return type === "toolCall" || type === "toolUse" || type === "functionCall"; } +const REPLAY_TOOL_CALL_NAME_MAX_CHARS = 64; +const REPLAY_TOOL_CALL_NAME_RE = /^[A-Za-z0-9_-]+$/; + +type ReplayToolCallBlock = { + type?: unknown; + id?: unknown; + name?: unknown; + input?: unknown; + arguments?: unknown; +}; + +function isReplayToolCallBlock(block: unknown): block is ReplayToolCallBlock { + if (!block || typeof block !== "object") { + return false; + } + return isToolCallBlockType((block as { type?: unknown }).type); +} + +function replayToolCallHasInput(block: ReplayToolCallBlock): boolean { + const hasInput = "input" in block ? block.input !== undefined && block.input !== null : false; + const hasArguments = + "arguments" in block ? block.arguments !== undefined && block.arguments !== null : false; + return hasInput || hasArguments; +} + +function replayToolCallNonEmptyString(value: unknown): value is string { + return typeof value === "string" && value.trim().length > 0; +} + +function replayToolCallHasName( + block: ReplayToolCallBlock, + allowedToolNames?: Set, +): block is ReplayToolCallBlock & { name: string } { + if (!replayToolCallNonEmptyString(block.name)) { + return false; + } + const trimmed = block.name.trim(); + if (trimmed.length > REPLAY_TOOL_CALL_NAME_MAX_CHARS || !REPLAY_TOOL_CALL_NAME_RE.test(trimmed)) { + return false; + } + if (!allowedToolNames || allowedToolNames.size === 0) { + return true; + } + return allowedToolNames.has(trimmed.toLowerCase()); +} + +function sanitizeReplayToolCallInputs( + messages: AgentMessage[], + allowedToolNames?: Set, +): AgentMessage[] { + let changed = false; + const out: AgentMessage[] = []; + + for (const message of messages) { + if (!message || typeof message !== "object" || message.role !== "assistant") { + out.push(message); + continue; + } + if (!Array.isArray(message.content)) { + out.push(message); + continue; + } + + const nextContent: typeof message.content = []; + let messageChanged = false; + + for (const block of message.content) { + if (!isReplayToolCallBlock(block)) { + nextContent.push(block); + continue; + } + + if ( + !replayToolCallHasInput(block) || + !replayToolCallNonEmptyString(block.id) || + !replayToolCallHasName(block, allowedToolNames) + ) { + changed = true; + messageChanged = true; + continue; + } + + const trimmedName = block.name.trim(); + if (block.name !== trimmedName) { + nextContent.push({ ...(block as object), name: trimmedName } as typeof block); + changed = true; + messageChanged = true; + continue; + } + nextContent.push(block); + } + + if (messageChanged) { + changed = true; + if (nextContent.length > 0) { + out.push({ ...message, content: nextContent }); + } + continue; + } + + out.push(message); + } + + return changed ? out : messages; +} + function normalizeToolCallIdsInMessage(message: unknown): void { if (!message || typeof message !== "object") { return; @@ -809,9 +912,7 @@ export function wrapStreamFnSanitizeMalformedToolCalls( if (!Array.isArray(messages)) { return baseFn(model, context, options); } - const sanitized = sanitizeToolCallInputs(messages as AgentMessage[], { - allowedToolNames, - }); + const sanitized = sanitizeReplayToolCallInputs(messages as AgentMessage[], allowedToolNames); if (sanitized === messages) { return baseFn(model, context, options); }