fix: persist fallback overrides safely

This commit is contained in:
Peter Steinberger
2026-04-04 08:58:39 +09:00
parent 411282c36d
commit fe72474153
6 changed files with 417 additions and 7 deletions

View File

@@ -130,7 +130,7 @@ describe("live model switch", () => {
});
});
it("prefers persisted runtime model fields ahead of session overrides", async () => {
it("prefers persisted session overrides ahead of stale runtime model fields", async () => {
state.loadSessionStoreMock.mockReturnValue({
main: {
providerOverride: "anthropic",
@@ -152,7 +152,7 @@ describe("live model switch", () => {
}),
).toEqual({
provider: "anthropic",
model: "claude-sonnet-4-6",
model: "claude-opus-4-6",
authProfileId: undefined,
authProfileIdSource: undefined,
});

View File

@@ -35,13 +35,17 @@ export function resolveLiveSessionModelSelection(params: {
agentId,
});
const entry = loadSessionStore(storePath, { skipCache: true })[sessionKey];
const persisted = resolvePersistedModelRef({
const overrideSelection = resolvePersistedModelRef({
defaultProvider: defaultModelRef.provider,
runtimeProvider: entry?.modelProvider,
runtimeModel: entry?.model,
overrideProvider: entry?.providerOverride,
overrideModel: entry?.modelOverride,
});
const runtimeSelection = resolvePersistedModelRef({
defaultProvider: defaultModelRef.provider,
runtimeProvider: entry?.modelProvider,
runtimeModel: entry?.model,
});
const persisted = overrideSelection ?? runtimeSelection;
const provider =
persisted?.provider ?? entry?.providerOverride?.trim() ?? defaultModelRef.provider;
const model = persisted?.model ?? defaultModelRef.model;

View File

@@ -1,5 +1,6 @@
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
import { LiveSessionModelSwitchError } from "../../agents/live-model-switch-error.js";
import type { SessionEntry } from "../../config/sessions.js";
import type { TemplateContext } from "../templating.js";
import type { GetReplyOptions } from "../types.js";
import { MAX_LIVE_SWITCH_RETRIES } from "./agent-runner-execution.js";
@@ -746,4 +747,63 @@ describe("runAgentTurnWithFallback", () => {
expect(followupRun.run.authProfileId).toBe("profile-c");
expect(followupRun.run.authProfileIdSource).toBe("auto");
});
it("does not roll back newer override changes after a failed fallback candidate", async () => {
state.runWithModelFallbackMock.mockImplementation(
async (params: { run: (provider: string, model: string) => Promise<unknown> }) => {
await expect(params.run("openai", "gpt-5.4")).rejects.toThrow("fallback failed");
throw new Error("fallback failed");
},
);
const sessionEntry: SessionEntry = {
sessionId: "session",
updatedAt: Date.now(),
providerOverride: "anthropic",
modelOverride: "claude",
authProfileOverride: "anthropic:default",
authProfileOverrideSource: "user",
};
const sessionStore = { main: sessionEntry };
state.runEmbeddedPiAgentMock.mockImplementationOnce(async () => {
sessionEntry.providerOverride = "zai";
sessionEntry.modelOverride = "glm-5";
sessionEntry.authProfileOverride = "zai:work";
sessionEntry.authProfileOverrideSource = "user";
throw new Error("fallback failed");
});
const runAgentTurnWithFallback = await getRunAgentTurnWithFallback();
const result = await runAgentTurnWithFallback({
commandBody: "hello",
followupRun: createFollowupRun(),
sessionCtx: {
Provider: "whatsapp",
MessageSid: "msg",
} as unknown as TemplateContext,
opts: {},
typingSignals: createMockTypingSignaler(),
blockReplyPipeline: null,
blockStreamingEnabled: false,
resolvedBlockStreamingBreak: "message_end",
applyReplyToMode: (payload) => payload,
shouldEmitToolResult: () => true,
shouldEmitToolOutput: () => false,
pendingToolTasks: new Set(),
resetSessionAfterCompactionFailure: async () => false,
resetSessionAfterRoleOrderingConflict: async () => false,
isHeartbeat: false,
sessionKey: "main",
getActiveSessionEntry: () => sessionEntry,
activeSessionStore: sessionStore,
resolvedVerboseLevel: "off",
});
expect(result.kind).toBe("final");
expect(sessionEntry.providerOverride).toBe("zai");
expect(sessionEntry.modelOverride).toBe("glm-5");
expect(sessionEntry.authProfileOverride).toBe("zai:work");
expect(sessionEntry.authProfileOverrideSource).toBe("user");
expect(sessionStore.main.providerOverride).toBe("zai");
expect(sessionStore.main.modelOverride).toBe("glm-5");
});
});

View File

@@ -47,6 +47,7 @@ import {
SILENT_REPLY_TOKEN,
} from "../tokens.js";
import type { GetReplyOptions, ReplyPayload } from "../types.js";
import { resolveRunAuthProfile } from "./agent-runner-auth-profile.js";
import {
buildEmbeddedRunExecutionParams,
resolveModelFallbackOptions,
@@ -87,6 +88,142 @@ export type AgentRunLoopResult =
}
| { kind: "final"; payload: ReplyPayload };
type FallbackSelectionState = Pick<
SessionEntry,
| "providerOverride"
| "modelOverride"
| "authProfileOverride"
| "authProfileOverrideSource"
| "authProfileOverrideCompactionCount"
>;
const FALLBACK_SELECTION_STATE_KEYS = [
"providerOverride",
"modelOverride",
"authProfileOverride",
"authProfileOverrideSource",
"authProfileOverrideCompactionCount",
] as const satisfies ReadonlyArray<keyof FallbackSelectionState>;
function setFallbackSelectionStateField(
entry: SessionEntry,
key: keyof FallbackSelectionState,
value: FallbackSelectionState[keyof FallbackSelectionState],
): boolean {
switch (key) {
case "providerOverride":
if (entry.providerOverride !== value) {
entry.providerOverride = value as SessionEntry["providerOverride"];
return true;
}
return false;
case "modelOverride":
if (entry.modelOverride !== value) {
entry.modelOverride = value as SessionEntry["modelOverride"];
return true;
}
return false;
case "authProfileOverride":
if (entry.authProfileOverride !== value) {
entry.authProfileOverride = value as SessionEntry["authProfileOverride"];
return true;
}
return false;
case "authProfileOverrideSource":
if (entry.authProfileOverrideSource !== value) {
entry.authProfileOverrideSource = value as SessionEntry["authProfileOverrideSource"];
return true;
}
return false;
case "authProfileOverrideCompactionCount":
if (entry.authProfileOverrideCompactionCount !== value) {
entry.authProfileOverrideCompactionCount =
value as SessionEntry["authProfileOverrideCompactionCount"];
return true;
}
return false;
}
}
function snapshotFallbackSelectionState(entry: SessionEntry): FallbackSelectionState {
return {
providerOverride: entry.providerOverride,
modelOverride: entry.modelOverride,
authProfileOverride: entry.authProfileOverride,
authProfileOverrideSource: entry.authProfileOverrideSource,
authProfileOverrideCompactionCount: entry.authProfileOverrideCompactionCount,
};
}
function buildFallbackSelectionState(params: {
provider: string;
model: string;
authProfileId?: string;
authProfileIdSource?: "auto" | "user";
}): FallbackSelectionState {
return {
providerOverride: params.provider,
modelOverride: params.model,
authProfileOverride: params.authProfileId,
authProfileOverrideSource: params.authProfileId ? params.authProfileIdSource : undefined,
authProfileOverrideCompactionCount: undefined,
};
}
function applyFallbackSelectionState(
entry: SessionEntry,
nextState: FallbackSelectionState,
now = Date.now(),
): boolean {
let updated = false;
for (const key of FALLBACK_SELECTION_STATE_KEYS) {
const nextValue = nextState[key];
if (nextValue === undefined) {
if (Object.hasOwn(entry, key)) {
delete entry[key];
updated = true;
}
continue;
}
if (entry[key] !== nextValue) {
updated = setFallbackSelectionStateField(entry, key, nextValue) || updated;
}
}
if (updated) {
entry.updatedAt = now;
}
return updated;
}
function rollbackFallbackSelectionStateIfUnchanged(
entry: SessionEntry,
expectedState: FallbackSelectionState,
previousState: FallbackSelectionState,
now = Date.now(),
): boolean {
let updated = false;
for (const key of FALLBACK_SELECTION_STATE_KEYS) {
if (entry[key] !== expectedState[key]) {
continue;
}
const previousValue = previousState[key];
if (previousValue === undefined) {
if (Object.hasOwn(entry, key)) {
delete entry[key];
updated = true;
}
continue;
}
if (entry[key] !== previousValue) {
updated = setFallbackSelectionStateField(entry, key, previousValue) || updated;
}
}
if (updated) {
entry.updatedAt = now;
}
return updated;
}
/**
* Build a human-friendly rate-limit message from a FallbackSummaryError.
* Includes a countdown when the soonest cooldown expiry is known.
@@ -207,6 +344,74 @@ export async function runAgentTurnWithFallback(params: {
let bootstrapPromptWarningSignaturesSeen = resolveBootstrapWarningSignaturesSeen(
params.getActiveSessionEntry()?.systemPromptReport,
);
const persistFallbackCandidateSelection = async (provider: string, model: string) => {
if (
!params.sessionKey ||
!params.activeSessionStore ||
(provider === params.followupRun.run.provider && model === params.followupRun.run.model)
) {
return;
}
const activeSessionEntry =
params.getActiveSessionEntry() ?? params.activeSessionStore[params.sessionKey];
if (!activeSessionEntry) {
return;
}
const previousState = snapshotFallbackSelectionState(activeSessionEntry);
const scopedAuthProfile = resolveRunAuthProfile(params.followupRun.run, provider);
const nextState = buildFallbackSelectionState({
provider,
model,
authProfileId: scopedAuthProfile.authProfileId,
authProfileIdSource: scopedAuthProfile.authProfileIdSource,
});
if (!applyFallbackSelectionState(activeSessionEntry, nextState)) {
return;
}
params.activeSessionStore[params.sessionKey] = activeSessionEntry;
try {
if (params.storePath) {
await updateSessionStore(params.storePath, (store) => {
const persistedEntry = store[params.sessionKey!];
if (!persistedEntry) {
return;
}
applyFallbackSelectionState(persistedEntry, nextState);
store[params.sessionKey!] = persistedEntry;
});
}
} catch (error) {
rollbackFallbackSelectionStateIfUnchanged(activeSessionEntry, nextState, previousState);
params.activeSessionStore[params.sessionKey] = activeSessionEntry;
throw error;
}
return async () => {
const rolledBackInMemory = rollbackFallbackSelectionStateIfUnchanged(
activeSessionEntry,
nextState,
previousState,
);
if (rolledBackInMemory) {
params.activeSessionStore![params.sessionKey!] = activeSessionEntry;
}
if (!params.storePath) {
return;
}
await updateSessionStore(params.storePath, (store) => {
const persistedEntry = store[params.sessionKey!];
if (!persistedEntry) {
return;
}
if (rollbackFallbackSelectionStateIfUnchanged(persistedEntry, nextState, previousState)) {
store[params.sessionKey!] = persistedEntry;
}
});
};
};
while (true) {
try {
@@ -286,7 +491,7 @@ export async function runAgentTurnWithFallback(params: {
const fallbackResult = await runWithModelFallback({
...resolveModelFallbackOptions(params.followupRun.run),
runId,
run: (provider, model, runOptions) => {
run: async (provider, model, runOptions) => {
// Notify that model selection is complete (including after fallback).
// This allows responsePrefix template interpolation with the actual model.
params.opts?.onModelSelected?.({
@@ -294,6 +499,17 @@ export async function runAgentTurnWithFallback(params: {
model,
thinkLevel: params.followupRun.run.thinkLevel,
});
let rollbackFallbackCandidateSelection: (() => Promise<void>) | undefined;
try {
rollbackFallbackCandidateSelection = await persistFallbackCandidateSelection(
provider,
model,
);
} catch (error) {
logVerbose(
`failed to persist fallback candidate selection (non-fatal): ${String(error)}`,
);
}
if (isCliProvider(provider, params.followupRun.run.config)) {
const startedAt = Date.now();
@@ -372,6 +588,15 @@ export async function runAgentTurnWithFallback(params: {
return result;
} catch (err) {
if (rollbackFallbackCandidateSelection) {
try {
await rollbackFallbackCandidateSelection();
} catch (rollbackError) {
logVerbose(
`failed to roll back fallback candidate selection (non-fatal): ${String(rollbackError)}`,
);
}
}
emitAgentEvent({
runId,
stream: "lifecycle",
@@ -592,6 +817,17 @@ export async function runAgentTurnWithFallback(params: {
);
attemptCompactionCount = Math.max(attemptCompactionCount, resultCompactionCount);
return result;
} catch (err) {
if (rollbackFallbackCandidateSelection) {
try {
await rollbackFallbackCandidateSelection();
} catch (rollbackError) {
logVerbose(
`failed to roll back fallback candidate selection (non-fatal): ${String(rollbackError)}`,
);
}
}
throw err;
} finally {
autoCompactionCount += attemptCompactionCount;
}

View File

@@ -296,7 +296,7 @@ describe("runReplyAgent authProfileId fallback scoping", () => {
} as unknown as FollowupRun;
const sessionKey = "main";
const sessionEntry = {
const sessionEntry: SessionEntry = {
sessionId: "session",
updatedAt: Date.now(),
totalTokens: 1,
@@ -338,6 +338,115 @@ describe("runReplyAgent authProfileId fallback scoping", () => {
expect(call.provider).toBe("openai-codex");
expect(call.authProfileId).toBeUndefined();
expect(call.authProfileIdSource).toBeUndefined();
expect(sessionEntry.providerOverride).toBe("openai-codex");
expect(sessionEntry.modelOverride).toBe("gpt-5.2");
expect(sessionEntry.authProfileOverride).toBeUndefined();
expect(sessionEntry.authProfileOverrideSource).toBeUndefined();
});
it("persists same-provider fallback model while keeping the scoped auth profile", async () => {
runWithModelFallbackMock.mockImplementationOnce(
async ({ run }: RunWithModelFallbackParams) => ({
result: await run("anthropic", "claude-sonnet"),
provider: "anthropic",
model: "claude-sonnet",
}),
);
runEmbeddedPiAgentMock.mockResolvedValue({ payloads: [{ text: "ok" }], meta: {} });
const typing = createMockTypingController();
const sessionCtx = {
Provider: "telegram",
OriginatingTo: "chat",
AccountId: "primary",
MessageSid: "msg",
Surface: "telegram",
} as unknown as TemplateContext;
const resolvedQueue = { mode: "interrupt" } as unknown as QueueSettings;
const followupRun = {
prompt: "hello",
summaryLine: "hello",
enqueuedAt: Date.now(),
run: {
agentId: "main",
agentDir: "/tmp/agent",
sessionId: "session",
sessionKey: "main",
messageProvider: "telegram",
sessionFile: "/tmp/session.jsonl",
workspaceDir: "/tmp",
config: createCliBackendTestConfig(),
skillsSnapshot: {},
provider: "anthropic",
model: "claude-opus",
authProfileId: "anthropic:openclaw",
authProfileIdSource: "user",
thinkLevel: "low",
verboseLevel: "off",
elevatedLevel: "off",
bashElevated: {
enabled: false,
allowed: false,
defaultLevel: "off",
},
timeoutMs: 5_000,
blockReplyBreak: "message_end",
},
} as unknown as FollowupRun;
const sessionKey = "main";
const sessionEntry: SessionEntry = {
sessionId: "session",
updatedAt: Date.now(),
totalTokens: 1,
compactionCount: 0,
authProfileOverride: "anthropic:openclaw",
authProfileOverrideSource: "user" as const,
};
await runReplyAgent({
commandBody: "hello",
followupRun,
queueKey: sessionKey,
resolvedQueue,
shouldSteer: false,
shouldFollowup: false,
isActive: false,
isStreaming: false,
typing,
sessionCtx,
sessionEntry,
sessionStore: { [sessionKey]: sessionEntry },
sessionKey,
storePath: undefined,
defaultModel: "anthropic/claude-opus-4-5",
agentCfgContextTokens: 100_000,
resolvedVerboseLevel: "off",
isNewSession: false,
blockStreamingEnabled: false,
resolvedBlockStreamingBreak: "message_end",
shouldInjectGroupIntro: false,
typingMode: "instant",
});
expect(runEmbeddedPiAgentMock).toHaveBeenCalledTimes(1);
const call = runEmbeddedPiAgentMock.mock.calls[0]?.[0] as {
authProfileId?: unknown;
authProfileIdSource?: unknown;
provider?: unknown;
model?: unknown;
};
expect(call.provider).toBe("anthropic");
expect(call.model).toBe("claude-sonnet");
expect(call.authProfileId).toBe("anthropic:openclaw");
expect(call.authProfileIdSource).toBe("user");
expect(sessionEntry.providerOverride).toBe("anthropic");
expect(sessionEntry.modelOverride).toBe("claude-sonnet");
expect(sessionEntry.authProfileOverride).toBe("anthropic:openclaw");
expect(sessionEntry.authProfileOverrideSource).toBe("user");
});
});