refactor: centralize context prompt token resolution

This commit is contained in:
Peter Steinberger
2026-04-25 20:34:14 +01:00
parent 5fe06f3cdc
commit bb6cf75463
4 changed files with 105 additions and 190 deletions

View File

@@ -1,9 +1,10 @@
import { describe, expect, it } from "vitest";
import {
normalizeUsage,
hasNonzeroUsage,
deriveContextPromptTokens,
derivePromptTokens,
deriveSessionTotalTokens,
hasNonzeroUsage,
normalizeUsage,
toOpenAiChatCompletionsUsage,
} from "./usage.js";
@@ -282,6 +283,35 @@ describe("derivePromptTokens", () => {
});
});
describe("deriveContextPromptTokens", () => {
it("prefers explicit prompt snapshot over provider usage", () => {
expect(
deriveContextPromptTokens({
promptTokens: 44_000,
lastCallUsage: { input: 55_000, cacheRead: 25_000 },
usage: { input: 75_000, cacheRead: 25_000, output: 5_000, total: 105_000 },
}),
).toBe(44_000);
});
it("falls back to last-call prompt usage before accumulated usage", () => {
expect(
deriveContextPromptTokens({
lastCallUsage: { input: 55_000, cacheRead: 25_000, cacheWrite: 1_000 },
usage: { input: 75_000, cacheRead: 25_000, output: 5_000, total: 105_000 },
}),
).toBe(81_000);
});
it("falls back to accumulated usage when no prompt snapshot exists", () => {
expect(
deriveContextPromptTokens({
usage: { input: 75_000, cacheRead: 25_000, output: 5_000, total: 105_000 },
}),
).toBe(100_000);
});
});
describe("deriveSessionTotalTokens", () => {
it("includes cache tokens in total calculation", () => {
const totalTokens = deriveSessionTotalTokens({

View File

@@ -219,6 +219,19 @@ export function derivePromptTokens(usage?: {
return sum > 0 ? sum : undefined;
}
export function deriveContextPromptTokens(params: {
lastCallUsage?: NormalizedUsage;
promptTokens?: number;
usage?: NormalizedUsage;
}): number | undefined {
const promptOverride = params.promptTokens;
if (typeof promptOverride === "number" && Number.isFinite(promptOverride) && promptOverride > 0) {
return promptOverride;
}
return derivePromptTokens(params.lastCallUsage) ?? derivePromptTokens(params.usage);
}
export function deriveSessionTotalTokens(params: {
usage?: {
input?: number;
@@ -241,13 +254,10 @@ export function deriveSessionTotalTokens(params: {
// NOTE: SessionEntry.totalTokens is used as a prompt/context snapshot.
// It intentionally excludes completion/output tokens.
const promptTokens = hasPromptOverride
? promptOverride
: derivePromptTokens({
input: usage?.input,
cacheRead: usage?.cacheRead,
cacheWrite: usage?.cacheWrite,
});
const promptTokens = deriveContextPromptTokens({
promptTokens: hasPromptOverride ? promptOverride : undefined,
usage,
});
if (!(typeof promptTokens === "number") || !Number.isFinite(promptTokens) || promptTokens <= 0) {
return undefined;

View File

@@ -239,8 +239,12 @@ describe("runReplyAgent auto-compaction token update", () => {
return { typing, sessionCtx, resolvedQueue, followupRun };
}
it("updates totalTokens from lastCallUsage even without compaction", async () => {
const tmp = await fs.mkdtemp(path.join(os.tmpdir(), "openclaw-usage-last-"));
async function runBaseReplyWithAgentMeta(params: {
agentMeta: Record<string, unknown>;
collectDiagnostics?: boolean;
tmpPrefix: string;
}) {
const tmp = await fs.mkdtemp(path.join(os.tmpdir(), params.tmpPrefix));
const storePath = path.join(tmp, "sessions.json");
const sessionKey = "main";
const sessionEntry = {
@@ -254,76 +258,16 @@ describe("runReplyAgent auto-compaction token update", () => {
runEmbeddedPiAgentMock.mockResolvedValue({
payloads: [{ text: "ok" }],
meta: {
agentMeta: {
// Tool-use loop: accumulated input is higher than last call's input
usage: { input: 75_000, output: 5_000, total: 80_000 },
lastCallUsage: { input: 55_000, output: 2_000, total: 57_000 },
},
},
});
const { typing, sessionCtx, resolvedQueue, followupRun } = createBaseRun({
storePath,
sessionEntry,
});
await runReplyAgent({
commandBody: "hello",
followupRun,
queueKey: "main",
resolvedQueue,
shouldSteer: false,
shouldFollowup: false,
isActive: false,
isStreaming: false,
typing,
sessionCtx,
sessionEntry,
sessionStore: { [sessionKey]: sessionEntry },
sessionKey,
storePath,
defaultModel: "anthropic/claude-opus-4-6",
agentCfgContextTokens: 200_000,
resolvedVerboseLevel: "off",
isNewSession: false,
blockStreamingEnabled: false,
resolvedBlockStreamingBreak: "message_end",
shouldInjectGroupIntro: false,
typingMode: "instant",
});
const stored = JSON.parse(await fs.readFile(storePath, "utf-8"));
// totalTokens should use lastCallUsage (55k), not accumulated (75k)
expect(stored[sessionKey].totalTokens).toBe(55_000);
});
it("reports live diagnostic context from promptTokens, not provider usage totals", async () => {
const tmp = await fs.mkdtemp(path.join(os.tmpdir(), "openclaw-usage-diagnostic-"));
const storePath = path.join(tmp, "sessions.json");
const sessionKey = "main";
const sessionEntry = {
sessionId: "session",
updatedAt: Date.now(),
totalTokens: 50_000,
};
await seedSessionStore({ storePath, sessionKey, entry: sessionEntry });
runEmbeddedPiAgentMock.mockResolvedValue({
payloads: [{ text: "ok" }],
meta: {
agentMeta: {
usage: { input: 75_000, output: 5_000, cacheRead: 25_000, total: 105_000 },
lastCallUsage: { input: 55_000, output: 2_000, cacheRead: 25_000, total: 82_000 },
promptTokens: 44_000,
},
agentMeta: params.agentMeta,
},
});
const diagnostics: DiagnosticEventPayload[] = [];
const unsubscribe = onInternalDiagnosticEvent((event) => {
diagnostics.push(event);
});
const unsubscribe = params.collectDiagnostics
? onInternalDiagnosticEvent((event) => {
diagnostics.push(event);
})
: undefined;
const { typing, sessionCtx, resolvedQueue, followupRun } = createBaseRun({
storePath,
sessionEntry,
@@ -355,10 +299,39 @@ describe("runReplyAgent auto-compaction token update", () => {
typingMode: "instant",
});
} finally {
unsubscribe();
unsubscribe?.();
}
const stored = JSON.parse(await fs.readFile(storePath, "utf-8"));
const usageEvent = diagnostics.find((event) => event.type === "model.usage");
return { sessionKey, stored, usageEvent };
}
it("updates totalTokens from lastCallUsage even without compaction", async () => {
const { sessionKey, stored } = await runBaseReplyWithAgentMeta({
tmpPrefix: "openclaw-usage-last-",
agentMeta: {
// Tool-use loop: accumulated input is higher than last call's input
usage: { input: 75_000, output: 5_000, total: 80_000 },
lastCallUsage: { input: 55_000, output: 2_000, total: 57_000 },
},
});
// totalTokens should use lastCallUsage (55k), not accumulated (75k)
expect(stored[sessionKey].totalTokens).toBe(55_000);
});
it("reports live diagnostic context from promptTokens, not provider usage totals", async () => {
const { usageEvent } = await runBaseReplyWithAgentMeta({
tmpPrefix: "openclaw-usage-diagnostic-",
collectDiagnostics: true,
agentMeta: {
usage: { input: 75_000, output: 5_000, cacheRead: 25_000, total: 105_000 },
lastCallUsage: { input: 55_000, output: 2_000, cacheRead: 25_000, total: 82_000 },
promptTokens: 44_000,
},
});
expect(usageEvent).toMatchObject({
type: "model.usage",
usage: {
@@ -376,72 +349,21 @@ describe("runReplyAgent auto-compaction token update", () => {
});
it("falls back to last-call prompt usage for live diagnostic context", async () => {
const tmp = await fs.mkdtemp(path.join(os.tmpdir(), "openclaw-usage-diagnostic-last-"));
const storePath = path.join(tmp, "sessions.json");
const sessionKey = "main";
const sessionEntry = {
sessionId: "session",
updatedAt: Date.now(),
totalTokens: 50_000,
};
await seedSessionStore({ storePath, sessionKey, entry: sessionEntry });
runEmbeddedPiAgentMock.mockResolvedValue({
payloads: [{ text: "ok" }],
meta: {
agentMeta: {
usage: { input: 75_000, output: 5_000, cacheRead: 25_000, total: 105_000 },
lastCallUsage: {
input: 55_000,
output: 2_000,
cacheRead: 25_000,
cacheWrite: 1_000,
total: 83_000,
},
const { usageEvent } = await runBaseReplyWithAgentMeta({
tmpPrefix: "openclaw-usage-diagnostic-last-",
collectDiagnostics: true,
agentMeta: {
usage: { input: 75_000, output: 5_000, cacheRead: 25_000, total: 105_000 },
lastCallUsage: {
input: 55_000,
output: 2_000,
cacheRead: 25_000,
cacheWrite: 1_000,
total: 83_000,
},
},
});
const diagnostics: DiagnosticEventPayload[] = [];
const unsubscribe = onInternalDiagnosticEvent((event) => {
diagnostics.push(event);
});
const { typing, sessionCtx, resolvedQueue, followupRun } = createBaseRun({
storePath,
sessionEntry,
});
try {
await runReplyAgent({
commandBody: "hello",
followupRun,
queueKey: "main",
resolvedQueue,
shouldSteer: false,
shouldFollowup: false,
isActive: false,
isStreaming: false,
typing,
sessionCtx,
sessionEntry,
sessionStore: { [sessionKey]: sessionEntry },
sessionKey,
storePath,
defaultModel: "anthropic/claude-opus-4-6",
agentCfgContextTokens: 200_000,
resolvedVerboseLevel: "off",
isNewSession: false,
blockStreamingEnabled: false,
resolvedBlockStreamingBreak: "message_end",
shouldInjectGroupIntro: false,
typingMode: "instant",
});
} finally {
unsubscribe();
}
const usageEvent = diagnostics.find((event) => event.type === "model.usage");
expect(usageEvent).toMatchObject({
type: "model.usage",
usage: {

View File

@@ -5,7 +5,7 @@ import { DEFAULT_CONTEXT_TOKENS } from "../../agents/defaults.js";
import { resolveModelAuthMode } from "../../agents/model-auth.js";
import { isCliProvider } from "../../agents/model-selection.js";
import { queueEmbeddedPiMessage } from "../../agents/pi-embedded-runner/runs.js";
import { hasNonzeroUsage, normalizeUsage } from "../../agents/usage.js";
import { deriveContextPromptTokens, hasNonzeroUsage, normalizeUsage } from "../../agents/usage.js";
import {
loadSessionStore,
resolveSessionPluginStatusLines,
@@ -568,53 +568,6 @@ async function accumulateSessionUsageFromTranscript(params: {
}
}
function resolveRequestPromptTokens(params: {
lastCallUsage?: {
input?: number;
output?: number;
cacheRead?: number;
cacheWrite?: number;
total?: number;
};
promptTokens?: number;
usage?: {
input?: number;
output?: number;
cacheRead?: number;
cacheWrite?: number;
total?: number;
};
}): number | undefined {
if (
typeof params.promptTokens === "number" &&
Number.isFinite(params.promptTokens) &&
params.promptTokens > 0
) {
return params.promptTokens;
}
const lastCall = params.lastCallUsage;
if (lastCall) {
const input = lastCall.input ?? 0;
const cacheRead = lastCall.cacheRead ?? 0;
const cacheWrite = lastCall.cacheWrite ?? 0;
const sum = input + cacheRead + cacheWrite;
if (sum > 0) {
return sum;
}
}
const usage = params.usage;
if (usage) {
const input = usage.input ?? 0;
const cacheRead = usage.cacheRead ?? 0;
const cacheWrite = usage.cacheWrite ?? 0;
const sum = input + cacheRead + cacheWrite;
if (sum > 0) {
return sum;
}
}
return undefined;
}
function formatRequestContextTraceBlock(params: {
provider?: string;
model?: string;
@@ -785,7 +738,7 @@ function buildInlineRawTracePayload(params: {
if (params.entry?.traceLevel !== "raw") {
return undefined;
}
const resolvedPromptTokens = resolveRequestPromptTokens({
const resolvedPromptTokens = deriveContextPromptTokens({
lastCallUsage: params.lastCallUsage,
promptTokens: params.promptTokens,
usage: params.usage,
@@ -1430,7 +1383,7 @@ export async function runReplyAgent(params: {
const cacheWrite = usage.cacheWrite ?? 0;
const usagePromptTokens = input + cacheRead + cacheWrite;
const totalTokens = usage.total ?? usagePromptTokens + output;
const contextUsedTokens = resolveRequestPromptTokens({
const contextUsedTokens = deriveContextPromptTokens({
lastCallUsage: runResult.meta?.agentMeta?.lastCallUsage,
promptTokens,
usage,