diff --git a/packages/tool-call-repair/package.json b/packages/tool-call-repair/package.json new file mode 100644 index 00000000000..011f4b51b4f --- /dev/null +++ b/packages/tool-call-repair/package.json @@ -0,0 +1,9 @@ +{ + "name": "@openclaw/tool-call-repair", + "version": "0.0.0-private", + "private": true, + "type": "module", + "exports": { + ".": "./src/index.ts" + } +} diff --git a/packages/tool-call-repair/src/grammar.ts b/packages/tool-call-repair/src/grammar.ts new file mode 100644 index 00000000000..37b177bed4e --- /dev/null +++ b/packages/tool-call-repair/src/grammar.ts @@ -0,0 +1,218 @@ +export const END_TOOL_REQUEST = "[END_TOOL_REQUEST]"; +export const HARMONY_CHANNEL_MARKER = "<|channel|>"; +export const HARMONY_MESSAGE_MARKER = "<|message|>"; +export const HARMONY_CALL_MARKER = "<|call|>"; + +export function matchesLiteralPrefix(text: string, literal: string): boolean { + return literal.startsWith(text) || text.startsWith(literal); +} + +export function isPlainTextToolNameChar(char: string | undefined): boolean { + return Boolean(char && /[A-Za-z0-9_-]/.test(char)); +} + +export function isXmlishNameChar(char: string | undefined): boolean { + return Boolean(char && /[A-Za-z0-9_.:-]/.test(char)); +} + +export function skipHorizontalWhitespace(text: string, start: number): number { + let index = start; + while (index < text.length && (text[index] === " " || text[index] === "\t")) { + index += 1; + } + return index; +} + +export function skipWhitespace(text: string, start: number): number { + let index = start; + while (index < text.length && /\s/.test(text[index] ?? "")) { + index += 1; + } + return index; +} + +export function consumeLineBreak(text: string, start: number): number | null { + if (text[start] === "\r") { + return text[start + 1] === "\n" ? start + 2 : start + 1; + } + if (text[start] === "\n") { + return start + 1; + } + return null; +} + +export function findJsonObjectEnd( + text: string, + start: number, + maxPayloadBytes?: number, +): number | null { + let depth = 0; + let inString = false; + let escaped = false; + for (let index = start; index < text.length; index += 1) { + if (maxPayloadBytes !== undefined && index + 1 - start > maxPayloadBytes) { + return null; + } + const char = text[index]; + if (inString) { + if (escaped) { + escaped = false; + } else if (char === "\\") { + escaped = true; + } else if (char === '"') { + inString = false; + } + continue; + } + if (char === '"') { + inString = true; + continue; + } + if (char === "{") { + depth += 1; + continue; + } + if (char === "}") { + depth -= 1; + if (depth === 0) { + return index + 1; + } + } + } + return null; +} + +export function skipSerializedToolCallTrailingLineBreak(text: string, cursor: number): number { + const afterLineBreak = consumeLineBreak(text, cursor); + return afterLineBreak ?? cursor; +} + +export function consumeJsonToolClosingMarker(text: string, cursor: number): number { + let markerStart = cursor; + while (markerStart < text.length && /\s/.test(text[markerStart] ?? "")) { + markerStart += 1; + } + const rest = text.slice(markerStart); + if (rest.startsWith(END_TOOL_REQUEST)) { + return skipSerializedToolCallTrailingLineBreak(text, markerStart + END_TOOL_REQUEST.length); + } + const bracketClose = /^\[\/[A-Za-z0-9_-]+\]/.exec(rest); + if (bracketClose) { + return skipSerializedToolCallTrailingLineBreak(text, markerStart + bracketClose[0].length); + } + if (rest.startsWith(HARMONY_CALL_MARKER)) { + return skipSerializedToolCallTrailingLineBreak(text, markerStart + HARMONY_CALL_MARKER.length); + } + return skipSerializedToolCallTrailingLineBreak(text, cursor); +} + +export function findBracketedJsonPayloadStart(text: string): number | null { + if (!text.startsWith("[")) { + return null; + } + const close = text.indexOf("]"); + if (close === -1) { + return null; + } + let cursor = close + 1; + cursor = skipHorizontalWhitespace(text, cursor); + cursor = skipSerializedToolCallTrailingLineBreak(text, cursor); + cursor = skipHorizontalWhitespace(text, cursor); + return text[cursor] === "{" ? cursor : null; +} + +export function findHarmonyJsonPayloadStart(text: string): number | null { + let cursor = 0; + if (text.startsWith(HARMONY_CHANNEL_MARKER)) { + cursor = HARMONY_CHANNEL_MARKER.length; + } + const rest = text.slice(cursor); + const channel = ["commentary", "analysis", "final"].find((candidate) => + rest.startsWith(candidate), + ); + if (!channel) { + return null; + } + cursor += channel.length; + cursor = skipHorizontalWhitespace(text, cursor); + if (!text.slice(cursor).startsWith("to=")) { + return null; + } + cursor += "to=".length; + const nameStart = cursor; + while (isPlainTextToolNameChar(text[cursor])) { + cursor += 1; + } + if (cursor === nameStart) { + return null; + } + cursor = skipHorizontalWhitespace(text, cursor); + if (!text.slice(cursor).startsWith("code")) { + return null; + } + cursor += "code".length; + cursor = skipWhitespace(text, cursor); + if (text.slice(cursor).startsWith(HARMONY_MESSAGE_MARKER)) { + cursor = skipWhitespace(text, cursor + HARMONY_MESSAGE_MARKER.length); + } + return text[cursor] === "{" ? cursor : null; +} + +export function startsWithAsciiMarkerIgnoreCase( + text: string, + cursor: number, + marker: string, +): boolean { + return text.slice(cursor, cursor + marker.length).toLowerCase() === marker; +} + +export function indexOfAsciiMarkerIgnoreCase(text: string, marker: string, start: number): number { + let cursor = start; + while (cursor < text.length) { + const next = text.indexOf(marker[0] ?? "", cursor); + if (next === -1) { + return -1; + } + if (startsWithAsciiMarkerIgnoreCase(text, next, marker)) { + return next; + } + cursor = next + 1; + } + return -1; +} + +export function findXmlishToolCallEnd(text: string): number | null { + let cursor: number; + const xmlFunction = /^/i.exec(text); + if (xmlFunction) { + cursor = xmlFunction[0].length; + } else { + const bracketed = /^\[(?:tool:)?[A-Za-z0-9_-]+\]/.exec(text); + if (!bracketed) { + return null; + } + cursor = bracketed[0].length; + cursor = skipHorizontalWhitespace(text, cursor); + cursor = skipSerializedToolCallTrailingLineBreak(text, cursor); + } + + cursor = skipWhitespace(text, cursor); + if (!startsWithAsciiMarkerIgnoreCase(text, cursor, "", cursor); + if (parameterClose === -1) { + return null; + } + cursor = skipWhitespace(text, parameterClose + "".length); + if (startsWithAsciiMarkerIgnoreCase(text, cursor, "")) { + return skipSerializedToolCallTrailingLineBreak(text, cursor + "".length); + } + if (!startsWithAsciiMarkerIgnoreCase(text, cursor, "; + end: number; + name: string; + raw: string; + start: number; +}; + +export type PlainTextToolCallParseOptions = { + allowedToolNames?: Iterable; + maxPayloadBytes?: number; +}; + +const DEFAULT_MAX_PLAIN_TEXT_TOOL_PAYLOAD_BYTES = 256_000; + +type PlainTextToolCallOpening = { + allowsOptionalXmlishClose?: boolean; + end: number; + name: string; + requiresClosing: boolean; +}; + +function parseBracketOpening(text: string, start: number): PlainTextToolCallOpening | null { + if (text[start] !== "[") { + return null; + } + let cursor = start + 1; + if (text.startsWith("tool:", cursor)) { + cursor += "tool:".length; + const nameStart = cursor; + while (isPlainTextToolNameChar(text[cursor])) { + cursor += 1; + } + if (cursor === nameStart || text[cursor] !== "]") { + return null; + } + return { + allowsOptionalXmlishClose: true, + end: cursor + 1, + name: text.slice(nameStart, cursor), + requiresClosing: false, + }; + } + const nameStart = cursor; + while (isPlainTextToolNameChar(text[cursor])) { + cursor += 1; + } + if (cursor === nameStart || text[cursor] !== "]") { + return null; + } + const name = text.slice(nameStart, cursor); + cursor += 1; + cursor = skipHorizontalWhitespace(text, cursor); + const afterLineBreak = consumeLineBreak(text, cursor); + if (afterLineBreak === null) { + return null; + } + return { end: afterLineBreak, name, requiresClosing: true }; +} + +function parseHarmonyOpening(text: string, start: number): PlainTextToolCallOpening | null { + let cursor = start; + if (text.startsWith(HARMONY_CHANNEL_MARKER, cursor)) { + cursor += HARMONY_CHANNEL_MARKER.length; + } + const channelStart = cursor; + while (/[A-Za-z_]/.test(text[cursor] ?? "")) { + cursor += 1; + } + const channel = text.slice(channelStart, cursor); + if (channel !== "commentary" && channel !== "analysis" && channel !== "final") { + return null; + } + cursor = skipHorizontalWhitespace(text, cursor); + if (!text.startsWith("to=", cursor)) { + return null; + } + cursor += 3; + const nameStart = cursor; + while (isPlainTextToolNameChar(text[cursor])) { + cursor += 1; + } + if (cursor === nameStart) { + return null; + } + const name = text.slice(nameStart, cursor); + cursor = skipHorizontalWhitespace(text, cursor); + if (!text.startsWith("code", cursor)) { + return null; + } + cursor += 4; + cursor = skipWhitespace(text, cursor); + if (text.startsWith(HARMONY_MESSAGE_MARKER, cursor)) { + cursor = skipWhitespace(text, cursor + HARMONY_MESSAGE_MARKER.length); + } + return { end: cursor, name, requiresClosing: false }; +} + +function parseXmlishFunctionOpening(text: string, start: number): PlainTextToolCallOpening | null { + const match = /^\s*/i.exec(text.slice(start)); + if (!match?.[1]) { + return null; + } + return { end: start + match[0].length, name: match[1], requiresClosing: false }; +} + +function parseOpening(text: string, start: number): PlainTextToolCallOpening | null { + return parseBracketOpening(text, start) ?? parseHarmonyOpening(text, start); +} + +function consumeJsonObject( + text: string, + start: number, + maxPayloadBytes: number, +): { end: number; value: Record } | null { + const cursor = skipWhitespace(text, start); + if (text[cursor] !== "{") { + return null; + } + const end = findJsonObjectEnd(text, cursor, maxPayloadBytes); + if (end === null) { + return null; + } + const rawJson = text.slice(cursor, end); + try { + const parsed = JSON.parse(rawJson) as unknown; + if (!parsed || typeof parsed !== "object" || Array.isArray(parsed)) { + return null; + } + return { end, value: parsed as Record }; + } catch { + return null; + } +} + +function parseClosing(text: string, start: number, name: string): number | null { + const cursor = skipWhitespace(text, start); + if (text.startsWith(END_TOOL_REQUEST, cursor)) { + return cursor + END_TOOL_REQUEST.length; + } + const namedClosing = `[/${name}]`; + if (text.startsWith(namedClosing, cursor)) { + return cursor + namedClosing.length; + } + return null; +} + +function parseOptionalHarmonyClosing(text: string, start: number): number { + const cursor = skipWhitespace(text, start); + if (text.startsWith(HARMONY_CALL_MARKER, cursor)) { + return cursor + HARMONY_CALL_MARKER.length; + } + return start; +} + +function parsePlainTextToolCallBlockAt( + text: string, + start: number, + options?: PlainTextToolCallParseOptions, +): PlainTextToolCallBlock | null { + const opening = parseOpening(text, start); + if (!opening) { + return null; + } + const allowedToolNames = options?.allowedToolNames + ? new Set(options.allowedToolNames) + : undefined; + if (allowedToolNames && !allowedToolNames.has(opening.name)) { + return null; + } + const payload = consumeJsonObject( + text, + opening.end, + options?.maxPayloadBytes ?? DEFAULT_MAX_PLAIN_TEXT_TOOL_PAYLOAD_BYTES, + ); + if (!payload) { + return null; + } + const closingEnd = opening.requiresClosing + ? parseClosing(text, payload.end, opening.name) + : parseOptionalHarmonyClosing(text, payload.end); + if (closingEnd === null) { + return null; + } + return { + arguments: payload.value, + end: closingEnd, + name: opening.name, + raw: text.slice(start, closingEnd), + start, + }; +} + +type XmlishParameterBlockBounds = { + closeStart: number; + end: number; + name: string; + payloadStart: number; + start: number; +}; + +function findXmlishParameterBlock(text: string, start: number): XmlishParameterBlockBounds | null { + const cursor = skipWhitespace(text, start); + const openMatch = /^/i.exec(text.slice(cursor)); + if (!openMatch?.[1]) { + return null; + } + const payloadStart = cursor + openMatch[0].length; + const closeMatch = /<\/parameter>/i.exec(text.slice(payloadStart)); + if (!closeMatch) { + return null; + } + const closeStart = payloadStart + closeMatch.index; + const closeEnd = closeStart + closeMatch[0].length; + return { + closeStart, + end: closeEnd, + name: openMatch[1], + payloadStart, + start: cursor, + }; +} + +function consumeXmlishParameterBlock( + text: string, + start: number, + maxPayloadBytes: number, +): { end: number; name: string; value: string } | null { + const bounds = findXmlishParameterBlock(text, start); + if (!bounds) { + return null; + } + if (bounds.end - bounds.start > maxPayloadBytes) { + return null; + } + return { + end: bounds.end, + name: bounds.name, + value: extractXmlishParameterValue(text, bounds.payloadStart, bounds.closeStart), + }; +} + +function extractXmlishParameterValue(text: string, start: number, end: number): string { + let payloadStart = start; + let payloadEnd = end; + const afterOpeningLineBreak = consumeLineBreak(text, payloadStart); + if (afterOpeningLineBreak !== null) { + payloadStart = afterOpeningLineBreak; + if (payloadEnd > payloadStart && text[payloadEnd - 1] === "\n") { + payloadEnd -= 1; + if (payloadEnd > payloadStart && text[payloadEnd - 1] === "\r") { + payloadEnd -= 1; + } + } else if (payloadEnd > payloadStart && text[payloadEnd - 1] === "\r") { + payloadEnd -= 1; + } + } + return text.slice(payloadStart, payloadEnd); +} + +function consumeXmlishFunctionClose(text: string, start: number): number | null { + const cursor = skipWhitespace(text, start); + return text.slice(cursor).toLowerCase().startsWith("") + ? cursor + "".length + : null; +} + +function consumeOptionalXmlishFunctionClose(text: string, start: number): number { + return consumeXmlishFunctionClose(text, start) ?? start; +} + +function parseXmlishPlainTextToolCallBlockEndAt( + text: string, + start: number, + options?: PlainTextToolCallParseOptions, +): number | null { + const opening = parseXmlishOpening(text, start); + if (!opening) { + return null; + } + const allowedToolNames = options?.allowedToolNames + ? new Set(options.allowedToolNames) + : undefined; + if (allowedToolNames && !allowedToolNames.has(opening.name)) { + return null; + } + + let cursor = opening.end; + let parameterCount = 0; + while (true) { + const parameter = findXmlishParameterBlock(text, cursor); + if (!parameter) { + break; + } + parameterCount += 1; + cursor = parameter.end; + } + if (parameterCount === 0) { + return null; + } + return opening.allowsOptionalXmlishClose + ? consumeOptionalXmlishFunctionClose(text, cursor) + : consumeXmlishFunctionClose(text, cursor); +} + +function parseXmlishOpening(text: string, start: number): PlainTextToolCallOpening | null { + return parseBracketOpening(text, start) ?? parseXmlishFunctionOpening(text, start); +} + +function parseXmlishPlainTextToolCallBlockAt( + text: string, + start: number, + options?: PlainTextToolCallParseOptions, +): PlainTextToolCallBlock | null { + const opening = parseXmlishOpening(text, start); + if (!opening) { + return null; + } + const allowedToolNames = options?.allowedToolNames + ? new Set(options.allowedToolNames) + : undefined; + if (allowedToolNames && !allowedToolNames.has(opening.name)) { + return null; + } + + const maxPayloadBytes = options?.maxPayloadBytes ?? DEFAULT_MAX_PLAIN_TEXT_TOOL_PAYLOAD_BYTES; + const args: Record = {}; + let cursor = opening.end; + let parameterCount = 0; + while (true) { + const parameter = consumeXmlishParameterBlock(text, cursor, maxPayloadBytes); + if (!parameter) { + break; + } + if (parameter.end - opening.end > maxPayloadBytes) { + return null; + } + args[parameter.name] = parameter.value; + parameterCount += 1; + cursor = parameter.end; + } + if (parameterCount === 0) { + return null; + } + + const end = opening.allowsOptionalXmlishClose + ? consumeOptionalXmlishFunctionClose(text, cursor) + : consumeXmlishFunctionClose(text, cursor); + if (end === null) { + return null; + } + return { + arguments: args, + end, + name: opening.name, + raw: text.slice(start, end), + start, + }; +} + +export function parseStandalonePlainTextToolCallBlocks( + text: string, + options?: PlainTextToolCallParseOptions, +): PlainTextToolCallBlock[] | null { + const blocks: PlainTextToolCallBlock[] = []; + let cursor = skipWhitespace(text, 0); + while (cursor < text.length) { + const block = + parsePlainTextToolCallBlockAt(text, cursor, options) ?? + parseXmlishPlainTextToolCallBlockAt(text, cursor, options); + if (!block) { + return null; + } + blocks.push(block); + cursor = skipWhitespace(text, block.end); + } + return blocks.length > 0 ? blocks : null; +} + +export function stripPlainTextToolCallBlocks(text: string): string { + if ( + !text || + (!/\[(?:tool:)?[A-Za-z0-9_-]+\]/.test(text) && + !/(?:^|\n)\s*(?:<\|channel\|>)?(?:commentary|analysis|final)\s+to=/.test(text) && + !/(?:^|\n)\s*/i.test(text)) + ) { + return text; + } + let result = ""; + let cursor = 0; + let index = 0; + while (index < text.length) { + const lineStart = index === 0 || text[index - 1] === "\n"; + if (!lineStart) { + index += 1; + continue; + } + const blockStart = skipHorizontalWhitespace(text, index); + const block = parsePlainTextToolCallBlockAt(text, blockStart); + const blockEnd = block?.end ?? parseXmlishPlainTextToolCallBlockEndAt(text, blockStart); + if (blockEnd === null) { + index += 1; + continue; + } + result += text.slice(cursor, index); + cursor = blockEnd; + const afterBlockLineBreak = consumeLineBreak(text, cursor); + if (afterBlockLineBreak !== null) { + cursor = afterBlockLineBreak; + } + index = cursor; + } + result += text.slice(cursor); + return result; +} diff --git a/packages/tool-call-repair/src/promote.ts b/packages/tool-call-repair/src/promote.ts new file mode 100644 index 00000000000..7469895cc5f --- /dev/null +++ b/packages/tool-call-repair/src/promote.ts @@ -0,0 +1,257 @@ +import { parseStandalonePlainTextToolCallBlocks, type PlainTextToolCallBlock } from "./payload.js"; + +export type ToolCallRepairNameResolver = ( + rawName: string, + allowedToolNames: Set, +) => string | null; + +export type PromotedPlainTextToolCallBlockFactory = ( + block: PlainTextToolCallBlock, + resolvedName: string, +) => Record; + +export type PlainTextToolCallPromotionOptions = { + allowedStopReasons?: ReadonlySet; + allowedToolNames: Set; + createToolCallBlock: PromotedPlainTextToolCallBlockFactory; + isRetainableNonTextBlock?: (block: Record) => boolean; + message: unknown; + requireAssistantRole?: boolean; + resolveToolName?: ToolCallRepairNameResolver; +}; + +function asRecord(value: unknown): Record | undefined { + return value && typeof value === "object" ? (value as Record) : undefined; +} + +function resolveExactToolName(rawName: string, allowedToolNames: Set): string | null { + return allowedToolNames.has(rawName) ? rawName : null; +} + +function createPromotedToolCallBlocks( + text: string, + options: PlainTextToolCallPromotionOptions, +): Record[] | undefined { + const parsedBlocks = parseStandalonePlainTextToolCallBlocks(text); + if (!parsedBlocks) { + return undefined; + } + + const resolveToolName = options.resolveToolName ?? resolveExactToolName; + const toolCalls: Record[] = []; + for (const block of parsedBlocks) { + const resolvedName = resolveToolName(block.name, options.allowedToolNames); + if (!resolvedName) { + return undefined; + } + toolCalls.push(options.createToolCallBlock(block, resolvedName)); + } + return toolCalls; +} + +function createPromotedToolCallBlocksFromTextParts( + textParts: readonly string[], + options: PlainTextToolCallPromotionOptions, +): Record[] | undefined { + const exactText = textParts.join("").trim(); + if (!exactText) { + return []; + } + for (const text of createTextPartPromotionCandidates(textParts, exactText)) { + const toolCalls = createPromotedToolCallBlocks(text, options); + if (toolCalls) { + return toolCalls; + } + } + return undefined; +} + +function createTextPartPromotionCandidates( + textParts: readonly string[], + exactText: string, +): string[] { + const repairedText = joinTextPartsWithStructuralLineBreaks(textParts).trim(); + const newlineJoinedText = textParts.join("\n").trim(); + return [...new Set([repairedText, exactText, newlineJoinedText].filter(Boolean))]; +} + +function joinTextPartsWithStructuralLineBreaks(textParts: readonly string[]): string { + let text = ""; + for (const part of textParts) { + if (text && shouldInsertStructuralLineBreak(text, part)) { + text += "\n"; + } + text += part; + } + return text; +} + +function shouldInsertStructuralLineBreak(left: string, right: string): boolean { + if (!left || !right || /[\r\n]$/u.test(left) || /^\s/u.test(right)) { + return false; + } + const trimmedLeft = left.trimEnd(); + return ( + /$/iu.test(trimmedLeft) || + /^\[[A-Za-z0-9_-]+\]$/u.test(trimmedLeft) + ); +} + +function shouldPromoteMessage(options: PlainTextToolCallPromotionOptions): boolean { + if (options.allowedToolNames.size === 0) { + return false; + } + const messageRecord = asRecord(options.message); + if (!messageRecord) { + return false; + } + if (options.requireAssistantRole && messageRecord.role !== "assistant") { + return false; + } + return !options.allowedStopReasons || options.allowedStopReasons.has(messageRecord.stopReason); +} + +export function extractStandalonePlainTextToolCallText(params: { + allowOtherNonTextBlocks?: boolean; + allowedStopReasons?: ReadonlySet; + isRetainableNonTextBlock?: (block: Record) => boolean; + message: unknown; + requireAssistantRole?: boolean; +}): string | undefined { + const record = asRecord(params.message); + if (!record) { + return undefined; + } + if (params.requireAssistantRole && record.role !== "assistant") { + return undefined; + } + if (params.allowedStopReasons && !params.allowedStopReasons.has(record.stopReason)) { + return undefined; + } + + const content = record.content; + if (typeof content === "string") { + const text = content.trim(); + return text || undefined; + } + if (!Array.isArray(content)) { + return undefined; + } + + const textParts: string[] = []; + for (const block of content) { + const blockRecord = asRecord(block); + if (!blockRecord) { + return undefined; + } + if (blockRecord.type === "text") { + if (typeof blockRecord.text !== "string") { + return undefined; + } + if (blockRecord.text.trim()) { + textParts.push(blockRecord.text); + } + continue; + } + if (params.isRetainableNonTextBlock?.(blockRecord) || params.allowOtherNonTextBlocks) { + continue; + } + return undefined; + } + + const text = textParts.join("").trim(); + return text || undefined; +} + +export function promoteStandalonePlainTextToolCallMessage( + options: PlainTextToolCallPromotionOptions, +): Record | undefined { + if (!shouldPromoteMessage(options)) { + return undefined; + } + const messageRecord = asRecord(options.message); + if (!messageRecord) { + return undefined; + } + + const originalContent = messageRecord.content; + if (typeof originalContent === "string") { + const text = originalContent.trim(); + if (!text) { + return undefined; + } + const toolCalls = createPromotedToolCallBlocks(text, options); + if (!toolCalls) { + return undefined; + } + return { + ...messageRecord, + content: toolCalls, + stopReason: "toolUse", + }; + } + + if (!Array.isArray(originalContent)) { + return undefined; + } + + const content: Array> = []; + let promotedTextBlock = false; + let textParts: string[] = []; + const flushTextParts = (): boolean | undefined => { + if (textParts.length === 0) { + return false; + } + const toolCalls = createPromotedToolCallBlocksFromTextParts(textParts, options); + textParts = []; + if (toolCalls?.length === 0) { + return false; + } + if (!toolCalls) { + return undefined; + } + content.push(...toolCalls); + return true; + }; + + for (const block of originalContent) { + const blockRecord = asRecord(block); + if (!blockRecord) { + return undefined; + } + if (blockRecord.type === "text") { + if (typeof blockRecord.text !== "string") { + return undefined; + } + if (blockRecord.text.trim()) { + textParts.push(blockRecord.text); + } + continue; + } + const promotedTextRun = flushTextParts(); + if (promotedTextRun === undefined) { + return undefined; + } + promotedTextBlock ||= promotedTextRun; + if (options.isRetainableNonTextBlock?.(blockRecord)) { + content.push(blockRecord); + continue; + } + return undefined; + } + + const promotedTrailingTextRun = flushTextParts(); + if (promotedTrailingTextRun === undefined) { + return undefined; + } + promotedTextBlock ||= promotedTrailingTextRun; + if (!promotedTextBlock) { + return undefined; + } + + return { + ...messageRecord, + content, + stopReason: "toolUse", + }; +} diff --git a/packages/tool-call-repair/src/stream-normalizer.ts b/packages/tool-call-repair/src/stream-normalizer.ts new file mode 100644 index 00000000000..ea94bc80b3d --- /dev/null +++ b/packages/tool-call-repair/src/stream-normalizer.ts @@ -0,0 +1,1353 @@ +import { + consumeJsonToolClosingMarker, + END_TOOL_REQUEST, + findBracketedJsonPayloadStart, + findHarmonyJsonPayloadStart, + findJsonObjectEnd, + findXmlishToolCallEnd, + isPlainTextToolNameChar, + isXmlishNameChar, + matchesLiteralPrefix, +} from "./grammar.js"; + +export type PlainTextToolCallNameMatcher = { + hasExactName(name: string): boolean; + hasNamePrefix(prefix: string): boolean; +}; + +export type PlainTextToolCallMessageNormalization = + | { kind: "promoted" | "scrubbed"; message: Record } + | undefined; + +export type PlainTextToolCallStreamNormalizerOptions = { + createPromotedToolCallEvents(message: Record): Iterable; + matcher: PlainTextToolCallNameMatcher; + normalizeDoneMessage(params: { + message: unknown; + reason: unknown; + }): PlainTextToolCallMessageNormalization; + stopAfterDone?: boolean; +}; + +const TEXT_TOOL_CALL_BUFFER_MAX_CHARS = 256_000; + +const TEXT_TOOL_CALL_SUPPRESSED_SCAN_MAX_CHARS = TEXT_TOOL_CALL_BUFFER_MAX_CHARS + 64_000; +const TEXT_TOOL_CALL_SUPPRESSED_TAIL_CHARS = + TEXT_TOOL_CALL_SUPPRESSED_SCAN_MAX_CHARS - TEXT_TOOL_CALL_BUFFER_MAX_CHARS; +const TEXT_TOOL_CALL_SUPPRESSED_MARKER_SCAN_CHARS = 2_048; + +type PlainTextToolCallBufferState = "possible" | "impossible" | "over-cap"; + +function asRecord(value: unknown): Record | undefined { + return value && typeof value === "object" ? (value as Record) : undefined; +} + +function couldStillBeJsonPayload(text: string, start: number): boolean { + let cursor = start; + while (cursor < text.length && /\s/.test(text[cursor] ?? "")) { + cursor += 1; + } + return cursor >= text.length || text[cursor] === "{"; +} + +function couldStillBeXmlishParameterPayload(text: string, start: number): boolean { + let cursor = start; + while (cursor < text.length && /\s/.test(text[cursor] ?? "")) { + cursor += 1; + } + if (cursor >= text.length) { + return true; + } + return matchesLiteralPrefix(text.slice(cursor).toLowerCase(), "= text.length) { + return true; + } + if (text[cursor] !== "]") { + return false; + } + if (!matcher.hasExactName(name)) { + return false; + } + return ( + couldStillBeJsonPayload(text, cursor + 1) || + couldStillBeXmlishParameterPayload(text, cursor + 1) + ); + } + + let cursor = 1; + while (isPlainTextToolNameChar(text[cursor])) { + cursor += 1; + } + const name = text.slice(1, cursor); + if (!name || !matcher.hasNamePrefix(name)) { + return false; + } + if (cursor >= text.length) { + return true; + } + if (text[cursor] !== "]") { + return false; + } + if (!matcher.hasExactName(name)) { + return false; + } + + cursor += 1; + while (text[cursor] === " " || text[cursor] === "\t") { + cursor += 1; + } + if (cursor >= text.length) { + return true; + } + if (text[cursor] === "\r") { + if (cursor + 1 >= text.length) { + return true; + } + const payloadStart = text[cursor + 1] === "\n" ? cursor + 2 : cursor + 1; + return ( + couldStillBeJsonPayload(text, payloadStart) || + couldStillBeXmlishParameterPayload(text, payloadStart) + ); + } + if (text[cursor] !== "\n") { + return false; + } + return ( + couldStillBeJsonPayload(text, cursor + 1) || + couldStillBeXmlishParameterPayload(text, cursor + 1) + ); +} + +function couldStillBeXmlishFunctionToolCall( + text: string, + matcher: PlainTextToolCallNameMatcher, +): boolean { + const marker = "= text.length) { + return true; + } + if (text[cursor] !== ">") { + return false; + } + if (!matcher.hasExactName(name)) { + return false; + } + return couldStillBeXmlishParameterPayload(text, cursor + 1); +} + +function couldStillBeHarmonyStandaloneToolCall( + text: string, + matcher: PlainTextToolCallNameMatcher, +): boolean { + const channelMarker = "<|channel|>"; + let cursor = 0; + if (matchesLiteralPrefix(text, channelMarker)) { + if (text.length <= channelMarker.length) { + return true; + } + cursor = channelMarker.length; + } + + const rest = text.slice(cursor); + const channel = ["commentary", "analysis", "final"].find((candidate) => + matchesLiteralPrefix(rest, candidate), + ); + if (!channel) { + return false; + } + if (rest.length <= channel.length) { + return true; + } + + cursor += channel.length; + while (text[cursor] === " " || text[cursor] === "\t") { + cursor += 1; + } + if (cursor >= text.length) { + return true; + } + + const toMarker = "to="; + const toRest = text.slice(cursor); + if (!matchesLiteralPrefix(toRest, toMarker)) { + return false; + } + if (toRest.length <= toMarker.length) { + return true; + } + + cursor += toMarker.length; + const nameStart = cursor; + while (isPlainTextToolNameChar(text[cursor])) { + cursor += 1; + } + const name = text.slice(nameStart, cursor); + if (!name || !matcher.hasNamePrefix(name)) { + return false; + } + if (cursor >= text.length) { + return true; + } + + while (text[cursor] === " " || text[cursor] === "\t") { + cursor += 1; + } + if (cursor >= text.length) { + return true; + } + if (!matcher.hasExactName(name)) { + return false; + } + + const codeMarker = "code"; + const codeRest = text.slice(cursor); + if (!matchesLiteralPrefix(codeRest, codeMarker)) { + return false; + } + if (codeRest.length <= codeMarker.length) { + return true; + } + + cursor += codeMarker.length; + while (cursor < text.length && /\s/.test(text[cursor] ?? "")) { + cursor += 1; + } + if (cursor >= text.length) { + return true; + } + + const messageMarker = "<|message|>"; + const messageRest = text.slice(cursor); + if (matchesLiteralPrefix(messageRest, messageMarker)) { + return true; + } + return text[cursor] === "{"; +} + +function hasExactSerializedToolCallPrefix( + text: string, + matcher: PlainTextToolCallNameMatcher, +): boolean { + const bracketed = /^\[(?:tool:)?([A-Za-z0-9_-]+)\]/.exec(text); + if (bracketed?.[1]) { + return matcher.hasExactName(bracketed[1]); + } + const xmlish = /^/i.exec(text); + if (xmlish?.[1]) { + return matcher.hasExactName(xmlish[1]); + } + const harmony = + /^(?:<\|channel\|>)?(?:commentary|analysis|final)\s+to=([A-Za-z0-9_-]+)\s+code\b/.exec(text); + return Boolean(harmony?.[1] && matcher.hasExactName(harmony[1])); +} + +function stripCompleteSerializedToolCallPrefix( + text: string, + matcher?: PlainTextToolCallNameMatcher, +): string | null { + if (matcher && !hasExactSerializedToolCallPrefix(text, matcher)) { + return null; + } + const xmlishEnd = findXmlishToolCallEnd(text); + if (xmlishEnd !== null) { + return text.slice(xmlishEnd); + } + const jsonStart = findBracketedJsonPayloadStart(text) ?? findHarmonyJsonPayloadStart(text); + if (jsonStart === null) { + return null; + } + const jsonEnd = findJsonObjectEnd(text, jsonStart); + if (jsonEnd === null) { + return null; + } + return text.slice(consumeJsonToolClosingMarker(text, jsonEnd)); +} + +function stripSerializedToolCallPrefixes( + text: string, + matcher: PlainTextToolCallNameMatcher, +): string | null { + let current = text; + let changed = false; + for (let count = 0; count < 32; count += 1) { + const next = stripCompleteSerializedToolCallPrefix(current.trimStart(), matcher); + if (next === null) { + if (changed && hasExactSerializedToolCallPrefix(current.trimStart(), matcher)) { + return ""; + } + return changed ? current : null; + } + changed = true; + current = next; + if (!current.trim()) { + return current; + } + } + return hasExactSerializedToolCallPrefix(current.trimStart(), matcher) ? "" : current; +} + +function getPlainTextToolCallBufferState( + text: string, + matcher: PlainTextToolCallNameMatcher, +): PlainTextToolCallBufferState { + const trimmed = text.trimStart(); + if (trimmed.length === 0) { + return text.length > TEXT_TOOL_CALL_BUFFER_MAX_CHARS ? "impossible" : "possible"; + } + const toolCallLike = + couldStillBeBracketedStandaloneToolCall(trimmed, matcher) || + couldStillBeXmlishFunctionToolCall(trimmed, matcher) || + couldStillBeHarmonyStandaloneToolCall(trimmed, matcher); + if (!toolCallLike) { + return "impossible"; + } + if (text.length <= TEXT_TOOL_CALL_BUFFER_MAX_CHARS) { + return "possible"; + } + const textAfterCompleteToolBlocks = stripSerializedToolCallPrefixes(trimmed, matcher); + return textAfterCompleteToolBlocks !== null && textAfterCompleteToolBlocks.trim() + ? "impossible" + : "over-cap"; +} + +function getTextToolCallEventText(event: Record): string | undefined { + if (typeof event.delta === "string") { + return event.delta; + } + return typeof event.content === "string" ? event.content : undefined; +} + +function appendTextToolCallBuffer(bufferedText: string, event: Record): string { + const text = getTextToolCallEventText(event); + if (text === undefined) { + return bufferedText; + } + if (typeof event.content === "string" && !bufferedText) { + return text; + } + return typeof event.delta === "string" ? bufferedText + text : bufferedText; +} + +function hasSuppressedToolCallClosingMarker(text: string): boolean { + if (!text) { + return false; + } + const lowerText = text.toLowerCase(); + return ( + lowerText.includes("") || + lowerText.includes("") || + text.includes(END_TOOL_REQUEST) || + text.includes("<|call|>") || + text.includes("}") || + /\[\/[A-Za-z0-9_.:-]+\]/.test(text) + ); +} + +function shouldRescanSuppressedTextToolCallBuffer( + previousBufferedText: string, + event: Record, +): boolean { + const eventText = getTextToolCallEventText(event); + if (!eventText) { + return false; + } + return hasSuppressedToolCallClosingMarker( + previousBufferedText.slice(-TEXT_TOOL_CALL_SUPPRESSED_MARKER_SCAN_CHARS) + eventText, + ); +} + +function truncateSuppressedTextToolCallBuffer(text: string): string { + if (text.length <= TEXT_TOOL_CALL_SUPPRESSED_SCAN_MAX_CHARS) { + return text; + } + return ( + text.slice(0, TEXT_TOOL_CALL_BUFFER_MAX_CHARS) + + text.slice(-TEXT_TOOL_CALL_SUPPRESSED_TAIL_CHARS) + ); +} + +function appendSuppressedTextToolCallBuffer( + bufferedText: string, + event: Record, +): { changed: boolean; scanText: string; text: string } { + const nextText = appendTextToolCallBuffer(bufferedText, event); + if (nextText === bufferedText) { + return { changed: false, scanText: bufferedText, text: bufferedText }; + } + return { + changed: true, + scanText: nextText, + text: truncateSuppressedTextToolCallBuffer(nextText), + }; +} + +function shouldSuppressBufferedTextBlock(blockText: string, bufferedText: string): boolean { + const normalizedBlock = blockText.trim(); + const normalizedBuffer = bufferedText.trim(); + const normalizedSuppressedPrefix = bufferedText.slice(0, TEXT_TOOL_CALL_BUFFER_MAX_CHARS).trim(); + return ( + Boolean(normalizedBlock && normalizedBuffer) && + (normalizedBuffer.startsWith(normalizedBlock) || + normalizedBlock.startsWith(normalizedBuffer) || + (bufferedText.length >= TEXT_TOOL_CALL_SUPPRESSED_SCAN_MAX_CHARS && + Boolean(normalizedSuppressedPrefix) && + normalizedBlock.startsWith(normalizedSuppressedPrefix))) + ); +} + +function scrubBufferedTextFromContent( + content: unknown, + bufferedText: string, + matcher: PlainTextToolCallNameMatcher, + options?: { onlyTextIndex?: unknown; preserveEmptyTextBlocks?: boolean }, +): { changed: boolean; content: unknown } { + if (Array.isArray(content)) { + if (typeof options?.onlyTextIndex === "number") { + const block = content[options.onlyTextIndex]; + const record = asRecord(block); + if ( + record?.type !== "text" || + typeof record.text !== "string" || + !shouldSuppressBufferedTextBlock(record.text, bufferedText) + ) { + return { changed: false, content }; + } + const nextContent = [...content]; + if (options.preserveEmptyTextBlocks) { + nextContent[options.onlyTextIndex] = { ...record, text: "" }; + } else { + nextContent.splice(options.onlyTextIndex, 1); + } + return { changed: true, content: nextContent }; + } + + const overCapPrefix = scrubOverCapTextPrefixFromContent(content, matcher, options); + if (overCapPrefix.changed) { + return overCapPrefix; + } + + let changed = false; + const nextContent = content.flatMap((block) => { + const record = asRecord(block); + if ( + record?.type === "text" && + typeof record.text === "string" && + shouldSuppressBufferedTextBlock(record.text, bufferedText) + ) { + changed = true; + return options?.preserveEmptyTextBlocks ? [{ ...record, text: "" }] : []; + } + return [block]; + }); + return changed ? { changed, content: nextContent } : { changed: false, content }; + } + if (typeof content === "string" && shouldSuppressBufferedTextBlock(content, bufferedText)) { + return { changed: true, content: "" }; + } + return { changed: false, content }; +} + +function scrubOverCapTextPrefixFromContent( + content: readonly unknown[], + matcher: PlainTextToolCallNameMatcher, + options?: { preserveEmptyTextBlocks?: boolean }, +): { changed: boolean; content: unknown } { + let currentContent: readonly unknown[] = content; + let changed = false; + for (let count = 0; count < 32; count += 1) { + const scrubbed = scrubFirstOverCapTextPrefixFromContent(currentContent, matcher, options); + if (!scrubbed.changed || !Array.isArray(scrubbed.content)) { + return changed ? { changed: true, content: currentContent } : scrubbed; + } + currentContent = scrubbed.content; + changed = true; + } + return { changed, content: currentContent }; +} + +function scrubFirstOverCapTextPrefixFromContent( + content: readonly unknown[], + matcher: PlainTextToolCallNameMatcher, + options?: { preserveEmptyTextBlocks?: boolean }, +): { changed: boolean; content: unknown } { + const suppressedTextIndexes = new Set(); + let accumulated = ""; + let reachedOverCap = false; + for (let index = 0; index < content.length; index += 1) { + const record = asRecord(content[index]); + if (record?.type !== "text" || typeof record.text !== "string") { + continue; + } + if (!record.text.trim()) { + continue; + } + if (!accumulated && !hasExactSerializedToolCallPrefix(record.text.trimStart(), matcher)) { + continue; + } + if (reachedOverCap && hasExactSerializedToolCallPrefix(record.text.trimStart(), matcher)) { + break; + } + if ( + reachedOverCap && + suppressedTextIndexes.size === 1 && + !hasSuppressedToolCallClosingMarker(record.text) + ) { + break; + } + + accumulated = accumulated ? `${accumulated}\n${record.text}` : record.text; + suppressedTextIndexes.add(index); + + const state = getPlainTextToolCallBufferState(accumulated, matcher); + if (state === "over-cap") { + reachedOverCap = true; + const strippedSuffix = stripSerializedToolCallPrefixes(accumulated, matcher); + if (strippedSuffix !== null) { + return scrubSuppressedTextIndexesFromContent( + content, + suppressedTextIndexes, + options, + strippedSuffix, + index, + ); + } + continue; + } + if (state === "impossible") { + if (reachedOverCap) { + const strippedSuffix = stripSerializedToolCallPrefixes(accumulated, matcher); + if (strippedSuffix !== null) { + return scrubSuppressedTextIndexesFromContent( + content, + suppressedTextIndexes, + options, + strippedSuffix, + index, + ); + } + return scrubSuppressedTextIndexesFromContent(content, suppressedTextIndexes, options); + } + accumulated = ""; + suppressedTextIndexes.clear(); + reachedOverCap = false; + } + } + if (reachedOverCap) { + return scrubSuppressedTextIndexesFromContent(content, suppressedTextIndexes, options); + } + return { changed: false, content }; +} + +function scrubSuppressedTextIndexesFromContent( + content: readonly unknown[], + suppressedTextIndexes: ReadonlySet, + options?: { preserveEmptyTextBlocks?: boolean }, + visibleSuffix?: string, + visibleSuffixIndex?: number, +): { changed: boolean; content: unknown } { + const nextContent = content.flatMap((block, blockIndex) => { + if (!suppressedTextIndexes.has(blockIndex)) { + return [block]; + } + const blockRecord = asRecord(block); + if ( + visibleSuffixIndex === blockIndex && + visibleSuffix !== undefined && + visibleSuffix.trim() && + blockRecord + ) { + return [{ ...blockRecord, text: visibleSuffix }]; + } + return options?.preserveEmptyTextBlocks && blockRecord ? [{ ...blockRecord, text: "" }] : []; + }); + return { changed: true, content: nextContent }; +} + +function stripPlainTextToolCallsFromContent( + content: unknown, + matcher: PlainTextToolCallNameMatcher, + options?: { preserveEmptyTextBlocks?: boolean }, +): { changed: boolean; content: unknown } { + if (Array.isArray(content)) { + const textBlocks = content + .map((block, index) => ({ index, record: asRecord(block) })) + .filter( + (entry): entry is { index: number; record: Record } => + entry.record?.type === "text" && typeof entry.record.text === "string", + ); + const joinedText = textBlocks.map((entry) => String(entry.record.text)).join("\n"); + if (joinedText.trim()) { + const strippedJoined = stripSerializedToolCallPrefixes(joinedText.trim(), matcher); + if (strippedJoined !== null && strippedJoined !== joinedText) { + const firstTextIndex = textBlocks[0]?.index; + const nextContent = content.flatMap((block, index) => { + const record = asRecord(block); + if (record?.type !== "text" || typeof record.text !== "string") { + return [block]; + } + if (options?.preserveEmptyTextBlocks) { + return [ + { + ...record, + text: index === firstTextIndex && strippedJoined.trim() ? strippedJoined : "", + }, + ]; + } + return index === firstTextIndex && strippedJoined.trim() + ? [{ ...record, text: strippedJoined }] + : []; + }); + return { changed: true, content: nextContent }; + } + } + + let changed = false; + const nextContent: unknown[] = []; + for (const block of content) { + const record = asRecord(block); + if (record?.type !== "text" || typeof record.text !== "string") { + nextContent.push(block); + continue; + } + const strippedText = stripSerializedToolCallPrefixes(record.text, matcher); + if (strippedText === null || strippedText === record.text) { + nextContent.push(block); + continue; + } + changed = true; + if (strippedText.trim()) { + nextContent.push({ ...record, text: strippedText }); + } else if (options?.preserveEmptyTextBlocks) { + nextContent.push({ ...record, text: "" }); + } + } + return changed ? { changed, content: nextContent } : { changed: false, content }; + } + if (typeof content === "string") { + const strippedText = stripSerializedToolCallPrefixes(content, matcher); + if (strippedText !== null && strippedText !== content) { + return { changed: true, content: strippedText }; + } + } + return { changed: false, content }; +} + +function stripOverCapPlainTextToolCallsFromContent( + content: unknown, + matcher: PlainTextToolCallNameMatcher, + options?: { preserveEmptyTextBlocks?: boolean }, +): { changed: boolean; content: unknown } { + if (Array.isArray(content)) { + let changed = false; + const nextContent: unknown[] = []; + for (const block of content) { + const record = asRecord(block); + if ( + record?.type !== "text" || + typeof record.text !== "string" || + record.text.length <= TEXT_TOOL_CALL_BUFFER_MAX_CHARS + ) { + nextContent.push(block); + continue; + } + const strippedText = stripSerializedToolCallPrefixes(record.text, matcher); + if (strippedText === null || strippedText === record.text) { + nextContent.push(block); + continue; + } + changed = true; + if (strippedText.trim()) { + nextContent.push({ ...record, text: strippedText }); + } else if (options?.preserveEmptyTextBlocks) { + nextContent.push({ ...record, text: "" }); + } + } + return changed ? { changed, content: nextContent } : { changed: false, content }; + } + if (typeof content === "string" && content.length > TEXT_TOOL_CALL_BUFFER_MAX_CHARS) { + const strippedText = stripSerializedToolCallPrefixes(content, matcher); + if (strippedText !== null && strippedText !== content) { + return { changed: true, content: strippedText }; + } + } + return { changed: false, content }; +} + +function scrubPlainTextToolCallContent( + content: unknown, + bufferedText: string, + matcher: PlainTextToolCallNameMatcher, + options?: { onlyTextIndex?: unknown; preserveEmptyTextBlocks?: boolean }, +): { changed: boolean; content: unknown } { + const scrubbed = scrubBufferedTextFromContent(content, bufferedText, matcher, options); + const stripped = + options?.onlyTextIndex === undefined + ? stripPlainTextToolCallsFromContent(scrubbed.content, matcher, options) + : { changed: false, content: scrubbed.content }; + return stripped.changed ? stripped : scrubbed; +} + +function shouldPreserveEmptyTextBlocksForEventIndex( + content: unknown, + bufferedText: string, + matcher: PlainTextToolCallNameMatcher, + eventContentIndex: unknown, +): boolean { + if ( + typeof eventContentIndex !== "number" || + !Number.isInteger(eventContentIndex) || + eventContentIndex < 0 || + !Array.isArray(content) + ) { + return false; + } + const currentBlock = content[eventContentIndex]; + if (currentBlock === undefined) { + return false; + } + const scrubbed = scrubPlainTextToolCallContent(content, bufferedText, matcher); + return ( + scrubbed.changed && + Array.isArray(scrubbed.content) && + scrubbed.content[eventContentIndex] !== currentBlock + ); +} + +function scrubBufferedTextFromPartial( + event: Record, + bufferedText: string, + matcher: PlainTextToolCallNameMatcher, + contentIndex?: unknown, + options?: { preserveEmptyTextBlocks?: boolean }, +): Record { + const partial = asRecord(event.partial); + if (!partial) { + return event; + } + const preserveEmptyTextBlocks = + options?.preserveEmptyTextBlocks === true || + shouldPreserveEmptyTextBlocksForEventIndex( + partial.content, + bufferedText, + matcher, + event.contentIndex, + ); + const scrubbed = scrubPlainTextToolCallContent(partial.content, bufferedText, matcher, { + onlyTextIndex: contentIndex, + preserveEmptyTextBlocks, + }); + if (!scrubbed.changed) { + return event; + } + return { + ...event, + partial: { + ...partial, + content: scrubbed.content, + }, + }; +} + +function scrubBufferedTextFromMessage( + event: Record, + bufferedText: string, + matcher: PlainTextToolCallNameMatcher, + contentIndex?: unknown, +): Record { + const message = asRecord(event.message); + if (!message) { + return event; + } + const scrubbed = scrubPlainTextToolCallContent(message.content, bufferedText, matcher, { + onlyTextIndex: contentIndex, + }); + if (!scrubbed.changed) { + return event; + } + return { + ...event, + message: { + ...message, + content: scrubbed.content, + }, + }; +} + +function scrubBufferedTextFromError( + event: Record, + bufferedText: string, + matcher: PlainTextToolCallNameMatcher, + contentIndex?: unknown, +): Record { + const error = asRecord(event.error); + if (!error) { + return event; + } + const scrubbed = scrubPlainTextToolCallContent(error.content, bufferedText, matcher, { + onlyTextIndex: contentIndex, + }); + if (!scrubbed.changed) { + return event; + } + return { + ...event, + error: { + ...error, + content: scrubbed.content, + }, + }; +} + +function replaceTextContentWithVisibleSuffix( + record: Record, + visibleText: string, + contentIndex?: unknown, + matcher?: PlainTextToolCallNameMatcher, +): Record { + if (typeof record.content === "string") { + return { ...record, content: visibleText }; + } + if (!Array.isArray(record.content)) { + return record; + } + const originalContent = record.content; + if (typeof contentIndex === "number") { + const content = originalContent.flatMap((block, index) => { + if (index !== contentIndex) { + return [block]; + } + const blockRecord = asRecord(block); + if (blockRecord?.type !== "text" || typeof blockRecord.text !== "string") { + return [block]; + } + if (matcher && !hasExactSerializedToolCallPrefix(blockRecord.text.trimStart(), matcher)) { + return [block]; + } + return visibleText.trim() ? [{ ...blockRecord, text: visibleText }] : []; + }); + if (matcher && content.every((block, index) => block === originalContent[index])) { + return replaceTextContentWithVisibleSuffix(record, visibleText, undefined, matcher); + } + return { ...record, content }; + } + const textBlockCount = originalContent.filter((block) => { + const blockRecord = asRecord(block); + return blockRecord?.type === "text" && typeof blockRecord.text === "string"; + }).length; + if (textBlockCount !== 1) { + if (!matcher) { + return record; + } + let replaced = false; + const content = originalContent.flatMap((block) => { + const blockRecord = asRecord(block); + if (blockRecord?.type !== "text" || typeof blockRecord.text !== "string") { + return [block]; + } + if (replaced) { + return [block]; + } + if (!hasExactSerializedToolCallPrefix(blockRecord.text.trimStart(), matcher)) { + return [block]; + } + replaced = true; + return visibleText.trim() ? [{ ...blockRecord, text: visibleText }] : []; + }); + return replaced ? { ...record, content } : record; + } + let replaced = false; + const content = originalContent.flatMap((block) => { + const blockRecord = asRecord(block); + if (blockRecord?.type !== "text" || typeof blockRecord.text !== "string") { + return [block]; + } + if (replaced) { + return []; + } + replaced = true; + return visibleText.trim() ? [{ ...blockRecord, text: visibleText }] : []; + }); + return { ...record, content }; +} + +function scrubReclassifiedMixedTextFromPartial( + event: Record, + visibleText: string, + contentIndex?: unknown, + matcher?: PlainTextToolCallNameMatcher, +): Record { + const partial = asRecord(event.partial); + if (!partial) { + return event; + } + return { + ...event, + partial: replaceTextContentWithVisibleSuffix(partial, visibleText, contentIndex, matcher), + }; +} + +function scrubReclassifiedMixedTextFromError( + event: Record, + visibleText: string, + contentIndex?: unknown, + matcher?: PlainTextToolCallNameMatcher, +): Record { + const error = asRecord(event.error); + if (!error) { + return event; + } + return { + ...event, + error: replaceTextContentWithVisibleSuffix(error, visibleText, contentIndex, matcher), + }; +} + +export function scrubOverCapPlainTextToolCallMessage(params: { + candidateText: string | undefined; + matcher: PlainTextToolCallNameMatcher; + message: unknown; +}): Record | undefined { + const record = asRecord(params.message); + const candidateText = params.candidateText; + if (!record || !candidateText) { + return undefined; + } + const bufferState = getPlainTextToolCallBufferState(candidateText, params.matcher); + if (bufferState === "impossible") { + if (candidateText.length <= TEXT_TOOL_CALL_BUFFER_MAX_CHARS) { + return undefined; + } + const visibleText = stripSerializedToolCallPrefixes(candidateText, params.matcher); + if (visibleText?.trim() && !Array.isArray(record.content)) { + const replaced = replaceTextContentWithVisibleSuffix( + record, + visibleText, + undefined, + params.matcher, + ); + if (replaced !== record) { + return replaced; + } + } + if (Array.isArray(record.content)) { + const overCap = scrubOverCapTextPrefixFromContent(record.content, params.matcher); + const stripped = stripOverCapPlainTextToolCallsFromContent(overCap.content, params.matcher); + if (!overCap.changed && !stripped.changed) { + return undefined; + } + return { + ...record, + content: stripped.changed ? stripped.content : overCap.content, + }; + } + return undefined; + } + if (bufferState !== "over-cap") { + return undefined; + } + const scrubbed = scrubPlainTextToolCallContent(record.content, candidateText, params.matcher); + return { + ...record, + content: scrubbed.content, + }; +} + +function createScrubbedTextDeltaEvent( + event: Record, + text: string, +): Record { + const partial = asRecord(event.partial); + const syntheticContent = + typeof event.contentIndex === "number" + ? Array.from({ length: event.contentIndex + 1 }, (_, index) => ({ + type: "text", + text: index === event.contentIndex ? text : "", + })) + : [{ type: "text", text }]; + const scrubbedPartial = partial + ? replaceTextContentWithVisibleSuffix(partial, text, event.contentIndex) + : { role: "assistant", content: syntheticContent }; + const eventWithoutTextEndContent = { ...event }; + delete eventWithoutTextEndContent.content; + return { + ...eventWithoutTextEndContent, + type: "text_delta", + delta: text, + partial: scrubbedPartial, + }; +} + +function appendReclassifiedVisibleDelta( + visibleText: string, + event: Record, +): string { + return typeof event.delta === "string" ? `${visibleText}${event.delta}` : visibleText; +} + +function isAllowedTextToolCallLikeEvent( + event: Record, + matcher: PlainTextToolCallNameMatcher, +): boolean { + const text = getTextToolCallEventText(event); + return Boolean(text?.trim() && getPlainTextToolCallBufferState(text, matcher) !== "impossible"); +} + +function isBufferedTextEvent(bufferedEvent: unknown): boolean { + const bufferedRecord = asRecord(bufferedEvent); + const bufferedType = typeof bufferedRecord?.type === "string" ? bufferedRecord.type : ""; + return ( + bufferedType === "text_start" || bufferedType === "text_delta" || bufferedType === "text_end" + ); +} + +export async function* normalizePlainTextToolCallStreamEvents( + source: AsyncIterable, + options: PlainTextToolCallStreamNormalizerOptions, +): AsyncGenerator { + const bufferedEvents: unknown[] = []; + let bufferedText = ""; + let suppressingOverCapTextToolCall = false; + let suppressedTextContentIndex: unknown; + let hasSuppressedTextContentIndex = false; + let reclassifiedMixedTextContentIndex: unknown; + let hasReclassifiedMixedTextContentIndex = false; + let scrubReclassifiedMixedTextFromDone = false; + let reclassifiedMixedVisibleText: string | undefined; + + const flushBufferedEvents = () => { + const events = bufferedEvents.splice(0); + bufferedText = ""; + return events; + }; + + function* flushScrubbedBufferedNonTextEvents(resetBufferedText: boolean) { + const events = bufferedEvents.splice(0); + const textToScrub = bufferedText; + if (resetBufferedText) { + bufferedText = ""; + } + for (const bufferedEvent of events) { + if (isBufferedTextEvent(bufferedEvent)) { + continue; + } + const bufferedRecord = asRecord(bufferedEvent); + yield bufferedRecord + ? scrubBufferedTextFromPartial( + bufferedRecord, + textToScrub, + options.matcher, + hasSuppressedTextContentIndex ? suppressedTextContentIndex : undefined, + { preserveEmptyTextBlocks: suppressingOverCapTextToolCall }, + ) + : bufferedEvent; + } + } + + function* suppressBufferedTextEvents() { + suppressingOverCapTextToolCall = true; + yield* flushScrubbedBufferedNonTextEvents(false); + } + + for await (const event of source) { + const record = asRecord(event); + if (!record) { + yield event; + continue; + } + const type = typeof record.type === "string" ? record.type : ""; + if (type === "text_start" || type === "text_delta" || type === "text_end") { + if ( + type === "text_end" && + hasReclassifiedMixedTextContentIndex && + record.contentIndex === reclassifiedMixedTextContentIndex + ) { + continue; + } + if ( + scrubReclassifiedMixedTextFromDone && + reclassifiedMixedVisibleText !== undefined && + hasReclassifiedMixedTextContentIndex && + record.contentIndex === reclassifiedMixedTextContentIndex + ) { + reclassifiedMixedVisibleText = appendReclassifiedVisibleDelta( + reclassifiedMixedVisibleText, + record, + ); + yield scrubReclassifiedMixedTextFromPartial( + record, + reclassifiedMixedVisibleText, + reclassifiedMixedTextContentIndex, + options.matcher, + ); + continue; + } + if (suppressingOverCapTextToolCall) { + if (hasSuppressedTextContentIndex && record.contentIndex !== suppressedTextContentIndex) { + if (isAllowedTextToolCallLikeEvent(record, options.matcher)) { + continue; + } + yield scrubBufferedTextFromPartial( + record, + bufferedText, + options.matcher, + suppressedTextContentIndex, + { preserveEmptyTextBlocks: true }, + ); + continue; + } + const previousBufferedText = bufferedText; + const appended = appendSuppressedTextToolCallBuffer(bufferedText, record); + bufferedText = appended.text; + const shouldRescan = + appended.changed && + shouldRescanSuppressedTextToolCallBuffer(previousBufferedText, record); + const bufferState = shouldRescan + ? getPlainTextToolCallBufferState(appended.scanText, options.matcher) + : "over-cap"; + if (bufferState === "impossible") { + const visibleText = + stripSerializedToolCallPrefixes(appended.scanText, options.matcher) ?? ""; + yield* flushScrubbedBufferedNonTextEvents(true); + suppressingOverCapTextToolCall = false; + suppressedTextContentIndex = undefined; + hasSuppressedTextContentIndex = false; + reclassifiedMixedTextContentIndex = record.contentIndex; + hasReclassifiedMixedTextContentIndex = true; + scrubReclassifiedMixedTextFromDone = true; + reclassifiedMixedVisibleText = visibleText; + if (visibleText.trim()) { + yield createScrubbedTextDeltaEvent(record, visibleText); + } + } + continue; + } + bufferedEvents.push(event); + bufferedText = appendTextToolCallBuffer(bufferedText, record); + const scanBufferedText = truncateSuppressedTextToolCallBuffer(bufferedText); + const scanWasTruncated = scanBufferedText.length !== bufferedText.length; + const bufferState = getPlainTextToolCallBufferState(scanBufferedText, options.matcher); + if (bufferState === "impossible") { + const visibleText = + !scanWasTruncated && bufferedText.length > TEXT_TOOL_CALL_BUFFER_MAX_CHARS + ? stripSerializedToolCallPrefixes(bufferedText.trimStart(), options.matcher) + : null; + if (visibleText?.trim()) { + yield* flushScrubbedBufferedNonTextEvents(true); + reclassifiedMixedTextContentIndex = record.contentIndex; + hasReclassifiedMixedTextContentIndex = true; + scrubReclassifiedMixedTextFromDone = true; + reclassifiedMixedVisibleText = visibleText; + yield createScrubbedTextDeltaEvent(record, visibleText); + } else if ( + scanWasTruncated && + stripSerializedToolCallPrefixes(scanBufferedText.trimStart(), options.matcher) !== null + ) { + bufferedText = scanBufferedText; + suppressedTextContentIndex = record.contentIndex; + hasSuppressedTextContentIndex = true; + yield* suppressBufferedTextEvents(); + } else { + yield* flushBufferedEvents(); + } + } else if (bufferState === "over-cap") { + bufferedText = scanBufferedText; + suppressedTextContentIndex = record.contentIndex; + hasSuppressedTextContentIndex = true; + yield* suppressBufferedTextEvents(); + } + continue; + } + + if (type === "done") { + const normalizedMessage = options.normalizeDoneMessage({ + message: record.message, + reason: record.reason, + }); + if (normalizedMessage?.kind === "promoted") { + yield* flushScrubbedBufferedNonTextEvents(true); + suppressingOverCapTextToolCall = false; + suppressedTextContentIndex = undefined; + hasSuppressedTextContentIndex = false; + scrubReclassifiedMixedTextFromDone = false; + reclassifiedMixedTextContentIndex = undefined; + hasReclassifiedMixedTextContentIndex = false; + reclassifiedMixedVisibleText = undefined; + yield* options.createPromotedToolCallEvents(normalizedMessage.message); + yield { ...record, reason: "toolUse", message: normalizedMessage.message }; + if (options.stopAfterDone) { + return; + } + continue; + } + if (normalizedMessage?.kind === "scrubbed") { + yield* flushScrubbedBufferedNonTextEvents(true); + suppressingOverCapTextToolCall = false; + suppressedTextContentIndex = undefined; + hasSuppressedTextContentIndex = false; + scrubReclassifiedMixedTextFromDone = false; + reclassifiedMixedTextContentIndex = undefined; + hasReclassifiedMixedTextContentIndex = false; + reclassifiedMixedVisibleText = undefined; + yield { ...record, message: normalizedMessage.message }; + if (options.stopAfterDone) { + return; + } + continue; + } + const mixedMessageRecord = scrubReclassifiedMixedTextFromDone + ? asRecord(record.message) + : undefined; + const strippedMixedMessage = + mixedMessageRecord && reclassifiedMixedVisibleText !== undefined + ? replaceTextContentWithVisibleSuffix( + mixedMessageRecord, + reclassifiedMixedVisibleText, + hasReclassifiedMixedTextContentIndex ? reclassifiedMixedTextContentIndex : undefined, + options.matcher, + ) + : undefined; + if (strippedMixedMessage) { + yield* flushScrubbedBufferedNonTextEvents(true); + scrubReclassifiedMixedTextFromDone = false; + reclassifiedMixedTextContentIndex = undefined; + hasReclassifiedMixedTextContentIndex = false; + reclassifiedMixedVisibleText = undefined; + yield { ...record, message: strippedMixedMessage }; + if (options.stopAfterDone) { + return; + } + continue; + } + if (suppressingOverCapTextToolCall) { + const scrubbedDoneEvent = scrubBufferedTextFromMessage( + record, + bufferedText, + options.matcher, + hasSuppressedTextContentIndex ? suppressedTextContentIndex : undefined, + ); + yield* flushScrubbedBufferedNonTextEvents(true); + suppressingOverCapTextToolCall = false; + suppressedTextContentIndex = undefined; + hasSuppressedTextContentIndex = false; + scrubReclassifiedMixedTextFromDone = false; + reclassifiedMixedTextContentIndex = undefined; + hasReclassifiedMixedTextContentIndex = false; + reclassifiedMixedVisibleText = undefined; + yield scrubbedDoneEvent; + if (options.stopAfterDone) { + return; + } + continue; + } + yield* flushBufferedEvents(); + yield event; + if (options.stopAfterDone) { + return; + } + continue; + } + + if (type === "error") { + if (!suppressingOverCapTextToolCall) { + yield* flushBufferedEvents(); + } + yield suppressingOverCapTextToolCall + ? scrubBufferedTextFromError( + scrubBufferedTextFromPartial( + record, + bufferedText, + options.matcher, + hasSuppressedTextContentIndex ? suppressedTextContentIndex : undefined, + { preserveEmptyTextBlocks: true }, + ), + bufferedText, + options.matcher, + hasSuppressedTextContentIndex ? suppressedTextContentIndex : undefined, + ) + : scrubReclassifiedMixedTextFromDone && reclassifiedMixedVisibleText !== undefined + ? scrubReclassifiedMixedTextFromError( + scrubReclassifiedMixedTextFromPartial( + record, + reclassifiedMixedVisibleText, + hasReclassifiedMixedTextContentIndex + ? reclassifiedMixedTextContentIndex + : undefined, + options.matcher, + ), + reclassifiedMixedVisibleText, + hasReclassifiedMixedTextContentIndex ? reclassifiedMixedTextContentIndex : undefined, + options.matcher, + ) + : event; + return; + } + + if (scrubReclassifiedMixedTextFromDone && reclassifiedMixedVisibleText !== undefined) { + yield scrubReclassifiedMixedTextFromPartial( + record, + reclassifiedMixedVisibleText, + hasReclassifiedMixedTextContentIndex ? reclassifiedMixedTextContentIndex : undefined, + options.matcher, + ); + continue; + } + + if (bufferedEvents.length > 0 && !suppressingOverCapTextToolCall) { + bufferedEvents.push(event); + continue; + } + + yield suppressingOverCapTextToolCall + ? scrubBufferedTextFromPartial( + record, + bufferedText, + options.matcher, + hasSuppressedTextContentIndex ? suppressedTextContentIndex : undefined, + { preserveEmptyTextBlocks: suppressingOverCapTextToolCall }, + ) + : event; + } + + if (!suppressingOverCapTextToolCall) { + yield* flushBufferedEvents(); + } +} diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index d024217ffa4..5a95660c753 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -1855,6 +1855,8 @@ importers: packages/web-content-core: {} + packages/tool-call-repair: {} + ui: dependencies: '@create-markdown/preview': diff --git a/src/agents/embedded-agent-helpers/sanitize-user-facing-text.ts b/src/agents/embedded-agent-helpers/sanitize-user-facing-text.ts index 31e50629c33..963d4b7dfe9 100644 --- a/src/agents/embedded-agent-helpers/sanitize-user-facing-text.ts +++ b/src/agents/embedded-agent-helpers/sanitize-user-facing-text.ts @@ -1,5 +1,5 @@ +import { stripPlainTextToolCallBlocks } from "../../../packages/tool-call-repair/src/index.js"; import { stripInboundMetadata } from "../../auto-reply/reply/strip-inbound-meta.js"; -import { stripPlainTextToolCallBlocks } from "../../plugin-sdk/tool-payload.js"; import { extractLeadingHttpStatus, formatRawAssistantErrorForUi, diff --git a/src/agents/embedded-agent-runner/run/attempt.tool-call-normalization.test.ts b/src/agents/embedded-agent-runner/run/attempt.tool-call-normalization.test.ts index 165af7fd38e..cbdacb254e2 100644 --- a/src/agents/embedded-agent-runner/run/attempt.tool-call-normalization.test.ts +++ b/src/agents/embedded-agent-runner/run/attempt.tool-call-normalization.test.ts @@ -1,13 +1,51 @@ import type { AgentMessage } from "openclaw/plugin-sdk/agent-core"; -import { describe, expect, it } from "vitest"; +import { describe, expect, it, vi } from "vitest"; import { sanitizeOpenAIResponsesReplayForStream, sanitizeReplayToolCallIdsForStream, shouldApplyReplayToolCallIdSanitizer, + wrapStreamFnPromoteStandaloneTextToolCalls, } from "./attempt.tool-call-normalization.js"; type AssistantMessage = Extract; type ToolResultMessage = Extract; +type FakeWrappedStream = { + result: () => Promise; + [Symbol.asyncIterator]: () => AsyncIterator; +}; + +function createFakeStream(params: { + events: unknown[]; + resultMessage: unknown; +}): FakeWrappedStream { + return { + async result() { + return params.resultMessage; + }, + [Symbol.asyncIterator]() { + return (async function* () { + for (const event of params.events) { + yield event; + } + })(); + }, + }; +} + +async function collectStreamEvents(stream: AsyncIterable): Promise { + const events: unknown[] = []; + for await (const event of stream) { + events.push(event); + } + return events; +} + +function requireRecord(value: unknown, label: string): Record { + if (!value || typeof value !== "object") { + throw new Error(`expected ${label}`); + } + return value as Record; +} function requireAssistantMessage(message: AgentMessage | undefined): AssistantMessage { if (!message || message.role !== "assistant") { @@ -50,6 +88,731 @@ function toolResultSummary(message: AgentMessage | undefined) { }; } +describe("wrapStreamFnPromoteStandaloneTextToolCalls", () => { + it("promotes standalone serialized parameter XML text to structured tool calls", async () => { + const rawToolText = [ + "[tool:exec]", + "", + "cat /proc/mounts 2>/dev/null | head -20", + "", + "", + "", + "", + "", + "find / -maxdepth 4 -type d 2>/dev/null | head -20", + "", + "", + ].join("\n"); + const resultMessage = { + role: "assistant", + content: [ + { type: "thinking", thinking: "Need to audit the mount." }, + { type: "text", text: rawToolText }, + ], + stopReason: "stop", + }; + const baseFn = vi.fn(() => + createFakeStream({ + events: [ + { type: "start", partial: { content: [] } }, + { + type: "text_start", + contentIndex: 1, + partial: { content: [{ type: "text", text: "" }] }, + }, + { type: "text_delta", contentIndex: 1, delta: rawToolText }, + { type: "text_end", contentIndex: 1, content: rawToolText }, + { type: "done", reason: "stop", message: resultMessage }, + ], + resultMessage, + }), + ); + const wrapped = wrapStreamFnPromoteStandaloneTextToolCalls(baseFn as never, new Set(["exec"])); + const stream = (await Promise.resolve( + wrapped({} as never, {} as never, {} as never), + )) as FakeWrappedStream; + + const events = await collectStreamEvents(stream); + const result = requireRecord(await stream.result(), "result message"); + + expect(events.map((event) => requireRecord(event, "event").type)).toEqual([ + "start", + "toolcall_start", + "toolcall_delta", + "toolcall_start", + "toolcall_delta", + "done", + ]); + expect(requireRecord(events.at(-1), "done").reason).toBe("toolUse"); + expect(result.stopReason).toBe("toolUse"); + const content = result.content as Array>; + expect(content).toHaveLength(3); + expect(content[0]).toEqual({ type: "thinking", thinking: "Need to audit the mount." }); + expect(content[1]).toMatchObject({ + type: "toolCall", + name: "exec", + arguments: { command: "cat /proc/mounts 2>/dev/null | head -20" }, + partialArgs: '{"command":"cat /proc/mounts 2>/dev/null | head -20"}', + }); + expect(String(content[1].id)).toMatch(/^call_[a-f0-9]{24}$/); + expect(content[2]).toMatchObject({ + type: "toolCall", + name: "exec", + arguments: { command: "find / -maxdepth 4 -type d 2>/dev/null | head -20" }, + }); + }); + + it("preserves content indexes when promoting text before thinking", async () => { + const rawToolText = [ + "[tool:exec]", + "", + "pwd", + "", + "", + ].join("\n"); + const resultMessage = { + role: "assistant", + content: [ + { type: "text", text: rawToolText }, + { type: "thinking", thinking: "Need the current directory." }, + ], + stopReason: "stop", + }; + const baseFn = vi.fn(() => + createFakeStream({ + events: [ + { type: "text_delta", contentIndex: 0, delta: rawToolText }, + { + type: "thinking_delta", + contentIndex: 1, + delta: "Need the current directory.", + partial: { + content: [ + { type: "text", text: rawToolText }, + { type: "thinking", thinking: "Need the current directory." }, + ], + }, + }, + { type: "done", reason: "stop", message: resultMessage }, + ], + resultMessage, + }), + ); + const wrapped = wrapStreamFnPromoteStandaloneTextToolCalls(baseFn as never, new Set(["exec"])); + const stream = (await Promise.resolve( + wrapped({} as never, {} as never, {} as never), + )) as FakeWrappedStream; + + const events = await collectStreamEvents(stream); + const result = requireRecord(await stream.result(), "result message"); + + expect(events.map((event) => requireRecord(event, "event").type)).toEqual([ + "thinking_delta", + "toolcall_start", + "toolcall_delta", + "done", + ]); + expect(requireRecord(events[0], "thinking event").contentIndex).toBe(1); + expect(requireRecord(events[1], "toolcall start").contentIndex).toBe(0); + expect((result.content as Array>).map((block) => block.type)).toEqual([ + "toolCall", + "thinking", + ]); + }); + + it("preserves intervening thinking when promoting multiple text blocks", async () => { + const firstRawToolText = [ + "[tool:exec]", + "", + "pwd", + "", + "", + ].join("\n"); + const secondRawToolText = [ + "[tool:exec]", + "", + "whoami", + "", + "", + ].join("\n"); + const resultMessage = { + role: "assistant", + content: [ + { type: "text", text: firstRawToolText }, + { type: "thinking", thinking: "Need one more check." }, + { type: "text", text: secondRawToolText }, + ], + stopReason: "stop", + }; + const baseFn = vi.fn(() => + createFakeStream({ + events: [ + { type: "text_delta", contentIndex: 0, delta: firstRawToolText }, + { + type: "thinking_delta", + contentIndex: 1, + delta: "Need one more check.", + partial: { + content: [ + { type: "text", text: firstRawToolText }, + { type: "thinking", thinking: "Need one more check." }, + { type: "text", text: secondRawToolText }, + ], + }, + }, + { type: "text_delta", contentIndex: 2, delta: secondRawToolText }, + { type: "done", reason: "stop", message: resultMessage }, + ], + resultMessage, + }), + ); + const wrapped = wrapStreamFnPromoteStandaloneTextToolCalls(baseFn as never, new Set(["exec"])); + const stream = (await Promise.resolve( + wrapped({} as never, {} as never, {} as never), + )) as FakeWrappedStream; + + const events = await collectStreamEvents(stream); + const result = requireRecord(await stream.result(), "result message"); + + expect(events.map((event) => requireRecord(event, "event").type)).toEqual([ + "thinking_delta", + "toolcall_start", + "toolcall_delta", + "toolcall_start", + "toolcall_delta", + "done", + ]); + expect(requireRecord(events[0], "thinking event").contentIndex).toBe(1); + expect(requireRecord(events[1], "first toolcall start").contentIndex).toBe(0); + expect(requireRecord(events[3], "second toolcall start").contentIndex).toBe(2); + expect((result.content as Array>).map((block) => block.type)).toEqual([ + "toolCall", + "thinking", + "toolCall", + ]); + expect(requireRecord((result.content as unknown[])[0], "first tool call")).toMatchObject({ + name: "exec", + arguments: { command: "pwd" }, + }); + expect(requireRecord((result.content as unknown[])[2], "second tool call")).toMatchObject({ + name: "exec", + arguments: { command: "whoami" }, + }); + }); + + it("promotes serialized tool calls split across adjacent text blocks", async () => { + const resultMessage = { + role: "assistant", + content: [ + { type: "text", text: "[tool:exec]\n\n" }, + { type: "text", text: "pwd\n\n" }, + { type: "thinking", thinking: "Checking location." }, + ], + stopReason: "stop", + }; + const baseFn = vi.fn(() => + createFakeStream({ + events: [ + { type: "text_delta", contentIndex: 0, delta: "[tool:exec]\n\n" }, + { type: "text_delta", contentIndex: 1, delta: "pwd\n\n" }, + { + type: "thinking_delta", + contentIndex: 2, + delta: "Checking location.", + partial: { content: resultMessage.content }, + }, + { type: "done", reason: "stop", message: resultMessage }, + ], + resultMessage, + }), + ); + const wrapped = wrapStreamFnPromoteStandaloneTextToolCalls(baseFn as never, new Set(["exec"])); + const stream = (await Promise.resolve( + wrapped({} as never, {} as never, {} as never), + )) as FakeWrappedStream; + + const events = await collectStreamEvents(stream); + const result = requireRecord(await stream.result(), "result message"); + + expect(events.map((event) => requireRecord(event, "event").type)).toEqual([ + "thinking_delta", + "toolcall_start", + "toolcall_delta", + "done", + ]); + expect(requireRecord(events[0], "thinking event").contentIndex).toBe(2); + expect(requireRecord(events[1], "toolcall start").contentIndex).toBe(0); + expect((result.content as Array>).map((block) => block.type)).toEqual([ + "toolCall", + "thinking", + ]); + expect(requireRecord((result.content as unknown[])[0], "tool call")).toMatchObject({ + name: "exec", + arguments: { command: "pwd" }, + }); + }); + + it("buffers case-insensitive tool-name prefixes until final promotion", async () => { + const rawToolText = [ + "[tool:read]", + "", + "src/index.ts", + "", + "", + ].join("\n"); + const resultMessage = { + role: "assistant", + content: [{ type: "text", text: rawToolText }], + stopReason: "stop", + }; + const baseFn = vi.fn(() => + createFakeStream({ + events: [ + { type: "text_delta", contentIndex: 0, delta: "[tool:rea" }, + { type: "text_delta", contentIndex: 0, delta: rawToolText.slice("[tool:rea".length) }, + { type: "done", reason: "stop", message: resultMessage }, + ], + resultMessage, + }), + ); + const wrapped = wrapStreamFnPromoteStandaloneTextToolCalls(baseFn as never, new Set(["Read"])); + const stream = (await Promise.resolve( + wrapped({} as never, {} as never, {} as never), + )) as FakeWrappedStream; + + const events = await collectStreamEvents(stream); + const result = requireRecord(await stream.result(), "result message"); + + expect(events.map((event) => requireRecord(event, "event").type)).toEqual([ + "toolcall_start", + "toolcall_delta", + "done", + ]); + expect(result.stopReason).toBe("toolUse"); + expect(requireRecord((result.content as unknown[])[0], "tool call")).toMatchObject({ + type: "toolCall", + name: "Read", + arguments: { path: "src/index.ts" }, + }); + }); + + it("buffers normalized alias tool-name prefixes until final promotion", async () => { + const rawToolText = [ + "[tool:bash]", + "", + "pwd", + "", + "", + ].join("\n"); + const resultMessage = { + role: "assistant", + content: [{ type: "text", text: rawToolText }], + stopReason: "stop", + }; + const baseFn = vi.fn(() => + createFakeStream({ + events: [ + { type: "text_delta", contentIndex: 0, delta: "[tool:ba" }, + { type: "text_delta", contentIndex: 0, delta: rawToolText.slice("[tool:ba".length) }, + { type: "done", reason: "stop", message: resultMessage }, + ], + resultMessage, + }), + ); + const wrapped = wrapStreamFnPromoteStandaloneTextToolCalls(baseFn as never, new Set(["exec"])); + const stream = (await Promise.resolve( + wrapped({} as never, {} as never, {} as never), + )) as FakeWrappedStream; + + const events = await collectStreamEvents(stream); + const result = requireRecord(await stream.result(), "result message"); + + expect(events.map((event) => requireRecord(event, "event").type)).toEqual([ + "toolcall_start", + "toolcall_delta", + "done", + ]); + expect(requireRecord((result.content as unknown[])[0], "tool call")).toMatchObject({ + type: "toolCall", + name: "exec", + arguments: { command: "pwd" }, + }); + }); + + it("keeps possible tool-call text buffered across interleaved non-text events", async () => { + const rawToolText = [ + "[tool:exec]", + "", + "pwd", + "", + "", + ].join("\n"); + const resultMessage = { + role: "assistant", + content: [ + { type: "thinking", thinking: "Need shell state." }, + { type: "text", text: rawToolText }, + ], + stopReason: "stop", + }; + const baseFn = vi.fn(() => + createFakeStream({ + events: [ + { type: "text_delta", contentIndex: 1, delta: rawToolText }, + { + type: "thinking_delta", + contentIndex: 0, + delta: "Need shell state.", + partial: { + content: [ + { type: "thinking", thinking: "Need shell state." }, + { type: "text", text: rawToolText }, + ], + }, + }, + { type: "done", reason: "stop", message: resultMessage }, + ], + resultMessage, + }), + ); + const wrapped = wrapStreamFnPromoteStandaloneTextToolCalls(baseFn as never, new Set(["exec"])); + const stream = (await Promise.resolve( + wrapped({} as never, {} as never, {} as never), + )) as FakeWrappedStream; + + const events = await collectStreamEvents(stream); + + expect(events.map((event) => requireRecord(event, "event").type)).toEqual([ + "thinking_delta", + "toolcall_start", + "toolcall_delta", + "done", + ]); + const thinkingEvent = requireRecord(events[0], "thinking event"); + expect(requireRecord(thinkingEvent.partial, "thinking partial").content).toEqual([ + { type: "thinking", thinking: "Need shell state." }, + ]); + expect(JSON.stringify(events)).not.toContain(rawToolText); + }); + + it("preserves interleaved event content indexes when buffered text is scrubbed first", async () => { + const rawToolText = [ + "[tool:exec]", + "", + "pwd", + "", + "", + ].join("\n"); + const resultMessage = { + role: "assistant", + content: [ + { type: "text", text: rawToolText }, + { type: "thinking", thinking: "Need shell state." }, + ], + stopReason: "stop", + }; + const baseFn = vi.fn(() => + createFakeStream({ + events: [ + { type: "text_delta", contentIndex: 0, delta: rawToolText }, + { + type: "thinking_delta", + contentIndex: 1, + delta: "Need shell state.", + partial: { + content: [ + { type: "text", text: rawToolText }, + { type: "thinking", thinking: "Need shell state." }, + ], + }, + }, + { type: "done", reason: "stop", message: resultMessage }, + ], + resultMessage, + }), + ); + const wrapped = wrapStreamFnPromoteStandaloneTextToolCalls(baseFn as never, new Set(["exec"])); + const stream = (await Promise.resolve( + wrapped({} as never, {} as never, {} as never), + )) as FakeWrappedStream; + + const events = await collectStreamEvents(stream); + + expect(events.map((event) => requireRecord(event, "event").type)).toEqual([ + "thinking_delta", + "toolcall_start", + "toolcall_delta", + "done", + ]); + const thinkingEvent = requireRecord(events[0], "thinking event"); + expect(thinkingEvent.contentIndex).toBe(1); + expect(requireRecord(thinkingEvent.partial, "thinking partial").content).toEqual([ + { type: "text", text: "" }, + { type: "thinking", thinking: "Need shell state." }, + ]); + expect(JSON.stringify(events)).not.toContain(rawToolText); + }); + + it("closes the underlying stream iterator when consumers stop early", async () => { + const returnIterator = vi.fn(async () => ({ done: true, value: undefined })); + const nextIterator = vi + .fn() + .mockResolvedValueOnce({ done: false, value: { type: "start", partial: { content: [] } } }) + .mockResolvedValue({ done: true, value: undefined }); + const baseFn = vi.fn(() => ({ + async result() { + return { role: "assistant", content: [], stopReason: "stop" }; + }, + [Symbol.asyncIterator]() { + return { + next: nextIterator, + return: returnIterator, + }; + }, + })); + const wrapped = wrapStreamFnPromoteStandaloneTextToolCalls(baseFn as never, new Set(["exec"])); + const stream = (await Promise.resolve( + wrapped({} as never, {} as never, {} as never), + )) as FakeWrappedStream; + const iterator = stream[Symbol.asyncIterator](); + + expect(await iterator.next()).toEqual({ + done: false, + value: { type: "start", partial: { content: [] } }, + }); + await iterator.return?.(); + + expect(returnIterator).toHaveBeenCalledTimes(1); + }); + + it("flushes buffered text before terminal error events", async () => { + const rawToolText = "[tool:exec]"; + const errorEvent = { type: "error", error: new Error("stream failed") }; + const baseFn = vi.fn(() => + createFakeStream({ + events: [{ type: "text_delta", contentIndex: 0, delta: rawToolText }, errorEvent], + resultMessage: { role: "assistant", content: [], stopReason: "stop" }, + }), + ); + const wrapped = wrapStreamFnPromoteStandaloneTextToolCalls(baseFn as never, new Set(["exec"])); + const stream = (await Promise.resolve( + wrapped({} as never, {} as never, {} as never), + )) as FakeWrappedStream; + + const events = await collectStreamEvents(stream); + + expect(events).toEqual([ + { type: "text_delta", contentIndex: 0, delta: rawToolText }, + errorEvent, + ]); + }); + + it("buffers split XML function markers until final promotion", async () => { + const rawToolText = [ + "", + "", + "pwd", + "", + "", + ].join("\n"); + const resultMessage = { + role: "assistant", + content: [{ type: "text", text: rawToolText }], + stopReason: "stop", + }; + const baseFn = vi.fn(() => + createFakeStream({ + events: [ + { type: "text_delta", contentIndex: 0, delta: "<" }, + { type: "text_delta", contentIndex: 0, delta: rawToolText.slice(1) }, + { type: "done", reason: "stop", message: resultMessage }, + ], + resultMessage, + }), + ); + const wrapped = wrapStreamFnPromoteStandaloneTextToolCalls(baseFn as never, new Set(["exec"])); + const stream = (await Promise.resolve( + wrapped({} as never, {} as never, {} as never), + )) as FakeWrappedStream; + + const events = await collectStreamEvents(stream); + + expect(events.map((event) => requireRecord(event, "event").type)).toEqual([ + "toolcall_start", + "toolcall_delta", + "done", + ]); + }); + + it("suppresses over-cap serialized XMLish text instead of flushing it", async () => { + const rawToolText = [ + "[tool:exec]", + "", + "x".repeat(256_001), + "", + "", + ].join("\n"); + const resultMessage = { + role: "assistant", + content: [{ type: "text", text: rawToolText }], + stopReason: "stop", + }; + const baseFn = vi.fn(() => + createFakeStream({ + events: [ + { type: "start", partial: { content: [] } }, + { + type: "text_start", + contentIndex: 0, + partial: { content: [{ type: "text", text: "" }] }, + }, + { type: "text_delta", contentIndex: 0, delta: rawToolText }, + { + type: "thinking_delta", + contentIndex: 1, + delta: "still thinking", + partial: { + content: [ + { type: "text", text: rawToolText }, + { type: "thinking", thinking: "still thinking" }, + ], + }, + }, + { type: "text_end", contentIndex: 0, content: rawToolText }, + { type: "done", reason: "stop", message: resultMessage }, + ], + resultMessage, + }), + ); + const wrapped = wrapStreamFnPromoteStandaloneTextToolCalls(baseFn as never, new Set(["exec"])); + const stream = (await Promise.resolve( + wrapped({} as never, {} as never, {} as never), + )) as FakeWrappedStream; + + const events = await collectStreamEvents(stream); + const result = requireRecord(await stream.result(), "result message"); + + expect(events.map((event) => requireRecord(event, "event").type)).toEqual([ + "start", + "thinking_delta", + "done", + ]); + const thinkingEvent = requireRecord(events[1], "thinking event"); + expect(requireRecord(thinkingEvent.partial, "thinking partial").content).toEqual([ + { type: "text", text: "" }, + { type: "thinking", thinking: "still thinking" }, + ]); + const doneEvent = requireRecord(events[2], "done event"); + expect(doneEvent.reason).toBe("stop"); + expect(doneEvent.message).toMatchObject({ + role: "assistant", + content: [], + stopReason: "stop", + }); + expect(result).toMatchObject({ role: "assistant", content: [], stopReason: "stop" }); + expect(JSON.stringify(events)).not.toContain("[tool:exec]"); + expect(JSON.stringify(result)).not.toContain("[tool:exec]"); + }); + + it("scrubs split over-cap serialized XMLish text blocks from done messages", async () => { + const rawToolTextParts = [ + "[tool:exec]\n", + ["x".repeat(256_001), "", ""].join("\n"), + ]; + const resultMessage = { + role: "assistant", + content: rawToolTextParts.map((text) => ({ type: "text", text })), + stopReason: "stop", + }; + const baseFn = vi.fn(() => + createFakeStream({ + events: [{ type: "done", reason: "stop", message: resultMessage }], + resultMessage, + }), + ); + const wrapped = wrapStreamFnPromoteStandaloneTextToolCalls(baseFn as never, new Set(["exec"])); + const stream = (await Promise.resolve( + wrapped({} as never, {} as never, {} as never), + )) as FakeWrappedStream; + + const events = await collectStreamEvents(stream); + const result = requireRecord(await stream.result(), "result message"); + + expect(requireRecord(events[0], "done event").message).toMatchObject({ + role: "assistant", + content: [], + stopReason: "stop", + }); + expect(result).toMatchObject({ role: "assistant", content: [], stopReason: "stop" }); + expect(JSON.stringify(events)).not.toContain("[tool:exec]"); + expect(JSON.stringify(result)).not.toContain(""); + }); + + it("preserves visible suffix text after an over-cap JSON tool payload", async () => { + const visibleSuffix = "Visible answer after oversized JSON."; + const rawText = [`[tool:exec] {"command":"${"x".repeat(256_001)}"}`, visibleSuffix].join("\n"); + const resultMessage = { + role: "assistant", + content: [{ type: "text", text: rawText }], + stopReason: "stop", + }; + const baseFn = vi.fn(() => + createFakeStream({ + events: [ + { type: "text_delta", contentIndex: 0, delta: rawText }, + { type: "done", reason: "stop", message: resultMessage }, + ], + resultMessage, + }), + ); + const wrapped = wrapStreamFnPromoteStandaloneTextToolCalls(baseFn as never, new Set(["exec"])); + const stream = (await Promise.resolve( + wrapped({} as never, {} as never, {} as never), + )) as FakeWrappedStream; + + const events = await collectStreamEvents(stream); + + expect(events.map((event) => requireRecord(event, "event").type)).toEqual([ + "text_delta", + "done", + ]); + const textEvent = requireRecord(events[0], "text event"); + expect(String(textEvent.delta)).toBe(visibleSuffix); + expect(requireRecord(textEvent.partial, "text partial").content).toEqual([ + { type: "text", text: visibleSuffix }, + ]); + expect(JSON.stringify(events)).not.toContain("[tool:exec]"); + }); + + it("does not buffer normal prose that starts like a final answer", async () => { + const resultMessage = { + role: "assistant", + content: [{ type: "text", text: "Finally, the audit is done." }], + stopReason: "stop", + }; + const baseFn = vi.fn(() => + createFakeStream({ + events: [ + { type: "text_delta", contentIndex: 0, delta: "Finally, the audit is done." }, + { type: "done", reason: "stop", message: resultMessage }, + ], + resultMessage, + }), + ); + const wrapped = wrapStreamFnPromoteStandaloneTextToolCalls(baseFn as never, new Set(["exec"])); + const stream = (await Promise.resolve( + wrapped({} as never, {} as never, {} as never), + )) as FakeWrappedStream; + + const events = await collectStreamEvents(stream); + + expect(events).toEqual([ + { type: "text_delta", contentIndex: 0, delta: "Finally, the audit is done." }, + { type: "done", reason: "stop", message: resultMessage }, + ]); + }); +}); + describe("sanitizeReplayToolCallIdsForStream", () => { it("skips strict stream id sanitization when provider policy opts out", () => { expect( diff --git a/src/agents/embedded-agent-runner/run/attempt.tool-call-normalization.ts b/src/agents/embedded-agent-runner/run/attempt.tool-call-normalization.ts index ef00e5e71cd..f69421afa3a 100644 --- a/src/agents/embedded-agent-runner/run/attempt.tool-call-normalization.ts +++ b/src/agents/embedded-agent-runner/run/attempt.tool-call-normalization.ts @@ -1,5 +1,15 @@ +import { randomUUID } from "node:crypto"; +import { + extractStandalonePlainTextToolCallText, + normalizePlainTextToolCallStreamEvents, + promoteStandalonePlainTextToolCallMessage as promotePlainTextToolCallMessage, + scrubOverCapPlainTextToolCallMessage, + type PlainTextToolCallBlock, + type PlainTextToolCallNameMatcher, +} from "../../../../packages/tool-call-repair/src/index.js"; import { visitObjectContentBlocks } from "../../../shared/message-content-blocks.js"; import { normalizeLowercaseStringOrEmpty } from "../../../shared/string-coerce.js"; +import { normalizeStringEntries } from "../../../shared/string-normalization.js"; import { downgradeOpenAIFunctionCallReasoningPairs, downgradeOpenAIReasoningBlocks, @@ -9,14 +19,13 @@ import { } from "../../embedded-agent-helpers.js"; import type { AgentMessage, StreamFn } from "../../runtime/index.js"; import { sanitizeToolUseResultPairing } from "../../session-transcript-repair.js"; -import type { MutableAssistantMessageEventStream } from "../../stream-compat.js"; import { extractToolCallsFromAssistant, extractToolResultIds, sanitizeToolCallIdsForCloudCodeAssist, type ToolCallIdMode, } from "../../tool-call-id.js"; -import { normalizeToolName } from "../../tool-policy.js"; +import { couldNormalizeToolNamePrefixToAllowedTool, normalizeToolName } from "../../tool-policy.js"; import { shouldAllowProviderOwnedThinkingReplay } from "../../transcript-policy.js"; import type { TranscriptPolicy } from "../../transcript-policy.js"; import { wrapStreamObjectEvents } from "./stream-wrapper.js"; @@ -28,6 +37,7 @@ type UnknownToolLoopGuardState = { count: number; countedMessages: WeakSet; }; +type AssistantStream = Awaited>; function resolveCaseInsensitiveAllowedToolName( rawName: string, @@ -94,10 +104,7 @@ function buildStructuredToolNameCandidates(rawName: string): string[] { addCandidate(normalizedDelimiter); addCandidate(normalizeToolName(normalizedDelimiter)); - const segments = normalizedDelimiter - .split(".") - .map((segment) => segment.trim()) - .filter(Boolean); + const segments = normalizeStringEntries(normalizedDelimiter.split(".")); if (segments.length > 1) { for (let index = 1; index < segments.length; index += 1) { const suffix = segments.slice(index).join("."); @@ -849,11 +856,209 @@ function guardUnknownToolLoopInMessage( return true; } +type PromotedTextToolCallBlock = { + type: "toolCall"; + id: string; + name: string; + arguments: Record; + partialArgs: string; +}; + +function asRecord(value: unknown): Record | undefined { + return value && typeof value === "object" ? (value as Record) : undefined; +} + +function createStandaloneTextToolCallId(): string { + return `call_${randomUUID().replace(/-/g, "").slice(0, 24)}`; +} + +function createPromotedTextToolCallBlock( + block: PlainTextToolCallBlock, + name: string, +): PromotedTextToolCallBlock { + return { + type: "toolCall", + id: createStandaloneTextToolCallId(), + name, + arguments: block.arguments, + partialArgs: JSON.stringify(block.arguments), + }; +} + +function isRetainableNonVisibleBlock(block: Record): boolean { + return block.type === "thinking" || block.type === "redacted_thinking"; +} + +const STANDALONE_TEXT_TOOL_CALL_PROMOTION_STOP_REASONS = new Set(["stop"]); +const STANDALONE_TEXT_TOOL_CALL_SCRUB_STOP_REASONS = new Set(["stop", "length"]); + +function extractStandaloneTextToolCallCandidateForStopReasons( + message: unknown, + allowedStopReasons: ReadonlySet, +): + | { + text: string; + } + | undefined { + const text = extractStandalonePlainTextToolCallText({ + allowedStopReasons, + isRetainableNonTextBlock: isRetainableNonVisibleBlock, + message, + requireAssistantRole: true, + }); + return text ? { text } : undefined; +} + +function promoteStandaloneTextToolCallMessage( + message: unknown, + allowedToolNames?: Set, +): Record | undefined { + if (!allowedToolNames) { + return undefined; + } + return promotePlainTextToolCallMessage({ + allowedStopReasons: STANDALONE_TEXT_TOOL_CALL_PROMOTION_STOP_REASONS, + allowedToolNames, + createToolCallBlock: createPromotedTextToolCallBlock, + isRetainableNonTextBlock: isRetainableNonVisibleBlock, + message, + requireAssistantRole: true, + resolveToolName: resolveExactAllowedToolName, + }); +} + +function createPromotedToolCallEvents( + message: Record, +): Array> { + const content = Array.isArray(message.content) ? message.content : []; + const events: Array> = []; + content.forEach((block, contentIndex) => { + const record = asRecord(block); + if (record?.type !== "toolCall") { + return; + } + events.push({ type: "toolcall_start", contentIndex, partial: message }); + events.push({ + type: "toolcall_delta", + contentIndex, + delta: typeof record.partialArgs === "string" ? record.partialArgs : "{}", + partial: message, + }); + }); + return events; +} + +function createStandaloneToolCallNameMatcher( + allowedToolNames: Set, +): PlainTextToolCallNameMatcher { + return { + hasExactName: (name) => Boolean(resolveExactAllowedToolName(name, allowedToolNames)), + hasNamePrefix: (prefix) => couldNormalizeToolNamePrefixToAllowedTool(prefix, allowedToolNames), + }; +} + +function wrapStreamPromoteStandaloneTextToolCalls( + stream: AssistantStream, + allowedToolNames: Set, +): AssistantStream { + const matcher = createStandaloneToolCallNameMatcher(allowedToolNames); + const normalizedMessages = new WeakMap< + object, + { kind: "promoted" | "scrubbed"; message: Record } + >(); + const normalizeMessage = ( + message: unknown, + ): { kind: "promoted" | "scrubbed"; message: Record } | undefined => { + if (!message || typeof message !== "object") { + return undefined; + } + const cached = normalizedMessages.get(message); + if (cached) { + return cached; + } + const promoted = promoteStandaloneTextToolCallMessage(message, allowedToolNames); + if (promoted) { + const result = { kind: "promoted" as const, message: promoted }; + normalizedMessages.set(message, result); + return result; + } + const scrubbed = scrubOverCapPlainTextToolCallMessage({ + candidateText: extractStandaloneTextToolCallCandidateForStopReasons( + message, + STANDALONE_TEXT_TOOL_CALL_SCRUB_STOP_REASONS, + )?.text, + matcher, + message, + }); + if (scrubbed) { + const result = { kind: "scrubbed" as const, message: scrubbed }; + normalizedMessages.set(message, result); + return result; + } + return undefined; + }; + + const originalResult = stream.result.bind(stream); + stream.result = async () => { + const message = await originalResult(); + return (normalizeMessage(message)?.message ?? message) as Awaited< + ReturnType + >; + }; + + const originalAsyncIterator = stream[Symbol.asyncIterator].bind(stream); + (stream as unknown as { [Symbol.asyncIterator]: () => AsyncIterator })[ + Symbol.asyncIterator + ] = async function* () { + const source = { + [Symbol.asyncIterator]: originalAsyncIterator, + } as AsyncIterable; + yield* normalizePlainTextToolCallStreamEvents(source, { + createPromotedToolCallEvents, + matcher, + normalizeDoneMessage: ({ message, reason }) => { + if (reason === "stop") { + return normalizeMessage(message); + } + const scrubbed = scrubOverCapPlainTextToolCallMessage({ + candidateText: extractStandaloneTextToolCallCandidateForStopReasons( + message, + STANDALONE_TEXT_TOOL_CALL_SCRUB_STOP_REASONS, + )?.text, + matcher, + message, + }); + return scrubbed ? { kind: "scrubbed", message: scrubbed } : undefined; + }, + }); + }; + + return stream; +} + +export function wrapStreamFnPromoteStandaloneTextToolCalls( + baseFn: StreamFn, + allowedToolNames?: Set, +): StreamFn { + if (!allowedToolNames || allowedToolNames.size === 0) { + return baseFn; + } + return (model, context, streamOptions) => { + const maybeStream = baseFn(model, context, streamOptions); + if (maybeStream && typeof maybeStream === "object" && "then" in maybeStream) { + return Promise.resolve(maybeStream).then((stream) => + wrapStreamPromoteStandaloneTextToolCalls(stream, allowedToolNames), + ); + } + return wrapStreamPromoteStandaloneTextToolCalls(maybeStream, allowedToolNames); + }; +} + function wrapStreamTrimToolCallNames( - stream: MutableAssistantMessageEventStream, + stream: AssistantStream, allowedToolNames?: Set, options?: { unknownToolThreshold?: number; state?: UnknownToolLoopGuardState }, -): MutableAssistantMessageEventStream { +): AssistantStream { const unknownToolGuardState = options?.state ?? { count: 0, countedMessages: new WeakSet(), diff --git a/src/agents/embedded-agent-runner/run/attempt.ts b/src/agents/embedded-agent-runner/run/attempt.ts index 779128a9760..cedf814e296 100644 --- a/src/agents/embedded-agent-runner/run/attempt.ts +++ b/src/agents/embedded-agent-runner/run/attempt.ts @@ -379,6 +379,7 @@ import { sanitizeOpenAIResponsesReplayForStream, sanitizeReplayToolCallIdsForStream, shouldApplyReplayToolCallIdSanitizer, + wrapStreamFnPromoteStandaloneTextToolCalls, wrapStreamFnSanitizeMalformedToolCalls, wrapStreamFnTrimToolCallNames, } from "./attempt.tool-call-normalization.js"; @@ -454,6 +455,7 @@ export { wrapStreamFnRepairMalformedToolCallArguments, } from "./attempt.tool-call-argument-repair.js"; export { + wrapStreamFnPromoteStandaloneTextToolCalls, wrapStreamFnSanitizeMalformedToolCalls, wrapStreamFnTrimToolCallNames, } from "./attempt.tool-call-normalization.js"; @@ -2616,6 +2618,10 @@ export async function runEmbeddedAttempt( transcriptPolicy, params.provider, ); + activeSession.agent.streamFn = wrapStreamFnPromoteStandaloneTextToolCalls( + activeSession.agent.streamFn, + allowedToolNames, + ); activeSession.agent.streamFn = wrapStreamFnTrimToolCallNames( activeSession.agent.streamFn, allowedToolNames, diff --git a/src/agents/tool-policy-shared.ts b/src/agents/tool-policy-shared.ts index 35c67c16863..8645e3191b9 100644 --- a/src/agents/tool-policy-shared.ts +++ b/src/agents/tool-policy-shared.ts @@ -23,6 +23,50 @@ export function normalizeToolName(name: string) { return TOOL_NAME_ALIASES[normalized] ?? normalized; } +export function couldNormalizeToolNamePrefixToAllowedTool( + prefix: string, + allowedToolNames: Set, +): boolean { + const normalizedPrefix = normalizeLowercaseStringOrEmpty(prefix); + if (!normalizedPrefix) { + return false; + } + + const allowed = new Set(); + for (const toolName of allowedToolNames) { + const normalizedToolName = normalizeToolName(toolName); + const foldedToolName = normalizeLowercaseStringOrEmpty(toolName); + if (normalizedToolName) { + allowed.add(normalizedToolName); + } + if (foldedToolName) { + allowed.add(foldedToolName); + } + if ( + normalizedToolName.startsWith(normalizedPrefix) || + foldedToolName.startsWith(normalizedPrefix) + ) { + return true; + } + } + + const resolvedPrefix = normalizeToolName(normalizedPrefix); + if (resolvedPrefix !== normalizedPrefix) { + for (const toolName of allowed) { + if (toolName.startsWith(resolvedPrefix)) { + return true; + } + } + } + + for (const [alias, toolName] of Object.entries(TOOL_NAME_ALIASES)) { + if (alias.startsWith(normalizedPrefix) && allowed.has(toolName)) { + return true; + } + } + return false; +} + export function normalizeToolList(list?: string[]) { if (!list) { return []; diff --git a/src/agents/tool-policy.ts b/src/agents/tool-policy.ts index 258cdfd1eaa..4db3ad743d9 100644 --- a/src/agents/tool-policy.ts +++ b/src/agents/tool-policy.ts @@ -3,6 +3,7 @@ import { uniqueStrings } from "../shared/string-normalization.js"; import { IMPLICIT_ALLOW_ALL_FROM_ALSO_ALLOW } from "./sandbox-tool-policy.js"; import { expandToolGroups, normalizeToolList, normalizeToolName } from "./tool-policy-shared.js"; export { + couldNormalizeToolNamePrefixToAllowedTool, expandToolGroups, normalizeToolList, normalizeToolName, diff --git a/src/infra/outbound/message-action-runner.ts b/src/infra/outbound/message-action-runner.ts index daa362adb02..3425f197f22 100644 --- a/src/infra/outbound/message-action-runner.ts +++ b/src/infra/outbound/message-action-runner.ts @@ -1,4 +1,5 @@ import { resolveSendableOutboundReplyParts } from "openclaw/plugin-sdk/reply-payload"; +import { stripPlainTextToolCallBlocks } from "../../../packages/tool-call-repair/src/index.js"; import { resolveSessionAgentId } from "../../agents/agent-scope.js"; import type { AgentToolResult } from "../../agents/runtime/index.js"; import { @@ -30,7 +31,6 @@ import { import type { OutboundMediaAccess } from "../../media/load-options.js"; import { getAgentScopedMediaLocalRoots } from "../../media/local-roots.js"; import { resolveAgentScopedOutboundMediaAccess } from "../../media/read-capability.js"; -import { stripPlainTextToolCallBlocks } from "../../plugin-sdk/tool-payload.js"; import { hasPollCreationParams } from "../../poll-params.js"; import { resolvePollMaxSelections } from "../../polls.js"; import { resolveFirstBoundAccountId } from "../../routing/bound-account-read.js"; diff --git a/src/infra/outbound/sanitize-text.ts b/src/infra/outbound/sanitize-text.ts index 7531d9db574..f702757cf87 100644 --- a/src/infra/outbound/sanitize-text.ts +++ b/src/infra/outbound/sanitize-text.ts @@ -11,7 +11,7 @@ * @see https://github.com/openclaw/openclaw/issues/18558 */ -import { stripPlainTextToolCallBlocks } from "../../plugin-sdk/tool-payload.js"; +import { stripPlainTextToolCallBlocks } from "../../../packages/tool-call-repair/src/index.js"; const INTERNAL_RUNTIME_SCAFFOLDING_TAGS = ["system-reminder", "previous_response"] as const; const INTERNAL_RUNTIME_SCAFFOLDING_TAG_PATTERN = INTERNAL_RUNTIME_SCAFFOLDING_TAGS.join("|"); diff --git a/src/plugin-sdk/provider-stream-shared.test.ts b/src/plugin-sdk/provider-stream-shared.test.ts index 2f3e6e1b40c..b96c7769828 100644 --- a/src/plugin-sdk/provider-stream-shared.test.ts +++ b/src/plugin-sdk/provider-stream-shared.test.ts @@ -1,14 +1,68 @@ import type { StreamFn } from "openclaw/plugin-sdk/agent-core"; import { describe, expect, it } from "vitest"; +import { createAssistantMessageEventStream } from "../llm/utils/event-stream.js"; import { createDeepSeekV4OpenAICompatibleThinkingWrapper, createAnthropicThinkingPrefillPayloadWrapper, createPayloadPatchStreamWrapper, + createPlainTextToolCallCompatWrapper, defaultToolStreamExtraParams, isOpenAICompatibleThinkingEnabled, stripTrailingAnthropicAssistantPrefillWhenThinking, } from "./provider-stream-shared.js"; +type StreamEvent = { type: string } & Record; + +function requireRecord(value: unknown, label: string): Record { + if (!value || typeof value !== "object" || Array.isArray(value)) { + throw new Error(`expected ${label} to be a record`); + } + return value as Record; +} + +function createEventStream(events: unknown[]): ReturnType { + const output = createAssistantMessageEventStream(); + const stream = output as unknown as { push(event: unknown): void; end(): void }; + queueMicrotask(() => { + for (const event of events) { + stream.push(event); + } + stream.end(); + }); + return output as ReturnType; +} + +function createControlledPlainTextToolCallCompatStream() { + const source = createAssistantMessageEventStream(); + const baseStream: StreamFn = () => source as ReturnType; + const wrapped = createPlainTextToolCallCompatWrapper(baseStream); + const stream = wrapped( + { provider: "test", api: "openai-completions", id: "test-model" } as never, + { + messages: [], + tools: [{ name: "read", description: "Read", parameters: { type: "object" } }], + } as never, + {}, + ); + return { source, stream }; +} + +async function resolveStream(stream: ReturnType) { + return stream instanceof Promise ? await stream : stream; +} + +async function nextEvent(iterator: AsyncIterator, label: string): Promise { + const result = await Promise.race([ + iterator.next(), + new Promise<"timed out">((resolve) => setTimeout(() => resolve("timed out"), 50)), + ]); + if (result === "timed out") { + throw new Error(`timed out waiting for ${label}`); + } + expect(result.done).toBe(false); + return result.value as StreamEvent; +} + describe("defaultToolStreamExtraParams", () => { it("defaults tool_stream on when absent", () => { expect(defaultToolStreamExtraParams()).toEqual({ tool_stream: true }); @@ -139,6 +193,1822 @@ describe("createPayloadPatchStreamWrapper", () => { }); }); +describe("createPlainTextToolCallCompatWrapper", () => { + it("promotes standalone text tool calls into tool-call stream events", async () => { + const baseStreamFn: StreamFn = () => + createEventStream([ + { type: "text_start", content: "" }, + { type: "text_delta", delta: '[tool:read] {"path":"/tmp/file.txt"}' }, + { type: "text_end" }, + { + type: "done", + reason: "stop", + message: { + role: "assistant", + content: '[tool:read] {"path":"/tmp/file.txt"}', + }, + }, + ]); + const wrapped = createPlainTextToolCallCompatWrapper(baseStreamFn); + const events: unknown[] = []; + + for await (const event of wrapped( + {} as never, + { tools: [{ name: "read" }] } as never, + {}, + ) as AsyncIterable) { + events.push(event); + } + + expect(events.map((event) => (event as { type?: string }).type)).toEqual([ + "toolcall_start", + "toolcall_delta", + "done", + ]); + const done = events.at(-1) as { message?: { content?: unknown; stopReason?: unknown } }; + expect(done.message?.stopReason).toBe("toolUse"); + expect(done.message?.content).toEqual([ + expect.objectContaining({ + type: "toolCall", + name: "read", + arguments: { path: "/tmp/file.txt" }, + }), + ]); + }); + + it("promotes complete under-cap text tool calls for non-stop terminal reasons", async () => { + const rawToolText = '[tool:read] {"path":"/tmp/file.txt"}'; + const baseStreamFn: StreamFn = () => + createEventStream([ + { + type: "done", + reason: "length", + message: { + role: "assistant", + content: rawToolText, + stopReason: "length", + }, + }, + ]); + const wrapped = createPlainTextToolCallCompatWrapper(baseStreamFn); + const events: unknown[] = []; + + for await (const event of wrapped( + {} as never, + { tools: [{ name: "read" }] } as never, + {}, + ) as AsyncIterable) { + events.push(event); + } + + expect(events.map((event) => (event as { type?: string }).type)).toEqual([ + "toolcall_start", + "toolcall_delta", + "done", + ]); + const done = events.at(-1) as { reason?: unknown; message?: { stopReason?: unknown } }; + expect(done.reason).toBe("toolUse"); + expect(done.message?.stopReason).toBe("toolUse"); + }); + + it("passes through bracketed text when no configured tool names match", async () => { + const baseStreamFn: StreamFn = () => + createEventStream([ + { type: "text_delta", delta: "[note] keep streaming" }, + { + type: "done", + reason: "stop", + message: { + role: "assistant", + content: "[note] keep streaming", + }, + }, + ]); + const wrapped = createPlainTextToolCallCompatWrapper(baseStreamFn); + const events: unknown[] = []; + + for await (const event of wrapped( + {} as never, + { tools: [{ name: "read" }] } as never, + {}, + ) as AsyncIterable) { + events.push(event); + } + + expect(events.map((event) => (event as { type?: string }).type)).toEqual([ + "text_delta", + "done", + ]); + }); + + it("converts standalone plain-text tool calls for result consumers", async () => { + const { source, stream } = createControlledPlainTextToolCallCompatStream(); + const resultPromise = (await resolveStream(stream)).result(); + const rawToolText = '[tool:read] {"path":"src/index.ts"}'; + + source.push({ type: "start", partial: { content: [] } } as never); + source.push({ + type: "text_delta", + contentIndex: 0, + delta: rawToolText, + } as never); + source.push({ + type: "done", + reason: "stop", + message: { + role: "assistant", + content: [{ type: "text", text: rawToolText }], + stopReason: "stop", + }, + } as never); + source.end(); + + const message = requireRecord(await resultPromise, "result message"); + expect(message.stopReason).toBe("toolUse"); + expect(requireRecord((message.content as unknown[])[0], "tool call")).toMatchObject({ + type: "toolCall", + name: "read", + arguments: { path: "src/index.ts" }, + }); + }); + + it("promotes serialized tool calls split across adjacent text blocks", async () => { + const { source, stream } = createControlledPlainTextToolCallCompatStream(); + const resultPromise = (await resolveStream(stream)).result(); + const rawToolText = [ + "[tool:read]", + "", + "src/index.ts", + "", + "", + ].join("\n"); + + source.push({ type: "start", partial: { content: [] } } as never); + source.push({ + type: "text_delta", + contentIndex: 0, + delta: rawToolText, + } as never); + source.push({ + type: "done", + reason: "stop", + message: { + role: "assistant", + content: [ + { type: "text", text: "[tool:read]\n" }, + { type: "text", text: "src/index.ts\n\n" }, + ], + stopReason: "stop", + }, + } as never); + source.end(); + + const message = requireRecord(await resultPromise, "result message"); + expect(message.stopReason).toBe("toolUse"); + expect(requireRecord((message.content as unknown[])[0], "tool call")).toMatchObject({ + type: "toolCall", + name: "read", + arguments: { path: "src/index.ts" }, + }); + }); + + it("preserves exact text block adjacency inside promoted arguments", async () => { + const { source, stream } = createControlledPlainTextToolCallCompatStream(); + const resultPromise = (await resolveStream(stream)).result(); + + source.push({ + type: "done", + reason: "stop", + message: { + role: "assistant", + content: [ + { type: "text", text: "[tool:read]\n\nsrc/ind" }, + { type: "text", text: "ex.ts\n\n" }, + ], + stopReason: "stop", + }, + } as never); + source.end(); + + const message = requireRecord(await resultPromise, "result message"); + expect(requireRecord((message.content as unknown[])[0], "tool call")).toMatchObject({ + type: "toolCall", + name: "read", + arguments: { path: "src/index.ts" }, + }); + }); + + it("repairs bracketed tool-call block boundaries when providers split header text", async () => { + const { source, stream } = createControlledPlainTextToolCallCompatStream(); + const resultPromise = (await resolveStream(stream)).result(); + + source.push({ + type: "done", + reason: "stop", + message: { + role: "assistant", + content: [ + { type: "text", text: "[read]" }, + { type: "text", text: '{"path":"src/index.ts"}\n[END_TOOL_REQUEST]' }, + ], + stopReason: "stop", + }, + } as never); + source.end(); + + const message = requireRecord(await resultPromise, "result message"); + expect(requireRecord((message.content as unknown[])[0], "tool call")).toMatchObject({ + type: "toolCall", + name: "read", + arguments: { path: "src/index.ts" }, + }); + }); + + it("keeps possible tool-call text buffered across interleaved non-text events", async () => { + const rawToolText = [ + "[tool:read]", + "", + "src/index.ts", + "", + "", + ].join("\n"); + const baseStreamFn: StreamFn = () => + createEventStream([ + { type: "text_delta", contentIndex: 1, delta: rawToolText }, + { + type: "thinking_delta", + contentIndex: 0, + delta: "Need file contents.", + partial: { + content: [ + { type: "thinking", thinking: "Need file contents." }, + { type: "text", text: rawToolText }, + ], + }, + }, + { + type: "done", + reason: "stop", + message: { + role: "assistant", + content: [ + { type: "thinking", thinking: "Need file contents." }, + { type: "text", text: rawToolText }, + ], + stopReason: "stop", + }, + }, + ]); + const wrapped = createPlainTextToolCallCompatWrapper(baseStreamFn); + const stream = await resolveStream( + wrapped({} as never, { tools: [{ name: "read" }] } as never, {}), + ); + const events: unknown[] = []; + + for await (const event of stream as AsyncIterable) { + events.push(event); + } + + expect(events.map((event) => (event as { type?: string }).type)).toEqual([ + "thinking_delta", + "toolcall_start", + "toolcall_delta", + "done", + ]); + const thinkingEvent = requireRecord(events[0], "thinking event"); + expect(requireRecord(thinkingEvent.partial, "thinking partial").content).toEqual([ + { type: "thinking", thinking: "Need file contents." }, + ]); + expect(JSON.stringify(events)).not.toContain(rawToolText); + }); + + it("preserves interleaved event content indexes when buffered text is scrubbed first", async () => { + const rawToolText = [ + "[tool:read]", + "", + "src/index.ts", + "", + "", + ].join("\n"); + const baseStreamFn: StreamFn = () => + createEventStream([ + { type: "text_delta", contentIndex: 0, delta: rawToolText }, + { + type: "thinking_delta", + contentIndex: 1, + delta: "Need file contents.", + partial: { + content: [ + { type: "text", text: rawToolText }, + { type: "thinking", thinking: "Need file contents." }, + ], + }, + }, + { + type: "done", + reason: "stop", + message: { + role: "assistant", + content: [ + { type: "text", text: rawToolText }, + { type: "thinking", thinking: "Need file contents." }, + ], + stopReason: "stop", + }, + }, + ]); + const wrapped = createPlainTextToolCallCompatWrapper(baseStreamFn); + const stream = await resolveStream( + wrapped({} as never, { tools: [{ name: "read" }] } as never, {}), + ); + const events: unknown[] = []; + + for await (const event of stream as AsyncIterable) { + events.push(event); + } + + expect(events.map((event) => (event as { type?: string }).type)).toEqual([ + "thinking_delta", + "toolcall_start", + "toolcall_delta", + "done", + ]); + const thinkingEvent = requireRecord(events[0], "thinking event"); + expect(thinkingEvent.contentIndex).toBe(1); + expect(requireRecord(thinkingEvent.partial, "thinking partial").content).toEqual([ + { type: "text", text: "" }, + { type: "thinking", thinking: "Need file contents." }, + ]); + expect(JSON.stringify(events)).not.toContain(rawToolText); + }); + + it("flushes false-positive buffered prefixes around interleaved events in source order", async () => { + const firstText = "[tool:re"; + const secondText = " not a call"; + const baseStreamFn: StreamFn = () => + createEventStream([ + { type: "text_delta", contentIndex: 0, delta: firstText }, + { + type: "thinking_delta", + contentIndex: 1, + delta: "Need file contents.", + partial: { + content: [ + { type: "text", text: firstText }, + { type: "thinking", thinking: "Need file contents." }, + ], + }, + }, + { type: "text_delta", contentIndex: 0, delta: secondText }, + { + type: "done", + reason: "stop", + message: { + role: "assistant", + content: [ + { type: "text", text: `${firstText}${secondText}` }, + { type: "thinking", thinking: "Need file contents." }, + ], + stopReason: "stop", + }, + }, + ]); + const wrapped = createPlainTextToolCallCompatWrapper(baseStreamFn); + const stream = await resolveStream( + wrapped({} as never, { tools: [{ name: "read" }] } as never, {}), + ); + const events: unknown[] = []; + + for await (const event of stream as AsyncIterable) { + events.push(event); + } + + expect(events.map((event) => (event as { type?: string }).type)).toEqual([ + "text_delta", + "thinking_delta", + "text_delta", + "done", + ]); + expect(requireRecord(events[0], "first text").delta).toBe(firstText); + const thinkingEvent = requireRecord(events[1], "thinking event"); + expect(requireRecord(thinkingEvent.partial, "thinking partial").content).toEqual([ + { type: "text", text: firstText }, + { type: "thinking", thinking: "Need file contents." }, + ]); + expect(requireRecord(events[2], "second text").delta).toBe(secondText); + }); + + it("keeps CR-separated bracketed tool calls buffered for conversion", async () => { + const { source, stream } = createControlledPlainTextToolCallCompatStream(); + const iterator = (await resolveStream(stream))[Symbol.asyncIterator](); + + try { + source.push({ type: "start", partial: { content: [] } } as never); + expect((await nextEvent(iterator, "start")).type).toBe("start"); + + source.push({ + type: "text_delta", + contentIndex: 0, + delta: '[read]\r{"path":"src/index.ts"}\r[END_TOOL_REQUEST]', + } as never); + source.push({ + type: "done", + reason: "stop", + message: { + role: "assistant", + content: [{ type: "text", text: '[read]\r{"path":"src/index.ts"}\r[END_TOOL_REQUEST]' }], + stopReason: "stop", + }, + } as never); + + const event = await nextEvent(iterator, "converted CR tool call"); + expect(event.type).toBe("toolcall_start"); + } finally { + source.end(); + await iterator.return?.(); + } + }); + + it("keeps bracketed XML parameter tool calls buffered for conversion", async () => { + const { source, stream } = createControlledPlainTextToolCallCompatStream(); + const iterator = (await resolveStream(stream))[Symbol.asyncIterator](); + const rawToolText = [ + "[tool:read]", + "", + "src/index.ts", + "", + "", + ].join("\n"); + + try { + source.push({ type: "start", partial: { content: [] } } as never); + expect((await nextEvent(iterator, "start")).type).toBe("start"); + + source.push({ + type: "text_delta", + contentIndex: 0, + delta: rawToolText, + } as never); + source.push({ + type: "done", + reason: "stop", + message: { + role: "assistant", + content: [{ type: "text", text: rawToolText }], + stopReason: "stop", + }, + } as never); + + const event = await nextEvent(iterator, "converted bracketed XML tool call"); + expect(event.type).toBe("toolcall_start"); + } finally { + source.end(); + await iterator.return?.(); + } + }); + + it("suppresses over-cap bracketed XML parameter text instead of streaming it", async () => { + const oversizedPath = "x".repeat(256_001); + const rawToolText = [ + "[tool:read]", + "", + oversizedPath, + "", + "", + ].join("\n"); + const baseStreamFn: StreamFn = () => + createEventStream([ + { type: "start", partial: { content: [] } }, + { type: "text_start", contentIndex: 0, content: "" }, + { type: "text_delta", contentIndex: 0, delta: rawToolText }, + { + type: "thinking_delta", + contentIndex: 1, + delta: "checking", + partial: { + content: [ + { type: "text", text: rawToolText }, + { type: "thinking", thinking: "checking" }, + ], + }, + }, + { type: "text_end", contentIndex: 0, content: rawToolText }, + { + type: "done", + reason: "stop", + message: { + role: "assistant", + content: [{ type: "text", text: rawToolText }], + stopReason: "stop", + }, + }, + ]); + const wrapped = createPlainTextToolCallCompatWrapper(baseStreamFn); + const events: unknown[] = []; + + for await (const event of wrapped( + {} as never, + { tools: [{ name: "read" }] } as never, + {}, + ) as AsyncIterable) { + events.push(event); + } + + expect(events.map((event) => (event as { type?: string }).type)).toEqual([ + "start", + "thinking_delta", + "done", + ]); + const thinkingEvent = requireRecord(events[1], "thinking event"); + expect(requireRecord(thinkingEvent.partial, "thinking partial").content).toEqual([ + { type: "text", text: "" }, + { type: "thinking", thinking: "checking" }, + ]); + const doneEvent = requireRecord(events[2], "done event"); + expect(doneEvent.reason).toBe("stop"); + expect(doneEvent.message).toMatchObject({ + role: "assistant", + content: [], + stopReason: "stop", + }); + expect(JSON.stringify(events)).not.toContain("[tool:read]"); + }); + + it("scrubs over-cap bracketed XML parameter text from terminal error partials", async () => { + const rawToolText = ["[tool:read]", "", "x".repeat(256_001)].join("\n"); + const baseStreamFn: StreamFn = () => + createEventStream([ + { type: "text_delta", contentIndex: 0, delta: rawToolText }, + { + type: "error", + partial: { + content: [ + { type: "text", text: rawToolText }, + { type: "thinking", thinking: "checking" }, + ], + }, + error: { + content: [{ type: "text", text: rawToolText }], + errorMessage: "stream failed", + }, + }, + ]); + const wrapped = createPlainTextToolCallCompatWrapper(baseStreamFn); + const events: unknown[] = []; + + for await (const event of wrapped( + {} as never, + { tools: [{ name: "read" }] } as never, + {}, + ) as AsyncIterable) { + events.push(event); + } + + expect(events.map((event) => (event as { type?: string }).type)).toEqual(["error"]); + const errorEvent = requireRecord(events[0], "error event"); + expect(requireRecord(errorEvent.partial, "error partial").content).toEqual([ + { type: "text", text: "" }, + { type: "thinking", thinking: "checking" }, + ]); + expect(requireRecord(errorEvent.error, "error body").content).toEqual([]); + expect(JSON.stringify(events)).not.toContain("[tool:read]"); + }); + + it("scrubs over-cap bracketed XML parameter text from done-message-only streams", async () => { + const rawToolText = [ + "[tool:read]", + "", + "x".repeat(256_001), + "", + "", + ].join("\n"); + const baseStreamFn: StreamFn = () => + createEventStream([ + { + type: "done", + reason: "stop", + message: { + role: "assistant", + content: [{ type: "text", text: rawToolText }], + stopReason: "stop", + }, + }, + ]); + const wrapped = createPlainTextToolCallCompatWrapper(baseStreamFn); + const events: unknown[] = []; + + for await (const event of wrapped( + {} as never, + { tools: [{ name: "read" }] } as never, + {}, + ) as AsyncIterable) { + events.push(event); + } + + expect(events.map((event) => (event as { type?: string }).type)).toEqual(["done"]); + const doneEvent = requireRecord(events[0], "done event"); + expect(doneEvent.reason).toBe("stop"); + expect(doneEvent.message).toMatchObject({ + role: "assistant", + content: [], + stopReason: "stop", + }); + expect(JSON.stringify(events)).not.toContain("[tool:read]"); + }); + + it("scrubs over-cap bracketed XML parameter text from length terminal messages", async () => { + const { source, stream } = createControlledPlainTextToolCallCompatStream(); + const output = await resolveStream(stream); + const resultPromise = output.result(); + const eventsPromise = (async () => { + const events: unknown[] = []; + for await (const event of output as AsyncIterable) { + events.push(event); + } + return events; + })(); + const rawToolText = [ + "[tool:read]", + "", + "x".repeat(256_001), + "", + "", + ].join("\n"); + + source.push({ + type: "done", + reason: "length", + message: { + role: "assistant", + content: [{ type: "text", text: rawToolText }], + stopReason: "length", + }, + } as never); + source.end(); + + const events = await eventsPromise; + const result = requireRecord(await resultPromise, "result message"); + + expect(requireRecord(events[0], "done event")).toMatchObject({ + reason: "length", + message: { role: "assistant", content: [], stopReason: "length" }, + }); + expect(result).toMatchObject({ role: "assistant", content: [], stopReason: "length" }); + expect(JSON.stringify(events)).not.toContain("[tool:read]"); + expect(JSON.stringify(result)).not.toContain("[tool:read]"); + }); + + it("scrubs split over-cap bracketed XML parameter text from done messages", async () => { + const rawToolTextParts = [ + "[tool:read]\n", + ["x".repeat(256_001), "", ""].join("\n"), + ]; + const baseStreamFn: StreamFn = () => + createEventStream([ + { + type: "done", + reason: "stop", + message: { + role: "assistant", + content: rawToolTextParts.map((text) => ({ type: "text", text })), + stopReason: "stop", + }, + }, + ]); + const wrapped = createPlainTextToolCallCompatWrapper(baseStreamFn); + const events: unknown[] = []; + + for await (const event of wrapped( + {} as never, + { tools: [{ name: "read" }] } as never, + {}, + ) as AsyncIterable) { + events.push(event); + } + + const doneEvent = requireRecord(events[0], "done event"); + expect(doneEvent.reason).toBe("stop"); + expect(doneEvent.message).toMatchObject({ + role: "assistant", + content: [], + stopReason: "stop", + }); + expect(JSON.stringify(events)).not.toContain("[tool:read]"); + expect(JSON.stringify(events)).not.toContain(""); + }); + + it("scrubs split over-cap bracketed XML tails before later visible text", async () => { + const rawToolTextParts = [ + "[tool:read]\n", + "x".repeat(256_001), + ["", ""].join("\n"), + ]; + const visibleText = "Visible text after the tool-looking blocks."; + const baseStreamFn: StreamFn = () => + createEventStream([ + { + type: "done", + reason: "stop", + message: { + role: "assistant", + content: [ + ...rawToolTextParts.map((text) => ({ type: "text", text })), + { type: "text", text: visibleText }, + ], + stopReason: "stop", + }, + }, + ]); + const wrapped = createPlainTextToolCallCompatWrapper(baseStreamFn); + const events: unknown[] = []; + + for await (const event of wrapped( + {} as never, + { tools: [{ name: "read" }] } as never, + {}, + ) as AsyncIterable) { + events.push(event); + } + + expect(requireRecord(events[0], "done event").message).toMatchObject({ + role: "assistant", + content: [{ type: "text", text: visibleText }], + stopReason: "stop", + }); + expect(JSON.stringify(events)).not.toContain("[tool:read]"); + expect(JSON.stringify(events)).not.toContain(""); + }); + + it("scrubs split over-cap bracketed XML around non-text blocks", async () => { + const baseStreamFn: StreamFn = () => + createEventStream([ + { + type: "done", + reason: "stop", + message: { + role: "assistant", + content: [ + { type: "text", text: "[tool:read]\n" }, + { type: "thinking", thinking: "Checking path." }, + { + type: "text", + text: ["x".repeat(256_001), "", ""].join("\n"), + }, + ], + stopReason: "stop", + }, + }, + ]); + const wrapped = createPlainTextToolCallCompatWrapper(baseStreamFn); + const events: unknown[] = []; + + for await (const event of wrapped( + {} as never, + { tools: [{ name: "read" }] } as never, + {}, + ) as AsyncIterable) { + events.push(event); + } + + expect(requireRecord(events[0], "done event").message).toMatchObject({ + role: "assistant", + content: [{ type: "thinking", thinking: "Checking path." }], + stopReason: "stop", + }); + expect(JSON.stringify(events)).not.toContain("[tool:read]"); + expect(JSON.stringify(events)).not.toContain(""); + }); + + it("scrubs closing tails after a single over-cap bracketed XML block", async () => { + const rawToolTextParts = [ + ["[tool:read]", "", "x".repeat(256_001)].join("\n"), + ["", ""].join("\n"), + ]; + const visibleText = "Visible text after the tool-looking blocks."; + const baseStreamFn: StreamFn = () => + createEventStream([ + { + type: "done", + reason: "stop", + message: { + role: "assistant", + content: [ + ...rawToolTextParts.map((text) => ({ type: "text", text })), + { type: "text", text: visibleText }, + ], + stopReason: "stop", + }, + }, + ]); + const wrapped = createPlainTextToolCallCompatWrapper(baseStreamFn); + const events: unknown[] = []; + + for await (const event of wrapped( + {} as never, + { tools: [{ name: "read" }] } as never, + {}, + ) as AsyncIterable) { + events.push(event); + } + + expect(requireRecord(events[0], "done event").message).toMatchObject({ + role: "assistant", + content: [{ type: "text", text: visibleText }], + stopReason: "stop", + }); + expect(JSON.stringify(events)).not.toContain("[tool:read]"); + expect(JSON.stringify(events)).not.toContain(""); + }); + + it("scrubs closing tails after a single over-cap bracketed XML block without visible text", async () => { + const rawToolTextParts = [ + ["[tool:read]", "", "x".repeat(256_001)].join("\n"), + ["", ""].join("\n"), + ]; + const baseStreamFn: StreamFn = () => + createEventStream([ + { + type: "done", + reason: "stop", + message: { + role: "assistant", + content: rawToolTextParts.map((text) => ({ type: "text", text })), + stopReason: "stop", + }, + }, + ]); + const wrapped = createPlainTextToolCallCompatWrapper(baseStreamFn); + const events: unknown[] = []; + + for await (const event of wrapped( + {} as never, + { tools: [{ name: "read" }] } as never, + {}, + ) as AsyncIterable) { + events.push(event); + } + + expect(requireRecord(events[0], "done event").message).toMatchObject({ + role: "assistant", + content: [], + stopReason: "stop", + }); + expect(JSON.stringify(events)).not.toContain("[tool:read]"); + expect(JSON.stringify(events)).not.toContain(""); + }); + + it("scrubs over-cap buffers even when later text blocks contain complete tool calls", async () => { + const incompleteOverCapTool = ["[tool:read]", "", "x".repeat(256_001)].join( + "\n", + ); + const completeTool = '[tool:read] {"path":"src/index.ts"}'; + const baseStreamFn: StreamFn = () => + createEventStream([ + { + type: "done", + reason: "stop", + message: { + role: "assistant", + content: [ + { type: "text", text: incompleteOverCapTool }, + { type: "text", text: completeTool }, + ], + stopReason: "stop", + }, + }, + ]); + const wrapped = createPlainTextToolCallCompatWrapper(baseStreamFn); + const events: unknown[] = []; + + for await (const event of wrapped( + {} as never, + { tools: [{ name: "read" }] } as never, + {}, + ) as AsyncIterable) { + events.push(event); + } + + expect(requireRecord(events[0], "done event").message).toMatchObject({ + role: "assistant", + content: [], + stopReason: "stop", + }); + expect(JSON.stringify(events)).not.toContain("[tool:read]"); + expect(JSON.stringify(events)).not.toContain("src/index.ts"); + }); + + it("scrubs multiple incomplete over-cap tool blocks from done messages", async () => { + const firstOverCapTool = ["[tool:read]", "", "x".repeat(256_001)].join("\n"); + const secondOverCapTool = ["[tool:read]", "", "y".repeat(256_001)].join("\n"); + const visibleText = "Visible text after the tool-looking blocks."; + const baseStreamFn: StreamFn = () => + createEventStream([ + { + type: "done", + reason: "stop", + message: { + role: "assistant", + content: [ + { type: "text", text: firstOverCapTool }, + { type: "text", text: secondOverCapTool }, + { type: "text", text: visibleText }, + ], + stopReason: "stop", + }, + }, + ]); + const wrapped = createPlainTextToolCallCompatWrapper(baseStreamFn); + const events: unknown[] = []; + + for await (const event of wrapped( + {} as never, + { tools: [{ name: "read" }] } as never, + {}, + ) as AsyncIterable) { + events.push(event); + } + + expect(requireRecord(events[0], "done event").message).toMatchObject({ + role: "assistant", + content: [{ type: "text", text: visibleText }], + stopReason: "stop", + }); + expect(JSON.stringify(events)).not.toContain("[tool:read]"); + expect(JSON.stringify(events)).not.toContain("x".repeat(256_001)); + expect(JSON.stringify(events)).not.toContain("y".repeat(256_001)); + }); + + it("scrubs done-message over-cap blocks after visible text", async () => { + const intro = "Visible intro."; + const incompleteOverCapTool = ["[tool:read]", "", "x".repeat(256_001)].join( + "\n", + ); + const baseStreamFn: StreamFn = () => + createEventStream([ + { + type: "done", + reason: "stop", + message: { + role: "assistant", + content: [ + { type: "text", text: intro }, + { type: "text", text: incompleteOverCapTool }, + ], + stopReason: "stop", + }, + }, + ]); + const wrapped = createPlainTextToolCallCompatWrapper(baseStreamFn); + const events: unknown[] = []; + + for await (const event of wrapped( + {} as never, + { tools: [{ name: "read" }] } as never, + {}, + ) as AsyncIterable) { + events.push(event); + } + + expect(requireRecord(events[0], "done event").message).toMatchObject({ + role: "assistant", + content: [{ type: "text", text: intro }], + stopReason: "stop", + }); + expect(JSON.stringify(events)).not.toContain("[tool:read]"); + }); + + it("scrubs split done-message over-cap blocks after visible text", async () => { + const intro = "Visible intro."; + const baseStreamFn: StreamFn = () => + createEventStream([ + { + type: "done", + reason: "stop", + message: { + role: "assistant", + content: [ + { type: "text", text: intro }, + { type: "text", text: "[tool:read]\n" }, + { type: "text", text: "x".repeat(256_001) }, + { type: "text", text: ["", ""].join("\n") }, + ], + stopReason: "stop", + }, + }, + ]); + const wrapped = createPlainTextToolCallCompatWrapper(baseStreamFn); + const events: unknown[] = []; + + for await (const event of wrapped( + {} as never, + { tools: [{ name: "read" }] } as never, + {}, + ) as AsyncIterable) { + events.push(event); + } + + expect(requireRecord(events[0], "done event").message).toMatchObject({ + role: "assistant", + content: [{ type: "text", text: intro }], + stopReason: "stop", + }); + expect(JSON.stringify(events)).not.toContain("[tool:read]"); + expect(JSON.stringify(events)).not.toContain(""); + }); + + it("preserves small complete tool calls after over-cap visible text", async () => { + const visibleText = `Visible intro ${"x".repeat(256_001)}`; + const toolText = '[tool:read] {"path":"src/index.ts"}'; + const baseStreamFn: StreamFn = () => + createEventStream([ + { + type: "done", + reason: "stop", + message: { + role: "assistant", + content: [ + { type: "text", text: visibleText }, + { type: "text", text: toolText }, + ], + stopReason: "stop", + }, + }, + ]); + const wrapped = createPlainTextToolCallCompatWrapper(baseStreamFn); + const events: unknown[] = []; + + for await (const event of wrapped( + {} as never, + { tools: [{ name: "read" }] } as never, + {}, + ) as AsyncIterable) { + events.push(event); + } + + expect(events.map((event) => (event as { type?: string }).type)).toEqual(["done"]); + expect(requireRecord(events[0], "done event").message).toMatchObject({ + role: "assistant", + content: [ + { type: "text", text: visibleText }, + { type: "text", text: toolText }, + ], + stopReason: "stop", + }); + }); + + it("does not leak over-cap buffers when stripped later tool blocks are followed by text", async () => { + const incompleteOverCapTool = ["[tool:read]", "", "x".repeat(256_001)].join( + "\n", + ); + const completeTool = '[tool:read] {"path":"src/index.ts"}'; + const baseStreamFn: StreamFn = () => + createEventStream([ + { + type: "done", + reason: "stop", + message: { + role: "assistant", + content: [ + { type: "text", text: incompleteOverCapTool }, + { type: "text", text: completeTool }, + { type: "text", text: "Visible text after the tool-looking blocks." }, + ], + stopReason: "stop", + }, + }, + ]); + const wrapped = createPlainTextToolCallCompatWrapper(baseStreamFn); + const events: unknown[] = []; + + for await (const event of wrapped( + {} as never, + { tools: [{ name: "read" }] } as never, + {}, + ) as AsyncIterable) { + events.push(event); + } + + expect(JSON.stringify(events)).not.toContain("[tool:read]"); + expect(JSON.stringify(events)).not.toContain("src/index.ts"); + expect(requireRecord(events[0], "done event").message).toMatchObject({ + role: "assistant", + content: [{ type: "text", text: "Visible text after the tool-looking blocks." }], + stopReason: "stop", + }); + }); + + it("preserves unallowed tool-looking text while scrubbing an over-cap allowed tool block", async () => { + const allowedOverCapTool = [ + "[tool:read]", + "", + "x".repeat(256_001), + "", + "", + ].join("\n"); + const unallowedToolText = '[tool:write] {"path":"keep-visible"}'; + const baseStreamFn: StreamFn = () => + createEventStream([ + { + type: "done", + reason: "stop", + message: { + role: "assistant", + content: [ + { type: "text", text: allowedOverCapTool }, + { type: "text", text: unallowedToolText }, + ], + stopReason: "stop", + }, + }, + ]); + const wrapped = createPlainTextToolCallCompatWrapper(baseStreamFn); + const events: unknown[] = []; + + for await (const event of wrapped( + {} as never, + { tools: [{ name: "read" }] } as never, + {}, + ) as AsyncIterable) { + events.push(event); + } + + expect(JSON.stringify(events)).toContain("[tool:write]"); + expect(JSON.stringify(events)).not.toContain("[tool:read]"); + }); + + it("flushes over-cap text for closed tool names that only prefix-match configured tools", async () => { + const rawToolText = [ + "[tool:read]", + "", + "x".repeat(256_001), + "", + "", + ].join("\n"); + const baseStreamFn: StreamFn = () => + createEventStream([ + { type: "text_delta", contentIndex: 0, delta: rawToolText }, + { + type: "done", + reason: "stop", + message: { + role: "assistant", + content: [{ type: "text", text: rawToolText }], + stopReason: "stop", + }, + }, + ]); + const wrapped = createPlainTextToolCallCompatWrapper(baseStreamFn); + const events: unknown[] = []; + + for await (const event of wrapped( + {} as never, + { tools: [{ name: "read_file" }] } as never, + {}, + ) as AsyncIterable) { + events.push(event); + } + + expect(events.map((event) => (event as { type?: string }).type)).toEqual([ + "text_delta", + "done", + ]); + expect(String(requireRecord(events[0], "text event").delta)).toContain("[tool:read]"); + }); + + it("flushes long mixed text after a complete serialized tool-call prefix", async () => { + const rawText = ['[tool:read] {"path":"src/index.ts"}', "A".repeat(256_001)].join("\n"); + const baseStreamFn: StreamFn = () => + createEventStream([ + { type: "text_delta", contentIndex: 0, delta: rawText }, + { + type: "done", + reason: "stop", + message: { + role: "assistant", + content: [{ type: "text", text: rawText }], + stopReason: "stop", + }, + }, + ]); + const wrapped = createPlainTextToolCallCompatWrapper(baseStreamFn); + const events: unknown[] = []; + + for await (const event of wrapped( + {} as never, + { tools: [{ name: "read" }] } as never, + {}, + ) as AsyncIterable) { + events.push(event); + } + + expect(events.map((event) => (event as { type?: string }).type)).toEqual([ + "text_delta", + "done", + ]); + expect(String(requireRecord(events[0], "text event").delta)).toContain("AAAA"); + expect(JSON.stringify(events)).not.toContain("[tool:read]"); + }); + + it("preserves visible suffix text after an over-cap JSON tool payload", async () => { + const visibleSuffix = "Visible answer after oversized JSON."; + const rawText = [`[tool:read] {"path":"${"x".repeat(256_001)}"}`, visibleSuffix].join("\n"); + const baseStreamFn: StreamFn = () => + createEventStream([ + { type: "text_delta", contentIndex: 0, delta: rawText }, + { + type: "done", + reason: "stop", + message: { + role: "assistant", + content: [{ type: "text", text: rawText }], + stopReason: "stop", + }, + }, + ]); + const wrapped = createPlainTextToolCallCompatWrapper(baseStreamFn); + const events: unknown[] = []; + + for await (const event of wrapped( + {} as never, + { tools: [{ name: "read" }] } as never, + {}, + ) as AsyncIterable) { + events.push(event); + } + + expect(events.map((event) => (event as { type?: string }).type)).toEqual([ + "text_delta", + "done", + ]); + const textEvent = requireRecord(events[0], "text event"); + expect(String(textEvent.delta)).toBe(visibleSuffix); + expect(requireRecord(textEvent.partial, "text partial").content).toEqual([ + { type: "text", text: visibleSuffix }, + ]); + expect(JSON.stringify(events)).not.toContain("[tool:read]"); + }); + + it("reclassifies split over-cap mixed text and streams the visible suffix", async () => { + const toolPrefix = ["[tool:read]", "", "x".repeat(256_001)].join("\n"); + const visibleSuffix = "Visible answer after the tool-looking prefix."; + const rawText = [toolPrefix, "", "", visibleSuffix].join("\n"); + const baseStreamFn: StreamFn = () => + createEventStream([ + { type: "text_delta", contentIndex: 0, delta: toolPrefix }, + { + type: "text_delta", + contentIndex: 0, + delta: ["", "", visibleSuffix].join("\n"), + }, + { type: "text_end", contentIndex: 0, content: rawText }, + { + type: "done", + reason: "stop", + message: { + role: "assistant", + content: [{ type: "text", text: rawText }], + stopReason: "stop", + }, + }, + ]); + const wrapped = createPlainTextToolCallCompatWrapper(baseStreamFn); + const events: unknown[] = []; + + for await (const event of wrapped( + {} as never, + { tools: [{ name: "read" }] } as never, + {}, + ) as AsyncIterable) { + events.push(event); + } + + expect(events.map((event) => (event as { type?: string }).type)).toEqual([ + "text_delta", + "done", + ]); + expect(String(requireRecord(events[0], "text event").delta)).toBe(visibleSuffix); + expect(JSON.stringify(events)).not.toContain("[tool:read]"); + }); + + it("preserves XML visible suffix after Unicode payload text", async () => { + const toolPrefix = ["[tool:read]", "", `${"x".repeat(256_001)}İ`].join("\n"); + const visibleSuffix = "Visible suffix after Unicode payload."; + const rawText = [toolPrefix, "", "", visibleSuffix].join("\n"); + const baseStreamFn: StreamFn = () => + createEventStream([ + { type: "text_delta", contentIndex: 0, delta: toolPrefix }, + { + type: "text_delta", + contentIndex: 0, + delta: ["", "", visibleSuffix].join("\n"), + }, + { type: "text_end", contentIndex: 0, content: rawText }, + { + type: "done", + reason: "stop", + message: { + role: "assistant", + content: [{ type: "text", text: rawText }], + stopReason: "stop", + }, + }, + ]); + const wrapped = createPlainTextToolCallCompatWrapper(baseStreamFn); + const events: unknown[] = []; + + for await (const event of wrapped( + {} as never, + { tools: [{ name: "read" }] } as never, + {}, + ) as AsyncIterable) { + events.push(event); + } + + expect(String(requireRecord(events[0], "text event").delta)).toBe(visibleSuffix); + expect(JSON.stringify(events)).not.toContain(""); + expect(JSON.stringify(events)).not.toContain(""); + }); + + it("scrubs reclassified mixed text from terminal error partials", async () => { + const toolPrefix = ["[tool:read]", "", "x".repeat(256_001)].join("\n"); + const visibleSuffix = "Visible answer before the stream error."; + const rawText = [toolPrefix, "", "", visibleSuffix].join("\n"); + const baseStreamFn: StreamFn = () => + createEventStream([ + { type: "text_delta", contentIndex: 0, delta: toolPrefix }, + { + type: "text_delta", + contentIndex: 0, + delta: ["", "", visibleSuffix].join("\n"), + }, + { + type: "error", + partial: { content: [{ type: "text", text: rawText }] }, + error: { + content: [{ type: "text", text: rawText }], + message: "stream failed", + }, + }, + ]); + const wrapped = createPlainTextToolCallCompatWrapper(baseStreamFn); + const events: unknown[] = []; + + for await (const event of wrapped( + {} as never, + { tools: [{ name: "read" }] } as never, + {}, + ) as AsyncIterable) { + events.push(event); + } + + expect(events.map((event) => (event as { type?: string }).type)).toEqual([ + "text_delta", + "error", + ]); + expect(String(requireRecord(events[0], "text event").delta)).toBe(visibleSuffix); + expect( + requireRecord(requireRecord(events[1], "error event").partial, "error partial").content, + ).toEqual([{ type: "text", text: visibleSuffix }]); + expect( + requireRecord(requireRecord(events[1], "error event").error, "error record").content, + ).toEqual([{ type: "text", text: visibleSuffix }]); + expect(JSON.stringify(events)).not.toContain("[tool:read]"); + }); + + it("preserves visible suffix text when the tool terminator arrives after the scan cap", async () => { + const toolPrefix = ["[tool:read]", "", "x".repeat(400_000)].join("\n"); + const visibleSuffix = "Visible answer after a very large tool-looking prefix."; + const rawText = [toolPrefix, "", "", visibleSuffix].join("\n"); + const baseStreamFn: StreamFn = () => + createEventStream([ + { type: "text_delta", contentIndex: 0, delta: toolPrefix }, + { + type: "text_delta", + contentIndex: 0, + delta: ["", "", visibleSuffix].join("\n"), + }, + { + type: "done", + reason: "stop", + message: { + role: "assistant", + content: [{ type: "text", text: rawText }], + stopReason: "stop", + }, + }, + ]); + const wrapped = createPlainTextToolCallCompatWrapper(baseStreamFn); + const events: unknown[] = []; + + for await (const event of wrapped( + {} as never, + { tools: [{ name: "read" }] } as never, + {}, + ) as AsyncIterable) { + events.push(event); + } + + expect(events.map((event) => (event as { type?: string }).type)).toEqual([ + "text_delta", + "done", + ]); + expect(String(requireRecord(events[0], "text event").delta)).toBe(visibleSuffix); + expect(JSON.stringify(events)).not.toContain("[tool:read]"); + }); + + it("preserves visible suffix text when the over-cap terminator is split across chunks", async () => { + const toolPrefix = ["[tool:read]", "", "x".repeat(400_000)].join("\n"); + const visibleSuffix = "Visible answer after a split terminator."; + const rawText = [toolPrefix, "", "", visibleSuffix].join("\n"); + const baseStreamFn: StreamFn = () => + createEventStream([ + { type: "text_delta", contentIndex: 0, delta: toolPrefix }, + { type: "text_delta", contentIndex: 0, delta: "", "", visibleSuffix].join("\n"), + }, + { + type: "done", + reason: "stop", + message: { + role: "assistant", + content: [{ type: "text", text: rawText }], + stopReason: "stop", + }, + }, + ]); + const wrapped = createPlainTextToolCallCompatWrapper(baseStreamFn); + const events: unknown[] = []; + + for await (const event of wrapped( + {} as never, + { tools: [{ name: "read" }] } as never, + {}, + ) as AsyncIterable) { + events.push(event); + } + + expect(events.map((event) => (event as { type?: string }).type)).toEqual([ + "text_delta", + "done", + ]); + expect(String(requireRecord(events[0], "text event").delta)).toBe(visibleSuffix); + expect(JSON.stringify(events)).not.toContain("[tool:read]"); + }); + + it("preserves long visible suffix text after an over-cap terminator", async () => { + const toolPrefix = ["[tool:read]", "", "x".repeat(400_000)].join("\n"); + const visibleSuffix = `Visible answer ${"y".repeat(70_000)}`; + const rawText = [toolPrefix, "", "", visibleSuffix].join("\n"); + const baseStreamFn: StreamFn = () => + createEventStream([ + { type: "text_delta", contentIndex: 0, delta: toolPrefix }, + { + type: "text_delta", + contentIndex: 0, + delta: ["", "", visibleSuffix].join("\n"), + }, + { + type: "done", + reason: "stop", + message: { + role: "assistant", + content: [{ type: "text", text: rawText }], + stopReason: "stop", + }, + }, + ]); + const wrapped = createPlainTextToolCallCompatWrapper(baseStreamFn); + const events: unknown[] = []; + + for await (const event of wrapped( + {} as never, + { tools: [{ name: "read" }] } as never, + {}, + ) as AsyncIterable) { + events.push(event); + } + + expect(events.map((event) => (event as { type?: string }).type)).toEqual([ + "text_delta", + "done", + ]); + expect(String(requireRecord(events[0], "text event").delta)).toBe(visibleSuffix); + expect(JSON.stringify(events)).not.toContain("[tool:read]"); + }); + + it("does not duplicate visible suffix text when mixed over-cap events omit contentIndex", async () => { + const visibleSuffix = "Visible answer from an index-less stream."; + const rawText = [`[tool:read] {"path":"${"x".repeat(256_001)}"}`, visibleSuffix].join("\n"); + const baseStreamFn: StreamFn = () => + createEventStream([ + { type: "text_delta", delta: rawText }, + { type: "text_end", content: rawText }, + { + type: "done", + reason: "stop", + message: { + role: "assistant", + content: [{ type: "text", text: rawText }], + stopReason: "stop", + }, + }, + ]); + const wrapped = createPlainTextToolCallCompatWrapper(baseStreamFn); + const events: unknown[] = []; + + for await (const event of wrapped( + {} as never, + { tools: [{ name: "read" }] } as never, + {}, + ) as AsyncIterable) { + events.push(event); + } + + expect(events.map((event) => (event as { type?: string }).type)).toEqual([ + "text_delta", + "done", + ]); + expect(String(requireRecord(events[0], "text event").delta)).toBe(visibleSuffix); + expect(JSON.stringify(events)).not.toContain("[tool:read]"); + }); + + it("keeps partial snapshots current for multi-delta visible suffix text", async () => { + const firstVisible = "Visible answer "; + const secondVisible = "continues."; + const rawPrefix = `[tool:read] {"path":"${"x".repeat(256_001)}"}`; + const firstChunk = [rawPrefix, firstVisible].join("\n"); + const rawText = `${firstChunk}${secondVisible}`; + const baseStreamFn: StreamFn = () => + createEventStream([ + { type: "text_delta", contentIndex: 0, delta: firstChunk }, + { + type: "text_delta", + contentIndex: 0, + delta: secondVisible, + partial: { content: [{ type: "text", text: rawText }] }, + }, + { + type: "done", + reason: "stop", + message: { + role: "assistant", + content: [{ type: "text", text: rawText }], + stopReason: "stop", + }, + }, + ]); + const wrapped = createPlainTextToolCallCompatWrapper(baseStreamFn); + const events: unknown[] = []; + + for await (const event of wrapped( + {} as never, + { tools: [{ name: "read" }] } as never, + {}, + ) as AsyncIterable) { + events.push(event); + } + + const secondEvent = requireRecord(events[1], "second text event"); + expect(events.map((event) => (event as { type?: string }).type)).toEqual([ + "text_delta", + "text_delta", + "done", + ]); + expect(secondEvent.delta).toBe(secondVisible); + expect(requireRecord(secondEvent.partial, "second partial").content).toEqual([ + { type: "text", text: `${firstVisible}${secondVisible}` }, + ]); + expect(JSON.stringify(events)).not.toContain("[tool:read]"); + }); + + it("preserves unrelated done-message text blocks when replacing a reclassified suffix", async () => { + const introText = "Intro text before the reclassified block."; + const visibleSuffix = "Visible suffix from the reclassified block."; + const rawToolText = [`[tool:read] {"path":"${"x".repeat(256_001)}"}`, visibleSuffix].join("\n"); + const baseStreamFn: StreamFn = () => + createEventStream([ + { type: "text_delta", contentIndex: 0, delta: introText }, + { type: "text_delta", contentIndex: 1, delta: rawToolText }, + { type: "text_end", contentIndex: 1, content: rawToolText }, + { + type: "done", + reason: "stop", + message: { + role: "assistant", + content: [ + { type: "text", text: introText }, + { type: "text", text: rawToolText }, + ], + stopReason: "stop", + }, + }, + ]); + const wrapped = createPlainTextToolCallCompatWrapper(baseStreamFn); + const events: unknown[] = []; + + for await (const event of wrapped( + {} as never, + { tools: [{ name: "read" }] } as never, + {}, + ) as AsyncIterable) { + events.push(event); + } + + const doneMessage = requireRecord( + requireRecord(events.at(-1), "done event").message, + "done message", + ); + expect(doneMessage.content).toEqual([ + { type: "text", text: introText }, + { type: "text", text: visibleSuffix }, + ]); + expect(JSON.stringify(events)).not.toContain("[tool:read]"); + }); + + it("preserves later done-message text blocks when replacing an indexless reclassified suffix", async () => { + const visibleSuffix = "Visible suffix from the reclassified block."; + const laterText = "Additional visible answer text."; + const rawToolText = [`[tool:read] {"path":"${"x".repeat(256_001)}"}`, visibleSuffix].join("\n"); + const baseStreamFn: StreamFn = () => + createEventStream([ + { type: "text_delta", delta: rawToolText }, + { + type: "done", + reason: "stop", + message: { + role: "assistant", + content: [ + { type: "text", text: rawToolText }, + { type: "text", text: laterText }, + ], + stopReason: "stop", + }, + }, + ]); + const wrapped = createPlainTextToolCallCompatWrapper(baseStreamFn); + const events: unknown[] = []; + + for await (const event of wrapped( + {} as never, + { tools: [{ name: "read" }] } as never, + {}, + ) as AsyncIterable) { + events.push(event); + } + + const doneMessage = requireRecord( + requireRecord(events.at(-1), "done event").message, + "done message", + ); + expect(doneMessage.content).toEqual([ + { type: "text", text: visibleSuffix }, + { type: "text", text: laterText }, + ]); + expect(JSON.stringify(events)).not.toContain("[tool:read]"); + }); + + it("keeps legacy bracketed XML parameter tool calls buffered for conversion", async () => { + const { source, stream } = createControlledPlainTextToolCallCompatStream(); + const iterator = (await resolveStream(stream))[Symbol.asyncIterator](); + const rawToolText = [ + "[read]", + "", + "src/index.ts", + "", + "", + ].join("\n"); + + try { + source.push({ type: "start", partial: { content: [] } } as never); + expect((await nextEvent(iterator, "start")).type).toBe("start"); + + source.push({ + type: "text_delta", + contentIndex: 0, + delta: rawToolText, + } as never); + source.push({ + type: "done", + reason: "stop", + message: { + role: "assistant", + content: [{ type: "text", text: rawToolText }], + stopReason: "stop", + }, + } as never); + + const event = await nextEvent(iterator, "converted legacy bracketed XML tool call"); + expect(event.type).toBe("toolcall_start"); + } finally { + source.end(); + await iterator.return?.(); + } + }); + + it("keeps CRLF legacy bracketed XML parameter tool calls buffered for conversion", async () => { + const { source, stream } = createControlledPlainTextToolCallCompatStream(); + const iterator = (await resolveStream(stream))[Symbol.asyncIterator](); + const rawToolText = [ + "[read]", + "", + "src/index.ts", + "", + "", + ].join("\r\n"); + + try { + source.push({ type: "start", partial: { content: [] } } as never); + expect((await nextEvent(iterator, "start")).type).toBe("start"); + + source.push({ + type: "text_delta", + contentIndex: 0, + delta: rawToolText, + } as never); + source.push({ + type: "done", + reason: "stop", + message: { + role: "assistant", + content: [{ type: "text", text: rawToolText }], + stopReason: "stop", + }, + } as never); + + const event = await nextEvent(iterator, "converted CRLF legacy XML tool call"); + expect(event.type).toBe("toolcall_start"); + } finally { + source.end(); + await iterator.return?.(); + } + }); + + it("keeps split XML function tool-call markers buffered for conversion", async () => { + const { source, stream } = createControlledPlainTextToolCallCompatStream(); + const iterator = (await resolveStream(stream))[Symbol.asyncIterator](); + const rawToolText = [ + "", + "", + "src/index.ts", + "", + "", + ].join("\n"); + + try { + source.push({ type: "start", partial: { content: [] } } as never); + expect((await nextEvent(iterator, "start")).type).toBe("start"); + + source.push({ type: "text_delta", contentIndex: 0, delta: "<" } as never); + source.push({ + type: "text_delta", + contentIndex: 0, + delta: rawToolText.slice(1), + } as never); + source.push({ + type: "done", + reason: "stop", + message: { + role: "assistant", + content: [{ type: "text", text: rawToolText }], + stopReason: "stop", + }, + } as never); + + const event = await nextEvent(iterator, "converted split XML tool call"); + expect(event.type).toBe("toolcall_start"); + } finally { + source.end(); + await iterator.return?.(); + } + }); + + it("does not buffer normal final prose until done", async () => { + const { source, stream } = createControlledPlainTextToolCallCompatStream(); + const iterator = (await resolveStream(stream))[Symbol.asyncIterator](); + + try { + source.push({ type: "start", partial: { content: [] } } as never); + expect((await nextEvent(iterator, "start")).type).toBe("start"); + + source.push({ + type: "text_delta", + contentIndex: 0, + delta: "final answer starts here", + } as never); + + const event = await nextEvent(iterator, "normal final prose"); + expect(event).toMatchObject({ type: "text_delta", delta: "final answer starts here" }); + } finally { + source.push({ type: "done", reason: "stop", message: {} } as never); + source.end(); + await iterator.return?.(); + } + }); +}); + describe("stripTrailingAnthropicAssistantPrefillWhenThinking", () => { it("removes trailing assistant text turns when Anthropic thinking is enabled", () => { const payload = { diff --git a/src/plugin-sdk/provider-stream-shared.ts b/src/plugin-sdk/provider-stream-shared.ts index 0d4a46ed310..69404814779 100644 --- a/src/plugin-sdk/provider-stream-shared.ts +++ b/src/plugin-sdk/provider-stream-shared.ts @@ -1,11 +1,18 @@ import { randomUUID } from "node:crypto"; +import { + extractStandalonePlainTextToolCallText, + normalizePlainTextToolCallStreamEvents, + promoteStandalonePlainTextToolCallMessage, + scrubOverCapPlainTextToolCallMessage, + type PlainTextToolCallNameMatcher, + type PlainTextToolCallMessageNormalization, +} from "../../packages/tool-call-repair/src/index.js"; import type { StreamFn } from "../agents/runtime/index.js"; import { streamWithPayloadPatch } from "../llm/providers/stream-wrappers/stream-payload-utils.js"; import { streamSimple } from "../llm/stream.js"; import { createAssistantMessageEventStream } from "../llm/utils/event-stream.js"; import { normalizeLowercaseStringOrEmpty } from "../shared/string-coerce.js"; import type { ProviderWrapStreamFnContext } from "./plugin-entry.js"; -import { parseStandalonePlainTextToolCallBlocks } from "./tool-payload.js"; export type ProviderStreamWrapperFactory = | ((streamFn: StreamFn | undefined) => StreamFn | undefined) @@ -41,164 +48,6 @@ function resolveContextToolNames(context: Parameters[1]): Set return new Set(names); } -function matchesLiteralPrefix(text: string, literal: string): boolean { - return literal.startsWith(text) || text.startsWith(literal); -} - -function skipHorizontalWhitespace(text: string, start: number): number { - let cursor = start; - while (cursor < text.length && /[ \t]/.test(text[cursor] ?? "")) { - cursor += 1; - } - return cursor; -} - -function matchesAnyToolNamePrefix(text: string, toolNames: Set): boolean { - if (!text) { - return true; - } - for (const toolName of toolNames) { - if (toolName.startsWith(text) || text.startsWith(toolName)) { - return true; - } - } - return false; -} - -function couldStillBeJsonPayload(text: string, start: number): boolean { - let cursor = start; - while (cursor < text.length && /\s/.test(text[cursor] ?? "")) { - cursor += 1; - } - return cursor >= text.length || text[cursor] === "{"; -} - -function couldStillBeBracketedToolCall(text: string, toolNames: Set): boolean { - if (!text.startsWith("[")) { - return false; - } - - const toolPrefix = "[tool:"; - if (matchesLiteralPrefix(text, toolPrefix)) { - if (text.length <= toolPrefix.length) { - return true; - } - const nameStart = toolPrefix.length; - let cursor = nameStart; - while (cursor < text.length && text[cursor] !== "]") { - cursor += 1; - } - const name = text.slice(nameStart, cursor).trim(); - if (!matchesAnyToolNamePrefix(name, toolNames)) { - return false; - } - if (cursor >= text.length) { - return true; - } - if (text[cursor] !== "]") { - return false; - } - return couldStillBeJsonPayload(text, cursor + 1); - } - - let cursor = 1; - while (cursor < text.length && text[cursor] !== "\n" && text[cursor] !== "]") { - cursor += 1; - } - const firstLine = text.slice(1, cursor); - if (!matchesAnyToolNamePrefix(firstLine.trim(), toolNames)) { - return false; - } - if (cursor >= text.length) { - return true; - } - if (text[cursor] === "]") { - return couldStillBeJsonPayload(text, text[cursor + 1] === "\n" ? cursor + 2 : cursor + 1); - } - if (text[cursor] !== "\n") { - return false; - } - return couldStillBeJsonPayload(text, cursor + 1); -} - -function couldStillBeHarmonyToolCall(text: string, toolNames: Set): boolean { - const harmonyChannelPrefix = "<|channel|>"; - let cursor = 0; - if (matchesLiteralPrefix(text, harmonyChannelPrefix)) { - if (text.length <= harmonyChannelPrefix.length) { - return true; - } - cursor = harmonyChannelPrefix.length; - } - - const channelRest = text.slice(cursor); - const channelName = ["commentary", "analysis", "final"].find((marker) => - matchesLiteralPrefix(channelRest, marker), - ); - if (channelName) { - if (channelRest.length <= channelName.length) { - return true; - } - cursor += channelName.length; - } else if (cursor === 0) { - return false; - } else { - return false; - } - - const constraintMarker = " to="; - const constraintRest = text.slice(cursor); - if (matchesLiteralPrefix(constraintRest, constraintMarker)) { - if (constraintRest.length <= constraintMarker.length) { - return true; - } - cursor += constraintMarker.length; - const nameStart = cursor; - while (cursor < text.length && text[cursor] !== " " && text[cursor] !== "\n") { - cursor += 1; - } - const name = text.slice(nameStart, cursor).trim(); - if (!matchesAnyToolNamePrefix(name, toolNames)) { - return false; - } - } - - cursor = skipHorizontalWhitespace(text, cursor); - if (cursor >= text.length) { - return true; - } - const codeMarker = "code"; - const codeRest = text.slice(cursor); - if (matchesLiteralPrefix(codeRest, codeMarker)) { - if (codeRest.length <= codeMarker.length) { - return true; - } - cursor += codeMarker.length; - cursor = skipHorizontalWhitespace(text, cursor); - if (cursor >= text.length) { - return true; - } - } - const messageMarker = "<|message|>"; - const messageRest = text.slice(cursor); - if (matchesLiteralPrefix(messageRest, messageMarker)) { - return true; - } - return text[cursor] === "{"; -} - -function couldStillBePlainTextToolCall(text: string, toolNames: Set): boolean { - if (text.length > 256_000) { - return false; - } - const trimmed = text.trimStart(); - return ( - trimmed.length === 0 || - couldStillBeBracketedToolCall(trimmed, toolNames) || - couldStillBeHarmonyToolCall(trimmed, toolNames) - ); -} - function createSyntheticToolCallId(): string { return `call_${randomUUID().replace(/-/g, "").slice(0, 24)}`; } @@ -221,65 +70,18 @@ function promotePlainTextToolCalls( toolNames: Set, ): Record | undefined { const messageRecord = toRecord(message); - if (!messageRecord) { - return undefined; - } - if (!Array.isArray(messageRecord.content)) { - if (typeof messageRecord.content !== "string" || !messageRecord.content.trim()) { - return undefined; - } - const parsed = parseStandalonePlainTextToolCallBlocks(messageRecord.content, { - allowedToolNames: toolNames, - }); - if (!parsed) { - return undefined; - } - return { - ...messageRecord, - content: parsed.map(createPlainTextToolCallBlock), - stopReason: "toolUse", - }; - } if ( - messageRecord.content.some((block) => toRecord(block)?.type === "toolCall") || - messageRecord.content.length === 0 + Array.isArray(messageRecord?.content) && + messageRecord.content.some((block) => toRecord(block)?.type === "toolCall") ) { return undefined; } - - let promoted = false; - const nextContent: Array> = []; - for (const block of messageRecord.content) { - const blockRecord = toRecord(block); - if (!blockRecord) { - return undefined; - } - if (blockRecord.type !== "text") { - nextContent.push(blockRecord); - continue; - } - const text = typeof blockRecord.text === "string" ? blockRecord.text : ""; - if (!text.trim()) { - continue; - } - const parsed = parseStandalonePlainTextToolCallBlocks(text, { - allowedToolNames: toolNames, - }); - if (!parsed) { - return undefined; - } - nextContent.push(...parsed.map(createPlainTextToolCallBlock)); - promoted = true; - } - - if (!promoted) { - return undefined; - } - return { - ...messageRecord, - content: nextContent, - stopReason: "toolUse", - }; + return promoteStandalonePlainTextToolCallMessage({ + allowedToolNames: toolNames, + createToolCallBlock: (block, name) => createPlainTextToolCallBlock({ ...block, name }), + isRetainableNonTextBlock: () => true, + message, + }); } function emitPromotedToolCallEvents( @@ -302,6 +104,44 @@ function emitPromotedToolCallEvents( }); } +function extractPlainTextToolCallCandidate(message: unknown): string | undefined { + return extractStandalonePlainTextToolCallText({ + allowOtherNonTextBlocks: true, + message, + }); +} + +function createProviderToolNameMatcher(toolNames: Set): PlainTextToolCallNameMatcher { + return { + hasExactName: (name) => toolNames.has(name), + hasNamePrefix: (prefix) => { + for (const toolName of toolNames) { + if (toolName.startsWith(prefix)) { + return true; + } + } + return false; + }, + }; +} + +function normalizeProviderDoneMessage( + message: unknown, + toolNames: Set, + matcher: PlainTextToolCallNameMatcher, +): PlainTextToolCallMessageNormalization { + const scrubbedMessage = scrubOverCapPlainTextToolCallMessage({ + candidateText: extractPlainTextToolCallCandidate(message), + matcher, + message, + }); + if (scrubbedMessage) { + return { kind: "scrubbed", message: scrubbedMessage }; + } + const promotedMessage = promotePlainTextToolCalls(message, toolNames); + return promotedMessage ? { kind: "promoted", message: promotedMessage } : undefined; +} + function wrapPlainTextToolCallStream( source: ReturnType, context: Parameters[1], @@ -310,12 +150,11 @@ function wrapPlainTextToolCallStream( if (toolNames.size === 0) { return source; } + const matcher = createProviderToolNameMatcher(toolNames); const output = createAssistantMessageEventStream(); const stream = output as unknown as { push(event: unknown): void; end(): void }; void (async () => { - const bufferedTextEvents: unknown[] = []; - let bufferedText = ""; let ended = false; const endStream = () => { if (!ended) { @@ -323,54 +162,25 @@ function wrapPlainTextToolCallStream( stream.end(); } }; - const flushBufferedTextEvents = () => { - for (const event of bufferedTextEvents.splice(0)) { - stream.push(event); - } - bufferedText = ""; - }; try { - for await (const event of source as AsyncIterable) { - const record = toRecord(event); - const type = typeof record?.type === "string" ? record.type : ""; - - if (type === "text_start" || type === "text_delta" || type === "text_end") { - bufferedTextEvents.push(event); - if (typeof record?.delta === "string") { - bufferedText += record.delta; - } else if (typeof record?.content === "string" && !bufferedText) { - bufferedText = record.content; - } - if (!couldStillBePlainTextToolCall(bufferedText, toolNames)) { - flushBufferedTextEvents(); - } - continue; - } - - if (type === "done") { - const promotedMessage = promotePlainTextToolCalls(record?.message, toolNames); - if (promotedMessage) { - bufferedTextEvents.splice(0); - bufferedText = ""; - emitPromotedToolCallEvents(stream, promotedMessage); - stream.push({ ...record, reason: "toolUse", message: promotedMessage }); - } else { - flushBufferedTextEvents(); - stream.push(event); - } - endStream(); - return; - } - - flushBufferedTextEvents(); + const normalizedEvents = normalizePlainTextToolCallStreamEvents( + source as AsyncIterable, + { + createPromotedToolCallEvents: (message) => { + const events: unknown[] = []; + emitPromotedToolCallEvents({ push: (event: unknown) => events.push(event) }, message); + return events; + }, + matcher, + normalizeDoneMessage: ({ message }) => + normalizeProviderDoneMessage(message, toolNames, matcher), + stopAfterDone: true, + }, + ); + for await (const event of normalizedEvents) { stream.push(event); - if (type === "error") { - endStream(); - return; - } } - flushBufferedTextEvents(); } catch (error) { stream.push({ type: "error", diff --git a/src/plugin-sdk/tool-payload.test.ts b/src/plugin-sdk/tool-payload.test.ts index ec29df68e5c..828807f214d 100644 --- a/src/plugin-sdk/tool-payload.test.ts +++ b/src/plugin-sdk/tool-payload.test.ts @@ -121,6 +121,126 @@ describe("parseStandalonePlainTextToolCallBlocks", () => { ]); }); + it("parses serialized parameter XML tool calls", () => { + const firstRaw = [ + "[tool:exec]", + "", + 'cat /proc/mounts 2>/dev/null | grep -i "libra|rav|openclaw" | head -20', + "", + "", + ].join("\n"); + const secondRaw = [ + "", + "", + 'find / -maxdepth 4 -type d \\( -name "ravdb" -o -name "librav" \\) 2>/dev/null | head -20', + "", + "", + ].join("\n"); + const raw = [firstRaw, "", secondRaw].join("\n"); + const blocks = parseStandalonePlainTextToolCallBlocks(raw, { + allowedToolNames: ["exec"], + }); + + expect(blocks).toEqual([ + { + name: "exec", + arguments: { + command: 'cat /proc/mounts 2>/dev/null | grep -i "libra|rav|openclaw" | head -20', + }, + start: 0, + end: firstRaw.length, + raw: firstRaw, + }, + { + name: "exec", + arguments: { + command: + 'find / -maxdepth 4 -type d \\( -name "ravdb" -o -name "librav" \\) 2>/dev/null | head -20', + }, + start: firstRaw.length + 2, + end: raw.length, + raw: secondRaw, + }, + ]); + }); + + it("preserves whitespace inside serialized XML parameter values", () => { + const raw = [ + "", + "", + " first line", + " second line", + "", + "", + "", + ].join("\n"); + const blocks = parseStandalonePlainTextToolCallBlocks(raw, { + allowedToolNames: ["write"], + }); + + expect(blocks?.[0]?.arguments).toEqual({ + content: " first line\n second line\n", + }); + }); + + it("rejects serialized XML parameter calls without a function close", () => { + const raw = ["", "", "pwd", ""].join("\n"); + + expect( + parseStandalonePlainTextToolCallBlocks(raw, { + allowedToolNames: ["exec"], + }), + ).toBeNull(); + }); + + it("parses legacy tool-prefixed XML parameter calls without a function close", () => { + const raw = ["[tool:exec]", "", "pwd", ""].join("\n"); + + expect( + parseStandalonePlainTextToolCallBlocks(raw, { + allowedToolNames: ["exec"], + }), + ).toEqual([ + { + arguments: { command: "pwd" }, + end: raw.length, + name: "exec", + raw, + start: 0, + }, + ]); + }); + + it("finds XML parameter close tags without lowercased string offsets", () => { + const dottedCapitalI = "\u0130"; + const raw = [ + "", + "", + dottedCapitalI, + "", + "", + ].join("\n"); + const blocks = parseStandalonePlainTextToolCallBlocks(raw, { + allowedToolNames: ["write"], + }); + + expect(blocks?.[0]?.arguments).toEqual({ content: dottedCapitalI }); + }); + + it("rejects XML parameter blocks whose cumulative payload exceeds the cap", () => { + const firstParameter = ["", "alpha", ""].join("\n"); + const secondParameter = ["", "beta", ""].join("\n"); + const raw = ["", firstParameter, secondParameter, ""].join("\n"); + const maxPayloadBytes = Math.max(firstParameter.length, secondParameter.length) + 1; + + expect( + parseStandalonePlainTextToolCallBlocks(raw, { + allowedToolNames: ["write"], + maxPayloadBytes, + }), + ).toBeNull(); + }); + it("respects allowed tool names for Harmony calls", () => { const blocks = parseStandalonePlainTextToolCallBlocks( 'commentary to=write code {"path":"/tmp/file.txt","content":"x"}', @@ -170,6 +290,7 @@ describe("stripPlainTextToolCallBlocks", () => { "", 'cat /proc/mounts 2>/dev/null | grep -i "libra|rav|openclaw" | head -20', "", + "", "", "", "", @@ -184,4 +305,59 @@ describe("stripPlainTextToolCallBlocks", () => { ), ).toBe("before\n\nafter"); }); + + it("keeps legacy bracketed XML parameter blocks scrubbed", () => { + expect( + stripPlainTextToolCallBlocks( + [ + "before", + "[exec]", + "", + "pwd", + "", + "", + "after", + ].join("\n"), + ), + ).toBe("before\nafter"); + }); + + it("preserves incomplete XML parameter blocks when stripping visible text", () => { + const text = ["before", "[exec]", "", "pwd", "", "after"].join( + "\n", + ); + + expect(stripPlainTextToolCallBlocks(text)).toBe(text); + }); + + it("strips legacy tool-prefixed XML parameter blocks without a function close", () => { + expect( + stripPlainTextToolCallBlocks( + ["before", "[tool:exec]", "", "pwd", "", "after"].join("\n"), + ), + ).toBe("before\nafter"); + }); + + it("strips oversized XML parameter tool calls without promoting them", () => { + const largeValue = "x".repeat(140_000); + const block = [ + "", + "", + largeValue, + "", + "", + largeValue, + "", + "", + ].join("\n"); + + expect( + parseStandalonePlainTextToolCallBlocks(block, { + allowedToolNames: ["write"], + }), + ).toBeNull(); + expect(stripPlainTextToolCallBlocks(["before", block, "after"].join("\n"))).toBe( + "before\nafter", + ); + }); }); diff --git a/src/plugin-sdk/tool-payload.ts b/src/plugin-sdk/tool-payload.ts index 40bff0f067e..a569f6acbf4 100644 --- a/src/plugin-sdk/tool-payload.ts +++ b/src/plugin-sdk/tool-payload.ts @@ -1,3 +1,32 @@ +import { + parseStandalonePlainTextToolCallBlocks as parseStandaloneRepairToolCallBlocks, + stripPlainTextToolCallBlocks as stripRepairToolCallBlocks, +} from "../../packages/tool-call-repair/src/index.js"; + +export type PlainTextToolCallBlock = { + arguments: Record; + end: number; + name: string; + raw: string; + start: number; +}; + +export type PlainTextToolCallParseOptions = { + allowedToolNames?: Iterable; + maxPayloadBytes?: number; +}; + +export function parseStandalonePlainTextToolCallBlocks( + text: string, + options?: PlainTextToolCallParseOptions, +): PlainTextToolCallBlock[] | null { + return parseStandaloneRepairToolCallBlocks(text, options); +} + +export function stripPlainTextToolCallBlocks(text: string): string { + return stripRepairToolCallBlocks(text); +} + type ToolPayloadTextBlock = { type: "text"; text: string; @@ -41,378 +70,3 @@ export function extractToolPayload(result: ToolPayloadCarrier | null | undefined return text; } } - -export type PlainTextToolCallBlock = { - arguments: Record; - end: number; - name: string; - raw: string; - start: number; -}; - -export type PlainTextToolCallParseOptions = { - allowedToolNames?: Iterable; - maxPayloadBytes?: number; -}; - -const DEFAULT_MAX_PLAIN_TEXT_TOOL_PAYLOAD_BYTES = 256_000; -const END_TOOL_REQUEST = "[END_TOOL_REQUEST]"; -const HARMONY_CHANNEL_MARKER = "<|channel|>"; -const HARMONY_MESSAGE_MARKER = "<|message|>"; -const HARMONY_CALL_MARKER = "<|call|>"; -const XMLISH_PARAMETER_CLOSE = ""; - -type PlainTextToolCallOpening = { - end: number; - name: string; - requiresClosing: boolean; -}; - -function isToolNameChar(char: string | undefined): boolean { - return Boolean(char && /[A-Za-z0-9_-]/.test(char)); -} - -function skipHorizontalWhitespace(text: string, start: number): number { - let index = start; - while (index < text.length && (text[index] === " " || text[index] === "\t")) { - index += 1; - } - return index; -} - -function skipWhitespace(text: string, start: number): number { - let index = start; - while (index < text.length && /\s/.test(text[index] ?? "")) { - index += 1; - } - return index; -} - -function consumeLineBreak(text: string, start: number): number | null { - if (text[start] === "\r") { - return text[start + 1] === "\n" ? start + 2 : start + 1; - } - if (text[start] === "\n") { - return start + 1; - } - return null; -} - -function parseBracketOpening(text: string, start: number): PlainTextToolCallOpening | null { - if (text[start] !== "[") { - return null; - } - let cursor = start + 1; - if (text.startsWith("tool:", cursor)) { - cursor += "tool:".length; - const nameStart = cursor; - while (isToolNameChar(text[cursor])) { - cursor += 1; - } - if (cursor === nameStart || text[cursor] !== "]") { - return null; - } - return { end: cursor + 1, name: text.slice(nameStart, cursor), requiresClosing: false }; - } - const nameStart = cursor; - while (isToolNameChar(text[cursor])) { - cursor += 1; - } - if (cursor === nameStart || text[cursor] !== "]") { - return null; - } - const name = text.slice(nameStart, cursor); - cursor += 1; - cursor = skipHorizontalWhitespace(text, cursor); - const afterLineBreak = consumeLineBreak(text, cursor); - if (afterLineBreak === null) { - return null; - } - return { end: afterLineBreak, name, requiresClosing: true }; -} - -function parseHarmonyOpening(text: string, start: number): PlainTextToolCallOpening | null { - let cursor = start; - if (text.startsWith(HARMONY_CHANNEL_MARKER, cursor)) { - cursor += HARMONY_CHANNEL_MARKER.length; - } - const channelStart = cursor; - while (/[A-Za-z_]/.test(text[cursor] ?? "")) { - cursor += 1; - } - const channel = text.slice(channelStart, cursor); - if (channel !== "commentary" && channel !== "analysis" && channel !== "final") { - return null; - } - cursor = skipHorizontalWhitespace(text, cursor); - if (!text.startsWith("to=", cursor)) { - return null; - } - cursor += 3; - const nameStart = cursor; - while (isToolNameChar(text[cursor])) { - cursor += 1; - } - if (cursor === nameStart) { - return null; - } - const name = text.slice(nameStart, cursor); - cursor = skipHorizontalWhitespace(text, cursor); - if (!text.startsWith("code", cursor)) { - return null; - } - cursor += 4; - cursor = skipWhitespace(text, cursor); - if (text.startsWith(HARMONY_MESSAGE_MARKER, cursor)) { - cursor = skipWhitespace(text, cursor + HARMONY_MESSAGE_MARKER.length); - } - return { end: cursor, name, requiresClosing: false }; -} - -function parseXmlishFunctionOpening(text: string, start: number): PlainTextToolCallOpening | null { - const match = /^\s*/i.exec(text.slice(start)); - if (!match?.[1]) { - return null; - } - return { end: start + match[0].length, name: match[1], requiresClosing: false }; -} - -function parseOpening(text: string, start: number): PlainTextToolCallOpening | null { - return parseBracketOpening(text, start) ?? parseHarmonyOpening(text, start); -} - -function consumeJsonObject( - text: string, - start: number, - maxPayloadBytes: number, -): { end: number; value: Record } | null { - const cursor = skipWhitespace(text, start); - if (text[cursor] !== "{") { - return null; - } - let depth = 0; - let inString = false; - let escaped = false; - for (let index = cursor; index < text.length; index += 1) { - const char = text[index]; - if (index + 1 - cursor > maxPayloadBytes) { - return null; - } - if (inString) { - if (escaped) { - escaped = false; - } else if (char === "\\") { - escaped = true; - } else if (char === '"') { - inString = false; - } - continue; - } - if (char === '"') { - inString = true; - continue; - } - if (char === "{") { - depth += 1; - } else if (char === "}") { - depth -= 1; - if (depth === 0) { - const rawJson = text.slice(cursor, index + 1); - try { - const parsed = JSON.parse(rawJson) as unknown; - if (!parsed || typeof parsed !== "object" || Array.isArray(parsed)) { - return null; - } - return { end: index + 1, value: parsed as Record }; - } catch { - return null; - } - } - } - } - return null; -} - -function parseClosing(text: string, start: number, name: string): number | null { - const cursor = skipWhitespace(text, start); - if (text.startsWith(END_TOOL_REQUEST, cursor)) { - return cursor + END_TOOL_REQUEST.length; - } - const namedClosing = `[/${name}]`; - if (text.startsWith(namedClosing, cursor)) { - return cursor + namedClosing.length; - } - return null; -} - -function parseOptionalHarmonyClosing(text: string, start: number): number { - const cursor = skipWhitespace(text, start); - if (text.startsWith(HARMONY_CALL_MARKER, cursor)) { - return cursor + HARMONY_CALL_MARKER.length; - } - return start; -} - -function parsePlainTextToolCallBlockAt( - text: string, - start: number, - options?: PlainTextToolCallParseOptions, -): PlainTextToolCallBlock | null { - const opening = parseOpening(text, start); - if (!opening) { - return null; - } - const allowedToolNames = options?.allowedToolNames - ? new Set(options.allowedToolNames) - : undefined; - if (allowedToolNames && !allowedToolNames.has(opening.name)) { - return null; - } - const payload = consumeJsonObject( - text, - opening.end, - options?.maxPayloadBytes ?? DEFAULT_MAX_PLAIN_TEXT_TOOL_PAYLOAD_BYTES, - ); - if (!payload) { - return null; - } - const closingEnd = opening.requiresClosing - ? parseClosing(text, payload.end, opening.name) - : parseOptionalHarmonyClosing(text, payload.end); - if (closingEnd === null) { - return null; - } - return { - arguments: payload.value, - end: closingEnd, - name: opening.name, - raw: text.slice(start, closingEnd), - start, - }; -} - -function consumeXmlishParameterBlock( - text: string, - start: number, - maxPayloadBytes: number, -): number | null { - const cursor = skipWhitespace(text, start); - const openMatch = /^\s*/i.exec(text.slice(cursor)); - if (!openMatch) { - return null; - } - const payloadStart = cursor + openMatch[0].length; - const closeStart = text.toLowerCase().indexOf(XMLISH_PARAMETER_CLOSE, payloadStart); - if (closeStart === -1 || closeStart + XMLISH_PARAMETER_CLOSE.length - cursor > maxPayloadBytes) { - return null; - } - return closeStart + XMLISH_PARAMETER_CLOSE.length; -} - -function consumeXmlishParameterBlocks( - text: string, - start: number, - maxPayloadBytes: number, -): number | null { - let cursor = start; - let consumed = false; - while (true) { - const next = consumeXmlishParameterBlock(text, cursor, maxPayloadBytes); - if (next === null) { - break; - } - if (next - start > maxPayloadBytes) { - return null; - } - cursor = next; - consumed = true; - } - return consumed ? cursor : null; -} - -function consumeOptionalXmlishFunctionClose(text: string, start: number): number { - const cursor = skipWhitespace(text, start); - return text.slice(cursor).toLowerCase().startsWith("") - ? cursor + "".length - : start; -} - -function parseXmlishPlainTextToolCallBlockEndAt( - text: string, - start: number, - options?: PlainTextToolCallParseOptions, -): number | null { - const opening = parseBracketOpening(text, start) ?? parseXmlishFunctionOpening(text, start); - if (!opening) { - return null; - } - const allowedToolNames = options?.allowedToolNames - ? new Set(options.allowedToolNames) - : undefined; - if (allowedToolNames && !allowedToolNames.has(opening.name)) { - return null; - } - const payloadEnd = consumeXmlishParameterBlocks( - text, - opening.end, - options?.maxPayloadBytes ?? DEFAULT_MAX_PLAIN_TEXT_TOOL_PAYLOAD_BYTES, - ); - if (payloadEnd === null) { - return null; - } - return consumeOptionalXmlishFunctionClose(text, payloadEnd); -} - -export function parseStandalonePlainTextToolCallBlocks( - text: string, - options?: PlainTextToolCallParseOptions, -): PlainTextToolCallBlock[] | null { - const blocks: PlainTextToolCallBlock[] = []; - let cursor = skipWhitespace(text, 0); - while (cursor < text.length) { - const block = parsePlainTextToolCallBlockAt(text, cursor, options); - if (!block) { - return null; - } - blocks.push(block); - cursor = skipWhitespace(text, block.end); - } - return blocks.length > 0 ? blocks : null; -} - -export function stripPlainTextToolCallBlocks(text: string): string { - if ( - !text || - (!/\[(?:tool:)?[A-Za-z0-9_-]+\]/.test(text) && - !/(?:^|\n)\s*(?:<\|channel\|>)?(?:commentary|analysis|final)\s+to=/.test(text) && - !/(?:^|\n)\s*/i.test(text)) - ) { - return text; - } - let result = ""; - let cursor = 0; - let index = 0; - while (index < text.length) { - const lineStart = index === 0 || text[index - 1] === "\n"; - if (!lineStart) { - index += 1; - continue; - } - const blockStart = skipHorizontalWhitespace(text, index); - const block = parsePlainTextToolCallBlockAt(text, blockStart); - const blockEnd = block?.end ?? parseXmlishPlainTextToolCallBlockEndAt(text, blockStart); - if (blockEnd === null) { - index += 1; - continue; - } - result += text.slice(cursor, index); - cursor = blockEnd; - const afterBlockLineBreak = consumeLineBreak(text, cursor); - if (afterBlockLineBreak !== null) { - cursor = afterBlockLineBreak; - } - index = cursor; - } - result += text.slice(cursor); - return result; -} diff --git a/src/shared/text/plain-text-tool-call-blocks.ts b/src/shared/text/plain-text-tool-call-blocks.ts index 3ecf9dd82b0..f4a6a994b8f 100644 --- a/src/shared/text/plain-text-tool-call-blocks.ts +++ b/src/shared/text/plain-text-tool-call-blocks.ts @@ -1 +1 @@ -export { stripPlainTextToolCallBlocks } from "../../plugin-sdk/tool-payload.js"; +export { stripPlainTextToolCallBlocks } from "../../../packages/tool-call-repair/src/index.js";