diff --git a/extensions/kimi-coding/stream.ts b/extensions/kimi-coding/stream.ts index b4c2c686971..a1f495d45e3 100644 --- a/extensions/kimi-coding/stream.ts +++ b/extensions/kimi-coding/stream.ts @@ -1,7 +1,10 @@ import type { StreamFn } from "@mariozechner/pi-agent-core"; import { streamSimple } from "@mariozechner/pi-ai"; import type { ProviderWrapStreamFnContext } from "openclaw/plugin-sdk/plugin-entry"; -import { streamWithPayloadPatch } from "openclaw/plugin-sdk/provider-stream-shared"; +import { + streamWithPayloadPatch, + wrapStreamMessageObjects, +} from "openclaw/plugin-sdk/provider-stream-shared"; import { normalizeOptionalLowercaseString } from "openclaw/plugin-sdk/text-runtime"; const TOOL_CALLS_SECTION_BEGIN = "<|tool_calls_section_begin|>"; @@ -173,53 +176,16 @@ function rewriteKimiTaggedToolCallsInMessage(message: unknown): void { } } -function wrapKimiTaggedToolCalls( - stream: ReturnType, -): ReturnType { - const originalResult = stream.result.bind(stream); - stream.result = async () => { - const message = await originalResult(); - rewriteKimiTaggedToolCallsInMessage(message); - return message; - }; - - const originalAsyncIterator = stream[Symbol.asyncIterator].bind(stream); - (stream as { [Symbol.asyncIterator]: typeof originalAsyncIterator })[Symbol.asyncIterator] = - function () { - const iterator = originalAsyncIterator(); - return { - async next() { - const result = await iterator.next(); - if (!result.done && result.value && typeof result.value === "object") { - const event = result.value as { - partial?: unknown; - message?: unknown; - }; - rewriteKimiTaggedToolCallsInMessage(event.partial); - rewriteKimiTaggedToolCallsInMessage(event.message); - } - return result; - }, - async return(value?: unknown) { - return iterator.return?.(value) ?? { done: true as const, value: undefined }; - }, - async throw(error?: unknown) { - return iterator.throw?.(error) ?? { done: true as const, value: undefined }; - }, - }; - }; - - return stream; -} - export function createKimiToolCallMarkupWrapper(baseStreamFn: StreamFn | undefined): StreamFn { const underlying = baseStreamFn ?? streamSimple; return (model, context, options) => { const maybeStream = underlying(model, context, options); if (maybeStream && typeof maybeStream === "object" && "then" in maybeStream) { - return Promise.resolve(maybeStream).then((stream) => wrapKimiTaggedToolCalls(stream)); + return Promise.resolve(maybeStream).then((stream) => + wrapStreamMessageObjects(stream, rewriteKimiTaggedToolCallsInMessage), + ); } - return wrapKimiTaggedToolCalls(maybeStream); + return wrapStreamMessageObjects(maybeStream, rewriteKimiTaggedToolCallsInMessage); }; } diff --git a/src/plugin-sdk/provider-stream-shared.ts b/src/plugin-sdk/provider-stream-shared.ts index b8e538142d4..8feafccdc9c 100644 --- a/src/plugin-sdk/provider-stream-shared.ts +++ b/src/plugin-sdk/provider-stream-shared.ts @@ -76,13 +76,14 @@ function decodeToolCallArgumentsHtmlEntitiesInMessage(message: unknown): void { }); } -function wrapStreamDecodeToolCallArgumentHtmlEntities( +export function wrapStreamMessageObjects( stream: ReturnType, + transformMessage: (message: unknown) => void, ): ReturnType { const originalResult = stream.result.bind(stream); stream.result = async () => { const message = await originalResult(); - decodeToolCallArgumentsHtmlEntitiesInMessage(message); + transformMessage(message); return message; }; @@ -95,8 +96,8 @@ function wrapStreamDecodeToolCallArgumentHtmlEntities( const result = await iterator.next(); if (!result.done && result.value && typeof result.value === "object") { const event = result.value as { partial?: unknown; message?: unknown }; - decodeToolCallArgumentsHtmlEntitiesInMessage(event.partial); - decodeToolCallArgumentsHtmlEntitiesInMessage(event.message); + transformMessage(event.partial); + transformMessage(event.message); } return result; }, @@ -119,10 +120,10 @@ export function createHtmlEntityToolCallArgumentDecodingWrapper( const maybeStream = underlying(model, context, options); if (maybeStream && typeof maybeStream === "object" && "then" in maybeStream) { return Promise.resolve(maybeStream).then((stream) => - wrapStreamDecodeToolCallArgumentHtmlEntities(stream), + wrapStreamMessageObjects(stream, decodeToolCallArgumentsHtmlEntitiesInMessage), ); } - return wrapStreamDecodeToolCallArgumentHtmlEntities(maybeStream); + return wrapStreamMessageObjects(maybeStream, decodeToolCallArgumentsHtmlEntitiesInMessage); }; }