From 42e708d0055b2c519fcdd41ad976304d851ec8ab Mon Sep 17 00:00:00 2001 From: Ayaan Zaidi Date: Sat, 21 Mar 2026 17:22:31 +0530 Subject: [PATCH] fix(agents): register simple completion transports --- .../simple-completion-transport.test.ts | 96 +++++++++++++++++++ src/agents/simple-completion-transport.ts | 39 ++++++++ src/auto-reply/reply/auto-topic-label.test.ts | 11 +++ src/auto-reply/reply/auto-topic-label.ts | 6 +- src/tts/tts-core.ts | 21 +--- 5 files changed, 154 insertions(+), 19 deletions(-) create mode 100644 src/agents/simple-completion-transport.test.ts create mode 100644 src/agents/simple-completion-transport.ts diff --git a/src/agents/simple-completion-transport.test.ts b/src/agents/simple-completion-transport.test.ts new file mode 100644 index 00000000000..9621faade47 --- /dev/null +++ b/src/agents/simple-completion-transport.test.ts @@ -0,0 +1,96 @@ +import type { Model } from "@mariozechner/pi-ai"; +import { beforeEach, describe, expect, it, vi } from "vitest"; +import type { OpenClawConfig } from "../config/config.js"; + +const createAnthropicVertexStreamFnForModel = vi.hoisted(() => vi.fn()); +const ensureCustomApiRegistered = vi.hoisted(() => vi.fn()); +const createConfiguredOllamaStreamFn = vi.hoisted(() => vi.fn()); + +vi.mock("./anthropic-vertex-stream.js", () => ({ + createAnthropicVertexStreamFnForModel, +})); + +vi.mock("./custom-api-registry.js", () => ({ + ensureCustomApiRegistered, +})); + +vi.mock("./ollama-stream.js", () => ({ + createConfiguredOllamaStreamFn, +})); + +import { prepareModelForSimpleCompletion } from "./simple-completion-transport.js"; + +describe("prepareModelForSimpleCompletion", () => { + beforeEach(() => { + createAnthropicVertexStreamFnForModel.mockReset(); + ensureCustomApiRegistered.mockReset(); + createConfiguredOllamaStreamFn.mockReset(); + createAnthropicVertexStreamFnForModel.mockReturnValue("vertex-stream"); + createConfiguredOllamaStreamFn.mockReturnValue("ollama-stream"); + }); + + it("registers the configured Ollama transport and keeps the original api", () => { + const model: Model<"ollama"> = { + id: "llama3", + name: "Llama 3", + api: "ollama", + provider: "ollama", + baseUrl: "http://localhost:11434", + reasoning: false, + input: ["text"], + cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 }, + contextWindow: 8192, + maxTokens: 4096, + headers: {}, + }; + const cfg: OpenClawConfig = { + models: { + providers: { + ollama: { + baseUrl: "http://remote-ollama:11434", + models: [], + }, + }, + }, + }; + + const result = prepareModelForSimpleCompletion({ + model, + cfg, + }); + + expect(createConfiguredOllamaStreamFn).toHaveBeenCalledWith({ + model, + providerBaseUrl: "http://remote-ollama:11434", + }); + expect(ensureCustomApiRegistered).toHaveBeenCalledWith("ollama", "ollama-stream"); + expect(result).toBe(model); + }); + + it("uses a custom api alias for Anthropic Vertex simple completions", () => { + const model: Model<"anthropic-messages"> = { + id: "claude-sonnet", + name: "Claude Sonnet", + api: "anthropic-messages", + provider: "anthropic-vertex", + baseUrl: "https://us-central1-aiplatform.googleapis.com", + reasoning: true, + input: ["text"], + cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 }, + contextWindow: 200000, + maxTokens: 8192, + }; + + const result = prepareModelForSimpleCompletion({ model }); + + expect(createAnthropicVertexStreamFnForModel).toHaveBeenCalledWith(model); + expect(ensureCustomApiRegistered).toHaveBeenCalledWith( + "openclaw-anthropic-vertex-simple:https%3A%2F%2Fus-central1-aiplatform.googleapis.com", + "vertex-stream", + ); + expect(result).toEqual({ + ...model, + api: "openclaw-anthropic-vertex-simple:https%3A%2F%2Fus-central1-aiplatform.googleapis.com", + }); + }); +}); diff --git a/src/agents/simple-completion-transport.ts b/src/agents/simple-completion-transport.ts new file mode 100644 index 00000000000..74f0e9a466c --- /dev/null +++ b/src/agents/simple-completion-transport.ts @@ -0,0 +1,39 @@ +import type { Api, Model } from "@mariozechner/pi-ai"; +import type { OpenClawConfig } from "../config/config.js"; +import { createAnthropicVertexStreamFnForModel } from "./anthropic-vertex-stream.js"; +import { ensureCustomApiRegistered } from "./custom-api-registry.js"; +import { createConfiguredOllamaStreamFn } from "./ollama-stream.js"; + +function resolveAnthropicVertexSimpleApi(baseUrl?: string): Api { + const suffix = baseUrl?.trim() ? encodeURIComponent(baseUrl.trim()) : "default"; + return `openclaw-anthropic-vertex-simple:${suffix}`; +} + +export function prepareModelForSimpleCompletion(params: { + model: Model; + cfg?: OpenClawConfig; +}): Model { + const { model, cfg } = params; + if (model.api === "ollama") { + const providerBaseUrl = + typeof cfg?.models?.providers?.[model.provider]?.baseUrl === "string" + ? cfg.models.providers[model.provider]?.baseUrl + : undefined; + ensureCustomApiRegistered( + model.api, + createConfiguredOllamaStreamFn({ + model, + providerBaseUrl, + }), + ); + return model; + } + + if (model.provider === "anthropic-vertex") { + const api = resolveAnthropicVertexSimpleApi(model.baseUrl); + ensureCustomApiRegistered(api, createAnthropicVertexStreamFnForModel(model)); + return { ...model, api }; + } + + return model; +} diff --git a/src/auto-reply/reply/auto-topic-label.test.ts b/src/auto-reply/reply/auto-topic-label.test.ts index 2db02b4eebc..0cf3b60af47 100644 --- a/src/auto-reply/reply/auto-topic-label.test.ts +++ b/src/auto-reply/reply/auto-topic-label.test.ts @@ -5,6 +5,7 @@ const getApiKeyForModel = vi.hoisted(() => vi.fn()); const requireApiKey = vi.hoisted(() => vi.fn()); const resolveDefaultModelForAgent = vi.hoisted(() => vi.fn()); const resolveModelAsync = vi.hoisted(() => vi.fn()); +const prepareModelForSimpleCompletion = vi.hoisted(() => vi.fn()); vi.mock("@mariozechner/pi-ai", () => ({ completeSimple, @@ -23,6 +24,10 @@ vi.mock("../../agents/pi-embedded-runner/model.js", () => ({ resolveModelAsync, })); +vi.mock("../../agents/simple-completion-transport.js", () => ({ + prepareModelForSimpleCompletion, +})); + import { generateTopicLabel, resolveAutoTopicLabelConfig } from "./auto-topic-label.js"; describe("resolveAutoTopicLabelConfig", () => { @@ -117,6 +122,7 @@ describe("generateTopicLabel", () => { requireApiKey.mockReset(); resolveDefaultModelForAgent.mockReset(); resolveModelAsync.mockReset(); + prepareModelForSimpleCompletion.mockReset(); resolveDefaultModelForAgent.mockReturnValue({ provider: "openai", model: "gpt-test" }); resolveModelAsync.mockResolvedValue({ @@ -124,6 +130,7 @@ describe("generateTopicLabel", () => { authStorage: {}, modelRegistry: {}, }); + prepareModelForSimpleCompletion.mockImplementation(({ model }) => model); getApiKeyForModel.mockResolvedValue({ apiKey: "resolved-key", mode: "api-key" }); requireApiKey.mockReturnValue("resolved-key"); completeSimple.mockResolvedValue({ @@ -155,5 +162,9 @@ describe("generateTopicLabel", () => { cfg: {}, agentDir: "/tmp/agents/billing/agent", }); + expect(prepareModelForSimpleCompletion).toHaveBeenCalledWith({ + model: { provider: "openai" }, + cfg: {}, + }); }); }); diff --git a/src/auto-reply/reply/auto-topic-label.ts b/src/auto-reply/reply/auto-topic-label.ts index bf42378e18c..911f7842329 100644 --- a/src/auto-reply/reply/auto-topic-label.ts +++ b/src/auto-reply/reply/auto-topic-label.ts @@ -9,6 +9,7 @@ import { completeSimple, type TextContent } from "@mariozechner/pi-ai"; import { getApiKeyForModel, requireApiKey } from "../../agents/model-auth.js"; import { resolveDefaultModelForAgent } from "../../agents/model-selection.js"; import { resolveModelAsync } from "../../agents/pi-embedded-runner/model.js"; +import { prepareModelForSimpleCompletion } from "../../agents/simple-completion-transport.js"; import type { OpenClawConfig } from "../../config/config.js"; import { logVerbose } from "../../globals.js"; @@ -53,9 +54,10 @@ export async function generateTopicLabel(params: { logVerbose(`auto-topic-label: failed to resolve model ${modelRef.provider}/${modelRef.model}`); return null; } + const completionModel = prepareModelForSimpleCompletion({ model: resolved.model, cfg }); const apiKey = requireApiKey( - await getApiKeyForModel({ model: resolved.model, cfg, agentDir }), + await getApiKeyForModel({ model: completionModel, cfg, agentDir }), modelRef.provider, ); @@ -63,7 +65,7 @@ export async function generateTopicLabel(params: { const timeout = setTimeout(() => controller.abort(), TIMEOUT_MS); try { const result = await completeSimple( - resolved.model, + completionModel, { messages: [ { diff --git a/src/tts/tts-core.ts b/src/tts/tts-core.ts index 7bdc8f56288..f665b005a51 100644 --- a/src/tts/tts-core.ts +++ b/src/tts/tts-core.ts @@ -1,7 +1,6 @@ import { rmSync, statSync } from "node:fs"; import { completeSimple, type TextContent } from "@mariozechner/pi-ai"; import { EdgeTTS } from "node-edge-tts"; -import { ensureCustomApiRegistered } from "../agents/custom-api-registry.js"; import { getApiKeyForModel, requireApiKey } from "../agents/model-auth.js"; import { buildModelAliasIndex, @@ -9,8 +8,8 @@ import { resolveModelRefFromString, type ModelRef, } from "../agents/model-selection.js"; -import { createConfiguredOllamaStreamFn } from "../agents/ollama-stream.js"; import { resolveModelAsync } from "../agents/pi-embedded-runner/model.js"; +import { prepareModelForSimpleCompletion } from "../agents/simple-completion-transport.js"; import type { OpenClawConfig } from "../config/config.js"; import type { ResolvedTtsConfig, @@ -463,8 +462,9 @@ export async function summarizeText(params: { if (!resolved.model) { throw new Error(resolved.error ?? `Unknown summary model: ${ref.provider}/${ref.model}`); } + const completionModel = prepareModelForSimpleCompletion({ model: resolved.model, cfg }); const apiKey = requireApiKey( - await getApiKeyForModel({ model: resolved.model, cfg }), + await getApiKeyForModel({ model: completionModel, cfg }), ref.provider, ); @@ -473,21 +473,8 @@ export async function summarizeText(params: { const timeout = setTimeout(() => controller.abort(), timeoutMs); try { - if (resolved.model.api === "ollama") { - const providerBaseUrl = - typeof cfg.models?.providers?.[resolved.model.provider]?.baseUrl === "string" - ? cfg.models.providers[resolved.model.provider]?.baseUrl - : undefined; - ensureCustomApiRegistered( - resolved.model.api, - createConfiguredOllamaStreamFn({ - model: resolved.model, - providerBaseUrl, - }), - ); - } const res = await completeSimple( - resolved.model, + completionModel, { messages: [ {