diff --git a/src/extension-host/runtime-backend-arbitration.test.ts b/src/extension-host/runtime-backend-arbitration.test.ts new file mode 100644 index 00000000000..a25c7e6dc5f --- /dev/null +++ b/src/extension-host/runtime-backend-arbitration.test.ts @@ -0,0 +1,92 @@ +import { describe, expect, it } from "vitest"; +import { + listExtensionHostRuntimeBackendCandidatesByArbitration, + listExtensionHostRuntimeBackendIdsByArbitration, + resolveExtensionHostRuntimeBackendOrderByArbitration, +} from "./runtime-backend-arbitration.js"; + +const entries = [ + { + id: "capability.runtime-backend:embedding:local", + family: "capability.runtime-backend", + subsystemId: "embedding", + backendId: "local", + source: "builtin", + defaultRank: 0, + selectorKeys: ["local"], + capabilities: ["embed.query", "embed.batch"], + metadata: { autoSelectable: true }, + }, + { + id: "capability.runtime-backend:embedding:openai", + family: "capability.runtime-backend", + subsystemId: "embedding", + backendId: "openai", + source: "builtin", + defaultRank: 1, + selectorKeys: ["openai"], + capabilities: ["embed.query", "embed.batch"], + metadata: { autoSelectable: true }, + }, + { + id: "capability.runtime-backend:embedding:custom", + family: "capability.runtime-backend", + subsystemId: "embedding", + backendId: "custom", + source: "builtin", + defaultRank: 2, + selectorKeys: ["custom"], + capabilities: ["embed.query", "embed.batch"], + metadata: { autoSelectable: false }, + }, + { + id: "capability.runtime-backend:tts:edge", + family: "capability.runtime-backend", + subsystemId: "tts", + backendId: "edge", + source: "builtin", + defaultRank: 1, + selectorKeys: ["edge"], + capabilities: ["tts.synthesis"], + }, +] as const; + +describe("runtime backend arbitration", () => { + it("keeps candidates ranked by default rank inside a subsystem", () => { + expect( + listExtensionHostRuntimeBackendCandidatesByArbitration({ + entries, + subsystemId: "embedding", + }).map((entry) => entry.backendId), + ).toEqual(["local", "openai", "custom"]); + }); + + it("supports filtered runtime-family arbitration", () => { + expect( + listExtensionHostRuntimeBackendIdsByArbitration({ + entries, + subsystemId: "embedding", + include: (entry) => entry.metadata?.autoSelectable === true && entry.backendId !== "local", + }), + ).toEqual(["openai"]); + }); + + it("keeps the preferred backend first without duplicating ranked entries", () => { + expect( + resolveExtensionHostRuntimeBackendOrderByArbitration({ + entries, + subsystemId: "embedding", + preferredBackendId: "openai", + }), + ).toEqual(["openai", "local", "custom"]); + + expect( + resolveExtensionHostRuntimeBackendOrderByArbitration({ + entries, + subsystemId: "embedding", + preferredBackendId: "fallback-only", + include: (entry) => entry.metadata?.autoSelectable === true, + }), + ).toEqual(["fallback-only", "local", "openai"]); + }); +}); diff --git a/src/extension-host/runtime-backend-arbitration.ts b/src/extension-host/runtime-backend-arbitration.ts new file mode 100644 index 00000000000..41f64905be3 --- /dev/null +++ b/src/extension-host/runtime-backend-arbitration.ts @@ -0,0 +1,45 @@ +import type { + ExtensionHostRuntimeBackendCatalogEntry, + ExtensionHostRuntimeBackendSubsystemId, +} from "./runtime-backend-catalog.js"; + +type ExtensionHostRuntimeBackendArbitrationPredicate = ( + entry: ExtensionHostRuntimeBackendCatalogEntry, +) => boolean; + +export function listExtensionHostRuntimeBackendCandidatesByArbitration(params: { + entries: readonly ExtensionHostRuntimeBackendCatalogEntry[]; + subsystemId: ExtensionHostRuntimeBackendSubsystemId; + include?: ExtensionHostRuntimeBackendArbitrationPredicate; +}): readonly ExtensionHostRuntimeBackendCatalogEntry[] { + const include = params.include ?? (() => true); + return params.entries + .filter((entry) => entry.subsystemId === params.subsystemId && include(entry)) + .toSorted((left, right) => left.defaultRank - right.defaultRank); +} + +export function listExtensionHostRuntimeBackendIdsByArbitration(params: { + entries: readonly ExtensionHostRuntimeBackendCatalogEntry[]; + subsystemId: ExtensionHostRuntimeBackendSubsystemId; + include?: ExtensionHostRuntimeBackendArbitrationPredicate; +}): readonly string[] { + return listExtensionHostRuntimeBackendCandidatesByArbitration(params).map( + (entry) => entry.backendId, + ); +} + +export function resolveExtensionHostRuntimeBackendOrderByArbitration(params: { + entries: readonly ExtensionHostRuntimeBackendCatalogEntry[]; + subsystemId: ExtensionHostRuntimeBackendSubsystemId; + preferredBackendId: string; + include?: ExtensionHostRuntimeBackendArbitrationPredicate; +}): readonly string[] { + const ordered = listExtensionHostRuntimeBackendIdsByArbitration(params); + if (!ordered.includes(params.preferredBackendId)) { + return [params.preferredBackendId, ...ordered]; + } + return [ + params.preferredBackendId, + ...ordered.filter((backendId) => backendId !== params.preferredBackendId), + ]; +} diff --git a/src/extension-host/runtime-backend-catalog.ts b/src/extension-host/runtime-backend-catalog.ts index 821ebdc4348..4416cd0500c 100644 --- a/src/extension-host/runtime-backend-catalog.ts +++ b/src/extension-host/runtime-backend-catalog.ts @@ -12,6 +12,10 @@ import { normalizeExtensionHostMediaProviderId, resolveExtensionHostMediaRuntimeDefaultModelMetadata, } from "./media-runtime-backends.js"; +import { + listExtensionHostRuntimeBackendIdsByArbitration, + resolveExtensionHostRuntimeBackendOrderByArbitration, +} from "./runtime-backend-arbitration.js"; import { listExtensionHostTtsRuntimeBackends } from "./tts-runtime-backends.js"; export const EXTENSION_HOST_RUNTIME_BACKEND_FAMILY = "capability.runtime-backend"; @@ -78,9 +82,11 @@ export function listExtensionHostEmbeddingRuntimeBackendCatalogEntries(): readon } export function listExtensionHostEmbeddingRemoteRuntimeBackendIds(): readonly EmbeddingProviderId[] { - return listExtensionHostEmbeddingRuntimeBackendCatalogEntries() - .filter((entry) => entry.backendId !== "local" && entry.metadata?.autoSelectable === true) - .map((entry) => entry.backendId as EmbeddingProviderId); + return listExtensionHostRuntimeBackendIdsByArbitration({ + entries: listExtensionHostEmbeddingRuntimeBackendCatalogEntries(), + subsystemId: "embedding", + include: (entry) => entry.backendId !== "local" && entry.metadata?.autoSelectable === true, + }).map((entry) => entry as EmbeddingProviderId); } export function listExtensionHostMediaRuntimeBackendCatalogEntries(): readonly ExtensionHostRuntimeBackendCatalogEntry[] { @@ -117,10 +123,11 @@ export function listExtensionHostMediaAutoRuntimeBackendIds( capability: MediaUnderstandingCapability, ): readonly string[] { const subsystemId = mapMediaCapabilityToSubsystem(capability); - return listExtensionHostMediaRuntimeBackendCatalogEntries() - .filter((entry) => entry.subsystemId === subsystemId && entry.metadata?.autoSelectable === true) - .toSorted((left, right) => left.defaultRank - right.defaultRank) - .map((entry) => entry.backendId); + return listExtensionHostRuntimeBackendIdsByArbitration({ + entries: listExtensionHostMediaRuntimeBackendCatalogEntries(), + subsystemId, + include: (entry) => entry.metadata?.autoSelectable === true, + }); } export function resolveExtensionHostMediaRuntimeDefaultModel(params: { @@ -163,21 +170,21 @@ export function listExtensionHostTtsRuntimeBackendIds(): readonly TtsProvider[] export function listExtensionHostRuntimeBackendIdsForSubsystem( subsystemId: ExtensionHostRuntimeBackendSubsystemId, ): readonly string[] { - return listExtensionHostRuntimeBackendCatalogEntries() - .filter((entry) => entry.subsystemId === subsystemId) - .toSorted((left, right) => left.defaultRank - right.defaultRank) - .map((entry) => entry.backendId); + return listExtensionHostRuntimeBackendIdsByArbitration({ + entries: listExtensionHostRuntimeBackendCatalogEntries(), + subsystemId, + }); } export function resolveExtensionHostRuntimeBackendOrderForSubsystem( subsystemId: ExtensionHostRuntimeBackendSubsystemId, preferredBackendId: string, ): readonly string[] { - const ordered = listExtensionHostRuntimeBackendIdsForSubsystem(subsystemId); - if (!ordered.includes(preferredBackendId)) { - return [preferredBackendId, ...ordered]; - } - return [preferredBackendId, ...ordered.filter((backendId) => backendId !== preferredBackendId)]; + return resolveExtensionHostRuntimeBackendOrderByArbitration({ + entries: listExtensionHostRuntimeBackendCatalogEntries(), + subsystemId, + preferredBackendId, + }); } export function listExtensionHostMediaRuntimeBackendIds( diff --git a/src/extension-host/tts-runtime-registry.ts b/src/extension-host/tts-runtime-registry.ts index 7af49fcc99e..15449b2c0d0 100644 --- a/src/extension-host/tts-runtime-registry.ts +++ b/src/extension-host/tts-runtime-registry.ts @@ -1,4 +1,5 @@ import type { TtsProvider } from "../config/types.tts.js"; +import { resolveExtensionHostTtsRuntimeBackendOrder } from "./runtime-backend-catalog.js"; import type { ResolvedTtsConfig } from "./tts-config.js"; import { EXTENSION_HOST_TTS_RUNTIME_BACKEND_IDS, @@ -36,7 +37,7 @@ export function isExtensionHostTtsProviderConfigured( } export function resolveExtensionHostTtsProviderOrder(primary: TtsProvider): TtsProvider[] { - return [primary, ...EXTENSION_HOST_TTS_PROVIDER_IDS.filter((provider) => provider !== primary)]; + return [...resolveExtensionHostTtsRuntimeBackendOrder(primary)]; } export function supportsExtensionHostTtsTelephony(provider: TtsProvider): boolean {