diff --git a/src/tts/provider-registry-core.ts b/src/tts/provider-registry-core.ts new file mode 100644 index 00000000000..6d5315c7e69 --- /dev/null +++ b/src/tts/provider-registry-core.ts @@ -0,0 +1,58 @@ +import type { OpenClawConfig } from "../config/types.js"; +import { + buildCapabilityProviderMaps, + normalizeCapabilityProviderId, +} from "../plugins/provider-registry-shared.js"; +import type { SpeechProviderPlugin } from "../plugins/types.js"; +import type { SpeechProviderId } from "./provider-types.js"; + +export type SpeechProviderRegistryResolver = { + getProvider: (providerId: string, cfg?: OpenClawConfig) => SpeechProviderPlugin | undefined; + listProviders: (cfg?: OpenClawConfig) => SpeechProviderPlugin[]; +}; + +export function normalizeSpeechProviderId( + providerId: string | undefined, +): SpeechProviderId | undefined { + return normalizeCapabilityProviderId(providerId); +} + +export function createSpeechProviderRegistry(resolver: SpeechProviderRegistryResolver) { + const buildResolvedProviderMaps = (cfg?: OpenClawConfig) => + buildCapabilityProviderMaps(resolver.listProviders(cfg)); + + const listProviders = (cfg?: OpenClawConfig): SpeechProviderPlugin[] => [ + ...buildResolvedProviderMaps(cfg).canonical.values(), + ]; + + const getProvider = ( + providerId: string | undefined, + cfg?: OpenClawConfig, + ): SpeechProviderPlugin | undefined => { + const normalized = normalizeSpeechProviderId(providerId); + if (!normalized) { + return undefined; + } + return ( + resolver.getProvider(normalized, cfg) ?? + buildResolvedProviderMaps(cfg).aliases.get(normalized) + ); + }; + + const canonicalizeProviderId = ( + providerId: string | undefined, + cfg?: OpenClawConfig, + ): SpeechProviderId | undefined => { + const normalized = normalizeSpeechProviderId(providerId); + if (!normalized) { + return undefined; + } + return getProvider(normalized, cfg)?.id ?? normalized; + }; + + return { + canonicalizeSpeechProviderId: canonicalizeProviderId, + getSpeechProvider: getProvider, + listSpeechProviders: listProviders, + }; +} diff --git a/src/tts/provider-registry.test.ts b/src/tts/provider-registry.test.ts index a739af22b65..9584f0dadf9 100644 --- a/src/tts/provider-registry.test.ts +++ b/src/tts/provider-registry.test.ts @@ -1,19 +1,10 @@ -import { beforeAll, beforeEach, describe, expect, it, vi } from "vitest"; +import { beforeEach, describe, expect, it } from "vitest"; import type { OpenClawConfig } from "../config/types.js"; import type { SpeechProviderPlugin } from "../plugins/types.js"; - -const resolvePluginCapabilityProviderMock = vi.hoisted(() => vi.fn()); -const resolvePluginCapabilityProvidersMock = vi.hoisted(() => vi.fn()); - -vi.mock("../plugins/capability-provider-runtime.js", () => ({ - resolvePluginCapabilityProvider: resolvePluginCapabilityProviderMock, - resolvePluginCapabilityProviders: resolvePluginCapabilityProvidersMock, -})); - -let getSpeechProvider: typeof import("./provider-registry.js").getSpeechProvider; -let listSpeechProviders: typeof import("./provider-registry.js").listSpeechProviders; -let canonicalizeSpeechProviderId: typeof import("./provider-registry.js").canonicalizeSpeechProviderId; -let normalizeSpeechProviderId: typeof import("./provider-registry.js").normalizeSpeechProviderId; +import { + createSpeechProviderRegistry, + normalizeSpeechProviderId, +} from "./provider-registry-core.js"; function createSpeechProvider(id: string, aliases?: string[]): SpeechProviderPlugin { return { @@ -31,59 +22,57 @@ function createSpeechProvider(id: string, aliases?: string[]): SpeechProviderPlu } describe("speech provider registry", () => { - beforeAll(async () => { - vi.resetModules(); - ({ - getSpeechProvider, - listSpeechProviders, - canonicalizeSpeechProviderId, - normalizeSpeechProviderId, - } = await import("./provider-registry.js")); - }); + const getProviderCalls: Array<{ providerId: string; cfg?: OpenClawConfig }> = []; + const listProvidersCalls: Array<{ cfg?: OpenClawConfig }> = []; + let providers: SpeechProviderPlugin[] = []; + let directProvider: SpeechProviderPlugin | undefined; + let registry: ReturnType; beforeEach(() => { - resolvePluginCapabilityProviderMock.mockReset(); - resolvePluginCapabilityProviderMock.mockReturnValue(undefined); - resolvePluginCapabilityProvidersMock.mockReset(); - resolvePluginCapabilityProvidersMock.mockReturnValue([]); + providers = []; + directProvider = undefined; + getProviderCalls.length = 0; + listProvidersCalls.length = 0; + registry = createSpeechProviderRegistry({ + getProvider: (providerId, cfg) => { + getProviderCalls.push({ providerId, cfg }); + return directProvider; + }, + listProviders: (cfg) => { + listProvidersCalls.push({ cfg }); + return providers; + }, + }); }); it("lists providers from the speech capability runtime", () => { const cfg = {} as OpenClawConfig; - resolvePluginCapabilityProvidersMock.mockReturnValue([createSpeechProvider("demo-speech")]); + providers = [createSpeechProvider("demo-speech")]; - expect(listSpeechProviders(cfg).map((provider) => provider.id)).toEqual(["demo-speech"]); - expect(resolvePluginCapabilityProvidersMock).toHaveBeenCalledWith({ - key: "speechProviders", - cfg, - }); + expect(registry.listSpeechProviders(cfg).map((provider) => provider.id)).toEqual([ + "demo-speech", + ]); + expect(listProvidersCalls).toEqual([{ cfg }]); }); it("gets providers by normalized id through the capability runtime", () => { const cfg = {} as OpenClawConfig; - const provider = createSpeechProvider("microsoft", ["edge"]); - resolvePluginCapabilityProviderMock.mockReturnValue(provider); + directProvider = createSpeechProvider("microsoft", ["edge"]); - expect(getSpeechProvider(" MICROSOFT ", cfg)).toBe(provider); - expect(resolvePluginCapabilityProviderMock).toHaveBeenCalledWith({ - key: "speechProviders", - providerId: "microsoft", - cfg, - }); + expect(registry.getSpeechProvider(" MICROSOFT ", cfg)).toBe(directProvider); + expect(getProviderCalls).toEqual([{ providerId: "microsoft", cfg }]); }); it("canonicalizes aliases from listed providers when direct lookup misses", () => { - resolvePluginCapabilityProvidersMock.mockReturnValue([ - createSpeechProvider("microsoft", ["edge"]), - ]); + providers = [createSpeechProvider("microsoft", ["edge"])]; expect(normalizeSpeechProviderId("edge")).toBe("edge"); - expect(canonicalizeSpeechProviderId("edge")).toBe("microsoft"); + expect(registry.canonicalizeSpeechProviderId("edge")).toBe("microsoft"); }); it("returns empty results when the capability runtime has no speech providers", () => { - expect(listSpeechProviders()).toEqual([]); - expect(getSpeechProvider("demo-speech")).toBeUndefined(); - expect(canonicalizeSpeechProviderId("demo-speech")).toBe("demo-speech"); + expect(registry.listSpeechProviders()).toEqual([]); + expect(registry.getSpeechProvider("demo-speech")).toBeUndefined(); + expect(registry.canonicalizeSpeechProviderId("demo-speech")).toBe("demo-speech"); }); }); diff --git a/src/tts/provider-registry.ts b/src/tts/provider-registry.ts index f43a746fd87..b673c21790e 100644 --- a/src/tts/provider-registry.ts +++ b/src/tts/provider-registry.ts @@ -3,18 +3,12 @@ import { resolvePluginCapabilityProvider, resolvePluginCapabilityProviders, } from "../plugins/capability-provider-runtime.js"; -import { - buildCapabilityProviderMaps, - normalizeCapabilityProviderId, -} from "../plugins/provider-registry-shared.js"; import type { SpeechProviderPlugin } from "../plugins/types.js"; -import type { SpeechProviderId } from "./provider-types.js"; - -export function normalizeSpeechProviderId( - providerId: string | undefined, -): SpeechProviderId | undefined { - return normalizeCapabilityProviderId(providerId); -} +export { normalizeSpeechProviderId } from "./provider-registry-core.js"; +import { + createSpeechProviderRegistry, + type SpeechProviderRegistryResolver, +} from "./provider-registry-core.js"; function resolveSpeechProviderPluginEntries(cfg?: OpenClawConfig): SpeechProviderPlugin[] { return resolvePluginCapabilityProviders({ @@ -23,41 +17,21 @@ function resolveSpeechProviderPluginEntries(cfg?: OpenClawConfig): SpeechProvide }); } -function buildProviderMaps(cfg?: OpenClawConfig): { - canonical: Map; - aliases: Map; -} { - return buildCapabilityProviderMaps(resolveSpeechProviderPluginEntries(cfg)); -} - -export function listSpeechProviders(cfg?: OpenClawConfig): SpeechProviderPlugin[] { - return [...buildProviderMaps(cfg).canonical.values()]; -} - -export function getSpeechProvider( - providerId: string | undefined, - cfg?: OpenClawConfig, -): SpeechProviderPlugin | undefined { - const normalized = normalizeSpeechProviderId(providerId); - if (!normalized) { - return undefined; - } - return ( +const defaultSpeechProviderRegistryResolver: SpeechProviderRegistryResolver = { + getProvider: (providerId, cfg) => resolvePluginCapabilityProvider({ key: "speechProviders", - providerId: normalized, + providerId, cfg, - }) ?? buildProviderMaps(cfg).aliases.get(normalized) - ); -} + }), + listProviders: resolveSpeechProviderPluginEntries, +}; -export function canonicalizeSpeechProviderId( - providerId: string | undefined, - cfg?: OpenClawConfig, -): SpeechProviderId | undefined { - const normalized = normalizeSpeechProviderId(providerId); - if (!normalized) { - return undefined; - } - return getSpeechProvider(normalized, cfg)?.id ?? normalized; -} +const defaultSpeechProviderRegistry = createSpeechProviderRegistry( + defaultSpeechProviderRegistryResolver, +); + +export const listSpeechProviders = defaultSpeechProviderRegistry.listSpeechProviders; +export const getSpeechProvider = defaultSpeechProviderRegistry.getSpeechProvider; +export const canonicalizeSpeechProviderId = + defaultSpeechProviderRegistry.canonicalizeSpeechProviderId; diff --git a/test/vitest/vitest.unit-fast-paths.mjs b/test/vitest/vitest.unit-fast-paths.mjs index cd7ba821f3a..7b39e73e1bb 100644 --- a/test/vitest/vitest.unit-fast-paths.mjs +++ b/test/vitest/vitest.unit-fast-paths.mjs @@ -93,6 +93,7 @@ export const forcedUnitFastTestFiles = [ "src/security/windows-acl.test.ts", "src/realtime-transcription/websocket-session.test.ts", "src/trajectory/export.test.ts", + "src/tts/provider-registry.test.ts", "src/tts/status-config.test.ts", "src/version.test.ts", ];