refactor: share embedded stream event wrapper

This commit is contained in:
Peter Steinberger
2026-04-18 21:27:50 +01:00
parent 5d6ee4f73e
commit 8879ed153d
3 changed files with 129 additions and 147 deletions

View File

@@ -6,6 +6,7 @@ import {
} from "../../../plugin-sdk/provider-stream-shared.js";
import { normalizeProviderId } from "../../model-selection.js";
import { log } from "../logger.js";
import { wrapStreamObjectEvents } from "./stream-wrapper.js";
function isToolCallBlockType(type: unknown): boolean {
return type === "toolCall" || type === "toolUse" || type === "functionCall";
@@ -254,111 +255,82 @@ function wrapStreamRepairMalformedToolCallArguments(
return message;
};
const originalAsyncIterator = stream[Symbol.asyncIterator].bind(stream);
(stream as { [Symbol.asyncIterator]: typeof originalAsyncIterator })[Symbol.asyncIterator] =
function () {
const iterator = originalAsyncIterator();
return {
async next() {
const result = await iterator.next();
if (!result.done && result.value && typeof result.value === "object") {
const event = result.value as {
type?: unknown;
contentIndex?: unknown;
delta?: unknown;
partial?: unknown;
message?: unknown;
toolCall?: unknown;
};
if (
typeof event.contentIndex === "number" &&
Number.isInteger(event.contentIndex) &&
event.type === "toolcall_delta" &&
typeof event.delta === "string"
) {
if (disabledIndices.has(event.contentIndex)) {
return result;
}
const nextPartialJson =
(partialJsonByIndex.get(event.contentIndex) ?? "") + event.delta;
if (nextPartialJson.length > MAX_TOOLCALL_REPAIR_BUFFER_CHARS) {
partialJsonByIndex.delete(event.contentIndex);
repairedArgsByIndex.delete(event.contentIndex);
disabledIndices.add(event.contentIndex);
return result;
}
partialJsonByIndex.set(event.contentIndex, nextPartialJson);
const shouldReevaluateRepair =
shouldAttemptMalformedToolCallRepair(nextPartialJson, event.delta) ||
repairedArgsByIndex.has(event.contentIndex);
if (shouldReevaluateRepair) {
const hadRepairState = repairedArgsByIndex.has(event.contentIndex);
const repair = tryExtractUsableToolCallArguments(nextPartialJson);
if (repair) {
if (
!hadRepairState &&
(hasMeaningfulToolCallArgumentsInMessage(event.partial, event.contentIndex) ||
hasMeaningfulToolCallArgumentsInMessage(event.message, event.contentIndex))
) {
hadPreexistingArgsByIndex.add(event.contentIndex);
}
repairedArgsByIndex.set(event.contentIndex, repair.args);
repairToolCallArgumentsInMessage(event.partial, event.contentIndex, repair.args);
repairToolCallArgumentsInMessage(event.message, event.contentIndex, repair.args);
if (!loggedRepairIndices.has(event.contentIndex) && repair.kind === "repaired") {
loggedRepairIndices.add(event.contentIndex);
log.warn(
`repairing Kimi tool call arguments with ${repair.leadingPrefix.length} leading chars and ${repair.trailingSuffix.length} trailing chars`,
);
}
} else {
repairedArgsByIndex.delete(event.contentIndex);
// Keep args that were already present on the streamed message, but
// clear repair-only state so stale repaired args do not get replayed.
const hadPreexistingArgs =
hadPreexistingArgsByIndex.has(event.contentIndex) ||
(!hadRepairState &&
(hasMeaningfulToolCallArgumentsInMessage(event.partial, event.contentIndex) ||
hasMeaningfulToolCallArgumentsInMessage(
event.message,
event.contentIndex,
)));
if (!hadPreexistingArgs) {
clearToolCallArgumentsInMessage(event.partial, event.contentIndex);
clearToolCallArgumentsInMessage(event.message, event.contentIndex);
}
}
}
}
if (
typeof event.contentIndex === "number" &&
Number.isInteger(event.contentIndex) &&
event.type === "toolcall_end"
) {
const repairedArgs = repairedArgsByIndex.get(event.contentIndex);
if (repairedArgs) {
if (event.toolCall && typeof event.toolCall === "object") {
(event.toolCall as { arguments?: unknown }).arguments = repairedArgs;
}
repairToolCallArgumentsInMessage(event.partial, event.contentIndex, repairedArgs);
repairToolCallArgumentsInMessage(event.message, event.contentIndex, repairedArgs);
}
partialJsonByIndex.delete(event.contentIndex);
hadPreexistingArgsByIndex.delete(event.contentIndex);
disabledIndices.delete(event.contentIndex);
loggedRepairIndices.delete(event.contentIndex);
}
wrapStreamObjectEvents(stream, (event) => {
if (
typeof event.contentIndex === "number" &&
Number.isInteger(event.contentIndex) &&
event.type === "toolcall_delta" &&
typeof event.delta === "string"
) {
if (disabledIndices.has(event.contentIndex)) {
return;
}
const nextPartialJson = (partialJsonByIndex.get(event.contentIndex) ?? "") + event.delta;
if (nextPartialJson.length > MAX_TOOLCALL_REPAIR_BUFFER_CHARS) {
partialJsonByIndex.delete(event.contentIndex);
repairedArgsByIndex.delete(event.contentIndex);
disabledIndices.add(event.contentIndex);
return;
}
partialJsonByIndex.set(event.contentIndex, nextPartialJson);
const shouldReevaluateRepair =
shouldAttemptMalformedToolCallRepair(nextPartialJson, event.delta) ||
repairedArgsByIndex.has(event.contentIndex);
if (shouldReevaluateRepair) {
const hadRepairState = repairedArgsByIndex.has(event.contentIndex);
const repair = tryExtractUsableToolCallArguments(nextPartialJson);
if (repair) {
if (
!hadRepairState &&
(hasMeaningfulToolCallArgumentsInMessage(event.partial, event.contentIndex) ||
hasMeaningfulToolCallArgumentsInMessage(event.message, event.contentIndex))
) {
hadPreexistingArgsByIndex.add(event.contentIndex);
}
return result;
},
async return(value?: unknown) {
return iterator.return?.(value) ?? { done: true as const, value: undefined };
},
async throw(error?: unknown) {
return iterator.throw?.(error) ?? { done: true as const, value: undefined };
},
};
};
repairedArgsByIndex.set(event.contentIndex, repair.args);
repairToolCallArgumentsInMessage(event.partial, event.contentIndex, repair.args);
repairToolCallArgumentsInMessage(event.message, event.contentIndex, repair.args);
if (!loggedRepairIndices.has(event.contentIndex) && repair.kind === "repaired") {
loggedRepairIndices.add(event.contentIndex);
log.warn(
`repairing Kimi tool call arguments with ${repair.leadingPrefix.length} leading chars and ${repair.trailingSuffix.length} trailing chars`,
);
}
} else {
repairedArgsByIndex.delete(event.contentIndex);
// Keep args that were already present on the streamed message, but
// clear repair-only state so stale repaired args do not get replayed.
const hadPreexistingArgs =
hadPreexistingArgsByIndex.has(event.contentIndex) ||
(!hadRepairState &&
(hasMeaningfulToolCallArgumentsInMessage(event.partial, event.contentIndex) ||
hasMeaningfulToolCallArgumentsInMessage(event.message, event.contentIndex)));
if (!hadPreexistingArgs) {
clearToolCallArgumentsInMessage(event.partial, event.contentIndex);
clearToolCallArgumentsInMessage(event.message, event.contentIndex);
}
}
}
}
if (
typeof event.contentIndex === "number" &&
Number.isInteger(event.contentIndex) &&
event.type === "toolcall_end"
) {
const repairedArgs = repairedArgsByIndex.get(event.contentIndex);
if (repairedArgs) {
if (event.toolCall && typeof event.toolCall === "object") {
(event.toolCall as { arguments?: unknown }).arguments = repairedArgs;
}
repairToolCallArgumentsInMessage(event.partial, event.contentIndex, repairedArgs);
repairToolCallArgumentsInMessage(event.message, event.contentIndex, repairedArgs);
}
partialJsonByIndex.delete(event.contentIndex);
hadPreexistingArgsByIndex.delete(event.contentIndex);
disabledIndices.delete(event.contentIndex);
loggedRepairIndices.delete(event.contentIndex);
}
});
return stream;
}

View File

@@ -12,6 +12,7 @@ import { hasUnredactedSessionsSpawnAttachments } from "../../tool-call-shared.js
import { normalizeToolName } from "../../tool-policy.js";
import { shouldAllowProviderOwnedThinkingReplay } from "../../transcript-policy.js";
import type { TranscriptPolicy } from "../../transcript-policy.js";
import { wrapStreamObjectEvents } from "./stream-wrapper.js";
type UnknownToolLoopGuardState = {
lastUnknownToolName?: string;
@@ -774,50 +775,29 @@ function wrapStreamTrimToolCallNames(
return message;
};
const originalAsyncIterator = stream[Symbol.asyncIterator].bind(stream);
(stream as { [Symbol.asyncIterator]: typeof originalAsyncIterator })[Symbol.asyncIterator] =
function () {
const iterator = originalAsyncIterator();
return {
async next() {
const result = await iterator.next();
if (!result.done && result.value && typeof result.value === "object") {
const event = result.value as {
partial?: unknown;
message?: unknown;
};
trimWhitespaceFromToolCallNamesInMessage(event.partial, allowedToolNames);
trimWhitespaceFromToolCallNamesInMessage(event.message, allowedToolNames);
if (event.message && typeof event.message === "object") {
const countedStreamAttempt = guardUnknownToolLoopInMessage(
event.message,
unknownToolGuardState,
{
allowedToolNames,
threshold: options?.unknownToolThreshold,
countAttempt: !streamAttemptAlreadyCounted,
resetOnAllowedTool: true,
resetOnMissingUnknownTool: false,
},
);
streamAttemptAlreadyCounted ||= countedStreamAttempt;
}
guardUnknownToolLoopInMessage(event.partial, unknownToolGuardState, {
allowedToolNames,
threshold: options?.unknownToolThreshold,
countAttempt: false,
});
}
return result;
wrapStreamObjectEvents(stream, (event) => {
trimWhitespaceFromToolCallNamesInMessage(event.partial, allowedToolNames);
trimWhitespaceFromToolCallNamesInMessage(event.message, allowedToolNames);
if (event.message && typeof event.message === "object") {
const countedStreamAttempt = guardUnknownToolLoopInMessage(
event.message,
unknownToolGuardState,
{
allowedToolNames,
threshold: options?.unknownToolThreshold,
countAttempt: !streamAttemptAlreadyCounted,
resetOnAllowedTool: true,
resetOnMissingUnknownTool: false,
},
async return(value?: unknown) {
return iterator.return?.(value) ?? { done: true as const, value: undefined };
},
async throw(error?: unknown) {
return iterator.throw?.(error) ?? { done: true as const, value: undefined };
},
};
};
);
streamAttemptAlreadyCounted ||= countedStreamAttempt;
}
guardUnknownToolLoopInMessage(event.partial, unknownToolGuardState, {
allowedToolNames,
threshold: options?.unknownToolThreshold,
countAttempt: false,
});
});
return stream;
}

View File

@@ -0,0 +1,30 @@
import { streamSimple } from "@mariozechner/pi-ai";
type SimpleStream = ReturnType<typeof streamSimple>;
export function wrapStreamObjectEvents(
stream: SimpleStream,
onEvent: (event: Record<string, unknown>) => void | Promise<void>,
): SimpleStream {
const originalAsyncIterator = stream[Symbol.asyncIterator].bind(stream);
(stream as { [Symbol.asyncIterator]: typeof originalAsyncIterator })[Symbol.asyncIterator] =
function () {
const iterator = originalAsyncIterator();
return {
async next() {
const result = await iterator.next();
if (!result.done && result.value && typeof result.value === "object") {
await onEvent(result.value as Record<string, unknown>);
}
return result;
},
async return(value?: unknown) {
return iterator.return?.(value) ?? { done: true as const, value: undefined };
},
async throw(error?: unknown) {
return iterator.throw?.(error) ?? { done: true as const, value: undefined };
},
};
};
return stream;
}