test: inject thread-safe base seams

This commit is contained in:
Peter Steinberger
2026-03-23 04:58:18 -07:00
parent 8fd2fa13c6
commit 47db5abece
8 changed files with 265 additions and 98 deletions

View File

@@ -1,9 +1,22 @@
import type { StreamFn } from "@mariozechner/pi-agent-core";
import type { Context, Model, SimpleStreamOptions } from "@mariozechner/pi-ai";
import { describe, expect, it, vi } from "vitest";
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
import { __testing as extraParamsTesting } from "./pi-embedded-runner/extra-params.js";
import {
createOpenRouterSystemCacheWrapper,
createOpenRouterWrapper,
isProxyReasoningUnsupported,
} from "./pi-embedded-runner/proxy-stream-wrappers.js";
import type { ProviderCapabilities } from "./provider-capabilities.js";
import { __testing as providerCapabilitiesTesting } from "./provider-capabilities.js";
const resolveProviderCapabilitiesWithPluginMock = vi.fn(
(params: { provider: string; workspaceDir?: string }) => {
(params: {
provider: string;
config?: import("../config/config.js").OpenClawConfig;
workspaceDir?: string;
env?: NodeJS.ProcessEnv;
}): Partial<ProviderCapabilities> | undefined => {
if (
params.provider === "workspace-anthropic-proxy" &&
params.workspaceDir === "/tmp/workspace-capabilities"
@@ -17,20 +30,12 @@ const resolveProviderCapabilitiesWithPluginMock = vi.fn(
},
);
vi.mock("../plugins/provider-runtime.js", async (importOriginal) => {
const actual = await importOriginal<typeof import("../plugins/provider-runtime.js")>();
const {
createOpenRouterSystemCacheWrapper,
createOpenRouterWrapper,
isProxyReasoningUnsupported,
} = await import("./pi-embedded-runner/proxy-stream-wrappers.js");
import { applyExtraParamsToAgent, resolveExtraParams } from "./pi-embedded-runner.js";
import { log } from "./pi-embedded-runner/logger.js";
return {
...actual,
prepareProviderExtraParams: (params: {
provider: string;
context: { extraParams?: Record<string, unknown> };
}) => {
beforeEach(() => {
extraParamsTesting.setProviderRuntimeDepsForTest({
prepareProviderExtraParams: (params) => {
if (params.provider !== "openai-codex") {
return undefined;
}
@@ -43,15 +48,7 @@ vi.mock("../plugins/provider-runtime.js", async (importOriginal) => {
transport: "auto",
};
},
wrapProviderStreamFn: (params: {
provider: string;
context: {
modelId: string;
thinkingLevel?: import("../auto-reply/thinking.js").ThinkLevel;
extraParams?: Record<string, unknown>;
streamFn?: StreamFn;
};
}) => {
wrapProviderStreamFn: (params) => {
if (params.provider !== "openrouter") {
return params.context.streamFn;
}
@@ -80,13 +77,17 @@ vi.mock("../plugins/provider-runtime.js", async (importOriginal) => {
const thinkingLevel = skipReasoningInjection ? undefined : params.context.thinkingLevel;
return createOpenRouterSystemCacheWrapper(createOpenRouterWrapper(streamFn, thinkingLevel));
},
resolveProviderCapabilitiesWithPlugin: (params: { provider: string; workspaceDir?: string }) =>
resolveProviderCapabilitiesWithPluginMock(params),
};
});
providerCapabilitiesTesting.setResolveProviderCapabilitiesWithPluginForTest(
resolveProviderCapabilitiesWithPluginMock,
);
resolveProviderCapabilitiesWithPluginMock.mockClear();
});
import { applyExtraParamsToAgent, resolveExtraParams } from "./pi-embedded-runner.js";
import { log } from "./pi-embedded-runner/logger.js";
afterEach(() => {
extraParamsTesting.resetProviderRuntimeDepsForTest();
providerCapabilitiesTesting.resetDepsForTests();
});
describe("resolveExtraParams", () => {
it("returns undefined with no model config", () => {

View File

@@ -4,8 +4,8 @@ import { streamSimple } from "@mariozechner/pi-ai";
import type { ThinkLevel } from "../../auto-reply/thinking.js";
import type { OpenClawConfig } from "../../config/config.js";
import {
prepareProviderExtraParams,
wrapProviderStreamFn,
prepareProviderExtraParams as prepareProviderExtraParamsRuntime,
wrapProviderStreamFn as wrapProviderStreamFnRuntime,
} from "../../plugins/provider-runtime.js";
import {
createAnthropicBetaHeadersWrapper,
@@ -38,6 +38,31 @@ import {
} from "./openai-stream-wrappers.js";
import { createXaiFastModeWrapper } from "./xai-stream-wrappers.js";
const defaultProviderRuntimeDeps = {
prepareProviderExtraParams: prepareProviderExtraParamsRuntime,
wrapProviderStreamFn: wrapProviderStreamFnRuntime,
};
const providerRuntimeDeps = {
...defaultProviderRuntimeDeps,
};
export const __testing = {
setProviderRuntimeDepsForTest(
deps: Partial<typeof defaultProviderRuntimeDeps> | undefined,
): void {
providerRuntimeDeps.prepareProviderExtraParams =
deps?.prepareProviderExtraParams ?? defaultProviderRuntimeDeps.prepareProviderExtraParams;
providerRuntimeDeps.wrapProviderStreamFn =
deps?.wrapProviderStreamFn ?? defaultProviderRuntimeDeps.wrapProviderStreamFn;
},
resetProviderRuntimeDepsForTest(): void {
providerRuntimeDeps.prepareProviderExtraParams =
defaultProviderRuntimeDeps.prepareProviderExtraParams;
providerRuntimeDeps.wrapProviderStreamFn = defaultProviderRuntimeDeps.wrapProviderStreamFn;
},
};
/**
* Resolve provider-specific extra params from model config.
* Used to pass through stream params like temperature/maxTokens.
@@ -206,7 +231,7 @@ export function applyExtraParamsToAgent(
: undefined;
const merged = Object.assign({}, resolvedExtraParams, override);
const effectiveExtraParams =
prepareProviderExtraParams({
providerRuntimeDeps.prepareProviderExtraParams({
provider,
config: cfg,
context: {
@@ -257,7 +282,7 @@ export function applyExtraParamsToAgent(
workspaceDir,
});
const providerStreamBase = agent.streamFn;
const pluginWrappedStreamFn = wrapProviderStreamFn({
const pluginWrappedStreamFn = providerRuntimeDeps.wrapProviderStreamFn({
provider,
config: cfg,
context: {

View File

@@ -1,5 +1,5 @@
import type { OpenClawConfig } from "../config/config.js";
import { resolveProviderCapabilitiesWithPlugin } from "../plugins/provider-runtime.js";
import { resolveProviderCapabilitiesWithPlugin as resolveProviderCapabilitiesWithPluginRuntime } from "../plugins/provider-runtime.js";
import { normalizeProviderId } from "./model-selection.js";
export type ProviderCapabilities = {
@@ -82,13 +82,31 @@ const PLUGIN_CAPABILITIES_FALLBACKS: Record<string, Partial<ProviderCapabilities
},
};
const defaultResolveProviderCapabilitiesWithPlugin = resolveProviderCapabilitiesWithPluginRuntime;
const providerCapabilityDeps = {
resolveProviderCapabilitiesWithPlugin: defaultResolveProviderCapabilitiesWithPlugin,
};
export const __testing = {
setResolveProviderCapabilitiesWithPluginForTest(
resolveProviderCapabilitiesWithPlugin?: typeof defaultResolveProviderCapabilitiesWithPlugin,
): void {
providerCapabilityDeps.resolveProviderCapabilitiesWithPlugin =
resolveProviderCapabilitiesWithPlugin ?? defaultResolveProviderCapabilitiesWithPlugin;
},
resetDepsForTests(): void {
providerCapabilityDeps.resolveProviderCapabilitiesWithPlugin =
defaultResolveProviderCapabilitiesWithPlugin;
},
};
export function resolveProviderCapabilities(
provider?: string | null,
options?: ProviderCapabilityLookupOptions,
): ProviderCapabilities {
const normalized = normalizeProviderId(provider ?? "");
const pluginCapabilities = normalized
? resolveProviderCapabilitiesWithPlugin({
? providerCapabilityDeps.resolveProviderCapabilitiesWithPlugin({
provider: normalized,
config: options?.config,
workspaceDir: options?.workspaceDir,

View File

@@ -1,10 +1,11 @@
import fs from "node:fs/promises";
import os from "node:os";
import path from "node:path";
import { afterEach, describe, expect, it, vi } from "vitest";
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
import type { SubagentRunRecord } from "../../agents/subagent-registry.js";
import type { OpenClawConfig } from "../../config/config.js";
import {
__testing as abortTesting,
getAbortMemory,
getAbortMemorySizeForTest,
isAbortRequestText,
@@ -17,6 +18,7 @@ import {
tryFastAbortFromMessage,
} from "./abort.js";
import { enqueueFollowupRun, getFollowupQueueDepth, type FollowupRun } from "./queue.js";
import { __testing as queueCleanupTesting } from "./queue/cleanup.js";
import { initSessionState } from "./session.js";
import { buildTestCtx } from "./test-ctx.js";
@@ -26,7 +28,7 @@ vi.mock("../../agents/pi-embedded.js", () => ({
}));
const commandQueueMocks = vi.hoisted(() => ({
clearCommandLane: vi.fn(),
clearCommandLane: vi.fn(() => 1),
}));
vi.mock("../../process/command-queue.js", () => commandQueueMocks);
@@ -162,8 +164,29 @@ describe("abort detection", () => {
expect(commandQueueMocks.clearCommandLane).toHaveBeenCalledWith(`session:${sessionKey}`);
}
beforeEach(() => {
abortTesting.setDepsForTests({
getAcpSessionManager: (() =>
({
resolveSession: acpManagerMocks.resolveSession,
cancelSession: acpManagerMocks.cancelSession,
}) as never) as never,
abortEmbeddedPiRun: () => true,
listSubagentRunsForController: subagentRegistryMocks.listSubagentRunsForRequester,
markSubagentRunTerminated: subagentRegistryMocks.markSubagentRunTerminated,
});
queueCleanupTesting.setDepsForTests({
resolveEmbeddedSessionLane: (key) => `session:${key.trim() || "main"}`,
clearCommandLane: commandQueueMocks.clearCommandLane,
});
commandQueueMocks.clearCommandLane.mockClear().mockReturnValue(1);
});
afterEach(() => {
resetAbortMemoryForTest();
abortTesting.resetDepsForTests();
queueCleanupTesting.resetDepsForTests();
commandQueueMocks.clearCommandLane.mockClear().mockReturnValue(1);
acpManagerMocks.resolveSession.mockReset().mockReturnValue({ kind: "none" });
acpManagerMocks.cancelSession.mockReset().mockResolvedValue(undefined);
});

View File

@@ -47,6 +47,35 @@ export {
setAbortMemory,
};
const defaultAbortDeps = {
getAcpSessionManager,
abortEmbeddedPiRun,
listSubagentRunsForController,
markSubagentRunTerminated,
};
const abortDeps = {
...defaultAbortDeps,
};
export const __testing = {
setDepsForTests(deps: Partial<typeof defaultAbortDeps> | undefined): void {
abortDeps.getAcpSessionManager =
deps?.getAcpSessionManager ?? defaultAbortDeps.getAcpSessionManager;
abortDeps.abortEmbeddedPiRun = deps?.abortEmbeddedPiRun ?? defaultAbortDeps.abortEmbeddedPiRun;
abortDeps.listSubagentRunsForController =
deps?.listSubagentRunsForController ?? defaultAbortDeps.listSubagentRunsForController;
abortDeps.markSubagentRunTerminated =
deps?.markSubagentRunTerminated ?? defaultAbortDeps.markSubagentRunTerminated;
},
resetDepsForTests(): void {
abortDeps.getAcpSessionManager = defaultAbortDeps.getAcpSessionManager;
abortDeps.abortEmbeddedPiRun = defaultAbortDeps.abortEmbeddedPiRun;
abortDeps.listSubagentRunsForController = defaultAbortDeps.listSubagentRunsForController;
abortDeps.markSubagentRunTerminated = defaultAbortDeps.markSubagentRunTerminated;
},
};
export function formatAbortReplyText(stoppedSubagents?: number): string {
if (typeof stoppedSubagents !== "number" || stoppedSubagents <= 0) {
return "⚙️ Agent was aborted.";
@@ -107,7 +136,7 @@ export function stopSubagentsForRequester(params: {
if (!requesterKey) {
return { stopped: 0 };
}
const runs = listSubagentRunsForController(requesterKey);
const runs = abortDeps.listSubagentRunsForController(requesterKey);
if (runs.length === 0) {
return { stopped: 0 };
}
@@ -134,9 +163,9 @@ export function stopSubagentsForRequester(params: {
}
const entry = store[childKey];
const sessionId = entry?.sessionId;
const aborted = sessionId ? abortEmbeddedPiRun(sessionId) : false;
const aborted = sessionId ? abortDeps.abortEmbeddedPiRun(sessionId) : false;
const markedTerminated =
markSubagentRunTerminated({
abortDeps.markSubagentRunTerminated({
runId: run.runId,
childSessionKey: childKey,
reason: "killed",
@@ -198,7 +227,7 @@ export async function tryFastAbortFromMessage(params: {
const store = loadSessionStore(storePath);
const { entry, key, legacyKeys } = resolveSessionEntryForKey(store, targetKey);
const resolvedTargetKey = key ?? targetKey;
const acpManager = getAcpSessionManager();
const acpManager = abortDeps.getAcpSessionManager();
const acpResolution = acpManager.resolveSession({
cfg,
sessionKey: resolvedTargetKey,
@@ -217,7 +246,7 @@ export async function tryFastAbortFromMessage(params: {
}
}
const sessionId = entry?.sessionId;
const aborted = sessionId ? abortEmbeddedPiRun(sessionId) : false;
const aborted = sessionId ? abortDeps.abortEmbeddedPiRun(sessionId) : false;
const cleared = clearSessionQueues([resolvedTargetKey, sessionId]);
if (cleared.followupCleared > 0 || cleared.laneCleared > 0) {
logVerbose(

View File

@@ -9,6 +9,29 @@ export type ClearSessionQueueResult = {
keys: string[];
};
const defaultQueueCleanupDeps = {
resolveEmbeddedSessionLane,
clearCommandLane,
};
const queueCleanupDeps = {
...defaultQueueCleanupDeps,
};
export const __testing = {
setDepsForTests(deps: Partial<typeof defaultQueueCleanupDeps> | undefined): void {
queueCleanupDeps.resolveEmbeddedSessionLane =
deps?.resolveEmbeddedSessionLane ?? defaultQueueCleanupDeps.resolveEmbeddedSessionLane;
queueCleanupDeps.clearCommandLane =
deps?.clearCommandLane ?? defaultQueueCleanupDeps.clearCommandLane;
},
resetDepsForTests(): void {
queueCleanupDeps.resolveEmbeddedSessionLane =
defaultQueueCleanupDeps.resolveEmbeddedSessionLane;
queueCleanupDeps.clearCommandLane = defaultQueueCleanupDeps.clearCommandLane;
},
};
export function clearSessionQueues(keys: Array<string | undefined>): ClearSessionQueueResult {
const seen = new Set<string>();
let followupCleared = 0;
@@ -24,7 +47,9 @@ export function clearSessionQueues(keys: Array<string | undefined>): ClearSessio
clearedKeys.push(cleaned);
followupCleared += clearFollowupQueue(cleaned);
clearFollowupDrainCallback(cleaned);
laneCleared += clearCommandLane(resolveEmbeddedSessionLane(cleaned));
laneCleared += queueCleanupDeps.clearCommandLane(
queueCleanupDeps.resolveEmbeddedSessionLane(cleaned),
);
}
return { followupCleared, laneCleared, keys: clearedKeys };

View File

@@ -28,54 +28,42 @@ let startMode: StartMode = "hello";
let closeCode = 1006;
let closeReason = "";
let helloMethods: string[] | undefined = ["health", "secrets.resolve"];
vi.mock("./client.js", () => ({
describeGatewayCloseCode: (code: number) => {
if (code === 1000) {
return "normal closure";
}
if (code === 1006) {
return "abnormal closure (no close frame)";
}
return undefined;
},
GatewayClient: class {
constructor(opts: {
url?: string;
token?: string;
password?: string;
scopes?: string[];
onHelloOk?: (hello: { features?: { methods?: string[] } }) => void | Promise<void>;
onClose?: (code: number, reason: string) => void;
}) {
lastClientOptions = opts;
}
async request(
method: string,
params: unknown,
opts?: { expectFinal?: boolean; timeoutMs?: number | null },
) {
lastRequestOptions = { method, params, opts };
return { ok: true };
}
start() {
if (startMode === "hello") {
void lastClientOptions?.onHelloOk?.({
features: {
methods: helloMethods,
},
});
} else if (startMode === "close") {
lastClientOptions?.onClose?.(closeCode, closeReason);
}
}
stop() {}
},
}));
const { buildGatewayConnectionDetails, callGateway, callGatewayCli, callGatewayScoped } =
const { __testing, buildGatewayConnectionDetails, callGateway, callGatewayCli, callGatewayScoped } =
await import("./call.js");
class StubGatewayClient {
constructor(opts: {
url?: string;
token?: string;
password?: string;
scopes?: string[];
onHelloOk?: (hello: { features?: { methods?: string[] } }) => void | Promise<void>;
onClose?: (code: number, reason: string) => void;
}) {
lastClientOptions = opts;
}
async request(
method: string,
params: unknown,
opts?: { expectFinal?: boolean; timeoutMs?: number | null },
) {
lastRequestOptions = { method, params, opts };
return { ok: true };
}
start() {
if (startMode === "hello") {
void lastClientOptions?.onHelloOk?.({
features: {
methods: helloMethods,
},
});
} else if (startMode === "close") {
lastClientOptions?.onClose?.(closeCode, closeReason);
}
}
stop() {}
}
function resetGatewayCallMocks() {
loadConfig.mockClear();
resolveGatewayPort.mockClear();
@@ -87,6 +75,17 @@ function resetGatewayCallMocks() {
closeCode = 1006;
closeReason = "";
helloMethods = ["health", "secrets.resolve"];
const loadConfigForTests = loadConfig as unknown as () => OpenClawConfig;
const resolveGatewayPortForTests = resolveGatewayPort as unknown as (
cfg?: OpenClawConfig,
env?: NodeJS.ProcessEnv,
) => number;
__testing.setDepsForTests({
createGatewayClient: (opts) =>
new StubGatewayClient(opts as ConstructorParameters<typeof StubGatewayClient>[0]) as never,
loadConfig: loadConfigForTests,
resolveGatewayPort: resolveGatewayPortForTests,
});
}
function setGatewayNetworkDefaults(port = 18789) {
@@ -126,6 +125,7 @@ describe("callGateway url resolution", () => {
afterEach(() => {
envSnapshot.restore();
__testing.resetDepsForTests();
});
it.each([

View File

@@ -17,7 +17,7 @@ import {
type GatewayClientName,
} from "../utils/message-channel.js";
import { VERSION } from "../version.js";
import { GatewayClient } from "./client.js";
import { GatewayClient, type GatewayClientOptions } from "./client.js";
import {
GatewaySecretRefUnavailableError,
resolveGatewayCredentialsFromConfig,
@@ -81,6 +81,47 @@ export type GatewayConnectionDetails = {
message: string;
};
const defaultCreateGatewayClient = (opts: GatewayClientOptions) => new GatewayClient(opts);
const defaultGatewayCallDeps = {
createGatewayClient: defaultCreateGatewayClient,
loadConfig,
resolveGatewayPort,
resolveConfigPath,
resolveStateDir,
loadGatewayTlsRuntime,
};
const gatewayCallDeps = {
...defaultGatewayCallDeps,
};
export const __testing = {
setDepsForTests(deps: Partial<typeof defaultGatewayCallDeps> | undefined): void {
gatewayCallDeps.createGatewayClient =
deps?.createGatewayClient ?? defaultGatewayCallDeps.createGatewayClient;
gatewayCallDeps.loadConfig = deps?.loadConfig ?? defaultGatewayCallDeps.loadConfig;
gatewayCallDeps.resolveGatewayPort =
deps?.resolveGatewayPort ?? defaultGatewayCallDeps.resolveGatewayPort;
gatewayCallDeps.resolveConfigPath =
deps?.resolveConfigPath ?? defaultGatewayCallDeps.resolveConfigPath;
gatewayCallDeps.resolveStateDir =
deps?.resolveStateDir ?? defaultGatewayCallDeps.resolveStateDir;
gatewayCallDeps.loadGatewayTlsRuntime =
deps?.loadGatewayTlsRuntime ?? defaultGatewayCallDeps.loadGatewayTlsRuntime;
},
setCreateGatewayClientForTests(createGatewayClient?: typeof defaultCreateGatewayClient): void {
gatewayCallDeps.createGatewayClient =
createGatewayClient ?? defaultGatewayCallDeps.createGatewayClient;
},
resetDepsForTests(): void {
gatewayCallDeps.createGatewayClient = defaultGatewayCallDeps.createGatewayClient;
gatewayCallDeps.loadConfig = defaultGatewayCallDeps.loadConfig;
gatewayCallDeps.resolveGatewayPort = defaultGatewayCallDeps.resolveGatewayPort;
gatewayCallDeps.resolveConfigPath = defaultGatewayCallDeps.resolveConfigPath;
gatewayCallDeps.resolveStateDir = defaultGatewayCallDeps.resolveStateDir;
gatewayCallDeps.loadGatewayTlsRuntime = defaultGatewayCallDeps.loadGatewayTlsRuntime;
},
};
function shouldAttachDeviceIdentityForGatewayCall(params: {
url: string;
token?: string;
@@ -155,13 +196,14 @@ export function buildGatewayConnectionDetails(
urlSource?: "cli" | "env";
} = {},
): GatewayConnectionDetails {
const config = options.config ?? loadConfig();
const config = options.config ?? gatewayCallDeps.loadConfig();
const configPath =
options.configPath ?? resolveConfigPath(process.env, resolveStateDir(process.env));
options.configPath ??
gatewayCallDeps.resolveConfigPath(process.env, gatewayCallDeps.resolveStateDir(process.env));
const isRemoteMode = config.gateway?.mode === "remote";
const remote = isRemoteMode ? config.gateway?.remote : undefined;
const tlsEnabled = config.gateway?.tls?.enabled === true;
const localPort = resolveGatewayPort(config);
const localPort = gatewayCallDeps.resolveGatewayPort(config);
const bindMode = config.gateway?.bind ?? "loopback";
const scheme = tlsEnabled ? "wss" : "ws";
// Self-connections should always target loopback; bind mode only controls listener exposure.
@@ -273,9 +315,10 @@ function resolveGatewayCallTimeout(timeoutValue: unknown): {
}
function resolveGatewayCallContext(opts: CallGatewayBaseOptions): ResolvedGatewayCallContext {
const config = opts.config ?? loadConfig();
const config = opts.config ?? gatewayCallDeps.loadConfig();
const configPath =
opts.configPath ?? resolveConfigPath(process.env, resolveStateDir(process.env));
opts.configPath ??
gatewayCallDeps.resolveConfigPath(process.env, gatewayCallDeps.resolveStateDir(process.env));
const isRemoteMode = config.gateway?.mode === "remote";
const remote = isRemoteMode
? (config.gateway?.remote as GatewayRemoteSettings | undefined)
@@ -683,7 +726,10 @@ export async function resolveGatewayCredentialsWithSecretInputs(params: {
: undefined;
const context: ResolvedGatewayCallContext = {
config: params.config,
configPath: resolveConfigPath(process.env, resolveStateDir(process.env)),
configPath: gatewayCallDeps.resolveConfigPath(
process.env,
gatewayCallDeps.resolveStateDir(process.env),
),
isRemoteMode,
remote: remoteFromOverride ?? remoteFromConfig,
urlOverride: trimToUndefined(params.urlOverride),
@@ -715,7 +761,7 @@ async function resolveGatewayTlsFingerprint(params: {
!context.remoteUrl &&
url.startsWith("wss://");
const tlsRuntime = useLocalTls
? await loadGatewayTlsRuntime(context.config.gateway?.tls)
? await gatewayCallDeps.loadGatewayTlsRuntime(context.config.gateway?.tls)
: undefined;
const overrideTlsFingerprint = trimToUndefined(opts.tlsFingerprint);
const remoteTlsFingerprint =
@@ -809,7 +855,7 @@ async function executeGatewayRequestWithScopes<T>(params: {
}
};
const client = new GatewayClient({
const client = gatewayCallDeps.createGatewayClient({
url,
token,
password,