diff --git a/src/agents/agent-command.live-model-switch.test.ts b/src/agents/agent-command.live-model-switch.test.ts index 45b8b20fcbc..ecaed234557 100644 --- a/src/agents/agent-command.live-model-switch.test.ts +++ b/src/agents/agent-command.live-model-switch.test.ts @@ -290,6 +290,8 @@ type FallbackRunnerParams = { run: (provider: string, model: string) => Promise; }; +type ModelSwitchOptions = ConstructorParameters[0]; + function makeSuccessResult(provider: string, model: string) { return { payloads: [{ text: "ok" }], @@ -302,6 +304,42 @@ function makeSuccessResult(provider: string, model: string) { }; } +function setupModelSwitchRetry(switchOptions: ModelSwitchOptions) { + let invocation = 0; + state.runWithModelFallbackMock.mockImplementation(async (params: FallbackRunnerParams) => { + invocation += 1; + if (invocation === 1) { + throw new LiveSessionModelSwitchError(switchOptions); + } + const result = await params.run(params.provider, params.model); + return { + result, + provider: params.provider, + model: params.model, + attempts: [], + }; + }); +} + +async function runBasicAgentCommand() { + const agentCommand = await getAgentCommand(); + await agentCommand({ + message: "hello", + to: "+1234567890", + senderIsOwner: true, + }); +} + +function expectFallbackOverrideCalls(first: boolean, second: boolean) { + expect(state.resolveEffectiveModelFallbacksMock).toHaveBeenCalledTimes(2); + expect(state.resolveEffectiveModelFallbacksMock.mock.calls[0][0]).toMatchObject({ + hasSessionModelOverride: first, + }); + expect(state.resolveEffectiveModelFallbacksMock.mock.calls[1][0]).toMatchObject({ + hasSessionModelOverride: second, + }); +} + describe("agentCommand – LiveSessionModelSwitchError retry", () => { beforeEach(() => { vi.clearAllMocks(); @@ -314,32 +352,14 @@ describe("agentCommand – LiveSessionModelSwitchError retry", () => { }); it("retries with the switched provider/model when LiveSessionModelSwitchError is thrown", async () => { - let invocation = 0; - state.runWithModelFallbackMock.mockImplementation(async (params: FallbackRunnerParams) => { - invocation += 1; - if (invocation === 1) { - throw new LiveSessionModelSwitchError({ - provider: "openai", - model: "gpt-5.4", - }); - } - const result = await params.run(params.provider, params.model); - return { - result, - provider: params.provider, - model: params.model, - attempts: [], - }; + setupModelSwitchRetry({ + provider: "openai", + model: "gpt-5.4", }); state.runAgentAttemptMock.mockResolvedValue(makeSuccessResult("openai", "gpt-5.4")); - const agentCommand = await getAgentCommand(); - await agentCommand({ - message: "hello", - to: "+1234567890", - senderIsOwner: true, - }); + await runBasicAgentCommand(); expect(state.runWithModelFallbackMock).toHaveBeenCalledTimes(2); @@ -385,32 +405,14 @@ describe("agentCommand – LiveSessionModelSwitchError retry", () => { }); it("resets lifecycleEnded flag between retry iterations", async () => { - let invocation = 0; - state.runWithModelFallbackMock.mockImplementation(async (params: FallbackRunnerParams) => { - invocation += 1; - if (invocation === 1) { - throw new LiveSessionModelSwitchError({ - provider: "openai", - model: "gpt-5.4", - }); - } - const result = await params.run(params.provider, params.model); - return { - result, - provider: params.provider, - model: params.model, - attempts: [], - }; + setupModelSwitchRetry({ + provider: "openai", + model: "gpt-5.4", }); state.runAgentAttemptMock.mockResolvedValue(makeSuccessResult("openai", "gpt-5.4")); - const agentCommand = await getAgentCommand(); - await agentCommand({ - message: "hello", - to: "+1234567890", - senderIsOwner: true, - }); + await runBasicAgentCommand(); const lifecycleEndCalls = state.emitAgentEventMock.mock.calls.filter((call: unknown[]) => { const arg = call[0] as { stream?: string; data?: { phase?: string } }; @@ -420,25 +422,12 @@ describe("agentCommand – LiveSessionModelSwitchError retry", () => { }); it("propagates authProfileId from the switch error to the retried session entry", async () => { - let invocation = 0; let capturedAuthProfileProvider: string | undefined; - state.runWithModelFallbackMock.mockImplementation(async (params: FallbackRunnerParams) => { - invocation += 1; - if (invocation === 1) { - throw new LiveSessionModelSwitchError({ - provider: "openai", - model: "gpt-5.4", - authProfileId: "profile-openai-prod", - authProfileIdSource: "user", - }); - } - const result = await params.run(params.provider, params.model); - return { - result, - provider: params.provider, - model: params.model, - attempts: [], - }; + setupModelSwitchRetry({ + provider: "openai", + model: "gpt-5.4", + authProfileId: "profile-openai-prod", + authProfileIdSource: "user", }); state.runAgentAttemptMock.mockImplementation(async (...args: unknown[]) => { @@ -447,130 +436,53 @@ describe("agentCommand – LiveSessionModelSwitchError retry", () => { return makeSuccessResult("openai", "gpt-5.4"); }); - const agentCommand = await getAgentCommand(); - await agentCommand({ - message: "hello", - to: "+1234567890", - senderIsOwner: true, - }); + await runBasicAgentCommand(); expect(capturedAuthProfileProvider).toBe("openai"); expect(state.runWithModelFallbackMock).toHaveBeenCalledTimes(2); }); it("updates hasSessionModelOverride for fallback resolution after switch", async () => { - let invocation = 0; - state.runWithModelFallbackMock.mockImplementation(async (params: FallbackRunnerParams) => { - invocation += 1; - if (invocation === 1) { - throw new LiveSessionModelSwitchError({ - provider: "openai", - model: "gpt-5.4", - }); - } - const result = await params.run(params.provider, params.model); - return { - result, - provider: params.provider, - model: params.model, - attempts: [], - }; + setupModelSwitchRetry({ + provider: "openai", + model: "gpt-5.4", }); state.runAgentAttemptMock.mockResolvedValue(makeSuccessResult("openai", "gpt-5.4")); state.resolveEffectiveModelFallbacksMock.mockClear(); - const agentCommand = await getAgentCommand(); - await agentCommand({ - message: "hello", - to: "+1234567890", - senderIsOwner: true, - }); + await runBasicAgentCommand(); - expect(state.resolveEffectiveModelFallbacksMock).toHaveBeenCalledTimes(2); - expect(state.resolveEffectiveModelFallbacksMock.mock.calls[0][0]).toMatchObject({ - hasSessionModelOverride: false, - }); - expect(state.resolveEffectiveModelFallbacksMock.mock.calls[1][0]).toMatchObject({ - hasSessionModelOverride: true, - }); + expectFallbackOverrideCalls(false, true); }); it("does not flip hasSessionModelOverride on auth-only switch with same model", async () => { - let invocation = 0; - state.runWithModelFallbackMock.mockImplementation(async (params: FallbackRunnerParams) => { - invocation += 1; - if (invocation === 1) { - throw new LiveSessionModelSwitchError({ - provider: "anthropic", - model: "claude", - authProfileId: "profile-99", - authProfileIdSource: "user", - }); - } - const result = await params.run(params.provider, params.model); - return { - result, - provider: params.provider, - model: params.model, - attempts: [], - }; + setupModelSwitchRetry({ + provider: "anthropic", + model: "claude", + authProfileId: "profile-99", + authProfileIdSource: "user", }); state.runAgentAttemptMock.mockResolvedValue(makeSuccessResult("anthropic", "claude")); state.resolveEffectiveModelFallbacksMock.mockClear(); - const agentCommand = await getAgentCommand(); - await agentCommand({ - message: "hello", - to: "+1234567890", - senderIsOwner: true, - }); + await runBasicAgentCommand(); - expect(state.resolveEffectiveModelFallbacksMock).toHaveBeenCalledTimes(2); - expect(state.resolveEffectiveModelFallbacksMock.mock.calls[0][0]).toMatchObject({ - hasSessionModelOverride: false, - }); - expect(state.resolveEffectiveModelFallbacksMock.mock.calls[1][0]).toMatchObject({ - hasSessionModelOverride: false, - }); + expectFallbackOverrideCalls(false, false); }); it("flips hasSessionModelOverride on provider-only switch with same model", async () => { - let invocation = 0; - state.runWithModelFallbackMock.mockImplementation(async (params: FallbackRunnerParams) => { - invocation += 1; - if (invocation === 1) { - throw new LiveSessionModelSwitchError({ - provider: "openai", - model: "claude", - }); - } - const result = await params.run(params.provider, params.model); - return { - result, - provider: params.provider, - model: params.model, - attempts: [], - }; + setupModelSwitchRetry({ + provider: "openai", + model: "claude", }); state.runAgentAttemptMock.mockResolvedValue(makeSuccessResult("openai", "claude")); state.resolveEffectiveModelFallbacksMock.mockClear(); - const agentCommand = await getAgentCommand(); - await agentCommand({ - message: "hello", - to: "+1234567890", - senderIsOwner: true, - }); + await runBasicAgentCommand(); - expect(state.resolveEffectiveModelFallbacksMock).toHaveBeenCalledTimes(2); - expect(state.resolveEffectiveModelFallbacksMock.mock.calls[0][0]).toMatchObject({ - hasSessionModelOverride: false, - }); - expect(state.resolveEffectiveModelFallbacksMock.mock.calls[1][0]).toMatchObject({ - hasSessionModelOverride: true, - }); + expectFallbackOverrideCalls(false, true); }); });