diff --git a/src/agents/pi-embedded-runner-extraparams.test.ts b/src/agents/pi-embedded-runner-extraparams.test.ts index 1799b77b96f..de77a86759e 100644 --- a/src/agents/pi-embedded-runner-extraparams.test.ts +++ b/src/agents/pi-embedded-runner-extraparams.test.ts @@ -140,7 +140,7 @@ import { } from "./pi-embedded-runner/openai-stream-wrappers.js"; type WrapProviderStreamFnParams = Parameters< - typeof import("../plugins/provider-runtime.js").wrapProviderStreamFn + typeof import("../plugins/provider-hook-runtime.js").wrapProviderStreamFn >[0]; function createTestOpenAIProviderWrapper( diff --git a/src/agents/pi-embedded-runner/extra-params.cache-retention-default.test.ts b/src/agents/pi-embedded-runner/extra-params.cache-retention-default.test.ts index 01888ad65b7..610239cca2e 100644 --- a/src/agents/pi-embedded-runner/extra-params.cache-retention-default.test.ts +++ b/src/agents/pi-embedded-runner/extra-params.cache-retention-default.test.ts @@ -4,15 +4,6 @@ import { isOpenRouterAnthropicModelRef } from "./anthropic-family-cache-semantic import { __testing as extraParamsTesting, applyExtraParamsToAgent } from "./extra-params.js"; import { resolveCacheRetention } from "./prompt-cache-retention.js"; -vi.mock("../../plugins/provider-runtime.js", () => ({ - prepareProviderExtraParams: ({ - context, - }: { - context: { extraParams: Record }; - }) => context.extraParams, - wrapProviderStreamFn: () => undefined, -})); - function applyAndExpectWrapped(params: { cfg?: Parameters[1]; extraParamsOverride?: Parameters[4]; diff --git a/src/agents/pi-embedded-runner/extra-params.google.test.ts b/src/agents/pi-embedded-runner/extra-params.google.test.ts index 486c6de004f..92a2356d5f9 100644 --- a/src/agents/pi-embedded-runner/extra-params.google.test.ts +++ b/src/agents/pi-embedded-runner/extra-params.google.test.ts @@ -4,15 +4,6 @@ import { createPiAiStreamSimpleMock } from "../../../test/helpers/agents/pi-ai-s import { __testing as extraParamsTesting } from "./extra-params.js"; import { runExtraParamsCase } from "./extra-params.test-support.js"; -vi.mock("../../plugins/provider-runtime.js", () => ({ - prepareProviderExtraParams: ({ - context, - }: { - context: { extraParams: Record }; - }) => context.extraParams, - wrapProviderStreamFn: () => undefined, -})); - vi.mock("@mariozechner/pi-ai", async () => createPiAiStreamSimpleMock(() => vi.importActual("@mariozechner/pi-ai"), diff --git a/src/agents/pi-embedded-runner/extra-params.ollama.test.ts b/src/agents/pi-embedded-runner/extra-params.ollama.test.ts index 6a30e7b0108..7a29bc62e4e 100644 --- a/src/agents/pi-embedded-runner/extra-params.ollama.test.ts +++ b/src/agents/pi-embedded-runner/extra-params.ollama.test.ts @@ -3,15 +3,6 @@ import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"; import { __testing as extraParamsTesting } from "./extra-params.js"; import { runExtraParamsCase } from "./extra-params.test-support.js"; -vi.mock("../../plugins/provider-runtime.js", () => ({ - prepareProviderExtraParams: ({ - context, - }: { - context: { extraParams: Record }; - }) => context.extraParams, - wrapProviderStreamFn: () => undefined, -})); - vi.mock("@mariozechner/pi-ai", async () => { const original = await vi.importActual("@mariozechner/pi-ai"); diff --git a/src/agents/pi-embedded-runner/extra-params.test-support.ts b/src/agents/pi-embedded-runner/extra-params.test-support.ts index cce8028b8f3..93a82e6ea86 100644 --- a/src/agents/pi-embedded-runner/extra-params.test-support.ts +++ b/src/agents/pi-embedded-runner/extra-params.test-support.ts @@ -1,19 +1,9 @@ import type { StreamFn } from "@mariozechner/pi-agent-core"; import type { Context, Model, SimpleStreamOptions } from "@mariozechner/pi-ai"; -import { vi } from "vitest"; import type { ThinkLevel } from "../../auto-reply/thinking.shared.js"; import type { OpenClawConfig } from "../../config/types.openclaw.js"; import { __testing as extraParamsTesting, applyExtraParamsToAgent } from "./extra-params.js"; -vi.mock("../../plugins/provider-runtime.js", () => ({ - prepareProviderExtraParams: ({ - context, - }: { - context: { extraParams: Record }; - }) => context.extraParams, - wrapProviderStreamFn: () => undefined, -})); - export type ExtraParamsCapture> = { headers?: Record; options?: SimpleStreamOptions; diff --git a/src/agents/pi-embedded-runner/extra-params.ts b/src/agents/pi-embedded-runner/extra-params.ts index 9d4ff6fe421..8c5afa6e6bc 100644 --- a/src/agents/pi-embedded-runner/extra-params.ts +++ b/src/agents/pi-embedded-runner/extra-params.ts @@ -4,11 +4,11 @@ import { streamSimple } from "@mariozechner/pi-ai"; import type { SettingsManager } from "@mariozechner/pi-coding-agent"; import type { ThinkLevel } from "../../auto-reply/thinking.js"; import type { OpenClawConfig } from "../../config/types.openclaw.js"; -import type { ProviderRuntimeModel } from "../../plugins/provider-runtime-model.types.js"; import { prepareProviderExtraParams as prepareProviderExtraParamsRuntime, wrapProviderStreamFn as wrapProviderStreamFnRuntime, -} from "../../plugins/provider-runtime.js"; +} from "../../plugins/provider-hook-runtime.js"; +import type { ProviderRuntimeModel } from "../../plugins/provider-runtime-model.types.js"; import { createGoogleThinkingPayloadWrapper } from "./google-stream-wrappers.js"; import { log } from "./logger.js"; import { createMinimaxThinkingDisabledWrapper } from "./minimax-stream-wrappers.js"; diff --git a/src/agents/transcript-policy.ts b/src/agents/transcript-policy.ts index 00c1411fb16..8c991535e64 100644 --- a/src/agents/transcript-policy.ts +++ b/src/agents/transcript-policy.ts @@ -1,7 +1,7 @@ import type { OpenClawConfig } from "../config/types.openclaw.js"; +import { resolveProviderRuntimePlugin } from "../plugins/provider-hook-runtime.js"; import { shouldPreserveThinkingBlocks } from "../plugins/provider-replay-helpers.js"; import type { ProviderRuntimeModel } from "../plugins/provider-runtime-model.types.js"; -import { resolveProviderRuntimePlugin } from "../plugins/provider-runtime.js"; import type { ProviderReplayPolicy } from "../plugins/types.js"; import { normalizeLowercaseStringOrEmpty } from "../shared/string-coerce.js"; import { normalizeProviderId } from "./model-selection.js"; diff --git a/src/plugins/provider-hook-runtime.ts b/src/plugins/provider-hook-runtime.ts new file mode 100644 index 00000000000..3f8e6c1c531 --- /dev/null +++ b/src/plugins/provider-hook-runtime.ts @@ -0,0 +1,193 @@ +import { normalizeProviderId } from "../agents/provider-id.js"; +import type { OpenClawConfig } from "../config/types.openclaw.js"; +import { normalizePluginIdScope, serializePluginIdScope } from "./plugin-scope.js"; +import { isPluginProvidersLoadInFlight, resolvePluginProviders } from "./providers.runtime.js"; +import { resolvePluginCacheInputs } from "./roots.js"; +import { getActivePluginRegistryWorkspaceDirFromState } from "./runtime-state.js"; +import type { + ProviderPlugin, + ProviderPrepareExtraParamsContext, + ProviderWrapStreamFnContext, +} from "./types.js"; + +function matchesProviderId(provider: ProviderPlugin, providerId: string): boolean { + const normalized = normalizeProviderId(providerId); + if (!normalized) { + return false; + } + if (normalizeProviderId(provider.id) === normalized) { + return true; + } + return [...(provider.aliases ?? []), ...(provider.hookAliases ?? [])].some( + (alias) => normalizeProviderId(alias) === normalized, + ); +} + +let cachedHookProvidersWithoutConfig = new WeakMap< + NodeJS.ProcessEnv, + Map +>(); +let cachedHookProvidersByConfig = new WeakMap< + OpenClawConfig, + WeakMap> +>(); + +function resolveHookProviderCacheBucket(params: { + config?: OpenClawConfig; + env: NodeJS.ProcessEnv; +}) { + if (!params.config) { + let bucket = cachedHookProvidersWithoutConfig.get(params.env); + if (!bucket) { + bucket = new Map(); + cachedHookProvidersWithoutConfig.set(params.env, bucket); + } + return bucket; + } + + let envBuckets = cachedHookProvidersByConfig.get(params.config); + if (!envBuckets) { + envBuckets = new WeakMap>(); + cachedHookProvidersByConfig.set(params.config, envBuckets); + } + let bucket = envBuckets.get(params.env); + if (!bucket) { + bucket = new Map(); + envBuckets.set(params.env, bucket); + } + return bucket; +} + +function buildHookProviderCacheKey(params: { + config?: OpenClawConfig; + workspaceDir?: string; + onlyPluginIds?: string[]; + providerRefs?: string[]; + env?: NodeJS.ProcessEnv; +}) { + const { roots } = resolvePluginCacheInputs({ + workspaceDir: params.workspaceDir, + env: params.env, + }); + const onlyPluginIds = normalizePluginIdScope(params.onlyPluginIds); + return `${roots.workspace ?? ""}::${roots.global}::${roots.stock ?? ""}::${JSON.stringify(params.config ?? null)}::${serializePluginIdScope(onlyPluginIds)}::${JSON.stringify(params.providerRefs ?? [])}`; +} + +export function clearProviderRuntimeHookCache(): void { + cachedHookProvidersWithoutConfig = new WeakMap< + NodeJS.ProcessEnv, + Map + >(); + cachedHookProvidersByConfig = new WeakMap< + OpenClawConfig, + WeakMap> + >(); +} + +export function resetProviderRuntimeHookCacheForTest(): void { + clearProviderRuntimeHookCache(); +} + +export const __testing = { + buildHookProviderCacheKey, +} as const; + +export function resolveProviderPluginsForHooks(params: { + config?: OpenClawConfig; + workspaceDir?: string; + env?: NodeJS.ProcessEnv; + onlyPluginIds?: string[]; + providerRefs?: string[]; +}): ProviderPlugin[] { + const env = params.env ?? process.env; + const workspaceDir = params.workspaceDir ?? getActivePluginRegistryWorkspaceDirFromState(); + const cacheBucket = resolveHookProviderCacheBucket({ + config: params.config, + env, + }); + const cacheKey = buildHookProviderCacheKey({ + config: params.config, + workspaceDir, + onlyPluginIds: params.onlyPluginIds, + providerRefs: params.providerRefs, + env, + }); + const cached = cacheBucket.get(cacheKey); + if (cached) { + return cached; + } + if ( + isPluginProvidersLoadInFlight({ + ...params, + workspaceDir, + env, + activate: false, + cache: false, + bundledProviderAllowlistCompat: true, + bundledProviderVitestCompat: true, + }) + ) { + return []; + } + const resolved = resolvePluginProviders({ + ...params, + workspaceDir, + env, + activate: false, + cache: false, + bundledProviderAllowlistCompat: true, + bundledProviderVitestCompat: true, + }); + cacheBucket.set(cacheKey, resolved); + return resolved; +} + +export function resolveProviderRuntimePlugin(params: { + provider: string; + config?: OpenClawConfig; + workspaceDir?: string; + env?: NodeJS.ProcessEnv; +}): ProviderPlugin | undefined { + return resolveProviderPluginsForHooks({ + config: params.config, + workspaceDir: params.workspaceDir ?? getActivePluginRegistryWorkspaceDirFromState(), + env: params.env, + providerRefs: [params.provider], + }).find((plugin) => matchesProviderId(plugin, params.provider)); +} + +export function resolveProviderHookPlugin(params: { + provider: string; + config?: OpenClawConfig; + workspaceDir?: string; + env?: NodeJS.ProcessEnv; +}): ProviderPlugin | undefined { + return ( + resolveProviderRuntimePlugin(params) ?? + resolveProviderPluginsForHooks({ + config: params.config, + workspaceDir: params.workspaceDir, + env: params.env, + }).find((candidate) => matchesProviderId(candidate, params.provider)) + ); +} + +export function prepareProviderExtraParams(params: { + provider: string; + config?: OpenClawConfig; + workspaceDir?: string; + env?: NodeJS.ProcessEnv; + context: ProviderPrepareExtraParamsContext; +}) { + return resolveProviderRuntimePlugin(params)?.prepareExtraParams?.(params.context) ?? undefined; +} + +export function wrapProviderStreamFn(params: { + provider: string; + config?: OpenClawConfig; + workspaceDir?: string; + env?: NodeJS.ProcessEnv; + context: ProviderWrapStreamFnContext; +}) { + return resolveProviderHookPlugin(params)?.wrapStreamFn?.(params.context) ?? undefined; +} diff --git a/src/plugins/provider-runtime.ts b/src/plugins/provider-runtime.ts index 99ae3b96efb..9f4e0138af8 100644 --- a/src/plugins/provider-runtime.ts +++ b/src/plugins/provider-runtime.ts @@ -3,17 +3,23 @@ import { applyPluginTextReplacements, mergePluginTextTransforms, } from "../agents/plugin-text-transforms.js"; -import { normalizeProviderId } from "../agents/provider-id.js"; import type { ProviderSystemPromptContribution } from "../agents/system-prompt-contribution.js"; import type { ModelProviderConfig } from "../config/types.js"; import type { OpenClawConfig } from "../config/types.openclaw.js"; import { normalizeOptionalString } from "../shared/string-coerce.js"; -import { normalizePluginIdScope, serializePluginIdScope } from "./plugin-scope.js"; +import { + __testing as providerHookRuntimeTesting, + clearProviderRuntimeHookCache, + prepareProviderExtraParams, + resetProviderRuntimeHookCacheForTest, + resolveProviderHookPlugin, + resolveProviderPluginsForHooks, + resolveProviderRuntimePlugin, + wrapProviderStreamFn, +} from "./provider-hook-runtime.js"; import { resolveBundledProviderPolicySurface } from "./provider-public-artifacts.js"; import type { ProviderRuntimeModel } from "./provider-runtime-model.types.js"; import { resolveCatalogHookProviderPluginIds } from "./providers.js"; -import { isPluginProvidersLoadInFlight, resolvePluginProviders } from "./providers.runtime.js"; -import { resolvePluginCacheInputs } from "./roots.js"; import { getActivePluginRegistryWorkspaceDirFromState } from "./runtime-state.js"; import { resolveRuntimeTextTransforms } from "./text-transforms.runtime.js"; import type { @@ -41,7 +47,6 @@ import type { ProviderNormalizeResolvedModelContext, ProviderNormalizeTransportContext, ProviderModernModelPolicyContext, - ProviderPrepareExtraParamsContext, ProviderPrepareDynamicModelContext, ProviderPreferRuntimeResolvedModelContext, ProviderResolveExternalAuthProfilesContext, @@ -61,142 +66,20 @@ import type { ProviderTransportTurnState, ProviderValidateReplayTurnsContext, ProviderWebSocketSessionPolicy, - ProviderWrapStreamFnContext, PluginTextTransforms, } from "./types.js"; - -function matchesProviderId(provider: ProviderPlugin, providerId: string): boolean { - const normalized = normalizeProviderId(providerId); - if (!normalized) { - return false; - } - if (normalizeProviderId(provider.id) === normalized) { - return true; - } - return [...(provider.aliases ?? []), ...(provider.hookAliases ?? [])].some( - (alias) => normalizeProviderId(alias) === normalized, - ); -} - -let cachedHookProvidersWithoutConfig = new WeakMap< - NodeJS.ProcessEnv, - Map ->(); -let cachedHookProvidersByConfig = new WeakMap< - OpenClawConfig, - WeakMap> ->(); - -function resolveHookProviderCacheBucket(params: { - config?: OpenClawConfig; - env: NodeJS.ProcessEnv; -}) { - if (!params.config) { - let bucket = cachedHookProvidersWithoutConfig.get(params.env); - if (!bucket) { - bucket = new Map(); - cachedHookProvidersWithoutConfig.set(params.env, bucket); - } - return bucket; - } - - let envBuckets = cachedHookProvidersByConfig.get(params.config); - if (!envBuckets) { - envBuckets = new WeakMap>(); - cachedHookProvidersByConfig.set(params.config, envBuckets); - } - let bucket = envBuckets.get(params.env); - if (!bucket) { - bucket = new Map(); - envBuckets.set(params.env, bucket); - } - return bucket; -} - -function buildHookProviderCacheKey(params: { - config?: OpenClawConfig; - workspaceDir?: string; - onlyPluginIds?: string[]; - providerRefs?: string[]; - env?: NodeJS.ProcessEnv; -}) { - const { roots } = resolvePluginCacheInputs({ - workspaceDir: params.workspaceDir, - env: params.env, - }); - const onlyPluginIds = normalizePluginIdScope(params.onlyPluginIds); - return `${roots.workspace ?? ""}::${roots.global}::${roots.stock ?? ""}::${JSON.stringify(params.config ?? null)}::${serializePluginIdScope(onlyPluginIds)}::${JSON.stringify(params.providerRefs ?? [])}`; -} - -export function clearProviderRuntimeHookCache(): void { - cachedHookProvidersWithoutConfig = new WeakMap< - NodeJS.ProcessEnv, - Map - >(); - cachedHookProvidersByConfig = new WeakMap< - OpenClawConfig, - WeakMap> - >(); -} - -export function resetProviderRuntimeHookCacheForTest(): void { - clearProviderRuntimeHookCache(); -} +export { + clearProviderRuntimeHookCache, + prepareProviderExtraParams, + resetProviderRuntimeHookCacheForTest, + resolveProviderRuntimePlugin, + wrapProviderStreamFn, +}; export const __testing = { - buildHookProviderCacheKey, + ...providerHookRuntimeTesting, } as const; -function resolveProviderPluginsForHooks(params: { - config?: OpenClawConfig; - workspaceDir?: string; - env?: NodeJS.ProcessEnv; - onlyPluginIds?: string[]; - providerRefs?: string[]; -}): ProviderPlugin[] { - const env = params.env ?? process.env; - const workspaceDir = params.workspaceDir ?? getActivePluginRegistryWorkspaceDirFromState(); - const cacheBucket = resolveHookProviderCacheBucket({ - config: params.config, - env, - }); - const cacheKey = buildHookProviderCacheKey({ - config: params.config, - workspaceDir, - onlyPluginIds: params.onlyPluginIds, - providerRefs: params.providerRefs, - env, - }); - const cached = cacheBucket.get(cacheKey); - if (cached) { - return cached; - } - if ( - isPluginProvidersLoadInFlight({ - ...params, - workspaceDir, - env, - activate: false, - cache: false, - bundledProviderAllowlistCompat: true, - bundledProviderVitestCompat: true, - }) - ) { - return []; - } - const resolved = resolvePluginProviders({ - ...params, - workspaceDir, - env, - activate: false, - cache: false, - bundledProviderAllowlistCompat: true, - bundledProviderVitestCompat: true, - }); - cacheBucket.set(cacheKey, resolved); - return resolved; -} - function resolveProviderPluginsForCatalogHooks(params: { config?: OpenClawConfig; workspaceDir?: string; @@ -218,20 +101,6 @@ function resolveProviderPluginsForCatalogHooks(params: { }); } -export function resolveProviderRuntimePlugin(params: { - provider: string; - config?: OpenClawConfig; - workspaceDir?: string; - env?: NodeJS.ProcessEnv; -}): ProviderPlugin | undefined { - return resolveProviderPluginsForHooks({ - config: params.config, - workspaceDir: params.workspaceDir ?? getActivePluginRegistryWorkspaceDirFromState(), - env: params.env, - providerRefs: [params.provider], - }).find((plugin) => matchesProviderId(plugin, params.provider)); -} - export function runProviderDynamicModel(params: { provider: string; config?: OpenClawConfig; @@ -433,22 +302,6 @@ export function applyProviderResolvedTransportWithPlugin(params: { }; } -function resolveProviderHookPlugin(params: { - provider: string; - config?: OpenClawConfig; - workspaceDir?: string; - env?: NodeJS.ProcessEnv; -}): ProviderPlugin | undefined { - return ( - resolveProviderRuntimePlugin(params) ?? - resolveProviderPluginsForHooks({ - config: params.config, - workspaceDir: params.workspaceDir, - env: params.env, - }).find((candidate) => matchesProviderId(candidate, params.provider)) - ); -} - export function normalizeProviderModelIdWithPlugin(params: { provider: string; config?: OpenClawConfig; @@ -612,16 +465,6 @@ export function resolveProviderReasoningOutputModeWithPlugin(params: { return mode === "native" || mode === "tagged" ? mode : undefined; } -export function prepareProviderExtraParams(params: { - provider: string; - config?: OpenClawConfig; - workspaceDir?: string; - env?: NodeJS.ProcessEnv; - context: ProviderPrepareExtraParamsContext; -}) { - return resolveProviderRuntimePlugin(params)?.prepareExtraParams?.(params.context) ?? undefined; -} - export function resolveProviderStreamFn(params: { provider: string; config?: OpenClawConfig; @@ -632,16 +475,6 @@ export function resolveProviderStreamFn(params: { return resolveProviderRuntimePlugin(params)?.createStreamFn?.(params.context) ?? undefined; } -export function wrapProviderStreamFn(params: { - provider: string; - config?: OpenClawConfig; - workspaceDir?: string; - env?: NodeJS.ProcessEnv; - context: ProviderWrapStreamFnContext; -}) { - return resolveProviderHookPlugin(params)?.wrapStreamFn?.(params.context) ?? undefined; -} - export function resolveProviderTransportTurnStateWithPlugin(params: { provider: string; config?: OpenClawConfig;