Handle aggregate tool-result overflow fallback

This commit is contained in:
Tak Hoffman
2026-04-05 23:51:03 -05:00
committed by Peter Steinberger
parent 09b7c00dab
commit a8fb094c5b
2 changed files with 210 additions and 4 deletions

View File

@@ -1,6 +1,10 @@
import fs from "node:fs/promises";
import os from "node:os";
import path from "node:path";
import type { AgentMessage } from "@mariozechner/pi-agent-core";
import { SessionManager } from "@mariozechner/pi-coding-agent";
import type { AssistantMessage, ToolResultMessage, UserMessage } from "@mariozechner/pi-ai";
import { beforeEach, describe, expect, it } from "vitest";
import { afterEach, beforeEach, describe, expect, it } from "vitest";
import { makeAgentAssistantMessage } from "../test-helpers/agent-message-fixtures.js";
let truncateToolResultText: typeof import("./tool-result-truncation.js").truncateToolResultText;
@@ -8,9 +12,12 @@ let truncateToolResultMessage: typeof import("./tool-result-truncation.js").trun
let calculateMaxToolResultChars: typeof import("./tool-result-truncation.js").calculateMaxToolResultChars;
let getToolResultTextLength: typeof import("./tool-result-truncation.js").getToolResultTextLength;
let truncateOversizedToolResultsInMessages: typeof import("./tool-result-truncation.js").truncateOversizedToolResultsInMessages;
let truncateOversizedToolResultsInSession: typeof import("./tool-result-truncation.js").truncateOversizedToolResultsInSession;
let isOversizedToolResult: typeof import("./tool-result-truncation.js").isOversizedToolResult;
let sessionLikelyHasOversizedToolResults: typeof import("./tool-result-truncation.js").sessionLikelyHasOversizedToolResults;
let DEFAULT_MAX_LIVE_TOOL_RESULT_CHARS: typeof import("./tool-result-truncation.js").DEFAULT_MAX_LIVE_TOOL_RESULT_CHARS;
let HARD_MAX_TOOL_RESULT_CHARS: typeof import("./tool-result-truncation.js").HARD_MAX_TOOL_RESULT_CHARS;
let tmpDir: string | undefined;
async function loadFreshToolResultTruncationModuleForTest() {
({
@@ -19,7 +26,9 @@ async function loadFreshToolResultTruncationModuleForTest() {
calculateMaxToolResultChars,
getToolResultTextLength,
truncateOversizedToolResultsInMessages,
truncateOversizedToolResultsInSession,
isOversizedToolResult,
sessionLikelyHasOversizedToolResults,
DEFAULT_MAX_LIVE_TOOL_RESULT_CHARS,
HARD_MAX_TOOL_RESULT_CHARS,
} = await import("./tool-result-truncation.js"));
@@ -33,6 +42,13 @@ beforeEach(async () => {
await loadFreshToolResultTruncationModuleForTest();
});
afterEach(async () => {
if (tmpDir) {
await fs.rm(tmpDir, { recursive: true, force: true }).catch(() => {});
tmpDir = undefined;
}
});
function makeToolResult(text: string, toolCallId = "call_1"): ToolResultMessage {
return {
role: "toolResult",
@@ -69,6 +85,11 @@ function getFirstToolResultText(message: AgentMessage | ToolResultMessage): stri
return firstBlock && "text" in firstBlock ? firstBlock.text : "";
}
async function createTmpDir(): Promise<string> {
tmpDir = await fs.mkdtemp(path.join(os.tmpdir(), "tool-result-truncation-test-"));
return tmpDir;
}
describe("truncateToolResultText", () => {
it("returns text unchanged when under limit", () => {
const text = "hello world";
@@ -203,6 +224,27 @@ describe("isOversizedToolResult", () => {
});
});
describe("sessionLikelyHasOversizedToolResults", () => {
it("returns true for individually oversized tool results", () => {
const messages: AgentMessage[] = [makeToolResult("x".repeat(500_000))];
expect(sessionLikelyHasOversizedToolResults({ messages, contextWindowTokens: 128_000 })).toBe(
true,
);
});
it("returns true for aggregate medium tool results that exceed the shared budget", () => {
const medium = "alpha beta gamma delta epsilon ".repeat(600);
const messages: AgentMessage[] = [
makeToolResult(medium, "call_1"),
makeToolResult(medium, "call_2"),
makeToolResult(medium, "call_3"),
];
expect(sessionLikelyHasOversizedToolResults({ messages, contextWindowTokens: 128_000 })).toBe(
true,
);
});
});
describe("truncateOversizedToolResultsInMessages", () => {
it("returns unchanged messages when nothing is oversized", () => {
const messages = [
@@ -268,6 +310,64 @@ describe("truncateOversizedToolResultsInMessages", () => {
});
});
describe("truncateOversizedToolResultsInSession", () => {
it("readably truncates aggregate medium tool results in a session file", async () => {
const dir = await createTmpDir();
const sm = SessionManager.create(dir, dir);
sm.appendMessage(makeUserMessage("hello"));
sm.appendMessage(makeAssistantMessage("calling tools"));
const medium = "alpha beta gamma delta epsilon ".repeat(600);
sm.appendMessage(makeToolResult(medium, "call_1"));
sm.appendMessage(makeToolResult(medium, "call_2"));
sm.appendMessage(makeToolResult(medium, "call_3"));
const sessionFile = sm.getSessionFile()!;
const beforeBranch = SessionManager.open(sessionFile).getBranch();
const beforeLengths = beforeBranch
.filter((entry) => entry.type === "message")
.map((entry) =>
entry.type === "message" && entry.message.role === "toolResult"
? getToolResultTextLength(entry.message)
: 0,
)
.filter((length) => length > 0);
const result = await truncateOversizedToolResultsInSession({
sessionFile,
contextWindowTokens: 128_000,
});
expect(result.truncated).toBe(true);
expect(result.truncatedCount).toBeGreaterThan(0);
const afterBranch = SessionManager.open(sessionFile).getBranch();
const afterToolResults = afterBranch.filter(
(entry) => entry.type === "message" && entry.message.role === "toolResult",
);
const afterLengths = afterToolResults.map((entry) =>
entry.type === "message" ? getToolResultTextLength(entry.message) : 0,
);
expect(afterLengths.reduce((sum, value) => sum + value, 0)).toBeLessThan(
beforeLengths.reduce((sum, value) => sum + value, 0),
);
expect(
afterToolResults.some((entry) =>
entry.type === "message"
? getFirstToolResultText(entry.message).includes("truncated")
: false,
),
).toBe(true);
expect(
afterToolResults.some((entry) =>
entry.type === "message"
? getFirstToolResultText(entry.message).includes("[compacted:")
: false,
),
).toBe(false);
});
});
describe("truncateToolResultText head+tail strategy", () => {
it("preserves error content at the tail when present", () => {
const head = "Line 1\n".repeat(500);

View File

@@ -41,6 +41,7 @@ const TRUNCATION_SUFFIX =
"\n\n⚠ [Content truncated — original was too large for the model's context window. " +
"The content above is a partial view. If you need more, request specific sections or use " +
"offset/limit parameters to read smaller chunks.]";
const MIN_TRUNCATED_TEXT_CHARS = MIN_KEEP_CHARS + TRUNCATION_SUFFIX.length;
type ToolResultTruncationOptions = {
suffix?: string | ((truncatedChars: number) => string);
@@ -250,6 +251,72 @@ export function truncateOversizedToolResultsInMessages(
return { messages: result, truncatedCount };
}
function calculateAggregateToolResultChars(contextWindowTokens: number): number {
return Math.max(calculateMaxToolResultChars(contextWindowTokens), MIN_TRUNCATED_TEXT_CHARS);
}
function buildAggregateToolResultReplacements(params: {
branch: Array<{ id: string; type: string; message?: AgentMessage }>;
aggregateBudgetChars: number;
}): Array<{ entryId: string; message: AgentMessage }> {
const candidates = params.branch
.map((entry, index) => ({ entry, index }))
.filter(
(item): item is {
entry: { id: string; type: string; message: AgentMessage };
index: number;
} =>
item.entry.type === "message" &&
Boolean(item.entry.message) &&
(item.entry.message as { role?: string }).role === "toolResult",
)
.map((item) => ({
entryId: item.entry.id,
message: item.entry.message,
textLength: getToolResultTextLength(item.entry.message),
}))
.filter((item) => item.textLength > 0);
if (candidates.length < 2) {
return [];
}
const totalChars = candidates.reduce((sum, item) => sum + item.textLength, 0);
if (totalChars <= params.aggregateBudgetChars) {
return [];
}
let remainingReduction = totalChars - params.aggregateBudgetChars;
const replacements: Array<{ entryId: string; message: AgentMessage }> = [];
for (const candidate of candidates.toSorted((a, b) => b.textLength - a.textLength)) {
if (remainingReduction <= 0) {
break;
}
const reducibleChars = Math.max(0, candidate.textLength - MIN_TRUNCATED_TEXT_CHARS);
if (reducibleChars <= 0) {
continue;
}
const requestedReduction = Math.min(reducibleChars, remainingReduction);
const targetChars = Math.max(
MIN_TRUNCATED_TEXT_CHARS,
candidate.textLength - requestedReduction,
);
const truncatedMessage = truncateToolResultMessage(candidate.message, targetChars);
const newLength = getToolResultTextLength(truncatedMessage);
const actualReduction = Math.max(0, candidate.textLength - newLength);
if (actualReduction <= 0) {
continue;
}
replacements.push({ entryId: candidate.entryId, message: truncatedMessage });
remainingReduction -= actualReduction;
}
return replacements;
}
export async function truncateOversizedToolResultsInSession(params: {
sessionFile: string;
contextWindowTokens: number;
@@ -258,6 +325,7 @@ export async function truncateOversizedToolResultsInSession(params: {
}): Promise<{ truncated: boolean; truncatedCount: number; reason?: string }> {
const { sessionFile, contextWindowTokens } = params;
const maxChars = calculateMaxToolResultChars(contextWindowTokens);
const aggregateBudgetChars = calculateAggregateToolResultChars(contextWindowTokens);
let sessionLock: Awaited<ReturnType<typeof acquireSessionWriteLock>> | undefined;
try {
@@ -285,7 +353,37 @@ export async function truncateOversizedToolResultsInSession(params: {
}
if (oversizedIndices.length === 0) {
return { truncated: false, truncatedCount: 0, reason: "no oversized tool results" };
const replacements = buildAggregateToolResultReplacements({
branch: branch as Array<{ id: string; type: string; message?: AgentMessage }>,
aggregateBudgetChars,
});
if (replacements.length === 0) {
return {
truncated: false,
truncatedCount: 0,
reason: "no oversized or aggregate tool results",
};
}
const rewriteResult = rewriteTranscriptEntriesInSessionManager({
sessionManager,
replacements,
});
if (rewriteResult.changed) {
emitSessionTranscriptUpdate(sessionFile);
}
log.info(
`[tool-result-truncation] Aggregate-truncated ${rewriteResult.rewrittenEntries} tool result(s) in session ` +
`(contextWindow=${contextWindowTokens} aggregateBudgetChars=${aggregateBudgetChars}) ` +
`sessionKey=${params.sessionKey ?? params.sessionId ?? "unknown"}`,
);
return {
truncated: rewriteResult.changed,
truncatedCount: rewriteResult.rewrittenEntries,
reason: rewriteResult.reason,
};
}
const replacements = oversizedIndices.flatMap((index) => {
@@ -341,15 +439,23 @@ export function sessionLikelyHasOversizedToolResults(params: {
}): boolean {
const { messages, contextWindowTokens } = params;
const maxChars = calculateMaxToolResultChars(contextWindowTokens);
const aggregateBudgetChars = calculateAggregateToolResultChars(contextWindowTokens);
let totalToolResultChars = 0;
let toolResultCount = 0;
for (const msg of messages) {
if ((msg as { role?: string }).role !== "toolResult") {
continue;
}
if (getToolResultTextLength(msg) > maxChars) {
const textLength = getToolResultTextLength(msg);
if (textLength > maxChars) {
return true;
}
totalToolResultChars += textLength;
if (textLength > 0) {
toolResultCount += 1;
}
}
return false;
return toolResultCount >= 2 && totalToolResultChars > aggregateBudgetChars;
}