fix(agents): register simple completion transports

This commit is contained in:
Ayaan Zaidi
2026-03-21 17:22:31 +05:30
parent 466debb75c
commit 42e708d005
5 changed files with 154 additions and 19 deletions

View File

@@ -0,0 +1,96 @@
import type { Model } from "@mariozechner/pi-ai";
import { beforeEach, describe, expect, it, vi } from "vitest";
import type { OpenClawConfig } from "../config/config.js";
const createAnthropicVertexStreamFnForModel = vi.hoisted(() => vi.fn());
const ensureCustomApiRegistered = vi.hoisted(() => vi.fn());
const createConfiguredOllamaStreamFn = vi.hoisted(() => vi.fn());
vi.mock("./anthropic-vertex-stream.js", () => ({
createAnthropicVertexStreamFnForModel,
}));
vi.mock("./custom-api-registry.js", () => ({
ensureCustomApiRegistered,
}));
vi.mock("./ollama-stream.js", () => ({
createConfiguredOllamaStreamFn,
}));
import { prepareModelForSimpleCompletion } from "./simple-completion-transport.js";
describe("prepareModelForSimpleCompletion", () => {
beforeEach(() => {
createAnthropicVertexStreamFnForModel.mockReset();
ensureCustomApiRegistered.mockReset();
createConfiguredOllamaStreamFn.mockReset();
createAnthropicVertexStreamFnForModel.mockReturnValue("vertex-stream");
createConfiguredOllamaStreamFn.mockReturnValue("ollama-stream");
});
it("registers the configured Ollama transport and keeps the original api", () => {
const model: Model<"ollama"> = {
id: "llama3",
name: "Llama 3",
api: "ollama",
provider: "ollama",
baseUrl: "http://localhost:11434",
reasoning: false,
input: ["text"],
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
contextWindow: 8192,
maxTokens: 4096,
headers: {},
};
const cfg: OpenClawConfig = {
models: {
providers: {
ollama: {
baseUrl: "http://remote-ollama:11434",
models: [],
},
},
},
};
const result = prepareModelForSimpleCompletion({
model,
cfg,
});
expect(createConfiguredOllamaStreamFn).toHaveBeenCalledWith({
model,
providerBaseUrl: "http://remote-ollama:11434",
});
expect(ensureCustomApiRegistered).toHaveBeenCalledWith("ollama", "ollama-stream");
expect(result).toBe(model);
});
it("uses a custom api alias for Anthropic Vertex simple completions", () => {
const model: Model<"anthropic-messages"> = {
id: "claude-sonnet",
name: "Claude Sonnet",
api: "anthropic-messages",
provider: "anthropic-vertex",
baseUrl: "https://us-central1-aiplatform.googleapis.com",
reasoning: true,
input: ["text"],
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
contextWindow: 200000,
maxTokens: 8192,
};
const result = prepareModelForSimpleCompletion({ model });
expect(createAnthropicVertexStreamFnForModel).toHaveBeenCalledWith(model);
expect(ensureCustomApiRegistered).toHaveBeenCalledWith(
"openclaw-anthropic-vertex-simple:https%3A%2F%2Fus-central1-aiplatform.googleapis.com",
"vertex-stream",
);
expect(result).toEqual({
...model,
api: "openclaw-anthropic-vertex-simple:https%3A%2F%2Fus-central1-aiplatform.googleapis.com",
});
});
});

View File

@@ -0,0 +1,39 @@
import type { Api, Model } from "@mariozechner/pi-ai";
import type { OpenClawConfig } from "../config/config.js";
import { createAnthropicVertexStreamFnForModel } from "./anthropic-vertex-stream.js";
import { ensureCustomApiRegistered } from "./custom-api-registry.js";
import { createConfiguredOllamaStreamFn } from "./ollama-stream.js";
function resolveAnthropicVertexSimpleApi(baseUrl?: string): Api {
const suffix = baseUrl?.trim() ? encodeURIComponent(baseUrl.trim()) : "default";
return `openclaw-anthropic-vertex-simple:${suffix}`;
}
export function prepareModelForSimpleCompletion<TApi extends Api>(params: {
model: Model<TApi>;
cfg?: OpenClawConfig;
}): Model<Api> {
const { model, cfg } = params;
if (model.api === "ollama") {
const providerBaseUrl =
typeof cfg?.models?.providers?.[model.provider]?.baseUrl === "string"
? cfg.models.providers[model.provider]?.baseUrl
: undefined;
ensureCustomApiRegistered(
model.api,
createConfiguredOllamaStreamFn({
model,
providerBaseUrl,
}),
);
return model;
}
if (model.provider === "anthropic-vertex") {
const api = resolveAnthropicVertexSimpleApi(model.baseUrl);
ensureCustomApiRegistered(api, createAnthropicVertexStreamFnForModel(model));
return { ...model, api };
}
return model;
}

View File

@@ -5,6 +5,7 @@ const getApiKeyForModel = vi.hoisted(() => vi.fn());
const requireApiKey = vi.hoisted(() => vi.fn());
const resolveDefaultModelForAgent = vi.hoisted(() => vi.fn());
const resolveModelAsync = vi.hoisted(() => vi.fn());
const prepareModelForSimpleCompletion = vi.hoisted(() => vi.fn());
vi.mock("@mariozechner/pi-ai", () => ({
completeSimple,
@@ -23,6 +24,10 @@ vi.mock("../../agents/pi-embedded-runner/model.js", () => ({
resolveModelAsync,
}));
vi.mock("../../agents/simple-completion-transport.js", () => ({
prepareModelForSimpleCompletion,
}));
import { generateTopicLabel, resolveAutoTopicLabelConfig } from "./auto-topic-label.js";
describe("resolveAutoTopicLabelConfig", () => {
@@ -117,6 +122,7 @@ describe("generateTopicLabel", () => {
requireApiKey.mockReset();
resolveDefaultModelForAgent.mockReset();
resolveModelAsync.mockReset();
prepareModelForSimpleCompletion.mockReset();
resolveDefaultModelForAgent.mockReturnValue({ provider: "openai", model: "gpt-test" });
resolveModelAsync.mockResolvedValue({
@@ -124,6 +130,7 @@ describe("generateTopicLabel", () => {
authStorage: {},
modelRegistry: {},
});
prepareModelForSimpleCompletion.mockImplementation(({ model }) => model);
getApiKeyForModel.mockResolvedValue({ apiKey: "resolved-key", mode: "api-key" });
requireApiKey.mockReturnValue("resolved-key");
completeSimple.mockResolvedValue({
@@ -155,5 +162,9 @@ describe("generateTopicLabel", () => {
cfg: {},
agentDir: "/tmp/agents/billing/agent",
});
expect(prepareModelForSimpleCompletion).toHaveBeenCalledWith({
model: { provider: "openai" },
cfg: {},
});
});
});

View File

@@ -9,6 +9,7 @@ import { completeSimple, type TextContent } from "@mariozechner/pi-ai";
import { getApiKeyForModel, requireApiKey } from "../../agents/model-auth.js";
import { resolveDefaultModelForAgent } from "../../agents/model-selection.js";
import { resolveModelAsync } from "../../agents/pi-embedded-runner/model.js";
import { prepareModelForSimpleCompletion } from "../../agents/simple-completion-transport.js";
import type { OpenClawConfig } from "../../config/config.js";
import { logVerbose } from "../../globals.js";
@@ -53,9 +54,10 @@ export async function generateTopicLabel(params: {
logVerbose(`auto-topic-label: failed to resolve model ${modelRef.provider}/${modelRef.model}`);
return null;
}
const completionModel = prepareModelForSimpleCompletion({ model: resolved.model, cfg });
const apiKey = requireApiKey(
await getApiKeyForModel({ model: resolved.model, cfg, agentDir }),
await getApiKeyForModel({ model: completionModel, cfg, agentDir }),
modelRef.provider,
);
@@ -63,7 +65,7 @@ export async function generateTopicLabel(params: {
const timeout = setTimeout(() => controller.abort(), TIMEOUT_MS);
try {
const result = await completeSimple(
resolved.model,
completionModel,
{
messages: [
{

View File

@@ -1,7 +1,6 @@
import { rmSync, statSync } from "node:fs";
import { completeSimple, type TextContent } from "@mariozechner/pi-ai";
import { EdgeTTS } from "node-edge-tts";
import { ensureCustomApiRegistered } from "../agents/custom-api-registry.js";
import { getApiKeyForModel, requireApiKey } from "../agents/model-auth.js";
import {
buildModelAliasIndex,
@@ -9,8 +8,8 @@ import {
resolveModelRefFromString,
type ModelRef,
} from "../agents/model-selection.js";
import { createConfiguredOllamaStreamFn } from "../agents/ollama-stream.js";
import { resolveModelAsync } from "../agents/pi-embedded-runner/model.js";
import { prepareModelForSimpleCompletion } from "../agents/simple-completion-transport.js";
import type { OpenClawConfig } from "../config/config.js";
import type {
ResolvedTtsConfig,
@@ -463,8 +462,9 @@ export async function summarizeText(params: {
if (!resolved.model) {
throw new Error(resolved.error ?? `Unknown summary model: ${ref.provider}/${ref.model}`);
}
const completionModel = prepareModelForSimpleCompletion({ model: resolved.model, cfg });
const apiKey = requireApiKey(
await getApiKeyForModel({ model: resolved.model, cfg }),
await getApiKeyForModel({ model: completionModel, cfg }),
ref.provider,
);
@@ -473,21 +473,8 @@ export async function summarizeText(params: {
const timeout = setTimeout(() => controller.abort(), timeoutMs);
try {
if (resolved.model.api === "ollama") {
const providerBaseUrl =
typeof cfg.models?.providers?.[resolved.model.provider]?.baseUrl === "string"
? cfg.models.providers[resolved.model.provider]?.baseUrl
: undefined;
ensureCustomApiRegistered(
resolved.model.api,
createConfiguredOllamaStreamFn({
model: resolved.model,
providerBaseUrl,
}),
);
}
const res = await completeSimple(
resolved.model,
completionModel,
{
messages: [
{