From 8879ed153d27c5faa0d9683478a0c453d3e1876a Mon Sep 17 00:00:00 2001 From: Peter Steinberger Date: Sat, 18 Apr 2026 21:27:50 +0100 Subject: [PATCH] refactor: share embedded stream event wrapper --- .../run/attempt.tool-call-argument-repair.ts | 180 ++++++++---------- .../run/attempt.tool-call-normalization.ts | 66 +++---- .../pi-embedded-runner/run/stream-wrapper.ts | 30 +++ 3 files changed, 129 insertions(+), 147 deletions(-) create mode 100644 src/agents/pi-embedded-runner/run/stream-wrapper.ts diff --git a/src/agents/pi-embedded-runner/run/attempt.tool-call-argument-repair.ts b/src/agents/pi-embedded-runner/run/attempt.tool-call-argument-repair.ts index fd88b2e09ff..d74f6ac333a 100644 --- a/src/agents/pi-embedded-runner/run/attempt.tool-call-argument-repair.ts +++ b/src/agents/pi-embedded-runner/run/attempt.tool-call-argument-repair.ts @@ -6,6 +6,7 @@ import { } from "../../../plugin-sdk/provider-stream-shared.js"; import { normalizeProviderId } from "../../model-selection.js"; import { log } from "../logger.js"; +import { wrapStreamObjectEvents } from "./stream-wrapper.js"; function isToolCallBlockType(type: unknown): boolean { return type === "toolCall" || type === "toolUse" || type === "functionCall"; @@ -254,111 +255,82 @@ function wrapStreamRepairMalformedToolCallArguments( return message; }; - const originalAsyncIterator = stream[Symbol.asyncIterator].bind(stream); - (stream as { [Symbol.asyncIterator]: typeof originalAsyncIterator })[Symbol.asyncIterator] = - function () { - const iterator = originalAsyncIterator(); - return { - async next() { - const result = await iterator.next(); - if (!result.done && result.value && typeof result.value === "object") { - const event = result.value as { - type?: unknown; - contentIndex?: unknown; - delta?: unknown; - partial?: unknown; - message?: unknown; - toolCall?: unknown; - }; - if ( - typeof event.contentIndex === "number" && - Number.isInteger(event.contentIndex) && - event.type === "toolcall_delta" && - typeof event.delta === "string" - ) { - if (disabledIndices.has(event.contentIndex)) { - return result; - } - const nextPartialJson = - (partialJsonByIndex.get(event.contentIndex) ?? "") + event.delta; - if (nextPartialJson.length > MAX_TOOLCALL_REPAIR_BUFFER_CHARS) { - partialJsonByIndex.delete(event.contentIndex); - repairedArgsByIndex.delete(event.contentIndex); - disabledIndices.add(event.contentIndex); - return result; - } - partialJsonByIndex.set(event.contentIndex, nextPartialJson); - const shouldReevaluateRepair = - shouldAttemptMalformedToolCallRepair(nextPartialJson, event.delta) || - repairedArgsByIndex.has(event.contentIndex); - if (shouldReevaluateRepair) { - const hadRepairState = repairedArgsByIndex.has(event.contentIndex); - const repair = tryExtractUsableToolCallArguments(nextPartialJson); - if (repair) { - if ( - !hadRepairState && - (hasMeaningfulToolCallArgumentsInMessage(event.partial, event.contentIndex) || - hasMeaningfulToolCallArgumentsInMessage(event.message, event.contentIndex)) - ) { - hadPreexistingArgsByIndex.add(event.contentIndex); - } - repairedArgsByIndex.set(event.contentIndex, repair.args); - repairToolCallArgumentsInMessage(event.partial, event.contentIndex, repair.args); - repairToolCallArgumentsInMessage(event.message, event.contentIndex, repair.args); - if (!loggedRepairIndices.has(event.contentIndex) && repair.kind === "repaired") { - loggedRepairIndices.add(event.contentIndex); - log.warn( - `repairing Kimi tool call arguments with ${repair.leadingPrefix.length} leading chars and ${repair.trailingSuffix.length} trailing chars`, - ); - } - } else { - repairedArgsByIndex.delete(event.contentIndex); - // Keep args that were already present on the streamed message, but - // clear repair-only state so stale repaired args do not get replayed. - const hadPreexistingArgs = - hadPreexistingArgsByIndex.has(event.contentIndex) || - (!hadRepairState && - (hasMeaningfulToolCallArgumentsInMessage(event.partial, event.contentIndex) || - hasMeaningfulToolCallArgumentsInMessage( - event.message, - event.contentIndex, - ))); - if (!hadPreexistingArgs) { - clearToolCallArgumentsInMessage(event.partial, event.contentIndex); - clearToolCallArgumentsInMessage(event.message, event.contentIndex); - } - } - } - } - if ( - typeof event.contentIndex === "number" && - Number.isInteger(event.contentIndex) && - event.type === "toolcall_end" - ) { - const repairedArgs = repairedArgsByIndex.get(event.contentIndex); - if (repairedArgs) { - if (event.toolCall && typeof event.toolCall === "object") { - (event.toolCall as { arguments?: unknown }).arguments = repairedArgs; - } - repairToolCallArgumentsInMessage(event.partial, event.contentIndex, repairedArgs); - repairToolCallArgumentsInMessage(event.message, event.contentIndex, repairedArgs); - } - partialJsonByIndex.delete(event.contentIndex); - hadPreexistingArgsByIndex.delete(event.contentIndex); - disabledIndices.delete(event.contentIndex); - loggedRepairIndices.delete(event.contentIndex); - } + wrapStreamObjectEvents(stream, (event) => { + if ( + typeof event.contentIndex === "number" && + Number.isInteger(event.contentIndex) && + event.type === "toolcall_delta" && + typeof event.delta === "string" + ) { + if (disabledIndices.has(event.contentIndex)) { + return; + } + const nextPartialJson = (partialJsonByIndex.get(event.contentIndex) ?? "") + event.delta; + if (nextPartialJson.length > MAX_TOOLCALL_REPAIR_BUFFER_CHARS) { + partialJsonByIndex.delete(event.contentIndex); + repairedArgsByIndex.delete(event.contentIndex); + disabledIndices.add(event.contentIndex); + return; + } + partialJsonByIndex.set(event.contentIndex, nextPartialJson); + const shouldReevaluateRepair = + shouldAttemptMalformedToolCallRepair(nextPartialJson, event.delta) || + repairedArgsByIndex.has(event.contentIndex); + if (shouldReevaluateRepair) { + const hadRepairState = repairedArgsByIndex.has(event.contentIndex); + const repair = tryExtractUsableToolCallArguments(nextPartialJson); + if (repair) { + if ( + !hadRepairState && + (hasMeaningfulToolCallArgumentsInMessage(event.partial, event.contentIndex) || + hasMeaningfulToolCallArgumentsInMessage(event.message, event.contentIndex)) + ) { + hadPreexistingArgsByIndex.add(event.contentIndex); } - return result; - }, - async return(value?: unknown) { - return iterator.return?.(value) ?? { done: true as const, value: undefined }; - }, - async throw(error?: unknown) { - return iterator.throw?.(error) ?? { done: true as const, value: undefined }; - }, - }; - }; + repairedArgsByIndex.set(event.contentIndex, repair.args); + repairToolCallArgumentsInMessage(event.partial, event.contentIndex, repair.args); + repairToolCallArgumentsInMessage(event.message, event.contentIndex, repair.args); + if (!loggedRepairIndices.has(event.contentIndex) && repair.kind === "repaired") { + loggedRepairIndices.add(event.contentIndex); + log.warn( + `repairing Kimi tool call arguments with ${repair.leadingPrefix.length} leading chars and ${repair.trailingSuffix.length} trailing chars`, + ); + } + } else { + repairedArgsByIndex.delete(event.contentIndex); + // Keep args that were already present on the streamed message, but + // clear repair-only state so stale repaired args do not get replayed. + const hadPreexistingArgs = + hadPreexistingArgsByIndex.has(event.contentIndex) || + (!hadRepairState && + (hasMeaningfulToolCallArgumentsInMessage(event.partial, event.contentIndex) || + hasMeaningfulToolCallArgumentsInMessage(event.message, event.contentIndex))); + if (!hadPreexistingArgs) { + clearToolCallArgumentsInMessage(event.partial, event.contentIndex); + clearToolCallArgumentsInMessage(event.message, event.contentIndex); + } + } + } + } + if ( + typeof event.contentIndex === "number" && + Number.isInteger(event.contentIndex) && + event.type === "toolcall_end" + ) { + const repairedArgs = repairedArgsByIndex.get(event.contentIndex); + if (repairedArgs) { + if (event.toolCall && typeof event.toolCall === "object") { + (event.toolCall as { arguments?: unknown }).arguments = repairedArgs; + } + repairToolCallArgumentsInMessage(event.partial, event.contentIndex, repairedArgs); + repairToolCallArgumentsInMessage(event.message, event.contentIndex, repairedArgs); + } + partialJsonByIndex.delete(event.contentIndex); + hadPreexistingArgsByIndex.delete(event.contentIndex); + disabledIndices.delete(event.contentIndex); + loggedRepairIndices.delete(event.contentIndex); + } + }); return stream; } diff --git a/src/agents/pi-embedded-runner/run/attempt.tool-call-normalization.ts b/src/agents/pi-embedded-runner/run/attempt.tool-call-normalization.ts index 544973df384..953f68a99da 100644 --- a/src/agents/pi-embedded-runner/run/attempt.tool-call-normalization.ts +++ b/src/agents/pi-embedded-runner/run/attempt.tool-call-normalization.ts @@ -12,6 +12,7 @@ import { hasUnredactedSessionsSpawnAttachments } from "../../tool-call-shared.js import { normalizeToolName } from "../../tool-policy.js"; import { shouldAllowProviderOwnedThinkingReplay } from "../../transcript-policy.js"; import type { TranscriptPolicy } from "../../transcript-policy.js"; +import { wrapStreamObjectEvents } from "./stream-wrapper.js"; type UnknownToolLoopGuardState = { lastUnknownToolName?: string; @@ -774,50 +775,29 @@ function wrapStreamTrimToolCallNames( return message; }; - const originalAsyncIterator = stream[Symbol.asyncIterator].bind(stream); - (stream as { [Symbol.asyncIterator]: typeof originalAsyncIterator })[Symbol.asyncIterator] = - function () { - const iterator = originalAsyncIterator(); - return { - async next() { - const result = await iterator.next(); - if (!result.done && result.value && typeof result.value === "object") { - const event = result.value as { - partial?: unknown; - message?: unknown; - }; - trimWhitespaceFromToolCallNamesInMessage(event.partial, allowedToolNames); - trimWhitespaceFromToolCallNamesInMessage(event.message, allowedToolNames); - if (event.message && typeof event.message === "object") { - const countedStreamAttempt = guardUnknownToolLoopInMessage( - event.message, - unknownToolGuardState, - { - allowedToolNames, - threshold: options?.unknownToolThreshold, - countAttempt: !streamAttemptAlreadyCounted, - resetOnAllowedTool: true, - resetOnMissingUnknownTool: false, - }, - ); - streamAttemptAlreadyCounted ||= countedStreamAttempt; - } - guardUnknownToolLoopInMessage(event.partial, unknownToolGuardState, { - allowedToolNames, - threshold: options?.unknownToolThreshold, - countAttempt: false, - }); - } - return result; + wrapStreamObjectEvents(stream, (event) => { + trimWhitespaceFromToolCallNamesInMessage(event.partial, allowedToolNames); + trimWhitespaceFromToolCallNamesInMessage(event.message, allowedToolNames); + if (event.message && typeof event.message === "object") { + const countedStreamAttempt = guardUnknownToolLoopInMessage( + event.message, + unknownToolGuardState, + { + allowedToolNames, + threshold: options?.unknownToolThreshold, + countAttempt: !streamAttemptAlreadyCounted, + resetOnAllowedTool: true, + resetOnMissingUnknownTool: false, }, - async return(value?: unknown) { - return iterator.return?.(value) ?? { done: true as const, value: undefined }; - }, - async throw(error?: unknown) { - return iterator.throw?.(error) ?? { done: true as const, value: undefined }; - }, - }; - }; + ); + streamAttemptAlreadyCounted ||= countedStreamAttempt; + } + guardUnknownToolLoopInMessage(event.partial, unknownToolGuardState, { + allowedToolNames, + threshold: options?.unknownToolThreshold, + countAttempt: false, + }); + }); return stream; } diff --git a/src/agents/pi-embedded-runner/run/stream-wrapper.ts b/src/agents/pi-embedded-runner/run/stream-wrapper.ts new file mode 100644 index 00000000000..6af05a4eeff --- /dev/null +++ b/src/agents/pi-embedded-runner/run/stream-wrapper.ts @@ -0,0 +1,30 @@ +import { streamSimple } from "@mariozechner/pi-ai"; + +type SimpleStream = ReturnType; + +export function wrapStreamObjectEvents( + stream: SimpleStream, + onEvent: (event: Record) => void | Promise, +): SimpleStream { + const originalAsyncIterator = stream[Symbol.asyncIterator].bind(stream); + (stream as { [Symbol.asyncIterator]: typeof originalAsyncIterator })[Symbol.asyncIterator] = + function () { + const iterator = originalAsyncIterator(); + return { + async next() { + const result = await iterator.next(); + if (!result.done && result.value && typeof result.value === "object") { + await onEvent(result.value as Record); + } + return result; + }, + async return(value?: unknown) { + return iterator.return?.(value) ?? { done: true as const, value: undefined }; + }, + async throw(error?: unknown) { + return iterator.throw?.(error) ?? { done: true as const, value: undefined }; + }, + }; + }; + return stream; +}