fix: require tool-name matches for signed anthropic replay pairing

This commit is contained in:
Shakker
2026-04-12 04:11:48 +01:00
committed by Shakker
parent 98e89f5939
commit 1a689240dc
2 changed files with 104 additions and 17 deletions

View File

@@ -577,6 +577,34 @@ describe("validateAnthropicTurns strips dangling tool_use blocks", () => {
]);
});
it("does not trust future tool results with the right id but the wrong tool name", () => {
const msgs = asMessages([
{ role: "user", content: [{ type: "text", text: "Use tool" }] },
{
role: "assistant",
content: [
{ type: "thinking", thinking: "internal", thinkingSignature: "sig_1" },
{ type: "toolCall", id: "tool-1", name: "gateway", arguments: {} },
],
},
{
role: "toolResult",
toolCallId: "tool-1",
toolName: "exec",
content: [{ type: "text", text: "wrong tool" }],
isError: false,
},
{ role: "user", content: [{ type: "text", text: "Continue" }] },
]);
const result = validateAnthropicTurns(msgs);
expect(result).toHaveLength(4);
expect((result[1] as { content?: unknown[] }).content).toEqual([
{ type: "text", text: "[tool calls omitted]" },
]);
});
it("drops redacted-thinking turns whose sibling tool calls are dangling", () => {
const msgs = asMessages([
{ role: "user", content: [{ type: "text", text: "Use tool" }] },

View File

@@ -28,19 +28,23 @@ function isAbortedAssistantTurn(message: AgentMessage): boolean {
return stopReason === "aborted" || stopReason === "error";
}
function extractToolResultIdsFromRecord(record: Record<string, unknown>): string[] {
const ids = [
normalizeOptionalString(record.toolUseId),
normalizeOptionalString(record.toolCallId),
normalizeOptionalString(record.tool_use_id),
normalizeOptionalString(record.tool_call_id),
normalizeOptionalString(record.callId),
normalizeOptionalString(record.call_id),
].filter((value): value is string => typeof value === "string");
return [...new Set(ids)];
function extractToolResultMatchId(record: Record<string, unknown>): string | null {
return (
normalizeOptionalString(record.toolUseId) ??
normalizeOptionalString(record.toolCallId) ??
normalizeOptionalString(record.tool_use_id) ??
normalizeOptionalString(record.tool_call_id) ??
normalizeOptionalString(record.callId) ??
normalizeOptionalString(record.call_id) ??
null
);
}
function collectMatchingToolResultIds(message: AgentMessage): Set<string> {
function extractToolResultMatchName(record: Record<string, unknown>): string | null {
return normalizeOptionalString(record.toolName) ?? normalizeOptionalString(record.name) ?? null;
}
function collectAnyToolResultIds(message: AgentMessage): Set<string> {
const ids = new Set<string>();
const role = (message as { role?: unknown }).role;
if (role === "toolResult") {
@@ -51,9 +55,9 @@ function collectMatchingToolResultIds(message: AgentMessage): Set<string> {
ids.add(toolResultId);
}
} else if (role === "tool") {
for (const id of extractToolResultIdsFromRecord(
message as unknown as Record<string, unknown>,
)) {
const record = message as unknown as Record<string, unknown>;
const id = extractToolResultMatchId(record);
if (id) {
ids.add(id);
}
}
@@ -71,7 +75,8 @@ function collectMatchingToolResultIds(message: AgentMessage): Set<string> {
if (record.type !== "toolResult" && record.type !== "tool") {
continue;
}
for (const id of extractToolResultIdsFromRecord(record)) {
const id = extractToolResultMatchId(record);
if (id) {
ids.add(id);
}
}
@@ -79,6 +84,56 @@ function collectMatchingToolResultIds(message: AgentMessage): Set<string> {
return ids;
}
function collectTrustedToolResultMatches(message: AgentMessage): Map<string, Set<string>> {
const matches = new Map<string, Set<string>>();
const role = (message as { role?: unknown }).role;
const addMatch = (id: string | null, toolName: string | null) => {
if (!id || !toolName) {
return;
}
const bucket = matches.get(id) ?? new Set<string>();
bucket.add(toolName);
matches.set(id, bucket);
};
if (role === "toolResult") {
const record = message as unknown as Record<string, unknown>;
addMatch(
extractToolResultId(message as Extract<AgentMessage, { role: "toolResult" }>),
extractToolResultMatchName(record),
);
} else if (role === "tool") {
const record = message as unknown as Record<string, unknown>;
addMatch(extractToolResultMatchId(record), extractToolResultMatchName(record));
}
return matches;
}
function collectFutureToolResultMatches(
messages: AgentMessage[],
startIndex: number,
): Map<string, Set<string>> {
const matches = new Map<string, Set<string>>();
for (let index = startIndex + 1; index < messages.length; index += 1) {
const candidate = messages[index];
if (!candidate || typeof candidate !== "object") {
continue;
}
if ((candidate as { role?: unknown }).role === "assistant") {
break;
}
for (const [id, toolNames] of collectTrustedToolResultMatches(candidate)) {
const bucket = matches.get(id) ?? new Set<string>();
for (const toolName of toolNames) {
bucket.add(toolName);
}
matches.set(id, bucket);
}
}
return matches;
}
function collectFutureToolResultIds(messages: AgentMessage[], startIndex: number): Set<string> {
const ids = new Set<string>();
for (let index = startIndex + 1; index < messages.length; index += 1) {
@@ -89,7 +144,7 @@ function collectFutureToolResultIds(messages: AgentMessage[], startIndex: number
if ((candidate as { role?: unknown }).role === "assistant") {
break;
}
for (const id of collectMatchingToolResultIds(candidate)) {
for (const id of collectAnyToolResultIds(candidate)) {
ids.add(id);
}
}
@@ -133,6 +188,7 @@ function stripDanglingAnthropicToolUses(messages: AgentMessage[]): AgentMessage[
continue;
}
const hasThinking = originalContent.some((block) => isThinkingLikeBlock(block));
const validToolResultMatches = collectFutureToolResultMatches(messages, i);
const validToolUseIds = collectFutureToolResultIds(messages, i);
if (hasThinking) {
@@ -141,7 +197,10 @@ function stripDanglingAnthropicToolUses(messages: AgentMessage[]): AgentMessage[
return true;
}
const blockId = normalizeOptionalString(block.id);
return blockId ? validToolUseIds.has(blockId) : false;
const blockName = normalizeOptionalString(block.name);
return blockId && blockName
? validToolResultMatches.get(blockId)?.has(blockName) === true
: false;
});
if (allToolCallsResolvable) {
result.push(msg);