diff --git a/src/agents/pi-embedded-runner/run.overflow-compaction.loop.test.ts b/src/agents/pi-embedded-runner/run.overflow-compaction.loop.test.ts index 6c645359b6d..669a310713f 100644 --- a/src/agents/pi-embedded-runner/run.overflow-compaction.loop.test.ts +++ b/src/agents/pi-embedded-runner/run.overflow-compaction.loop.test.ts @@ -171,6 +171,68 @@ describe("overflow compaction in run loop", () => { expect(result.meta.error).toBeUndefined(); }); + it("retries without hitting compaction when attempt-level preflight truncation already handled the overflow", async () => { + mockedRunEmbeddedAttempt + .mockResolvedValueOnce( + makeAttemptResult({ + promptError: null, + preflightRecovery: { + route: "truncate_tool_results_only", + handled: true, + truncatedCount: 2, + }, + }), + ) + .mockResolvedValueOnce(makeAttemptResult({ promptError: null })); + + const result = await runEmbeddedPiAgent(baseParams); + + expect(mockedCompactDirect).not.toHaveBeenCalled(); + expect(mockedTruncateOversizedToolResultsInSession).not.toHaveBeenCalled(); + expect(mockedRunEmbeddedAttempt).toHaveBeenCalledTimes(2); + expect(mockedLog.info).toHaveBeenCalledWith( + expect.stringContaining("early recovery route=truncate_tool_results_only"), + ); + expect(result.meta.error).toBeUndefined(); + }); + + it("runs post-compaction tool-result truncation before retry for mixed precheck routes", async () => { + mockedRunEmbeddedAttempt + .mockResolvedValueOnce( + makeAttemptResult({ + promptError: makeOverflowError( + "Context overflow: prompt too large for the model (precheck).", + ), + preflightRecovery: { route: "compact_then_truncate" }, + }), + ) + .mockResolvedValueOnce(makeAttemptResult({ promptError: null })); + + mockedCompactDirect.mockResolvedValueOnce( + makeCompactionSuccess({ + summary: "Compacted session", + firstKeptEntryId: "entry-5", + tokensBefore: 150000, + }), + ); + mockedTruncateOversizedToolResultsInSession.mockResolvedValueOnce({ + truncated: true, + truncatedCount: 2, + }); + + const result = await runEmbeddedPiAgent(baseParams); + + expect(mockedCompactDirect).toHaveBeenCalledTimes(1); + expect(mockedTruncateOversizedToolResultsInSession).toHaveBeenCalledWith( + expect.objectContaining({ sessionFile: "/tmp/session.json" }), + ); + expect(mockedRunEmbeddedAttempt).toHaveBeenCalledTimes(2); + expect(mockedLog.info).toHaveBeenCalledWith( + expect.stringContaining("post-compaction tool-result truncation succeeded"), + ); + expect(result.meta.error).toBeUndefined(); + }); + it("retries compaction up to 3 times before giving up", async () => { const overflowError = makeOverflowError(); diff --git a/src/agents/pi-embedded-runner/run.ts b/src/agents/pi-embedded-runner/run.ts index 1598cbc9758..d4dc99b4692 100644 --- a/src/agents/pi-embedded-runner/run.ts +++ b/src/agents/pi-embedded-runner/run.ts @@ -64,10 +64,6 @@ import { runContextEngineMaintenance } from "./context-engine-maintenance.js"; import { resolveGlobalLane, resolveSessionLane } from "./lanes.js"; import { log } from "./logger.js"; import { resolveModelAsync } from "./model.js"; -import { - sessionLikelyHasOversizedToolResults, - truncateOversizedToolResultsInSession, -} from "./tool-result-truncation.js"; import { handleAssistantFailover } from "./run/assistant-failover.js"; import { runEmbeddedAttempt } from "./run/attempt.js"; import { createEmbeddedRunAuthController } from "./run/auth-controller.js"; @@ -95,6 +91,10 @@ import type { RunEmbeddedPiAgentParams } from "./run/params.js"; import { buildEmbeddedRunPayloads } from "./run/payloads.js"; import { handleRetryLimitExhaustion } from "./run/retry-limit.js"; import { resolveEffectiveRuntimeModel, resolveHookModelSelection } from "./run/setup.js"; +import { + sessionLikelyHasOversizedToolResults, + truncateOversizedToolResultsInSession, +} from "./tool-result-truncation.js"; import type { EmbeddedPiAgentMeta, EmbeddedPiRunResult } from "./types.js"; import { createUsageAccumulator, mergeUsageIntoAccumulator } from "./usage-accumulator.js"; import { describeUnknownError } from "./utils.js"; @@ -616,6 +616,7 @@ export async function runEmbeddedPiAgent( const { aborted, promptError, + preflightRecovery, timedOut, timedOutDuringCompaction, sessionIdUsed, @@ -663,6 +664,13 @@ export async function runEmbeddedPiAgent( !attempt.lastToolError && attempt.toolMetas.length === 0 && attempt.assistantTexts.length === 0; + if (preflightRecovery?.handled) { + log.info( + `[context-overflow-precheck] early recovery route=${preflightRecovery.route} ` + + `completed for ${provider}/${modelId}; retrying prompt`, + ); + continue; + } const requestedSelection = shouldSwitchToLiveModel({ cfg: params.config, sessionKey: params.sessionKey, @@ -919,6 +927,25 @@ export async function runEmbeddedPiAgent( } await runOwnsCompactionAfterHook("overflow recovery", compactResult); if (compactResult.compacted) { + if (preflightRecovery?.route === "compact_then_truncate") { + const truncResult = await truncateOversizedToolResultsInSession({ + sessionFile: params.sessionFile, + contextWindowTokens: ctxInfo.tokens, + sessionId: params.sessionId, + sessionKey: params.sessionKey, + }); + if (truncResult.truncated) { + log.info( + `[context-overflow-precheck] post-compaction tool-result truncation succeeded for ` + + `${provider}/${modelId}; truncated ${truncResult.truncatedCount} tool result(s)`, + ); + } else { + log.warn( + `[context-overflow-precheck] post-compaction tool-result truncation did not help for ` + + `${provider}/${modelId}: ${truncResult.reason ?? "unknown"}`, + ); + } + } autoCompactionCount += 1; log.info(`auto-compaction succeeded for ${provider}/${modelId}; retrying prompt`); continue; @@ -960,7 +987,8 @@ export async function runEmbeddedPiAgent( } } if ( - (isCompactionFailure || overflowCompactionAttempts >= MAX_OVERFLOW_COMPACTION_ATTEMPTS) && + (isCompactionFailure || + overflowCompactionAttempts >= MAX_OVERFLOW_COMPACTION_ATTEMPTS) && log.isEnabled("debug") ) { log.debug( diff --git a/src/agents/pi-embedded-runner/run/attempt.ts b/src/agents/pi-embedded-runner/run/attempt.ts index 297ca90c170..9f63c0ecf2f 100644 --- a/src/agents/pi-embedded-runner/run/attempt.ts +++ b/src/agents/pi-embedded-runner/run/attempt.ts @@ -120,10 +120,6 @@ import { } from "../prompt-cache-observability.js"; import { resolveCacheRetention } from "../prompt-cache-retention.js"; import { sanitizeSessionHistory, validateReplayTurns } from "../replay-history.js"; -import { - PREEMPTIVE_OVERFLOW_ERROR_TEXT, - shouldPreemptivelyCompactBeforePrompt, -} from "./preemptive-compaction.js"; import { clearActiveEmbeddedRun, type EmbeddedPiQueueHandle, @@ -149,6 +145,7 @@ import { import { dropThinkingBlocks } from "../thinking.js"; import { collectAllowedToolNames } from "../tool-name-allowlist.js"; import { installToolResultContextGuard } from "../tool-result-context-guard.js"; +import { truncateOversizedToolResultsInSessionManager } from "../tool-result-truncation.js"; import { logProviderToolSchemaDiagnostics, normalizeProviderToolSchemas, @@ -210,6 +207,10 @@ import { pruneProcessedHistoryImages } from "./history-image-prune.js"; import { detectAndLoadPromptImages } from "./images.js"; import { buildAttemptReplayMetadata } from "./incomplete-turn.js"; import { resolveLlmIdleTimeoutMs, streamWithIdleTimeout } from "./llm-idle-timeout.js"; +import { + PREEMPTIVE_OVERFLOW_ERROR_TEXT, + shouldPreemptivelyCompactBeforePrompt, +} from "./preemptive-compaction.js"; import type { EmbeddedRunAttemptParams, EmbeddedRunAttemptResult } from "./types.js"; export { @@ -1525,8 +1526,10 @@ export async function runEmbeddedAttempt( const hookAgentId = sessionAgentId; let promptError: unknown = null; + let preflightRecovery: EmbeddedRunAttemptResult["preflightRecovery"]; let promptErrorSource: "prompt" | "compaction" | "precheck" | null = null; let prePromptMessageCount = activeSession.messages.length; + let skipPromptSubmission = false; try { const promptStartedAt = Date.now(); @@ -1773,32 +1776,81 @@ export async function runEmbeddedAttempt( contextTokenBudget: params.contextTokenBudget, reserveTokens, }); + if (preemptiveCompaction.route === "truncate_tool_results_only") { + const truncationResult = truncateOversizedToolResultsInSessionManager({ + sessionManager, + contextWindowTokens: params.contextTokenBudget, + sessionFile: params.sessionFile, + sessionId: params.sessionId, + sessionKey: params.sessionKey, + }); + if (truncationResult.truncated) { + preflightRecovery = { + route: "truncate_tool_results_only", + handled: true, + truncatedCount: truncationResult.truncatedCount, + }; + log.info( + `[context-overflow-precheck] early tool-result truncation succeeded for ` + + `${params.provider}/${params.modelId} route=${preemptiveCompaction.route} ` + + `truncatedCount=${truncationResult.truncatedCount} ` + + `estimatedPromptTokens=${preemptiveCompaction.estimatedPromptTokens} ` + + `promptBudgetBeforeReserve=${preemptiveCompaction.promptBudgetBeforeReserve} ` + + `overflowTokens=${preemptiveCompaction.overflowTokens} ` + + `toolResultReducibleChars=${preemptiveCompaction.toolResultReducibleChars} ` + + `sessionFile=${params.sessionFile}`, + ); + skipPromptSubmission = true; + } + if (!skipPromptSubmission) { + log.warn( + `[context-overflow-precheck] early tool-result truncation did not help for ` + + `${params.provider}/${params.modelId}; falling back to compaction ` + + `reason=${truncationResult.reason ?? "unknown"} sessionFile=${params.sessionFile}`, + ); + preflightRecovery = { route: "compact_only" }; + promptError = new Error(PREEMPTIVE_OVERFLOW_ERROR_TEXT); + promptErrorSource = "precheck"; + skipPromptSubmission = true; + } + } if (preemptiveCompaction.shouldCompact) { + preflightRecovery = + preemptiveCompaction.route === "compact_then_truncate" + ? { route: "compact_then_truncate" } + : { route: "compact_only" }; promptError = new Error(PREEMPTIVE_OVERFLOW_ERROR_TEXT); promptErrorSource = "precheck"; log.warn( `[context-overflow-precheck] sessionKey=${params.sessionKey ?? params.sessionId} ` + `provider=${params.provider}/${params.modelId} ` + + `route=${preemptiveCompaction.route} ` + `estimatedPromptTokens=${preemptiveCompaction.estimatedPromptTokens} ` + `promptBudgetBeforeReserve=${preemptiveCompaction.promptBudgetBeforeReserve} ` + + `overflowTokens=${preemptiveCompaction.overflowTokens} ` + + `toolResultReducibleChars=${preemptiveCompaction.toolResultReducibleChars} ` + `reserveTokens=${reserveTokens} sessionFile=${params.sessionFile}`, ); - return; + skipPromptSubmission = true; } - const btwSnapshotMessages = activeSession.messages.slice(-MAX_BTW_SNAPSHOT_MESSAGES); - updateActiveEmbeddedRunSnapshot(params.sessionId, { - transcriptLeafId, - messages: btwSnapshotMessages, - inFlightPrompt: effectivePrompt, - }); + if (!skipPromptSubmission) { + const btwSnapshotMessages = activeSession.messages.slice(-MAX_BTW_SNAPSHOT_MESSAGES); + updateActiveEmbeddedRunSnapshot(params.sessionId, { + transcriptLeafId, + messages: btwSnapshotMessages, + inFlightPrompt: effectivePrompt, + }); - // Only pass images option if there are actually images to pass - // This avoids potential issues with models that don't expect the images parameter - if (imageResult.images.length > 0) { - await abortable(activeSession.prompt(effectivePrompt, { images: imageResult.images })); - } else { - await abortable(activeSession.prompt(effectivePrompt)); + // Only pass images option if there are actually images to pass + // This avoids potential issues with models that don't expect the images parameter + if (imageResult.images.length > 0) { + await abortable( + activeSession.prompt(effectivePrompt, { images: imageResult.images }), + ); + } else { + await abortable(activeSession.prompt(effectivePrompt)); + } } } catch (err) { // Yield-triggered abort is intentional — treat as clean stop, not error. @@ -2160,6 +2212,7 @@ export async function runEmbeddedAttempt( timedOut, timedOutDuringCompaction, promptError, + preflightRecovery, sessionIdUsed, bootstrapPromptWarningSignaturesSeen: bootstrapPromptWarning.warningSignaturesSeen, bootstrapPromptWarningSignature: bootstrapPromptWarning.signature, diff --git a/src/agents/pi-embedded-runner/run/preemptive-compaction.test.ts b/src/agents/pi-embedded-runner/run/preemptive-compaction.test.ts index 81919a477f9..6fb6715755a 100644 --- a/src/agents/pi-embedded-runner/run/preemptive-compaction.test.ts +++ b/src/agents/pi-embedded-runner/run/preemptive-compaction.test.ts @@ -1,4 +1,5 @@ import { describe, expect, it } from "vitest"; +import { estimateToolResultReductionPotential } from "../tool-result-truncation.js"; import { PREEMPTIVE_OVERFLOW_ERROR_TEXT, estimatePrePromptTokens, @@ -43,6 +44,7 @@ describe("preemptive-compaction", () => { }); expect(result.shouldCompact).toBe(true); + expect(result.route).toBe("compact_only"); expect(result.estimatedPromptTokens).toBeGreaterThan(result.promptBudgetBeforeReserve); }); @@ -56,6 +58,85 @@ describe("preemptive-compaction", () => { }); expect(result.shouldCompact).toBe(false); + expect(result.route).toBe("fits"); expect(result.estimatedPromptTokens).toBeLessThan(result.promptBudgetBeforeReserve); }); + + it("routes to direct tool-result truncation when recent tool tails can clearly absorb the overflow", () => { + const medium = "alpha beta gamma delta epsilon ".repeat(2200); + const messages = [ + { role: "assistant", content: "short history" }, + { + role: "toolResult", + content: [ + { type: "text", text: medium }, + { type: "text", text: medium }, + { type: "text", text: medium }, + { type: "text", text: medium }, + ], + } as never, + ]; + const reserveTokens = 2_000; + const contextTokenBudget = 26_000; + const estimatedPromptTokens = estimatePrePromptTokens({ + messages, + systemPrompt: "sys", + prompt: "hello", + }); + const desiredOverflowTokens = 200; + const adjustedContextTokenBudget = + estimatedPromptTokens - desiredOverflowTokens + reserveTokens; + const result = shouldPreemptivelyCompactBeforePrompt({ + messages, + systemPrompt: "sys", + prompt: "hello", + contextTokenBudget: Math.max(contextTokenBudget, adjustedContextTokenBudget), + reserveTokens, + }); + + expect(result.route).toBe("truncate_tool_results_only"); + expect(result.shouldCompact).toBe(false); + expect(result.overflowTokens).toBeGreaterThan(0); + expect(result.toolResultReducibleChars).toBeGreaterThan(0); + }); + + it("routes to compact then truncate when recent tool tails help but cannot fully cover the overflow", () => { + const medium = "alpha beta gamma delta epsilon ".repeat(220); + const longHistory = "old discussion with substantial retained context and decisions ".repeat( + 5000, + ); + const messages = [ + { role: "assistant", content: longHistory }, + { role: "toolResult", content: [{ type: "text", text: medium }] } as never, + { role: "toolResult", content: [{ type: "text", text: medium }] } as never, + { role: "toolResult", content: [{ type: "text", text: medium }] } as never, + ]; + const reserveTokens = 500; + const baseContextTokenBudget = 3_500; + const estimatedPromptTokens = estimatePrePromptTokens({ + messages, + systemPrompt: verboseSystem, + prompt: verbosePrompt, + }); + const toolResultPotential = estimateToolResultReductionPotential({ + messages: messages as never, + contextWindowTokens: baseContextTokenBudget, + }); + const desiredOverflowTokens = Math.ceil((toolResultPotential.maxReducibleChars + 4_096) / 4); + const result = shouldPreemptivelyCompactBeforePrompt({ + messages, + systemPrompt: verboseSystem, + prompt: verbosePrompt, + contextTokenBudget: Math.max( + baseContextTokenBudget, + estimatedPromptTokens - desiredOverflowTokens + reserveTokens, + ), + reserveTokens, + }); + + expect(result.route).toBe("compact_then_truncate"); + expect(result.shouldCompact).toBe(true); + expect(result.overflowTokens).toBeGreaterThan(0); + expect(result.toolResultReducibleChars).toBeGreaterThan(0); + }); }); diff --git a/src/agents/pi-embedded-runner/run/preemptive-compaction.ts b/src/agents/pi-embedded-runner/run/preemptive-compaction.ts index baf153c0256..f8c805a8959 100644 --- a/src/agents/pi-embedded-runner/run/preemptive-compaction.ts +++ b/src/agents/pi-embedded-runner/run/preemptive-compaction.ts @@ -1,10 +1,20 @@ import type { AgentMessage } from "@mariozechner/pi-agent-core"; import { estimateTokens } from "@mariozechner/pi-coding-agent"; import { SAFETY_MARGIN, estimateMessagesTokens } from "../../compaction.js"; +import { estimateToolResultReductionPotential } from "../tool-result-truncation.js"; export const PREEMPTIVE_OVERFLOW_ERROR_TEXT = "Context overflow: prompt too large for the model (precheck)."; +const ESTIMATED_CHARS_PER_TOKEN = 4; +const TRUNCATION_ROUTE_BUFFER_TOKENS = 512; + +export type PreemptiveCompactionRoute = + | "fits" + | "compact_only" + | "truncate_tool_results_only" + | "compact_then_truncate"; + export function estimatePrePromptTokens(params: { messages: AgentMessage[]; systemPrompt?: string; @@ -30,18 +40,47 @@ export function shouldPreemptivelyCompactBeforePrompt(params: { contextTokenBudget: number; reserveTokens: number; }): { + route: PreemptiveCompactionRoute; shouldCompact: boolean; estimatedPromptTokens: number; promptBudgetBeforeReserve: number; + overflowTokens: number; + toolResultReducibleChars: number; } { const estimatedPromptTokens = estimatePrePromptTokens(params); const promptBudgetBeforeReserve = Math.max( 1, Math.floor(params.contextTokenBudget) - Math.max(0, Math.floor(params.reserveTokens)), ); + const overflowTokens = Math.max(0, estimatedPromptTokens - promptBudgetBeforeReserve); + const toolResultPotential = estimateToolResultReductionPotential({ + messages: params.messages, + contextWindowTokens: params.contextTokenBudget, + }); + const overflowChars = overflowTokens * ESTIMATED_CHARS_PER_TOKEN; + const truncationBufferChars = TRUNCATION_ROUTE_BUFFER_TOKENS * ESTIMATED_CHARS_PER_TOKEN; + const truncateOnlyThresholdChars = Math.max( + overflowChars + truncationBufferChars, + Math.ceil(overflowChars * 1.5), + ); + const toolResultReducibleChars = toolResultPotential.maxReducibleChars; + + let route: PreemptiveCompactionRoute = "fits"; + if (overflowTokens > 0) { + if (toolResultReducibleChars <= 0) { + route = "compact_only"; + } else if (toolResultReducibleChars >= truncateOnlyThresholdChars) { + route = "truncate_tool_results_only"; + } else { + route = "compact_then_truncate"; + } + } return { - shouldCompact: estimatedPromptTokens > promptBudgetBeforeReserve, + route, + shouldCompact: route === "compact_only" || route === "compact_then_truncate", estimatedPromptTokens, promptBudgetBeforeReserve, + overflowTokens, + toolResultReducibleChars, }; } diff --git a/src/agents/pi-embedded-runner/run/types.ts b/src/agents/pi-embedded-runner/run/types.ts index f3696f79f01..cc95d4ad4df 100644 --- a/src/agents/pi-embedded-runner/run/types.ts +++ b/src/agents/pi-embedded-runner/run/types.ts @@ -9,6 +9,7 @@ import type { MessagingToolSend } from "../../pi-embedded-messaging.js"; import type { ToolErrorSummary } from "../../tool-error-summary.js"; import type { NormalizedUsage } from "../../usage.js"; import type { RunEmbeddedPiAgentParams } from "./params.js"; +import type { PreemptiveCompactionRoute } from "./preemptive-compaction.js"; type EmbeddedRunAttemptBase = Omit< RunEmbeddedPiAgentParams, @@ -41,6 +42,16 @@ export type EmbeddedRunAttemptResult = { /** True if the timeout occurred while compaction was in progress or pending. */ timedOutDuringCompaction: boolean; promptError: unknown; + preflightRecovery?: + | { + route: Exclude; + handled: true; + truncatedCount?: number; + } + | { + route: Exclude; + handled?: false; + }; sessionIdUsed: string; bootstrapPromptWarningSignaturesSeen?: string[]; bootstrapPromptWarningSignature?: string; diff --git a/src/agents/pi-embedded-runner/tool-result-truncation.test.ts b/src/agents/pi-embedded-runner/tool-result-truncation.test.ts index 5e49b999d88..b900da283fe 100644 --- a/src/agents/pi-embedded-runner/tool-result-truncation.test.ts +++ b/src/agents/pi-embedded-runner/tool-result-truncation.test.ts @@ -2,8 +2,8 @@ import fs from "node:fs/promises"; import os from "node:os"; import path from "node:path"; import type { AgentMessage } from "@mariozechner/pi-agent-core"; -import { SessionManager } from "@mariozechner/pi-coding-agent"; import type { AssistantMessage, ToolResultMessage, UserMessage } from "@mariozechner/pi-ai"; +import { SessionManager } from "@mariozechner/pi-coding-agent"; import { afterEach, beforeEach, describe, expect, it } from "vitest"; import { makeAgentAssistantMessage } from "../test-helpers/agent-message-fixtures.js"; @@ -15,6 +15,7 @@ let truncateOversizedToolResultsInMessages: typeof import("./tool-result-truncat let truncateOversizedToolResultsInSession: typeof import("./tool-result-truncation.js").truncateOversizedToolResultsInSession; let isOversizedToolResult: typeof import("./tool-result-truncation.js").isOversizedToolResult; let sessionLikelyHasOversizedToolResults: typeof import("./tool-result-truncation.js").sessionLikelyHasOversizedToolResults; +let estimateToolResultReductionPotential: typeof import("./tool-result-truncation.js").estimateToolResultReductionPotential; let DEFAULT_MAX_LIVE_TOOL_RESULT_CHARS: typeof import("./tool-result-truncation.js").DEFAULT_MAX_LIVE_TOOL_RESULT_CHARS; let HARD_MAX_TOOL_RESULT_CHARS: typeof import("./tool-result-truncation.js").HARD_MAX_TOOL_RESULT_CHARS; let tmpDir: string | undefined; @@ -29,6 +30,7 @@ async function loadFreshToolResultTruncationModuleForTest() { truncateOversizedToolResultsInSession, isOversizedToolResult, sessionLikelyHasOversizedToolResults, + estimateToolResultReductionPotential, DEFAULT_MAX_LIVE_TOOL_RESULT_CHARS, HARD_MAX_TOOL_RESULT_CHARS, } = await import("./tool-result-truncation.js")); @@ -245,6 +247,58 @@ describe("sessionLikelyHasOversizedToolResults", () => { }); }); +describe("estimateToolResultReductionPotential", () => { + it("reports no reducible budget when tool results are already small", () => { + const messages: AgentMessage[] = [makeToolResult("small result")]; + + const estimate = estimateToolResultReductionPotential({ + messages, + contextWindowTokens: 128_000, + }); + + expect(estimate.toolResultCount).toBe(1); + expect(estimate.maxReducibleChars).toBe(0); + }); + + it("estimates reducible chars for aggregate medium tool-result tails", () => { + const medium = "alpha beta gamma delta epsilon ".repeat(600); + const messages: AgentMessage[] = [ + makeToolResult(medium, "call_1"), + makeToolResult(medium, "call_2"), + makeToolResult(medium, "call_3"), + ]; + + const estimate = estimateToolResultReductionPotential({ + messages, + contextWindowTokens: 128_000, + }); + + expect(estimate.toolResultCount).toBe(3); + expect(estimate.oversizedCount).toBe(0); + expect(estimate.aggregateReducibleChars).toBeGreaterThan(0); + expect(estimate.maxReducibleChars).toBe(estimate.aggregateReducibleChars); + }); + + it("does not count aggregate savings on top of oversized savings in a single pass", () => { + const oversized = "x".repeat(500_000); + const medium = "alpha beta gamma delta epsilon ".repeat(800); + const messages: AgentMessage[] = [ + makeToolResult(oversized, "call_1"), + makeToolResult(medium, "call_2"), + makeToolResult(medium, "call_3"), + ]; + + const estimate = estimateToolResultReductionPotential({ + messages, + contextWindowTokens: 128_000, + }); + + expect(estimate.oversizedCount).toBeGreaterThan(0); + expect(estimate.oversizedReducibleChars).toBeGreaterThan(0); + expect(estimate.maxReducibleChars).toBe(estimate.oversizedReducibleChars); + }); +}); + describe("truncateOversizedToolResultsInMessages", () => { it("returns unchanged messages when nothing is oversized", () => { const messages = [ diff --git a/src/agents/pi-embedded-runner/tool-result-truncation.ts b/src/agents/pi-embedded-runner/tool-result-truncation.ts index 55e507b7f03..4f6bf2c0022 100644 --- a/src/agents/pi-embedded-runner/tool-result-truncation.ts +++ b/src/agents/pi-embedded-runner/tool-result-truncation.ts @@ -3,8 +3,8 @@ import type { TextContent } from "@mariozechner/pi-ai"; import { SessionManager } from "@mariozechner/pi-coding-agent"; import { emitSessionTranscriptUpdate } from "../../sessions/transcript-events.js"; import { acquireSessionWriteLock } from "../session-write-lock.js"; -import { formatContextLimitTruncationNotice } from "./tool-result-context-guard.js"; import { log } from "./logger.js"; +import { formatContextLimitTruncationNotice } from "./tool-result-context-guard.js"; import { rewriteTranscriptEntriesInSessionManager } from "./transcript-rewrite.js"; /** @@ -42,7 +42,7 @@ type ToolResultTruncationOptions = { const DEFAULT_SUFFIX = (truncatedChars: number) => formatContextLimitTruncationNotice(truncatedChars); -const MIN_TRUNCATED_TEXT_CHARS = MIN_KEEP_CHARS + DEFAULT_SUFFIX(1).length; +export const MIN_TRUNCATED_TEXT_CHARS = MIN_KEEP_CHARS + DEFAULT_SUFFIX(1).length; /** * Marker inserted between head and tail when using head+tail truncation. @@ -82,7 +82,7 @@ export function truncateToolResultText( const suffixFactory: (truncatedChars: number) => string = typeof options.suffix === "function" ? options.suffix - : () => (options.suffix ?? DEFAULT_SUFFIX(1)); + : () => options.suffix ?? DEFAULT_SUFFIX(1); const minKeepChars = options.minKeepChars ?? MIN_KEEP_CHARS; if (text.length <= maxChars) { return text; @@ -175,7 +175,7 @@ export function truncateToolResultMessage( const suffixFactory: (truncatedChars: number) => string = typeof options.suffix === "function" ? options.suffix - : () => (options.suffix ?? DEFAULT_SUFFIX(1)); + : () => options.suffix ?? DEFAULT_SUFFIX(1); const minKeepChars = options.minKeepChars ?? MIN_KEEP_CHARS; const content = (msg as { content?: unknown }).content; if (!Array.isArray(content)) { @@ -251,6 +251,17 @@ function calculateAggregateToolResultChars(contextWindowTokens: number): number return Math.max(calculateMaxToolResultChars(contextWindowTokens), MIN_TRUNCATED_TEXT_CHARS); } +export type ToolResultReductionPotential = { + maxChars: number; + aggregateBudgetChars: number; + toolResultCount: number; + totalToolResultChars: number; + oversizedCount: number; + oversizedReducibleChars: number; + aggregateReducibleChars: number; + maxReducibleChars: number; +}; + function buildAggregateToolResultReplacements(params: { branch: Array<{ id: string; type: string; message?: AgentMessage }>; aggregateBudgetChars: number; @@ -258,7 +269,9 @@ function buildAggregateToolResultReplacements(params: { const candidates = params.branch .map((entry, index) => ({ entry, index })) .filter( - (item): item is { + ( + item, + ): item is { entry: { id: string; type: string; message: AgentMessage }; index: number; } => @@ -313,94 +326,131 @@ function buildAggregateToolResultReplacements(params: { return replacements; } -export async function truncateOversizedToolResultsInSession(params: { - sessionFile: string; +export function estimateToolResultReductionPotential(params: { + messages: AgentMessage[]; contextWindowTokens: number; - sessionId?: string; - sessionKey?: string; -}): Promise<{ truncated: boolean; truncatedCount: number; reason?: string }> { - const { sessionFile, contextWindowTokens } = params; +}): ToolResultReductionPotential { + const { messages, contextWindowTokens } = params; const maxChars = calculateMaxToolResultChars(contextWindowTokens); const aggregateBudgetChars = calculateAggregateToolResultChars(contextWindowTokens); - let sessionLock: Awaited> | undefined; - try { - sessionLock = await acquireSessionWriteLock({ sessionFile }); - const sessionManager = SessionManager.open(sessionFile); - const branch = sessionManager.getBranch(); + let toolResultCount = 0; + let totalToolResultChars = 0; + let oversizedCount = 0; + let oversizedReducibleChars = 0; + const individuallyTrimmedMessages = messages.slice(); - if (branch.length === 0) { - return { truncated: false, truncatedCount: 0, reason: "empty session" }; + for (let index = 0; index < messages.length; index += 1) { + const msg = messages[index]; + if ((msg as { role?: string }).role !== "toolResult") { + continue; } - - const oversizedIndices: number[] = []; - for (let i = 0; i < branch.length; i += 1) { - const entry = branch[i]; - if (entry.type !== "message") { - continue; - } - const msg = entry.message; - if ((msg as { role?: string }).role !== "toolResult") { - continue; - } - if (getToolResultTextLength(msg) > maxChars) { - oversizedIndices.push(i); - } + const textLength = getToolResultTextLength(msg); + if (textLength <= 0) { + continue; } + toolResultCount += 1; + totalToolResultChars += textLength; + if (textLength <= maxChars) { + continue; + } + oversizedCount += 1; + const truncatedMessage = truncateToolResultMessage(msg, maxChars); + individuallyTrimmedMessages[index] = truncatedMessage; + oversizedReducibleChars += Math.max(0, textLength - getToolResultTextLength(truncatedMessage)); + } - if (oversizedIndices.length === 0) { - const replacements = buildAggregateToolResultReplacements({ - branch: branch as Array<{ id: string; type: string; message?: AgentMessage }>, - aggregateBudgetChars, - }); - if (replacements.length === 0) { - return { - truncated: false, - truncatedCount: 0, - reason: "no oversized or aggregate tool results", - }; - } + const aggregateReplacements = buildAggregateToolResultReplacements({ + branch: individuallyTrimmedMessages.map((message, index) => ({ + id: `message-${index}`, + type: "message", + message, + })), + aggregateBudgetChars, + }); + const individuallyTrimmedBranch = individuallyTrimmedMessages.map((message, index) => ({ + id: `message-${index}`, + type: "message", + message, + })); + const aggregateReducibleChars = aggregateReplacements.reduce((sum, replacement) => { + const match = individuallyTrimmedBranch.find((entry) => entry.id === replacement.entryId); + const originalLength = + match && match.message + ? getToolResultTextLength(match.message) + : getToolResultTextLength(replacement.message); + const newLength = getToolResultTextLength(replacement.message); + return sum + Math.max(0, originalLength - newLength); + }, 0); + const maxReducibleChars = oversizedCount > 0 ? oversizedReducibleChars : aggregateReducibleChars; - const rewriteResult = rewriteTranscriptEntriesInSessionManager({ - sessionManager, - replacements, - }); - if (rewriteResult.changed) { - emitSessionTranscriptUpdate(sessionFile); - } + return { + maxChars, + aggregateBudgetChars, + toolResultCount, + totalToolResultChars, + oversizedCount, + oversizedReducibleChars, + aggregateReducibleChars, + maxReducibleChars, + }; +} - log.info( - `[tool-result-truncation] Aggregate-truncated ${rewriteResult.rewrittenEntries} tool result(s) in session ` + - `(contextWindow=${contextWindowTokens} aggregateBudgetChars=${aggregateBudgetChars}) ` + - `sessionKey=${params.sessionKey ?? params.sessionId ?? "unknown"}`, - ); +function truncateOversizedToolResultsInExistingSessionManager(params: { + sessionManager: SessionManager; + contextWindowTokens: number; + sessionFile?: string; + sessionId?: string; + sessionKey?: string; +}): { truncated: boolean; truncatedCount: number; reason?: string } { + const { sessionManager, contextWindowTokens } = params; + const maxChars = calculateMaxToolResultChars(contextWindowTokens); + const aggregateBudgetChars = calculateAggregateToolResultChars(contextWindowTokens); + const branch = sessionManager.getBranch(); + if (branch.length === 0) { + return { truncated: false, truncatedCount: 0, reason: "empty session" }; + } + + const oversizedIndices: number[] = []; + for (let i = 0; i < branch.length; i += 1) { + const entry = branch[i]; + if (entry.type !== "message") { + continue; + } + const msg = entry.message; + if ((msg as { role?: string }).role !== "toolResult") { + continue; + } + if (getToolResultTextLength(msg) > maxChars) { + oversizedIndices.push(i); + } + } + + if (oversizedIndices.length === 0) { + const replacements = buildAggregateToolResultReplacements({ + branch: branch as Array<{ id: string; type: string; message?: AgentMessage }>, + aggregateBudgetChars, + }); + if (replacements.length === 0) { return { - truncated: rewriteResult.changed, - truncatedCount: rewriteResult.rewrittenEntries, - reason: rewriteResult.reason, + truncated: false, + truncatedCount: 0, + reason: "no oversized or aggregate tool results", }; } - const replacements = oversizedIndices.flatMap((index) => { - const entry = branch[index]; - if (!entry || entry.type !== "message") { - return []; - } - return [{ entryId: entry.id, message: truncateToolResultMessage(entry.message, maxChars) }]; - }); - const rewriteResult = rewriteTranscriptEntriesInSessionManager({ sessionManager, replacements, }); - if (rewriteResult.changed) { - emitSessionTranscriptUpdate(sessionFile); + if (rewriteResult.changed && params.sessionFile) { + emitSessionTranscriptUpdate(params.sessionFile); } log.info( - `[tool-result-truncation] Truncated ${rewriteResult.rewrittenEntries} tool result(s) in session ` + - `(contextWindow=${contextWindowTokens} maxChars=${maxChars}) ` + + `[tool-result-truncation] Aggregate-truncated ${rewriteResult.rewrittenEntries} tool result(s) in session ` + + `(contextWindow=${contextWindowTokens} aggregateBudgetChars=${aggregateBudgetChars}) ` + `sessionKey=${params.sessionKey ?? params.sessionId ?? "unknown"}`, ); @@ -409,6 +459,72 @@ export async function truncateOversizedToolResultsInSession(params: { truncatedCount: rewriteResult.rewrittenEntries, reason: rewriteResult.reason, }; + } + + const replacements = oversizedIndices.flatMap((index) => { + const entry = branch[index]; + if (!entry || entry.type !== "message") { + return []; + } + return [{ entryId: entry.id, message: truncateToolResultMessage(entry.message, maxChars) }]; + }); + + const rewriteResult = rewriteTranscriptEntriesInSessionManager({ + sessionManager, + replacements, + }); + if (rewriteResult.changed && params.sessionFile) { + emitSessionTranscriptUpdate(params.sessionFile); + } + + log.info( + `[tool-result-truncation] Truncated ${rewriteResult.rewrittenEntries} tool result(s) in session ` + + `(contextWindow=${contextWindowTokens} maxChars=${maxChars}) ` + + `sessionKey=${params.sessionKey ?? params.sessionId ?? "unknown"}`, + ); + + return { + truncated: rewriteResult.changed, + truncatedCount: rewriteResult.rewrittenEntries, + reason: rewriteResult.reason, + }; +} + +export function truncateOversizedToolResultsInSessionManager(params: { + sessionManager: SessionManager; + contextWindowTokens: number; + sessionFile?: string; + sessionId?: string; + sessionKey?: string; +}): { truncated: boolean; truncatedCount: number; reason?: string } { + try { + return truncateOversizedToolResultsInExistingSessionManager(params); + } catch (err) { + const errMsg = err instanceof Error ? err.message : String(err); + log.warn(`[tool-result-truncation] Failed to truncate: ${errMsg}`); + return { truncated: false, truncatedCount: 0, reason: errMsg }; + } +} + +export async function truncateOversizedToolResultsInSession(params: { + sessionFile: string; + contextWindowTokens: number; + sessionId?: string; + sessionKey?: string; +}): Promise<{ truncated: boolean; truncatedCount: number; reason?: string }> { + const { sessionFile, contextWindowTokens } = params; + let sessionLock: Awaited> | undefined; + + try { + sessionLock = await acquireSessionWriteLock({ sessionFile }); + const sessionManager = SessionManager.open(sessionFile); + return truncateOversizedToolResultsInExistingSessionManager({ + sessionManager, + contextWindowTokens, + sessionFile, + sessionId: params.sessionId, + sessionKey: params.sessionKey, + }); } catch (err) { const errMsg = err instanceof Error ? err.message : String(err); log.warn(`[tool-result-truncation] Failed to truncate: ${errMsg}`); @@ -433,25 +549,6 @@ export function sessionLikelyHasOversizedToolResults(params: { messages: AgentMessage[]; contextWindowTokens: number; }): boolean { - const { messages, contextWindowTokens } = params; - const maxChars = calculateMaxToolResultChars(contextWindowTokens); - const aggregateBudgetChars = calculateAggregateToolResultChars(contextWindowTokens); - let totalToolResultChars = 0; - let toolResultCount = 0; - - for (const msg of messages) { - if ((msg as { role?: string }).role !== "toolResult") { - continue; - } - const textLength = getToolResultTextLength(msg); - if (textLength > maxChars) { - return true; - } - totalToolResultChars += textLength; - if (textLength > 0) { - toolResultCount += 1; - } - } - - return toolResultCount >= 2 && totalToolResultChars > aggregateBudgetChars; + const estimate = estimateToolResultReductionPotential(params); + return estimate.oversizedCount > 0 || estimate.aggregateReducibleChars > 0; }