mirror of
https://github.com/openclaw/openclaw.git
synced 2026-03-22 07:20:59 +00:00
fix(agents): register simple completion transports
This commit is contained in:
96
src/agents/simple-completion-transport.test.ts
Normal file
96
src/agents/simple-completion-transport.test.ts
Normal file
@@ -0,0 +1,96 @@
|
||||
import type { Model } from "@mariozechner/pi-ai";
|
||||
import { beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import type { OpenClawConfig } from "../config/config.js";
|
||||
|
||||
const createAnthropicVertexStreamFnForModel = vi.hoisted(() => vi.fn());
|
||||
const ensureCustomApiRegistered = vi.hoisted(() => vi.fn());
|
||||
const createConfiguredOllamaStreamFn = vi.hoisted(() => vi.fn());
|
||||
|
||||
vi.mock("./anthropic-vertex-stream.js", () => ({
|
||||
createAnthropicVertexStreamFnForModel,
|
||||
}));
|
||||
|
||||
vi.mock("./custom-api-registry.js", () => ({
|
||||
ensureCustomApiRegistered,
|
||||
}));
|
||||
|
||||
vi.mock("./ollama-stream.js", () => ({
|
||||
createConfiguredOllamaStreamFn,
|
||||
}));
|
||||
|
||||
import { prepareModelForSimpleCompletion } from "./simple-completion-transport.js";
|
||||
|
||||
describe("prepareModelForSimpleCompletion", () => {
|
||||
beforeEach(() => {
|
||||
createAnthropicVertexStreamFnForModel.mockReset();
|
||||
ensureCustomApiRegistered.mockReset();
|
||||
createConfiguredOllamaStreamFn.mockReset();
|
||||
createAnthropicVertexStreamFnForModel.mockReturnValue("vertex-stream");
|
||||
createConfiguredOllamaStreamFn.mockReturnValue("ollama-stream");
|
||||
});
|
||||
|
||||
it("registers the configured Ollama transport and keeps the original api", () => {
|
||||
const model: Model<"ollama"> = {
|
||||
id: "llama3",
|
||||
name: "Llama 3",
|
||||
api: "ollama",
|
||||
provider: "ollama",
|
||||
baseUrl: "http://localhost:11434",
|
||||
reasoning: false,
|
||||
input: ["text"],
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
|
||||
contextWindow: 8192,
|
||||
maxTokens: 4096,
|
||||
headers: {},
|
||||
};
|
||||
const cfg: OpenClawConfig = {
|
||||
models: {
|
||||
providers: {
|
||||
ollama: {
|
||||
baseUrl: "http://remote-ollama:11434",
|
||||
models: [],
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const result = prepareModelForSimpleCompletion({
|
||||
model,
|
||||
cfg,
|
||||
});
|
||||
|
||||
expect(createConfiguredOllamaStreamFn).toHaveBeenCalledWith({
|
||||
model,
|
||||
providerBaseUrl: "http://remote-ollama:11434",
|
||||
});
|
||||
expect(ensureCustomApiRegistered).toHaveBeenCalledWith("ollama", "ollama-stream");
|
||||
expect(result).toBe(model);
|
||||
});
|
||||
|
||||
it("uses a custom api alias for Anthropic Vertex simple completions", () => {
|
||||
const model: Model<"anthropic-messages"> = {
|
||||
id: "claude-sonnet",
|
||||
name: "Claude Sonnet",
|
||||
api: "anthropic-messages",
|
||||
provider: "anthropic-vertex",
|
||||
baseUrl: "https://us-central1-aiplatform.googleapis.com",
|
||||
reasoning: true,
|
||||
input: ["text"],
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
|
||||
contextWindow: 200000,
|
||||
maxTokens: 8192,
|
||||
};
|
||||
|
||||
const result = prepareModelForSimpleCompletion({ model });
|
||||
|
||||
expect(createAnthropicVertexStreamFnForModel).toHaveBeenCalledWith(model);
|
||||
expect(ensureCustomApiRegistered).toHaveBeenCalledWith(
|
||||
"openclaw-anthropic-vertex-simple:https%3A%2F%2Fus-central1-aiplatform.googleapis.com",
|
||||
"vertex-stream",
|
||||
);
|
||||
expect(result).toEqual({
|
||||
...model,
|
||||
api: "openclaw-anthropic-vertex-simple:https%3A%2F%2Fus-central1-aiplatform.googleapis.com",
|
||||
});
|
||||
});
|
||||
});
|
||||
39
src/agents/simple-completion-transport.ts
Normal file
39
src/agents/simple-completion-transport.ts
Normal file
@@ -0,0 +1,39 @@
|
||||
import type { Api, Model } from "@mariozechner/pi-ai";
|
||||
import type { OpenClawConfig } from "../config/config.js";
|
||||
import { createAnthropicVertexStreamFnForModel } from "./anthropic-vertex-stream.js";
|
||||
import { ensureCustomApiRegistered } from "./custom-api-registry.js";
|
||||
import { createConfiguredOllamaStreamFn } from "./ollama-stream.js";
|
||||
|
||||
function resolveAnthropicVertexSimpleApi(baseUrl?: string): Api {
|
||||
const suffix = baseUrl?.trim() ? encodeURIComponent(baseUrl.trim()) : "default";
|
||||
return `openclaw-anthropic-vertex-simple:${suffix}`;
|
||||
}
|
||||
|
||||
export function prepareModelForSimpleCompletion<TApi extends Api>(params: {
|
||||
model: Model<TApi>;
|
||||
cfg?: OpenClawConfig;
|
||||
}): Model<Api> {
|
||||
const { model, cfg } = params;
|
||||
if (model.api === "ollama") {
|
||||
const providerBaseUrl =
|
||||
typeof cfg?.models?.providers?.[model.provider]?.baseUrl === "string"
|
||||
? cfg.models.providers[model.provider]?.baseUrl
|
||||
: undefined;
|
||||
ensureCustomApiRegistered(
|
||||
model.api,
|
||||
createConfiguredOllamaStreamFn({
|
||||
model,
|
||||
providerBaseUrl,
|
||||
}),
|
||||
);
|
||||
return model;
|
||||
}
|
||||
|
||||
if (model.provider === "anthropic-vertex") {
|
||||
const api = resolveAnthropicVertexSimpleApi(model.baseUrl);
|
||||
ensureCustomApiRegistered(api, createAnthropicVertexStreamFnForModel(model));
|
||||
return { ...model, api };
|
||||
}
|
||||
|
||||
return model;
|
||||
}
|
||||
@@ -5,6 +5,7 @@ const getApiKeyForModel = vi.hoisted(() => vi.fn());
|
||||
const requireApiKey = vi.hoisted(() => vi.fn());
|
||||
const resolveDefaultModelForAgent = vi.hoisted(() => vi.fn());
|
||||
const resolveModelAsync = vi.hoisted(() => vi.fn());
|
||||
const prepareModelForSimpleCompletion = vi.hoisted(() => vi.fn());
|
||||
|
||||
vi.mock("@mariozechner/pi-ai", () => ({
|
||||
completeSimple,
|
||||
@@ -23,6 +24,10 @@ vi.mock("../../agents/pi-embedded-runner/model.js", () => ({
|
||||
resolveModelAsync,
|
||||
}));
|
||||
|
||||
vi.mock("../../agents/simple-completion-transport.js", () => ({
|
||||
prepareModelForSimpleCompletion,
|
||||
}));
|
||||
|
||||
import { generateTopicLabel, resolveAutoTopicLabelConfig } from "./auto-topic-label.js";
|
||||
|
||||
describe("resolveAutoTopicLabelConfig", () => {
|
||||
@@ -117,6 +122,7 @@ describe("generateTopicLabel", () => {
|
||||
requireApiKey.mockReset();
|
||||
resolveDefaultModelForAgent.mockReset();
|
||||
resolveModelAsync.mockReset();
|
||||
prepareModelForSimpleCompletion.mockReset();
|
||||
|
||||
resolveDefaultModelForAgent.mockReturnValue({ provider: "openai", model: "gpt-test" });
|
||||
resolveModelAsync.mockResolvedValue({
|
||||
@@ -124,6 +130,7 @@ describe("generateTopicLabel", () => {
|
||||
authStorage: {},
|
||||
modelRegistry: {},
|
||||
});
|
||||
prepareModelForSimpleCompletion.mockImplementation(({ model }) => model);
|
||||
getApiKeyForModel.mockResolvedValue({ apiKey: "resolved-key", mode: "api-key" });
|
||||
requireApiKey.mockReturnValue("resolved-key");
|
||||
completeSimple.mockResolvedValue({
|
||||
@@ -155,5 +162,9 @@ describe("generateTopicLabel", () => {
|
||||
cfg: {},
|
||||
agentDir: "/tmp/agents/billing/agent",
|
||||
});
|
||||
expect(prepareModelForSimpleCompletion).toHaveBeenCalledWith({
|
||||
model: { provider: "openai" },
|
||||
cfg: {},
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -9,6 +9,7 @@ import { completeSimple, type TextContent } from "@mariozechner/pi-ai";
|
||||
import { getApiKeyForModel, requireApiKey } from "../../agents/model-auth.js";
|
||||
import { resolveDefaultModelForAgent } from "../../agents/model-selection.js";
|
||||
import { resolveModelAsync } from "../../agents/pi-embedded-runner/model.js";
|
||||
import { prepareModelForSimpleCompletion } from "../../agents/simple-completion-transport.js";
|
||||
import type { OpenClawConfig } from "../../config/config.js";
|
||||
import { logVerbose } from "../../globals.js";
|
||||
|
||||
@@ -53,9 +54,10 @@ export async function generateTopicLabel(params: {
|
||||
logVerbose(`auto-topic-label: failed to resolve model ${modelRef.provider}/${modelRef.model}`);
|
||||
return null;
|
||||
}
|
||||
const completionModel = prepareModelForSimpleCompletion({ model: resolved.model, cfg });
|
||||
|
||||
const apiKey = requireApiKey(
|
||||
await getApiKeyForModel({ model: resolved.model, cfg, agentDir }),
|
||||
await getApiKeyForModel({ model: completionModel, cfg, agentDir }),
|
||||
modelRef.provider,
|
||||
);
|
||||
|
||||
@@ -63,7 +65,7 @@ export async function generateTopicLabel(params: {
|
||||
const timeout = setTimeout(() => controller.abort(), TIMEOUT_MS);
|
||||
try {
|
||||
const result = await completeSimple(
|
||||
resolved.model,
|
||||
completionModel,
|
||||
{
|
||||
messages: [
|
||||
{
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import { rmSync, statSync } from "node:fs";
|
||||
import { completeSimple, type TextContent } from "@mariozechner/pi-ai";
|
||||
import { EdgeTTS } from "node-edge-tts";
|
||||
import { ensureCustomApiRegistered } from "../agents/custom-api-registry.js";
|
||||
import { getApiKeyForModel, requireApiKey } from "../agents/model-auth.js";
|
||||
import {
|
||||
buildModelAliasIndex,
|
||||
@@ -9,8 +8,8 @@ import {
|
||||
resolveModelRefFromString,
|
||||
type ModelRef,
|
||||
} from "../agents/model-selection.js";
|
||||
import { createConfiguredOllamaStreamFn } from "../agents/ollama-stream.js";
|
||||
import { resolveModelAsync } from "../agents/pi-embedded-runner/model.js";
|
||||
import { prepareModelForSimpleCompletion } from "../agents/simple-completion-transport.js";
|
||||
import type { OpenClawConfig } from "../config/config.js";
|
||||
import type {
|
||||
ResolvedTtsConfig,
|
||||
@@ -463,8 +462,9 @@ export async function summarizeText(params: {
|
||||
if (!resolved.model) {
|
||||
throw new Error(resolved.error ?? `Unknown summary model: ${ref.provider}/${ref.model}`);
|
||||
}
|
||||
const completionModel = prepareModelForSimpleCompletion({ model: resolved.model, cfg });
|
||||
const apiKey = requireApiKey(
|
||||
await getApiKeyForModel({ model: resolved.model, cfg }),
|
||||
await getApiKeyForModel({ model: completionModel, cfg }),
|
||||
ref.provider,
|
||||
);
|
||||
|
||||
@@ -473,21 +473,8 @@ export async function summarizeText(params: {
|
||||
const timeout = setTimeout(() => controller.abort(), timeoutMs);
|
||||
|
||||
try {
|
||||
if (resolved.model.api === "ollama") {
|
||||
const providerBaseUrl =
|
||||
typeof cfg.models?.providers?.[resolved.model.provider]?.baseUrl === "string"
|
||||
? cfg.models.providers[resolved.model.provider]?.baseUrl
|
||||
: undefined;
|
||||
ensureCustomApiRegistered(
|
||||
resolved.model.api,
|
||||
createConfiguredOllamaStreamFn({
|
||||
model: resolved.model,
|
||||
providerBaseUrl,
|
||||
}),
|
||||
);
|
||||
}
|
||||
const res = await completeSimple(
|
||||
resolved.model,
|
||||
completionModel,
|
||||
{
|
||||
messages: [
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user