mirror of
https://github.com/openclaw/openclaw.git
synced 2026-04-20 05:31:30 +00:00
refactor: split embedded attempt helpers
This commit is contained in:
100
src/agents/pi-embedded-runner/run/attempt.ollama-compat.ts
Normal file
100
src/agents/pi-embedded-runner/run/attempt.ollama-compat.ts
Normal file
@@ -0,0 +1,100 @@
|
||||
import type { StreamFn } from "@mariozechner/pi-agent-core";
|
||||
import { streamSimple } from "@mariozechner/pi-ai";
|
||||
import type { OpenClawConfig } from "../../../config/config.js";
|
||||
import { normalizeProviderId } from "../../model-selection.js";
|
||||
|
||||
export function isOllamaCompatProvider(model: {
|
||||
provider?: string;
|
||||
baseUrl?: string;
|
||||
api?: string;
|
||||
}): boolean {
|
||||
const providerId = normalizeProviderId(model.provider ?? "");
|
||||
if (providerId === "ollama") {
|
||||
return true;
|
||||
}
|
||||
if (!model.baseUrl) {
|
||||
return false;
|
||||
}
|
||||
try {
|
||||
const parsed = new URL(model.baseUrl);
|
||||
const hostname = parsed.hostname.toLowerCase();
|
||||
const isLocalhost =
|
||||
hostname === "localhost" ||
|
||||
hostname === "127.0.0.1" ||
|
||||
hostname === "::1" ||
|
||||
hostname === "[::1]";
|
||||
if (isLocalhost && parsed.port === "11434") {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Allow remote/LAN Ollama OpenAI-compatible endpoints when the provider id
|
||||
// itself indicates Ollama usage (e.g. "my-ollama").
|
||||
const providerHintsOllama = providerId.includes("ollama");
|
||||
const isOllamaPort = parsed.port === "11434";
|
||||
const isOllamaCompatPath = parsed.pathname === "/" || /^\/v1\/?$/i.test(parsed.pathname);
|
||||
return providerHintsOllama && isOllamaPort && isOllamaCompatPath;
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
export function resolveOllamaCompatNumCtxEnabled(params: {
|
||||
config?: OpenClawConfig;
|
||||
providerId?: string;
|
||||
}): boolean {
|
||||
const providerId = params.providerId?.trim();
|
||||
if (!providerId) {
|
||||
return true;
|
||||
}
|
||||
const providers = params.config?.models?.providers;
|
||||
if (!providers) {
|
||||
return true;
|
||||
}
|
||||
const direct = providers[providerId];
|
||||
if (direct) {
|
||||
return direct.injectNumCtxForOpenAICompat ?? true;
|
||||
}
|
||||
const normalized = normalizeProviderId(providerId);
|
||||
for (const [candidateId, candidate] of Object.entries(providers)) {
|
||||
if (normalizeProviderId(candidateId) === normalized) {
|
||||
return candidate.injectNumCtxForOpenAICompat ?? true;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
export function shouldInjectOllamaCompatNumCtx(params: {
|
||||
model: { api?: string; provider?: string; baseUrl?: string };
|
||||
config?: OpenClawConfig;
|
||||
providerId?: string;
|
||||
}): boolean {
|
||||
if (params.model.api !== "openai-completions") {
|
||||
return false;
|
||||
}
|
||||
if (!isOllamaCompatProvider(params.model)) {
|
||||
return false;
|
||||
}
|
||||
return resolveOllamaCompatNumCtxEnabled({
|
||||
config: params.config,
|
||||
providerId: params.providerId,
|
||||
});
|
||||
}
|
||||
|
||||
export function wrapOllamaCompatNumCtx(baseFn: StreamFn | undefined, numCtx: number): StreamFn {
|
||||
const streamFn = baseFn ?? streamSimple;
|
||||
return (model, context, options) =>
|
||||
streamFn(model, context, {
|
||||
...options,
|
||||
onPayload: (payload: unknown) => {
|
||||
if (!payload || typeof payload !== "object") {
|
||||
return options?.onPayload?.(payload, model);
|
||||
}
|
||||
const payloadRecord = payload as Record<string, unknown>;
|
||||
if (!payloadRecord.options || typeof payloadRecord.options !== "object") {
|
||||
payloadRecord.options = {};
|
||||
}
|
||||
(payloadRecord.options as Record<string, unknown>).num_ctx = numCtx;
|
||||
return options?.onPayload?.(payload, model);
|
||||
},
|
||||
});
|
||||
}
|
||||
168
src/agents/pi-embedded-runner/run/attempt.prompt-helpers.ts
Normal file
168
src/agents/pi-embedded-runner/run/attempt.prompt-helpers.ts
Normal file
@@ -0,0 +1,168 @@
|
||||
import type { OpenClawConfig } from "../../../config/config.js";
|
||||
import type {
|
||||
PluginHookAgentContext,
|
||||
PluginHookBeforeAgentStartResult,
|
||||
PluginHookBeforePromptBuildResult,
|
||||
} from "../../../plugins/types.js";
|
||||
import { isCronSessionKey, isSubagentSessionKey } from "../../../routing/session-key.js";
|
||||
import { joinPresentTextSegments } from "../../../shared/text/join-segments.js";
|
||||
import { resolveEffectiveToolFsWorkspaceOnly } from "../../tool-fs-policy.js";
|
||||
import type { CompactEmbeddedPiSessionParams } from "../compact.js";
|
||||
import { buildEmbeddedCompactionRuntimeContext } from "../compaction-runtime-context.js";
|
||||
import { log } from "../logger.js";
|
||||
import { shouldInjectHeartbeatPromptForTrigger } from "./trigger-policy.js";
|
||||
import type { EmbeddedRunAttemptParams } from "./types.js";
|
||||
|
||||
export type PromptBuildHookRunner = {
|
||||
hasHooks: (hookName: "before_prompt_build" | "before_agent_start") => boolean;
|
||||
runBeforePromptBuild: (
|
||||
event: { prompt: string; messages: unknown[] },
|
||||
ctx: PluginHookAgentContext,
|
||||
) => Promise<PluginHookBeforePromptBuildResult | undefined>;
|
||||
runBeforeAgentStart: (
|
||||
event: { prompt: string; messages: unknown[] },
|
||||
ctx: PluginHookAgentContext,
|
||||
) => Promise<PluginHookBeforeAgentStartResult | undefined>;
|
||||
};
|
||||
|
||||
export async function resolvePromptBuildHookResult(params: {
|
||||
prompt: string;
|
||||
messages: unknown[];
|
||||
hookCtx: PluginHookAgentContext;
|
||||
hookRunner?: PromptBuildHookRunner | null;
|
||||
legacyBeforeAgentStartResult?: PluginHookBeforeAgentStartResult;
|
||||
}): Promise<PluginHookBeforePromptBuildResult> {
|
||||
const promptBuildResult = params.hookRunner?.hasHooks("before_prompt_build")
|
||||
? await params.hookRunner
|
||||
.runBeforePromptBuild(
|
||||
{
|
||||
prompt: params.prompt,
|
||||
messages: params.messages,
|
||||
},
|
||||
params.hookCtx,
|
||||
)
|
||||
.catch((hookErr: unknown) => {
|
||||
log.warn(`before_prompt_build hook failed: ${String(hookErr)}`);
|
||||
return undefined;
|
||||
})
|
||||
: undefined;
|
||||
const legacyResult =
|
||||
params.legacyBeforeAgentStartResult ??
|
||||
(params.hookRunner?.hasHooks("before_agent_start")
|
||||
? await params.hookRunner
|
||||
.runBeforeAgentStart(
|
||||
{
|
||||
prompt: params.prompt,
|
||||
messages: params.messages,
|
||||
},
|
||||
params.hookCtx,
|
||||
)
|
||||
.catch((hookErr: unknown) => {
|
||||
log.warn(
|
||||
`before_agent_start hook (legacy prompt build path) failed: ${String(hookErr)}`,
|
||||
);
|
||||
return undefined;
|
||||
})
|
||||
: undefined);
|
||||
return {
|
||||
systemPrompt: promptBuildResult?.systemPrompt ?? legacyResult?.systemPrompt,
|
||||
prependContext: joinPresentTextSegments([
|
||||
promptBuildResult?.prependContext,
|
||||
legacyResult?.prependContext,
|
||||
]),
|
||||
prependSystemContext: joinPresentTextSegments([
|
||||
promptBuildResult?.prependSystemContext,
|
||||
legacyResult?.prependSystemContext,
|
||||
]),
|
||||
appendSystemContext: joinPresentTextSegments([
|
||||
promptBuildResult?.appendSystemContext,
|
||||
legacyResult?.appendSystemContext,
|
||||
]),
|
||||
};
|
||||
}
|
||||
|
||||
export function resolvePromptModeForSession(sessionKey?: string): "minimal" | "full" {
|
||||
if (!sessionKey) {
|
||||
return "full";
|
||||
}
|
||||
return isSubagentSessionKey(sessionKey) || isCronSessionKey(sessionKey) ? "minimal" : "full";
|
||||
}
|
||||
|
||||
export function shouldInjectHeartbeatPrompt(params: {
|
||||
isDefaultAgent: boolean;
|
||||
trigger?: EmbeddedRunAttemptParams["trigger"];
|
||||
}): boolean {
|
||||
return params.isDefaultAgent && shouldInjectHeartbeatPromptForTrigger(params.trigger);
|
||||
}
|
||||
|
||||
export function resolveAttemptFsWorkspaceOnly(params: {
|
||||
config?: OpenClawConfig;
|
||||
sessionAgentId: string;
|
||||
}): boolean {
|
||||
return resolveEffectiveToolFsWorkspaceOnly({
|
||||
cfg: params.config,
|
||||
agentId: params.sessionAgentId,
|
||||
});
|
||||
}
|
||||
|
||||
export function prependSystemPromptAddition(params: {
|
||||
systemPrompt: string;
|
||||
systemPromptAddition?: string;
|
||||
}): string {
|
||||
if (!params.systemPromptAddition) {
|
||||
return params.systemPrompt;
|
||||
}
|
||||
return `${params.systemPromptAddition}\n\n${params.systemPrompt}`;
|
||||
}
|
||||
|
||||
/** Build runtime context passed into context-engine afterTurn hooks. */
|
||||
export function buildAfterTurnRuntimeContext(params: {
|
||||
attempt: Pick<
|
||||
EmbeddedRunAttemptParams,
|
||||
| "sessionKey"
|
||||
| "messageChannel"
|
||||
| "messageProvider"
|
||||
| "agentAccountId"
|
||||
| "currentChannelId"
|
||||
| "currentThreadTs"
|
||||
| "currentMessageId"
|
||||
| "config"
|
||||
| "skillsSnapshot"
|
||||
| "senderIsOwner"
|
||||
| "senderId"
|
||||
| "provider"
|
||||
| "modelId"
|
||||
| "thinkLevel"
|
||||
| "reasoningLevel"
|
||||
| "bashElevated"
|
||||
| "extraSystemPrompt"
|
||||
| "ownerNumbers"
|
||||
| "authProfileId"
|
||||
>;
|
||||
workspaceDir: string;
|
||||
agentDir: string;
|
||||
}): Partial<CompactEmbeddedPiSessionParams> {
|
||||
return buildEmbeddedCompactionRuntimeContext({
|
||||
sessionKey: params.attempt.sessionKey,
|
||||
messageChannel: params.attempt.messageChannel,
|
||||
messageProvider: params.attempt.messageProvider,
|
||||
agentAccountId: params.attempt.agentAccountId,
|
||||
currentChannelId: params.attempt.currentChannelId,
|
||||
currentThreadTs: params.attempt.currentThreadTs,
|
||||
currentMessageId: params.attempt.currentMessageId,
|
||||
authProfileId: params.attempt.authProfileId,
|
||||
workspaceDir: params.workspaceDir,
|
||||
agentDir: params.agentDir,
|
||||
config: params.attempt.config,
|
||||
skillsSnapshot: params.attempt.skillsSnapshot,
|
||||
senderIsOwner: params.attempt.senderIsOwner,
|
||||
senderId: params.attempt.senderId,
|
||||
provider: params.attempt.provider,
|
||||
modelId: params.attempt.modelId,
|
||||
thinkLevel: params.attempt.thinkLevel,
|
||||
reasoningLevel: params.attempt.reasoningLevel,
|
||||
bashElevated: params.attempt.bashElevated,
|
||||
extraSystemPrompt: params.attempt.extraSystemPrompt,
|
||||
ownerNumbers: params.attempt.ownerNumbers,
|
||||
});
|
||||
}
|
||||
221
src/agents/pi-embedded-runner/run/attempt.sessions-yield.ts
Normal file
221
src/agents/pi-embedded-runner/run/attempt.sessions-yield.ts
Normal file
@@ -0,0 +1,221 @@
|
||||
import type { AgentMessage } from "@mariozechner/pi-agent-core";
|
||||
import { log } from "../logger.js";
|
||||
|
||||
const SESSIONS_YIELD_INTERRUPT_CUSTOM_TYPE = "openclaw.sessions_yield_interrupt";
|
||||
const SESSIONS_YIELD_CONTEXT_CUSTOM_TYPE = "openclaw.sessions_yield";
|
||||
const SESSIONS_YIELD_ABORT_SETTLE_TIMEOUT_MS = process.env.OPENCLAW_TEST_FAST === "1" ? 250 : 2_000;
|
||||
|
||||
// Persist a hidden context reminder so the next turn knows why the runner stopped.
|
||||
export function buildSessionsYieldContextMessage(message: string): string {
|
||||
return `${message}\n\n[Context: The previous turn ended intentionally via sessions_yield while waiting for a follow-up event.]`;
|
||||
}
|
||||
|
||||
export async function waitForSessionsYieldAbortSettle(params: {
|
||||
settlePromise: Promise<void> | null;
|
||||
runId: string;
|
||||
sessionId: string;
|
||||
}): Promise<void> {
|
||||
if (!params.settlePromise) {
|
||||
return;
|
||||
}
|
||||
|
||||
let timeout: NodeJS.Timeout | undefined;
|
||||
const outcome = await Promise.race([
|
||||
params.settlePromise
|
||||
.then(() => "settled" as const)
|
||||
.catch((err) => {
|
||||
log.warn(
|
||||
`sessions_yield abort settle failed: runId=${params.runId} sessionId=${params.sessionId} err=${String(err)}`,
|
||||
);
|
||||
return "errored" as const;
|
||||
}),
|
||||
new Promise<"timed_out">((resolve) => {
|
||||
timeout = setTimeout(() => resolve("timed_out"), SESSIONS_YIELD_ABORT_SETTLE_TIMEOUT_MS);
|
||||
}),
|
||||
]);
|
||||
if (timeout) {
|
||||
clearTimeout(timeout);
|
||||
}
|
||||
if (outcome === "timed_out") {
|
||||
log.warn(
|
||||
`sessions_yield abort settle timed out: runId=${params.runId} sessionId=${params.sessionId} timeoutMs=${SESSIONS_YIELD_ABORT_SETTLE_TIMEOUT_MS}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Return a synthetic aborted response so pi-agent-core unwinds without a real provider call.
|
||||
export function createYieldAbortedResponse(model: {
|
||||
api?: string;
|
||||
provider?: string;
|
||||
id?: string;
|
||||
}): {
|
||||
[Symbol.asyncIterator]: () => AsyncGenerator<never, void, unknown>;
|
||||
result: () => Promise<{
|
||||
role: "assistant";
|
||||
content: Array<{ type: "text"; text: string }>;
|
||||
stopReason: "aborted";
|
||||
api: string;
|
||||
provider: string;
|
||||
model: string;
|
||||
usage: {
|
||||
input: number;
|
||||
output: number;
|
||||
cacheRead: number;
|
||||
cacheWrite: number;
|
||||
totalTokens: number;
|
||||
cost: {
|
||||
input: number;
|
||||
output: number;
|
||||
cacheRead: number;
|
||||
cacheWrite: number;
|
||||
total: number;
|
||||
};
|
||||
};
|
||||
timestamp: number;
|
||||
}>;
|
||||
} {
|
||||
const message = {
|
||||
role: "assistant" as const,
|
||||
content: [{ type: "text" as const, text: "" }],
|
||||
stopReason: "aborted" as const,
|
||||
api: model.api ?? "",
|
||||
provider: model.provider ?? "",
|
||||
model: model.id ?? "",
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
total: 0,
|
||||
},
|
||||
},
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
return {
|
||||
async *[Symbol.asyncIterator]() {},
|
||||
result: async () => message,
|
||||
};
|
||||
}
|
||||
|
||||
// Queue a hidden steering message so pi-agent-core injects it before the next
|
||||
// LLM call once the current assistant turn finishes executing its tool calls.
|
||||
export function queueSessionsYieldInterruptMessage(activeSession: {
|
||||
agent: { steer: (message: AgentMessage) => void };
|
||||
}) {
|
||||
activeSession.agent.steer({
|
||||
role: "custom",
|
||||
customType: SESSIONS_YIELD_INTERRUPT_CUSTOM_TYPE,
|
||||
content: "[sessions_yield interrupt]",
|
||||
display: false,
|
||||
details: { source: "sessions_yield" },
|
||||
timestamp: Date.now(),
|
||||
});
|
||||
}
|
||||
|
||||
// Append the caller-provided yield payload as a hidden session message once the run is idle.
|
||||
export async function persistSessionsYieldContextMessage(
|
||||
activeSession: {
|
||||
sendCustomMessage: (
|
||||
message: {
|
||||
customType: string;
|
||||
content: string;
|
||||
display: boolean;
|
||||
details?: Record<string, unknown>;
|
||||
},
|
||||
options?: { triggerTurn?: boolean },
|
||||
) => Promise<void>;
|
||||
},
|
||||
message: string,
|
||||
) {
|
||||
await activeSession.sendCustomMessage(
|
||||
{
|
||||
customType: SESSIONS_YIELD_CONTEXT_CUSTOM_TYPE,
|
||||
content: buildSessionsYieldContextMessage(message),
|
||||
display: false,
|
||||
details: { source: "sessions_yield", message },
|
||||
},
|
||||
{ triggerTurn: false },
|
||||
);
|
||||
}
|
||||
|
||||
// Remove the synthetic yield interrupt + aborted assistant entry from the live transcript.
|
||||
export function stripSessionsYieldArtifacts(activeSession: {
|
||||
messages: AgentMessage[];
|
||||
agent: { replaceMessages: (messages: AgentMessage[]) => void };
|
||||
sessionManager?: unknown;
|
||||
}) {
|
||||
const strippedMessages = activeSession.messages.slice();
|
||||
while (strippedMessages.length > 0) {
|
||||
const last = strippedMessages.at(-1) as
|
||||
| AgentMessage
|
||||
| { role?: string; customType?: string; stopReason?: string };
|
||||
if (last?.role === "assistant" && "stopReason" in last && last.stopReason === "aborted") {
|
||||
strippedMessages.pop();
|
||||
continue;
|
||||
}
|
||||
if (
|
||||
last?.role === "custom" &&
|
||||
"customType" in last &&
|
||||
last.customType === SESSIONS_YIELD_INTERRUPT_CUSTOM_TYPE
|
||||
) {
|
||||
strippedMessages.pop();
|
||||
continue;
|
||||
}
|
||||
break;
|
||||
}
|
||||
if (strippedMessages.length !== activeSession.messages.length) {
|
||||
activeSession.agent.replaceMessages(strippedMessages);
|
||||
}
|
||||
|
||||
const sessionManager = activeSession.sessionManager as
|
||||
| {
|
||||
fileEntries?: Array<{
|
||||
type?: string;
|
||||
id?: string;
|
||||
parentId?: string | null;
|
||||
message?: { role?: string; stopReason?: string };
|
||||
customType?: string;
|
||||
}>;
|
||||
byId?: Map<string, { id: string }>;
|
||||
leafId?: string | null;
|
||||
_rewriteFile?: () => void;
|
||||
}
|
||||
| undefined;
|
||||
const fileEntries = sessionManager?.fileEntries;
|
||||
const byId = sessionManager?.byId;
|
||||
if (!fileEntries || !byId) {
|
||||
return;
|
||||
}
|
||||
|
||||
let changed = false;
|
||||
while (fileEntries.length > 1) {
|
||||
const last = fileEntries.at(-1);
|
||||
if (!last || last.type === "session") {
|
||||
break;
|
||||
}
|
||||
const isYieldAbortAssistant =
|
||||
last.type === "message" &&
|
||||
last.message?.role === "assistant" &&
|
||||
last.message?.stopReason === "aborted";
|
||||
const isYieldInterruptMessage =
|
||||
last.type === "custom_message" && last.customType === SESSIONS_YIELD_INTERRUPT_CUSTOM_TYPE;
|
||||
if (!isYieldAbortAssistant && !isYieldInterruptMessage) {
|
||||
break;
|
||||
}
|
||||
fileEntries.pop();
|
||||
if (last.id) {
|
||||
byId.delete(last.id);
|
||||
}
|
||||
sessionManager.leafId = last.parentId ?? null;
|
||||
changed = true;
|
||||
}
|
||||
if (changed) {
|
||||
sessionManager._rewriteFile?.();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,383 @@
|
||||
import type { StreamFn } from "@mariozechner/pi-agent-core";
|
||||
import { streamSimple } from "@mariozechner/pi-ai";
|
||||
import { normalizeProviderId } from "../../model-selection.js";
|
||||
import { log } from "../logger.js";
|
||||
|
||||
function isToolCallBlockType(type: unknown): boolean {
|
||||
return type === "toolCall" || type === "toolUse" || type === "functionCall";
|
||||
}
|
||||
|
||||
function extractBalancedJsonPrefix(raw: string): string | null {
|
||||
let start = 0;
|
||||
while (start < raw.length && /\s/.test(raw[start] ?? "")) {
|
||||
start += 1;
|
||||
}
|
||||
const startChar = raw[start];
|
||||
if (startChar !== "{" && startChar !== "[") {
|
||||
return null;
|
||||
}
|
||||
|
||||
let depth = 0;
|
||||
let inString = false;
|
||||
let escaped = false;
|
||||
for (let i = start; i < raw.length; i += 1) {
|
||||
const char = raw[i];
|
||||
if (char === undefined) {
|
||||
break;
|
||||
}
|
||||
if (inString) {
|
||||
if (escaped) {
|
||||
escaped = false;
|
||||
} else if (char === "\\") {
|
||||
escaped = true;
|
||||
} else if (char === '"') {
|
||||
inString = false;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (char === '"') {
|
||||
inString = true;
|
||||
continue;
|
||||
}
|
||||
if (char === "{" || char === "[") {
|
||||
depth += 1;
|
||||
continue;
|
||||
}
|
||||
if (char === "}" || char === "]") {
|
||||
depth -= 1;
|
||||
if (depth === 0) {
|
||||
return raw.slice(start, i + 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
const MAX_TOOLCALL_REPAIR_BUFFER_CHARS = 64_000;
|
||||
const MAX_TOOLCALL_REPAIR_TRAILING_CHARS = 3;
|
||||
const TOOLCALL_REPAIR_ALLOWED_TRAILING_RE = /^[^\s{}[\]":,\\]{1,3}$/;
|
||||
|
||||
function shouldAttemptMalformedToolCallRepair(partialJson: string, delta: string): boolean {
|
||||
if (/[}\]]/.test(delta)) {
|
||||
return true;
|
||||
}
|
||||
const trimmedDelta = delta.trim();
|
||||
return (
|
||||
trimmedDelta.length > 0 &&
|
||||
trimmedDelta.length <= MAX_TOOLCALL_REPAIR_TRAILING_CHARS &&
|
||||
/[}\]]/.test(partialJson)
|
||||
);
|
||||
}
|
||||
|
||||
type ToolCallArgumentRepair = {
|
||||
args: Record<string, unknown>;
|
||||
trailingSuffix: string;
|
||||
};
|
||||
|
||||
function tryParseMalformedToolCallArguments(raw: string): ToolCallArgumentRepair | undefined {
|
||||
if (!raw.trim()) {
|
||||
return undefined;
|
||||
}
|
||||
try {
|
||||
JSON.parse(raw);
|
||||
return undefined;
|
||||
} catch {
|
||||
const jsonPrefix = extractBalancedJsonPrefix(raw);
|
||||
if (!jsonPrefix) {
|
||||
return undefined;
|
||||
}
|
||||
const suffix = raw.slice(raw.indexOf(jsonPrefix) + jsonPrefix.length).trim();
|
||||
if (
|
||||
suffix.length === 0 ||
|
||||
suffix.length > MAX_TOOLCALL_REPAIR_TRAILING_CHARS ||
|
||||
!TOOLCALL_REPAIR_ALLOWED_TRAILING_RE.test(suffix)
|
||||
) {
|
||||
return undefined;
|
||||
}
|
||||
try {
|
||||
const parsed = JSON.parse(jsonPrefix) as unknown;
|
||||
return parsed && typeof parsed === "object" && !Array.isArray(parsed)
|
||||
? { args: parsed as Record<string, unknown>, trailingSuffix: suffix }
|
||||
: undefined;
|
||||
} catch {
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function repairToolCallArgumentsInMessage(
|
||||
message: unknown,
|
||||
contentIndex: number,
|
||||
repairedArgs: Record<string, unknown>,
|
||||
): void {
|
||||
if (!message || typeof message !== "object") {
|
||||
return;
|
||||
}
|
||||
const content = (message as { content?: unknown }).content;
|
||||
if (!Array.isArray(content)) {
|
||||
return;
|
||||
}
|
||||
const block = content[contentIndex];
|
||||
if (!block || typeof block !== "object") {
|
||||
return;
|
||||
}
|
||||
const typedBlock = block as { type?: unknown; arguments?: unknown };
|
||||
if (!isToolCallBlockType(typedBlock.type)) {
|
||||
return;
|
||||
}
|
||||
typedBlock.arguments = repairedArgs;
|
||||
}
|
||||
|
||||
function clearToolCallArgumentsInMessage(message: unknown, contentIndex: number): void {
|
||||
if (!message || typeof message !== "object") {
|
||||
return;
|
||||
}
|
||||
const content = (message as { content?: unknown }).content;
|
||||
if (!Array.isArray(content)) {
|
||||
return;
|
||||
}
|
||||
const block = content[contentIndex];
|
||||
if (!block || typeof block !== "object") {
|
||||
return;
|
||||
}
|
||||
const typedBlock = block as { type?: unknown; arguments?: unknown };
|
||||
if (!isToolCallBlockType(typedBlock.type)) {
|
||||
return;
|
||||
}
|
||||
typedBlock.arguments = {};
|
||||
}
|
||||
|
||||
function repairMalformedToolCallArgumentsInMessage(
|
||||
message: unknown,
|
||||
repairedArgsByIndex: Map<number, Record<string, unknown>>,
|
||||
): void {
|
||||
if (!message || typeof message !== "object") {
|
||||
return;
|
||||
}
|
||||
const content = (message as { content?: unknown }).content;
|
||||
if (!Array.isArray(content)) {
|
||||
return;
|
||||
}
|
||||
for (const [index, repairedArgs] of repairedArgsByIndex.entries()) {
|
||||
repairToolCallArgumentsInMessage(message, index, repairedArgs);
|
||||
}
|
||||
}
|
||||
|
||||
function wrapStreamRepairMalformedToolCallArguments(
|
||||
stream: ReturnType<typeof streamSimple>,
|
||||
): ReturnType<typeof streamSimple> {
|
||||
const partialJsonByIndex = new Map<number, string>();
|
||||
const repairedArgsByIndex = new Map<number, Record<string, unknown>>();
|
||||
const disabledIndices = new Set<number>();
|
||||
const loggedRepairIndices = new Set<number>();
|
||||
const originalResult = stream.result.bind(stream);
|
||||
stream.result = async () => {
|
||||
const message = await originalResult();
|
||||
repairMalformedToolCallArgumentsInMessage(message, repairedArgsByIndex);
|
||||
partialJsonByIndex.clear();
|
||||
repairedArgsByIndex.clear();
|
||||
disabledIndices.clear();
|
||||
loggedRepairIndices.clear();
|
||||
return message;
|
||||
};
|
||||
|
||||
const originalAsyncIterator = stream[Symbol.asyncIterator].bind(stream);
|
||||
(stream as { [Symbol.asyncIterator]: typeof originalAsyncIterator })[Symbol.asyncIterator] =
|
||||
function () {
|
||||
const iterator = originalAsyncIterator();
|
||||
return {
|
||||
async next() {
|
||||
const result = await iterator.next();
|
||||
if (!result.done && result.value && typeof result.value === "object") {
|
||||
const event = result.value as {
|
||||
type?: unknown;
|
||||
contentIndex?: unknown;
|
||||
delta?: unknown;
|
||||
partial?: unknown;
|
||||
message?: unknown;
|
||||
toolCall?: unknown;
|
||||
};
|
||||
if (
|
||||
typeof event.contentIndex === "number" &&
|
||||
Number.isInteger(event.contentIndex) &&
|
||||
event.type === "toolcall_delta" &&
|
||||
typeof event.delta === "string"
|
||||
) {
|
||||
if (disabledIndices.has(event.contentIndex)) {
|
||||
return result;
|
||||
}
|
||||
const nextPartialJson =
|
||||
(partialJsonByIndex.get(event.contentIndex) ?? "") + event.delta;
|
||||
if (nextPartialJson.length > MAX_TOOLCALL_REPAIR_BUFFER_CHARS) {
|
||||
partialJsonByIndex.delete(event.contentIndex);
|
||||
repairedArgsByIndex.delete(event.contentIndex);
|
||||
disabledIndices.add(event.contentIndex);
|
||||
return result;
|
||||
}
|
||||
partialJsonByIndex.set(event.contentIndex, nextPartialJson);
|
||||
if (shouldAttemptMalformedToolCallRepair(nextPartialJson, event.delta)) {
|
||||
const repair = tryParseMalformedToolCallArguments(nextPartialJson);
|
||||
if (repair) {
|
||||
repairedArgsByIndex.set(event.contentIndex, repair.args);
|
||||
repairToolCallArgumentsInMessage(event.partial, event.contentIndex, repair.args);
|
||||
repairToolCallArgumentsInMessage(event.message, event.contentIndex, repair.args);
|
||||
if (!loggedRepairIndices.has(event.contentIndex)) {
|
||||
loggedRepairIndices.add(event.contentIndex);
|
||||
log.warn(
|
||||
`repairing Kimi tool call arguments after ${repair.trailingSuffix.length} trailing chars`,
|
||||
);
|
||||
}
|
||||
} else {
|
||||
repairedArgsByIndex.delete(event.contentIndex);
|
||||
clearToolCallArgumentsInMessage(event.partial, event.contentIndex);
|
||||
clearToolCallArgumentsInMessage(event.message, event.contentIndex);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (
|
||||
typeof event.contentIndex === "number" &&
|
||||
Number.isInteger(event.contentIndex) &&
|
||||
event.type === "toolcall_end"
|
||||
) {
|
||||
const repairedArgs = repairedArgsByIndex.get(event.contentIndex);
|
||||
if (repairedArgs) {
|
||||
if (event.toolCall && typeof event.toolCall === "object") {
|
||||
(event.toolCall as { arguments?: unknown }).arguments = repairedArgs;
|
||||
}
|
||||
repairToolCallArgumentsInMessage(event.partial, event.contentIndex, repairedArgs);
|
||||
repairToolCallArgumentsInMessage(event.message, event.contentIndex, repairedArgs);
|
||||
}
|
||||
partialJsonByIndex.delete(event.contentIndex);
|
||||
disabledIndices.delete(event.contentIndex);
|
||||
loggedRepairIndices.delete(event.contentIndex);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
},
|
||||
async return(value?: unknown) {
|
||||
return iterator.return?.(value) ?? { done: true as const, value: undefined };
|
||||
},
|
||||
async throw(error?: unknown) {
|
||||
return iterator.throw?.(error) ?? { done: true as const, value: undefined };
|
||||
},
|
||||
};
|
||||
};
|
||||
|
||||
return stream;
|
||||
}
|
||||
|
||||
export function wrapStreamFnRepairMalformedToolCallArguments(baseFn: StreamFn): StreamFn {
|
||||
return (model, context, options) => {
|
||||
const maybeStream = baseFn(model, context, options);
|
||||
if (maybeStream && typeof maybeStream === "object" && "then" in maybeStream) {
|
||||
return Promise.resolve(maybeStream).then((stream) =>
|
||||
wrapStreamRepairMalformedToolCallArguments(stream),
|
||||
);
|
||||
}
|
||||
return wrapStreamRepairMalformedToolCallArguments(maybeStream);
|
||||
};
|
||||
}
|
||||
|
||||
export function shouldRepairMalformedAnthropicToolCallArguments(provider?: string): boolean {
|
||||
return normalizeProviderId(provider ?? "") === "kimi";
|
||||
}
|
||||
|
||||
const HTML_ENTITY_RE = /&(?:amp|lt|gt|quot|apos|#39|#x[0-9a-f]+|#\d+);/i;
|
||||
|
||||
function decodeHtmlEntities(value: string): string {
|
||||
return value
|
||||
.replace(/&/gi, "&")
|
||||
.replace(/"/gi, '"')
|
||||
.replace(/'/gi, "'")
|
||||
.replace(/'/gi, "'")
|
||||
.replace(/</gi, "<")
|
||||
.replace(/>/gi, ">")
|
||||
.replace(/&#x([0-9a-f]+);/gi, (_, hex) => String.fromCodePoint(Number.parseInt(hex, 16)))
|
||||
.replace(/&#(\d+);/gi, (_, dec) => String.fromCodePoint(Number.parseInt(dec, 10)));
|
||||
}
|
||||
|
||||
export function decodeHtmlEntitiesInObject(obj: unknown): unknown {
|
||||
if (typeof obj === "string") {
|
||||
return HTML_ENTITY_RE.test(obj) ? decodeHtmlEntities(obj) : obj;
|
||||
}
|
||||
if (Array.isArray(obj)) {
|
||||
return obj.map(decodeHtmlEntitiesInObject);
|
||||
}
|
||||
if (obj && typeof obj === "object") {
|
||||
const result: Record<string, unknown> = {};
|
||||
for (const [key, val] of Object.entries(obj as Record<string, unknown>)) {
|
||||
result[key] = decodeHtmlEntitiesInObject(val);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
return obj;
|
||||
}
|
||||
|
||||
function decodeXaiToolCallArgumentsInMessage(message: unknown): void {
|
||||
if (!message || typeof message !== "object") {
|
||||
return;
|
||||
}
|
||||
const content = (message as { content?: unknown }).content;
|
||||
if (!Array.isArray(content)) {
|
||||
return;
|
||||
}
|
||||
for (const block of content) {
|
||||
if (!block || typeof block !== "object") {
|
||||
continue;
|
||||
}
|
||||
const typedBlock = block as { type?: unknown; arguments?: unknown };
|
||||
if (typedBlock.type !== "toolCall" || !typedBlock.arguments) {
|
||||
continue;
|
||||
}
|
||||
if (typeof typedBlock.arguments === "object") {
|
||||
typedBlock.arguments = decodeHtmlEntitiesInObject(typedBlock.arguments);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function wrapStreamDecodeXaiToolCallArguments(
|
||||
stream: ReturnType<typeof streamSimple>,
|
||||
): ReturnType<typeof streamSimple> {
|
||||
const originalResult = stream.result.bind(stream);
|
||||
stream.result = async () => {
|
||||
const message = await originalResult();
|
||||
decodeXaiToolCallArgumentsInMessage(message);
|
||||
return message;
|
||||
};
|
||||
|
||||
const originalAsyncIterator = stream[Symbol.asyncIterator].bind(stream);
|
||||
(stream as { [Symbol.asyncIterator]: typeof originalAsyncIterator })[Symbol.asyncIterator] =
|
||||
function () {
|
||||
const iterator = originalAsyncIterator();
|
||||
return {
|
||||
async next() {
|
||||
const result = await iterator.next();
|
||||
if (!result.done && result.value && typeof result.value === "object") {
|
||||
const event = result.value as { partial?: unknown; message?: unknown };
|
||||
decodeXaiToolCallArgumentsInMessage(event.partial);
|
||||
decodeXaiToolCallArgumentsInMessage(event.message);
|
||||
}
|
||||
return result;
|
||||
},
|
||||
async return(value?: unknown) {
|
||||
return iterator.return?.(value) ?? { done: true as const, value: undefined };
|
||||
},
|
||||
async throw(error?: unknown) {
|
||||
return iterator.throw?.(error) ?? { done: true as const, value: undefined };
|
||||
},
|
||||
};
|
||||
};
|
||||
return stream;
|
||||
}
|
||||
|
||||
export function wrapStreamFnDecodeXaiToolCallArguments(baseFn: StreamFn): StreamFn {
|
||||
return (model, context, options) => {
|
||||
const maybeStream = baseFn(model, context, options);
|
||||
if (maybeStream && typeof maybeStream === "object" && "then" in maybeStream) {
|
||||
return Promise.resolve(maybeStream).then((stream) =>
|
||||
wrapStreamDecodeXaiToolCallArguments(stream),
|
||||
);
|
||||
}
|
||||
return wrapStreamDecodeXaiToolCallArguments(maybeStream);
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,588 @@
|
||||
import type { AgentMessage, StreamFn } from "@mariozechner/pi-agent-core";
|
||||
import { streamSimple } from "@mariozechner/pi-ai";
|
||||
import { validateAnthropicTurns, validateGeminiTurns } from "../../pi-embedded-helpers.js";
|
||||
import { sanitizeToolUseResultPairing } from "../../session-transcript-repair.js";
|
||||
import { normalizeToolName } from "../../tool-policy.js";
|
||||
import type { TranscriptPolicy } from "../../transcript-policy.js";
|
||||
|
||||
function resolveCaseInsensitiveAllowedToolName(
|
||||
rawName: string,
|
||||
allowedToolNames?: Set<string>,
|
||||
): string | null {
|
||||
if (!allowedToolNames || allowedToolNames.size === 0) {
|
||||
return null;
|
||||
}
|
||||
const folded = rawName.toLowerCase();
|
||||
let caseInsensitiveMatch: string | null = null;
|
||||
for (const name of allowedToolNames) {
|
||||
if (name.toLowerCase() !== folded) {
|
||||
continue;
|
||||
}
|
||||
if (caseInsensitiveMatch && caseInsensitiveMatch !== name) {
|
||||
return null;
|
||||
}
|
||||
caseInsensitiveMatch = name;
|
||||
}
|
||||
return caseInsensitiveMatch;
|
||||
}
|
||||
|
||||
function resolveExactAllowedToolName(
|
||||
rawName: string,
|
||||
allowedToolNames?: Set<string>,
|
||||
): string | null {
|
||||
if (!allowedToolNames || allowedToolNames.size === 0) {
|
||||
return null;
|
||||
}
|
||||
if (allowedToolNames.has(rawName)) {
|
||||
return rawName;
|
||||
}
|
||||
const normalized = normalizeToolName(rawName);
|
||||
if (allowedToolNames.has(normalized)) {
|
||||
return normalized;
|
||||
}
|
||||
return (
|
||||
resolveCaseInsensitiveAllowedToolName(rawName, allowedToolNames) ??
|
||||
resolveCaseInsensitiveAllowedToolName(normalized, allowedToolNames)
|
||||
);
|
||||
}
|
||||
|
||||
function buildStructuredToolNameCandidates(rawName: string): string[] {
|
||||
const trimmed = rawName.trim();
|
||||
if (!trimmed) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const candidates: string[] = [];
|
||||
const seen = new Set<string>();
|
||||
const addCandidate = (value: string) => {
|
||||
const candidate = value.trim();
|
||||
if (!candidate || seen.has(candidate)) {
|
||||
return;
|
||||
}
|
||||
seen.add(candidate);
|
||||
candidates.push(candidate);
|
||||
};
|
||||
|
||||
addCandidate(trimmed);
|
||||
addCandidate(normalizeToolName(trimmed));
|
||||
|
||||
const normalizedDelimiter = trimmed.replace(/\//g, ".");
|
||||
addCandidate(normalizedDelimiter);
|
||||
addCandidate(normalizeToolName(normalizedDelimiter));
|
||||
|
||||
const segments = normalizedDelimiter
|
||||
.split(".")
|
||||
.map((segment) => segment.trim())
|
||||
.filter(Boolean);
|
||||
if (segments.length > 1) {
|
||||
for (let index = 1; index < segments.length; index += 1) {
|
||||
const suffix = segments.slice(index).join(".");
|
||||
addCandidate(suffix);
|
||||
addCandidate(normalizeToolName(suffix));
|
||||
}
|
||||
}
|
||||
|
||||
return candidates;
|
||||
}
|
||||
|
||||
function resolveStructuredAllowedToolName(
|
||||
rawName: string,
|
||||
allowedToolNames?: Set<string>,
|
||||
): string | null {
|
||||
if (!allowedToolNames || allowedToolNames.size === 0) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const candidateNames = buildStructuredToolNameCandidates(rawName);
|
||||
for (const candidate of candidateNames) {
|
||||
if (allowedToolNames.has(candidate)) {
|
||||
return candidate;
|
||||
}
|
||||
}
|
||||
|
||||
for (const candidate of candidateNames) {
|
||||
const caseInsensitiveMatch = resolveCaseInsensitiveAllowedToolName(candidate, allowedToolNames);
|
||||
if (caseInsensitiveMatch) {
|
||||
return caseInsensitiveMatch;
|
||||
}
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
function inferToolNameFromToolCallId(
|
||||
rawId: string | undefined,
|
||||
allowedToolNames?: Set<string>,
|
||||
): string | null {
|
||||
if (!rawId || !allowedToolNames || allowedToolNames.size === 0) {
|
||||
return null;
|
||||
}
|
||||
const id = rawId.trim();
|
||||
if (!id) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const candidateTokens = new Set<string>();
|
||||
const addToken = (value: string) => {
|
||||
const trimmed = value.trim();
|
||||
if (!trimmed) {
|
||||
return;
|
||||
}
|
||||
candidateTokens.add(trimmed);
|
||||
candidateTokens.add(trimmed.replace(/[:._/-]\d+$/, ""));
|
||||
candidateTokens.add(trimmed.replace(/\d+$/, ""));
|
||||
|
||||
const normalizedDelimiter = trimmed.replace(/\//g, ".");
|
||||
candidateTokens.add(normalizedDelimiter);
|
||||
candidateTokens.add(normalizedDelimiter.replace(/[:._-]\d+$/, ""));
|
||||
candidateTokens.add(normalizedDelimiter.replace(/\d+$/, ""));
|
||||
|
||||
for (const prefixPattern of [/^functions?[._-]?/i, /^tools?[._-]?/i]) {
|
||||
const stripped = normalizedDelimiter.replace(prefixPattern, "");
|
||||
if (stripped !== normalizedDelimiter) {
|
||||
candidateTokens.add(stripped);
|
||||
candidateTokens.add(stripped.replace(/[:._-]\d+$/, ""));
|
||||
candidateTokens.add(stripped.replace(/\d+$/, ""));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const preColon = id.split(":")[0] ?? id;
|
||||
for (const seed of [id, preColon]) {
|
||||
addToken(seed);
|
||||
}
|
||||
|
||||
let singleMatch: string | null = null;
|
||||
for (const candidate of candidateTokens) {
|
||||
const matched = resolveStructuredAllowedToolName(candidate, allowedToolNames);
|
||||
if (!matched) {
|
||||
continue;
|
||||
}
|
||||
if (singleMatch && singleMatch !== matched) {
|
||||
return null;
|
||||
}
|
||||
singleMatch = matched;
|
||||
}
|
||||
|
||||
return singleMatch;
|
||||
}
|
||||
|
||||
function looksLikeMalformedToolNameCounter(rawName: string): boolean {
|
||||
const normalizedDelimiter = rawName.trim().replace(/\//g, ".");
|
||||
return (
|
||||
/^(?:functions?|tools?)[._-]?/i.test(normalizedDelimiter) &&
|
||||
/(?:[:._-]\d+|\d+)$/.test(normalizedDelimiter)
|
||||
);
|
||||
}
|
||||
|
||||
function normalizeToolCallNameForDispatch(
|
||||
rawName: string,
|
||||
allowedToolNames?: Set<string>,
|
||||
rawToolCallId?: string,
|
||||
): string {
|
||||
const trimmed = rawName.trim();
|
||||
if (!trimmed) {
|
||||
return inferToolNameFromToolCallId(rawToolCallId, allowedToolNames) ?? rawName;
|
||||
}
|
||||
if (!allowedToolNames || allowedToolNames.size === 0) {
|
||||
return trimmed;
|
||||
}
|
||||
|
||||
const exact = resolveExactAllowedToolName(trimmed, allowedToolNames);
|
||||
if (exact) {
|
||||
return exact;
|
||||
}
|
||||
const inferredFromName = inferToolNameFromToolCallId(trimmed, allowedToolNames);
|
||||
if (inferredFromName) {
|
||||
return inferredFromName;
|
||||
}
|
||||
|
||||
if (looksLikeMalformedToolNameCounter(trimmed)) {
|
||||
return trimmed;
|
||||
}
|
||||
|
||||
return resolveStructuredAllowedToolName(trimmed, allowedToolNames) ?? trimmed;
|
||||
}
|
||||
|
||||
function isToolCallBlockType(type: unknown): boolean {
|
||||
return type === "toolCall" || type === "toolUse" || type === "functionCall";
|
||||
}
|
||||
|
||||
const REPLAY_TOOL_CALL_NAME_MAX_CHARS = 64;
|
||||
|
||||
type ReplayToolCallBlock = {
|
||||
type?: unknown;
|
||||
id?: unknown;
|
||||
name?: unknown;
|
||||
input?: unknown;
|
||||
arguments?: unknown;
|
||||
};
|
||||
|
||||
type ReplayToolCallSanitizeReport = {
|
||||
messages: AgentMessage[];
|
||||
droppedAssistantMessages: number;
|
||||
};
|
||||
|
||||
type AnthropicToolResultContentBlock = {
|
||||
type?: unknown;
|
||||
toolUseId?: unknown;
|
||||
};
|
||||
|
||||
function isReplayToolCallBlock(block: unknown): block is ReplayToolCallBlock {
|
||||
if (!block || typeof block !== "object") {
|
||||
return false;
|
||||
}
|
||||
return isToolCallBlockType((block as { type?: unknown }).type);
|
||||
}
|
||||
|
||||
function replayToolCallHasInput(block: ReplayToolCallBlock): boolean {
|
||||
const hasInput = "input" in block ? block.input !== undefined && block.input !== null : false;
|
||||
const hasArguments =
|
||||
"arguments" in block ? block.arguments !== undefined && block.arguments !== null : false;
|
||||
return hasInput || hasArguments;
|
||||
}
|
||||
|
||||
function replayToolCallNonEmptyString(value: unknown): value is string {
|
||||
return typeof value === "string" && value.trim().length > 0;
|
||||
}
|
||||
|
||||
function resolveReplayToolCallName(
|
||||
rawName: string,
|
||||
rawId: string,
|
||||
allowedToolNames?: Set<string>,
|
||||
): string | null {
|
||||
if (rawName.length > REPLAY_TOOL_CALL_NAME_MAX_CHARS * 2) {
|
||||
return null;
|
||||
}
|
||||
const normalized = normalizeToolCallNameForDispatch(rawName, allowedToolNames, rawId);
|
||||
const trimmed = normalized.trim();
|
||||
if (!trimmed || trimmed.length > REPLAY_TOOL_CALL_NAME_MAX_CHARS || /\s/.test(trimmed)) {
|
||||
return null;
|
||||
}
|
||||
if (!allowedToolNames || allowedToolNames.size === 0) {
|
||||
return trimmed;
|
||||
}
|
||||
return resolveExactAllowedToolName(trimmed, allowedToolNames);
|
||||
}
|
||||
|
||||
function sanitizeReplayToolCallInputs(
|
||||
messages: AgentMessage[],
|
||||
allowedToolNames?: Set<string>,
|
||||
): ReplayToolCallSanitizeReport {
|
||||
let changed = false;
|
||||
let droppedAssistantMessages = 0;
|
||||
const out: AgentMessage[] = [];
|
||||
|
||||
for (const message of messages) {
|
||||
if (!message || typeof message !== "object" || message.role !== "assistant") {
|
||||
out.push(message);
|
||||
continue;
|
||||
}
|
||||
if (!Array.isArray(message.content)) {
|
||||
out.push(message);
|
||||
continue;
|
||||
}
|
||||
|
||||
const nextContent: typeof message.content = [];
|
||||
let messageChanged = false;
|
||||
|
||||
for (const block of message.content) {
|
||||
if (!isReplayToolCallBlock(block)) {
|
||||
nextContent.push(block);
|
||||
continue;
|
||||
}
|
||||
const replayBlock = block as ReplayToolCallBlock;
|
||||
|
||||
if (!replayToolCallHasInput(replayBlock) || !replayToolCallNonEmptyString(replayBlock.id)) {
|
||||
changed = true;
|
||||
messageChanged = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
const rawName = typeof replayBlock.name === "string" ? replayBlock.name : "";
|
||||
const resolvedName = resolveReplayToolCallName(rawName, replayBlock.id, allowedToolNames);
|
||||
if (!resolvedName) {
|
||||
changed = true;
|
||||
messageChanged = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (replayBlock.name !== resolvedName) {
|
||||
nextContent.push({ ...(block as object), name: resolvedName } as typeof block);
|
||||
changed = true;
|
||||
messageChanged = true;
|
||||
continue;
|
||||
}
|
||||
nextContent.push(block);
|
||||
}
|
||||
|
||||
if (messageChanged) {
|
||||
changed = true;
|
||||
if (nextContent.length > 0) {
|
||||
out.push({ ...message, content: nextContent });
|
||||
} else {
|
||||
droppedAssistantMessages += 1;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
out.push(message);
|
||||
}
|
||||
|
||||
return {
|
||||
messages: changed ? out : messages,
|
||||
droppedAssistantMessages,
|
||||
};
|
||||
}
|
||||
|
||||
function sanitizeAnthropicReplayToolResults(messages: AgentMessage[]): AgentMessage[] {
|
||||
let changed = false;
|
||||
const out: AgentMessage[] = [];
|
||||
|
||||
for (let index = 0; index < messages.length; index += 1) {
|
||||
const message = messages[index];
|
||||
if (!message || typeof message !== "object" || message.role !== "user") {
|
||||
out.push(message);
|
||||
continue;
|
||||
}
|
||||
if (!Array.isArray(message.content)) {
|
||||
out.push(message);
|
||||
continue;
|
||||
}
|
||||
|
||||
const previous = messages[index - 1];
|
||||
const validToolUseIds = new Set<string>();
|
||||
if (previous && typeof previous === "object" && previous.role === "assistant") {
|
||||
const previousContent = (previous as { content?: unknown }).content;
|
||||
if (Array.isArray(previousContent)) {
|
||||
for (const block of previousContent) {
|
||||
if (!block || typeof block !== "object") {
|
||||
continue;
|
||||
}
|
||||
const typedBlock = block as { type?: unknown; id?: unknown };
|
||||
if (typedBlock.type !== "toolUse" || typeof typedBlock.id !== "string") {
|
||||
continue;
|
||||
}
|
||||
const trimmedId = typedBlock.id.trim();
|
||||
if (trimmedId) {
|
||||
validToolUseIds.add(trimmedId);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const nextContent = message.content.filter((block) => {
|
||||
if (!block || typeof block !== "object") {
|
||||
return true;
|
||||
}
|
||||
const typedBlock = block as AnthropicToolResultContentBlock;
|
||||
if (typedBlock.type !== "toolResult" || typeof typedBlock.toolUseId !== "string") {
|
||||
return true;
|
||||
}
|
||||
return validToolUseIds.size > 0 && validToolUseIds.has(typedBlock.toolUseId);
|
||||
});
|
||||
|
||||
if (nextContent.length === message.content.length) {
|
||||
out.push(message);
|
||||
continue;
|
||||
}
|
||||
|
||||
changed = true;
|
||||
if (nextContent.length > 0) {
|
||||
out.push({ ...message, content: nextContent });
|
||||
continue;
|
||||
}
|
||||
|
||||
out.push({
|
||||
...message,
|
||||
content: [{ type: "text", text: "[tool results omitted]" }],
|
||||
} as AgentMessage);
|
||||
}
|
||||
|
||||
return changed ? out : messages;
|
||||
}
|
||||
|
||||
function normalizeToolCallIdsInMessage(message: unknown): void {
|
||||
if (!message || typeof message !== "object") {
|
||||
return;
|
||||
}
|
||||
const content = (message as { content?: unknown }).content;
|
||||
if (!Array.isArray(content)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const usedIds = new Set<string>();
|
||||
for (const block of content) {
|
||||
if (!block || typeof block !== "object") {
|
||||
continue;
|
||||
}
|
||||
const typedBlock = block as { type?: unknown; id?: unknown };
|
||||
if (!isToolCallBlockType(typedBlock.type) || typeof typedBlock.id !== "string") {
|
||||
continue;
|
||||
}
|
||||
const trimmedId = typedBlock.id.trim();
|
||||
if (!trimmedId) {
|
||||
continue;
|
||||
}
|
||||
usedIds.add(trimmedId);
|
||||
}
|
||||
|
||||
let fallbackIndex = 1;
|
||||
const assignedIds = new Set<string>();
|
||||
for (const block of content) {
|
||||
if (!block || typeof block !== "object") {
|
||||
continue;
|
||||
}
|
||||
const typedBlock = block as { type?: unknown; id?: unknown };
|
||||
if (!isToolCallBlockType(typedBlock.type)) {
|
||||
continue;
|
||||
}
|
||||
if (typeof typedBlock.id === "string") {
|
||||
const trimmedId = typedBlock.id.trim();
|
||||
if (trimmedId) {
|
||||
if (!assignedIds.has(trimmedId)) {
|
||||
if (typedBlock.id !== trimmedId) {
|
||||
typedBlock.id = trimmedId;
|
||||
}
|
||||
assignedIds.add(trimmedId);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let fallbackId = "";
|
||||
while (!fallbackId || usedIds.has(fallbackId) || assignedIds.has(fallbackId)) {
|
||||
fallbackId = `call_auto_${fallbackIndex++}`;
|
||||
}
|
||||
typedBlock.id = fallbackId;
|
||||
usedIds.add(fallbackId);
|
||||
assignedIds.add(fallbackId);
|
||||
}
|
||||
}
|
||||
|
||||
function trimWhitespaceFromToolCallNamesInMessage(
|
||||
message: unknown,
|
||||
allowedToolNames?: Set<string>,
|
||||
): void {
|
||||
if (!message || typeof message !== "object") {
|
||||
return;
|
||||
}
|
||||
const content = (message as { content?: unknown }).content;
|
||||
if (!Array.isArray(content)) {
|
||||
return;
|
||||
}
|
||||
for (const block of content) {
|
||||
if (!block || typeof block !== "object") {
|
||||
continue;
|
||||
}
|
||||
const typedBlock = block as { type?: unknown; name?: unknown; id?: unknown };
|
||||
if (!isToolCallBlockType(typedBlock.type)) {
|
||||
continue;
|
||||
}
|
||||
const rawId = typeof typedBlock.id === "string" ? typedBlock.id : undefined;
|
||||
if (typeof typedBlock.name === "string") {
|
||||
const normalized = normalizeToolCallNameForDispatch(typedBlock.name, allowedToolNames, rawId);
|
||||
if (normalized !== typedBlock.name) {
|
||||
typedBlock.name = normalized;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
const inferred = inferToolNameFromToolCallId(rawId, allowedToolNames);
|
||||
if (inferred) {
|
||||
typedBlock.name = inferred;
|
||||
}
|
||||
}
|
||||
normalizeToolCallIdsInMessage(message);
|
||||
}
|
||||
|
||||
function wrapStreamTrimToolCallNames(
|
||||
stream: ReturnType<typeof streamSimple>,
|
||||
allowedToolNames?: Set<string>,
|
||||
): ReturnType<typeof streamSimple> {
|
||||
const originalResult = stream.result.bind(stream);
|
||||
stream.result = async () => {
|
||||
const message = await originalResult();
|
||||
trimWhitespaceFromToolCallNamesInMessage(message, allowedToolNames);
|
||||
return message;
|
||||
};
|
||||
|
||||
const originalAsyncIterator = stream[Symbol.asyncIterator].bind(stream);
|
||||
(stream as { [Symbol.asyncIterator]: typeof originalAsyncIterator })[Symbol.asyncIterator] =
|
||||
function () {
|
||||
const iterator = originalAsyncIterator();
|
||||
return {
|
||||
async next() {
|
||||
const result = await iterator.next();
|
||||
if (!result.done && result.value && typeof result.value === "object") {
|
||||
const event = result.value as {
|
||||
partial?: unknown;
|
||||
message?: unknown;
|
||||
};
|
||||
trimWhitespaceFromToolCallNamesInMessage(event.partial, allowedToolNames);
|
||||
trimWhitespaceFromToolCallNamesInMessage(event.message, allowedToolNames);
|
||||
}
|
||||
return result;
|
||||
},
|
||||
async return(value?: unknown) {
|
||||
return iterator.return?.(value) ?? { done: true as const, value: undefined };
|
||||
},
|
||||
async throw(error?: unknown) {
|
||||
return iterator.throw?.(error) ?? { done: true as const, value: undefined };
|
||||
},
|
||||
};
|
||||
};
|
||||
|
||||
return stream;
|
||||
}
|
||||
|
||||
export function wrapStreamFnTrimToolCallNames(
|
||||
baseFn: StreamFn,
|
||||
allowedToolNames?: Set<string>,
|
||||
): StreamFn {
|
||||
return (model, context, options) => {
|
||||
const maybeStream = baseFn(model, context, options);
|
||||
if (maybeStream && typeof maybeStream === "object" && "then" in maybeStream) {
|
||||
return Promise.resolve(maybeStream).then((stream) =>
|
||||
wrapStreamTrimToolCallNames(stream, allowedToolNames),
|
||||
);
|
||||
}
|
||||
return wrapStreamTrimToolCallNames(maybeStream, allowedToolNames);
|
||||
};
|
||||
}
|
||||
|
||||
export function wrapStreamFnSanitizeMalformedToolCalls(
|
||||
baseFn: StreamFn,
|
||||
allowedToolNames?: Set<string>,
|
||||
transcriptPolicy?: Pick<TranscriptPolicy, "validateGeminiTurns" | "validateAnthropicTurns">,
|
||||
): StreamFn {
|
||||
return (model, context, options) => {
|
||||
const ctx = context as unknown as { messages?: unknown };
|
||||
const messages = ctx?.messages;
|
||||
if (!Array.isArray(messages)) {
|
||||
return baseFn(model, context, options);
|
||||
}
|
||||
const sanitized = sanitizeReplayToolCallInputs(messages as AgentMessage[], allowedToolNames);
|
||||
if (sanitized.messages === messages) {
|
||||
return baseFn(model, context, options);
|
||||
}
|
||||
let nextMessages = sanitizeToolUseResultPairing(sanitized.messages, {
|
||||
preserveErroredAssistantResults: true,
|
||||
});
|
||||
if (transcriptPolicy?.validateAnthropicTurns) {
|
||||
nextMessages = sanitizeAnthropicReplayToolResults(nextMessages);
|
||||
}
|
||||
if (sanitized.droppedAssistantMessages > 0 || transcriptPolicy?.validateAnthropicTurns) {
|
||||
if (transcriptPolicy?.validateGeminiTurns) {
|
||||
nextMessages = validateGeminiTurns(nextMessages);
|
||||
}
|
||||
if (transcriptPolicy?.validateAnthropicTurns) {
|
||||
nextMessages = validateAnthropicTurns(nextMessages);
|
||||
}
|
||||
}
|
||||
const nextContext = {
|
||||
...(context as unknown as Record<string, unknown>),
|
||||
messages: nextMessages,
|
||||
} as unknown;
|
||||
return baseFn(model, nextContext as typeof context, options);
|
||||
};
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user