refactor: share message content block visitor

This commit is contained in:
Peter Steinberger
2026-04-20 14:53:31 +01:00
parent 17c77f1307
commit 85c1c59c5f
3 changed files with 27 additions and 27 deletions

View File

@@ -1,5 +1,6 @@
import type { AgentMessage, StreamFn } from "@mariozechner/pi-agent-core";
import { streamSimple } from "@mariozechner/pi-ai";
import { visitObjectContentBlocks } from "../../../shared/message-content-blocks.js";
import { normalizeLowercaseStringOrEmpty } from "../../../shared/string-coerce.js";
import { validateAnthropicTurns, validateGeminiTurns } from "../../pi-embedded-helpers.js";
import { sanitizeToolUseResultPairing } from "../../session-transcript-repair.js";
@@ -586,20 +587,10 @@ function trimWhitespaceFromToolCallNamesInMessage(
message: unknown,
allowedToolNames?: Set<string>,
): void {
if (!message || typeof message !== "object") {
return;
}
const content = (message as { content?: unknown }).content;
if (!Array.isArray(content)) {
return;
}
for (const block of content) {
if (!block || typeof block !== "object") {
continue;
}
visitObjectContentBlocks(message, (block) => {
const typedBlock = block as { type?: unknown; name?: unknown; id?: unknown };
if (!isToolCallBlockType(typedBlock.type)) {
continue;
return;
}
const rawId = typeof typedBlock.id === "string" ? typedBlock.id : undefined;
if (typeof typedBlock.name === "string") {
@@ -607,13 +598,13 @@ function trimWhitespaceFromToolCallNamesInMessage(
if (normalized !== typedBlock.name) {
typedBlock.name = normalized;
}
continue;
return;
}
const inferred = inferToolNameFromToolCallId(rawId, allowedToolNames);
if (inferred) {
typedBlock.name = inferred;
}
}
});
normalizeToolCallIdsInMessage(message);
}

View File

@@ -1,6 +1,7 @@
import type { StreamFn } from "@mariozechner/pi-agent-core";
import { streamSimple } from "@mariozechner/pi-ai";
import { streamWithPayloadPatch } from "../agents/pi-embedded-runner/stream-payload-utils.js";
import { visitObjectContentBlocks } from "../shared/message-content-blocks.js";
import { normalizeLowercaseStringOrEmpty } from "../shared/string-coerce.js";
import type { ProviderWrapStreamFnContext } from "./plugin-entry.js";
@@ -64,25 +65,15 @@ export function decodeHtmlEntitiesInObject(value: unknown): unknown {
}
function decodeToolCallArgumentsHtmlEntitiesInMessage(message: unknown): void {
if (!message || typeof message !== "object") {
return;
}
const content = (message as { content?: unknown }).content;
if (!Array.isArray(content)) {
return;
}
for (const block of content) {
if (!block || typeof block !== "object") {
continue;
}
visitObjectContentBlocks(message, (block) => {
const typedBlock = block as { type?: unknown; arguments?: unknown };
if (typedBlock.type !== "toolCall" || !typedBlock.arguments) {
continue;
return;
}
if (typeof typedBlock.arguments === "object") {
typedBlock.arguments = decodeHtmlEntitiesInObject(typedBlock.arguments);
}
}
});
}
function wrapStreamDecodeToolCallArgumentHtmlEntities(

View File

@@ -0,0 +1,18 @@
export function visitObjectContentBlocks(
message: unknown,
visitor: (block: Record<string, unknown>) => void,
): void {
if (!message || typeof message !== "object") {
return;
}
const content = (message as { content?: unknown }).content;
if (!Array.isArray(content)) {
return;
}
for (const block of content) {
if (!block || typeof block !== "object") {
continue;
}
visitor(block as Record<string, unknown>);
}
}