diff --git a/src/agents/pi-embedded-runner/run/attempt.test.ts b/src/agents/pi-embedded-runner/run/attempt.test.ts index 67e182bad5f..0466f4402e9 100644 --- a/src/agents/pi-embedded-runner/run/attempt.test.ts +++ b/src/agents/pi-embedded-runner/run/attempt.test.ts @@ -722,6 +722,225 @@ describe("wrapStreamFnTrimToolCallNames", () => { ]); }); + it("counts the final unknown-tool retry when streamed messages omit the tool name", async () => { + const baseFn = vi.fn(() => + createFakeStream({ + events: [ + { + type: "toolcall_delta", + message: { role: "assistant", content: [{ type: "toolCall", name: "" }] }, + }, + ], + resultMessage: { + role: "assistant", + content: [{ type: "toolCall", name: " exec ", arguments: { command: "echo retry" } }], + }, + }), + ); + const wrappedFn = wrapStreamFnTrimToolCallNames(baseFn as never, new Set(["read"]), { + unknownToolThreshold: 1, + }); + + const firstStream = await Promise.resolve(wrappedFn({} as never, {} as never, {} as never)); + await firstStream.result(); + + const secondStream = await Promise.resolve(wrappedFn({} as never, {} as never, {} as never)); + for await (const _item of secondStream) { + // drain + } + const secondResult = (await secondStream.result()) as { + role: string; + content: Array<{ type: string; text?: string; name?: string }>; + }; + + expect(secondResult.role).toBe("assistant"); + expect(secondResult.content).toEqual([ + expect.objectContaining({ + type: "text", + text: expect.stringContaining('"exec"'), + }), + ]); + }); + + it("resets a provisional streamed unknown-tool retry when later chunks resolve to an allowed tool", async () => { + const baseFn = vi + .fn() + .mockImplementationOnce(() => + createFakeStream({ + events: [ + { + type: "toolcall_delta", + message: { role: "assistant", content: [{ type: "toolCall", name: " ex " }] }, + }, + { + type: "toolcall_delta", + message: { role: "assistant", content: [{ type: "toolCall", name: " exec " }] }, + }, + ], + resultMessage: { + role: "assistant", + content: [{ type: "toolCall", name: " exec ", arguments: { command: "echo ok" } }], + }, + }), + ) + .mockImplementationOnce(() => + createFakeStream({ + events: [], + resultMessage: { + role: "assistant", + content: [{ type: "toolCall", name: " ex ", arguments: { command: "echo retry" } }], + }, + }), + ); + const wrappedFn = wrapStreamFnTrimToolCallNames(baseFn as never, new Set(["exec"]), { + unknownToolThreshold: 1, + }); + + const firstStream = await Promise.resolve(wrappedFn({} as never, {} as never, {} as never)); + for await (const _item of firstStream) { + // drain + } + await firstStream.result(); + + const secondStream = await Promise.resolve(wrappedFn({} as never, {} as never, {} as never)); + const secondResult = (await secondStream.result()) as { + role: string; + content: Array<{ type: string; text?: string; name?: string }>; + }; + + expect(secondResult.role).toBe("assistant"); + expect(secondResult.content).toEqual([ + expect.objectContaining({ + type: "toolCall", + name: "ex", + }), + ]); + }); + + it("keeps processing later streamed messages after one streamed unknown-tool retry was counted", async () => { + const baseFn = vi + .fn() + .mockImplementationOnce(() => + createFakeStream({ + events: [ + { + type: "toolcall_delta", + message: { role: "assistant", content: [{ type: "toolCall", name: " re " }] }, + }, + { + type: "toolcall_delta", + message: { role: "assistant", content: [{ type: "toolCall", name: " read " }] }, + }, + ], + resultMessage: { + role: "assistant", + content: [{ type: "text", text: "resolved to allowed tool" }], + }, + }), + ) + .mockImplementationOnce(() => + createFakeStream({ + events: [], + resultMessage: { + role: "assistant", + content: [{ type: "toolCall", name: " re ", arguments: { command: "echo retry" } }], + }, + }), + ); + const wrappedFn = wrapStreamFnTrimToolCallNames(baseFn as never, new Set(["read"]), { + unknownToolThreshold: 1, + }); + + const firstStream = await Promise.resolve(wrappedFn({} as never, {} as never, {} as never)); + for await (const _item of firstStream) { + // drain + } + await firstStream.result(); + + const secondStream = await Promise.resolve(wrappedFn({} as never, {} as never, {} as never)); + const secondResult = (await secondStream.result()) as { + role: string; + content: Array<{ type: string; text?: string; name?: string }>; + }; + + expect(secondResult.role).toBe("assistant"); + expect(secondResult.content).toEqual([ + expect.objectContaining({ + type: "toolCall", + name: "re", + }), + ]); + }); + + it("resets a stale unknown-tool streak when a streamed message mixes allowed and unknown tools", async () => { + const baseFn = vi + .fn() + .mockImplementationOnce(() => + createFakeStream({ + events: [], + resultMessage: { + role: "assistant", + content: [{ type: "toolCall", name: " ex ", arguments: { command: "echo first" } }], + }, + }), + ) + .mockImplementationOnce(() => + createFakeStream({ + events: [ + { + type: "toolcall_delta", + message: { + role: "assistant", + content: [ + { type: "toolCall", name: " exec ", arguments: { command: "echo allowed" } }, + { type: "toolCall", name: " ex ", arguments: { command: "echo provisional" } }, + ], + }, + }, + ], + resultMessage: { + role: "assistant", + content: [{ type: "toolCall", name: " exec ", arguments: { command: "echo ok" } }], + }, + }), + ) + .mockImplementationOnce(() => + createFakeStream({ + events: [], + resultMessage: { + role: "assistant", + content: [{ type: "toolCall", name: " ex ", arguments: { command: "echo retry" } }], + }, + }), + ); + const wrappedFn = wrapStreamFnTrimToolCallNames(baseFn as never, new Set(["exec"]), { + unknownToolThreshold: 1, + }); + + const firstStream = await Promise.resolve(wrappedFn({} as never, {} as never, {} as never)); + await firstStream.result(); + + const secondStream = await Promise.resolve(wrappedFn({} as never, {} as never, {} as never)); + for await (const _item of secondStream) { + // drain + } + await secondStream.result(); + + const thirdStream = await Promise.resolve(wrappedFn({} as never, {} as never, {} as never)); + const thirdResult = (await thirdStream.result()) as { + role: string; + content: Array<{ type: string; text?: string; name?: string }>; + }; + + expect(thirdResult.role).toBe("assistant"); + expect(thirdResult.content).toEqual([ + expect.objectContaining({ + type: "toolCall", + name: "ex", + }), + ]); + }); + it("infers tool names from malformed toolCallId variants when allowlist is present", async () => { const partialToolCall = { type: "toolCall", id: "functions.read:0", name: "" }; const finalToolCallA = { type: "toolCall", id: "functionsread3", name: "" }; diff --git a/src/agents/pi-embedded-runner/run/attempt.tool-call-normalization.ts b/src/agents/pi-embedded-runner/run/attempt.tool-call-normalization.ts index e4cda7a091c..dc75e37bc17 100644 --- a/src/agents/pi-embedded-runner/run/attempt.tool-call-normalization.ts +++ b/src/agents/pi-embedded-runner/run/attempt.tool-call-normalization.ts @@ -636,20 +636,26 @@ function trimWhitespaceFromToolCallNamesInMessage( normalizeToolCallIdsInMessage(message); } -function collectUnknownToolNameFromMessage( +function classifyToolCallMessage( message: unknown, allowedToolNames?: Set, -): string | undefined { +): + | { kind: "none" } + | { kind: "allowed" } + | { kind: "incomplete" } + | { kind: "unknown"; toolName: string } { if (!message || typeof message !== "object" || !allowedToolNames || allowedToolNames.size === 0) { - return undefined; + return { kind: "none" }; } const content = (message as { content?: unknown }).content; if (!Array.isArray(content)) { - return undefined; + return { kind: "none" }; } let unknownToolName: string | undefined; let sawToolCall = false; + let sawAllowedToolCall = false; + let sawIncompleteToolCall = false; for (const block of content) { if (!block || typeof block !== "object") { continue; @@ -661,10 +667,12 @@ function collectUnknownToolNameFromMessage( sawToolCall = true; const rawName = typeof typedBlock.name === "string" ? typedBlock.name.trim() : ""; if (!rawName) { - return undefined; + sawIncompleteToolCall = true; + continue; } if (resolveExactAllowedToolName(rawName, allowedToolNames)) { - return undefined; + sawAllowedToolCall = true; + continue; } const normalizedUnknownToolName = normalizeToolName(rawName); if (!unknownToolName) { @@ -672,11 +680,20 @@ function collectUnknownToolNameFromMessage( continue; } if (unknownToolName !== normalizedUnknownToolName) { - return undefined; + sawIncompleteToolCall = true; } } - return sawToolCall ? unknownToolName : undefined; + if (!sawToolCall) { + return { kind: "none" }; + } + if (sawAllowedToolCall) { + return { kind: "allowed" }; + } + if (sawIncompleteToolCall) { + return { kind: "incomplete" }; + } + return unknownToolName ? { kind: "unknown", toolName: unknownToolName } : { kind: "incomplete" }; } function rewriteUnknownToolLoopMessage(message: unknown, toolName: string): void { @@ -694,27 +711,41 @@ function rewriteUnknownToolLoopMessage(message: unknown, toolName: string): void function guardUnknownToolLoopInMessage( message: unknown, state: UnknownToolLoopGuardState, - params: { allowedToolNames?: Set; threshold?: number; countAttempt: boolean }, -): void { + params: { + allowedToolNames?: Set; + threshold?: number; + countAttempt: boolean; + resetOnAllowedTool?: boolean; + resetOnMissingUnknownTool?: boolean; + }, +): boolean { const threshold = params.threshold; if (threshold === undefined || threshold <= 0) { - return; + return false; } - const unknownToolName = collectUnknownToolNameFromMessage(message, params.allowedToolNames); - if (!unknownToolName) { - if (params.countAttempt) { + const toolCallState = classifyToolCallMessage(message, params.allowedToolNames); + if (toolCallState.kind === "allowed") { + if (params.resetOnAllowedTool === true) { state.lastUnknownToolName = undefined; state.count = 0; } - return; + return false; } + if (toolCallState.kind !== "unknown") { + if (params.countAttempt && params.resetOnMissingUnknownTool !== false) { + state.lastUnknownToolName = undefined; + state.count = 0; + } + return false; + } + const unknownToolName = toolCallState.toolName; if (!params.countAttempt) { if (state.lastUnknownToolName === unknownToolName && state.count > threshold) { rewriteUnknownToolLoopMessage(message, unknownToolName); } - return; + return false; } if (message && typeof message === "object") { @@ -722,7 +753,7 @@ function guardUnknownToolLoopInMessage( if (state.lastUnknownToolName === unknownToolName && state.count > threshold) { rewriteUnknownToolLoopMessage(message, unknownToolName); } - return; + return true; } state.countedMessages.add(message); } @@ -737,6 +768,7 @@ function guardUnknownToolLoopInMessage( if (state.count > threshold) { rewriteUnknownToolLoopMessage(message, unknownToolName); } + return true; } function wrapStreamTrimToolCallNames( @@ -757,6 +789,7 @@ function wrapStreamTrimToolCallNames( allowedToolNames, threshold: options?.unknownToolThreshold, countAttempt: !streamAttemptAlreadyCounted, + resetOnAllowedTool: true, }); return message; }; @@ -776,12 +809,18 @@ function wrapStreamTrimToolCallNames( trimWhitespaceFromToolCallNamesInMessage(event.partial, allowedToolNames); trimWhitespaceFromToolCallNamesInMessage(event.message, allowedToolNames); if (event.message && typeof event.message === "object") { - guardUnknownToolLoopInMessage(event.message, unknownToolGuardState, { - allowedToolNames, - threshold: options?.unknownToolThreshold, - countAttempt: true, - }); - streamAttemptAlreadyCounted = true; + const countedStreamAttempt = guardUnknownToolLoopInMessage( + event.message, + unknownToolGuardState, + { + allowedToolNames, + threshold: options?.unknownToolThreshold, + countAttempt: !streamAttemptAlreadyCounted, + resetOnAllowedTool: true, + resetOnMissingUnknownTool: false, + }, + ); + streamAttemptAlreadyCounted ||= countedStreamAttempt; } guardUnknownToolLoopInMessage(event.partial, unknownToolGuardState, { allowedToolNames,