diff --git a/src/agents/pi-embedded-runner/run/attempt.test.ts b/src/agents/pi-embedded-runner/run/attempt.test.ts index 6585b461c19..a745d298252 100644 --- a/src/agents/pi-embedded-runner/run/attempt.test.ts +++ b/src/agents/pi-embedded-runner/run/attempt.test.ts @@ -875,10 +875,59 @@ describe("wrapStreamFnSanitizeMalformedToolCalls", () => { name?: string; input?: { attachments?: Array<{ content?: string }> }; }; - expect(toolCall.name).toBe("SESSIONS_SPAWN"); + expect(toolCall.name).toBe("sessions_spawn"); expect(toolCall.input?.attachments?.[0]?.content).toBe(attachmentContent); }); + it("preserves allowlisted tool names that contain punctuation", async () => { + const messages = [ + { + role: "assistant", + content: [{ type: "toolUse", id: "call_1", name: "admin.export", input: { scope: "all" } }], + }, + ]; + const baseFn = vi.fn((_model, _context) => + createFakeStream({ events: [], resultMessage: { role: "assistant", content: [] } }), + ); + + const wrapped = wrapStreamFnSanitizeMalformedToolCalls( + baseFn as never, + new Set(["admin.export"]), + ); + const stream = wrapped({} as never, { messages } as never, {} as never) as + | FakeWrappedStream + | Promise; + await Promise.resolve(stream); + + expect(baseFn).toHaveBeenCalledTimes(1); + const seenContext = baseFn.mock.calls[0]?.[1] as { messages: unknown[] }; + expect(seenContext.messages).toBe(messages); + }); + + it("canonicalizes mixed-case allowlisted tool names on replay", async () => { + const messages = [ + { + role: "assistant", + content: [{ type: "toolCall", id: "call_1", name: "readfile", arguments: {} }], + }, + ]; + const baseFn = vi.fn((_model, _context) => + createFakeStream({ events: [], resultMessage: { role: "assistant", content: [] } }), + ); + + const wrapped = wrapStreamFnSanitizeMalformedToolCalls(baseFn as never, new Set(["ReadFile"])); + const stream = wrapped({} as never, { messages } as never, {} as never) as + | FakeWrappedStream + | Promise; + await Promise.resolve(stream); + + expect(baseFn).toHaveBeenCalledTimes(1); + const seenContext = baseFn.mock.calls[0]?.[1] as { + messages: Array<{ content?: Array<{ name?: string }> }>; + }; + expect(seenContext.messages[0]?.content?.[0]?.name).toBe("ReadFile"); + }); + it("drops orphaned tool results after replay sanitization removes a tool-call turn", async () => { const messages = [ { diff --git a/src/agents/pi-embedded-runner/run/attempt.ts b/src/agents/pi-embedded-runner/run/attempt.ts index e63cfd3f532..b38179e6b10 100644 --- a/src/agents/pi-embedded-runner/run/attempt.ts +++ b/src/agents/pi-embedded-runner/run/attempt.ts @@ -649,7 +649,6 @@ function isToolCallBlockType(type: unknown): boolean { } const REPLAY_TOOL_CALL_NAME_MAX_CHARS = 64; -const REPLAY_TOOL_CALL_NAME_RE = /^[A-Za-z0-9_-]+$/; type ReplayToolCallBlock = { type?: unknown; @@ -677,21 +676,21 @@ function replayToolCallNonEmptyString(value: unknown): value is string { return typeof value === "string" && value.trim().length > 0; } -function replayToolCallHasName( - block: ReplayToolCallBlock, +function resolveReplayAllowedToolName( + rawName: string, allowedToolNames?: Set, -): block is ReplayToolCallBlock & { name: string } { - if (!replayToolCallNonEmptyString(block.name)) { - return false; - } - const trimmed = block.name.trim(); - if (trimmed.length > REPLAY_TOOL_CALL_NAME_MAX_CHARS || !REPLAY_TOOL_CALL_NAME_RE.test(trimmed)) { - return false; +): string | null { + const trimmed = rawName.trim(); + if (!trimmed || trimmed.length > REPLAY_TOOL_CALL_NAME_MAX_CHARS || /\s/.test(trimmed)) { + return null; } if (!allowedToolNames || allowedToolNames.size === 0) { - return true; + return trimmed; } - return allowedToolNames.has(trimmed.toLowerCase()); + return ( + resolveExactAllowedToolName(trimmed, allowedToolNames) ?? + resolveStructuredAllowedToolName(trimmed, allowedToolNames) + ); } function sanitizeReplayToolCallInputs( @@ -723,16 +722,22 @@ function sanitizeReplayToolCallInputs( if ( !replayToolCallHasInput(block) || !replayToolCallNonEmptyString(block.id) || - !replayToolCallHasName(block, allowedToolNames) + !replayToolCallNonEmptyString(block.name) ) { changed = true; messageChanged = true; continue; } - const trimmedName = block.name.trim(); - if (block.name !== trimmedName) { - nextContent.push({ ...(block as object), name: trimmedName } as typeof block); + const resolvedName = resolveReplayAllowedToolName(block.name, allowedToolNames); + if (!resolvedName) { + changed = true; + messageChanged = true; + continue; + } + + if (block.name !== resolvedName) { + nextContent.push({ ...(block as object), name: resolvedName } as typeof block); changed = true; messageChanged = true; continue;