diff --git a/src/agents/pi-embedded-helpers.validate-turns.test.ts b/src/agents/pi-embedded-helpers.validate-turns.test.ts index 725249778f4..93271d85225 100644 --- a/src/agents/pi-embedded-helpers.validate-turns.test.ts +++ b/src/agents/pi-embedded-helpers.validate-turns.test.ts @@ -577,6 +577,34 @@ describe("validateAnthropicTurns strips dangling tool_use blocks", () => { ]); }); + it("does not trust future tool results with the right id but the wrong tool name", () => { + const msgs = asMessages([ + { role: "user", content: [{ type: "text", text: "Use tool" }] }, + { + role: "assistant", + content: [ + { type: "thinking", thinking: "internal", thinkingSignature: "sig_1" }, + { type: "toolCall", id: "tool-1", name: "gateway", arguments: {} }, + ], + }, + { + role: "toolResult", + toolCallId: "tool-1", + toolName: "exec", + content: [{ type: "text", text: "wrong tool" }], + isError: false, + }, + { role: "user", content: [{ type: "text", text: "Continue" }] }, + ]); + + const result = validateAnthropicTurns(msgs); + + expect(result).toHaveLength(4); + expect((result[1] as { content?: unknown[] }).content).toEqual([ + { type: "text", text: "[tool calls omitted]" }, + ]); + }); + it("drops redacted-thinking turns whose sibling tool calls are dangling", () => { const msgs = asMessages([ { role: "user", content: [{ type: "text", text: "Use tool" }] }, diff --git a/src/agents/pi-embedded-helpers/turns.ts b/src/agents/pi-embedded-helpers/turns.ts index ea40dedb5b3..b890214361d 100644 --- a/src/agents/pi-embedded-helpers/turns.ts +++ b/src/agents/pi-embedded-helpers/turns.ts @@ -28,19 +28,23 @@ function isAbortedAssistantTurn(message: AgentMessage): boolean { return stopReason === "aborted" || stopReason === "error"; } -function extractToolResultIdsFromRecord(record: Record): string[] { - const ids = [ - normalizeOptionalString(record.toolUseId), - normalizeOptionalString(record.toolCallId), - normalizeOptionalString(record.tool_use_id), - normalizeOptionalString(record.tool_call_id), - normalizeOptionalString(record.callId), - normalizeOptionalString(record.call_id), - ].filter((value): value is string => typeof value === "string"); - return [...new Set(ids)]; +function extractToolResultMatchId(record: Record): string | null { + return ( + normalizeOptionalString(record.toolUseId) ?? + normalizeOptionalString(record.toolCallId) ?? + normalizeOptionalString(record.tool_use_id) ?? + normalizeOptionalString(record.tool_call_id) ?? + normalizeOptionalString(record.callId) ?? + normalizeOptionalString(record.call_id) ?? + null + ); } -function collectMatchingToolResultIds(message: AgentMessage): Set { +function extractToolResultMatchName(record: Record): string | null { + return normalizeOptionalString(record.toolName) ?? normalizeOptionalString(record.name) ?? null; +} + +function collectAnyToolResultIds(message: AgentMessage): Set { const ids = new Set(); const role = (message as { role?: unknown }).role; if (role === "toolResult") { @@ -51,9 +55,9 @@ function collectMatchingToolResultIds(message: AgentMessage): Set { ids.add(toolResultId); } } else if (role === "tool") { - for (const id of extractToolResultIdsFromRecord( - message as unknown as Record, - )) { + const record = message as unknown as Record; + const id = extractToolResultMatchId(record); + if (id) { ids.add(id); } } @@ -71,7 +75,8 @@ function collectMatchingToolResultIds(message: AgentMessage): Set { if (record.type !== "toolResult" && record.type !== "tool") { continue; } - for (const id of extractToolResultIdsFromRecord(record)) { + const id = extractToolResultMatchId(record); + if (id) { ids.add(id); } } @@ -79,6 +84,56 @@ function collectMatchingToolResultIds(message: AgentMessage): Set { return ids; } +function collectTrustedToolResultMatches(message: AgentMessage): Map> { + const matches = new Map>(); + const role = (message as { role?: unknown }).role; + const addMatch = (id: string | null, toolName: string | null) => { + if (!id || !toolName) { + return; + } + const bucket = matches.get(id) ?? new Set(); + bucket.add(toolName); + matches.set(id, bucket); + }; + + if (role === "toolResult") { + const record = message as unknown as Record; + addMatch( + extractToolResultId(message as Extract), + extractToolResultMatchName(record), + ); + } else if (role === "tool") { + const record = message as unknown as Record; + addMatch(extractToolResultMatchId(record), extractToolResultMatchName(record)); + } + + return matches; +} + +function collectFutureToolResultMatches( + messages: AgentMessage[], + startIndex: number, +): Map> { + const matches = new Map>(); + for (let index = startIndex + 1; index < messages.length; index += 1) { + const candidate = messages[index]; + if (!candidate || typeof candidate !== "object") { + continue; + } + if ((candidate as { role?: unknown }).role === "assistant") { + break; + } + for (const [id, toolNames] of collectTrustedToolResultMatches(candidate)) { + const bucket = matches.get(id) ?? new Set(); + for (const toolName of toolNames) { + bucket.add(toolName); + } + matches.set(id, bucket); + } + } + return matches; +} + function collectFutureToolResultIds(messages: AgentMessage[], startIndex: number): Set { const ids = new Set(); for (let index = startIndex + 1; index < messages.length; index += 1) { @@ -89,7 +144,7 @@ function collectFutureToolResultIds(messages: AgentMessage[], startIndex: number if ((candidate as { role?: unknown }).role === "assistant") { break; } - for (const id of collectMatchingToolResultIds(candidate)) { + for (const id of collectAnyToolResultIds(candidate)) { ids.add(id); } } @@ -133,6 +188,7 @@ function stripDanglingAnthropicToolUses(messages: AgentMessage[]): AgentMessage[ continue; } const hasThinking = originalContent.some((block) => isThinkingLikeBlock(block)); + const validToolResultMatches = collectFutureToolResultMatches(messages, i); const validToolUseIds = collectFutureToolResultIds(messages, i); if (hasThinking) { @@ -141,7 +197,10 @@ function stripDanglingAnthropicToolUses(messages: AgentMessage[]): AgentMessage[ return true; } const blockId = normalizeOptionalString(block.id); - return blockId ? validToolUseIds.has(blockId) : false; + const blockName = normalizeOptionalString(block.name); + return blockId && blockName + ? validToolResultMatches.get(blockId)?.has(blockName) === true + : false; }); if (allToolCallsResolvable) { result.push(msg);