test: inject thread-safe base seams

This commit is contained in:
Peter Steinberger
2026-03-23 04:58:18 -07:00
parent 8fd2fa13c6
commit 47db5abece
8 changed files with 265 additions and 98 deletions

View File

@@ -1,9 +1,22 @@
import type { StreamFn } from "@mariozechner/pi-agent-core";
import type { Context, Model, SimpleStreamOptions } from "@mariozechner/pi-ai";
import { describe, expect, it, vi } from "vitest";
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
import { __testing as extraParamsTesting } from "./pi-embedded-runner/extra-params.js";
import {
createOpenRouterSystemCacheWrapper,
createOpenRouterWrapper,
isProxyReasoningUnsupported,
} from "./pi-embedded-runner/proxy-stream-wrappers.js";
import type { ProviderCapabilities } from "./provider-capabilities.js";
import { __testing as providerCapabilitiesTesting } from "./provider-capabilities.js";
const resolveProviderCapabilitiesWithPluginMock = vi.fn(
(params: { provider: string; workspaceDir?: string }) => {
(params: {
provider: string;
config?: import("../config/config.js").OpenClawConfig;
workspaceDir?: string;
env?: NodeJS.ProcessEnv;
}): Partial<ProviderCapabilities> | undefined => {
if (
params.provider === "workspace-anthropic-proxy" &&
params.workspaceDir === "/tmp/workspace-capabilities"
@@ -17,20 +30,12 @@ const resolveProviderCapabilitiesWithPluginMock = vi.fn(
},
);
vi.mock("../plugins/provider-runtime.js", async (importOriginal) => {
const actual = await importOriginal<typeof import("../plugins/provider-runtime.js")>();
const {
createOpenRouterSystemCacheWrapper,
createOpenRouterWrapper,
isProxyReasoningUnsupported,
} = await import("./pi-embedded-runner/proxy-stream-wrappers.js");
import { applyExtraParamsToAgent, resolveExtraParams } from "./pi-embedded-runner.js";
import { log } from "./pi-embedded-runner/logger.js";
return {
...actual,
prepareProviderExtraParams: (params: {
provider: string;
context: { extraParams?: Record<string, unknown> };
}) => {
beforeEach(() => {
extraParamsTesting.setProviderRuntimeDepsForTest({
prepareProviderExtraParams: (params) => {
if (params.provider !== "openai-codex") {
return undefined;
}
@@ -43,15 +48,7 @@ vi.mock("../plugins/provider-runtime.js", async (importOriginal) => {
transport: "auto",
};
},
wrapProviderStreamFn: (params: {
provider: string;
context: {
modelId: string;
thinkingLevel?: import("../auto-reply/thinking.js").ThinkLevel;
extraParams?: Record<string, unknown>;
streamFn?: StreamFn;
};
}) => {
wrapProviderStreamFn: (params) => {
if (params.provider !== "openrouter") {
return params.context.streamFn;
}
@@ -80,13 +77,17 @@ vi.mock("../plugins/provider-runtime.js", async (importOriginal) => {
const thinkingLevel = skipReasoningInjection ? undefined : params.context.thinkingLevel;
return createOpenRouterSystemCacheWrapper(createOpenRouterWrapper(streamFn, thinkingLevel));
},
resolveProviderCapabilitiesWithPlugin: (params: { provider: string; workspaceDir?: string }) =>
resolveProviderCapabilitiesWithPluginMock(params),
};
});
providerCapabilitiesTesting.setResolveProviderCapabilitiesWithPluginForTest(
resolveProviderCapabilitiesWithPluginMock,
);
resolveProviderCapabilitiesWithPluginMock.mockClear();
});
import { applyExtraParamsToAgent, resolveExtraParams } from "./pi-embedded-runner.js";
import { log } from "./pi-embedded-runner/logger.js";
afterEach(() => {
extraParamsTesting.resetProviderRuntimeDepsForTest();
providerCapabilitiesTesting.resetDepsForTests();
});
describe("resolveExtraParams", () => {
it("returns undefined with no model config", () => {

View File

@@ -4,8 +4,8 @@ import { streamSimple } from "@mariozechner/pi-ai";
import type { ThinkLevel } from "../../auto-reply/thinking.js";
import type { OpenClawConfig } from "../../config/config.js";
import {
prepareProviderExtraParams,
wrapProviderStreamFn,
prepareProviderExtraParams as prepareProviderExtraParamsRuntime,
wrapProviderStreamFn as wrapProviderStreamFnRuntime,
} from "../../plugins/provider-runtime.js";
import {
createAnthropicBetaHeadersWrapper,
@@ -38,6 +38,31 @@ import {
} from "./openai-stream-wrappers.js";
import { createXaiFastModeWrapper } from "./xai-stream-wrappers.js";
const defaultProviderRuntimeDeps = {
prepareProviderExtraParams: prepareProviderExtraParamsRuntime,
wrapProviderStreamFn: wrapProviderStreamFnRuntime,
};
const providerRuntimeDeps = {
...defaultProviderRuntimeDeps,
};
export const __testing = {
setProviderRuntimeDepsForTest(
deps: Partial<typeof defaultProviderRuntimeDeps> | undefined,
): void {
providerRuntimeDeps.prepareProviderExtraParams =
deps?.prepareProviderExtraParams ?? defaultProviderRuntimeDeps.prepareProviderExtraParams;
providerRuntimeDeps.wrapProviderStreamFn =
deps?.wrapProviderStreamFn ?? defaultProviderRuntimeDeps.wrapProviderStreamFn;
},
resetProviderRuntimeDepsForTest(): void {
providerRuntimeDeps.prepareProviderExtraParams =
defaultProviderRuntimeDeps.prepareProviderExtraParams;
providerRuntimeDeps.wrapProviderStreamFn = defaultProviderRuntimeDeps.wrapProviderStreamFn;
},
};
/**
* Resolve provider-specific extra params from model config.
* Used to pass through stream params like temperature/maxTokens.
@@ -206,7 +231,7 @@ export function applyExtraParamsToAgent(
: undefined;
const merged = Object.assign({}, resolvedExtraParams, override);
const effectiveExtraParams =
prepareProviderExtraParams({
providerRuntimeDeps.prepareProviderExtraParams({
provider,
config: cfg,
context: {
@@ -257,7 +282,7 @@ export function applyExtraParamsToAgent(
workspaceDir,
});
const providerStreamBase = agent.streamFn;
const pluginWrappedStreamFn = wrapProviderStreamFn({
const pluginWrappedStreamFn = providerRuntimeDeps.wrapProviderStreamFn({
provider,
config: cfg,
context: {

View File

@@ -1,5 +1,5 @@
import type { OpenClawConfig } from "../config/config.js";
import { resolveProviderCapabilitiesWithPlugin } from "../plugins/provider-runtime.js";
import { resolveProviderCapabilitiesWithPlugin as resolveProviderCapabilitiesWithPluginRuntime } from "../plugins/provider-runtime.js";
import { normalizeProviderId } from "./model-selection.js";
export type ProviderCapabilities = {
@@ -82,13 +82,31 @@ const PLUGIN_CAPABILITIES_FALLBACKS: Record<string, Partial<ProviderCapabilities
},
};
const defaultResolveProviderCapabilitiesWithPlugin = resolveProviderCapabilitiesWithPluginRuntime;
const providerCapabilityDeps = {
resolveProviderCapabilitiesWithPlugin: defaultResolveProviderCapabilitiesWithPlugin,
};
export const __testing = {
setResolveProviderCapabilitiesWithPluginForTest(
resolveProviderCapabilitiesWithPlugin?: typeof defaultResolveProviderCapabilitiesWithPlugin,
): void {
providerCapabilityDeps.resolveProviderCapabilitiesWithPlugin =
resolveProviderCapabilitiesWithPlugin ?? defaultResolveProviderCapabilitiesWithPlugin;
},
resetDepsForTests(): void {
providerCapabilityDeps.resolveProviderCapabilitiesWithPlugin =
defaultResolveProviderCapabilitiesWithPlugin;
},
};
export function resolveProviderCapabilities(
provider?: string | null,
options?: ProviderCapabilityLookupOptions,
): ProviderCapabilities {
const normalized = normalizeProviderId(provider ?? "");
const pluginCapabilities = normalized
? resolveProviderCapabilitiesWithPlugin({
? providerCapabilityDeps.resolveProviderCapabilitiesWithPlugin({
provider: normalized,
config: options?.config,
workspaceDir: options?.workspaceDir,