From bb6cf75463219374b200718b40eb0986865c477b Mon Sep 17 00:00:00 2001 From: Peter Steinberger Date: Sat, 25 Apr 2026 20:34:14 +0100 Subject: [PATCH] refactor: centralize context prompt token resolution --- src/agents/usage.test.ts | 34 +++- src/agents/usage.ts | 24 ++- .../agent-runner.misc.runreplyagent.test.ts | 184 +++++------------- src/auto-reply/reply/agent-runner.ts | 53 +---- 4 files changed, 105 insertions(+), 190 deletions(-) diff --git a/src/agents/usage.test.ts b/src/agents/usage.test.ts index 68f69113b67..986fccbdf95 100644 --- a/src/agents/usage.test.ts +++ b/src/agents/usage.test.ts @@ -1,9 +1,10 @@ import { describe, expect, it } from "vitest"; import { - normalizeUsage, - hasNonzeroUsage, + deriveContextPromptTokens, derivePromptTokens, deriveSessionTotalTokens, + hasNonzeroUsage, + normalizeUsage, toOpenAiChatCompletionsUsage, } from "./usage.js"; @@ -282,6 +283,35 @@ describe("derivePromptTokens", () => { }); }); +describe("deriveContextPromptTokens", () => { + it("prefers explicit prompt snapshot over provider usage", () => { + expect( + deriveContextPromptTokens({ + promptTokens: 44_000, + lastCallUsage: { input: 55_000, cacheRead: 25_000 }, + usage: { input: 75_000, cacheRead: 25_000, output: 5_000, total: 105_000 }, + }), + ).toBe(44_000); + }); + + it("falls back to last-call prompt usage before accumulated usage", () => { + expect( + deriveContextPromptTokens({ + lastCallUsage: { input: 55_000, cacheRead: 25_000, cacheWrite: 1_000 }, + usage: { input: 75_000, cacheRead: 25_000, output: 5_000, total: 105_000 }, + }), + ).toBe(81_000); + }); + + it("falls back to accumulated usage when no prompt snapshot exists", () => { + expect( + deriveContextPromptTokens({ + usage: { input: 75_000, cacheRead: 25_000, output: 5_000, total: 105_000 }, + }), + ).toBe(100_000); + }); +}); + describe("deriveSessionTotalTokens", () => { it("includes cache tokens in total calculation", () => { const totalTokens = deriveSessionTotalTokens({ diff --git a/src/agents/usage.ts b/src/agents/usage.ts index 4f437063585..f373bd957e8 100644 --- a/src/agents/usage.ts +++ b/src/agents/usage.ts @@ -219,6 +219,19 @@ export function derivePromptTokens(usage?: { return sum > 0 ? sum : undefined; } +export function deriveContextPromptTokens(params: { + lastCallUsage?: NormalizedUsage; + promptTokens?: number; + usage?: NormalizedUsage; +}): number | undefined { + const promptOverride = params.promptTokens; + if (typeof promptOverride === "number" && Number.isFinite(promptOverride) && promptOverride > 0) { + return promptOverride; + } + + return derivePromptTokens(params.lastCallUsage) ?? derivePromptTokens(params.usage); +} + export function deriveSessionTotalTokens(params: { usage?: { input?: number; @@ -241,13 +254,10 @@ export function deriveSessionTotalTokens(params: { // NOTE: SessionEntry.totalTokens is used as a prompt/context snapshot. // It intentionally excludes completion/output tokens. - const promptTokens = hasPromptOverride - ? promptOverride - : derivePromptTokens({ - input: usage?.input, - cacheRead: usage?.cacheRead, - cacheWrite: usage?.cacheWrite, - }); + const promptTokens = deriveContextPromptTokens({ + promptTokens: hasPromptOverride ? promptOverride : undefined, + usage, + }); if (!(typeof promptTokens === "number") || !Number.isFinite(promptTokens) || promptTokens <= 0) { return undefined; diff --git a/src/auto-reply/reply/agent-runner.misc.runreplyagent.test.ts b/src/auto-reply/reply/agent-runner.misc.runreplyagent.test.ts index cded996fb32..839a84c56ad 100644 --- a/src/auto-reply/reply/agent-runner.misc.runreplyagent.test.ts +++ b/src/auto-reply/reply/agent-runner.misc.runreplyagent.test.ts @@ -239,8 +239,12 @@ describe("runReplyAgent auto-compaction token update", () => { return { typing, sessionCtx, resolvedQueue, followupRun }; } - it("updates totalTokens from lastCallUsage even without compaction", async () => { - const tmp = await fs.mkdtemp(path.join(os.tmpdir(), "openclaw-usage-last-")); + async function runBaseReplyWithAgentMeta(params: { + agentMeta: Record; + collectDiagnostics?: boolean; + tmpPrefix: string; + }) { + const tmp = await fs.mkdtemp(path.join(os.tmpdir(), params.tmpPrefix)); const storePath = path.join(tmp, "sessions.json"); const sessionKey = "main"; const sessionEntry = { @@ -254,76 +258,16 @@ describe("runReplyAgent auto-compaction token update", () => { runEmbeddedPiAgentMock.mockResolvedValue({ payloads: [{ text: "ok" }], meta: { - agentMeta: { - // Tool-use loop: accumulated input is higher than last call's input - usage: { input: 75_000, output: 5_000, total: 80_000 }, - lastCallUsage: { input: 55_000, output: 2_000, total: 57_000 }, - }, - }, - }); - - const { typing, sessionCtx, resolvedQueue, followupRun } = createBaseRun({ - storePath, - sessionEntry, - }); - - await runReplyAgent({ - commandBody: "hello", - followupRun, - queueKey: "main", - resolvedQueue, - shouldSteer: false, - shouldFollowup: false, - isActive: false, - isStreaming: false, - typing, - sessionCtx, - sessionEntry, - sessionStore: { [sessionKey]: sessionEntry }, - sessionKey, - storePath, - defaultModel: "anthropic/claude-opus-4-6", - agentCfgContextTokens: 200_000, - resolvedVerboseLevel: "off", - isNewSession: false, - blockStreamingEnabled: false, - resolvedBlockStreamingBreak: "message_end", - shouldInjectGroupIntro: false, - typingMode: "instant", - }); - - const stored = JSON.parse(await fs.readFile(storePath, "utf-8")); - // totalTokens should use lastCallUsage (55k), not accumulated (75k) - expect(stored[sessionKey].totalTokens).toBe(55_000); - }); - - it("reports live diagnostic context from promptTokens, not provider usage totals", async () => { - const tmp = await fs.mkdtemp(path.join(os.tmpdir(), "openclaw-usage-diagnostic-")); - const storePath = path.join(tmp, "sessions.json"); - const sessionKey = "main"; - const sessionEntry = { - sessionId: "session", - updatedAt: Date.now(), - totalTokens: 50_000, - }; - - await seedSessionStore({ storePath, sessionKey, entry: sessionEntry }); - - runEmbeddedPiAgentMock.mockResolvedValue({ - payloads: [{ text: "ok" }], - meta: { - agentMeta: { - usage: { input: 75_000, output: 5_000, cacheRead: 25_000, total: 105_000 }, - lastCallUsage: { input: 55_000, output: 2_000, cacheRead: 25_000, total: 82_000 }, - promptTokens: 44_000, - }, + agentMeta: params.agentMeta, }, }); const diagnostics: DiagnosticEventPayload[] = []; - const unsubscribe = onInternalDiagnosticEvent((event) => { - diagnostics.push(event); - }); + const unsubscribe = params.collectDiagnostics + ? onInternalDiagnosticEvent((event) => { + diagnostics.push(event); + }) + : undefined; const { typing, sessionCtx, resolvedQueue, followupRun } = createBaseRun({ storePath, sessionEntry, @@ -355,10 +299,39 @@ describe("runReplyAgent auto-compaction token update", () => { typingMode: "instant", }); } finally { - unsubscribe(); + unsubscribe?.(); } + const stored = JSON.parse(await fs.readFile(storePath, "utf-8")); const usageEvent = diagnostics.find((event) => event.type === "model.usage"); + return { sessionKey, stored, usageEvent }; + } + + it("updates totalTokens from lastCallUsage even without compaction", async () => { + const { sessionKey, stored } = await runBaseReplyWithAgentMeta({ + tmpPrefix: "openclaw-usage-last-", + agentMeta: { + // Tool-use loop: accumulated input is higher than last call's input + usage: { input: 75_000, output: 5_000, total: 80_000 }, + lastCallUsage: { input: 55_000, output: 2_000, total: 57_000 }, + }, + }); + + // totalTokens should use lastCallUsage (55k), not accumulated (75k) + expect(stored[sessionKey].totalTokens).toBe(55_000); + }); + + it("reports live diagnostic context from promptTokens, not provider usage totals", async () => { + const { usageEvent } = await runBaseReplyWithAgentMeta({ + tmpPrefix: "openclaw-usage-diagnostic-", + collectDiagnostics: true, + agentMeta: { + usage: { input: 75_000, output: 5_000, cacheRead: 25_000, total: 105_000 }, + lastCallUsage: { input: 55_000, output: 2_000, cacheRead: 25_000, total: 82_000 }, + promptTokens: 44_000, + }, + }); + expect(usageEvent).toMatchObject({ type: "model.usage", usage: { @@ -376,72 +349,21 @@ describe("runReplyAgent auto-compaction token update", () => { }); it("falls back to last-call prompt usage for live diagnostic context", async () => { - const tmp = await fs.mkdtemp(path.join(os.tmpdir(), "openclaw-usage-diagnostic-last-")); - const storePath = path.join(tmp, "sessions.json"); - const sessionKey = "main"; - const sessionEntry = { - sessionId: "session", - updatedAt: Date.now(), - totalTokens: 50_000, - }; - - await seedSessionStore({ storePath, sessionKey, entry: sessionEntry }); - - runEmbeddedPiAgentMock.mockResolvedValue({ - payloads: [{ text: "ok" }], - meta: { - agentMeta: { - usage: { input: 75_000, output: 5_000, cacheRead: 25_000, total: 105_000 }, - lastCallUsage: { - input: 55_000, - output: 2_000, - cacheRead: 25_000, - cacheWrite: 1_000, - total: 83_000, - }, + const { usageEvent } = await runBaseReplyWithAgentMeta({ + tmpPrefix: "openclaw-usage-diagnostic-last-", + collectDiagnostics: true, + agentMeta: { + usage: { input: 75_000, output: 5_000, cacheRead: 25_000, total: 105_000 }, + lastCallUsage: { + input: 55_000, + output: 2_000, + cacheRead: 25_000, + cacheWrite: 1_000, + total: 83_000, }, }, }); - const diagnostics: DiagnosticEventPayload[] = []; - const unsubscribe = onInternalDiagnosticEvent((event) => { - diagnostics.push(event); - }); - const { typing, sessionCtx, resolvedQueue, followupRun } = createBaseRun({ - storePath, - sessionEntry, - }); - - try { - await runReplyAgent({ - commandBody: "hello", - followupRun, - queueKey: "main", - resolvedQueue, - shouldSteer: false, - shouldFollowup: false, - isActive: false, - isStreaming: false, - typing, - sessionCtx, - sessionEntry, - sessionStore: { [sessionKey]: sessionEntry }, - sessionKey, - storePath, - defaultModel: "anthropic/claude-opus-4-6", - agentCfgContextTokens: 200_000, - resolvedVerboseLevel: "off", - isNewSession: false, - blockStreamingEnabled: false, - resolvedBlockStreamingBreak: "message_end", - shouldInjectGroupIntro: false, - typingMode: "instant", - }); - } finally { - unsubscribe(); - } - - const usageEvent = diagnostics.find((event) => event.type === "model.usage"); expect(usageEvent).toMatchObject({ type: "model.usage", usage: { diff --git a/src/auto-reply/reply/agent-runner.ts b/src/auto-reply/reply/agent-runner.ts index 317e00cfd3e..1e4e8d7dc1f 100644 --- a/src/auto-reply/reply/agent-runner.ts +++ b/src/auto-reply/reply/agent-runner.ts @@ -5,7 +5,7 @@ import { DEFAULT_CONTEXT_TOKENS } from "../../agents/defaults.js"; import { resolveModelAuthMode } from "../../agents/model-auth.js"; import { isCliProvider } from "../../agents/model-selection.js"; import { queueEmbeddedPiMessage } from "../../agents/pi-embedded-runner/runs.js"; -import { hasNonzeroUsage, normalizeUsage } from "../../agents/usage.js"; +import { deriveContextPromptTokens, hasNonzeroUsage, normalizeUsage } from "../../agents/usage.js"; import { loadSessionStore, resolveSessionPluginStatusLines, @@ -568,53 +568,6 @@ async function accumulateSessionUsageFromTranscript(params: { } } -function resolveRequestPromptTokens(params: { - lastCallUsage?: { - input?: number; - output?: number; - cacheRead?: number; - cacheWrite?: number; - total?: number; - }; - promptTokens?: number; - usage?: { - input?: number; - output?: number; - cacheRead?: number; - cacheWrite?: number; - total?: number; - }; -}): number | undefined { - if ( - typeof params.promptTokens === "number" && - Number.isFinite(params.promptTokens) && - params.promptTokens > 0 - ) { - return params.promptTokens; - } - const lastCall = params.lastCallUsage; - if (lastCall) { - const input = lastCall.input ?? 0; - const cacheRead = lastCall.cacheRead ?? 0; - const cacheWrite = lastCall.cacheWrite ?? 0; - const sum = input + cacheRead + cacheWrite; - if (sum > 0) { - return sum; - } - } - const usage = params.usage; - if (usage) { - const input = usage.input ?? 0; - const cacheRead = usage.cacheRead ?? 0; - const cacheWrite = usage.cacheWrite ?? 0; - const sum = input + cacheRead + cacheWrite; - if (sum > 0) { - return sum; - } - } - return undefined; -} - function formatRequestContextTraceBlock(params: { provider?: string; model?: string; @@ -785,7 +738,7 @@ function buildInlineRawTracePayload(params: { if (params.entry?.traceLevel !== "raw") { return undefined; } - const resolvedPromptTokens = resolveRequestPromptTokens({ + const resolvedPromptTokens = deriveContextPromptTokens({ lastCallUsage: params.lastCallUsage, promptTokens: params.promptTokens, usage: params.usage, @@ -1430,7 +1383,7 @@ export async function runReplyAgent(params: { const cacheWrite = usage.cacheWrite ?? 0; const usagePromptTokens = input + cacheRead + cacheWrite; const totalTokens = usage.total ?? usagePromptTokens + output; - const contextUsedTokens = resolveRequestPromptTokens({ + const contextUsedTokens = deriveContextPromptTokens({ lastCallUsage: runResult.meta?.agentMeta?.lastCallUsage, promptTokens, usage,