diff --git a/CHANGELOG.md b/CHANGELOG.md index da464767e3e..e547e6eee84 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ Docs: https://docs.openclaw.ai ### Changes +- CLI/infer: add a first-class `openclaw infer ...` hub for provider-backed inference workflows across model, media, web, and embedding tasks. Thanks @Takhoffman. - Plugins/webhooks: add a bundled webhook ingress plugin so external automation can create and drive bound TaskFlows through per-route shared-secret endpoints. (#61892) Thanks @mbelinky. - Tools/media generation: preserve intent across auth-backed image, music, and video provider fallback, remap size, aspect ratio, resolution, and duration hints to the closest supported option, and surface explicit provider capabilities plus mode-aware video-to-video support. - Memory/wiki: restore the bundled `memory-wiki` stack with plugin, CLI, sync/query/apply tooling, and memory-host integration for wiki-backed memory workflows. @@ -29,6 +30,7 @@ Docs: https://docs.openclaw.ai ### Fixes - Plugins/media: when `plugins.allow` is set, capability fallback now merges bundled capability plugin ids into the allowlist (not only `plugins.entries`), so media understanding providers such as OpenAI-compatible STT load for voice transcription without requiring `openai` in `plugins.allow`. (#62205) Thanks @neeravmakwana. +- CLI/infer: keep provider-backed infer behavior aligned with actual runtime execution by fixing explicit TTS override handling, profile-aware gateway TTS prefs resolution, per-request transcription `prompt`/`language` overrides, image output MIME/extension mismatches, configured web-search fallback behavior, and agent-vs-CLI web-search execution drift. - Auth/OpenAI Codex OAuth: reload fresh on-disk credentials inside the locked refresh path and retry once after `refresh_token_reused` rotates only the stored refresh token, so relogin/restart recovery stops getting stuck on stale cached auth state. Thanks @owen-ever. - Agents/history and replies: buffer phaseless OpenAI WS text until a real assistant phase arrives, keep replay and SSE history sequence tracking aligned, hide commentary and leaked tool XML from user-visible history, and keep history-based follow-up replies on `final_answer` text only. (#61729, #61747, #61829, #61855, #61954) Thanks @100yenadmin, @afurm, and @openperf. - Plugins/channels: keep bundled channel artifact and secret-contract loading stable under lazy loading, preserve plugin-schema defaults during install, and fix Windows `file://` plus native-Jiti plugin loader paths so onboarding, doctor, `openclaw secret`, and bundled plugin installs work again. (#61832, #61836, #61853, #61856) Thanks @Zeesejo and @SuperMarioYL. diff --git a/docs/cli/capability.md b/docs/cli/capability.md new file mode 100644 index 00000000000..9128681fc69 --- /dev/null +++ b/docs/cli/capability.md @@ -0,0 +1,119 @@ +--- +summary: "Infer-first CLI for provider-backed model, image, audio, TTS, video, web, and embedding workflows" +read_when: + - Adding or modifying `openclaw infer` commands + - Designing stable headless capability automation +title: "Inference CLI" +--- + +# Inference CLI + +`openclaw infer` is the canonical headless surface for provider-backed inference workflows. + +`openclaw capability` remains supported as a fallback alias for compatibility. + +It intentionally exposes capability families, not raw gateway RPC names and not raw agent tool ids. + +## Command tree + +```text + openclaw infer + list + inspect + + model + run + list + inspect + providers + auth login + auth logout + auth status + + image + generate + edit + describe + describe-many + providers + + audio + transcribe + providers + + tts + convert + voices + providers + status + enable + disable + set-provider + + video + generate + describe + providers + + web + search + fetch + providers + + embedding + create + providers +``` + +## Transport + +Supported transport flags: + +- `--local` +- `--gateway` + +Default transport is implicit auto at the command-family level: + +- Stateless execution commands default to local. +- Gateway-managed state commands default to gateway. + +Examples: + +```bash +openclaw infer model run --prompt "hello" --json +openclaw infer image generate --prompt "friendly lobster" --json +openclaw infer tts status --json +openclaw infer embedding create --text "hello world" --json +``` + +## JSON output + +Capability commands normalize JSON output under a shared envelope: + +```json +{ + "ok": true, + "capability": "image.generate", + "transport": "local", + "provider": "openai", + "model": "gpt-image-1", + "attempts": [], + "outputs": [] +} +``` + +Top-level fields are stable: + +- `ok` +- `capability` +- `transport` +- `provider` +- `model` +- `attempts` +- `outputs` +- `error` + +## Notes + +- `model run` reuses the agent runtime so provider/model overrides behave like normal agent execution. +- `tts status` defaults to gateway because it reflects gateway-managed TTS state. diff --git a/docs/cli/index.md b/docs/cli/index.md index 571f03b38fe..0cc8179507e 100644 --- a/docs/cli/index.md +++ b/docs/cli/index.md @@ -35,6 +35,7 @@ This page describes the current CLI behavior. If commands change, update this do - [`logs`](/cli/logs) - [`system`](/cli/system) - [`models`](/cli/models) +- [`infer`](/cli/capability) - [`memory`](/cli/memory) - [`directory`](/cli/directory) - [`nodes`](/cli/nodes) @@ -248,6 +249,16 @@ openclaw [--dev] [--profile ] fallbacks list|add|remove|clear image-fallbacks list|add|remove|clear scan + infer (alias: capability) + list + inspect + model run|list|inspect|providers|auth login|logout|status + image generate|edit|describe|describe-many|providers + audio transcribe|providers + tts convert|voices|providers|status|enable|disable|set-provider + video generate|describe|providers + web search|fetch|providers + embedding create|providers auth add|login|login-github-copilot|setup-token|paste-token auth order get|set|clear sandbox diff --git a/extensions/memory-core/runtime-api.ts b/extensions/memory-core/runtime-api.ts index ea1fe308897..480dd4d28f4 100644 --- a/extensions/memory-core/runtime-api.ts +++ b/extensions/memory-core/runtime-api.ts @@ -4,7 +4,9 @@ export { DEFAULT_LOCAL_MODEL, getBuiltinMemoryEmbeddingProviderDoctorMetadata, listBuiltinAutoSelectMemoryEmbeddingProviderDoctorMetadata, + registerBuiltInMemoryEmbeddingProviders, } from "./src/memory/provider-adapters.js"; +export { createEmbeddingProvider } from "./src/memory/embeddings.js"; export { resolveMemoryCacheSummary, resolveMemoryFtsState, diff --git a/extensions/speech-core/runtime-api.ts b/extensions/speech-core/runtime-api.ts index 0a58b4cefa3..81a6c3a83c3 100644 --- a/extensions/speech-core/runtime-api.ts +++ b/extensions/speech-core/runtime-api.ts @@ -9,6 +9,7 @@ export { isTtsProviderConfigured, listSpeechVoices, maybeApplyTtsToPayload, + resolveExplicitTtsOverrides, resolveTtsAutoMode, resolveTtsConfig, resolveTtsPrefsPath, diff --git a/extensions/speech-core/src/tts.ts b/extensions/speech-core/src/tts.ts index a9f840c9e77..e2b60fd8094 100644 --- a/extensions/speech-core/src/tts.ts +++ b/extensions/speech-core/src/tts.ts @@ -25,8 +25,8 @@ import type { ReplyPayload } from "openclaw/plugin-sdk/reply-runtime"; import { isVerbose, logVerbose } from "openclaw/plugin-sdk/runtime-env"; import { resolvePreferredOpenClawTmpDir } from "openclaw/plugin-sdk/sandbox"; import { - CONFIG_DIR, normalizeOptionalString, + resolveConfigDir, resolveUserPath, stripMarkdown, } from "openclaw/plugin-sdk/text-runtime"; @@ -41,6 +41,7 @@ import { summarizeText, type SpeechModelOverridePolicy, type SpeechProviderConfig, + type SpeechProviderOverrides, type SpeechVoiceOption, type TtsDirectiveOverrides, type TtsDirectiveParseResult, @@ -173,7 +174,7 @@ function resolveTtsPrefsPathValue(prefsPath: string | undefined): string { if (envPath) { return resolveUserPath(envPath); } - return path.join(CONFIG_DIR, "settings", "tts.json"); + return path.join(resolveConfigDir(process.env), "settings", "tts.json"); } function resolveModelOverridePolicy( @@ -502,6 +503,66 @@ export function setTtsProvider(prefsPath: string, provider: TtsProvider): void { }); } +export function resolveExplicitTtsOverrides(params: { + cfg: OpenClawConfig; + prefsPath?: string; + provider?: string; + modelId?: string; + voiceId?: string; +}): TtsDirectiveOverrides { + const providerInput = params.provider?.trim(); + const modelId = params.modelId?.trim(); + const voiceId = params.voiceId?.trim(); + const config = resolveTtsConfig(params.cfg); + const prefsPath = params.prefsPath ?? resolveTtsPrefsPath(config); + const selectedProvider = + canonicalizeSpeechProviderId(providerInput, params.cfg) ?? + (modelId || voiceId ? getTtsProvider(config, prefsPath) : undefined); + + if (providerInput && !selectedProvider) { + throw new Error(`Unknown TTS provider "${providerInput}".`); + } + + if (!modelId && !voiceId) { + return selectedProvider ? { provider: selectedProvider } : {}; + } + + if (!selectedProvider) { + throw new Error("TTS model or voice overrides require a resolved provider."); + } + + const provider = getSpeechProvider(selectedProvider, params.cfg); + if (!provider) { + throw new Error(`speech provider ${selectedProvider} is not registered`); + } + if (!provider.resolveTalkOverrides) { + throw new Error( + `TTS provider "${selectedProvider}" does not support model or voice overrides.`, + ); + } + + const providerOverrides = provider.resolveTalkOverrides({ + talkProviderConfig: {}, + params: { + ...(voiceId ? { voiceId } : {}), + ...(modelId ? { modelId } : {}), + }, + }); + if ((voiceId || modelId) && (!providerOverrides || Object.keys(providerOverrides).length === 0)) { + throw new Error( + `TTS provider "${selectedProvider}" ignored the requested model or voice overrides.`, + ); + } + + const overridesRecord = providerOverrides as SpeechProviderOverrides; + return { + provider: selectedProvider, + providerOverrides: { + [provider.id]: overridesRecord, + }, + }; +} + export function getTtsMaxLength(prefsPath: string): number { const prefs = readPrefs(prefsPath); return prefs.tts?.maxLength ?? DEFAULT_TTS_MAX_LENGTH; diff --git a/src/agents/pi-embedded-runner/model.ts b/src/agents/pi-embedded-runner/model.ts index d493c14f1f2..cea4f232a92 100644 --- a/src/agents/pi-embedded-runner/model.ts +++ b/src/agents/pi-embedded-runner/model.ts @@ -232,8 +232,7 @@ function findInlineModelMatch(params: { ); } -export { buildModelAliasLines }; -export { buildInlineProviderModels }; +export { buildModelAliasLines, buildInlineProviderModels }; function resolveConfiguredProviderConfig( cfg: OpenClawConfig | undefined, @@ -336,7 +335,6 @@ function applyConfiguredProviderOverrides(params: { providerRequest, ); } - function resolveExplicitModelWithRegistry(params: { provider: string; modelId: string; diff --git a/src/agents/tools/web-search.ts b/src/agents/tools/web-search.ts index 783bba577f2..1cd107c9a07 100644 --- a/src/agents/tools/web-search.ts +++ b/src/agents/tools/web-search.ts @@ -4,6 +4,7 @@ import type { RuntimeWebSearchMetadata } from "../../secrets/runtime-web-tools.t import { resolveWebSearchDefinition, resolveWebSearchProviderId, + runWebSearch, } from "../../web-search/runtime.js"; import type { AnyAgentTool } from "./common.js"; import { jsonResult } from "./common.js"; @@ -16,16 +17,17 @@ export function createWebSearchTool(options?: { }): AnyAgentTool | null { const runtimeProviderId = options?.runtimeWebSearch?.selectedProvider ?? options?.runtimeWebSearch?.providerConfigured; + const preferRuntimeProviders = + Boolean(runtimeProviderId) && + !resolveManifestContractOwnerPluginId({ + contract: "webSearchProviders", + value: runtimeProviderId, + origin: "bundled", + config: options?.config, + }); const resolved = resolveWebSearchDefinition({ ...options, - preferRuntimeProviders: - Boolean(runtimeProviderId) && - !resolveManifestContractOwnerPluginId({ - contract: "webSearchProviders", - value: runtimeProviderId, - origin: "bundled", - config: options?.config, - }), + preferRuntimeProviders, }); if (!resolved) { return null; @@ -36,7 +38,19 @@ export function createWebSearchTool(options?: { name: "web_search", description: resolved.definition.description, parameters: resolved.definition.parameters, - execute: async (_toolCallId, args) => jsonResult(await resolved.definition.execute(args)), + execute: async (_toolCallId, args) => { + const result = await runWebSearch({ + config: options?.config, + sandboxed: options?.sandboxed, + runtimeWebSearch: options?.runtimeWebSearch, + preferRuntimeProviders, + args, + }); + return jsonResult({ + ...result.result, + provider: result.provider, + }); + }, }; } diff --git a/src/cli/capability-cli.test.ts b/src/cli/capability-cli.test.ts new file mode 100644 index 00000000000..90f31265ea8 --- /dev/null +++ b/src/cli/capability-cli.test.ts @@ -0,0 +1,903 @@ +import fs from "node:fs/promises"; +import os from "node:os"; +import path from "node:path"; +import { Command } from "commander"; +import { beforeEach, describe, expect, it, vi } from "vitest"; +import { runRegisteredCli } from "../test-utils/command-runner.js"; +import { registerCapabilityCli } from "./capability-cli.js"; + +const mocks = vi.hoisted(() => ({ + runtime: { + log: vi.fn(), + error: vi.fn(), + exit: vi.fn((code: number) => { + throw new Error(`exit ${code}`); + }), + writeJson: vi.fn(), + writeStdout: vi.fn(), + }, + loadConfig: vi.fn(() => ({})), + loadAuthProfileStoreForRuntime: vi.fn(() => ({ profiles: {}, order: {} })), + listProfilesForProvider: vi.fn(() => []), + updateAuthProfileStoreWithLock: vi.fn( + async ({ updater }: { updater: (store: any) => boolean }) => { + const store = { + version: 1, + profiles: {}, + order: {}, + lastGood: {}, + usageStats: {}, + }; + updater(store); + return store; + }, + ), + resolveMemorySearchConfig: vi.fn(() => null), + loadModelCatalog: vi.fn(async () => []), + agentCommand: vi.fn(async () => ({ + payloads: [{ text: "local reply" }], + meta: { agentMeta: { provider: "openai", model: "gpt-5.4" } }, + })), + callGateway: vi.fn(async ({ method }: { method: string }) => { + if (method === "tts.status") { + return { enabled: true, provider: "openai" }; + } + if (method === "agent") { + return { + result: { + payloads: [{ text: "gateway reply" }], + meta: { agentMeta: { provider: "anthropic", model: "claude-sonnet-4-6" } }, + }, + }; + } + return {}; + }), + describeImageFile: vi.fn(async () => ({ + text: "friendly lobster", + provider: "openai", + model: "gpt-4.1-mini", + })), + generateImage: vi.fn(), + transcribeAudioFile: vi.fn(async () => ({ text: "meeting notes" })), + textToSpeech: vi.fn(async () => ({ + success: true, + audioPath: "/tmp/tts-source.mp3", + provider: "openai", + outputFormat: "mp3", + voiceCompatible: false, + attempts: [], + })), + setTtsProvider: vi.fn(), + resolveExplicitTtsOverrides: vi.fn( + ({ + provider, + modelId, + voiceId, + }: { + provider?: string; + modelId?: string; + voiceId?: string; + }) => ({ + ...(provider ? { provider } : {}), + ...(modelId || voiceId + ? { + providerOverrides: { + [provider ?? "openai"]: { + ...(modelId ? { modelId } : {}), + ...(voiceId ? { voiceId } : {}), + }, + }, + } + : {}), + }), + ), + createEmbeddingProvider: vi.fn(async () => ({ + provider: { + id: "openai", + model: "text-embedding-3-small", + embedQuery: async () => [0.1, 0.2], + embedBatch: async (texts: string[]) => texts.map(() => [0.1, 0.2]), + }, + })), + registerMemoryEmbeddingProvider: vi.fn(), + listMemoryEmbeddingProviders: vi.fn(() => [ + { id: "openai", defaultModel: "text-embedding-3-small", transport: "remote" }, + ]), + registerBuiltInMemoryEmbeddingProviders: vi.fn(), + isWebSearchProviderConfigured: vi.fn(() => false), + isWebFetchProviderConfigured: vi.fn(() => false), + modelsStatusCommand: vi.fn( + async (_opts: unknown, runtime: { log: (...args: unknown[]) => void }) => { + runtime.log(JSON.stringify({ ok: true, providers: [{ id: "openai" }] })); + }, + ), +})); + +vi.mock("../runtime.js", () => ({ + defaultRuntime: mocks.runtime, + writeRuntimeJson: (runtime: { writeJson: (value: unknown) => void }, value: unknown) => + runtime.writeJson(value), +})); + +vi.mock("../config/config.js", () => ({ + loadConfig: (...args: unknown[]) => mocks.loadConfig(...args), +})); + +vi.mock("../agents/agent-command.js", () => ({ + agentCommand: (...args: unknown[]) => mocks.agentCommand(...args), +})); + +vi.mock("../agents/agent-scope.js", () => ({ + resolveDefaultAgentId: () => "main", + resolveAgentDir: () => "/tmp/agent", +})); + +vi.mock("../agents/model-catalog.js", () => ({ + loadModelCatalog: (...args: unknown[]) => mocks.loadModelCatalog(...args), +})); + +vi.mock("../agents/auth-profiles.js", () => ({ + loadAuthProfileStoreForRuntime: (...args: unknown[]) => + mocks.loadAuthProfileStoreForRuntime(...args), + listProfilesForProvider: (...args: unknown[]) => mocks.listProfilesForProvider(...args), +})); + +vi.mock("../agents/auth-profiles/store.js", () => ({ + updateAuthProfileStoreWithLock: (...args: unknown[]) => + mocks.updateAuthProfileStoreWithLock(...args), +})); + +vi.mock("../agents/memory-search.js", () => ({ + resolveMemorySearchConfig: (...args: unknown[]) => mocks.resolveMemorySearchConfig(...args), +})); + +vi.mock("../commands/models.js", () => ({ + modelsAuthLoginCommand: vi.fn(), + modelsStatusCommand: (...args: unknown[]) => mocks.modelsStatusCommand(...args), +})); + +vi.mock("../gateway/call.js", () => ({ + callGateway: (...args: unknown[]) => mocks.callGateway(...args), + randomIdempotencyKey: () => "run-1", +})); + +vi.mock("../gateway/connection-details.js", () => ({ + buildGatewayConnectionDetailsWithResolvers: vi.fn(() => ({ + url: "ws://127.0.0.1:18789", + urlSource: "local loopback", + message: "Gateway target: ws://127.0.0.1:18789", + })), +})); + +vi.mock("../media-understanding/runtime.js", () => ({ + describeImageFile: (...args: unknown[]) => mocks.describeImageFile(...args), + describeVideoFile: vi.fn(), + transcribeAudioFile: (...args: unknown[]) => mocks.transcribeAudioFile(...args), +})); + +vi.mock("../plugins/memory-embedding-providers.js", () => ({ + listMemoryEmbeddingProviders: (...args: unknown[]) => mocks.listMemoryEmbeddingProviders(...args), + registerMemoryEmbeddingProvider: (...args: unknown[]) => + mocks.registerMemoryEmbeddingProvider(...args), +})); + +vi.mock("../../extensions/memory-core/runtime-api.js", () => ({ + createEmbeddingProvider: (...args: unknown[]) => mocks.createEmbeddingProvider(...args), + registerBuiltInMemoryEmbeddingProviders: (...args: unknown[]) => + mocks.registerBuiltInMemoryEmbeddingProviders(...args), +})); + +vi.mock("../image-generation/runtime.js", () => ({ + generateImage: (...args: unknown[]) => mocks.generateImage(...args), + listRuntimeImageGenerationProviders: vi.fn(() => []), +})); + +vi.mock("../video-generation/runtime.js", () => ({ + generateVideo: vi.fn(), + listRuntimeVideoGenerationProviders: vi.fn(() => []), +})); + +vi.mock("../tts/tts.js", () => ({ + getTtsProvider: vi.fn(() => "openai"), + listSpeechVoices: vi.fn(async () => []), + resolveTtsConfig: vi.fn(() => ({})), + resolveTtsPrefsPath: vi.fn(() => "/tmp/tts.json"), + setTtsEnabled: vi.fn(), + setTtsProvider: (...args: unknown[]) => mocks.setTtsProvider(...args), + resolveExplicitTtsOverrides: (...args: unknown[]) => mocks.resolveExplicitTtsOverrides(...args), + textToSpeech: (...args: unknown[]) => mocks.textToSpeech(...args), +})); + +vi.mock("../tts/provider-registry.js", () => ({ + canonicalizeSpeechProviderId: vi.fn((provider: string) => provider), + listSpeechProviders: vi.fn(() => []), +})); + +vi.mock("../web-search/runtime.js", () => ({ + listWebSearchProviders: vi.fn(() => []), + isWebSearchProviderConfigured: (...args: unknown[]) => + mocks.isWebSearchProviderConfigured(...args), + runWebSearch: vi.fn(), +})); + +vi.mock("../web-fetch/runtime.js", () => ({ + listWebFetchProviders: vi.fn(() => []), + isWebFetchProviderConfigured: (...args: unknown[]) => mocks.isWebFetchProviderConfigured(...args), + resolveWebFetchDefinition: vi.fn(), +})); + +describe("capability cli", () => { + beforeEach(() => { + mocks.runtime.log.mockClear(); + mocks.runtime.error.mockClear(); + mocks.runtime.writeJson.mockClear(); + mocks.loadModelCatalog + .mockReset() + .mockResolvedValue([{ id: "gpt-5.4", provider: "openai", name: "GPT-5.4" }]); + mocks.loadAuthProfileStoreForRuntime.mockReset().mockReturnValue({ profiles: {}, order: {} }); + mocks.listProfilesForProvider.mockReset().mockReturnValue([]); + mocks.updateAuthProfileStoreWithLock + .mockReset() + .mockImplementation(async ({ updater }: { updater: (store: any) => boolean }) => { + const store = { + version: 1, + profiles: {}, + order: {}, + lastGood: {}, + usageStats: {}, + }; + updater(store); + return store; + }); + mocks.resolveMemorySearchConfig.mockReset().mockReturnValue(null); + mocks.agentCommand.mockClear(); + mocks.callGateway.mockClear().mockImplementation(async ({ method }: { method: string }) => { + if (method === "tts.status") { + return { enabled: true, provider: "openai" }; + } + if (method === "agent") { + return { + result: { + payloads: [{ text: "gateway reply" }], + meta: { agentMeta: { provider: "anthropic", model: "claude-sonnet-4-6" } }, + }, + }; + } + return {}; + }); + mocks.describeImageFile.mockClear(); + mocks.generateImage.mockReset(); + mocks.transcribeAudioFile.mockClear(); + mocks.textToSpeech.mockClear(); + mocks.setTtsProvider.mockClear(); + mocks.resolveExplicitTtsOverrides.mockClear(); + mocks.createEmbeddingProvider.mockClear(); + mocks.registerMemoryEmbeddingProvider.mockClear(); + mocks.registerBuiltInMemoryEmbeddingProviders.mockClear(); + mocks.isWebSearchProviderConfigured.mockReset().mockReturnValue(false); + mocks.isWebFetchProviderConfigured.mockReset().mockReturnValue(false); + mocks.modelsStatusCommand.mockClear(); + mocks.callGateway.mockImplementation(async ({ method }: { method: string }) => { + if (method === "tts.status") { + return { enabled: true, provider: "openai" }; + } + if (method === "tts.convert") { + return { + audioPath: "/tmp/gateway-tts.mp3", + provider: "openai", + outputFormat: "mp3", + voiceCompatible: false, + }; + } + if (method === "agent") { + return { + result: { + payloads: [{ text: "gateway reply" }], + meta: { agentMeta: { provider: "anthropic", model: "claude-sonnet-4-6" } }, + }, + }; + } + return {}; + }); + }); + + it("lists canonical capabilities", async () => { + await runRegisteredCli({ + register: registerCapabilityCli as (program: Command) => void, + argv: ["capability", "list", "--json"], + }); + + const payload = mocks.runtime.writeJson.mock.calls[0]?.[0] as Array<{ id: string }>; + expect(payload.some((entry) => entry.id === "model.run")).toBe(true); + expect(payload.some((entry) => entry.id === "image.describe")).toBe(true); + }); + + it("defaults model run to local transport", async () => { + await runRegisteredCli({ + register: registerCapabilityCli as (program: Command) => void, + argv: ["capability", "model", "run", "--prompt", "hello", "--json"], + }); + + expect(mocks.agentCommand).toHaveBeenCalledTimes(1); + expect(mocks.callGateway).not.toHaveBeenCalled(); + expect(mocks.runtime.writeJson).toHaveBeenCalledWith( + expect.objectContaining({ + capability: "model.run", + transport: "local", + }), + ); + }); + + it("defaults tts status to gateway transport", async () => { + await runRegisteredCli({ + register: registerCapabilityCli as (program: Command) => void, + argv: ["capability", "tts", "status", "--json"], + }); + + expect(mocks.callGateway).toHaveBeenCalledWith( + expect.objectContaining({ method: "tts.status" }), + ); + expect(mocks.runtime.writeJson).toHaveBeenCalledWith( + expect.objectContaining({ transport: "gateway" }), + ); + }); + + it("routes image describe through media understanding, not generation", async () => { + await runRegisteredCli({ + register: registerCapabilityCli as (program: Command) => void, + argv: ["capability", "image", "describe", "--file", "photo.jpg", "--json"], + }); + + expect(mocks.describeImageFile).toHaveBeenCalledWith( + expect.objectContaining({ filePath: expect.stringMatching(/photo\.jpg$/) }), + ); + expect(mocks.runtime.writeJson).toHaveBeenCalledWith( + expect.objectContaining({ + capability: "image.describe", + outputs: [expect.objectContaining({ kind: "image.description" })], + }), + ); + }); + + it("fails image describe when no description text is returned", async () => { + mocks.describeImageFile.mockResolvedValueOnce({ + text: undefined, + provider: undefined, + model: undefined, + }); + + await expect( + runRegisteredCli({ + register: registerCapabilityCli as (program: Command) => void, + argv: ["capability", "image", "describe", "--file", "photo.jpg", "--json"], + }), + ).rejects.toThrow("exit 1"); + expect(mocks.runtime.error).toHaveBeenCalledWith( + expect.stringMatching(/No description returned for image/), + ); + }); + + it("rewrites mismatched explicit image output extensions to the detected file type", async () => { + const jpegBase64 = + "/9j/4AAQSkZJRgABAQAAAQABAAD/2wCEAAkGBxAQEBUQEBAVFRUVFRUVFRUVFRUVFRUVFRUXFhUVFRUYHSggGBolHRUVITEhJSkrLi4uFx8zODMsNygtLisBCgoKDg0OGhAQGi0fHyUtLS0tLS0tLS0tLS0tLS0tLS0tLS0tLS0tLS0tLS0tLS0tLS0tLS0tLS0tLS0tLf/AABEIAAEAAQMBIgACEQEDEQH/xAAXAAEBAQEAAAAAAAAAAAAAAAAAAQID/8QAFhEBAQEAAAAAAAAAAAAAAAAAAAER/9oADAMBAAIQAxAAAAH2AP/EABgQAQEAAwAAAAAAAAAAAAAAAAEAEQIS/9oACAEBAAEFAk1o7//EABYRAQEBAAAAAAAAAAAAAAAAAAABEf/aAAgBAwEBPwGn/8QAFhEBAQEAAAAAAAAAAAAAAAAAABEB/9oACAECAQE/AYf/xAAaEAACAgMAAAAAAAAAAAAAAAABEQAhMUFh/9oACAEBAAY/AjK9cY2f/8QAGhABAQACAwAAAAAAAAAAAAAAAAERITFBUf/aAAgBAQABPyGQk7W5jVYkA//Z"; + mocks.generateImage.mockResolvedValue({ + provider: "openai", + model: "gpt-image-1", + attempts: [], + images: [ + { + buffer: Buffer.from(jpegBase64, "base64"), + mimeType: "image/png", + fileName: "provider-output.png", + }, + ], + }); + + const tempOutput = path.join(os.tmpdir(), `openclaw-image-mismatch-${Date.now()}.png`); + await fs.rm(tempOutput, { force: true }); + await fs.rm(tempOutput.replace(/\.png$/, ".jpg"), { force: true }); + + await runRegisteredCli({ + register: registerCapabilityCli as (program: Command) => void, + argv: [ + "capability", + "image", + "generate", + "--prompt", + "friendly lobster", + "--output", + tempOutput, + "--json", + ], + }); + + expect(mocks.runtime.writeJson).toHaveBeenCalledWith( + expect.objectContaining({ + outputs: [ + expect.objectContaining({ + path: tempOutput.replace(/\.png$/, ".jpg"), + mimeType: "image/jpeg", + }), + ], + }), + ); + }); + + it("routes audio transcribe through transcription, not realtime", async () => { + await runRegisteredCli({ + register: registerCapabilityCli as (program: Command) => void, + argv: ["capability", "audio", "transcribe", "--file", "memo.m4a", "--json"], + }); + + expect(mocks.transcribeAudioFile).toHaveBeenCalledWith( + expect.objectContaining({ filePath: expect.stringMatching(/memo\.m4a$/) }), + ); + expect(mocks.runtime.writeJson).toHaveBeenCalledWith( + expect.objectContaining({ + capability: "audio.transcribe", + outputs: [expect.objectContaining({ kind: "audio.transcription" })], + }), + ); + }); + + it("fails audio transcribe when no transcript text is returned", async () => { + mocks.transcribeAudioFile.mockResolvedValueOnce({ text: undefined }); + + await expect( + runRegisteredCli({ + register: registerCapabilityCli as (program: Command) => void, + argv: ["capability", "audio", "transcribe", "--file", "memo.m4a", "--json"], + }), + ).rejects.toThrow("exit 1"); + expect(mocks.runtime.error).toHaveBeenCalledWith( + expect.stringMatching(/No transcript returned for audio/), + ); + }); + + it("forwards transcription prompt and language hints", async () => { + await runRegisteredCli({ + register: registerCapabilityCli as (program: Command) => void, + argv: [ + "capability", + "audio", + "transcribe", + "--file", + "memo.m4a", + "--language", + "en", + "--prompt", + "Focus on names", + "--json", + ], + }); + + expect(mocks.transcribeAudioFile).toHaveBeenCalledWith( + expect.objectContaining({ + filePath: expect.stringMatching(/memo\.m4a$/), + language: "en", + prompt: "Focus on names", + }), + ); + }); + + it("uses request-scoped TTS overrides without mutating prefs", async () => { + await runRegisteredCli({ + register: registerCapabilityCli as (program: Command) => void, + argv: [ + "capability", + "tts", + "convert", + "--text", + "hello", + "--model", + "openai/gpt-4o-mini-tts", + "--voice", + "alloy", + "--json", + ], + }); + + expect(mocks.textToSpeech).toHaveBeenCalledWith( + expect.objectContaining({ + overrides: expect.objectContaining({ + provider: "openai", + providerOverrides: expect.objectContaining({ + openai: expect.objectContaining({ + modelId: "gpt-4o-mini-tts", + voiceId: "alloy", + }), + }), + }), + }), + ); + expect(mocks.setTtsProvider).not.toHaveBeenCalled(); + }); + + it("disables TTS fallback when explicit provider or voice/model selection is requested", async () => { + await runRegisteredCli({ + register: registerCapabilityCli as (program: Command) => void, + argv: [ + "capability", + "tts", + "convert", + "--text", + "hello", + "--model", + "openai/gpt-4o-mini-tts", + "--voice", + "alloy", + "--json", + ], + }); + + expect(mocks.textToSpeech).toHaveBeenCalledWith( + expect.objectContaining({ + disableFallback: true, + }), + ); + }); + + it("does not infer and forward a local provider guess for gateway TTS overrides", async () => { + await runRegisteredCli({ + register: registerCapabilityCli as (program: Command) => void, + argv: [ + "capability", + "tts", + "convert", + "--gateway", + "--text", + "hello", + "--voice", + "alloy", + "--json", + ], + }); + + expect(mocks.callGateway).toHaveBeenCalledWith( + expect.objectContaining({ + method: "tts.convert", + params: expect.objectContaining({ + provider: undefined, + voiceId: "alloy", + }), + }), + ); + }); + + it("fails clearly when gateway TTS output is requested against a remote gateway", async () => { + const gatewayConnection = await import("../gateway/connection-details.js"); + vi.mocked(gatewayConnection.buildGatewayConnectionDetailsWithResolvers).mockReturnValueOnce({ + url: "wss://gateway.example.com", + urlSource: "config gateway.remote.url", + message: "Gateway target: wss://gateway.example.com", + }); + + await expect( + runRegisteredCli({ + register: registerCapabilityCli as (program: Command) => void, + argv: [ + "capability", + "tts", + "convert", + "--gateway", + "--text", + "hello", + "--output", + "hello.mp3", + "--json", + ], + }), + ).rejects.toThrow("exit 1"); + + expect(mocks.runtime.error).toHaveBeenCalledWith( + expect.stringContaining("--output is not supported for remote gateway TTS yet"), + ); + }); + + it("uses only embedding providers for embedding creation", async () => { + await runRegisteredCli({ + register: registerCapabilityCli as (program: Command) => void, + argv: ["capability", "embedding", "create", "--text", "hello", "--json"], + }); + + expect(mocks.createEmbeddingProvider).toHaveBeenCalledWith( + expect.objectContaining({ + provider: "auto", + fallback: "none", + }), + ); + expect(mocks.runtime.writeJson).toHaveBeenCalledWith( + expect.objectContaining({ + capability: "embedding.create", + provider: "openai", + model: "text-embedding-3-small", + }), + ); + }); + + it("derives the embedding provider from a provider/model override", async () => { + await runRegisteredCli({ + register: registerCapabilityCli as (program: Command) => void, + argv: [ + "capability", + "embedding", + "create", + "--text", + "hello", + "--model", + "openai/text-embedding-3-large", + "--json", + ], + }); + + expect(mocks.createEmbeddingProvider).toHaveBeenCalledWith( + expect.objectContaining({ + provider: "openai", + fallback: "none", + model: "text-embedding-3-large", + }), + ); + }); + + it("cleans provider auth profiles and usage stats on logout", async () => { + mocks.loadAuthProfileStoreForRuntime.mockReturnValue({ + profiles: { + "openai:default": { id: "openai:default" }, + "openai:secondary": { id: "openai:secondary" }, + "anthropic:default": { id: "anthropic:default" }, + }, + order: { openai: ["openai:default", "openai:secondary"] }, + lastGood: { openai: "openai:secondary" }, + usageStats: { + "openai:default": { errorCount: 2 }, + "openai:secondary": { errorCount: 1 }, + "anthropic:default": { errorCount: 3 }, + }, + }); + mocks.listProfilesForProvider.mockReturnValue(["openai:default", "openai:secondary"]); + + let updatedStore: Record | null = null; + mocks.updateAuthProfileStoreWithLock.mockImplementationOnce( + async ({ updater }: { updater: (store: any) => boolean }) => { + const store = { + version: 1, + profiles: { + "openai:default": { id: "openai:default" }, + "openai:secondary": { id: "openai:secondary" }, + "anthropic:default": { id: "anthropic:default" }, + }, + order: { openai: ["openai:default", "openai:secondary"] }, + lastGood: { openai: "openai:secondary" }, + usageStats: { + "openai:default": { errorCount: 2 }, + "openai:secondary": { errorCount: 1 }, + "anthropic:default": { errorCount: 3 }, + }, + }; + updater(store); + updatedStore = store; + return store; + }, + ); + + await runRegisteredCli({ + register: registerCapabilityCli as (program: Command) => void, + argv: ["capability", "model", "auth", "logout", "--provider", "openai", "--json"], + }); + + expect(updatedStore).toMatchObject({ + profiles: { + "anthropic:default": { id: "anthropic:default" }, + }, + order: {}, + lastGood: {}, + usageStats: { + "anthropic:default": { errorCount: 3 }, + }, + }); + expect(mocks.runtime.writeJson).toHaveBeenCalledWith({ + provider: "openai", + removedProfiles: ["openai:default", "openai:secondary"], + }); + }); + + it("fails logout if the auth store update does not complete", async () => { + mocks.listProfilesForProvider.mockReturnValue(["openai:default"]); + mocks.updateAuthProfileStoreWithLock.mockResolvedValueOnce(null); + + await expect( + runRegisteredCli({ + register: registerCapabilityCli as (program: Command) => void, + argv: ["capability", "model", "auth", "logout", "--provider", "openai", "--json"], + }), + ).rejects.toThrow("exit 1"); + + expect(mocks.runtime.error).toHaveBeenCalledWith( + expect.stringContaining("Failed to remove saved auth profiles for provider openai."), + ); + }); + + it("rejects providerless audio model overrides", async () => { + await expect( + runRegisteredCli({ + register: registerCapabilityCli as (program: Command) => void, + argv: [ + "capability", + "audio", + "transcribe", + "--file", + "memo.m4a", + "--model", + "whisper-1", + "--json", + ], + }), + ).rejects.toThrow("exit 1"); + + expect(mocks.runtime.error).toHaveBeenCalledWith( + expect.stringContaining("Model overrides must use the form ."), + ); + expect(mocks.transcribeAudioFile).not.toHaveBeenCalled(); + }); + + it("rejects providerless image describe model overrides", async () => { + await expect( + runRegisteredCli({ + register: registerCapabilityCli as (program: Command) => void, + argv: [ + "capability", + "image", + "describe", + "--file", + "photo.jpg", + "--model", + "gpt-4.1-mini", + "--json", + ], + }), + ).rejects.toThrow("exit 1"); + + expect(mocks.runtime.error).toHaveBeenCalledWith( + expect.stringContaining("Model overrides must use the form ."), + ); + expect(mocks.describeImageFile).not.toHaveBeenCalled(); + }); + + it("rejects providerless video describe model overrides", async () => { + const mediaRuntime = await import("../media-understanding/runtime.js"); + vi.mocked(mediaRuntime.describeVideoFile).mockResolvedValue({ + text: "friendly lobster", + provider: "openai", + model: "gpt-4.1-mini", + } as never); + + await expect( + runRegisteredCli({ + register: registerCapabilityCli as (program: Command) => void, + argv: [ + "capability", + "video", + "describe", + "--file", + "clip.mp4", + "--model", + "gpt-4.1-mini", + "--json", + ], + }), + ).rejects.toThrow("exit 1"); + + expect(mocks.runtime.error).toHaveBeenCalledWith( + expect.stringContaining("Model overrides must use the form ."), + ); + expect(vi.mocked(mediaRuntime.describeVideoFile)).not.toHaveBeenCalled(); + }); + + it("bootstraps built-in embedding providers when the registry is empty", async () => { + mocks.listMemoryEmbeddingProviders.mockReturnValueOnce([]); + + await runRegisteredCli({ + register: registerCapabilityCli as (program: Command) => void, + argv: ["capability", "embedding", "providers", "--json"], + }); + + expect(mocks.registerBuiltInMemoryEmbeddingProviders).toHaveBeenCalledWith( + expect.objectContaining({ + registerMemoryEmbeddingProvider: expect.any(Function), + }), + ); + }); + + it("surfaces available, configured, and selected for web providers", async () => { + mocks.loadConfig.mockReturnValue({ + tools: { + web: { + search: { provider: "gemini" }, + fetch: { provider: "firecrawl" }, + }, + }, + }); + const webSearchRuntime = await import("../web-search/runtime.js"); + const webFetchRuntime = await import("../web-fetch/runtime.js"); + vi.mocked(webSearchRuntime.listWebSearchProviders).mockReturnValue([ + { id: "brave", envVars: ["BRAVE_API_KEY"] } as never, + { id: "gemini", envVars: ["GEMINI_API_KEY"] } as never, + ]); + vi.mocked(webFetchRuntime.listWebFetchProviders).mockReturnValue([ + { id: "firecrawl", envVars: ["FIRECRAWL_API_KEY"] } as never, + ]); + mocks.isWebSearchProviderConfigured.mockReturnValueOnce(false).mockReturnValueOnce(true); + mocks.isWebFetchProviderConfigured.mockReturnValueOnce(true); + + await runRegisteredCli({ + register: registerCapabilityCli as (program: Command) => void, + argv: ["capability", "web", "providers", "--json"], + }); + + expect(mocks.runtime.writeJson).toHaveBeenCalledWith({ + search: [ + { + available: true, + configured: false, + selected: false, + id: "brave", + envVars: ["BRAVE_API_KEY"], + }, + { + available: true, + configured: true, + selected: true, + id: "gemini", + envVars: ["GEMINI_API_KEY"], + }, + ], + fetch: [ + { + available: true, + configured: true, + selected: true, + id: "firecrawl", + envVars: ["FIRECRAWL_API_KEY"], + }, + ], + }); + }); + + it("surfaces selected and configured embedding provider state", async () => { + mocks.loadConfig.mockReturnValue({}); + mocks.resolveMemorySearchConfig.mockReturnValue({ + provider: "gemini", + model: "gemini-embedding-001", + }); + mocks.listMemoryEmbeddingProviders.mockReturnValue([ + { id: "openai", defaultModel: "text-embedding-3-small", transport: "remote" }, + { id: "gemini", defaultModel: "gemini-embedding-001", transport: "remote" }, + ]); + + await runRegisteredCli({ + register: registerCapabilityCli as (program: Command) => void, + argv: ["capability", "embedding", "providers", "--json"], + }); + + expect(mocks.runtime.writeJson).toHaveBeenCalledWith([ + { + available: true, + configured: false, + selected: false, + id: "openai", + defaultModel: "text-embedding-3-small", + transport: "remote", + autoSelectPriority: undefined, + }, + { + available: true, + configured: true, + selected: true, + id: "gemini", + defaultModel: "gemini-embedding-001", + transport: "remote", + autoSelectPriority: undefined, + }, + ]); + }); +}); diff --git a/src/cli/capability-cli.ts b/src/cli/capability-cli.ts new file mode 100644 index 00000000000..a19e8b26588 --- /dev/null +++ b/src/cli/capability-cli.ts @@ -0,0 +1,1822 @@ +import fs from "node:fs/promises"; +import path from "node:path"; +import type { Command } from "commander"; +import { + createEmbeddingProvider, + registerBuiltInMemoryEmbeddingProviders, +} from "../../extensions/memory-core/runtime-api.js"; +import { agentCommand } from "../agents/agent-command.js"; +import { resolveAgentDir, resolveDefaultAgentId } from "../agents/agent-scope.js"; +import { + listProfilesForProvider, + loadAuthProfileStoreForRuntime, +} from "../agents/auth-profiles.js"; +import { updateAuthProfileStoreWithLock } from "../agents/auth-profiles/store.js"; +import { resolveMemorySearchConfig } from "../agents/memory-search.js"; +import { loadModelCatalog } from "../agents/model-catalog.js"; +import { modelsAuthLoginCommand, modelsStatusCommand } from "../commands/models.js"; +import { loadConfig } from "../config/config.js"; +import { callGateway, randomIdempotencyKey } from "../gateway/call.js"; +import { buildGatewayConnectionDetailsWithResolvers } from "../gateway/connection-details.js"; +import { isLoopbackHost } from "../gateway/net.js"; +import { generateImage, listRuntimeImageGenerationProviders } from "../image-generation/runtime.js"; +import { buildMediaUnderstandingRegistry } from "../media-understanding/provider-registry.js"; +import { + describeImageFile, + describeVideoFile, + transcribeAudioFile, +} from "../media-understanding/runtime.js"; +import { getImageMetadata } from "../media/image-ops.js"; +import { detectMime, extensionForMime, normalizeMimeType } from "../media/mime.js"; +import { saveMediaBuffer } from "../media/store.js"; +import { + listMemoryEmbeddingProviders, + registerMemoryEmbeddingProvider, +} from "../plugins/memory-embedding-providers.js"; +import { writeRuntimeJson, defaultRuntime, type RuntimeEnv } from "../runtime.js"; +import { formatDocsLink } from "../terminal/links.js"; +import { theme } from "../terminal/theme.js"; +import { canonicalizeSpeechProviderId, listSpeechProviders } from "../tts/provider-registry.js"; +import { + getTtsProvider, + listSpeechVoices, + resolveExplicitTtsOverrides, + resolveTtsConfig, + resolveTtsPrefsPath, + setTtsEnabled, + setTtsProvider, + textToSpeech, +} from "../tts/tts.js"; +import { GATEWAY_CLIENT_MODES, GATEWAY_CLIENT_NAMES } from "../utils/message-channel.js"; +import { generateVideo, listRuntimeVideoGenerationProviders } from "../video-generation/runtime.js"; +import { + isWebFetchProviderConfigured, + resolveWebFetchDefinition, + listWebFetchProviders, +} from "../web-fetch/runtime.js"; +import { + isWebSearchProviderConfigured, + listWebSearchProviders, + runWebSearch, +} from "../web-search/runtime.js"; +import { runCommandWithRuntime } from "./cli-utils.js"; +import { createDefaultDeps } from "./deps.js"; +import { collectOption } from "./program/helpers.js"; + +type CapabilityTransport = "local" | "gateway"; + +type CapabilityMetadata = { + id: string; + description: string; + transports: Array; + flags: string[]; + resultShape: string; +}; + +type CapabilityEnvelope = { + ok: boolean; + capability: string; + transport: CapabilityTransport; + provider?: string; + model?: string; + attempts: Array>; + outputs: Array>; + error?: string; +}; + +const CAPABILITY_METADATA: CapabilityMetadata[] = [ + { + id: "model.run", + description: "Run a one-shot text inference turn through the agent runtime.", + transports: ["local", "gateway"], + flags: ["--prompt", "--model", "--local", "--gateway", "--json"], + resultShape: "normalized payloads plus provider/model attribution", + }, + { + id: "model.list", + description: "List known models from the model catalog.", + transports: ["local"], + flags: ["--json"], + resultShape: "catalog entries", + }, + { + id: "model.inspect", + description: "Inspect one model catalog entry.", + transports: ["local"], + flags: ["--model", "--json"], + resultShape: "single catalog entry", + }, + { + id: "model.providers", + description: "List model providers discovered from the catalog.", + transports: ["local"], + flags: ["--json"], + resultShape: "provider ids with counts and defaults", + }, + { + id: "model.auth.login", + description: "Run the existing provider auth login flow.", + transports: ["local"], + flags: ["--provider"], + resultShape: "interactive auth result", + }, + { + id: "model.auth.logout", + description: "Remove saved auth profiles for one provider.", + transports: ["local"], + flags: ["--provider", "--json"], + resultShape: "removed profile ids", + }, + { + id: "model.auth.status", + description: "Show configured model auth state.", + transports: ["local"], + flags: ["--json"], + resultShape: "model status summary", + }, + { + id: "image.generate", + description: "Generate raster images with configured image providers.", + transports: ["local"], + flags: [ + "--prompt", + "--model", + "--count", + "--size", + "--aspect-ratio", + "--resolution", + "--output", + "--json", + ], + resultShape: "saved image files plus attempts", + }, + { + id: "image.edit", + description: "Generate edited images from one or more input files.", + transports: ["local"], + flags: ["--file", "--prompt", "--model", "--output", "--json"], + resultShape: "saved image files plus attempts", + }, + { + id: "image.describe", + description: "Describe one image file through media-understanding providers.", + transports: ["local"], + flags: ["--file", "--prompt", "--model", "--json"], + resultShape: "normalized text output", + }, + { + id: "image.describe-many", + description: "Describe multiple image files independently.", + transports: ["local"], + flags: ["--file", "--prompt", "--model", "--json"], + resultShape: "one text output per file", + }, + { + id: "image.providers", + description: "List image generation providers.", + transports: ["local"], + flags: ["--json"], + resultShape: "provider ids and defaults", + }, + { + id: "audio.transcribe", + description: "Transcribe one audio file.", + transports: ["local"], + flags: ["--file", "--model", "--json"], + resultShape: "normalized text output", + }, + { + id: "audio.providers", + description: "List audio transcription providers.", + transports: ["local"], + flags: ["--json"], + resultShape: "provider ids and capabilities", + }, + { + id: "tts.convert", + description: "Convert text to speech.", + transports: ["local", "gateway"], + flags: [ + "--text", + "--channel", + "--voice", + "--model", + "--output", + "--local", + "--gateway", + "--json", + ], + resultShape: "saved audio file plus attempts", + }, + { + id: "tts.voices", + description: "List voices for a speech provider.", + transports: ["local"], + flags: ["--provider", "--json"], + resultShape: "voice entries", + }, + { + id: "tts.providers", + description: "List speech providers.", + transports: ["local", "gateway"], + flags: ["--local", "--gateway", "--json"], + resultShape: "provider ids, configured state, models, voices", + }, + { + id: "tts.status", + description: "Show gateway-managed TTS state.", + transports: ["gateway"], + flags: ["--gateway", "--json"], + resultShape: "enabled/provider state", + }, + { + id: "tts.enable", + description: "Enable TTS in prefs.", + transports: ["local", "gateway"], + flags: ["--local", "--gateway", "--json"], + resultShape: "enabled state", + }, + { + id: "tts.disable", + description: "Disable TTS in prefs.", + transports: ["local", "gateway"], + flags: ["--local", "--gateway", "--json"], + resultShape: "enabled state", + }, + { + id: "tts.set-provider", + description: "Set the active TTS provider.", + transports: ["local", "gateway"], + flags: ["--provider", "--local", "--gateway", "--json"], + resultShape: "selected provider", + }, + { + id: "video.generate", + description: "Generate video files with configured video providers.", + transports: ["local"], + flags: ["--prompt", "--model", "--output", "--json"], + resultShape: "saved video files plus attempts", + }, + { + id: "video.describe", + description: "Describe one video file through media-understanding providers.", + transports: ["local"], + flags: ["--file", "--model", "--json"], + resultShape: "normalized text output", + }, + { + id: "video.providers", + description: "List video generation and description providers.", + transports: ["local"], + flags: ["--json"], + resultShape: "provider ids and defaults", + }, + { + id: "web.search", + description: "Run provider-backed web search.", + transports: ["local"], + flags: ["--query", "--provider", "--limit", "--json"], + resultShape: "search provider result", + }, + { + id: "web.fetch", + description: "Fetch URL content through configured web fetch providers.", + transports: ["local"], + flags: ["--url", "--provider", "--format", "--json"], + resultShape: "fetch provider result", + }, + { + id: "web.providers", + description: "List web search and fetch providers.", + transports: ["local"], + flags: ["--json"], + resultShape: "provider ids grouped by family", + }, + { + id: "embedding.create", + description: "Create embeddings through embedding providers.", + transports: ["local"], + flags: ["--text", "--provider", "--model", "--json"], + resultShape: "vectors with provider/model attribution", + }, + { + id: "embedding.providers", + description: "List embedding providers.", + transports: ["local"], + flags: ["--json"], + resultShape: "provider ids and default models", + }, +]; + +function findCapabilityMetadata(id: string): CapabilityMetadata | undefined { + return CAPABILITY_METADATA.find((entry) => entry.id === id); +} + +function resolveTransport(opts: { + local?: boolean; + gateway?: boolean; + supported: Array; + defaultTransport: CapabilityTransport; +}): CapabilityTransport { + if (opts.local && opts.gateway) { + throw new Error("Pass only one of --local or --gateway."); + } + if (opts.local) { + if (!opts.supported.includes("local")) { + throw new Error("This command does not support --local."); + } + return "local"; + } + if (opts.gateway) { + if (!opts.supported.includes("gateway")) { + throw new Error("This command does not support --gateway."); + } + return "gateway"; + } + return opts.defaultTransport; +} + +function emitJsonOrText( + runtime: RuntimeEnv, + json: boolean | undefined, + value: unknown, + textFormatter: (value: unknown) => string, +) { + if (json) { + writeRuntimeJson(runtime, value); + return; + } + runtime.log(textFormatter(value)); +} + +function formatEnvelopeForText(value: unknown): string { + const envelope = value as CapabilityEnvelope; + if (!envelope.ok) { + return `${envelope.capability} failed: ${envelope.error ?? "unknown error"}`; + } + const lines = [ + `${envelope.capability} via ${envelope.transport}`, + ...(envelope.provider ? [`provider: ${envelope.provider}`] : []), + ...(envelope.model ? [`model: ${envelope.model}`] : []), + `outputs: ${String(envelope.outputs.length)}`, + ]; + for (const output of envelope.outputs) { + const pathValue = typeof output.path === "string" ? output.path : undefined; + const textValue = typeof output.text === "string" ? output.text : undefined; + if (pathValue) { + lines.push(pathValue); + } else if (textValue) { + lines.push(textValue); + } else { + lines.push(JSON.stringify(output)); + } + } + return lines.join("\n"); +} + +function providerSummaryText(value: unknown): string { + const providers = value as Array>; + return providers.map((entry) => JSON.stringify(entry)).join("\n"); +} + +function hasOwnKeys(value: unknown): boolean { + return Boolean( + value && typeof value === "object" && Object.keys(value as Record).length > 0, + ); +} + +function resolveSelectedProviderFromModelRef(modelRef: string | undefined): string | undefined { + return resolveModelRefOverride(modelRef).provider; +} + +function getAuthProfileIdsForProvider( + cfg: ReturnType, + providerId: string, +): string[] { + const agentDir = resolveAgentDir(cfg, resolveDefaultAgentId(cfg)); + const store = loadAuthProfileStoreForRuntime(agentDir); + return listProfilesForProvider(store, providerId); +} + +function providerHasGenericConfig(params: { + cfg: ReturnType; + providerId: string; + envVars?: string[]; +}): boolean { + const modelsProviders = (params.cfg.models?.providers ?? {}) as Record; + const pluginEntries = (params.cfg.plugins?.entries ?? {}) as Record; + const ttsProviders = (params.cfg.messages?.tts?.providers ?? {}) as Record; + const envConfigured = (params.envVars ?? []).some((envVar) => + Boolean(process.env[envVar]?.trim()), + ); + return ( + getAuthProfileIdsForProvider(params.cfg, params.providerId).length > 0 || + hasOwnKeys(modelsProviders[params.providerId]) || + hasOwnKeys(pluginEntries[params.providerId]?.config) || + hasOwnKeys(ttsProviders[params.providerId]) || + envConfigured + ); +} + +async function writeOutputAsset(params: { + buffer: Buffer; + mimeType?: string; + originalFilename?: string; + outputPath?: string; + outputIndex: number; + outputCount: number; + subdir: string; +}) { + if (!params.outputPath) { + const saved = await saveMediaBuffer( + params.buffer, + params.mimeType, + params.subdir, + Number.MAX_SAFE_INTEGER, + params.originalFilename, + ); + return { path: saved.path, mimeType: saved.contentType, size: saved.size }; + } + + const resolvedOutput = path.resolve(params.outputPath); + const parsed = path.parse(resolvedOutput); + const detectedMime = + (await detectMime({ + buffer: params.buffer, + headerMime: params.mimeType, + })) ?? params.mimeType; + const requestedMime = normalizeMimeType(await detectMime({ filePath: resolvedOutput })); + const detectedNormalized = normalizeMimeType(detectedMime); + const canonicalDetectedExt = extensionForMime(detectedNormalized); + const fallbackExt = parsed.ext || path.extname(params.originalFilename ?? "") || ""; + const ext = + parsed.ext && requestedMime === detectedNormalized + ? parsed.ext + : (canonicalDetectedExt ?? fallbackExt); + const filePath = + params.outputCount <= 1 + ? path.join(parsed.dir, `${parsed.name}${ext}`) + : path.join(parsed.dir, `${parsed.name}-${String(params.outputIndex + 1)}${ext}`); + await fs.mkdir(path.dirname(filePath), { recursive: true }); + await fs.writeFile(filePath, params.buffer); + return { + path: filePath, + mimeType: detectedNormalized ?? params.mimeType, + size: params.buffer.byteLength, + }; +} + +async function readInputFiles(files: string[]): Promise> { + return await Promise.all( + files.map(async (filePath) => ({ + path: path.resolve(filePath), + buffer: await fs.readFile(path.resolve(filePath)), + })), + ); +} + +function resolveModelRefOverride(raw: string | undefined): { provider?: string; model?: string } { + const trimmed = raw?.trim(); + if (!trimmed) { + return {}; + } + const slash = trimmed.indexOf("/"); + if (slash <= 0 || slash === trimmed.length - 1) { + return { model: trimmed }; + } + return { + provider: trimmed.slice(0, slash), + model: trimmed.slice(slash + 1), + }; +} + +function requireProviderModelOverride( + raw: string | undefined, +): { provider: string; model: string } | undefined { + const resolved = resolveModelRefOverride(raw); + if (!raw?.trim()) { + return undefined; + } + if (!resolved.provider || !resolved.model) { + throw new Error("Model overrides must use the form ."); + } + return { + provider: resolved.provider, + model: resolved.model, + }; +} + +async function runModelRun(params: { + prompt: string; + model?: string; + transport: CapabilityTransport; +}) { + const cfg = loadConfig(); + const agentId = resolveDefaultAgentId(cfg); + if (params.transport === "local") { + const result = await agentCommand( + { + message: params.prompt, + agentId, + model: params.model, + json: false, + }, + { + ...defaultRuntime, + log: () => {}, + }, + createDefaultDeps(), + ); + return { + ok: true, + capability: "model.run", + transport: "local" as const, + provider: result?.meta?.agentMeta?.provider, + model: result?.meta?.agentMeta?.model, + attempts: [], + outputs: (result?.payloads ?? []).map((payload) => ({ + text: payload.text, + mediaUrl: payload.mediaUrl, + mediaUrls: payload.mediaUrls, + })), + } satisfies CapabilityEnvelope; + } + + const { provider, model } = resolveModelRefOverride(params.model); + const response = await callGateway<{ + result?: { + payloads?: Array<{ text?: string; mediaUrl?: string | null; mediaUrls?: string[] }>; + meta?: { agentMeta?: { provider?: string; model?: string } }; + }; + }>({ + method: "agent", + params: { + agentId, + message: params.prompt, + provider, + model, + idempotencyKey: randomIdempotencyKey(), + }, + expectFinal: true, + timeoutMs: 120_000, + clientName: GATEWAY_CLIENT_NAMES.CLI, + mode: GATEWAY_CLIENT_MODES.CLI, + }); + return { + ok: true, + capability: "model.run", + transport: "gateway" as const, + provider: response?.result?.meta?.agentMeta?.provider, + model: response?.result?.meta?.agentMeta?.model, + attempts: [], + outputs: (response?.result?.payloads ?? []).map((payload) => ({ + text: payload.text, + mediaUrl: payload.mediaUrl, + mediaUrls: payload.mediaUrls, + })), + } satisfies CapabilityEnvelope; +} + +async function buildModelProviders() { + const cfg = loadConfig(); + const catalog = await loadModelCatalog({ config: cfg }); + const selectedProvider = resolveSelectedProviderFromModelRef( + cfg.agents?.defaults?.model?.primary, + ); + const grouped = new Map< + string, + { + provider: string; + count: number; + defaults: string[]; + available: boolean; + configured: boolean; + selected: boolean; + } + >(); + for (const entry of catalog) { + const current = grouped.get(entry.provider) ?? { + provider: entry.provider, + count: 0, + defaults: [], + available: true, + configured: providerHasGenericConfig({ cfg, providerId: entry.provider }), + selected: selectedProvider === entry.provider, + }; + current.count += 1; + if (current.defaults.length < 3) { + current.defaults.push(entry.id); + } + grouped.set(entry.provider, current); + } + return [...grouped.values()].toSorted((a, b) => a.provider.localeCompare(b.provider)); +} + +async function runModelAuthStatus() { + const captured: string[] = []; + await modelsStatusCommand( + { json: true }, + { + log: (...args) => captured.push(args.join(" ")), + error: (message) => { + throw new Error(message); + }, + exit: (code) => { + throw new Error(`exit ${code}`); + }, + }, + ); + const raw = captured.find((line) => line.trim().startsWith("{")); + return raw ? (JSON.parse(raw) as Record) : {}; +} + +async function runModelAuthLogout(provider: string) { + const cfg = loadConfig(); + const agentDir = resolveAgentDir(cfg, resolveDefaultAgentId(cfg)); + const store = loadAuthProfileStoreForRuntime(agentDir); + const profileIds = listProfilesForProvider(store, provider); + const updated = await updateAuthProfileStoreWithLock({ + agentDir, + updater: (nextStore) => { + let changed = false; + for (const profileId of profileIds) { + if (nextStore.profiles[profileId]) { + delete nextStore.profiles[profileId]; + changed = true; + } + if (nextStore.usageStats?.[profileId]) { + delete nextStore.usageStats[profileId]; + changed = true; + } + } + if (nextStore.order?.[provider]) { + delete nextStore.order[provider]; + changed = true; + } + if (nextStore.lastGood?.[provider]) { + delete nextStore.lastGood[provider]; + changed = true; + } + return changed; + }, + }); + if (!updated) { + throw new Error(`Failed to remove saved auth profiles for provider ${provider}.`); + } + return { + provider, + removedProfiles: profileIds, + }; +} + +async function runImageGenerate(params: { + capability: "image.generate" | "image.edit"; + prompt: string; + model?: string; + count?: number; + size?: string; + aspectRatio?: string; + resolution?: "1K" | "2K" | "4K"; + file?: string[]; + output?: string; +}) { + const cfg = loadConfig(); + const agentDir = resolveAgentDir(cfg, resolveDefaultAgentId(cfg)); + const inputImages = + params.file && params.file.length > 0 + ? await Promise.all( + (await readInputFiles(params.file)).map(async (entry) => ({ + buffer: entry.buffer, + fileName: path.basename(entry.path), + mimeType: + (await detectMime({ buffer: entry.buffer, filePath: entry.path })) ?? "image/png", + })), + ) + : undefined; + const result = await generateImage({ + cfg, + agentDir, + prompt: params.prompt, + modelOverride: params.model, + count: params.count, + size: params.size, + aspectRatio: params.aspectRatio, + resolution: params.resolution, + inputImages, + }); + const outputs = await Promise.all( + result.images.map(async (image, index) => { + const written = await writeOutputAsset({ + buffer: image.buffer, + mimeType: image.mimeType, + originalFilename: image.fileName, + outputPath: params.output, + outputIndex: index, + outputCount: result.images.length, + subdir: "generated", + }); + const metadata = await getImageMetadata(written.path).catch(() => undefined); + return { + ...written, + width: metadata?.width, + height: metadata?.height, + revisedPrompt: image.revisedPrompt, + }; + }), + ); + return { + ok: true, + capability: params.capability, + transport: "local" as const, + provider: result.provider, + model: result.model, + attempts: result.attempts, + outputs, + } satisfies CapabilityEnvelope; +} + +async function runImageDescribe(params: { + capability: "image.describe" | "image.describe-many"; + files: string[]; + model?: string; +}) { + const cfg = loadConfig(); + const activeModel = requireProviderModelOverride(params.model); + const outputs = await Promise.all( + params.files.map(async (filePath) => { + const result = await describeImageFile({ + filePath: path.resolve(filePath), + cfg, + activeModel, + }); + if (!result.text) { + throw new Error(`No description returned for image: ${path.resolve(filePath)}`); + } + return { + path: path.resolve(filePath), + text: result.text, + provider: result.provider, + model: result.model, + kind: "image.description", + }; + }), + ); + return { + ok: true, + capability: params.capability, + transport: "local" as const, + provider: outputs[0]?.provider, + model: outputs[0]?.model, + attempts: [], + outputs, + } satisfies CapabilityEnvelope; +} + +async function runAudioTranscribe(params: { + file: string; + language?: string; + model?: string; + prompt?: string; +}) { + const cfg = loadConfig(); + const activeModel = requireProviderModelOverride(params.model); + const result = await transcribeAudioFile({ + filePath: path.resolve(params.file), + cfg, + language: params.language, + activeModel, + prompt: params.prompt, + }); + if (!result.text) { + throw new Error(`No transcript returned for audio: ${path.resolve(params.file)}`); + } + return { + ok: true, + capability: "audio.transcribe", + transport: "local" as const, + attempts: [], + outputs: [{ path: path.resolve(params.file), text: result.text, kind: "audio.transcription" }], + } satisfies CapabilityEnvelope; +} + +async function runVideoGenerate(params: { prompt: string; model?: string; output?: string }) { + const cfg = loadConfig(); + const agentDir = resolveAgentDir(cfg, resolveDefaultAgentId(cfg)); + const result = await generateVideo({ + cfg, + agentDir, + prompt: params.prompt, + modelOverride: params.model, + }); + const outputs = await Promise.all( + result.videos.map(async (video, index) => ({ + ...(await writeOutputAsset({ + buffer: video.buffer, + mimeType: video.mimeType, + originalFilename: video.fileName, + outputPath: params.output, + outputIndex: index, + outputCount: result.videos.length, + subdir: "generated", + })), + })), + ); + return { + ok: true, + capability: "video.generate", + transport: "local" as const, + provider: result.provider, + model: result.model, + attempts: result.attempts, + outputs, + } satisfies CapabilityEnvelope; +} + +async function runVideoDescribe(params: { file: string; model?: string }) { + const cfg = loadConfig(); + const activeModel = requireProviderModelOverride(params.model); + const result = await describeVideoFile({ + filePath: path.resolve(params.file), + cfg, + activeModel, + }); + if (!result.text) { + throw new Error(`No description returned for video: ${path.resolve(params.file)}`); + } + return { + ok: true, + capability: "video.describe", + transport: "local" as const, + provider: result.provider, + model: result.model, + attempts: [], + outputs: [{ path: path.resolve(params.file), text: result.text, kind: "video.description" }], + } satisfies CapabilityEnvelope; +} + +async function runTtsConvert(params: { + text: string; + channel?: string; + provider?: string; + modelId?: string; + voiceId?: string; + output?: string; + transport: CapabilityTransport; +}) { + if (params.transport === "gateway") { + const gatewayConnection = buildGatewayConnectionDetailsWithResolvers({ config: loadConfig() }); + const result = await callGateway<{ + audioPath?: string; + provider?: string; + outputFormat?: string; + voiceCompatible?: boolean; + }>({ + method: "tts.convert", + params: { + text: params.text, + channel: params.channel, + provider: params.provider?.trim() || undefined, + modelId: params.modelId, + voiceId: params.voiceId, + }, + timeoutMs: 120_000, + }); + let outputPath = result.audioPath; + if (params.output && result.audioPath) { + const gatewayHost = new URL(gatewayConnection.url).hostname; + if (!isLoopbackHost(gatewayHost)) { + throw new Error( + `--output is not supported for remote gateway TTS yet (gateway target: ${gatewayConnection.url}).`, + ); + } + const target = path.resolve(params.output); + await fs.mkdir(path.dirname(target), { recursive: true }); + await fs.copyFile(result.audioPath, target); + outputPath = target; + } + return { + ok: true, + capability: "tts.convert", + transport: "gateway" as const, + provider: result.provider, + attempts: [], + outputs: [ + { + path: outputPath, + format: result.outputFormat, + voiceCompatible: result.voiceCompatible, + }, + ], + } satisfies CapabilityEnvelope; + } + + const cfg = loadConfig(); + const overrides = resolveExplicitTtsOverrides({ + cfg, + provider: params.provider, + modelId: params.modelId, + voiceId: params.voiceId, + }); + const hasExplicitSelection = Boolean( + overrides.provider || params.modelId?.trim() || params.voiceId?.trim(), + ); + const result = await textToSpeech({ + text: params.text, + cfg, + channel: params.channel, + overrides, + disableFallback: hasExplicitSelection, + }); + if (!result.success || !result.audioPath) { + throw new Error(result.error ?? "TTS conversion failed"); + } + let outputPath = result.audioPath; + if (params.output) { + const target = path.resolve(params.output); + await fs.mkdir(path.dirname(target), { recursive: true }); + await fs.copyFile(result.audioPath, target); + outputPath = target; + } + return { + ok: true, + capability: "tts.convert", + transport: "local" as const, + provider: result.provider, + attempts: result.attempts ?? [], + outputs: [ + { + path: outputPath, + format: result.outputFormat, + voiceCompatible: result.voiceCompatible, + }, + ], + } satisfies CapabilityEnvelope; +} + +async function runTtsProviders(transport: CapabilityTransport) { + const cfg = loadConfig(); + if (transport === "gateway") { + const payload = await callGateway<{ + providers?: Array>; + active?: string; + }>({ + method: "tts.providers", + timeoutMs: 30_000, + }); + return { + ...payload, + providers: (payload.providers ?? []).map((provider) => { + const id = typeof provider.id === "string" ? provider.id : ""; + return { + available: true, + configured: + typeof provider.configured === "boolean" + ? provider.configured + : providerHasGenericConfig({ cfg, providerId: id }), + selected: Boolean(id && payload.active === id), + ...provider, + }; + }), + }; + } + const config = resolveTtsConfig(cfg); + const prefsPath = resolveTtsPrefsPath(config); + const active = getTtsProvider(config, prefsPath); + return { + providers: listSpeechProviders(cfg).map((provider) => ({ + available: true, + configured: + active === provider.id || providerHasGenericConfig({ cfg, providerId: provider.id }), + selected: active === provider.id, + id: provider.id, + name: provider.label, + models: [...(provider.models ?? [])], + voices: [...(provider.voices ?? [])], + })), + active, + }; +} + +async function runTtsVoices(providerRaw?: string) { + const cfg = loadConfig(); + const config = resolveTtsConfig(cfg); + const prefsPath = resolveTtsPrefsPath(config); + const provider = providerRaw?.trim() || getTtsProvider(config, prefsPath); + return await listSpeechVoices({ + provider, + cfg, + config, + }); +} + +async function runTtsStateMutation(params: { + capability: "tts.enable" | "tts.disable" | "tts.set-provider"; + transport: CapabilityTransport; + provider?: string; +}) { + if (params.transport === "gateway") { + const method = + params.capability === "tts.enable" + ? "tts.enable" + : params.capability === "tts.disable" + ? "tts.disable" + : "tts.setProvider"; + const payload = await callGateway({ + method, + params: params.provider ? { provider: params.provider } : undefined, + timeoutMs: 30_000, + }); + return payload; + } + + const cfg = loadConfig(); + const config = resolveTtsConfig(cfg); + const prefsPath = resolveTtsPrefsPath(config); + if (params.capability === "tts.enable") { + setTtsEnabled(prefsPath, true); + return { enabled: true }; + } + if (params.capability === "tts.disable") { + setTtsEnabled(prefsPath, false); + return { enabled: false }; + } + if (!params.provider) { + throw new Error("--provider is required"); + } + const provider = canonicalizeSpeechProviderId(params.provider, cfg); + if (!provider) { + throw new Error(`Unknown speech provider: ${params.provider}`); + } + setTtsProvider(prefsPath, provider); + return { provider }; +} + +async function runWebSearchCommand(params: { query: string; provider?: string; limit?: number }) { + const cfg = loadConfig(); + const result = await runWebSearch({ + config: cfg, + providerId: params.provider, + args: { + query: params.query, + count: params.limit, + limit: params.limit, + }, + }); + return { + ok: true, + capability: "web.search", + transport: "local" as const, + provider: result.provider, + attempts: [], + outputs: [{ result: result.result }], + } satisfies CapabilityEnvelope; +} + +async function runWebFetchCommand(params: { url: string; provider?: string; format?: string }) { + const cfg = loadConfig(); + const resolved = resolveWebFetchDefinition({ + config: cfg, + providerId: params.provider, + }); + if (!resolved) { + throw new Error("web.fetch is disabled or no provider is available."); + } + const result = await resolved.definition.execute({ + url: params.url, + format: params.format, + }); + return { + ok: true, + capability: "web.fetch", + transport: "local" as const, + provider: resolved.provider.id, + attempts: [], + outputs: [{ result }], + } satisfies CapabilityEnvelope; +} + +async function runMemoryEmbeddingCreate(params: { + texts: string[]; + provider?: string; + model?: string; +}) { + ensureMemoryEmbeddingProvidersRegistered(); + const cfg = loadConfig(); + const modelRef = resolveModelRefOverride(params.model); + const requestedProvider = params.provider?.trim() || modelRef.provider || "auto"; + const result = await createEmbeddingProvider({ + config: cfg, + agentDir: resolveAgentDir(cfg, resolveDefaultAgentId(cfg)), + provider: requestedProvider, + fallback: "none", + model: modelRef.model ?? "", + }); + if (!result.provider) { + throw new Error(result.providerUnavailableReason ?? "No embedding provider available."); + } + const embeddings = await result.provider.embedBatch(params.texts); + return { + ok: true, + capability: "embedding.create", + transport: "local" as const, + provider: result.provider.id, + model: result.provider.model, + attempts: result.fallbackFrom + ? [{ provider: result.fallbackFrom, outcome: "failed", error: result.fallbackReason }] + : [], + outputs: embeddings.map((embedding, index) => ({ + text: params.texts[index], + embedding, + dimensions: embedding.length, + })), + } satisfies CapabilityEnvelope; +} + +function ensureMemoryEmbeddingProvidersRegistered(): void { + if (listMemoryEmbeddingProviders().length > 0) { + return; + } + registerBuiltInMemoryEmbeddingProviders({ + registerMemoryEmbeddingProvider, + }); +} + +function registerCapabilityListAndInspect(capability: Command) { + capability + .command("list") + .description("List canonical capability ids and supported transports") + .option("--json", "Output JSON", false) + .action(async (opts) => { + await runCommandWithRuntime(defaultRuntime, async () => { + const result = CAPABILITY_METADATA.map((entry) => ({ + id: entry.id, + transports: entry.transports, + description: entry.description, + })); + emitJsonOrText(defaultRuntime, Boolean(opts.json), result, providerSummaryText); + }); + }); + + capability + .command("inspect") + .description("Inspect one canonical capability id") + .requiredOption("--name ", "Capability id") + .option("--json", "Output JSON", false) + .action(async (opts) => { + await runCommandWithRuntime(defaultRuntime, async () => { + const entry = findCapabilityMetadata(String(opts.name)); + if (!entry) { + throw new Error(`Unknown capability: ${String(opts.name)}`); + } + emitJsonOrText(defaultRuntime, Boolean(opts.json), entry, (value) => + JSON.stringify(value, null, 2), + ); + }); + }); +} + +export function registerCapabilityCli(program: Command) { + const capability = program + .command("infer") + .alias("capability") + .description("Run provider-backed inference commands through a stable CLI surface") + .addHelpText( + "after", + () => + `\n${theme.muted("Docs:")} ${formatDocsLink("/cli/capability", "docs.openclaw.ai/cli/capability")}\n`, + ); + + registerCapabilityListAndInspect(capability); + + const model = capability + .command("model") + .description("Text inference and model catalog commands"); + + model + .command("run") + .description("Run a one-shot model turn") + .requiredOption("--prompt ", "Prompt text") + .option("--model ", "Model override") + .option("--local", "Force local execution", false) + .option("--gateway", "Force gateway execution", false) + .option("--json", "Output JSON", false) + .action(async (opts) => { + await runCommandWithRuntime(defaultRuntime, async () => { + const transport = resolveTransport({ + local: Boolean(opts.local), + gateway: Boolean(opts.gateway), + supported: ["local", "gateway"], + defaultTransport: "local", + }); + const result = await runModelRun({ + prompt: String(opts.prompt), + model: opts.model as string | undefined, + transport, + }); + emitJsonOrText(defaultRuntime, Boolean(opts.json), result, formatEnvelopeForText); + }); + }); + + model + .command("list") + .description("List known models") + .option("--json", "Output JSON", false) + .action(async (opts) => { + await runCommandWithRuntime(defaultRuntime, async () => { + const result = await loadModelCatalog({ config: loadConfig() }); + emitJsonOrText(defaultRuntime, Boolean(opts.json), result, providerSummaryText); + }); + }); + + model + .command("inspect") + .description("Inspect one model catalog entry") + .requiredOption("--model ", "Model id") + .option("--json", "Output JSON", false) + .action(async (opts) => { + await runCommandWithRuntime(defaultRuntime, async () => { + const target = String(opts.model).trim(); + const catalog = await loadModelCatalog({ config: loadConfig() }); + const entry = + catalog.find((candidate) => `${candidate.provider}/${candidate.id}` === target) ?? + catalog.find((candidate) => candidate.id === target); + if (!entry) { + throw new Error(`Model not found: ${target}`); + } + emitJsonOrText(defaultRuntime, Boolean(opts.json), entry, (value) => + JSON.stringify(value, null, 2), + ); + }); + }); + + model + .command("providers") + .description("List model providers from the catalog") + .option("--json", "Output JSON", false) + .action(async (opts) => { + await runCommandWithRuntime(defaultRuntime, async () => { + const result = await buildModelProviders(); + emitJsonOrText(defaultRuntime, Boolean(opts.json), result, providerSummaryText); + }); + }); + + const modelAuth = model.command("auth").description("Provider auth helpers"); + + modelAuth + .command("login") + .description("Run provider auth login") + .requiredOption("--provider ", "Provider id") + .action(async (opts) => { + await runCommandWithRuntime(defaultRuntime, async () => { + await modelsAuthLoginCommand({ provider: String(opts.provider) }, defaultRuntime); + }); + }); + + modelAuth + .command("logout") + .description("Remove saved auth profiles for one provider") + .requiredOption("--provider ", "Provider id") + .option("--json", "Output JSON", false) + .action(async (opts) => { + await runCommandWithRuntime(defaultRuntime, async () => { + const result = await runModelAuthLogout(String(opts.provider)); + emitJsonOrText(defaultRuntime, Boolean(opts.json), result, (value) => + JSON.stringify(value, null, 2), + ); + }); + }); + + modelAuth + .command("status") + .description("Show configured auth state") + .option("--json", "Output JSON", false) + .action(async (opts) => { + await runCommandWithRuntime(defaultRuntime, async () => { + const result = await runModelAuthStatus(); + emitJsonOrText(defaultRuntime, Boolean(opts.json), result, (value) => + JSON.stringify(value, null, 2), + ); + }); + }); + + const image = capability.command("image").description("Image generation and description"); + + image + .command("generate") + .description("Generate images") + .requiredOption("--prompt ", "Prompt text") + .option("--model ", "Model override") + .option("--count ", "Number of images") + .option("--size ", "Size hint like 1024x1024") + .option("--aspect-ratio ", "Aspect ratio hint like 16:9") + .option("--resolution ", "Resolution hint: 1K, 2K, or 4K") + .option("--output ", "Output path") + .option("--json", "Output JSON", false) + .action(async (opts) => { + await runCommandWithRuntime(defaultRuntime, async () => { + const result = await runImageGenerate({ + capability: "image.generate", + prompt: String(opts.prompt), + model: opts.model as string | undefined, + count: opts.count ? Number.parseInt(String(opts.count), 10) : undefined, + size: opts.size as string | undefined, + aspectRatio: opts.aspectRatio as string | undefined, + resolution: opts.resolution as "1K" | "2K" | "4K" | undefined, + output: opts.output as string | undefined, + }); + emitJsonOrText(defaultRuntime, Boolean(opts.json), result, formatEnvelopeForText); + }); + }); + + image + .command("edit") + .description("Edit images with one or more input files") + .requiredOption("--file ", "Input file", collectOption, []) + .requiredOption("--prompt ", "Prompt text") + .option("--model ", "Model override") + .option("--output ", "Output path") + .option("--json", "Output JSON", false) + .action(async (opts) => { + await runCommandWithRuntime(defaultRuntime, async () => { + const files = Array.isArray(opts.file) ? (opts.file as string[]) : [String(opts.file)]; + const result = await runImageGenerate({ + capability: "image.edit", + prompt: String(opts.prompt), + model: opts.model as string | undefined, + file: files, + output: opts.output as string | undefined, + }); + emitJsonOrText(defaultRuntime, Boolean(opts.json), result, formatEnvelopeForText); + }); + }); + + image + .command("describe") + .description("Describe one image file") + .requiredOption("--file ", "Image file") + .option("--model ", "Model override") + .option("--json", "Output JSON", false) + .action(async (opts) => { + await runCommandWithRuntime(defaultRuntime, async () => { + const result = await runImageDescribe({ + capability: "image.describe", + files: [String(opts.file)], + model: opts.model as string | undefined, + }); + emitJsonOrText(defaultRuntime, Boolean(opts.json), result, formatEnvelopeForText); + }); + }); + + image + .command("describe-many") + .description("Describe multiple image files") + .requiredOption("--file ", "Image file", collectOption, []) + .option("--model ", "Model override") + .option("--json", "Output JSON", false) + .action(async (opts) => { + await runCommandWithRuntime(defaultRuntime, async () => { + const result = await runImageDescribe({ + capability: "image.describe-many", + files: opts.file as string[], + model: opts.model as string | undefined, + }); + emitJsonOrText(defaultRuntime, Boolean(opts.json), result, formatEnvelopeForText); + }); + }); + + image + .command("providers") + .description("List image generation providers") + .option("--json", "Output JSON", false) + .action(async (opts) => { + await runCommandWithRuntime(defaultRuntime, async () => { + const cfg = loadConfig(); + const selectedProvider = resolveSelectedProviderFromModelRef( + cfg.agents?.defaults?.imageGenerationModel?.primary, + ); + const result = listRuntimeImageGenerationProviders({ config: cfg }).map((provider) => ({ + available: true, + configured: + selectedProvider === provider.id || + providerHasGenericConfig({ cfg, providerId: provider.id }), + selected: selectedProvider === provider.id, + id: provider.id, + label: provider.label, + defaultModel: provider.defaultModel, + models: provider.models ?? [], + capabilities: provider.capabilities, + })); + emitJsonOrText(defaultRuntime, Boolean(opts.json), result, providerSummaryText); + }); + }); + + const audio = capability.command("audio").description("Audio transcription"); + + audio + .command("transcribe") + .description("Transcribe one audio file") + .requiredOption("--file ", "Audio file") + .option("--language ", "Language hint") + .option("--prompt ", "Prompt hint") + .option("--model ", "Model override") + .option("--json", "Output JSON", false) + .action(async (opts) => { + await runCommandWithRuntime(defaultRuntime, async () => { + const result = await runAudioTranscribe({ + file: String(opts.file), + language: opts.language as string | undefined, + model: opts.model as string | undefined, + prompt: opts.prompt as string | undefined, + }); + emitJsonOrText(defaultRuntime, Boolean(opts.json), result, formatEnvelopeForText); + }); + }); + + audio + .command("providers") + .description("List audio transcription providers") + .option("--json", "Output JSON", false) + .action(async (opts) => { + await runCommandWithRuntime(defaultRuntime, async () => { + const cfg = loadConfig(); + const providers = [...buildMediaUnderstandingRegistry(undefined, cfg).values()] + .filter((provider) => provider.capabilities?.includes("audio")) + .map((provider) => ({ + available: true, + configured: providerHasGenericConfig({ cfg, providerId: provider.id }), + selected: false, + id: provider.id, + capabilities: provider.capabilities, + defaultModels: provider.defaultModels, + })); + emitJsonOrText(defaultRuntime, Boolean(opts.json), providers, providerSummaryText); + }); + }); + + const tts = capability.command("tts").description("Text to speech"); + + tts + .command("convert") + .description("Convert text to speech") + .requiredOption("--text ", "Input text") + .option("--channel ", "Channel hint") + .option("--voice ", "Voice hint") + .option("--model ", "Model override") + .option("--output ", "Output path") + .option("--local", "Force local execution", false) + .option("--gateway", "Force gateway execution", false) + .option("--json", "Output JSON", false) + .action(async (opts) => { + await runCommandWithRuntime(defaultRuntime, async () => { + const transport = resolveTransport({ + local: Boolean(opts.local), + gateway: Boolean(opts.gateway), + supported: ["local", "gateway"], + defaultTransport: "local", + }); + const modelRef = resolveModelRefOverride(opts.model as string | undefined); + if (opts.model && !modelRef.provider) { + throw new Error("TTS model overrides must use the form ."); + } + const result = await runTtsConvert({ + text: String(opts.text), + channel: opts.channel as string | undefined, + provider: modelRef.provider, + modelId: modelRef.provider ? modelRef.model : undefined, + voiceId: opts.voice as string | undefined, + output: opts.output as string | undefined, + transport, + }); + emitJsonOrText(defaultRuntime, Boolean(opts.json), result, formatEnvelopeForText); + }); + }); + + tts + .command("voices") + .description("List voices for a TTS provider") + .option("--provider ", "Speech provider id") + .option("--json", "Output JSON", false) + .action(async (opts) => { + await runCommandWithRuntime(defaultRuntime, async () => { + const voices = await runTtsVoices(opts.provider as string | undefined); + emitJsonOrText(defaultRuntime, Boolean(opts.json), voices, providerSummaryText); + }); + }); + + tts + .command("providers") + .description("List speech providers") + .option("--local", "Force local execution", false) + .option("--gateway", "Force gateway execution", false) + .option("--json", "Output JSON", false) + .action(async (opts) => { + await runCommandWithRuntime(defaultRuntime, async () => { + const transport = resolveTransport({ + local: Boolean(opts.local), + gateway: Boolean(opts.gateway), + supported: ["local", "gateway"], + defaultTransport: "local", + }); + const result = await runTtsProviders(transport); + emitJsonOrText(defaultRuntime, Boolean(opts.json), result, (value) => + JSON.stringify(value, null, 2), + ); + }); + }); + + tts + .command("status") + .description("Show TTS status") + .option("--gateway", "Force gateway execution", false) + .option("--json", "Output JSON", false) + .action(async (opts) => { + await runCommandWithRuntime(defaultRuntime, async () => { + const transport = resolveTransport({ + gateway: Boolean(opts.gateway), + supported: ["gateway"], + defaultTransport: "gateway", + }); + const result = await callGateway({ + method: "tts.status", + timeoutMs: 30_000, + }); + emitJsonOrText(defaultRuntime, Boolean(opts.json), { transport, ...result }, (value) => + JSON.stringify(value, null, 2), + ); + }); + }); + + for (const [commandName, capabilityId] of [ + ["enable", "tts.enable"], + ["disable", "tts.disable"], + ] as const) { + tts + .command(commandName) + .description(`${commandName === "enable" ? "Enable" : "Disable"} TTS`) + .option("--local", "Force local execution", false) + .option("--gateway", "Force gateway execution", false) + .option("--json", "Output JSON", false) + .action(async (opts) => { + await runCommandWithRuntime(defaultRuntime, async () => { + const transport = resolveTransport({ + local: Boolean(opts.local), + gateway: Boolean(opts.gateway), + supported: ["local", "gateway"], + defaultTransport: "gateway", + }); + const result = await runTtsStateMutation({ + capability: capabilityId, + transport, + }); + emitJsonOrText(defaultRuntime, Boolean(opts.json), result, (value) => + JSON.stringify(value, null, 2), + ); + }); + }); + } + + tts + .command("set-provider") + .description("Set the active TTS provider") + .requiredOption("--provider ", "Speech provider id") + .option("--local", "Force local execution", false) + .option("--gateway", "Force gateway execution", false) + .option("--json", "Output JSON", false) + .action(async (opts) => { + await runCommandWithRuntime(defaultRuntime, async () => { + const transport = resolveTransport({ + local: Boolean(opts.local), + gateway: Boolean(opts.gateway), + supported: ["local", "gateway"], + defaultTransport: "gateway", + }); + const result = await runTtsStateMutation({ + capability: "tts.set-provider", + provider: String(opts.provider), + transport, + }); + emitJsonOrText(defaultRuntime, Boolean(opts.json), result, (value) => + JSON.stringify(value, null, 2), + ); + }); + }); + + const video = capability.command("video").description("Video generation and description"); + + video + .command("generate") + .description("Generate video") + .requiredOption("--prompt ", "Prompt text") + .option("--model ", "Model override") + .option("--output ", "Output path") + .option("--json", "Output JSON", false) + .action(async (opts) => { + await runCommandWithRuntime(defaultRuntime, async () => { + const result = await runVideoGenerate({ + prompt: String(opts.prompt), + model: opts.model as string | undefined, + output: opts.output as string | undefined, + }); + emitJsonOrText(defaultRuntime, Boolean(opts.json), result, formatEnvelopeForText); + }); + }); + + video + .command("describe") + .description("Describe one video file") + .requiredOption("--file ", "Video file") + .option("--model ", "Model override") + .option("--json", "Output JSON", false) + .action(async (opts) => { + await runCommandWithRuntime(defaultRuntime, async () => { + const result = await runVideoDescribe({ + file: String(opts.file), + model: opts.model as string | undefined, + }); + emitJsonOrText(defaultRuntime, Boolean(opts.json), result, formatEnvelopeForText); + }); + }); + + video + .command("providers") + .description("List video generation and description providers") + .option("--json", "Output JSON", false) + .action(async (opts) => { + await runCommandWithRuntime(defaultRuntime, async () => { + const cfg = loadConfig(); + const selectedGenerationProvider = resolveSelectedProviderFromModelRef( + cfg.agents?.defaults?.videoGenerationModel?.primary, + ); + const result = { + generation: listRuntimeVideoGenerationProviders({ config: cfg }).map((provider) => ({ + available: true, + configured: + selectedGenerationProvider === provider.id || + providerHasGenericConfig({ cfg, providerId: provider.id }), + selected: selectedGenerationProvider === provider.id, + id: provider.id, + label: provider.label, + defaultModel: provider.defaultModel, + models: provider.models ?? [], + capabilities: provider.capabilities, + })), + description: [...buildMediaUnderstandingRegistry(undefined, cfg).values()] + .filter((provider) => provider.capabilities?.includes("video")) + .map((provider) => ({ + available: true, + configured: providerHasGenericConfig({ cfg, providerId: provider.id }), + selected: false, + id: provider.id, + capabilities: provider.capabilities, + defaultModels: provider.defaultModels, + })), + }; + emitJsonOrText(defaultRuntime, Boolean(opts.json), result, (value) => + JSON.stringify(value, null, 2), + ); + }); + }); + + const web = capability.command("web").description("Web capabilities"); + + web + .command("search") + .description("Run web search") + .requiredOption("--query ", "Search query") + .option("--provider ", "Provider id") + .option("--limit ", "Result limit") + .option("--json", "Output JSON", false) + .action(async (opts) => { + await runCommandWithRuntime(defaultRuntime, async () => { + const result = await runWebSearchCommand({ + query: String(opts.query), + provider: opts.provider as string | undefined, + limit: opts.limit ? Number.parseInt(String(opts.limit), 10) : undefined, + }); + emitJsonOrText(defaultRuntime, Boolean(opts.json), result, formatEnvelopeForText); + }); + }); + + web + .command("fetch") + .description("Fetch one URL") + .requiredOption("--url ", "URL") + .option("--provider ", "Provider id") + .option("--format ", "Format hint") + .option("--json", "Output JSON", false) + .action(async (opts) => { + await runCommandWithRuntime(defaultRuntime, async () => { + const result = await runWebFetchCommand({ + url: String(opts.url), + provider: opts.provider as string | undefined, + format: opts.format as string | undefined, + }); + emitJsonOrText(defaultRuntime, Boolean(opts.json), result, formatEnvelopeForText); + }); + }); + + web + .command("providers") + .description("List web providers") + .option("--json", "Output JSON", false) + .action(async (opts) => { + await runCommandWithRuntime(defaultRuntime, async () => { + const cfg = loadConfig(); + const selectedSearchProvider = + typeof cfg.tools?.web?.search?.provider === "string" + ? cfg.tools.web.search.provider.trim().toLowerCase() + : ""; + const selectedFetchProvider = + typeof cfg.tools?.web?.fetch?.provider === "string" + ? cfg.tools.web.fetch.provider.trim().toLowerCase() + : ""; + const result = { + search: listWebSearchProviders({ config: cfg }).map((provider) => ({ + available: true, + configured: isWebSearchProviderConfigured({ provider, config: cfg }), + selected: provider.id === selectedSearchProvider, + id: provider.id, + envVars: provider.envVars, + })), + fetch: listWebFetchProviders({ config: cfg }).map((provider) => ({ + available: true, + configured: isWebFetchProviderConfigured({ provider, config: cfg }), + selected: provider.id === selectedFetchProvider, + id: provider.id, + envVars: provider.envVars, + })), + }; + emitJsonOrText(defaultRuntime, Boolean(opts.json), result, (value) => + JSON.stringify(value, null, 2), + ); + }); + }); + + const embedding = capability.command("embedding").description("Embedding providers"); + + embedding + .command("create") + .description("Create embeddings") + .requiredOption("--text ", "Input text", collectOption, []) + .option("--provider ", "Provider id") + .option("--model ", "Model override") + .option("--json", "Output JSON", false) + .action(async (opts) => { + await runCommandWithRuntime(defaultRuntime, async () => { + const result = await runMemoryEmbeddingCreate({ + texts: opts.text as string[], + provider: opts.provider as string | undefined, + model: opts.model as string | undefined, + }); + emitJsonOrText(defaultRuntime, Boolean(opts.json), result, formatEnvelopeForText); + }); + }); + + embedding + .command("providers") + .description("List embedding providers") + .option("--json", "Output JSON", false) + .action(async (opts) => { + await runCommandWithRuntime(defaultRuntime, async () => { + ensureMemoryEmbeddingProvidersRegistered(); + const cfg = loadConfig(); + const agentId = resolveDefaultAgentId(cfg); + const resolvedMemory = resolveMemorySearchConfig(cfg, agentId); + const selectedProvider = + resolvedMemory?.provider && resolvedMemory.provider !== "auto" + ? resolvedMemory.provider + : undefined; + const autoSelectedProvider = + resolvedMemory?.provider === "auto" + ? ( + await createEmbeddingProvider({ + config: cfg, + agentDir: resolveAgentDir(cfg, agentId), + provider: "auto", + fallback: "none", + model: resolvedMemory.model, + local: resolvedMemory.local, + remote: resolvedMemory.remote, + outputDimensionality: resolvedMemory.outputDimensionality, + }).catch(() => ({ provider: null })) + )?.provider?.id + : undefined; + const result = listMemoryEmbeddingProviders().map((provider) => ({ + available: true, + configured: + provider.id === selectedProvider || + provider.id === autoSelectedProvider || + providerHasGenericConfig({ + cfg, + providerId: provider.id, + }), + selected: provider.id === selectedProvider || provider.id === autoSelectedProvider, + id: provider.id, + defaultModel: provider.defaultModel, + transport: provider.transport, + autoSelectPriority: provider.autoSelectPriority, + })); + emitJsonOrText(defaultRuntime, Boolean(opts.json), result, providerSummaryText); + }); + }); +} diff --git a/src/cli/program/register.subclis.test.ts b/src/cli/program/register.subclis.test.ts index c74aca02e3a..1beb89f6d22 100644 --- a/src/cli/program/register.subclis.test.ts +++ b/src/cli/program/register.subclis.test.ts @@ -26,9 +26,18 @@ const { registerQaCli } = vi.hoisted(() => ({ }), })); +const { inferAction, registerCapabilityCli } = vi.hoisted(() => { + const action = vi.fn(); + const register = vi.fn((program: Command) => { + program.command("infer").alias("capability").action(action); + }); + return { inferAction: action, registerCapabilityCli: register }; +}); + vi.mock("../acp-cli.js", () => ({ registerAcpCli })); vi.mock("../nodes-cli.js", () => ({ registerNodesCli })); vi.mock("../qa-cli.js", () => ({ registerQaCli })); +vi.mock("../capability-cli.js", () => ({ registerCapabilityCli })); describe("registerSubCliCommands", () => { const originalArgv = process.argv; @@ -54,6 +63,8 @@ describe("registerSubCliCommands", () => { acpAction.mockClear(); registerNodesCli.mockClear(); nodesAction.mockClear(); + registerCapabilityCli.mockClear(); + inferAction.mockClear(); }); afterEach(() => { @@ -98,6 +109,17 @@ describe("registerSubCliCommands", () => { expect(nodesAction).toHaveBeenCalledTimes(1); }); + it("registers the infer placeholder and dispatches through the capability registrar", async () => { + const program = createRegisteredProgram(["node", "openclaw", "infer"], "openclaw"); + + expect(program.commands.map((cmd) => cmd.name())).toEqual(["infer"]); + + await program.parseAsync(["infer"], { from: "user" }); + + expect(registerCapabilityCli).toHaveBeenCalledTimes(1); + expect(inferAction).toHaveBeenCalledTimes(1); + }); + it("replaces placeholder when registering a subcommand by name", async () => { const program = createRegisteredProgram(["node", "openclaw", "acp", "--help"], "openclaw"); diff --git a/src/cli/program/register.subclis.ts b/src/cli/program/register.subclis.ts index 121abde8445..997bd9c7880 100644 --- a/src/cli/program/register.subclis.ts +++ b/src/cli/program/register.subclis.ts @@ -74,6 +74,11 @@ const entrySpecs: readonly CommandGroupDescriptorSpec[] = [ loadModule: () => import("../models-cli.js"), exportName: "registerModelsCli", }, + { + commandNames: ["infer", "capability"], + loadModule: () => import("../capability-cli.js"), + exportName: "registerCapabilityCli", + }, { commandNames: ["approvals"], loadModule: () => import("../exec-approvals-cli.js"), diff --git a/src/cli/program/subcli-descriptors.ts b/src/cli/program/subcli-descriptors.ts index 64dc5b09ec0..01eabf96b6e 100644 --- a/src/cli/program/subcli-descriptors.ts +++ b/src/cli/program/subcli-descriptors.ts @@ -22,6 +22,16 @@ const subCliCommandCatalog = defineCommandDescriptorCatalog([ description: "Discover, scan, and configure models", hasSubcommands: true, }, + { + name: "infer", + description: "Run provider-backed inference commands", + hasSubcommands: true, + }, + { + name: "capability", + description: "Run provider-backed inference commands (fallback alias: infer)", + hasSubcommands: true, + }, { name: "approvals", description: "Manage exec approvals (gateway or node host)", diff --git a/src/gateway/server-methods/tts.test.ts b/src/gateway/server-methods/tts.test.ts new file mode 100644 index 00000000000..4b9aab920b6 --- /dev/null +++ b/src/gateway/server-methods/tts.test.ts @@ -0,0 +1,83 @@ +import { beforeEach, describe, expect, it, vi } from "vitest"; +import { ErrorCodes } from "../protocol/index.js"; + +const mocks = vi.hoisted(() => ({ + loadConfig: vi.fn(() => ({})), + resolveExplicitTtsOverrides: vi.fn(() => ({})), + textToSpeech: vi.fn(async () => ({ + success: true, + audioPath: "/tmp/tts.mp3", + provider: "openai", + outputFormat: "mp3", + voiceCompatible: false, + })), +})); + +vi.mock("../../config/config.js", () => ({ + loadConfig: (...args: unknown[]) => mocks.loadConfig(...args), +})); + +vi.mock("../../tts/provider-registry.js", () => ({ + canonicalizeSpeechProviderId: vi.fn(), + getSpeechProvider: vi.fn(), + listSpeechProviders: vi.fn(() => []), +})); + +vi.mock("../../tts/tts.js", () => ({ + getResolvedSpeechProviderConfig: vi.fn(), + getTtsProvider: vi.fn(() => "openai"), + isTtsEnabled: vi.fn(() => true), + isTtsProviderConfigured: vi.fn(() => true), + resolveExplicitTtsOverrides: (...args: unknown[]) => mocks.resolveExplicitTtsOverrides(...args), + resolveTtsAutoMode: vi.fn(() => false), + resolveTtsConfig: vi.fn(() => ({})), + resolveTtsPrefsPath: vi.fn(() => "/tmp/tts.json"), + resolveTtsProviderOrder: vi.fn(() => ["openai"]), + setTtsEnabled: vi.fn(), + setTtsProvider: vi.fn(), + textToSpeech: (...args: unknown[]) => mocks.textToSpeech(...args), +})); + +describe("ttsHandlers", () => { + beforeEach(() => { + mocks.loadConfig.mockReset(); + mocks.loadConfig.mockReturnValue({}); + mocks.resolveExplicitTtsOverrides.mockReset(); + mocks.resolveExplicitTtsOverrides.mockReturnValue({}); + mocks.textToSpeech.mockReset(); + mocks.textToSpeech.mockResolvedValue({ + success: true, + audioPath: "/tmp/tts.mp3", + provider: "openai", + outputFormat: "mp3", + voiceCompatible: false, + }); + }); + + it("returns INVALID_REQUEST when TTS override validation fails", async () => { + mocks.resolveExplicitTtsOverrides.mockImplementation(() => { + throw new Error('Unknown TTS provider "bad".'); + }); + + const { ttsHandlers } = await import("./tts.js"); + const respond = vi.fn(); + + await ttsHandlers["tts.convert"]({ + params: { + text: "hello", + provider: "bad", + }, + respond, + } as never); + + expect(respond).toHaveBeenCalledWith( + false, + undefined, + expect.objectContaining({ + code: ErrorCodes.INVALID_REQUEST, + message: 'Error: Unknown TTS provider "bad".', + }), + ); + expect(mocks.textToSpeech).not.toHaveBeenCalled(); + }); +}); diff --git a/src/gateway/server-methods/tts.ts b/src/gateway/server-methods/tts.ts index 28345c53e8c..cddb43aa9d2 100644 --- a/src/gateway/server-methods/tts.ts +++ b/src/gateway/server-methods/tts.ts @@ -9,6 +9,7 @@ import { getTtsProvider, isTtsEnabled, isTtsProviderConfigured, + resolveExplicitTtsOverrides, resolveTtsAutoMode, resolveTtsConfig, resolveTtsPrefsPath, @@ -89,7 +90,28 @@ export const ttsHandlers: GatewayRequestHandlers = { try { const cfg = loadConfig(); const channel = typeof params.channel === "string" ? params.channel.trim() : undefined; - const result = await textToSpeech({ text, cfg, channel }); + const providerRaw = typeof params.provider === "string" ? params.provider.trim() : undefined; + const modelId = typeof params.modelId === "string" ? params.modelId.trim() : undefined; + const voiceId = typeof params.voiceId === "string" ? params.voiceId.trim() : undefined; + let overrides; + try { + overrides = resolveExplicitTtsOverrides({ + cfg, + provider: providerRaw, + modelId, + voiceId, + }); + } catch (err) { + respond(false, undefined, errorShape(ErrorCodes.INVALID_REQUEST, formatForLog(err))); + return; + } + const result = await textToSpeech({ + text, + cfg, + channel, + overrides, + disableFallback: Boolean(overrides.provider || modelId || voiceId), + }); if (result.success && result.audioPath) { respond(true, { audioPath: result.audioPath, diff --git a/src/media-understanding/runner.auto-audio.test.ts b/src/media-understanding/runner.auto-audio.test.ts index 6c8ceeba584..cef0a92f258 100644 --- a/src/media-understanding/runner.auto-audio.test.ts +++ b/src/media-understanding/runner.auto-audio.test.ts @@ -121,6 +121,43 @@ describe("runCapability auto audio entries", () => { expect(seenModel).toBe("whisper-1"); }); + it("lets per-request transcription hints override configured model-entry hints", async () => { + let seenLanguage: string | undefined; + let seenPrompt: string | undefined; + const result = await runAutoAudioCase({ + transcribeAudio: async (req) => { + seenLanguage = req.language; + seenPrompt = req.prompt; + return { text: "ok", model: req.model ?? "unknown" }; + }, + cfgExtra: { + tools: { + media: { + audio: { + enabled: true, + prompt: "configured prompt", + language: "fr", + _requestPromptOverride: "Focus on names", + _requestLanguageOverride: "en", + models: [ + { + provider: "openai", + model: "whisper-1", + prompt: "entry prompt", + language: "de", + }, + ], + }, + }, + }, + } as Partial, + }); + + expect(result.outputs[0]?.text).toBe("ok"); + expect(seenLanguage).toBe("en"); + expect(seenPrompt).toBe("Focus on names"); + }); + it("uses mistral when only mistral key is configured", async () => { const isolatedAgentDir = await fs.mkdtemp(path.join(os.tmpdir(), "openclaw-audio-agent-")); let runResult: Awaited> | undefined; diff --git a/src/media-understanding/runner.cli-audio.test.ts b/src/media-understanding/runner.cli-audio.test.ts new file mode 100644 index 00000000000..3b96ff8a3d3 --- /dev/null +++ b/src/media-understanding/runner.cli-audio.test.ts @@ -0,0 +1,67 @@ +import { afterEach, beforeAll, beforeEach, describe, expect, it, vi } from "vitest"; +import type { OpenClawConfig } from "../config/config.js"; +import { withAudioFixture } from "./runner.test-utils.js"; + +const runExecMock = vi.hoisted(() => vi.fn()); + +vi.mock("../process/exec.js", () => ({ + runExec: (...args: unknown[]) => runExecMock(...args), +})); + +let runCliEntry: typeof import("./runner.entries.js").runCliEntry; + +describe("media-understanding CLI audio entry", () => { + beforeAll(async () => { + ({ runCliEntry } = await import("./runner.entries.js")); + }); + + beforeEach(() => { + runExecMock.mockReset().mockResolvedValue({ stdout: "cli transcript" }); + }); + + afterEach(() => { + vi.clearAllMocks(); + }); + + it("applies per-request prompt and language overrides to CLI transcription templating", async () => { + await withAudioFixture("openclaw-cli-audio", async ({ ctx, cache }) => { + await runCliEntry({ + capability: "audio", + entry: { + type: "cli", + command: "mock-transcriber", + args: ["--prompt", "{{Prompt}}", "--language", "{{Language}}", "--file", "{{MediaPath}}"], + prompt: "entry prompt", + language: "de", + }, + cfg: { + tools: { + media: { + audio: { + prompt: "configured prompt", + language: "fr", + _requestPromptOverride: "Focus on names", + _requestLanguageOverride: "en", + }, + }, + }, + } as OpenClawConfig, + ctx, + attachmentIndex: 0, + cache, + config: { + prompt: "configured prompt", + language: "fr", + _requestPromptOverride: "Focus on names", + _requestLanguageOverride: "en", + } as never, + }); + }); + + expect(runExecMock).toHaveBeenCalledWith( + "mock-transcriber", + expect.arrayContaining(["--prompt", "Focus on names", "--language", "en"]), + expect.any(Object), + ); + }); +}); diff --git a/src/media-understanding/runner.entries.ts b/src/media-understanding/runner.entries.ts index ff702a2eeed..2173f80649b 100644 --- a/src/media-understanding/runner.entries.ts +++ b/src/media-understanding/runner.entries.ts @@ -372,6 +372,20 @@ function resolveEntryRunOptions(params: { return { maxBytes, maxChars, timeoutMs, prompt }; } +function resolveAudioRequestOverrides(config: MediaUnderstandingConfig | undefined): { + prompt?: string; + language?: string; +} { + const overrides = (config ?? {}) as MediaUnderstandingConfig & { + _requestPromptOverride?: string; + _requestLanguageOverride?: string; + }; + return { + prompt: overrides._requestPromptOverride, + language: overrides._requestLanguageOverride, + }; +} + async function resolveProviderExecutionAuth(params: { providerId: string; cfg: OpenClawConfig; @@ -530,6 +544,7 @@ export async function runProviderEntry(params: { throw new Error(`Audio transcription provider "${providerId}" not available.`); } const transcribeAudio = provider.transcribeAudio; + const requestOverrides = resolveAudioRequestOverrides(params.config); const media = await params.cache.getBuffer({ attachmentIndex: params.attachmentIndex, maxBytes, @@ -569,8 +584,12 @@ export async function runProviderEntry(params: { headers, request, model, - language: entry.language ?? params.config?.language ?? cfg.tools?.media?.audio?.language, - prompt, + language: + requestOverrides.language ?? + entry.language ?? + params.config?.language ?? + cfg.tools?.media?.audio?.language, + prompt: requestOverrides.prompt ?? prompt, query: providerQuery, timeoutMs, fetchFn, @@ -651,6 +670,7 @@ export async function runCliEntry(params: { if (!command) { throw new Error(`CLI entry missing command for ${capability}`); } + const requestOverrides = resolveAudioRequestOverrides(params.config); const { maxBytes, maxChars, timeoutMs, prompt } = resolveEntryRunOptions({ capability, entry, @@ -683,7 +703,8 @@ export async function runCliEntry(params: { MediaDir: path.dirname(mediaPath), OutputDir: outputDir, OutputBase: outputBase, - Prompt: prompt, + Prompt: requestOverrides.prompt ?? prompt, + ...(requestOverrides.language ? { Language: requestOverrides.language } : {}), MaxChars: maxChars, }; const argv = [command, ...args].map((part, index) => diff --git a/src/media-understanding/runtime.ts b/src/media-understanding/runtime.ts index 9c7b647957f..f3142b2bf08 100644 --- a/src/media-understanding/runtime.ts +++ b/src/media-understanding/runtime.ts @@ -150,7 +150,28 @@ export async function transcribeAudioFile(params: { agentDir?: string; mime?: string; activeModel?: ActiveMediaModel; + language?: string; + prompt?: string; }): Promise<{ text: string | undefined }> { - const result = await runMediaUnderstandingFile({ ...params, capability: "audio" }); + const cfg = + params.language || params.prompt + ? { + ...params.cfg, + tools: { + ...params.cfg.tools, + media: { + ...params.cfg.tools?.media, + audio: { + ...params.cfg.tools?.media?.audio, + ...(params.language ? { _requestLanguageOverride: params.language } : {}), + ...(params.prompt ? { _requestPromptOverride: params.prompt } : {}), + ...(params.language ? { language: params.language } : {}), + ...(params.prompt ? { prompt: params.prompt } : {}), + }, + }, + }, + } + : params.cfg; + const result = await runMediaUnderstandingFile({ ...params, cfg, capability: "audio" }); return { text: result.text }; } diff --git a/src/plugin-sdk/tts-runtime.ts b/src/plugin-sdk/tts-runtime.ts index fa4f7b6b49a..e4e5fe03821 100644 --- a/src/plugin-sdk/tts-runtime.ts +++ b/src/plugin-sdk/tts-runtime.ts @@ -34,6 +34,8 @@ export const listSpeechVoices: FacadeModule["listSpeechVoices"] = createLazyFacadeValue("listSpeechVoices"); export const maybeApplyTtsToPayload: FacadeModule["maybeApplyTtsToPayload"] = createLazyFacadeValue("maybeApplyTtsToPayload"); +export const resolveExplicitTtsOverrides: FacadeModule["resolveExplicitTtsOverrides"] = + createLazyFacadeValue("resolveExplicitTtsOverrides"); export const resolveTtsAutoMode: FacadeModule["resolveTtsAutoMode"] = createLazyFacadeValue("resolveTtsAutoMode"); export const resolveTtsConfig: FacadeModule["resolveTtsConfig"] = diff --git a/src/tts/status-config.test.ts b/src/tts/status-config.test.ts index 5533e331af0..ee92fb044f9 100644 --- a/src/tts/status-config.test.ts +++ b/src/tts/status-config.test.ts @@ -1,6 +1,6 @@ import fs from "node:fs"; import path from "node:path"; -import { describe, expect, it } from "vitest"; +import { describe, expect, it, vi } from "vitest"; import { withTempHome } from "../../test/helpers/temp-home.js"; import type { OpenClawConfig } from "../config/config.js"; import { resolveStatusTtsSnapshot } from "./status-config.js"; @@ -61,4 +61,44 @@ describe("resolveStatusTtsSnapshot", () => { }); }); }); + + it("derives the default prefs path from OPENCLAW_CONFIG_PATH when set", async () => { + await withTempHome( + async (home) => { + const stateDir = path.join(home, ".openclaw-dev"); + const prefsPath = path.join(stateDir, "settings", "tts.json"); + fs.mkdirSync(path.dirname(prefsPath), { recursive: true }); + fs.writeFileSync( + prefsPath, + JSON.stringify({ + tts: { + auto: "always", + provider: "openai", + }, + }), + ); + + vi.stubEnv("OPENCLAW_CONFIG_PATH", path.join(stateDir, "openclaw.json")); + try { + expect( + resolveStatusTtsSnapshot({ + cfg: { + messages: { + tts: {}, + }, + } as OpenClawConfig, + }), + ).toEqual({ + autoMode: "always", + provider: "openai", + maxLength: 1500, + summarize: true, + }); + } finally { + vi.unstubAllEnvs(); + } + }, + { env: { OPENCLAW_STATE_DIR: undefined } }, + ); + }); }); diff --git a/src/tts/status-config.ts b/src/tts/status-config.ts index a9412cfa6c2..696638ab954 100644 --- a/src/tts/status-config.ts +++ b/src/tts/status-config.ts @@ -6,7 +6,7 @@ import { normalizeOptionalLowercaseString, normalizeOptionalString, } from "../shared/string-coerce.js"; -import { CONFIG_DIR, resolveUserPath } from "../utils.js"; +import { resolveConfigDir, resolveUserPath } from "../utils.js"; import { normalizeTtsAutoMode } from "./tts-auto-mode.js"; const DEFAULT_TTS_MAX_LENGTH = 1500; @@ -52,7 +52,7 @@ function resolveTtsPrefsPathValue(prefsPath: string | undefined): string { if (envPath) { return resolveUserPath(envPath); } - return path.join(CONFIG_DIR, "settings", "tts.json"); + return path.join(resolveConfigDir(process.env), "settings", "tts.json"); } function readPrefs(prefsPath: string): TtsUserPrefs { diff --git a/src/tts/tts.ts b/src/tts/tts.ts index 43f098f4504..5fb831e3631 100644 --- a/src/tts/tts.ts +++ b/src/tts/tts.ts @@ -10,6 +10,7 @@ export { isTtsProviderConfigured, listSpeechVoices, maybeApplyTtsToPayload, + resolveExplicitTtsOverrides, resolveTtsAutoMode, resolveTtsConfig, resolveTtsPrefsPath, diff --git a/src/utils.test.ts b/src/utils.test.ts index 1daa84408c4..5fa48d2166d 100644 --- a/src/utils.test.ts +++ b/src/utils.test.ts @@ -50,6 +50,15 @@ describe("resolveConfigDir", () => { expect(resolveConfigDir(env)).toBe(path.resolve("/tmp/openclaw-home", "state")); }); + + it("falls back to the config file directory when only OPENCLAW_CONFIG_PATH is set", () => { + const env = { + HOME: "/tmp/openclaw-home", + OPENCLAW_CONFIG_PATH: "~/profiles/dev/openclaw.json", + } as NodeJS.ProcessEnv; + + expect(resolveConfigDir(env)).toBe(path.resolve("/tmp/openclaw-home", "profiles", "dev")); + }); }); describe("resolveHomeDir", () => { diff --git a/src/utils.ts b/src/utils.ts index 7b5810f9e12..ca6206b3756 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -141,6 +141,10 @@ export function resolveConfigDir( if (override) { return resolveUserPath(override, env, homedir); } + const configPath = env.OPENCLAW_CONFIG_PATH?.trim(); + if (configPath) { + return path.dirname(resolveUserPath(configPath, env, homedir)); + } const newDir = path.join(resolveRequiredHomeDir(env, homedir), ".openclaw"); try { const hasNew = fs.existsSync(newDir); diff --git a/src/web-fetch/runtime.ts b/src/web-fetch/runtime.ts index 3c400fa8f1b..4036871989c 100644 --- a/src/web-fetch/runtime.ts +++ b/src/web-fetch/runtime.ts @@ -61,6 +61,16 @@ function hasEntryCredential( }); } +export function isWebFetchProviderConfigured(params: { + provider: Pick< + PluginWebFetchProviderEntry, + "envVars" | "getConfiguredCredentialValue" | "getCredentialValue" | "requiresCredential" + >; + config?: OpenClawConfig; +}): boolean { + return hasEntryCredential(params.provider, params.config, resolveFetchConfig(params.config)); +} + export function listWebFetchProviders(params?: { config?: OpenClawConfig; }): PluginWebFetchProviderEntry[] { diff --git a/src/web-search/runtime.test.ts b/src/web-search/runtime.test.ts index 75b3ac30391..d13fc8c604f 100644 --- a/src/web-search/runtime.test.ts +++ b/src/web-search/runtime.test.ts @@ -289,4 +289,289 @@ describe("web search runtime", () => { result: { query: "runtime", provider: "beta", runtimeSelectedProvider: "beta" }, }); }); + + it("falls back to another provider when auto-selected search execution fails", async () => { + resolveRuntimeWebSearchProvidersMock.mockReturnValue([ + createProvider({ + pluginId: "google", + id: "google", + credentialPath: "tools.web.search.google.apiKey", + autoDetectOrder: 1, + getCredentialValue: () => "configured", + createTool: () => ({ + description: "google", + parameters: {}, + execute: async () => { + throw new Error("google aborted"); + }, + }), + }), + createProvider({ + pluginId: "duckduckgo", + id: "duckduckgo", + credentialPath: "", + autoDetectOrder: 100, + requiresCredential: false, + createTool: () => ({ + description: "duckduckgo", + parameters: {}, + execute: async (args) => ({ ...args, provider: "duckduckgo" }), + }), + }), + ]); + + await expect( + runWebSearch({ + config: {}, + args: { query: "fallback" }, + }), + ).resolves.toEqual({ + provider: "duckduckgo", + result: { query: "fallback", provider: "duckduckgo" }, + }); + }); + + it("does not prebuild fallback provider tools before attempting the selected provider", async () => { + resolveRuntimeWebSearchProvidersMock.mockReturnValue([ + createProvider({ + pluginId: "google", + id: "google", + credentialPath: "tools.web.search.google.apiKey", + autoDetectOrder: 1, + getCredentialValue: () => "configured", + createTool: () => ({ + description: "google", + parameters: {}, + execute: async (args) => ({ ...args, provider: "google" }), + }), + }), + createProvider({ + pluginId: "broken-fallback", + id: "broken-fallback", + credentialPath: "", + autoDetectOrder: 100, + requiresCredential: false, + createTool: () => { + throw new Error("fallback createTool exploded"); + }, + }), + ]); + + await expect( + runWebSearch({ + config: {}, + args: { query: "selected-first" }, + }), + ).resolves.toEqual({ + provider: "google", + result: { query: "selected-first", provider: "google" }, + }); + }); + + it("does not fall back when the provider came from explicit config selection", async () => { + resolveRuntimeWebSearchProvidersMock.mockReturnValue([ + createProvider({ + pluginId: "google", + id: "google", + credentialPath: "tools.web.search.google.apiKey", + autoDetectOrder: 1, + getCredentialValue: () => "configured", + createTool: () => ({ + description: "google", + parameters: {}, + execute: async () => { + throw new Error("google aborted"); + }, + }), + }), + createProvider({ + pluginId: "duckduckgo", + id: "duckduckgo", + credentialPath: "", + autoDetectOrder: 100, + requiresCredential: false, + createTool: () => ({ + description: "duckduckgo", + parameters: {}, + execute: async (args) => ({ ...args, provider: "duckduckgo" }), + }), + }), + ]); + + await expect( + runWebSearch({ + config: { + tools: { + web: { + search: { + provider: "google", + }, + }, + }, + }, + args: { query: "configured" }, + }), + ).rejects.toThrow("google aborted"); + }); + + it("does not fall back when the caller explicitly selects a provider", async () => { + resolveRuntimeWebSearchProvidersMock.mockReturnValue([ + createProvider({ + pluginId: "google", + id: "google", + credentialPath: "tools.web.search.google.apiKey", + autoDetectOrder: 1, + getCredentialValue: () => "configured", + createTool: () => ({ + description: "google", + parameters: {}, + execute: async () => { + throw new Error("google aborted"); + }, + }), + }), + createProvider({ + pluginId: "duckduckgo", + id: "duckduckgo", + credentialPath: "", + autoDetectOrder: 100, + requiresCredential: false, + }), + ]); + + await expect( + runWebSearch({ + config: {}, + providerId: "google", + args: { query: "explicit" }, + }), + ).rejects.toThrow("google aborted"); + }); + + it("fails fast when an explicit provider cannot create a tool", async () => { + resolveRuntimeWebSearchProvidersMock.mockReturnValue([ + createProvider({ + pluginId: "google", + id: "google", + credentialPath: "tools.web.search.google.apiKey", + autoDetectOrder: 1, + getCredentialValue: () => "configured", + createTool: () => null, + }), + createProvider({ + pluginId: "duckduckgo", + id: "duckduckgo", + credentialPath: "", + autoDetectOrder: 100, + requiresCredential: false, + }), + ]); + + await expect( + runWebSearch({ + config: {}, + providerId: "google", + args: { query: "explicit-null-tool" }, + }), + ).rejects.toThrow('web_search provider "google" is not available.'); + }); + + it("fails fast when the caller explicitly selects an unknown provider", async () => { + resolveRuntimeWebSearchProvidersMock.mockReturnValue([ + createProvider({ + pluginId: "google", + id: "google", + credentialPath: "tools.web.search.google.apiKey", + autoDetectOrder: 1, + getCredentialValue: () => "configured", + }), + createProvider({ + pluginId: "duckduckgo", + id: "duckduckgo", + credentialPath: "", + autoDetectOrder: 100, + requiresCredential: false, + }), + ]); + + await expect( + runWebSearch({ + config: {}, + providerId: "missing-id", + args: { query: "explicit-missing" }, + }), + ).rejects.toThrow('Unknown web_search provider "missing-id".'); + }); + + it("honors preferRuntimeProviders during execution", async () => { + const configuredProvider = createProvider({ + pluginId: "google", + id: "google", + credentialPath: "tools.web.search.google.apiKey", + autoDetectOrder: 1, + getCredentialValue: () => "configured", + }); + const runtimeProvider = createProvider({ + pluginId: "runtime-search", + id: "runtime-search", + credentialPath: "", + autoDetectOrder: 0, + requiresCredential: false, + }); + resolveRuntimeWebSearchProvidersMock.mockReturnValue([configuredProvider, runtimeProvider]); + resolvePluginWebSearchProvidersMock.mockReturnValue([configuredProvider]); + + await expect( + runWebSearch({ + config: { + tools: { + web: { + search: { + provider: "google", + }, + }, + }, + }, + runtimeWebSearch: { + enabled: true, + providerConfigured: "runtime-search", + selectedProvider: "runtime-search", + providerSource: "runtime", + }, + preferRuntimeProviders: false, + args: { query: "prefer-config" }, + }), + ).resolves.toEqual({ + provider: "google", + result: { query: "prefer-config", provider: "google" }, + }); + }); + + it("returns a clear error when every fallback-capable provider is unavailable", async () => { + resolveRuntimeWebSearchProvidersMock.mockReturnValue([ + createProvider({ + pluginId: "google", + id: "google", + credentialPath: "tools.web.search.google.apiKey", + autoDetectOrder: 1, + getCredentialValue: () => "configured", + createTool: () => null, + }), + createProvider({ + pluginId: "duckduckgo", + id: "duckduckgo", + credentialPath: "", + autoDetectOrder: 100, + requiresCredential: false, + createTool: () => null, + }), + ]); + + await expect( + runWebSearch({ + config: {}, + args: { query: "all-null-tools" }, + }), + ).rejects.toThrow("web_search is enabled but no provider is currently available."); + }); }); diff --git a/src/web-search/runtime.ts b/src/web-search/runtime.ts index 9bfc289ff42..682de0ba150 100644 --- a/src/web-search/runtime.ts +++ b/src/web-search/runtime.ts @@ -79,6 +79,21 @@ function hasEntryCredential( }); } +export function isWebSearchProviderConfigured(params: { + provider: Pick< + PluginWebSearchProviderEntry, + | "credentialPath" + | "id" + | "envVars" + | "getConfiguredCredentialValue" + | "getCredentialValue" + | "requiresCredential" + >; + config?: OpenClawConfig; +}): boolean { + return hasEntryCredential(params.provider, params.config, resolveSearchConfig(params.config)); +} + export function listWebSearchProviders(params?: { config?: OpenClawConfig; }): PluginWebSearchProviderEntry[] { @@ -198,21 +213,130 @@ export function resolveWebSearchDefinition( }); } +function resolveWebSearchCandidates( + options?: ResolveWebSearchDefinitionParams, +): PluginWebSearchProviderEntry[] { + const search = resolveSearchConfig(options?.config); + const runtimeWebSearch = options?.runtimeWebSearch ?? getActiveRuntimeWebToolsMetadata()?.search; + if (!resolveWebSearchEnabled({ search, sandboxed: options?.sandboxed })) { + return []; + } + + const providers = sortWebSearchProvidersForAutoDetect( + options?.preferRuntimeProviders + ? resolveRuntimeWebSearchProviders({ + config: options?.config, + bundledAllowlistCompat: true, + }) + : resolvePluginWebSearchProviders({ + config: options?.config, + bundledAllowlistCompat: true, + origin: "bundled", + }), + ).filter(Boolean); + if (providers.length === 0) { + return []; + } + + const preferredIds = [ + options?.providerId, + runtimeWebSearch?.selectedProvider, + runtimeWebSearch?.providerConfigured, + resolveWebSearchProviderId({ config: options?.config, search, providers }), + ].filter( + (value, index, array): value is string => Boolean(value) && array.indexOf(value) === index, + ); + + const explicitProviderId = options?.providerId?.trim(); + if (explicitProviderId && !providers.some((entry) => entry.id === explicitProviderId)) { + throw new Error(`Unknown web_search provider "${explicitProviderId}".`); + } + + const orderedProviders = [ + ...preferredIds + .map((id) => providers.find((entry) => entry.id === id)) + .filter((entry): entry is PluginWebSearchProviderEntry => Boolean(entry)), + ...providers.filter((entry) => !preferredIds.includes(entry.id)), + ]; + return orderedProviders; +} + +function hasExplicitWebSearchSelection(params: { + search?: WebSearchConfig; + runtimeWebSearch?: RuntimeWebSearchMetadata; + providerId?: string; +}): boolean { + if (params.providerId?.trim()) { + return true; + } + if ( + params.search && + "provider" in params.search && + typeof params.search.provider === "string" && + params.search.provider.trim() + ) { + return true; + } + return params.runtimeWebSearch?.providerSource === "configured"; +} + export async function runWebSearch( params: RunWebSearchParams, ): Promise<{ provider: string; result: Record }> { - const resolved = resolveWebSearchDefinition({ ...params, preferRuntimeProviders: true }); - if (!resolved) { + const search = resolveSearchConfig(params.config); + const runtimeWebSearch = params.runtimeWebSearch ?? getActiveRuntimeWebToolsMetadata()?.search; + const candidates = resolveWebSearchCandidates({ + ...params, + runtimeWebSearch, + preferRuntimeProviders: params.preferRuntimeProviders ?? true, + }); + if (candidates.length === 0) { throw new Error("web_search is disabled or no provider is available."); } - return { - provider: resolved.provider.id, - result: await resolved.definition.execute(params.args), - }; + const allowFallback = !hasExplicitWebSearchSelection({ + search, + runtimeWebSearch, + providerId: params.providerId, + }); + let lastError: unknown; + let sawUnavailableProvider = false; + + for (const candidate of candidates) { + try { + const definition = candidate.createTool({ + config: params.config, + searchConfig: search as Record | undefined, + runtimeMetadata: runtimeWebSearch, + }); + if (!definition) { + if (!allowFallback) { + throw new Error(`web_search provider "${candidate.id}" is not available.`); + } + sawUnavailableProvider = true; + continue; + } + return { + provider: candidate.id, + result: await definition.execute(params.args), + }; + } catch (error) { + lastError = error; + if (!allowFallback) { + throw error; + } + } + } + + if (sawUnavailableProvider && lastError === undefined) { + throw new Error("web_search is enabled but no provider is currently available."); + } + throw lastError instanceof Error ? lastError : new Error(String(lastError)); } export const __testing = { resolveSearchConfig, resolveSearchProvider: resolveWebSearchProviderId, resolveWebSearchProviderId, + resolveWebSearchCandidates, + hasExplicitWebSearchSelection, };