diff --git a/src/agents/model-fallback.run-embedded.e2e.test.ts b/src/agents/model-fallback.run-embedded.e2e.test.ts index 39d766889e1..2035d7caa22 100644 --- a/src/agents/model-fallback.run-embedded.e2e.test.ts +++ b/src/agents/model-fallback.run-embedded.e2e.test.ts @@ -103,6 +103,12 @@ const OVERLOADED_ERROR_PAYLOAD = const RATE_LIMIT_ERROR_MESSAGE = "rate limit exceeded"; const NO_ENDPOINTS_FOUND_ERROR_MESSAGE = "404 No endpoints found for deepseek/deepseek-r1:free."; +type EmbeddedAttemptParams = { + provider: string; + modelId?: string; + authProfileId?: string; +}; + function makeConfig(): OpenClawConfig { const apiKeyField = ["api", "Key"].join(""); return { @@ -277,85 +283,69 @@ function mockPrimaryOverloadedThenFallbackSuccess() { mockPrimaryErrorThenFallbackSuccess(OVERLOADED_ERROR_PAYLOAD); } -function mockPrimaryPromptErrorThenFallbackSuccess(errorMessage: string) { +function makeFallbackSuccessAttempt(): EmbeddedRunAttemptResult { + return makeEmbeddedRunnerAttempt({ + assistantTexts: ["fallback ok"], + lastAssistant: buildEmbeddedRunnerAssistant({ + provider: "groq", + model: "mock-2", + stopReason: "stop", + content: [{ type: "text", text: "fallback ok" }], + }), + }); +} + +function mockPrimaryFailureThenFallbackSuccess( + makePrimaryAttempt: ( + attemptParams: EmbeddedAttemptParams, + ) => EmbeddedRunAttemptResult | Promise, +) { runEmbeddedAttemptMock.mockImplementation(async (params: unknown) => { - const attemptParams = params as { provider: string }; + const attemptParams = params as EmbeddedAttemptParams; if (attemptParams.provider === "openai") { - return makeEmbeddedRunnerAttempt({ - promptError: new Error(errorMessage), - }); + return await makePrimaryAttempt(attemptParams); } if (attemptParams.provider === "groq") { - return makeEmbeddedRunnerAttempt({ - assistantTexts: ["fallback ok"], - lastAssistant: buildEmbeddedRunnerAssistant({ - provider: "groq", - model: "mock-2", - stopReason: "stop", - content: [{ type: "text", text: "fallback ok" }], - }), - }); + return makeFallbackSuccessAttempt(); } throw new Error(`Unexpected provider ${attemptParams.provider}`); }); } +function mockPrimaryPromptErrorThenFallbackSuccess(errorMessage: string) { + mockPrimaryFailureThenFallbackSuccess(() => + makeEmbeddedRunnerAttempt({ + promptError: new Error(errorMessage), + }), + ); +} + function mockPrimaryErrorThenFallbackSuccess(errorMessage: string) { - runEmbeddedAttemptMock.mockImplementation(async (params: unknown) => { - const attemptParams = params as { provider: string; modelId: string; authProfileId?: string }; - if (attemptParams.provider === "openai") { - return makeEmbeddedRunnerAttempt({ - assistantTexts: [], - lastAssistant: buildEmbeddedRunnerAssistant({ - provider: "openai", - model: "mock-1", - stopReason: "error", - errorMessage, - }), - }); - } - if (attemptParams.provider === "groq") { - return makeEmbeddedRunnerAttempt({ - assistantTexts: ["fallback ok"], - lastAssistant: buildEmbeddedRunnerAssistant({ - provider: "groq", - model: "mock-2", - stopReason: "stop", - content: [{ type: "text", text: "fallback ok" }], - }), - }); - } - throw new Error(`Unexpected provider ${attemptParams.provider}`); - }); + mockPrimaryFailureThenFallbackSuccess(() => + makeEmbeddedRunnerAttempt({ + assistantTexts: [], + lastAssistant: buildEmbeddedRunnerAssistant({ + provider: "openai", + model: "mock-1", + stopReason: "error", + errorMessage, + }), + }), + ); } function mockPrimaryRunLoopRateLimitThenFallbackSuccess(errorMessage: string) { - runEmbeddedAttemptMock.mockImplementation(async (params: unknown) => { - const attemptParams = params as { provider: string }; - if (attemptParams.provider === "openai") { - return makeEmbeddedRunnerAttempt({ - assistantTexts: [], - lastAssistant: buildEmbeddedRunnerAssistant({ - provider: "openai", - model: "mock-1", - stopReason: "length", - errorMessage, - }), - }); - } - if (attemptParams.provider === "groq") { - return makeEmbeddedRunnerAttempt({ - assistantTexts: ["fallback ok"], - lastAssistant: buildEmbeddedRunnerAssistant({ - provider: "groq", - model: "mock-2", - stopReason: "stop", - content: [{ type: "text", text: "fallback ok" }], - }), - }); - } - throw new Error(`Unexpected provider ${attemptParams.provider}`); - }); + mockPrimaryFailureThenFallbackSuccess(() => + makeEmbeddedRunnerAttempt({ + assistantTexts: [], + lastAssistant: buildEmbeddedRunnerAssistant({ + provider: "openai", + model: "mock-1", + stopReason: "length", + errorMessage, + }), + }), + ); } function expectOpenAiThenGroqAttemptOrder(params?: { expectOpenAiAuthProfileId?: string }) { @@ -391,6 +381,17 @@ function mockAllProvidersOverloaded() { }); } +function countProviderAttempts(provider: string) { + return runEmbeddedAttemptMock.mock.calls.filter( + (call) => (call[0] as { provider?: string })?.provider === provider, + ).length; +} + +function expectProviderAttemptCounts(expected: { openai: number; groq: number }) { + expect(countProviderAttempts("openai")).toBe(expected.openai); + expect(countProviderAttempts("groq")).toBe(expected.groq); +} + describe("runWithModelFallback + runEmbeddedPiAgent failover behavior", () => { it("falls back on OpenRouter-style no-endpoints assistant errors", async () => { await withAgentWorkspace(async ({ agentDir, workspaceDir }) => { @@ -627,62 +628,8 @@ describe("runWithModelFallback + runEmbeddedPiAgent failover behavior", () => { // cap profile rotations at overloadedProfileRotations=1 and escalate // to cross-provider fallback immediately. await withAgentWorkspace(async ({ agentDir, workspaceDir }) => { - // Write auth store with multiple profiles for openai - await fs.writeFile( - path.join(agentDir, "auth-profiles.json"), - JSON.stringify({ - version: 1, - profiles: { - "openai:p1": { type: "api_key", provider: "openai", key: "sk-openai-1" }, - "openai:p2": { type: "api_key", provider: "openai", key: "sk-openai-2" }, - "openai:p3": { type: "api_key", provider: "openai", key: "sk-openai-3" }, - "groq:p1": { type: "api_key", provider: "groq", key: "sk-groq" }, - }, - }), - ); - await fs.writeFile( - path.join(agentDir, "auth-state.json"), - JSON.stringify({ - version: 1, - usageStats: { - "openai:p1": { lastUsed: 1 }, - "openai:p2": { lastUsed: 2 }, - "openai:p3": { lastUsed: 3 }, - "groq:p1": { lastUsed: 4 }, - }, - }), - ); - - runEmbeddedAttemptMock.mockImplementation(async (params: unknown) => { - const attemptParams = params as { - provider: string; - modelId: string; - authProfileId?: string; - }; - if (attemptParams.provider === "openai") { - return makeEmbeddedRunnerAttempt({ - assistantTexts: [], - lastAssistant: buildEmbeddedRunnerAssistant({ - provider: "openai", - model: "mock-1", - stopReason: "error", - errorMessage: OVERLOADED_ERROR_PAYLOAD, - }), - }); - } - if (attemptParams.provider === "groq") { - return makeEmbeddedRunnerAttempt({ - assistantTexts: ["fallback ok"], - lastAssistant: buildEmbeddedRunnerAssistant({ - provider: "groq", - model: "mock-2", - stopReason: "stop", - content: [{ type: "text", text: "fallback ok" }], - }), - }); - } - throw new Error(`Unexpected provider ${attemptParams.provider}`); - }); + await writeMultiProfileAuthStore(agentDir); + mockPrimaryOverloadedThenFallbackSuccess(); const result = await runEmbeddedFallback({ agentDir, @@ -701,47 +648,14 @@ describe("runWithModelFallback + runEmbeddedPiAgent failover behavior", () => { // - 1 rotation to p2 (capped) // - escalation to groq (1 attempt) // Total: 3 attempts, NOT 4 (which would mean all 3 openai profiles tried) - const openaiAttempts = runEmbeddedAttemptMock.mock.calls.filter( - (call) => (call[0] as { provider?: string })?.provider === "openai", - ); - const groqAttempts = runEmbeddedAttemptMock.mock.calls.filter( - (call) => (call[0] as { provider?: string })?.provider === "groq", - ); - expect(openaiAttempts.length).toBe(2); - expect(groqAttempts.length).toBe(1); + expectProviderAttemptCounts({ openai: 2, groq: 1 }); }); }); it("respects overloadedProfileRotations=0 and falls back immediately", async () => { await withAgentWorkspace(async ({ agentDir, workspaceDir }) => { await writeMultiProfileAuthStore(agentDir); - - runEmbeddedAttemptMock.mockImplementation(async (params: unknown) => { - const attemptParams = params as { provider: string }; - if (attemptParams.provider === "openai") { - return makeEmbeddedRunnerAttempt({ - assistantTexts: [], - lastAssistant: buildEmbeddedRunnerAssistant({ - provider: "openai", - model: "mock-1", - stopReason: "error", - errorMessage: OVERLOADED_ERROR_PAYLOAD, - }), - }); - } - if (attemptParams.provider === "groq") { - return makeEmbeddedRunnerAttempt({ - assistantTexts: ["fallback ok"], - lastAssistant: buildEmbeddedRunnerAssistant({ - provider: "groq", - model: "mock-2", - stopReason: "stop", - content: [{ type: "text", text: "fallback ok" }], - }), - }); - } - throw new Error(`Unexpected provider ${attemptParams.provider}`); - }); + mockPrimaryOverloadedThenFallbackSuccess(); const result = await runEmbeddedFallback({ agentDir, @@ -755,14 +669,7 @@ describe("runWithModelFallback + runEmbeddedPiAgent failover behavior", () => { }); expect(result.provider).toBe("groq"); - const openaiAttempts = runEmbeddedAttemptMock.mock.calls.filter( - (call) => (call[0] as { provider?: string })?.provider === "openai", - ); - const groqAttempts = runEmbeddedAttemptMock.mock.calls.filter( - (call) => (call[0] as { provider?: string })?.provider === "groq", - ); - expect(openaiAttempts.length).toBe(1); - expect(groqAttempts.length).toBe(1); + expectProviderAttemptCounts({ openai: 1, groq: 1 }); }); }); @@ -783,14 +690,7 @@ describe("runWithModelFallback + runEmbeddedPiAgent failover behavior", () => { expect(result.model).toBe("mock-2"); expect(result.result.payloads?.[0]?.text ?? "").toContain("fallback ok"); - const openaiAttempts = runEmbeddedAttemptMock.mock.calls.filter( - (call) => (call[0] as { provider?: string })?.provider === "openai", - ); - const groqAttempts = runEmbeddedAttemptMock.mock.calls.filter( - (call) => (call[0] as { provider?: string })?.provider === "groq", - ); - expect(openaiAttempts.length).toBe(2); - expect(groqAttempts.length).toBe(1); + expectProviderAttemptCounts({ openai: 2, groq: 1 }); }); }); @@ -816,14 +716,7 @@ describe("runWithModelFallback + runEmbeddedPiAgent failover behavior", () => { expect(result.attempts[0]?.reason).toBe("rate_limit"); expect(result.result.payloads?.[0]?.text ?? "").toContain("fallback ok"); - const openaiAttempts = runEmbeddedAttemptMock.mock.calls.filter( - (call) => (call[0] as { provider?: string })?.provider === "openai", - ); - const groqAttempts = runEmbeddedAttemptMock.mock.calls.filter( - (call) => (call[0] as { provider?: string })?.provider === "groq", - ); - expect(openaiAttempts.length).toBe(3); - expect(groqAttempts.length).toBe(1); + expectProviderAttemptCounts({ openai: 3, groq: 1 }); }); }); @@ -845,14 +738,7 @@ describe("runWithModelFallback + runEmbeddedPiAgent failover behavior", () => { }); expect(result.provider).toBe("groq"); - const openaiAttempts = runEmbeddedAttemptMock.mock.calls.filter( - (call) => (call[0] as { provider?: string })?.provider === "openai", - ); - const groqAttempts = runEmbeddedAttemptMock.mock.calls.filter( - (call) => (call[0] as { provider?: string })?.provider === "groq", - ); - expect(openaiAttempts.length).toBe(1); - expect(groqAttempts.length).toBe(1); + expectProviderAttemptCounts({ openai: 1, groq: 1 }); }); }); @@ -872,14 +758,7 @@ describe("runWithModelFallback + runEmbeddedPiAgent failover behavior", () => { expect(result.provider).toBe("groq"); expect(result.model).toBe("mock-2"); - const openaiAttempts = runEmbeddedAttemptMock.mock.calls.filter( - (call) => (call[0] as { provider?: string })?.provider === "openai", - ); - const groqAttempts = runEmbeddedAttemptMock.mock.calls.filter( - (call) => (call[0] as { provider?: string })?.provider === "groq", - ); - expect(openaiAttempts.length).toBe(2); - expect(groqAttempts.length).toBe(1); + expectProviderAttemptCounts({ openai: 2, groq: 1 }); }); }); @@ -901,14 +780,7 @@ describe("runWithModelFallback + runEmbeddedPiAgent failover behavior", () => { }); expect(result.provider).toBe("groq"); - const openaiAttempts = runEmbeddedAttemptMock.mock.calls.filter( - (call) => (call[0] as { provider?: string })?.provider === "openai", - ); - const groqAttempts = runEmbeddedAttemptMock.mock.calls.filter( - (call) => (call[0] as { provider?: string })?.provider === "groq", - ); - expect(openaiAttempts.length).toBe(1); - expect(groqAttempts.length).toBe(1); + expectProviderAttemptCounts({ openai: 1, groq: 1 }); }); }); });