refactor: split embedded attempt helpers

This commit is contained in:
Peter Steinberger
2026-03-27 00:19:09 +00:00
parent 48ae976333
commit 0d9e4f20d5
6 changed files with 1517 additions and 1452 deletions

View File

@@ -0,0 +1,100 @@
import type { StreamFn } from "@mariozechner/pi-agent-core";
import { streamSimple } from "@mariozechner/pi-ai";
import type { OpenClawConfig } from "../../../config/config.js";
import { normalizeProviderId } from "../../model-selection.js";
export function isOllamaCompatProvider(model: {
provider?: string;
baseUrl?: string;
api?: string;
}): boolean {
const providerId = normalizeProviderId(model.provider ?? "");
if (providerId === "ollama") {
return true;
}
if (!model.baseUrl) {
return false;
}
try {
const parsed = new URL(model.baseUrl);
const hostname = parsed.hostname.toLowerCase();
const isLocalhost =
hostname === "localhost" ||
hostname === "127.0.0.1" ||
hostname === "::1" ||
hostname === "[::1]";
if (isLocalhost && parsed.port === "11434") {
return true;
}
// Allow remote/LAN Ollama OpenAI-compatible endpoints when the provider id
// itself indicates Ollama usage (e.g. "my-ollama").
const providerHintsOllama = providerId.includes("ollama");
const isOllamaPort = parsed.port === "11434";
const isOllamaCompatPath = parsed.pathname === "/" || /^\/v1\/?$/i.test(parsed.pathname);
return providerHintsOllama && isOllamaPort && isOllamaCompatPath;
} catch {
return false;
}
}
export function resolveOllamaCompatNumCtxEnabled(params: {
config?: OpenClawConfig;
providerId?: string;
}): boolean {
const providerId = params.providerId?.trim();
if (!providerId) {
return true;
}
const providers = params.config?.models?.providers;
if (!providers) {
return true;
}
const direct = providers[providerId];
if (direct) {
return direct.injectNumCtxForOpenAICompat ?? true;
}
const normalized = normalizeProviderId(providerId);
for (const [candidateId, candidate] of Object.entries(providers)) {
if (normalizeProviderId(candidateId) === normalized) {
return candidate.injectNumCtxForOpenAICompat ?? true;
}
}
return true;
}
export function shouldInjectOllamaCompatNumCtx(params: {
model: { api?: string; provider?: string; baseUrl?: string };
config?: OpenClawConfig;
providerId?: string;
}): boolean {
if (params.model.api !== "openai-completions") {
return false;
}
if (!isOllamaCompatProvider(params.model)) {
return false;
}
return resolveOllamaCompatNumCtxEnabled({
config: params.config,
providerId: params.providerId,
});
}
export function wrapOllamaCompatNumCtx(baseFn: StreamFn | undefined, numCtx: number): StreamFn {
const streamFn = baseFn ?? streamSimple;
return (model, context, options) =>
streamFn(model, context, {
...options,
onPayload: (payload: unknown) => {
if (!payload || typeof payload !== "object") {
return options?.onPayload?.(payload, model);
}
const payloadRecord = payload as Record<string, unknown>;
if (!payloadRecord.options || typeof payloadRecord.options !== "object") {
payloadRecord.options = {};
}
(payloadRecord.options as Record<string, unknown>).num_ctx = numCtx;
return options?.onPayload?.(payload, model);
},
});
}

View File

@@ -0,0 +1,168 @@
import type { OpenClawConfig } from "../../../config/config.js";
import type {
PluginHookAgentContext,
PluginHookBeforeAgentStartResult,
PluginHookBeforePromptBuildResult,
} from "../../../plugins/types.js";
import { isCronSessionKey, isSubagentSessionKey } from "../../../routing/session-key.js";
import { joinPresentTextSegments } from "../../../shared/text/join-segments.js";
import { resolveEffectiveToolFsWorkspaceOnly } from "../../tool-fs-policy.js";
import type { CompactEmbeddedPiSessionParams } from "../compact.js";
import { buildEmbeddedCompactionRuntimeContext } from "../compaction-runtime-context.js";
import { log } from "../logger.js";
import { shouldInjectHeartbeatPromptForTrigger } from "./trigger-policy.js";
import type { EmbeddedRunAttemptParams } from "./types.js";
export type PromptBuildHookRunner = {
hasHooks: (hookName: "before_prompt_build" | "before_agent_start") => boolean;
runBeforePromptBuild: (
event: { prompt: string; messages: unknown[] },
ctx: PluginHookAgentContext,
) => Promise<PluginHookBeforePromptBuildResult | undefined>;
runBeforeAgentStart: (
event: { prompt: string; messages: unknown[] },
ctx: PluginHookAgentContext,
) => Promise<PluginHookBeforeAgentStartResult | undefined>;
};
export async function resolvePromptBuildHookResult(params: {
prompt: string;
messages: unknown[];
hookCtx: PluginHookAgentContext;
hookRunner?: PromptBuildHookRunner | null;
legacyBeforeAgentStartResult?: PluginHookBeforeAgentStartResult;
}): Promise<PluginHookBeforePromptBuildResult> {
const promptBuildResult = params.hookRunner?.hasHooks("before_prompt_build")
? await params.hookRunner
.runBeforePromptBuild(
{
prompt: params.prompt,
messages: params.messages,
},
params.hookCtx,
)
.catch((hookErr: unknown) => {
log.warn(`before_prompt_build hook failed: ${String(hookErr)}`);
return undefined;
})
: undefined;
const legacyResult =
params.legacyBeforeAgentStartResult ??
(params.hookRunner?.hasHooks("before_agent_start")
? await params.hookRunner
.runBeforeAgentStart(
{
prompt: params.prompt,
messages: params.messages,
},
params.hookCtx,
)
.catch((hookErr: unknown) => {
log.warn(
`before_agent_start hook (legacy prompt build path) failed: ${String(hookErr)}`,
);
return undefined;
})
: undefined);
return {
systemPrompt: promptBuildResult?.systemPrompt ?? legacyResult?.systemPrompt,
prependContext: joinPresentTextSegments([
promptBuildResult?.prependContext,
legacyResult?.prependContext,
]),
prependSystemContext: joinPresentTextSegments([
promptBuildResult?.prependSystemContext,
legacyResult?.prependSystemContext,
]),
appendSystemContext: joinPresentTextSegments([
promptBuildResult?.appendSystemContext,
legacyResult?.appendSystemContext,
]),
};
}
export function resolvePromptModeForSession(sessionKey?: string): "minimal" | "full" {
if (!sessionKey) {
return "full";
}
return isSubagentSessionKey(sessionKey) || isCronSessionKey(sessionKey) ? "minimal" : "full";
}
export function shouldInjectHeartbeatPrompt(params: {
isDefaultAgent: boolean;
trigger?: EmbeddedRunAttemptParams["trigger"];
}): boolean {
return params.isDefaultAgent && shouldInjectHeartbeatPromptForTrigger(params.trigger);
}
export function resolveAttemptFsWorkspaceOnly(params: {
config?: OpenClawConfig;
sessionAgentId: string;
}): boolean {
return resolveEffectiveToolFsWorkspaceOnly({
cfg: params.config,
agentId: params.sessionAgentId,
});
}
export function prependSystemPromptAddition(params: {
systemPrompt: string;
systemPromptAddition?: string;
}): string {
if (!params.systemPromptAddition) {
return params.systemPrompt;
}
return `${params.systemPromptAddition}\n\n${params.systemPrompt}`;
}
/** Build runtime context passed into context-engine afterTurn hooks. */
export function buildAfterTurnRuntimeContext(params: {
attempt: Pick<
EmbeddedRunAttemptParams,
| "sessionKey"
| "messageChannel"
| "messageProvider"
| "agentAccountId"
| "currentChannelId"
| "currentThreadTs"
| "currentMessageId"
| "config"
| "skillsSnapshot"
| "senderIsOwner"
| "senderId"
| "provider"
| "modelId"
| "thinkLevel"
| "reasoningLevel"
| "bashElevated"
| "extraSystemPrompt"
| "ownerNumbers"
| "authProfileId"
>;
workspaceDir: string;
agentDir: string;
}): Partial<CompactEmbeddedPiSessionParams> {
return buildEmbeddedCompactionRuntimeContext({
sessionKey: params.attempt.sessionKey,
messageChannel: params.attempt.messageChannel,
messageProvider: params.attempt.messageProvider,
agentAccountId: params.attempt.agentAccountId,
currentChannelId: params.attempt.currentChannelId,
currentThreadTs: params.attempt.currentThreadTs,
currentMessageId: params.attempt.currentMessageId,
authProfileId: params.attempt.authProfileId,
workspaceDir: params.workspaceDir,
agentDir: params.agentDir,
config: params.attempt.config,
skillsSnapshot: params.attempt.skillsSnapshot,
senderIsOwner: params.attempt.senderIsOwner,
senderId: params.attempt.senderId,
provider: params.attempt.provider,
modelId: params.attempt.modelId,
thinkLevel: params.attempt.thinkLevel,
reasoningLevel: params.attempt.reasoningLevel,
bashElevated: params.attempt.bashElevated,
extraSystemPrompt: params.attempt.extraSystemPrompt,
ownerNumbers: params.attempt.ownerNumbers,
});
}

View File

@@ -0,0 +1,221 @@
import type { AgentMessage } from "@mariozechner/pi-agent-core";
import { log } from "../logger.js";
const SESSIONS_YIELD_INTERRUPT_CUSTOM_TYPE = "openclaw.sessions_yield_interrupt";
const SESSIONS_YIELD_CONTEXT_CUSTOM_TYPE = "openclaw.sessions_yield";
const SESSIONS_YIELD_ABORT_SETTLE_TIMEOUT_MS = process.env.OPENCLAW_TEST_FAST === "1" ? 250 : 2_000;
// Persist a hidden context reminder so the next turn knows why the runner stopped.
export function buildSessionsYieldContextMessage(message: string): string {
return `${message}\n\n[Context: The previous turn ended intentionally via sessions_yield while waiting for a follow-up event.]`;
}
export async function waitForSessionsYieldAbortSettle(params: {
settlePromise: Promise<void> | null;
runId: string;
sessionId: string;
}): Promise<void> {
if (!params.settlePromise) {
return;
}
let timeout: NodeJS.Timeout | undefined;
const outcome = await Promise.race([
params.settlePromise
.then(() => "settled" as const)
.catch((err) => {
log.warn(
`sessions_yield abort settle failed: runId=${params.runId} sessionId=${params.sessionId} err=${String(err)}`,
);
return "errored" as const;
}),
new Promise<"timed_out">((resolve) => {
timeout = setTimeout(() => resolve("timed_out"), SESSIONS_YIELD_ABORT_SETTLE_TIMEOUT_MS);
}),
]);
if (timeout) {
clearTimeout(timeout);
}
if (outcome === "timed_out") {
log.warn(
`sessions_yield abort settle timed out: runId=${params.runId} sessionId=${params.sessionId} timeoutMs=${SESSIONS_YIELD_ABORT_SETTLE_TIMEOUT_MS}`,
);
}
}
// Return a synthetic aborted response so pi-agent-core unwinds without a real provider call.
export function createYieldAbortedResponse(model: {
api?: string;
provider?: string;
id?: string;
}): {
[Symbol.asyncIterator]: () => AsyncGenerator<never, void, unknown>;
result: () => Promise<{
role: "assistant";
content: Array<{ type: "text"; text: string }>;
stopReason: "aborted";
api: string;
provider: string;
model: string;
usage: {
input: number;
output: number;
cacheRead: number;
cacheWrite: number;
totalTokens: number;
cost: {
input: number;
output: number;
cacheRead: number;
cacheWrite: number;
total: number;
};
};
timestamp: number;
}>;
} {
const message = {
role: "assistant" as const,
content: [{ type: "text" as const, text: "" }],
stopReason: "aborted" as const,
api: model.api ?? "",
provider: model.provider ?? "",
model: model.id ?? "",
usage: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
total: 0,
},
},
timestamp: Date.now(),
};
return {
async *[Symbol.asyncIterator]() {},
result: async () => message,
};
}
// Queue a hidden steering message so pi-agent-core injects it before the next
// LLM call once the current assistant turn finishes executing its tool calls.
export function queueSessionsYieldInterruptMessage(activeSession: {
agent: { steer: (message: AgentMessage) => void };
}) {
activeSession.agent.steer({
role: "custom",
customType: SESSIONS_YIELD_INTERRUPT_CUSTOM_TYPE,
content: "[sessions_yield interrupt]",
display: false,
details: { source: "sessions_yield" },
timestamp: Date.now(),
});
}
// Append the caller-provided yield payload as a hidden session message once the run is idle.
export async function persistSessionsYieldContextMessage(
activeSession: {
sendCustomMessage: (
message: {
customType: string;
content: string;
display: boolean;
details?: Record<string, unknown>;
},
options?: { triggerTurn?: boolean },
) => Promise<void>;
},
message: string,
) {
await activeSession.sendCustomMessage(
{
customType: SESSIONS_YIELD_CONTEXT_CUSTOM_TYPE,
content: buildSessionsYieldContextMessage(message),
display: false,
details: { source: "sessions_yield", message },
},
{ triggerTurn: false },
);
}
// Remove the synthetic yield interrupt + aborted assistant entry from the live transcript.
export function stripSessionsYieldArtifacts(activeSession: {
messages: AgentMessage[];
agent: { replaceMessages: (messages: AgentMessage[]) => void };
sessionManager?: unknown;
}) {
const strippedMessages = activeSession.messages.slice();
while (strippedMessages.length > 0) {
const last = strippedMessages.at(-1) as
| AgentMessage
| { role?: string; customType?: string; stopReason?: string };
if (last?.role === "assistant" && "stopReason" in last && last.stopReason === "aborted") {
strippedMessages.pop();
continue;
}
if (
last?.role === "custom" &&
"customType" in last &&
last.customType === SESSIONS_YIELD_INTERRUPT_CUSTOM_TYPE
) {
strippedMessages.pop();
continue;
}
break;
}
if (strippedMessages.length !== activeSession.messages.length) {
activeSession.agent.replaceMessages(strippedMessages);
}
const sessionManager = activeSession.sessionManager as
| {
fileEntries?: Array<{
type?: string;
id?: string;
parentId?: string | null;
message?: { role?: string; stopReason?: string };
customType?: string;
}>;
byId?: Map<string, { id: string }>;
leafId?: string | null;
_rewriteFile?: () => void;
}
| undefined;
const fileEntries = sessionManager?.fileEntries;
const byId = sessionManager?.byId;
if (!fileEntries || !byId) {
return;
}
let changed = false;
while (fileEntries.length > 1) {
const last = fileEntries.at(-1);
if (!last || last.type === "session") {
break;
}
const isYieldAbortAssistant =
last.type === "message" &&
last.message?.role === "assistant" &&
last.message?.stopReason === "aborted";
const isYieldInterruptMessage =
last.type === "custom_message" && last.customType === SESSIONS_YIELD_INTERRUPT_CUSTOM_TYPE;
if (!isYieldAbortAssistant && !isYieldInterruptMessage) {
break;
}
fileEntries.pop();
if (last.id) {
byId.delete(last.id);
}
sessionManager.leafId = last.parentId ?? null;
changed = true;
}
if (changed) {
sessionManager._rewriteFile?.();
}
}

View File

@@ -0,0 +1,383 @@
import type { StreamFn } from "@mariozechner/pi-agent-core";
import { streamSimple } from "@mariozechner/pi-ai";
import { normalizeProviderId } from "../../model-selection.js";
import { log } from "../logger.js";
function isToolCallBlockType(type: unknown): boolean {
return type === "toolCall" || type === "toolUse" || type === "functionCall";
}
function extractBalancedJsonPrefix(raw: string): string | null {
let start = 0;
while (start < raw.length && /\s/.test(raw[start] ?? "")) {
start += 1;
}
const startChar = raw[start];
if (startChar !== "{" && startChar !== "[") {
return null;
}
let depth = 0;
let inString = false;
let escaped = false;
for (let i = start; i < raw.length; i += 1) {
const char = raw[i];
if (char === undefined) {
break;
}
if (inString) {
if (escaped) {
escaped = false;
} else if (char === "\\") {
escaped = true;
} else if (char === '"') {
inString = false;
}
continue;
}
if (char === '"') {
inString = true;
continue;
}
if (char === "{" || char === "[") {
depth += 1;
continue;
}
if (char === "}" || char === "]") {
depth -= 1;
if (depth === 0) {
return raw.slice(start, i + 1);
}
}
}
return null;
}
const MAX_TOOLCALL_REPAIR_BUFFER_CHARS = 64_000;
const MAX_TOOLCALL_REPAIR_TRAILING_CHARS = 3;
const TOOLCALL_REPAIR_ALLOWED_TRAILING_RE = /^[^\s{}[\]":,\\]{1,3}$/;
function shouldAttemptMalformedToolCallRepair(partialJson: string, delta: string): boolean {
if (/[}\]]/.test(delta)) {
return true;
}
const trimmedDelta = delta.trim();
return (
trimmedDelta.length > 0 &&
trimmedDelta.length <= MAX_TOOLCALL_REPAIR_TRAILING_CHARS &&
/[}\]]/.test(partialJson)
);
}
type ToolCallArgumentRepair = {
args: Record<string, unknown>;
trailingSuffix: string;
};
function tryParseMalformedToolCallArguments(raw: string): ToolCallArgumentRepair | undefined {
if (!raw.trim()) {
return undefined;
}
try {
JSON.parse(raw);
return undefined;
} catch {
const jsonPrefix = extractBalancedJsonPrefix(raw);
if (!jsonPrefix) {
return undefined;
}
const suffix = raw.slice(raw.indexOf(jsonPrefix) + jsonPrefix.length).trim();
if (
suffix.length === 0 ||
suffix.length > MAX_TOOLCALL_REPAIR_TRAILING_CHARS ||
!TOOLCALL_REPAIR_ALLOWED_TRAILING_RE.test(suffix)
) {
return undefined;
}
try {
const parsed = JSON.parse(jsonPrefix) as unknown;
return parsed && typeof parsed === "object" && !Array.isArray(parsed)
? { args: parsed as Record<string, unknown>, trailingSuffix: suffix }
: undefined;
} catch {
return undefined;
}
}
}
function repairToolCallArgumentsInMessage(
message: unknown,
contentIndex: number,
repairedArgs: Record<string, unknown>,
): void {
if (!message || typeof message !== "object") {
return;
}
const content = (message as { content?: unknown }).content;
if (!Array.isArray(content)) {
return;
}
const block = content[contentIndex];
if (!block || typeof block !== "object") {
return;
}
const typedBlock = block as { type?: unknown; arguments?: unknown };
if (!isToolCallBlockType(typedBlock.type)) {
return;
}
typedBlock.arguments = repairedArgs;
}
function clearToolCallArgumentsInMessage(message: unknown, contentIndex: number): void {
if (!message || typeof message !== "object") {
return;
}
const content = (message as { content?: unknown }).content;
if (!Array.isArray(content)) {
return;
}
const block = content[contentIndex];
if (!block || typeof block !== "object") {
return;
}
const typedBlock = block as { type?: unknown; arguments?: unknown };
if (!isToolCallBlockType(typedBlock.type)) {
return;
}
typedBlock.arguments = {};
}
function repairMalformedToolCallArgumentsInMessage(
message: unknown,
repairedArgsByIndex: Map<number, Record<string, unknown>>,
): void {
if (!message || typeof message !== "object") {
return;
}
const content = (message as { content?: unknown }).content;
if (!Array.isArray(content)) {
return;
}
for (const [index, repairedArgs] of repairedArgsByIndex.entries()) {
repairToolCallArgumentsInMessage(message, index, repairedArgs);
}
}
function wrapStreamRepairMalformedToolCallArguments(
stream: ReturnType<typeof streamSimple>,
): ReturnType<typeof streamSimple> {
const partialJsonByIndex = new Map<number, string>();
const repairedArgsByIndex = new Map<number, Record<string, unknown>>();
const disabledIndices = new Set<number>();
const loggedRepairIndices = new Set<number>();
const originalResult = stream.result.bind(stream);
stream.result = async () => {
const message = await originalResult();
repairMalformedToolCallArgumentsInMessage(message, repairedArgsByIndex);
partialJsonByIndex.clear();
repairedArgsByIndex.clear();
disabledIndices.clear();
loggedRepairIndices.clear();
return message;
};
const originalAsyncIterator = stream[Symbol.asyncIterator].bind(stream);
(stream as { [Symbol.asyncIterator]: typeof originalAsyncIterator })[Symbol.asyncIterator] =
function () {
const iterator = originalAsyncIterator();
return {
async next() {
const result = await iterator.next();
if (!result.done && result.value && typeof result.value === "object") {
const event = result.value as {
type?: unknown;
contentIndex?: unknown;
delta?: unknown;
partial?: unknown;
message?: unknown;
toolCall?: unknown;
};
if (
typeof event.contentIndex === "number" &&
Number.isInteger(event.contentIndex) &&
event.type === "toolcall_delta" &&
typeof event.delta === "string"
) {
if (disabledIndices.has(event.contentIndex)) {
return result;
}
const nextPartialJson =
(partialJsonByIndex.get(event.contentIndex) ?? "") + event.delta;
if (nextPartialJson.length > MAX_TOOLCALL_REPAIR_BUFFER_CHARS) {
partialJsonByIndex.delete(event.contentIndex);
repairedArgsByIndex.delete(event.contentIndex);
disabledIndices.add(event.contentIndex);
return result;
}
partialJsonByIndex.set(event.contentIndex, nextPartialJson);
if (shouldAttemptMalformedToolCallRepair(nextPartialJson, event.delta)) {
const repair = tryParseMalformedToolCallArguments(nextPartialJson);
if (repair) {
repairedArgsByIndex.set(event.contentIndex, repair.args);
repairToolCallArgumentsInMessage(event.partial, event.contentIndex, repair.args);
repairToolCallArgumentsInMessage(event.message, event.contentIndex, repair.args);
if (!loggedRepairIndices.has(event.contentIndex)) {
loggedRepairIndices.add(event.contentIndex);
log.warn(
`repairing Kimi tool call arguments after ${repair.trailingSuffix.length} trailing chars`,
);
}
} else {
repairedArgsByIndex.delete(event.contentIndex);
clearToolCallArgumentsInMessage(event.partial, event.contentIndex);
clearToolCallArgumentsInMessage(event.message, event.contentIndex);
}
}
}
if (
typeof event.contentIndex === "number" &&
Number.isInteger(event.contentIndex) &&
event.type === "toolcall_end"
) {
const repairedArgs = repairedArgsByIndex.get(event.contentIndex);
if (repairedArgs) {
if (event.toolCall && typeof event.toolCall === "object") {
(event.toolCall as { arguments?: unknown }).arguments = repairedArgs;
}
repairToolCallArgumentsInMessage(event.partial, event.contentIndex, repairedArgs);
repairToolCallArgumentsInMessage(event.message, event.contentIndex, repairedArgs);
}
partialJsonByIndex.delete(event.contentIndex);
disabledIndices.delete(event.contentIndex);
loggedRepairIndices.delete(event.contentIndex);
}
}
return result;
},
async return(value?: unknown) {
return iterator.return?.(value) ?? { done: true as const, value: undefined };
},
async throw(error?: unknown) {
return iterator.throw?.(error) ?? { done: true as const, value: undefined };
},
};
};
return stream;
}
export function wrapStreamFnRepairMalformedToolCallArguments(baseFn: StreamFn): StreamFn {
return (model, context, options) => {
const maybeStream = baseFn(model, context, options);
if (maybeStream && typeof maybeStream === "object" && "then" in maybeStream) {
return Promise.resolve(maybeStream).then((stream) =>
wrapStreamRepairMalformedToolCallArguments(stream),
);
}
return wrapStreamRepairMalformedToolCallArguments(maybeStream);
};
}
export function shouldRepairMalformedAnthropicToolCallArguments(provider?: string): boolean {
return normalizeProviderId(provider ?? "") === "kimi";
}
const HTML_ENTITY_RE = /&(?:amp|lt|gt|quot|apos|#39|#x[0-9a-f]+|#\d+);/i;
function decodeHtmlEntities(value: string): string {
return value
.replace(/&amp;/gi, "&")
.replace(/&quot;/gi, '"')
.replace(/&#39;/gi, "'")
.replace(/&apos;/gi, "'")
.replace(/&lt;/gi, "<")
.replace(/&gt;/gi, ">")
.replace(/&#x([0-9a-f]+);/gi, (_, hex) => String.fromCodePoint(Number.parseInt(hex, 16)))
.replace(/&#(\d+);/gi, (_, dec) => String.fromCodePoint(Number.parseInt(dec, 10)));
}
export function decodeHtmlEntitiesInObject(obj: unknown): unknown {
if (typeof obj === "string") {
return HTML_ENTITY_RE.test(obj) ? decodeHtmlEntities(obj) : obj;
}
if (Array.isArray(obj)) {
return obj.map(decodeHtmlEntitiesInObject);
}
if (obj && typeof obj === "object") {
const result: Record<string, unknown> = {};
for (const [key, val] of Object.entries(obj as Record<string, unknown>)) {
result[key] = decodeHtmlEntitiesInObject(val);
}
return result;
}
return obj;
}
function decodeXaiToolCallArgumentsInMessage(message: unknown): void {
if (!message || typeof message !== "object") {
return;
}
const content = (message as { content?: unknown }).content;
if (!Array.isArray(content)) {
return;
}
for (const block of content) {
if (!block || typeof block !== "object") {
continue;
}
const typedBlock = block as { type?: unknown; arguments?: unknown };
if (typedBlock.type !== "toolCall" || !typedBlock.arguments) {
continue;
}
if (typeof typedBlock.arguments === "object") {
typedBlock.arguments = decodeHtmlEntitiesInObject(typedBlock.arguments);
}
}
}
function wrapStreamDecodeXaiToolCallArguments(
stream: ReturnType<typeof streamSimple>,
): ReturnType<typeof streamSimple> {
const originalResult = stream.result.bind(stream);
stream.result = async () => {
const message = await originalResult();
decodeXaiToolCallArgumentsInMessage(message);
return message;
};
const originalAsyncIterator = stream[Symbol.asyncIterator].bind(stream);
(stream as { [Symbol.asyncIterator]: typeof originalAsyncIterator })[Symbol.asyncIterator] =
function () {
const iterator = originalAsyncIterator();
return {
async next() {
const result = await iterator.next();
if (!result.done && result.value && typeof result.value === "object") {
const event = result.value as { partial?: unknown; message?: unknown };
decodeXaiToolCallArgumentsInMessage(event.partial);
decodeXaiToolCallArgumentsInMessage(event.message);
}
return result;
},
async return(value?: unknown) {
return iterator.return?.(value) ?? { done: true as const, value: undefined };
},
async throw(error?: unknown) {
return iterator.throw?.(error) ?? { done: true as const, value: undefined };
},
};
};
return stream;
}
export function wrapStreamFnDecodeXaiToolCallArguments(baseFn: StreamFn): StreamFn {
return (model, context, options) => {
const maybeStream = baseFn(model, context, options);
if (maybeStream && typeof maybeStream === "object" && "then" in maybeStream) {
return Promise.resolve(maybeStream).then((stream) =>
wrapStreamDecodeXaiToolCallArguments(stream),
);
}
return wrapStreamDecodeXaiToolCallArguments(maybeStream);
};
}

View File

@@ -0,0 +1,588 @@
import type { AgentMessage, StreamFn } from "@mariozechner/pi-agent-core";
import { streamSimple } from "@mariozechner/pi-ai";
import { validateAnthropicTurns, validateGeminiTurns } from "../../pi-embedded-helpers.js";
import { sanitizeToolUseResultPairing } from "../../session-transcript-repair.js";
import { normalizeToolName } from "../../tool-policy.js";
import type { TranscriptPolicy } from "../../transcript-policy.js";
function resolveCaseInsensitiveAllowedToolName(
rawName: string,
allowedToolNames?: Set<string>,
): string | null {
if (!allowedToolNames || allowedToolNames.size === 0) {
return null;
}
const folded = rawName.toLowerCase();
let caseInsensitiveMatch: string | null = null;
for (const name of allowedToolNames) {
if (name.toLowerCase() !== folded) {
continue;
}
if (caseInsensitiveMatch && caseInsensitiveMatch !== name) {
return null;
}
caseInsensitiveMatch = name;
}
return caseInsensitiveMatch;
}
function resolveExactAllowedToolName(
rawName: string,
allowedToolNames?: Set<string>,
): string | null {
if (!allowedToolNames || allowedToolNames.size === 0) {
return null;
}
if (allowedToolNames.has(rawName)) {
return rawName;
}
const normalized = normalizeToolName(rawName);
if (allowedToolNames.has(normalized)) {
return normalized;
}
return (
resolveCaseInsensitiveAllowedToolName(rawName, allowedToolNames) ??
resolveCaseInsensitiveAllowedToolName(normalized, allowedToolNames)
);
}
function buildStructuredToolNameCandidates(rawName: string): string[] {
const trimmed = rawName.trim();
if (!trimmed) {
return [];
}
const candidates: string[] = [];
const seen = new Set<string>();
const addCandidate = (value: string) => {
const candidate = value.trim();
if (!candidate || seen.has(candidate)) {
return;
}
seen.add(candidate);
candidates.push(candidate);
};
addCandidate(trimmed);
addCandidate(normalizeToolName(trimmed));
const normalizedDelimiter = trimmed.replace(/\//g, ".");
addCandidate(normalizedDelimiter);
addCandidate(normalizeToolName(normalizedDelimiter));
const segments = normalizedDelimiter
.split(".")
.map((segment) => segment.trim())
.filter(Boolean);
if (segments.length > 1) {
for (let index = 1; index < segments.length; index += 1) {
const suffix = segments.slice(index).join(".");
addCandidate(suffix);
addCandidate(normalizeToolName(suffix));
}
}
return candidates;
}
function resolveStructuredAllowedToolName(
rawName: string,
allowedToolNames?: Set<string>,
): string | null {
if (!allowedToolNames || allowedToolNames.size === 0) {
return null;
}
const candidateNames = buildStructuredToolNameCandidates(rawName);
for (const candidate of candidateNames) {
if (allowedToolNames.has(candidate)) {
return candidate;
}
}
for (const candidate of candidateNames) {
const caseInsensitiveMatch = resolveCaseInsensitiveAllowedToolName(candidate, allowedToolNames);
if (caseInsensitiveMatch) {
return caseInsensitiveMatch;
}
}
return null;
}
function inferToolNameFromToolCallId(
rawId: string | undefined,
allowedToolNames?: Set<string>,
): string | null {
if (!rawId || !allowedToolNames || allowedToolNames.size === 0) {
return null;
}
const id = rawId.trim();
if (!id) {
return null;
}
const candidateTokens = new Set<string>();
const addToken = (value: string) => {
const trimmed = value.trim();
if (!trimmed) {
return;
}
candidateTokens.add(trimmed);
candidateTokens.add(trimmed.replace(/[:._/-]\d+$/, ""));
candidateTokens.add(trimmed.replace(/\d+$/, ""));
const normalizedDelimiter = trimmed.replace(/\//g, ".");
candidateTokens.add(normalizedDelimiter);
candidateTokens.add(normalizedDelimiter.replace(/[:._-]\d+$/, ""));
candidateTokens.add(normalizedDelimiter.replace(/\d+$/, ""));
for (const prefixPattern of [/^functions?[._-]?/i, /^tools?[._-]?/i]) {
const stripped = normalizedDelimiter.replace(prefixPattern, "");
if (stripped !== normalizedDelimiter) {
candidateTokens.add(stripped);
candidateTokens.add(stripped.replace(/[:._-]\d+$/, ""));
candidateTokens.add(stripped.replace(/\d+$/, ""));
}
}
};
const preColon = id.split(":")[0] ?? id;
for (const seed of [id, preColon]) {
addToken(seed);
}
let singleMatch: string | null = null;
for (const candidate of candidateTokens) {
const matched = resolveStructuredAllowedToolName(candidate, allowedToolNames);
if (!matched) {
continue;
}
if (singleMatch && singleMatch !== matched) {
return null;
}
singleMatch = matched;
}
return singleMatch;
}
function looksLikeMalformedToolNameCounter(rawName: string): boolean {
const normalizedDelimiter = rawName.trim().replace(/\//g, ".");
return (
/^(?:functions?|tools?)[._-]?/i.test(normalizedDelimiter) &&
/(?:[:._-]\d+|\d+)$/.test(normalizedDelimiter)
);
}
function normalizeToolCallNameForDispatch(
rawName: string,
allowedToolNames?: Set<string>,
rawToolCallId?: string,
): string {
const trimmed = rawName.trim();
if (!trimmed) {
return inferToolNameFromToolCallId(rawToolCallId, allowedToolNames) ?? rawName;
}
if (!allowedToolNames || allowedToolNames.size === 0) {
return trimmed;
}
const exact = resolveExactAllowedToolName(trimmed, allowedToolNames);
if (exact) {
return exact;
}
const inferredFromName = inferToolNameFromToolCallId(trimmed, allowedToolNames);
if (inferredFromName) {
return inferredFromName;
}
if (looksLikeMalformedToolNameCounter(trimmed)) {
return trimmed;
}
return resolveStructuredAllowedToolName(trimmed, allowedToolNames) ?? trimmed;
}
function isToolCallBlockType(type: unknown): boolean {
return type === "toolCall" || type === "toolUse" || type === "functionCall";
}
const REPLAY_TOOL_CALL_NAME_MAX_CHARS = 64;
type ReplayToolCallBlock = {
type?: unknown;
id?: unknown;
name?: unknown;
input?: unknown;
arguments?: unknown;
};
type ReplayToolCallSanitizeReport = {
messages: AgentMessage[];
droppedAssistantMessages: number;
};
type AnthropicToolResultContentBlock = {
type?: unknown;
toolUseId?: unknown;
};
function isReplayToolCallBlock(block: unknown): block is ReplayToolCallBlock {
if (!block || typeof block !== "object") {
return false;
}
return isToolCallBlockType((block as { type?: unknown }).type);
}
function replayToolCallHasInput(block: ReplayToolCallBlock): boolean {
const hasInput = "input" in block ? block.input !== undefined && block.input !== null : false;
const hasArguments =
"arguments" in block ? block.arguments !== undefined && block.arguments !== null : false;
return hasInput || hasArguments;
}
function replayToolCallNonEmptyString(value: unknown): value is string {
return typeof value === "string" && value.trim().length > 0;
}
function resolveReplayToolCallName(
rawName: string,
rawId: string,
allowedToolNames?: Set<string>,
): string | null {
if (rawName.length > REPLAY_TOOL_CALL_NAME_MAX_CHARS * 2) {
return null;
}
const normalized = normalizeToolCallNameForDispatch(rawName, allowedToolNames, rawId);
const trimmed = normalized.trim();
if (!trimmed || trimmed.length > REPLAY_TOOL_CALL_NAME_MAX_CHARS || /\s/.test(trimmed)) {
return null;
}
if (!allowedToolNames || allowedToolNames.size === 0) {
return trimmed;
}
return resolveExactAllowedToolName(trimmed, allowedToolNames);
}
function sanitizeReplayToolCallInputs(
messages: AgentMessage[],
allowedToolNames?: Set<string>,
): ReplayToolCallSanitizeReport {
let changed = false;
let droppedAssistantMessages = 0;
const out: AgentMessage[] = [];
for (const message of messages) {
if (!message || typeof message !== "object" || message.role !== "assistant") {
out.push(message);
continue;
}
if (!Array.isArray(message.content)) {
out.push(message);
continue;
}
const nextContent: typeof message.content = [];
let messageChanged = false;
for (const block of message.content) {
if (!isReplayToolCallBlock(block)) {
nextContent.push(block);
continue;
}
const replayBlock = block as ReplayToolCallBlock;
if (!replayToolCallHasInput(replayBlock) || !replayToolCallNonEmptyString(replayBlock.id)) {
changed = true;
messageChanged = true;
continue;
}
const rawName = typeof replayBlock.name === "string" ? replayBlock.name : "";
const resolvedName = resolveReplayToolCallName(rawName, replayBlock.id, allowedToolNames);
if (!resolvedName) {
changed = true;
messageChanged = true;
continue;
}
if (replayBlock.name !== resolvedName) {
nextContent.push({ ...(block as object), name: resolvedName } as typeof block);
changed = true;
messageChanged = true;
continue;
}
nextContent.push(block);
}
if (messageChanged) {
changed = true;
if (nextContent.length > 0) {
out.push({ ...message, content: nextContent });
} else {
droppedAssistantMessages += 1;
}
continue;
}
out.push(message);
}
return {
messages: changed ? out : messages,
droppedAssistantMessages,
};
}
function sanitizeAnthropicReplayToolResults(messages: AgentMessage[]): AgentMessage[] {
let changed = false;
const out: AgentMessage[] = [];
for (let index = 0; index < messages.length; index += 1) {
const message = messages[index];
if (!message || typeof message !== "object" || message.role !== "user") {
out.push(message);
continue;
}
if (!Array.isArray(message.content)) {
out.push(message);
continue;
}
const previous = messages[index - 1];
const validToolUseIds = new Set<string>();
if (previous && typeof previous === "object" && previous.role === "assistant") {
const previousContent = (previous as { content?: unknown }).content;
if (Array.isArray(previousContent)) {
for (const block of previousContent) {
if (!block || typeof block !== "object") {
continue;
}
const typedBlock = block as { type?: unknown; id?: unknown };
if (typedBlock.type !== "toolUse" || typeof typedBlock.id !== "string") {
continue;
}
const trimmedId = typedBlock.id.trim();
if (trimmedId) {
validToolUseIds.add(trimmedId);
}
}
}
}
const nextContent = message.content.filter((block) => {
if (!block || typeof block !== "object") {
return true;
}
const typedBlock = block as AnthropicToolResultContentBlock;
if (typedBlock.type !== "toolResult" || typeof typedBlock.toolUseId !== "string") {
return true;
}
return validToolUseIds.size > 0 && validToolUseIds.has(typedBlock.toolUseId);
});
if (nextContent.length === message.content.length) {
out.push(message);
continue;
}
changed = true;
if (nextContent.length > 0) {
out.push({ ...message, content: nextContent });
continue;
}
out.push({
...message,
content: [{ type: "text", text: "[tool results omitted]" }],
} as AgentMessage);
}
return changed ? out : messages;
}
function normalizeToolCallIdsInMessage(message: unknown): void {
if (!message || typeof message !== "object") {
return;
}
const content = (message as { content?: unknown }).content;
if (!Array.isArray(content)) {
return;
}
const usedIds = new Set<string>();
for (const block of content) {
if (!block || typeof block !== "object") {
continue;
}
const typedBlock = block as { type?: unknown; id?: unknown };
if (!isToolCallBlockType(typedBlock.type) || typeof typedBlock.id !== "string") {
continue;
}
const trimmedId = typedBlock.id.trim();
if (!trimmedId) {
continue;
}
usedIds.add(trimmedId);
}
let fallbackIndex = 1;
const assignedIds = new Set<string>();
for (const block of content) {
if (!block || typeof block !== "object") {
continue;
}
const typedBlock = block as { type?: unknown; id?: unknown };
if (!isToolCallBlockType(typedBlock.type)) {
continue;
}
if (typeof typedBlock.id === "string") {
const trimmedId = typedBlock.id.trim();
if (trimmedId) {
if (!assignedIds.has(trimmedId)) {
if (typedBlock.id !== trimmedId) {
typedBlock.id = trimmedId;
}
assignedIds.add(trimmedId);
continue;
}
}
}
let fallbackId = "";
while (!fallbackId || usedIds.has(fallbackId) || assignedIds.has(fallbackId)) {
fallbackId = `call_auto_${fallbackIndex++}`;
}
typedBlock.id = fallbackId;
usedIds.add(fallbackId);
assignedIds.add(fallbackId);
}
}
function trimWhitespaceFromToolCallNamesInMessage(
message: unknown,
allowedToolNames?: Set<string>,
): void {
if (!message || typeof message !== "object") {
return;
}
const content = (message as { content?: unknown }).content;
if (!Array.isArray(content)) {
return;
}
for (const block of content) {
if (!block || typeof block !== "object") {
continue;
}
const typedBlock = block as { type?: unknown; name?: unknown; id?: unknown };
if (!isToolCallBlockType(typedBlock.type)) {
continue;
}
const rawId = typeof typedBlock.id === "string" ? typedBlock.id : undefined;
if (typeof typedBlock.name === "string") {
const normalized = normalizeToolCallNameForDispatch(typedBlock.name, allowedToolNames, rawId);
if (normalized !== typedBlock.name) {
typedBlock.name = normalized;
}
continue;
}
const inferred = inferToolNameFromToolCallId(rawId, allowedToolNames);
if (inferred) {
typedBlock.name = inferred;
}
}
normalizeToolCallIdsInMessage(message);
}
function wrapStreamTrimToolCallNames(
stream: ReturnType<typeof streamSimple>,
allowedToolNames?: Set<string>,
): ReturnType<typeof streamSimple> {
const originalResult = stream.result.bind(stream);
stream.result = async () => {
const message = await originalResult();
trimWhitespaceFromToolCallNamesInMessage(message, allowedToolNames);
return message;
};
const originalAsyncIterator = stream[Symbol.asyncIterator].bind(stream);
(stream as { [Symbol.asyncIterator]: typeof originalAsyncIterator })[Symbol.asyncIterator] =
function () {
const iterator = originalAsyncIterator();
return {
async next() {
const result = await iterator.next();
if (!result.done && result.value && typeof result.value === "object") {
const event = result.value as {
partial?: unknown;
message?: unknown;
};
trimWhitespaceFromToolCallNamesInMessage(event.partial, allowedToolNames);
trimWhitespaceFromToolCallNamesInMessage(event.message, allowedToolNames);
}
return result;
},
async return(value?: unknown) {
return iterator.return?.(value) ?? { done: true as const, value: undefined };
},
async throw(error?: unknown) {
return iterator.throw?.(error) ?? { done: true as const, value: undefined };
},
};
};
return stream;
}
export function wrapStreamFnTrimToolCallNames(
baseFn: StreamFn,
allowedToolNames?: Set<string>,
): StreamFn {
return (model, context, options) => {
const maybeStream = baseFn(model, context, options);
if (maybeStream && typeof maybeStream === "object" && "then" in maybeStream) {
return Promise.resolve(maybeStream).then((stream) =>
wrapStreamTrimToolCallNames(stream, allowedToolNames),
);
}
return wrapStreamTrimToolCallNames(maybeStream, allowedToolNames);
};
}
export function wrapStreamFnSanitizeMalformedToolCalls(
baseFn: StreamFn,
allowedToolNames?: Set<string>,
transcriptPolicy?: Pick<TranscriptPolicy, "validateGeminiTurns" | "validateAnthropicTurns">,
): StreamFn {
return (model, context, options) => {
const ctx = context as unknown as { messages?: unknown };
const messages = ctx?.messages;
if (!Array.isArray(messages)) {
return baseFn(model, context, options);
}
const sanitized = sanitizeReplayToolCallInputs(messages as AgentMessage[], allowedToolNames);
if (sanitized.messages === messages) {
return baseFn(model, context, options);
}
let nextMessages = sanitizeToolUseResultPairing(sanitized.messages, {
preserveErroredAssistantResults: true,
});
if (transcriptPolicy?.validateAnthropicTurns) {
nextMessages = sanitizeAnthropicReplayToolResults(nextMessages);
}
if (sanitized.droppedAssistantMessages > 0 || transcriptPolicy?.validateAnthropicTurns) {
if (transcriptPolicy?.validateGeminiTurns) {
nextMessages = validateGeminiTurns(nextMessages);
}
if (transcriptPolicy?.validateAnthropicTurns) {
nextMessages = validateAnthropicTurns(nextMessages);
}
}
const nextContext = {
...(context as unknown as Record<string, unknown>),
messages: nextMessages,
} as unknown;
return baseFn(model, nextContext as typeof context, options);
};
}

File diff suppressed because it is too large Load Diff