mirror of
https://github.com/openclaw/openclaw.git
synced 2026-05-06 05:30:42 +00:00
refactor: share stream message wrapper
This commit is contained in:
@@ -1,7 +1,10 @@
|
||||
import type { StreamFn } from "@mariozechner/pi-agent-core";
|
||||
import { streamSimple } from "@mariozechner/pi-ai";
|
||||
import type { ProviderWrapStreamFnContext } from "openclaw/plugin-sdk/plugin-entry";
|
||||
import { streamWithPayloadPatch } from "openclaw/plugin-sdk/provider-stream-shared";
|
||||
import {
|
||||
streamWithPayloadPatch,
|
||||
wrapStreamMessageObjects,
|
||||
} from "openclaw/plugin-sdk/provider-stream-shared";
|
||||
import { normalizeOptionalLowercaseString } from "openclaw/plugin-sdk/text-runtime";
|
||||
|
||||
const TOOL_CALLS_SECTION_BEGIN = "<|tool_calls_section_begin|>";
|
||||
@@ -173,53 +176,16 @@ function rewriteKimiTaggedToolCallsInMessage(message: unknown): void {
|
||||
}
|
||||
}
|
||||
|
||||
function wrapKimiTaggedToolCalls(
|
||||
stream: ReturnType<typeof streamSimple>,
|
||||
): ReturnType<typeof streamSimple> {
|
||||
const originalResult = stream.result.bind(stream);
|
||||
stream.result = async () => {
|
||||
const message = await originalResult();
|
||||
rewriteKimiTaggedToolCallsInMessage(message);
|
||||
return message;
|
||||
};
|
||||
|
||||
const originalAsyncIterator = stream[Symbol.asyncIterator].bind(stream);
|
||||
(stream as { [Symbol.asyncIterator]: typeof originalAsyncIterator })[Symbol.asyncIterator] =
|
||||
function () {
|
||||
const iterator = originalAsyncIterator();
|
||||
return {
|
||||
async next() {
|
||||
const result = await iterator.next();
|
||||
if (!result.done && result.value && typeof result.value === "object") {
|
||||
const event = result.value as {
|
||||
partial?: unknown;
|
||||
message?: unknown;
|
||||
};
|
||||
rewriteKimiTaggedToolCallsInMessage(event.partial);
|
||||
rewriteKimiTaggedToolCallsInMessage(event.message);
|
||||
}
|
||||
return result;
|
||||
},
|
||||
async return(value?: unknown) {
|
||||
return iterator.return?.(value) ?? { done: true as const, value: undefined };
|
||||
},
|
||||
async throw(error?: unknown) {
|
||||
return iterator.throw?.(error) ?? { done: true as const, value: undefined };
|
||||
},
|
||||
};
|
||||
};
|
||||
|
||||
return stream;
|
||||
}
|
||||
|
||||
export function createKimiToolCallMarkupWrapper(baseStreamFn: StreamFn | undefined): StreamFn {
|
||||
const underlying = baseStreamFn ?? streamSimple;
|
||||
return (model, context, options) => {
|
||||
const maybeStream = underlying(model, context, options);
|
||||
if (maybeStream && typeof maybeStream === "object" && "then" in maybeStream) {
|
||||
return Promise.resolve(maybeStream).then((stream) => wrapKimiTaggedToolCalls(stream));
|
||||
return Promise.resolve(maybeStream).then((stream) =>
|
||||
wrapStreamMessageObjects(stream, rewriteKimiTaggedToolCallsInMessage),
|
||||
);
|
||||
}
|
||||
return wrapKimiTaggedToolCalls(maybeStream);
|
||||
return wrapStreamMessageObjects(maybeStream, rewriteKimiTaggedToolCallsInMessage);
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@@ -76,13 +76,14 @@ function decodeToolCallArgumentsHtmlEntitiesInMessage(message: unknown): void {
|
||||
});
|
||||
}
|
||||
|
||||
function wrapStreamDecodeToolCallArgumentHtmlEntities(
|
||||
export function wrapStreamMessageObjects(
|
||||
stream: ReturnType<typeof streamSimple>,
|
||||
transformMessage: (message: unknown) => void,
|
||||
): ReturnType<typeof streamSimple> {
|
||||
const originalResult = stream.result.bind(stream);
|
||||
stream.result = async () => {
|
||||
const message = await originalResult();
|
||||
decodeToolCallArgumentsHtmlEntitiesInMessage(message);
|
||||
transformMessage(message);
|
||||
return message;
|
||||
};
|
||||
|
||||
@@ -95,8 +96,8 @@ function wrapStreamDecodeToolCallArgumentHtmlEntities(
|
||||
const result = await iterator.next();
|
||||
if (!result.done && result.value && typeof result.value === "object") {
|
||||
const event = result.value as { partial?: unknown; message?: unknown };
|
||||
decodeToolCallArgumentsHtmlEntitiesInMessage(event.partial);
|
||||
decodeToolCallArgumentsHtmlEntitiesInMessage(event.message);
|
||||
transformMessage(event.partial);
|
||||
transformMessage(event.message);
|
||||
}
|
||||
return result;
|
||||
},
|
||||
@@ -119,10 +120,10 @@ export function createHtmlEntityToolCallArgumentDecodingWrapper(
|
||||
const maybeStream = underlying(model, context, options);
|
||||
if (maybeStream && typeof maybeStream === "object" && "then" in maybeStream) {
|
||||
return Promise.resolve(maybeStream).then((stream) =>
|
||||
wrapStreamDecodeToolCallArgumentHtmlEntities(stream),
|
||||
wrapStreamMessageObjects(stream, decodeToolCallArgumentsHtmlEntitiesInMessage),
|
||||
);
|
||||
}
|
||||
return wrapStreamDecodeToolCallArgumentHtmlEntities(maybeStream);
|
||||
return wrapStreamMessageObjects(maybeStream, decodeToolCallArgumentsHtmlEntitiesInMessage);
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user