mirror of
https://github.com/openclaw/openclaw.git
synced 2026-04-29 10:02:04 +00:00
Merge branch 'main' into dashboard-v2-views-refactor
This commit is contained in:
@@ -44,11 +44,11 @@ import {
|
||||
type TurnLatencyStats,
|
||||
} from "./manager.types.js";
|
||||
import {
|
||||
canonicalizeAcpSessionKey,
|
||||
createUnsupportedControlError,
|
||||
hasLegacyAcpIdentityProjection,
|
||||
normalizeAcpErrorCode,
|
||||
normalizeActorKey,
|
||||
normalizeSessionKey,
|
||||
requireReadySessionMeta,
|
||||
resolveAcpAgentFromSessionKey,
|
||||
resolveAcpSessionResolutionError,
|
||||
@@ -87,7 +87,7 @@ export class AcpSessionManager {
|
||||
constructor(private readonly deps: AcpSessionManagerDeps = DEFAULT_DEPS) {}
|
||||
|
||||
resolveSession(params: { cfg: OpenClawConfig; sessionKey: string }): AcpSessionResolution {
|
||||
const sessionKey = normalizeSessionKey(params.sessionKey);
|
||||
const sessionKey = canonicalizeAcpSessionKey(params);
|
||||
if (!sessionKey) {
|
||||
return {
|
||||
kind: "none",
|
||||
@@ -213,7 +213,10 @@ export class AcpSessionManager {
|
||||
handle: AcpRuntimeHandle;
|
||||
meta: SessionAcpMeta;
|
||||
}> {
|
||||
const sessionKey = normalizeSessionKey(input.sessionKey);
|
||||
const sessionKey = canonicalizeAcpSessionKey({
|
||||
cfg: input.cfg,
|
||||
sessionKey: input.sessionKey,
|
||||
});
|
||||
if (!sessionKey) {
|
||||
throw new AcpRuntimeError("ACP_SESSION_INIT_FAILED", "ACP session key is required.");
|
||||
}
|
||||
@@ -321,7 +324,7 @@ export class AcpSessionManager {
|
||||
sessionKey: string;
|
||||
signal?: AbortSignal;
|
||||
}): Promise<AcpSessionStatus> {
|
||||
const sessionKey = normalizeSessionKey(params.sessionKey);
|
||||
const sessionKey = canonicalizeAcpSessionKey(params);
|
||||
if (!sessionKey) {
|
||||
throw new AcpRuntimeError("ACP_SESSION_INIT_FAILED", "ACP session key is required.");
|
||||
}
|
||||
@@ -397,7 +400,7 @@ export class AcpSessionManager {
|
||||
sessionKey: string;
|
||||
runtimeMode: string;
|
||||
}): Promise<AcpSessionRuntimeOptions> {
|
||||
const sessionKey = normalizeSessionKey(params.sessionKey);
|
||||
const sessionKey = canonicalizeAcpSessionKey(params);
|
||||
if (!sessionKey) {
|
||||
throw new AcpRuntimeError("ACP_SESSION_INIT_FAILED", "ACP session key is required.");
|
||||
}
|
||||
@@ -452,7 +455,7 @@ export class AcpSessionManager {
|
||||
key: string;
|
||||
value: string;
|
||||
}): Promise<AcpSessionRuntimeOptions> {
|
||||
const sessionKey = normalizeSessionKey(params.sessionKey);
|
||||
const sessionKey = canonicalizeAcpSessionKey(params);
|
||||
if (!sessionKey) {
|
||||
throw new AcpRuntimeError("ACP_SESSION_INIT_FAILED", "ACP session key is required.");
|
||||
}
|
||||
@@ -525,7 +528,7 @@ export class AcpSessionManager {
|
||||
sessionKey: string;
|
||||
patch: Partial<AcpSessionRuntimeOptions>;
|
||||
}): Promise<AcpSessionRuntimeOptions> {
|
||||
const sessionKey = normalizeSessionKey(params.sessionKey);
|
||||
const sessionKey = canonicalizeAcpSessionKey(params);
|
||||
const validatedPatch = validateRuntimeOptionPatch(params.patch);
|
||||
if (!sessionKey) {
|
||||
throw new AcpRuntimeError("ACP_SESSION_INIT_FAILED", "ACP session key is required.");
|
||||
@@ -555,7 +558,7 @@ export class AcpSessionManager {
|
||||
cfg: OpenClawConfig;
|
||||
sessionKey: string;
|
||||
}): Promise<AcpSessionRuntimeOptions> {
|
||||
const sessionKey = normalizeSessionKey(params.sessionKey);
|
||||
const sessionKey = canonicalizeAcpSessionKey(params);
|
||||
if (!sessionKey) {
|
||||
throw new AcpRuntimeError("ACP_SESSION_INIT_FAILED", "ACP session key is required.");
|
||||
}
|
||||
@@ -591,7 +594,10 @@ export class AcpSessionManager {
|
||||
}
|
||||
|
||||
async runTurn(input: AcpRunTurnInput): Promise<void> {
|
||||
const sessionKey = normalizeSessionKey(input.sessionKey);
|
||||
const sessionKey = canonicalizeAcpSessionKey({
|
||||
cfg: input.cfg,
|
||||
sessionKey: input.sessionKey,
|
||||
});
|
||||
if (!sessionKey) {
|
||||
throw new AcpRuntimeError("ACP_SESSION_INIT_FAILED", "ACP session key is required.");
|
||||
}
|
||||
@@ -738,7 +744,7 @@ export class AcpSessionManager {
|
||||
sessionKey: string;
|
||||
reason?: string;
|
||||
}): Promise<void> {
|
||||
const sessionKey = normalizeSessionKey(params.sessionKey);
|
||||
const sessionKey = canonicalizeAcpSessionKey(params);
|
||||
if (!sessionKey) {
|
||||
throw new AcpRuntimeError("ACP_SESSION_INIT_FAILED", "ACP session key is required.");
|
||||
}
|
||||
@@ -806,7 +812,10 @@ export class AcpSessionManager {
|
||||
}
|
||||
|
||||
async closeSession(input: AcpCloseSessionInput): Promise<AcpCloseSessionResult> {
|
||||
const sessionKey = normalizeSessionKey(input.sessionKey);
|
||||
const sessionKey = canonicalizeAcpSessionKey({
|
||||
cfg: input.cfg,
|
||||
sessionKey: input.sessionKey,
|
||||
});
|
||||
if (!sessionKey) {
|
||||
throw new AcpRuntimeError("ACP_SESSION_INIT_FAILED", "ACP session key is required.");
|
||||
}
|
||||
|
||||
@@ -170,6 +170,57 @@ describe("AcpSessionManager", () => {
|
||||
expect(resolved.error.message).toContain("ACP metadata is missing");
|
||||
});
|
||||
|
||||
it("canonicalizes the main alias before ACP rehydrate after restart", async () => {
|
||||
const runtimeState = createRuntime();
|
||||
hoisted.requireAcpRuntimeBackendMock.mockReturnValue({
|
||||
id: "acpx",
|
||||
runtime: runtimeState.runtime,
|
||||
});
|
||||
hoisted.readAcpSessionEntryMock.mockImplementation((paramsUnknown: unknown) => {
|
||||
const sessionKey = (paramsUnknown as { sessionKey?: string }).sessionKey;
|
||||
if (sessionKey !== "agent:main:main") {
|
||||
return null;
|
||||
}
|
||||
return {
|
||||
sessionKey,
|
||||
storeSessionKey: sessionKey,
|
||||
acp: {
|
||||
...readySessionMeta(),
|
||||
agent: "main",
|
||||
runtimeSessionName: sessionKey,
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
const manager = new AcpSessionManager();
|
||||
const cfg = {
|
||||
...baseCfg,
|
||||
session: { mainKey: "main" },
|
||||
agents: { list: [{ id: "main", default: true }] },
|
||||
} as OpenClawConfig;
|
||||
|
||||
await manager.runTurn({
|
||||
cfg,
|
||||
sessionKey: "main",
|
||||
text: "after restart",
|
||||
mode: "prompt",
|
||||
requestId: "r-main",
|
||||
});
|
||||
|
||||
expect(hoisted.readAcpSessionEntryMock).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
cfg,
|
||||
sessionKey: "agent:main:main",
|
||||
}),
|
||||
);
|
||||
expect(runtimeState.ensureSession).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
agent: "main",
|
||||
sessionKey: "agent:main:main",
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it("serializes concurrent turns for the same ACP session", async () => {
|
||||
const runtimeState = createRuntime();
|
||||
hoisted.requireAcpRuntimeBackendMock.mockReturnValue({
|
||||
|
||||
@@ -1,6 +1,14 @@
|
||||
import type { OpenClawConfig } from "../../config/config.js";
|
||||
import {
|
||||
canonicalizeMainSessionAlias,
|
||||
resolveMainSessionKey,
|
||||
} from "../../config/sessions/main-session.js";
|
||||
import type { SessionAcpMeta } from "../../config/sessions/types.js";
|
||||
import { normalizeAgentId, parseAgentSessionKey } from "../../routing/session-key.js";
|
||||
import {
|
||||
normalizeAgentId,
|
||||
normalizeMainKey,
|
||||
parseAgentSessionKey,
|
||||
} from "../../routing/session-key.js";
|
||||
import { ACP_ERROR_CODES, AcpRuntimeError } from "../runtime/errors.js";
|
||||
import type { AcpSessionResolution } from "./manager.types.js";
|
||||
|
||||
@@ -42,6 +50,33 @@ export function normalizeSessionKey(sessionKey: string): string {
|
||||
return sessionKey.trim();
|
||||
}
|
||||
|
||||
export function canonicalizeAcpSessionKey(params: {
|
||||
cfg: OpenClawConfig;
|
||||
sessionKey: string;
|
||||
}): string {
|
||||
const normalized = normalizeSessionKey(params.sessionKey);
|
||||
if (!normalized) {
|
||||
return "";
|
||||
}
|
||||
const lowered = normalized.toLowerCase();
|
||||
if (lowered === "global" || lowered === "unknown") {
|
||||
return lowered;
|
||||
}
|
||||
const parsed = parseAgentSessionKey(lowered);
|
||||
if (parsed) {
|
||||
return canonicalizeMainSessionAlias({
|
||||
cfg: params.cfg,
|
||||
agentId: parsed.agentId,
|
||||
sessionKey: lowered,
|
||||
});
|
||||
}
|
||||
const mainKey = normalizeMainKey(params.cfg.session?.mainKey);
|
||||
if (lowered === "main" || lowered === mainKey) {
|
||||
return resolveMainSessionKey(params.cfg);
|
||||
}
|
||||
return lowered;
|
||||
}
|
||||
|
||||
export function normalizeActorKey(sessionKey: string): string {
|
||||
return sessionKey.trim().toLowerCase();
|
||||
}
|
||||
|
||||
@@ -52,7 +52,7 @@ function createSetSessionModeRequest(sessionId: string, modeId: string): SetSess
|
||||
function createSetSessionConfigOptionRequest(
|
||||
sessionId: string,
|
||||
configId: string,
|
||||
value: string,
|
||||
value: string | boolean,
|
||||
): SetSessionConfigOptionRequest {
|
||||
return {
|
||||
sessionId,
|
||||
@@ -644,6 +644,55 @@ describe("acp setSessionConfigOption bridge behavior", () => {
|
||||
|
||||
sessionStore.clearAllSessionsForTest();
|
||||
});
|
||||
|
||||
it("rejects non-string ACP config option values", async () => {
|
||||
const sessionStore = createInMemorySessionStore();
|
||||
const connection = createAcpConnection();
|
||||
const request = vi.fn(async (method: string) => {
|
||||
if (method === "sessions.list") {
|
||||
return {
|
||||
ts: Date.now(),
|
||||
path: "/tmp/sessions.json",
|
||||
count: 1,
|
||||
defaults: {
|
||||
modelProvider: null,
|
||||
model: null,
|
||||
contextTokens: null,
|
||||
},
|
||||
sessions: [
|
||||
{
|
||||
key: "bool-config-session",
|
||||
kind: "direct",
|
||||
updatedAt: Date.now(),
|
||||
thinkingLevel: "minimal",
|
||||
modelProvider: "openai",
|
||||
model: "gpt-5.4",
|
||||
},
|
||||
],
|
||||
};
|
||||
}
|
||||
return { ok: true };
|
||||
}) as GatewayClient["request"];
|
||||
const agent = new AcpGatewayAgent(connection, createAcpGateway(request), {
|
||||
sessionStore,
|
||||
});
|
||||
|
||||
await agent.loadSession(createLoadSessionRequest("bool-config-session"));
|
||||
|
||||
await expect(
|
||||
agent.setSessionConfigOption(
|
||||
createSetSessionConfigOptionRequest("bool-config-session", "thought_level", false),
|
||||
),
|
||||
).rejects.toThrow(
|
||||
'ACP bridge does not support non-string session config option values for "thought_level".',
|
||||
);
|
||||
expect(request).not.toHaveBeenCalledWith(
|
||||
"sessions.patch",
|
||||
expect.objectContaining({ key: "bool-config-session" }),
|
||||
);
|
||||
|
||||
sessionStore.clearAllSessionsForTest();
|
||||
});
|
||||
});
|
||||
|
||||
describe("acp tool streaming bridge behavior", () => {
|
||||
|
||||
@@ -937,11 +937,16 @@ export class AcpGatewayAgent implements Agent {
|
||||
|
||||
private resolveSessionConfigPatch(
|
||||
configId: string,
|
||||
value: string,
|
||||
value: string | boolean,
|
||||
): {
|
||||
overrides: Partial<GatewaySessionPresentationRow>;
|
||||
patch: Record<string, string>;
|
||||
} {
|
||||
if (typeof value !== "string") {
|
||||
throw new Error(
|
||||
`ACP bridge does not support non-string session config option values for "${configId}".`,
|
||||
);
|
||||
}
|
||||
switch (configId) {
|
||||
case ACP_THOUGHT_LEVEL_CONFIG_ID:
|
||||
return {
|
||||
|
||||
@@ -207,7 +207,7 @@ describe("resolveProfilesUnavailableReason", () => {
|
||||
).toBe("overloaded");
|
||||
});
|
||||
|
||||
it("falls back to rate_limit when active cooldown has no reason history", () => {
|
||||
it("falls back to unknown when active cooldown has no reason history", () => {
|
||||
const now = Date.now();
|
||||
const store = makeStore({
|
||||
"anthropic:default": {
|
||||
@@ -221,7 +221,7 @@ describe("resolveProfilesUnavailableReason", () => {
|
||||
profileIds: ["anthropic:default"],
|
||||
now,
|
||||
}),
|
||||
).toBe("rate_limit");
|
||||
).toBe("unknown");
|
||||
});
|
||||
|
||||
it("ignores expired windows and returns null when no profile is actively unavailable", () => {
|
||||
|
||||
@@ -110,7 +110,11 @@ export function resolveProfilesUnavailableReason(params: {
|
||||
recordedReason = true;
|
||||
}
|
||||
if (!recordedReason) {
|
||||
addScore("rate_limit", 1);
|
||||
// No failure counts recorded for this cooldown window. Previously this
|
||||
// defaulted to "rate_limit", which caused false "rate limit reached"
|
||||
// warnings when the actual reason was unknown (e.g. transient network
|
||||
// blip or server error without a classified failure count).
|
||||
addScore("unknown", 1);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -274,6 +274,8 @@ describe("failover-error", () => {
|
||||
it("infers timeout from common node error codes", () => {
|
||||
expect(resolveFailoverReasonFromError({ code: "ETIMEDOUT" })).toBe("timeout");
|
||||
expect(resolveFailoverReasonFromError({ code: "ECONNRESET" })).toBe("timeout");
|
||||
expect(resolveFailoverReasonFromError({ code: "EHOSTDOWN" })).toBe("timeout");
|
||||
expect(resolveFailoverReasonFromError({ code: "EPIPE" })).toBe("timeout");
|
||||
});
|
||||
|
||||
it("infers timeout from abort/error stop-reason messages", () => {
|
||||
|
||||
@@ -170,7 +170,9 @@ export function resolveFailoverReasonFromError(err: unknown): FailoverReason | n
|
||||
"ECONNREFUSED",
|
||||
"ENETUNREACH",
|
||||
"EHOSTUNREACH",
|
||||
"EHOSTDOWN",
|
||||
"ENETRESET",
|
||||
"EPIPE",
|
||||
"EAI_AGAIN",
|
||||
].includes(code)
|
||||
) {
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import type { ModelDefinitionConfig } from "../config/types.models.js";
|
||||
import { createSubsystemLogger } from "../logging/subsystem.js";
|
||||
import { isReasoningModelHeuristic } from "./ollama-models.js";
|
||||
|
||||
const log = createSubsystemLogger("huggingface-models");
|
||||
|
||||
@@ -125,7 +126,7 @@ export function buildHuggingfaceModelDefinition(
|
||||
*/
|
||||
function inferredMetaFromModelId(id: string): { name: string; reasoning: boolean } {
|
||||
const base = id.split("/").pop() ?? id;
|
||||
const reasoning = /r1|reasoning|thinking|reason/i.test(id) || /-\d+[tb]?-thinking/i.test(base);
|
||||
const reasoning = isReasoningModelHeuristic(id);
|
||||
const name = base.replace(/-/g, " ").replace(/\b(\w)/g, (c) => c.toUpperCase());
|
||||
return { name, reasoning };
|
||||
}
|
||||
|
||||
@@ -81,7 +81,7 @@ export function isModernModelRef(ref: ModelRef): boolean {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (provider === "openrouter" || provider === "opencode") {
|
||||
if (provider === "openrouter" || provider === "opencode" || provider === "opencode-go") {
|
||||
// OpenRouter/opencode are pass-through proxies; accept any model ID
|
||||
// rather than restricting to a static prefix list.
|
||||
return true;
|
||||
|
||||
@@ -131,6 +131,113 @@ describe("memory search config", () => {
|
||||
expect(resolved?.extraPaths).toEqual(["/shared/notes", "docs", "../team-notes"]);
|
||||
});
|
||||
|
||||
it("normalizes multimodal settings", () => {
|
||||
const cfg = asConfig({
|
||||
agents: {
|
||||
defaults: {
|
||||
memorySearch: {
|
||||
provider: "gemini",
|
||||
model: "gemini-embedding-2-preview",
|
||||
multimodal: {
|
||||
enabled: true,
|
||||
modalities: ["all"],
|
||||
maxFileBytes: 8192,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
const resolved = resolveMemorySearchConfig(cfg, "main");
|
||||
expect(resolved?.multimodal).toEqual({
|
||||
enabled: true,
|
||||
modalities: ["image", "audio"],
|
||||
maxFileBytes: 8192,
|
||||
});
|
||||
});
|
||||
|
||||
it("keeps an explicit empty multimodal modalities list empty", () => {
|
||||
const cfg = asConfig({
|
||||
agents: {
|
||||
defaults: {
|
||||
memorySearch: {
|
||||
provider: "gemini",
|
||||
model: "gemini-embedding-2-preview",
|
||||
multimodal: {
|
||||
enabled: true,
|
||||
modalities: [],
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
const resolved = resolveMemorySearchConfig(cfg, "main");
|
||||
expect(resolved?.multimodal).toEqual({
|
||||
enabled: true,
|
||||
modalities: [],
|
||||
maxFileBytes: 10 * 1024 * 1024,
|
||||
});
|
||||
expect(resolved?.provider).toBe("gemini");
|
||||
});
|
||||
|
||||
it("does not enforce multimodal provider validation when no modalities are active", () => {
|
||||
const cfg = asConfig({
|
||||
agents: {
|
||||
defaults: {
|
||||
memorySearch: {
|
||||
provider: "openai",
|
||||
model: "text-embedding-3-small",
|
||||
fallback: "openai",
|
||||
multimodal: {
|
||||
enabled: true,
|
||||
modalities: [],
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
const resolved = resolveMemorySearchConfig(cfg, "main");
|
||||
expect(resolved?.multimodal).toEqual({
|
||||
enabled: true,
|
||||
modalities: [],
|
||||
maxFileBytes: 10 * 1024 * 1024,
|
||||
});
|
||||
});
|
||||
|
||||
it("rejects multimodal memory on unsupported providers", () => {
|
||||
const cfg = asConfig({
|
||||
agents: {
|
||||
defaults: {
|
||||
memorySearch: {
|
||||
provider: "openai",
|
||||
model: "text-embedding-3-small",
|
||||
multimodal: { enabled: true, modalities: ["image"] },
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
expect(() => resolveMemorySearchConfig(cfg, "main")).toThrow(
|
||||
/memorySearch\.multimodal requires memorySearch\.provider = "gemini"/,
|
||||
);
|
||||
});
|
||||
|
||||
it("rejects multimodal memory when fallback is configured", () => {
|
||||
const cfg = asConfig({
|
||||
agents: {
|
||||
defaults: {
|
||||
memorySearch: {
|
||||
provider: "gemini",
|
||||
model: "gemini-embedding-2-preview",
|
||||
fallback: "openai",
|
||||
multimodal: { enabled: true, modalities: ["image"] },
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
expect(() => resolveMemorySearchConfig(cfg, "main")).toThrow(
|
||||
/memorySearch\.multimodal does not support memorySearch\.fallback/,
|
||||
);
|
||||
});
|
||||
|
||||
it("includes batch defaults for openai without remote overrides", () => {
|
||||
const cfg = configWithDefaultProvider("openai");
|
||||
const resolved = resolveMemorySearchConfig(cfg, "main");
|
||||
|
||||
@@ -3,6 +3,12 @@ import path from "node:path";
|
||||
import type { OpenClawConfig, MemorySearchConfig } from "../config/config.js";
|
||||
import { resolveStateDir } from "../config/paths.js";
|
||||
import type { SecretInput } from "../config/types.secrets.js";
|
||||
import {
|
||||
isMemoryMultimodalEnabled,
|
||||
normalizeMemoryMultimodalSettings,
|
||||
supportsMemoryMultimodalEmbeddings,
|
||||
type MemoryMultimodalSettings,
|
||||
} from "../memory/multimodal.js";
|
||||
import { clampInt, clampNumber, resolveUserPath } from "../utils.js";
|
||||
import { resolveAgentConfig } from "./agent-scope.js";
|
||||
|
||||
@@ -10,6 +16,7 @@ export type ResolvedMemorySearchConfig = {
|
||||
enabled: boolean;
|
||||
sources: Array<"memory" | "sessions">;
|
||||
extraPaths: string[];
|
||||
multimodal: MemoryMultimodalSettings;
|
||||
provider: "openai" | "local" | "gemini" | "voyage" | "mistral" | "ollama" | "auto";
|
||||
remote?: {
|
||||
baseUrl?: string;
|
||||
@@ -28,6 +35,7 @@ export type ResolvedMemorySearchConfig = {
|
||||
};
|
||||
fallback: "openai" | "gemini" | "local" | "voyage" | "mistral" | "ollama" | "none";
|
||||
model: string;
|
||||
outputDimensionality?: number;
|
||||
local: {
|
||||
modelPath?: string;
|
||||
modelCacheDir?: string;
|
||||
@@ -193,6 +201,7 @@ function mergeConfig(
|
||||
? DEFAULT_OLLAMA_MODEL
|
||||
: undefined;
|
||||
const model = overrides?.model ?? defaults?.model ?? modelDefault ?? "";
|
||||
const outputDimensionality = overrides?.outputDimensionality ?? defaults?.outputDimensionality;
|
||||
const local = {
|
||||
modelPath: overrides?.local?.modelPath ?? defaults?.local?.modelPath,
|
||||
modelCacheDir: overrides?.local?.modelCacheDir ?? defaults?.local?.modelCacheDir,
|
||||
@@ -202,6 +211,11 @@ function mergeConfig(
|
||||
.map((value) => value.trim())
|
||||
.filter(Boolean);
|
||||
const extraPaths = Array.from(new Set(rawPaths));
|
||||
const multimodal = normalizeMemoryMultimodalSettings({
|
||||
enabled: overrides?.multimodal?.enabled ?? defaults?.multimodal?.enabled,
|
||||
modalities: overrides?.multimodal?.modalities ?? defaults?.multimodal?.modalities,
|
||||
maxFileBytes: overrides?.multimodal?.maxFileBytes ?? defaults?.multimodal?.maxFileBytes,
|
||||
});
|
||||
const vector = {
|
||||
enabled: overrides?.store?.vector?.enabled ?? defaults?.store?.vector?.enabled ?? true,
|
||||
extensionPath:
|
||||
@@ -305,6 +319,7 @@ function mergeConfig(
|
||||
enabled,
|
||||
sources,
|
||||
extraPaths,
|
||||
multimodal,
|
||||
provider,
|
||||
remote,
|
||||
experimental: {
|
||||
@@ -312,6 +327,7 @@ function mergeConfig(
|
||||
},
|
||||
fallback,
|
||||
model,
|
||||
outputDimensionality,
|
||||
local,
|
||||
store,
|
||||
chunking: { tokens: Math.max(1, chunking.tokens), overlap },
|
||||
@@ -362,5 +378,22 @@ export function resolveMemorySearchConfig(
|
||||
if (!resolved.enabled) {
|
||||
return null;
|
||||
}
|
||||
const multimodalActive = isMemoryMultimodalEnabled(resolved.multimodal);
|
||||
if (
|
||||
multimodalActive &&
|
||||
!supportsMemoryMultimodalEmbeddings({
|
||||
provider: resolved.provider,
|
||||
model: resolved.model,
|
||||
})
|
||||
) {
|
||||
throw new Error(
|
||||
'agents.*.memorySearch.multimodal requires memorySearch.provider = "gemini" and model = "gemini-embedding-2-preview".',
|
||||
);
|
||||
}
|
||||
if (multimodalActive && resolved.fallback !== "none") {
|
||||
throw new Error(
|
||||
'agents.*.memorySearch.multimodal does not support memorySearch.fallback. Set fallback to "none".',
|
||||
);
|
||||
}
|
||||
return resolved;
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ export const PROVIDER_ENV_API_KEY_CANDIDATES: Record<string, string[]> = {
|
||||
chutes: ["CHUTES_OAUTH_TOKEN", "CHUTES_API_KEY"],
|
||||
zai: ["ZAI_API_KEY", "Z_AI_API_KEY"],
|
||||
opencode: ["OPENCODE_API_KEY", "OPENCODE_ZEN_API_KEY"],
|
||||
"opencode-go": ["OPENCODE_API_KEY", "OPENCODE_ZEN_API_KEY"],
|
||||
"qwen-portal": ["QWEN_OAUTH_TOKEN", "QWEN_PORTAL_API_KEY"],
|
||||
volcengine: ["VOLCANO_ENGINE_API_KEY"],
|
||||
"volcengine-plan": ["VOLCANO_ENGINE_API_KEY"],
|
||||
|
||||
@@ -412,4 +412,18 @@ describe("getApiKeyForModel", () => {
|
||||
},
|
||||
);
|
||||
});
|
||||
|
||||
it("resolveEnvApiKey('opencode-go') falls back to OPENCODE_ZEN_API_KEY", async () => {
|
||||
await withEnvAsync(
|
||||
{
|
||||
OPENCODE_API_KEY: undefined,
|
||||
OPENCODE_ZEN_API_KEY: "sk-opencode-zen-fallback", // pragma: allowlist secret
|
||||
},
|
||||
async () => {
|
||||
const resolved = resolveEnvApiKey("opencode-go");
|
||||
expect(resolved?.apiKey).toBe("sk-opencode-zen-fallback");
|
||||
expect(resolved?.source).toContain("OPENCODE_ZEN_API_KEY");
|
||||
},
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -313,6 +313,12 @@ describe("isModernModelRef", () => {
|
||||
expect(isModernModelRef({ provider: "opencode", id: "claude-opus-4-6" })).toBe(true);
|
||||
expect(isModernModelRef({ provider: "opencode", id: "gemini-3-pro" })).toBe(true);
|
||||
});
|
||||
|
||||
it("accepts all opencode-go models without zen exclusions", () => {
|
||||
expect(isModernModelRef({ provider: "opencode-go", id: "kimi-k2.5" })).toBe(true);
|
||||
expect(isModernModelRef({ provider: "opencode-go", id: "glm-5" })).toBe(true);
|
||||
expect(isModernModelRef({ provider: "opencode-go", id: "minimax-m2.5" })).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe("resolveForwardCompatModel", () => {
|
||||
|
||||
@@ -555,7 +555,7 @@ describe("runWithModelFallback", () => {
|
||||
usageStat: {
|
||||
cooldownUntil: Date.now() + 5 * 60_000,
|
||||
},
|
||||
expectedReason: "rate_limit",
|
||||
expectedReason: "unknown",
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
@@ -449,7 +449,7 @@ function resolveCooldownDecision(params: {
|
||||
store: params.authStore,
|
||||
profileIds: params.profileIds,
|
||||
now: params.now,
|
||||
}) ?? "rate_limit";
|
||||
}) ?? "unknown";
|
||||
const isPersistentAuthIssue = inferredReason === "auth" || inferredReason === "auth_permanent";
|
||||
if (isPersistentAuthIssue) {
|
||||
return {
|
||||
@@ -483,7 +483,10 @@ function resolveCooldownDecision(params: {
|
||||
// limits, which are often model-scoped and can recover on a sibling model.
|
||||
const shouldAttemptDespiteCooldown =
|
||||
(params.isPrimary && (!params.requestedModel || shouldProbe)) ||
|
||||
(!params.isPrimary && (inferredReason === "rate_limit" || inferredReason === "overloaded"));
|
||||
(!params.isPrimary &&
|
||||
(inferredReason === "rate_limit" ||
|
||||
inferredReason === "overloaded" ||
|
||||
inferredReason === "unknown"));
|
||||
if (!shouldAttemptDespiteCooldown) {
|
||||
return {
|
||||
type: "skip",
|
||||
@@ -588,13 +591,16 @@ export async function runWithModelFallback<T>(params: {
|
||||
if (
|
||||
decision.reason === "rate_limit" ||
|
||||
decision.reason === "overloaded" ||
|
||||
decision.reason === "billing"
|
||||
decision.reason === "billing" ||
|
||||
decision.reason === "unknown"
|
||||
) {
|
||||
// Probe at most once per provider per fallback run when all profiles
|
||||
// are cooldowned. Re-probing every same-provider candidate can stall
|
||||
// cross-provider fallback on providers with long internal retries.
|
||||
const isTransientCooldownReason =
|
||||
decision.reason === "rate_limit" || decision.reason === "overloaded";
|
||||
decision.reason === "rate_limit" ||
|
||||
decision.reason === "overloaded" ||
|
||||
decision.reason === "unknown";
|
||||
if (isTransientCooldownReason && cooldownProbeUsedProviders.has(candidate.provider)) {
|
||||
const error = `Provider ${candidate.provider} is in cooldown (probe already attempted this run)`;
|
||||
attempts.push({
|
||||
|
||||
@@ -326,12 +326,12 @@ async function probeImage(
|
||||
}
|
||||
|
||||
function ensureImageInput(model: OpenAIModel): OpenAIModel {
|
||||
if (model.input.includes("image")) {
|
||||
if (model.input?.includes("image")) {
|
||||
return model;
|
||||
}
|
||||
return {
|
||||
...model,
|
||||
input: Array.from(new Set([...model.input, "image"])),
|
||||
input: Array.from(new Set([...(model.input ?? []), "image"])),
|
||||
};
|
||||
}
|
||||
|
||||
@@ -472,7 +472,7 @@ export async function scanOpenRouterModels(
|
||||
};
|
||||
|
||||
const toolResult = await probeTool(model, apiKey, timeoutMs);
|
||||
const imageResult = model.input.includes("image")
|
||||
const imageResult = model.input?.includes("image")
|
||||
? await probeImage(ensureImageInput(model), apiKey, timeoutMs)
|
||||
: { ok: false, latencyMs: null, skipped: true };
|
||||
|
||||
|
||||
@@ -46,6 +46,9 @@ export function normalizeProviderId(provider: string): string {
|
||||
if (normalized === "opencode-zen") {
|
||||
return "opencode";
|
||||
}
|
||||
if (normalized === "opencode-go-auth") {
|
||||
return "opencode-go";
|
||||
}
|
||||
if (normalized === "qwen") {
|
||||
return "qwen-portal";
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import {
|
||||
type ExistingProviderConfig,
|
||||
} from "./models-config.merge.js";
|
||||
import {
|
||||
enforceSourceManagedProviderSecrets,
|
||||
normalizeProviders,
|
||||
resolveImplicitProviders,
|
||||
type ProviderConfig,
|
||||
@@ -86,6 +87,7 @@ async function resolveProvidersForMode(params: {
|
||||
|
||||
export async function planOpenClawModelsJson(params: {
|
||||
cfg: OpenClawConfig;
|
||||
sourceConfigForSecrets?: OpenClawConfig;
|
||||
agentDir: string;
|
||||
env: NodeJS.ProcessEnv;
|
||||
existingRaw: string;
|
||||
@@ -106,6 +108,8 @@ export async function planOpenClawModelsJson(params: {
|
||||
agentDir,
|
||||
env,
|
||||
secretDefaults: cfg.secrets?.defaults,
|
||||
sourceProviders: params.sourceConfigForSecrets?.models?.providers,
|
||||
sourceSecretDefaults: params.sourceConfigForSecrets?.secrets?.defaults,
|
||||
secretRefManagedProviders,
|
||||
}) ?? providers;
|
||||
const mergedProviders = await resolveProvidersForMode({
|
||||
@@ -115,7 +119,14 @@ export async function planOpenClawModelsJson(params: {
|
||||
secretRefManagedProviders,
|
||||
explicitBaseUrlProviders: resolveExplicitBaseUrlProviders(cfg.models),
|
||||
});
|
||||
const nextContents = `${JSON.stringify({ providers: mergedProviders }, null, 2)}\n`;
|
||||
const secretEnforcedProviders =
|
||||
enforceSourceManagedProviderSecrets({
|
||||
providers: mergedProviders,
|
||||
sourceProviders: params.sourceConfigForSecrets?.models?.providers,
|
||||
sourceSecretDefaults: params.sourceConfigForSecrets?.secrets?.defaults,
|
||||
secretRefManagedProviders,
|
||||
}) ?? mergedProviders;
|
||||
const nextContents = `${JSON.stringify({ providers: secretEnforcedProviders }, null, 2)}\n`;
|
||||
|
||||
if (params.existingRaw === nextContents) {
|
||||
return { action: "noop" };
|
||||
|
||||
@@ -9,27 +9,27 @@ import {
|
||||
buildHuggingfaceModelDefinition,
|
||||
} from "./huggingface-models.js";
|
||||
import { discoverKilocodeModels } from "./kilocode-models.js";
|
||||
import { OLLAMA_NATIVE_BASE_URL } from "./ollama-stream.js";
|
||||
import {
|
||||
enrichOllamaModelsWithContext,
|
||||
OLLAMA_DEFAULT_CONTEXT_WINDOW,
|
||||
OLLAMA_DEFAULT_COST,
|
||||
OLLAMA_DEFAULT_MAX_TOKENS,
|
||||
isReasoningModelHeuristic,
|
||||
resolveOllamaApiBase,
|
||||
type OllamaTagsResponse,
|
||||
} from "./ollama-models.js";
|
||||
import { discoverVeniceModels, VENICE_BASE_URL } from "./venice-models.js";
|
||||
import { discoverVercelAiGatewayModels, VERCEL_AI_GATEWAY_BASE_URL } from "./vercel-ai-gateway.js";
|
||||
|
||||
export { resolveOllamaApiBase } from "./ollama-models.js";
|
||||
|
||||
type ModelsConfig = NonNullable<OpenClawConfig["models"]>;
|
||||
type ProviderConfig = NonNullable<ModelsConfig["providers"]>[string];
|
||||
|
||||
const log = createSubsystemLogger("agents/model-providers");
|
||||
|
||||
const OLLAMA_BASE_URL = OLLAMA_NATIVE_BASE_URL;
|
||||
const OLLAMA_API_BASE_URL = OLLAMA_BASE_URL;
|
||||
const OLLAMA_SHOW_CONCURRENCY = 8;
|
||||
const OLLAMA_SHOW_MAX_MODELS = 200;
|
||||
const OLLAMA_DEFAULT_CONTEXT_WINDOW = 128000;
|
||||
const OLLAMA_DEFAULT_MAX_TOKENS = 8192;
|
||||
const OLLAMA_DEFAULT_COST = {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
};
|
||||
|
||||
const VLLM_BASE_URL = "http://127.0.0.1:8000/v1";
|
||||
const VLLM_DEFAULT_CONTEXT_WINDOW = 128000;
|
||||
@@ -41,76 +41,12 @@ const VLLM_DEFAULT_COST = {
|
||||
cacheWrite: 0,
|
||||
};
|
||||
|
||||
interface OllamaModel {
|
||||
name: string;
|
||||
modified_at: string;
|
||||
size: number;
|
||||
digest: string;
|
||||
details?: {
|
||||
family?: string;
|
||||
parameter_size?: string;
|
||||
};
|
||||
}
|
||||
|
||||
interface OllamaTagsResponse {
|
||||
models: OllamaModel[];
|
||||
}
|
||||
|
||||
type VllmModelsResponse = {
|
||||
data?: Array<{
|
||||
id?: string;
|
||||
}>;
|
||||
};
|
||||
|
||||
/**
|
||||
* Derive the Ollama native API base URL from a configured base URL.
|
||||
*
|
||||
* Users typically configure `baseUrl` with a `/v1` suffix (e.g.
|
||||
* `http://192.168.20.14:11434/v1`) for the OpenAI-compatible endpoint.
|
||||
* The native Ollama API lives at the root (e.g. `/api/tags`), so we
|
||||
* strip the `/v1` suffix when present.
|
||||
*/
|
||||
export function resolveOllamaApiBase(configuredBaseUrl?: string): string {
|
||||
if (!configuredBaseUrl) {
|
||||
return OLLAMA_API_BASE_URL;
|
||||
}
|
||||
// Strip trailing slash, then strip /v1 suffix if present
|
||||
const trimmed = configuredBaseUrl.replace(/\/+$/, "");
|
||||
return trimmed.replace(/\/v1$/i, "");
|
||||
}
|
||||
|
||||
async function queryOllamaContextWindow(
|
||||
apiBase: string,
|
||||
modelName: string,
|
||||
): Promise<number | undefined> {
|
||||
try {
|
||||
const response = await fetch(`${apiBase}/api/show`, {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({ name: modelName }),
|
||||
signal: AbortSignal.timeout(3000),
|
||||
});
|
||||
if (!response.ok) {
|
||||
return undefined;
|
||||
}
|
||||
const data = (await response.json()) as { model_info?: Record<string, unknown> };
|
||||
if (!data.model_info) {
|
||||
return undefined;
|
||||
}
|
||||
for (const [key, value] of Object.entries(data.model_info)) {
|
||||
if (key.endsWith(".context_length") && typeof value === "number" && Number.isFinite(value)) {
|
||||
const contextWindow = Math.floor(value);
|
||||
if (contextWindow > 0) {
|
||||
return contextWindow;
|
||||
}
|
||||
}
|
||||
}
|
||||
return undefined;
|
||||
} catch {
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
|
||||
async function discoverOllamaModels(
|
||||
baseUrl?: string,
|
||||
opts?: { quiet?: boolean },
|
||||
@@ -140,29 +76,18 @@ async function discoverOllamaModels(
|
||||
`Capping Ollama /api/show inspection to ${OLLAMA_SHOW_MAX_MODELS} models (received ${data.models.length})`,
|
||||
);
|
||||
}
|
||||
const discovered: ModelDefinitionConfig[] = [];
|
||||
for (let index = 0; index < modelsToInspect.length; index += OLLAMA_SHOW_CONCURRENCY) {
|
||||
const batch = modelsToInspect.slice(index, index + OLLAMA_SHOW_CONCURRENCY);
|
||||
const batchDiscovered = await Promise.all(
|
||||
batch.map(async (model) => {
|
||||
const modelId = model.name;
|
||||
const contextWindow = await queryOllamaContextWindow(apiBase, modelId);
|
||||
const isReasoning =
|
||||
modelId.toLowerCase().includes("r1") || modelId.toLowerCase().includes("reasoning");
|
||||
return {
|
||||
id: modelId,
|
||||
name: modelId,
|
||||
reasoning: isReasoning,
|
||||
input: ["text"],
|
||||
cost: OLLAMA_DEFAULT_COST,
|
||||
contextWindow: contextWindow ?? OLLAMA_DEFAULT_CONTEXT_WINDOW,
|
||||
maxTokens: OLLAMA_DEFAULT_MAX_TOKENS,
|
||||
} satisfies ModelDefinitionConfig;
|
||||
}),
|
||||
);
|
||||
discovered.push(...batchDiscovered);
|
||||
}
|
||||
return discovered;
|
||||
const discovered = await enrichOllamaModelsWithContext(apiBase, modelsToInspect, {
|
||||
concurrency: OLLAMA_SHOW_CONCURRENCY,
|
||||
});
|
||||
return discovered.map((model) => ({
|
||||
id: model.name,
|
||||
name: model.name,
|
||||
reasoning: isReasoningModelHeuristic(model.name),
|
||||
input: ["text"],
|
||||
cost: OLLAMA_DEFAULT_COST,
|
||||
contextWindow: model.contextWindow ?? OLLAMA_DEFAULT_CONTEXT_WINDOW,
|
||||
maxTokens: OLLAMA_DEFAULT_MAX_TOKENS,
|
||||
}));
|
||||
} catch (error) {
|
||||
if (!opts?.quiet) {
|
||||
log.warn(`Failed to discover Ollama models: ${String(error)}`);
|
||||
@@ -204,13 +129,10 @@ async function discoverVllmModels(
|
||||
.filter((model) => Boolean(model.id))
|
||||
.map((model) => {
|
||||
const modelId = model.id;
|
||||
const lower = modelId.toLowerCase();
|
||||
const isReasoning =
|
||||
lower.includes("r1") || lower.includes("reasoning") || lower.includes("think");
|
||||
return {
|
||||
id: modelId,
|
||||
name: modelId,
|
||||
reasoning: isReasoning,
|
||||
reasoning: isReasoningModelHeuristic(modelId),
|
||||
input: ["text"],
|
||||
cost: VLLM_DEFAULT_COST,
|
||||
contextWindow: VLLM_DEFAULT_CONTEXT_WINDOW,
|
||||
|
||||
@@ -4,7 +4,10 @@ import path from "node:path";
|
||||
import { describe, expect, it } from "vitest";
|
||||
import type { OpenClawConfig } from "../config/config.js";
|
||||
import { NON_ENV_SECRETREF_MARKER } from "./model-auth-markers.js";
|
||||
import { normalizeProviders } from "./models-config.providers.js";
|
||||
import {
|
||||
enforceSourceManagedProviderSecrets,
|
||||
normalizeProviders,
|
||||
} from "./models-config.providers.js";
|
||||
|
||||
describe("normalizeProviders", () => {
|
||||
it("trims provider keys so image models remain discoverable for custom providers", async () => {
|
||||
@@ -136,4 +139,38 @@ describe("normalizeProviders", () => {
|
||||
await fs.rm(agentDir, { recursive: true, force: true });
|
||||
}
|
||||
});
|
||||
|
||||
it("ignores non-object provider entries during source-managed enforcement", () => {
|
||||
const providers = {
|
||||
openai: null,
|
||||
moonshot: {
|
||||
baseUrl: "https://api.moonshot.ai/v1",
|
||||
api: "openai-completions",
|
||||
apiKey: "sk-runtime-moonshot", // pragma: allowlist secret
|
||||
models: [],
|
||||
},
|
||||
} as unknown as NonNullable<NonNullable<OpenClawConfig["models"]>["providers"]>;
|
||||
|
||||
const sourceProviders: NonNullable<NonNullable<OpenClawConfig["models"]>["providers"]> = {
|
||||
openai: {
|
||||
baseUrl: "https://api.openai.com/v1",
|
||||
api: "openai-completions",
|
||||
apiKey: { source: "env", provider: "default", id: "OPENAI_API_KEY" }, // pragma: allowlist secret
|
||||
models: [],
|
||||
},
|
||||
moonshot: {
|
||||
baseUrl: "https://api.moonshot.ai/v1",
|
||||
api: "openai-completions",
|
||||
apiKey: { source: "env", provider: "default", id: "MOONSHOT_API_KEY" }, // pragma: allowlist secret
|
||||
models: [],
|
||||
},
|
||||
};
|
||||
|
||||
const enforced = enforceSourceManagedProviderSecrets({
|
||||
providers,
|
||||
sourceProviders,
|
||||
});
|
||||
expect((enforced as Record<string, unknown>).openai).toBeNull();
|
||||
expect(enforced?.moonshot?.apiKey).toBe("MOONSHOT_API_KEY"); // pragma: allowlist secret
|
||||
});
|
||||
});
|
||||
|
||||
@@ -429,6 +429,24 @@ export function buildOpenrouterProvider(): ProviderConfig {
|
||||
contextWindow: OPENROUTER_DEFAULT_CONTEXT_WINDOW,
|
||||
maxTokens: OPENROUTER_DEFAULT_MAX_TOKENS,
|
||||
},
|
||||
{
|
||||
id: "openrouter/hunter-alpha",
|
||||
name: "Hunter Alpha",
|
||||
reasoning: true,
|
||||
input: ["text"],
|
||||
cost: OPENROUTER_DEFAULT_COST,
|
||||
contextWindow: 1048576,
|
||||
maxTokens: 65536,
|
||||
},
|
||||
{
|
||||
id: "openrouter/healer-alpha",
|
||||
name: "Healer Alpha",
|
||||
reasoning: true,
|
||||
input: ["text", "image"],
|
||||
cost: OPENROUTER_DEFAULT_COST,
|
||||
contextWindow: 262144,
|
||||
maxTokens: 65536,
|
||||
},
|
||||
],
|
||||
};
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import {
|
||||
DEFAULT_COPILOT_API_BASE_URL,
|
||||
resolveCopilotApiToken,
|
||||
} from "../providers/github-copilot-token.js";
|
||||
import { isRecord } from "../utils.js";
|
||||
import { normalizeOptionalSecretInput } from "../utils/normalize-secret-input.js";
|
||||
import { ensureAuthProfileStore, listProfilesForProvider } from "./auth-profiles.js";
|
||||
import { discoverBedrockModels } from "./bedrock-discovery.js";
|
||||
@@ -70,6 +71,11 @@ export { resolveOllamaApiBase } from "./models-config.providers.discovery.js";
|
||||
|
||||
type ModelsConfig = NonNullable<OpenClawConfig["models"]>;
|
||||
export type ProviderConfig = NonNullable<ModelsConfig["providers"]>[string];
|
||||
type SecretDefaults = {
|
||||
env?: string;
|
||||
file?: string;
|
||||
exec?: string;
|
||||
};
|
||||
|
||||
const ENV_VAR_NAME_RE = /^[A-Z_][A-Z0-9_]*$/;
|
||||
|
||||
@@ -97,13 +103,7 @@ function resolveAwsSdkApiKeyVarName(env: NodeJS.ProcessEnv = process.env): strin
|
||||
|
||||
function normalizeHeaderValues(params: {
|
||||
headers: ProviderConfig["headers"] | undefined;
|
||||
secretDefaults:
|
||||
| {
|
||||
env?: string;
|
||||
file?: string;
|
||||
exec?: string;
|
||||
}
|
||||
| undefined;
|
||||
secretDefaults: SecretDefaults | undefined;
|
||||
}): { headers: ProviderConfig["headers"] | undefined; mutated: boolean } {
|
||||
const { headers } = params;
|
||||
if (!headers) {
|
||||
@@ -276,15 +276,155 @@ function normalizeAntigravityProvider(provider: ProviderConfig): ProviderConfig
|
||||
return normalizeProviderModels(provider, normalizeAntigravityModelId);
|
||||
}
|
||||
|
||||
function normalizeSourceProviderLookup(
|
||||
providers: ModelsConfig["providers"] | undefined,
|
||||
): Record<string, ProviderConfig> {
|
||||
if (!providers) {
|
||||
return {};
|
||||
}
|
||||
const out: Record<string, ProviderConfig> = {};
|
||||
for (const [key, provider] of Object.entries(providers)) {
|
||||
const normalizedKey = key.trim();
|
||||
if (!normalizedKey || !isRecord(provider)) {
|
||||
continue;
|
||||
}
|
||||
out[normalizedKey] = provider;
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
function resolveSourceManagedApiKeyMarker(params: {
|
||||
sourceProvider: ProviderConfig | undefined;
|
||||
sourceSecretDefaults: SecretDefaults | undefined;
|
||||
}): string | undefined {
|
||||
const sourceApiKeyRef = resolveSecretInputRef({
|
||||
value: params.sourceProvider?.apiKey,
|
||||
defaults: params.sourceSecretDefaults,
|
||||
}).ref;
|
||||
if (!sourceApiKeyRef || !sourceApiKeyRef.id.trim()) {
|
||||
return undefined;
|
||||
}
|
||||
return sourceApiKeyRef.source === "env"
|
||||
? sourceApiKeyRef.id.trim()
|
||||
: resolveNonEnvSecretRefApiKeyMarker(sourceApiKeyRef.source);
|
||||
}
|
||||
|
||||
function resolveSourceManagedHeaderMarkers(params: {
|
||||
sourceProvider: ProviderConfig | undefined;
|
||||
sourceSecretDefaults: SecretDefaults | undefined;
|
||||
}): Record<string, string> {
|
||||
const sourceHeaders = isRecord(params.sourceProvider?.headers)
|
||||
? (params.sourceProvider.headers as Record<string, unknown>)
|
||||
: undefined;
|
||||
if (!sourceHeaders) {
|
||||
return {};
|
||||
}
|
||||
const markers: Record<string, string> = {};
|
||||
for (const [headerName, headerValue] of Object.entries(sourceHeaders)) {
|
||||
const sourceHeaderRef = resolveSecretInputRef({
|
||||
value: headerValue,
|
||||
defaults: params.sourceSecretDefaults,
|
||||
}).ref;
|
||||
if (!sourceHeaderRef || !sourceHeaderRef.id.trim()) {
|
||||
continue;
|
||||
}
|
||||
markers[headerName] =
|
||||
sourceHeaderRef.source === "env"
|
||||
? resolveEnvSecretRefHeaderValueMarker(sourceHeaderRef.id)
|
||||
: resolveNonEnvSecretRefHeaderValueMarker(sourceHeaderRef.source);
|
||||
}
|
||||
return markers;
|
||||
}
|
||||
|
||||
export function enforceSourceManagedProviderSecrets(params: {
|
||||
providers: ModelsConfig["providers"];
|
||||
sourceProviders: ModelsConfig["providers"] | undefined;
|
||||
sourceSecretDefaults?: SecretDefaults;
|
||||
secretRefManagedProviders?: Set<string>;
|
||||
}): ModelsConfig["providers"] {
|
||||
const { providers } = params;
|
||||
if (!providers) {
|
||||
return providers;
|
||||
}
|
||||
const sourceProvidersByKey = normalizeSourceProviderLookup(params.sourceProviders);
|
||||
if (Object.keys(sourceProvidersByKey).length === 0) {
|
||||
return providers;
|
||||
}
|
||||
|
||||
let nextProviders: Record<string, ProviderConfig> | null = null;
|
||||
for (const [providerKey, provider] of Object.entries(providers)) {
|
||||
if (!isRecord(provider)) {
|
||||
continue;
|
||||
}
|
||||
const sourceProvider = sourceProvidersByKey[providerKey.trim()];
|
||||
if (!sourceProvider) {
|
||||
continue;
|
||||
}
|
||||
let nextProvider = provider;
|
||||
let providerMutated = false;
|
||||
|
||||
const sourceApiKeyMarker = resolveSourceManagedApiKeyMarker({
|
||||
sourceProvider,
|
||||
sourceSecretDefaults: params.sourceSecretDefaults,
|
||||
});
|
||||
if (sourceApiKeyMarker) {
|
||||
params.secretRefManagedProviders?.add(providerKey.trim());
|
||||
if (nextProvider.apiKey !== sourceApiKeyMarker) {
|
||||
providerMutated = true;
|
||||
nextProvider = {
|
||||
...nextProvider,
|
||||
apiKey: sourceApiKeyMarker,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
const sourceHeaderMarkers = resolveSourceManagedHeaderMarkers({
|
||||
sourceProvider,
|
||||
sourceSecretDefaults: params.sourceSecretDefaults,
|
||||
});
|
||||
if (Object.keys(sourceHeaderMarkers).length > 0) {
|
||||
const currentHeaders = isRecord(nextProvider.headers)
|
||||
? (nextProvider.headers as Record<string, unknown>)
|
||||
: undefined;
|
||||
const nextHeaders = {
|
||||
...(currentHeaders as Record<string, NonNullable<ProviderConfig["headers"]>[string]>),
|
||||
};
|
||||
let headersMutated = !currentHeaders;
|
||||
for (const [headerName, marker] of Object.entries(sourceHeaderMarkers)) {
|
||||
if (nextHeaders[headerName] === marker) {
|
||||
continue;
|
||||
}
|
||||
headersMutated = true;
|
||||
nextHeaders[headerName] = marker;
|
||||
}
|
||||
if (headersMutated) {
|
||||
providerMutated = true;
|
||||
nextProvider = {
|
||||
...nextProvider,
|
||||
headers: nextHeaders,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
if (!providerMutated) {
|
||||
continue;
|
||||
}
|
||||
if (!nextProviders) {
|
||||
nextProviders = { ...providers };
|
||||
}
|
||||
nextProviders[providerKey] = nextProvider;
|
||||
}
|
||||
|
||||
return nextProviders ?? providers;
|
||||
}
|
||||
|
||||
export function normalizeProviders(params: {
|
||||
providers: ModelsConfig["providers"];
|
||||
agentDir: string;
|
||||
env?: NodeJS.ProcessEnv;
|
||||
secretDefaults?: {
|
||||
env?: string;
|
||||
file?: string;
|
||||
exec?: string;
|
||||
};
|
||||
secretDefaults?: SecretDefaults;
|
||||
sourceProviders?: ModelsConfig["providers"];
|
||||
sourceSecretDefaults?: SecretDefaults;
|
||||
secretRefManagedProviders?: Set<string>;
|
||||
}): ModelsConfig["providers"] {
|
||||
const { providers } = params;
|
||||
@@ -434,7 +574,13 @@ export function normalizeProviders(params: {
|
||||
next[normalizedKey] = normalizedProvider;
|
||||
}
|
||||
|
||||
return mutated ? next : providers;
|
||||
const normalizedProviders = mutated ? next : providers;
|
||||
return enforceSourceManagedProviderSecrets({
|
||||
providers: normalizedProviders,
|
||||
sourceProviders: params.sourceProviders,
|
||||
sourceSecretDefaults: params.sourceSecretDefaults,
|
||||
secretRefManagedProviders: params.secretRefManagedProviders,
|
||||
});
|
||||
}
|
||||
|
||||
type ImplicitProviderParams = {
|
||||
|
||||
@@ -209,4 +209,152 @@ describe("models-config runtime source snapshot", () => {
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
it("keeps source markers when runtime projection is skipped for incompatible top-level shape", async () => {
|
||||
await withTempHome(async () => {
|
||||
const sourceConfig: OpenClawConfig = {
|
||||
models: {
|
||||
providers: {
|
||||
openai: {
|
||||
baseUrl: "https://api.openai.com/v1",
|
||||
apiKey: { source: "env", provider: "default", id: "OPENAI_API_KEY" }, // pragma: allowlist secret
|
||||
api: "openai-completions" as const,
|
||||
models: [],
|
||||
},
|
||||
},
|
||||
},
|
||||
gateway: {
|
||||
auth: {
|
||||
mode: "token",
|
||||
},
|
||||
},
|
||||
};
|
||||
const runtimeConfig: OpenClawConfig = {
|
||||
models: {
|
||||
providers: {
|
||||
openai: {
|
||||
baseUrl: "https://api.openai.com/v1",
|
||||
apiKey: "sk-runtime-resolved", // pragma: allowlist secret
|
||||
api: "openai-completions" as const,
|
||||
models: [],
|
||||
},
|
||||
},
|
||||
},
|
||||
gateway: {
|
||||
auth: {
|
||||
mode: "token",
|
||||
},
|
||||
},
|
||||
};
|
||||
const incompatibleCandidate: OpenClawConfig = {
|
||||
models: {
|
||||
providers: {
|
||||
openai: {
|
||||
baseUrl: "https://api.openai.com/v1",
|
||||
apiKey: "sk-runtime-resolved", // pragma: allowlist secret
|
||||
api: "openai-completions" as const,
|
||||
models: [],
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
try {
|
||||
setRuntimeConfigSnapshot(runtimeConfig, sourceConfig);
|
||||
await ensureOpenClawModelsJson(incompatibleCandidate);
|
||||
|
||||
const parsed = await readGeneratedModelsJson<{
|
||||
providers: Record<string, { apiKey?: string }>;
|
||||
}>();
|
||||
expect(parsed.providers.openai?.apiKey).toBe("OPENAI_API_KEY"); // pragma: allowlist secret
|
||||
} finally {
|
||||
clearRuntimeConfigSnapshot();
|
||||
clearConfigCache();
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
it("keeps source header markers when runtime projection is skipped for incompatible top-level shape", async () => {
|
||||
await withTempHome(async () => {
|
||||
const sourceConfig: OpenClawConfig = {
|
||||
models: {
|
||||
providers: {
|
||||
openai: {
|
||||
baseUrl: "https://api.openai.com/v1",
|
||||
api: "openai-completions" as const,
|
||||
headers: {
|
||||
Authorization: {
|
||||
source: "env",
|
||||
provider: "default",
|
||||
id: "OPENAI_HEADER_TOKEN", // pragma: allowlist secret
|
||||
},
|
||||
"X-Tenant-Token": {
|
||||
source: "file",
|
||||
provider: "vault",
|
||||
id: "/providers/openai/tenantToken",
|
||||
},
|
||||
},
|
||||
models: [],
|
||||
},
|
||||
},
|
||||
},
|
||||
gateway: {
|
||||
auth: {
|
||||
mode: "token",
|
||||
},
|
||||
},
|
||||
};
|
||||
const runtimeConfig: OpenClawConfig = {
|
||||
models: {
|
||||
providers: {
|
||||
openai: {
|
||||
baseUrl: "https://api.openai.com/v1",
|
||||
api: "openai-completions" as const,
|
||||
headers: {
|
||||
Authorization: "Bearer runtime-openai-token",
|
||||
"X-Tenant-Token": "runtime-tenant-token",
|
||||
},
|
||||
models: [],
|
||||
},
|
||||
},
|
||||
},
|
||||
gateway: {
|
||||
auth: {
|
||||
mode: "token",
|
||||
},
|
||||
},
|
||||
};
|
||||
const incompatibleCandidate: OpenClawConfig = {
|
||||
models: {
|
||||
providers: {
|
||||
openai: {
|
||||
baseUrl: "https://api.openai.com/v1",
|
||||
api: "openai-completions" as const,
|
||||
headers: {
|
||||
Authorization: "Bearer runtime-openai-token",
|
||||
"X-Tenant-Token": "runtime-tenant-token",
|
||||
},
|
||||
models: [],
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
try {
|
||||
setRuntimeConfigSnapshot(runtimeConfig, sourceConfig);
|
||||
await ensureOpenClawModelsJson(incompatibleCandidate);
|
||||
|
||||
const parsed = await readGeneratedModelsJson<{
|
||||
providers: Record<string, { headers?: Record<string, string> }>;
|
||||
}>();
|
||||
expect(parsed.providers.openai?.headers?.Authorization).toBe(
|
||||
"secretref-env:OPENAI_HEADER_TOKEN", // pragma: allowlist secret
|
||||
);
|
||||
expect(parsed.providers.openai?.headers?.["X-Tenant-Token"]).toBe(NON_ENV_SECRETREF_MARKER);
|
||||
} finally {
|
||||
clearRuntimeConfigSnapshot();
|
||||
clearConfigCache();
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -42,15 +42,31 @@ async function writeModelsFileAtomic(targetPath: string, contents: string): Prom
|
||||
await fs.rename(tempPath, targetPath);
|
||||
}
|
||||
|
||||
function resolveModelsConfigInput(config?: OpenClawConfig): OpenClawConfig {
|
||||
function resolveModelsConfigInput(config?: OpenClawConfig): {
|
||||
config: OpenClawConfig;
|
||||
sourceConfigForSecrets: OpenClawConfig;
|
||||
} {
|
||||
const runtimeSource = getRuntimeConfigSourceSnapshot();
|
||||
if (!config) {
|
||||
return runtimeSource ?? loadConfig();
|
||||
const loaded = loadConfig();
|
||||
return {
|
||||
config: runtimeSource ?? loaded,
|
||||
sourceConfigForSecrets: runtimeSource ?? loaded,
|
||||
};
|
||||
}
|
||||
if (!runtimeSource) {
|
||||
return config;
|
||||
return {
|
||||
config,
|
||||
sourceConfigForSecrets: config,
|
||||
};
|
||||
}
|
||||
return projectConfigOntoRuntimeSourceSnapshot(config);
|
||||
const projected = projectConfigOntoRuntimeSourceSnapshot(config);
|
||||
return {
|
||||
config: projected,
|
||||
// If projection is skipped (for example incompatible top-level shape),
|
||||
// keep managed secret persistence anchored to the active source snapshot.
|
||||
sourceConfigForSecrets: projected === config ? runtimeSource : projected,
|
||||
};
|
||||
}
|
||||
|
||||
async function withModelsJsonWriteLock<T>(targetPath: string, run: () => Promise<T>): Promise<T> {
|
||||
@@ -76,7 +92,8 @@ export async function ensureOpenClawModelsJson(
|
||||
config?: OpenClawConfig,
|
||||
agentDirOverride?: string,
|
||||
): Promise<{ agentDir: string; wrote: boolean }> {
|
||||
const cfg = resolveModelsConfigInput(config);
|
||||
const resolved = resolveModelsConfigInput(config);
|
||||
const cfg = resolved.config;
|
||||
const agentDir = agentDirOverride?.trim() ? agentDirOverride.trim() : resolveOpenClawAgentDir();
|
||||
const targetPath = path.join(agentDir, "models.json");
|
||||
|
||||
@@ -87,6 +104,7 @@ export async function ensureOpenClawModelsJson(
|
||||
const existingModelsFile = await readExistingModelsFile(targetPath);
|
||||
const plan = await planOpenClawModelsJson({
|
||||
cfg,
|
||||
sourceConfigForSecrets: resolved.sourceConfigForSecrets,
|
||||
agentDir,
|
||||
env,
|
||||
existingRaw: existingModelsFile.raw,
|
||||
|
||||
@@ -9,10 +9,6 @@ import {
|
||||
isAnthropicBillingError,
|
||||
isAnthropicRateLimitError,
|
||||
} from "./live-auth-keys.js";
|
||||
import {
|
||||
isMiniMaxModelNotFoundErrorMessage,
|
||||
isModelNotFoundErrorMessage,
|
||||
} from "./live-model-errors.js";
|
||||
import { isModernModelRef } from "./live-model-filter.js";
|
||||
import { getApiKeyForModel, requireApiKey } from "./model-auth.js";
|
||||
import { ensureOpenClawModelsJson } from "./models-config.js";
|
||||
@@ -86,6 +82,35 @@ function isGoogleModelNotFoundError(err: unknown): boolean {
|
||||
return false;
|
||||
}
|
||||
|
||||
function isModelNotFoundErrorMessage(raw: string): boolean {
|
||||
const msg = raw.trim();
|
||||
if (!msg) {
|
||||
return false;
|
||||
}
|
||||
if (/\b404\b/.test(msg) && /not(?:[\s_-]+)?found/i.test(msg)) {
|
||||
return true;
|
||||
}
|
||||
if (/not_found_error/i.test(msg)) {
|
||||
return true;
|
||||
}
|
||||
if (/model:\s*[a-z0-9._-]+/i.test(msg) && /not(?:[\s_-]+)?found/i.test(msg)) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
describe("isModelNotFoundErrorMessage", () => {
|
||||
it("matches whitespace-separated not found errors", () => {
|
||||
expect(isModelNotFoundErrorMessage("404 model not found")).toBe(true);
|
||||
expect(isModelNotFoundErrorMessage("model: minimax-text-01 not found")).toBe(true);
|
||||
});
|
||||
|
||||
it("still matches underscore and hyphen variants", () => {
|
||||
expect(isModelNotFoundErrorMessage("404 model not_found")).toBe(true);
|
||||
expect(isModelNotFoundErrorMessage("404 model not-found")).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
function isChatGPTUsageLimitErrorMessage(raw: string): boolean {
|
||||
const msg = raw.toLowerCase();
|
||||
return msg.includes("hit your chatgpt usage limit") && msg.includes("try again in");
|
||||
@@ -475,11 +500,7 @@ describeLive("live models (profile keys)", () => {
|
||||
|
||||
if (ok.res.stopReason === "error") {
|
||||
const msg = ok.res.errorMessage ?? "";
|
||||
if (
|
||||
allowNotFoundSkip &&
|
||||
(isModelNotFoundErrorMessage(msg) ||
|
||||
(model.provider === "minimax" && isMiniMaxModelNotFoundErrorMessage(msg)))
|
||||
) {
|
||||
if (allowNotFoundSkip && isModelNotFoundErrorMessage(msg)) {
|
||||
skipped.push({ model: id, reason: msg });
|
||||
logProgress(`${progressLabel}: skip (model not found)`);
|
||||
break;
|
||||
@@ -500,7 +521,9 @@ describeLive("live models (profile keys)", () => {
|
||||
}
|
||||
if (
|
||||
ok.text.length === 0 &&
|
||||
(model.provider === "openrouter" || model.provider === "opencode")
|
||||
(model.provider === "openrouter" ||
|
||||
model.provider === "opencode" ||
|
||||
model.provider === "opencode-go")
|
||||
) {
|
||||
skipped.push({
|
||||
model: id,
|
||||
@@ -563,15 +586,6 @@ describeLive("live models (profile keys)", () => {
|
||||
logProgress(`${progressLabel}: skip (google model not found)`);
|
||||
break;
|
||||
}
|
||||
if (
|
||||
allowNotFoundSkip &&
|
||||
model.provider === "minimax" &&
|
||||
isMiniMaxModelNotFoundErrorMessage(message)
|
||||
) {
|
||||
skipped.push({ model: id, reason: message });
|
||||
logProgress(`${progressLabel}: skip (model not found)`);
|
||||
break;
|
||||
}
|
||||
if (
|
||||
allowNotFoundSkip &&
|
||||
model.provider === "minimax" &&
|
||||
@@ -592,7 +606,7 @@ describeLive("live models (profile keys)", () => {
|
||||
}
|
||||
if (
|
||||
allowNotFoundSkip &&
|
||||
model.provider === "opencode" &&
|
||||
(model.provider === "opencode" || model.provider === "opencode-go") &&
|
||||
isRateLimitErrorMessage(message)
|
||||
) {
|
||||
skipped.push({ model: id, reason: message });
|
||||
|
||||
61
src/agents/ollama-models.test.ts
Normal file
61
src/agents/ollama-models.test.ts
Normal file
@@ -0,0 +1,61 @@
|
||||
import { afterEach, describe, expect, it, vi } from "vitest";
|
||||
import {
|
||||
enrichOllamaModelsWithContext,
|
||||
resolveOllamaApiBase,
|
||||
type OllamaTagModel,
|
||||
} from "./ollama-models.js";
|
||||
|
||||
function jsonResponse(body: unknown, status = 200): Response {
|
||||
return new Response(JSON.stringify(body), {
|
||||
status,
|
||||
headers: { "Content-Type": "application/json" },
|
||||
});
|
||||
}
|
||||
|
||||
function requestUrl(input: string | URL | Request): string {
|
||||
if (typeof input === "string") {
|
||||
return input;
|
||||
}
|
||||
if (input instanceof URL) {
|
||||
return input.toString();
|
||||
}
|
||||
return input.url;
|
||||
}
|
||||
|
||||
function requestBody(body: BodyInit | null | undefined): string {
|
||||
return typeof body === "string" ? body : "{}";
|
||||
}
|
||||
|
||||
describe("ollama-models", () => {
|
||||
afterEach(() => {
|
||||
vi.unstubAllGlobals();
|
||||
});
|
||||
|
||||
it("strips /v1 when resolving the Ollama API base", () => {
|
||||
expect(resolveOllamaApiBase("http://127.0.0.1:11434/v1")).toBe("http://127.0.0.1:11434");
|
||||
expect(resolveOllamaApiBase("http://127.0.0.1:11434///")).toBe("http://127.0.0.1:11434");
|
||||
});
|
||||
|
||||
it("enriches discovered models with context windows from /api/show", async () => {
|
||||
const models: OllamaTagModel[] = [{ name: "llama3:8b" }, { name: "deepseek-r1:14b" }];
|
||||
const fetchMock = vi.fn(async (input: string | URL | Request, init?: RequestInit) => {
|
||||
const url = requestUrl(input);
|
||||
if (!url.endsWith("/api/show")) {
|
||||
throw new Error(`Unexpected fetch: ${url}`);
|
||||
}
|
||||
const body = JSON.parse(requestBody(init?.body)) as { name?: string };
|
||||
if (body.name === "llama3:8b") {
|
||||
return jsonResponse({ model_info: { "llama.context_length": 65536 } });
|
||||
}
|
||||
return jsonResponse({});
|
||||
});
|
||||
vi.stubGlobal("fetch", fetchMock);
|
||||
|
||||
const enriched = await enrichOllamaModelsWithContext("http://127.0.0.1:11434", models);
|
||||
|
||||
expect(enriched).toEqual([
|
||||
{ name: "llama3:8b", contextWindow: 65536 },
|
||||
{ name: "deepseek-r1:14b", contextWindow: undefined },
|
||||
]);
|
||||
});
|
||||
});
|
||||
143
src/agents/ollama-models.ts
Normal file
143
src/agents/ollama-models.ts
Normal file
@@ -0,0 +1,143 @@
|
||||
import type { ModelDefinitionConfig } from "../config/types.models.js";
|
||||
import { OLLAMA_NATIVE_BASE_URL } from "./ollama-stream.js";
|
||||
|
||||
export const OLLAMA_DEFAULT_BASE_URL = OLLAMA_NATIVE_BASE_URL;
|
||||
export const OLLAMA_DEFAULT_CONTEXT_WINDOW = 128000;
|
||||
export const OLLAMA_DEFAULT_MAX_TOKENS = 8192;
|
||||
export const OLLAMA_DEFAULT_COST = {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
};
|
||||
|
||||
export type OllamaTagModel = {
|
||||
name: string;
|
||||
modified_at?: string;
|
||||
size?: number;
|
||||
digest?: string;
|
||||
remote_host?: string;
|
||||
details?: {
|
||||
family?: string;
|
||||
parameter_size?: string;
|
||||
};
|
||||
};
|
||||
|
||||
export type OllamaTagsResponse = {
|
||||
models?: OllamaTagModel[];
|
||||
};
|
||||
|
||||
export type OllamaModelWithContext = OllamaTagModel & {
|
||||
contextWindow?: number;
|
||||
};
|
||||
|
||||
const OLLAMA_SHOW_CONCURRENCY = 8;
|
||||
|
||||
/**
|
||||
* Derive the Ollama native API base URL from a configured base URL.
|
||||
*
|
||||
* Users typically configure `baseUrl` with a `/v1` suffix (e.g.
|
||||
* `http://192.168.20.14:11434/v1`) for the OpenAI-compatible endpoint.
|
||||
* The native Ollama API lives at the root (e.g. `/api/tags`), so we
|
||||
* strip the `/v1` suffix when present.
|
||||
*/
|
||||
export function resolveOllamaApiBase(configuredBaseUrl?: string): string {
|
||||
if (!configuredBaseUrl) {
|
||||
return OLLAMA_DEFAULT_BASE_URL;
|
||||
}
|
||||
const trimmed = configuredBaseUrl.replace(/\/+$/, "");
|
||||
return trimmed.replace(/\/v1$/i, "");
|
||||
}
|
||||
|
||||
export async function queryOllamaContextWindow(
|
||||
apiBase: string,
|
||||
modelName: string,
|
||||
): Promise<number | undefined> {
|
||||
try {
|
||||
const response = await fetch(`${apiBase}/api/show`, {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({ name: modelName }),
|
||||
signal: AbortSignal.timeout(3000),
|
||||
});
|
||||
if (!response.ok) {
|
||||
return undefined;
|
||||
}
|
||||
const data = (await response.json()) as { model_info?: Record<string, unknown> };
|
||||
if (!data.model_info) {
|
||||
return undefined;
|
||||
}
|
||||
for (const [key, value] of Object.entries(data.model_info)) {
|
||||
if (key.endsWith(".context_length") && typeof value === "number" && Number.isFinite(value)) {
|
||||
const contextWindow = Math.floor(value);
|
||||
if (contextWindow > 0) {
|
||||
return contextWindow;
|
||||
}
|
||||
}
|
||||
}
|
||||
return undefined;
|
||||
} catch {
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
|
||||
export async function enrichOllamaModelsWithContext(
|
||||
apiBase: string,
|
||||
models: OllamaTagModel[],
|
||||
opts?: { concurrency?: number },
|
||||
): Promise<OllamaModelWithContext[]> {
|
||||
const concurrency = Math.max(1, Math.floor(opts?.concurrency ?? OLLAMA_SHOW_CONCURRENCY));
|
||||
const enriched: OllamaModelWithContext[] = [];
|
||||
for (let index = 0; index < models.length; index += concurrency) {
|
||||
const batch = models.slice(index, index + concurrency);
|
||||
const batchResults = await Promise.all(
|
||||
batch.map(async (model) => ({
|
||||
...model,
|
||||
contextWindow: await queryOllamaContextWindow(apiBase, model.name),
|
||||
})),
|
||||
);
|
||||
enriched.push(...batchResults);
|
||||
}
|
||||
return enriched;
|
||||
}
|
||||
|
||||
/** Heuristic: treat models with "r1", "reasoning", or "think" in the name as reasoning models. */
|
||||
export function isReasoningModelHeuristic(modelId: string): boolean {
|
||||
return /r1|reasoning|think|reason/i.test(modelId);
|
||||
}
|
||||
|
||||
/** Build a ModelDefinitionConfig for an Ollama model with default values. */
|
||||
export function buildOllamaModelDefinition(
|
||||
modelId: string,
|
||||
contextWindow?: number,
|
||||
): ModelDefinitionConfig {
|
||||
return {
|
||||
id: modelId,
|
||||
name: modelId,
|
||||
reasoning: isReasoningModelHeuristic(modelId),
|
||||
input: ["text"],
|
||||
cost: OLLAMA_DEFAULT_COST,
|
||||
contextWindow: contextWindow ?? OLLAMA_DEFAULT_CONTEXT_WINDOW,
|
||||
maxTokens: OLLAMA_DEFAULT_MAX_TOKENS,
|
||||
};
|
||||
}
|
||||
|
||||
/** Fetch the model list from a running Ollama instance. */
|
||||
export async function fetchOllamaModels(
|
||||
baseUrl: string,
|
||||
): Promise<{ reachable: boolean; models: OllamaTagModel[] }> {
|
||||
try {
|
||||
const apiBase = resolveOllamaApiBase(baseUrl);
|
||||
const response = await fetch(`${apiBase}/api/tags`, {
|
||||
signal: AbortSignal.timeout(5000),
|
||||
});
|
||||
if (!response.ok) {
|
||||
return { reachable: true, models: [] };
|
||||
}
|
||||
const data = (await response.json()) as OllamaTagsResponse;
|
||||
const models = (data.models ?? []).filter((m) => m.name);
|
||||
return { reachable: true, models };
|
||||
} catch {
|
||||
return { reachable: false, models: [] };
|
||||
}
|
||||
}
|
||||
@@ -30,6 +30,13 @@ function extractInputTypes(input: unknown[]) {
|
||||
.filter((t): t is string => typeof t === "string");
|
||||
}
|
||||
|
||||
function extractInputMessages(input: unknown[]) {
|
||||
return input.filter(
|
||||
(item): item is Record<string, unknown> =>
|
||||
!!item && typeof item === "object" && (item as Record<string, unknown>).type === "message",
|
||||
);
|
||||
}
|
||||
|
||||
const ZERO_USAGE = {
|
||||
input: 0,
|
||||
output: 0,
|
||||
@@ -184,4 +191,36 @@ describe("openai-responses reasoning replay", () => {
|
||||
expect(types).toContain("reasoning");
|
||||
expect(types).toContain("message");
|
||||
});
|
||||
|
||||
it.each(["commentary", "final_answer"] as const)(
|
||||
"replays assistant message phase metadata for %s",
|
||||
async (phase) => {
|
||||
const assistantWithText = buildAssistantMessage({
|
||||
stopReason: "stop",
|
||||
content: [
|
||||
buildReasoningPart(),
|
||||
{
|
||||
type: "text",
|
||||
text: "hello",
|
||||
textSignature: JSON.stringify({ v: 1, id: `msg_${phase}`, phase }),
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
const { input, types } = await runAbortedOpenAIResponsesStream({
|
||||
messages: [
|
||||
{ role: "user", content: "Hi", timestamp: Date.now() },
|
||||
assistantWithText,
|
||||
{ role: "user", content: "Ok", timestamp: Date.now() },
|
||||
],
|
||||
});
|
||||
|
||||
expect(types).toContain("message");
|
||||
|
||||
const replayedMessage = extractInputMessages(input).find(
|
||||
(item) => item.id === `msg_${phase}`,
|
||||
);
|
||||
expect(replayedMessage?.phase).toBe(phase);
|
||||
},
|
||||
);
|
||||
});
|
||||
|
||||
@@ -595,14 +595,12 @@ describe("OpenAIWebSocketManager", () => {
|
||||
|
||||
manager.warmUp({
|
||||
model: "gpt-5.2",
|
||||
tools: [{ type: "function", function: { name: "exec", description: "Run a command" } }],
|
||||
tools: [{ type: "function", name: "exec", description: "Run a command" }],
|
||||
});
|
||||
|
||||
const sent = JSON.parse(sock.sentMessages[0] ?? "{}") as Record<string, unknown>;
|
||||
expect(sent["tools"]).toHaveLength(1);
|
||||
expect((sent["tools"] as Array<{ function?: { name?: string } }>)[0]?.function?.name).toBe(
|
||||
"exec",
|
||||
);
|
||||
expect((sent["tools"] as Array<{ name?: string }>)[0]?.name).toBe("exec");
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
@@ -37,12 +37,15 @@ export interface UsageInfo {
|
||||
total_tokens: number;
|
||||
}
|
||||
|
||||
export type OpenAIResponsesAssistantPhase = "commentary" | "final_answer";
|
||||
|
||||
export type OutputItem =
|
||||
| {
|
||||
type: "message";
|
||||
id: string;
|
||||
role: "assistant";
|
||||
content: Array<{ type: "output_text"; text: string }>;
|
||||
phase?: OpenAIResponsesAssistantPhase;
|
||||
status?: "in_progress" | "completed";
|
||||
}
|
||||
| {
|
||||
@@ -190,6 +193,7 @@ export type InputItem =
|
||||
type: "message";
|
||||
role: "system" | "developer" | "user" | "assistant";
|
||||
content: string | ContentPart[];
|
||||
phase?: OpenAIResponsesAssistantPhase;
|
||||
}
|
||||
| { type: "function_call"; id?: string; call_id?: string; name: string; arguments: string }
|
||||
| { type: "function_call_output"; call_id: string; output: string }
|
||||
@@ -204,11 +208,10 @@ export type ToolChoice =
|
||||
|
||||
export interface FunctionToolDefinition {
|
||||
type: "function";
|
||||
function: {
|
||||
name: string;
|
||||
description?: string;
|
||||
parameters?: Record<string, unknown>;
|
||||
};
|
||||
name: string;
|
||||
description?: string;
|
||||
parameters?: Record<string, unknown>;
|
||||
strict?: boolean;
|
||||
}
|
||||
|
||||
/** Standard response.create event payload (full turn) */
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
* Skipped in CI — no API key available and we avoid billable external calls.
|
||||
*/
|
||||
|
||||
import type { AssistantMessage, Context } from "@mariozechner/pi-ai";
|
||||
import { describe, it, expect, afterEach } from "vitest";
|
||||
import {
|
||||
createOpenAIWebSocketStreamFn,
|
||||
@@ -28,14 +29,13 @@ const testFn = LIVE ? it : it.skip;
|
||||
const model = {
|
||||
api: "openai-responses" as const,
|
||||
provider: "openai",
|
||||
id: "gpt-4o-mini",
|
||||
name: "gpt-4o-mini",
|
||||
baseUrl: "",
|
||||
reasoning: false,
|
||||
input: { maxTokens: 128_000 },
|
||||
output: { maxTokens: 16_384 },
|
||||
cache: false,
|
||||
compat: {},
|
||||
id: "gpt-5.2",
|
||||
name: "gpt-5.2",
|
||||
contextWindow: 128_000,
|
||||
maxTokens: 4_096,
|
||||
reasoning: true,
|
||||
input: ["text"],
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
|
||||
} as unknown as Parameters<ReturnType<typeof createOpenAIWebSocketStreamFn>>[0];
|
||||
|
||||
type StreamFnParams = Parameters<ReturnType<typeof createOpenAIWebSocketStreamFn>>;
|
||||
@@ -47,6 +47,61 @@ function makeContext(userMessage: string): StreamFnParams[1] {
|
||||
} as unknown as StreamFnParams[1];
|
||||
}
|
||||
|
||||
function makeToolContext(userMessage: string): StreamFnParams[1] {
|
||||
return {
|
||||
systemPrompt: "You are a precise assistant. Follow tool instructions exactly.",
|
||||
messages: [{ role: "user" as const, content: userMessage }],
|
||||
tools: [
|
||||
{
|
||||
name: "noop",
|
||||
description: "Return the supplied tool result to the user.",
|
||||
parameters: {
|
||||
type: "object",
|
||||
additionalProperties: false,
|
||||
properties: {},
|
||||
},
|
||||
},
|
||||
],
|
||||
} as unknown as Context;
|
||||
}
|
||||
|
||||
function makeToolResultMessage(
|
||||
callId: string,
|
||||
output: string,
|
||||
): StreamFnParams[1]["messages"][number] {
|
||||
return {
|
||||
role: "toolResult" as const,
|
||||
toolCallId: callId,
|
||||
toolName: "noop",
|
||||
content: [{ type: "text" as const, text: output }],
|
||||
isError: false,
|
||||
timestamp: Date.now(),
|
||||
} as unknown as StreamFnParams[1]["messages"][number];
|
||||
}
|
||||
|
||||
async function collectEvents(
|
||||
stream: ReturnType<ReturnType<typeof createOpenAIWebSocketStreamFn>>,
|
||||
): Promise<Array<{ type: string; message?: AssistantMessage }>> {
|
||||
const events: Array<{ type: string; message?: AssistantMessage }> = [];
|
||||
for await (const event of stream as AsyncIterable<{ type: string; message?: AssistantMessage }>) {
|
||||
events.push(event);
|
||||
}
|
||||
return events;
|
||||
}
|
||||
|
||||
function expectDone(events: Array<{ type: string; message?: AssistantMessage }>): AssistantMessage {
|
||||
const done = events.find((event) => event.type === "done")?.message;
|
||||
expect(done).toBeDefined();
|
||||
return done!;
|
||||
}
|
||||
|
||||
function assistantText(message: AssistantMessage): string {
|
||||
return message.content
|
||||
.filter((block) => block.type === "text")
|
||||
.map((block) => block.text)
|
||||
.join("");
|
||||
}
|
||||
|
||||
/** Each test gets a unique session ID to avoid cross-test interference. */
|
||||
const sessions: string[] = [];
|
||||
function freshSession(name: string): string {
|
||||
@@ -68,26 +123,14 @@ describe("OpenAI WebSocket e2e", () => {
|
||||
async () => {
|
||||
const sid = freshSession("single");
|
||||
const streamFn = createOpenAIWebSocketStreamFn(API_KEY!, sid);
|
||||
const stream = streamFn(model, makeContext("What is 2+2?"), {});
|
||||
const stream = streamFn(model, makeContext("What is 2+2?"), { transport: "websocket" });
|
||||
const done = expectDone(await collectEvents(stream));
|
||||
|
||||
const events: Array<{ type: string }> = [];
|
||||
for await (const event of stream as AsyncIterable<{ type: string }>) {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
const done = events.find((e) => e.type === "done") as
|
||||
| { type: "done"; message: { content: Array<{ type: string; text?: string }> } }
|
||||
| undefined;
|
||||
expect(done).toBeDefined();
|
||||
expect(done!.message.content.length).toBeGreaterThan(0);
|
||||
|
||||
const text = done!.message.content
|
||||
.filter((c) => c.type === "text")
|
||||
.map((c) => c.text)
|
||||
.join("");
|
||||
expect(done.content.length).toBeGreaterThan(0);
|
||||
const text = assistantText(done);
|
||||
expect(text).toMatch(/4/);
|
||||
},
|
||||
30_000,
|
||||
45_000,
|
||||
);
|
||||
|
||||
testFn(
|
||||
@@ -96,19 +139,80 @@ describe("OpenAI WebSocket e2e", () => {
|
||||
const sid = freshSession("temp");
|
||||
const streamFn = createOpenAIWebSocketStreamFn(API_KEY!, sid);
|
||||
const stream = streamFn(model, makeContext("Pick a random number between 1 and 1000."), {
|
||||
transport: "websocket",
|
||||
temperature: 0.8,
|
||||
});
|
||||
|
||||
const events: Array<{ type: string }> = [];
|
||||
for await (const event of stream as AsyncIterable<{ type: string }>) {
|
||||
events.push(event);
|
||||
}
|
||||
const events = await collectEvents(stream);
|
||||
|
||||
// Stream must complete (done or error with fallback) — must NOT hang.
|
||||
const hasTerminal = events.some((e) => e.type === "done" || e.type === "error");
|
||||
expect(hasTerminal).toBe(true);
|
||||
},
|
||||
30_000,
|
||||
45_000,
|
||||
);
|
||||
|
||||
testFn(
|
||||
"reuses the websocket session for tool-call follow-up turns",
|
||||
async () => {
|
||||
const sid = freshSession("tool-roundtrip");
|
||||
const streamFn = createOpenAIWebSocketStreamFn(API_KEY!, sid);
|
||||
const firstContext = makeToolContext(
|
||||
"Call the tool `noop` with {}. After the tool result arrives, reply with exactly the tool output and nothing else.",
|
||||
);
|
||||
const firstEvents = await collectEvents(
|
||||
streamFn(model, firstContext, {
|
||||
transport: "websocket",
|
||||
toolChoice: "required",
|
||||
maxTokens: 128,
|
||||
} as unknown as StreamFnParams[2]),
|
||||
);
|
||||
const firstDone = expectDone(firstEvents);
|
||||
const toolCall = firstDone.content.find((block) => block.type === "toolCall") as
|
||||
| { type: "toolCall"; id: string; name: string }
|
||||
| undefined;
|
||||
expect(toolCall?.name).toBe("noop");
|
||||
expect(toolCall?.id).toBeTruthy();
|
||||
|
||||
const secondContext = {
|
||||
...firstContext,
|
||||
messages: [
|
||||
...firstContext.messages,
|
||||
firstDone,
|
||||
makeToolResultMessage(toolCall!.id, "TOOL_OK"),
|
||||
],
|
||||
} as unknown as StreamFnParams[1];
|
||||
const secondDone = expectDone(
|
||||
await collectEvents(
|
||||
streamFn(model, secondContext, {
|
||||
transport: "websocket",
|
||||
maxTokens: 128,
|
||||
}),
|
||||
),
|
||||
);
|
||||
|
||||
expect(assistantText(secondDone)).toMatch(/TOOL_OK/);
|
||||
},
|
||||
60_000,
|
||||
);
|
||||
|
||||
testFn(
|
||||
"supports websocket warm-up before the first request",
|
||||
async () => {
|
||||
const sid = freshSession("warmup");
|
||||
const streamFn = createOpenAIWebSocketStreamFn(API_KEY!, sid);
|
||||
const done = expectDone(
|
||||
await collectEvents(
|
||||
streamFn(model, makeContext("Reply with the word warmed."), {
|
||||
transport: "websocket",
|
||||
openaiWsWarmup: true,
|
||||
maxTokens: 32,
|
||||
} as unknown as StreamFnParams[2]),
|
||||
),
|
||||
);
|
||||
|
||||
expect(assistantText(done).toLowerCase()).toContain("warmed");
|
||||
},
|
||||
45_000,
|
||||
);
|
||||
|
||||
testFn(
|
||||
@@ -119,16 +223,13 @@ describe("OpenAI WebSocket e2e", () => {
|
||||
|
||||
expect(hasWsSession(sid)).toBe(false);
|
||||
|
||||
const stream = streamFn(model, makeContext("Say hello."), {});
|
||||
for await (const _ of stream as AsyncIterable<unknown>) {
|
||||
/* consume */
|
||||
}
|
||||
await collectEvents(streamFn(model, makeContext("Say hello."), { transport: "websocket" }));
|
||||
|
||||
expect(hasWsSession(sid)).toBe(true);
|
||||
releaseWsSession(sid);
|
||||
expect(hasWsSession(sid)).toBe(false);
|
||||
},
|
||||
30_000,
|
||||
45_000,
|
||||
);
|
||||
|
||||
testFn(
|
||||
@@ -137,15 +238,11 @@ describe("OpenAI WebSocket e2e", () => {
|
||||
const sid = freshSession("fallback");
|
||||
const streamFn = createOpenAIWebSocketStreamFn("sk-invalid-key", sid);
|
||||
const stream = streamFn(model, makeContext("Hello"), {});
|
||||
|
||||
const events: Array<{ type: string }> = [];
|
||||
for await (const event of stream as AsyncIterable<{ type: string }>) {
|
||||
events.push(event);
|
||||
}
|
||||
const events = await collectEvents(stream);
|
||||
|
||||
const hasTerminal = events.some((e) => e.type === "done" || e.type === "error");
|
||||
expect(hasTerminal).toBe(true);
|
||||
},
|
||||
30_000,
|
||||
45_000,
|
||||
);
|
||||
});
|
||||
|
||||
@@ -224,6 +224,7 @@ type FakeMessage =
|
||||
| {
|
||||
role: "assistant";
|
||||
content: unknown[];
|
||||
phase?: "commentary" | "final_answer";
|
||||
stopReason: string;
|
||||
api: string;
|
||||
provider: string;
|
||||
@@ -247,6 +248,7 @@ function userMsg(text: string): FakeMessage {
|
||||
function assistantMsg(
|
||||
textBlocks: string[],
|
||||
toolCalls: Array<{ id: string; name: string; args: Record<string, unknown> }> = [],
|
||||
phase?: "commentary" | "final_answer",
|
||||
): FakeMessage {
|
||||
const content: unknown[] = [];
|
||||
for (const t of textBlocks) {
|
||||
@@ -258,6 +260,7 @@ function assistantMsg(
|
||||
return {
|
||||
role: "assistant",
|
||||
content,
|
||||
phase,
|
||||
stopReason: toolCalls.length > 0 ? "toolUse" : "stop",
|
||||
api: "openai-responses",
|
||||
provider: "openai",
|
||||
@@ -302,6 +305,7 @@ function makeResponseObject(
|
||||
id: string,
|
||||
outputText?: string,
|
||||
toolCallName?: string,
|
||||
phase?: "commentary" | "final_answer",
|
||||
): ResponseObject {
|
||||
const output: ResponseObject["output"] = [];
|
||||
if (outputText) {
|
||||
@@ -310,6 +314,7 @@ function makeResponseObject(
|
||||
id: "item_1",
|
||||
role: "assistant",
|
||||
content: [{ type: "output_text", text: outputText }],
|
||||
phase,
|
||||
});
|
||||
}
|
||||
if (toolCallName) {
|
||||
@@ -357,18 +362,16 @@ describe("convertTools", () => {
|
||||
expect(result).toHaveLength(1);
|
||||
expect(result[0]).toMatchObject({
|
||||
type: "function",
|
||||
function: {
|
||||
name: "exec",
|
||||
description: "Run a command",
|
||||
parameters: { type: "object", properties: { cmd: { type: "string" } } },
|
||||
},
|
||||
name: "exec",
|
||||
description: "Run a command",
|
||||
parameters: { type: "object", properties: { cmd: { type: "string" } } },
|
||||
});
|
||||
});
|
||||
|
||||
it("handles tools without description", () => {
|
||||
const tools = [{ name: "ping", description: "", parameters: {} }];
|
||||
const result = convertTools(tools as Parameters<typeof convertTools>[0]);
|
||||
expect(result[0]?.function?.name).toBe("ping");
|
||||
expect(result[0]?.name).toBe("ping");
|
||||
});
|
||||
});
|
||||
|
||||
@@ -391,6 +394,19 @@ describe("convertMessagesToInputItems", () => {
|
||||
expect(items[0]).toMatchObject({ type: "message", role: "assistant", content: "Hi there." });
|
||||
});
|
||||
|
||||
it("preserves assistant phase on replayed assistant messages", () => {
|
||||
const items = convertMessagesToInputItems([
|
||||
assistantMsg(["Working on it."], [], "commentary"),
|
||||
] as Parameters<typeof convertMessagesToInputItems>[0]);
|
||||
expect(items).toHaveLength(1);
|
||||
expect(items[0]).toMatchObject({
|
||||
type: "message",
|
||||
role: "assistant",
|
||||
content: "Working on it.",
|
||||
phase: "commentary",
|
||||
});
|
||||
});
|
||||
|
||||
it("converts an assistant message with a tool call", () => {
|
||||
const msg = assistantMsg(
|
||||
["Let me run that."],
|
||||
@@ -408,10 +424,58 @@ describe("convertMessagesToInputItems", () => {
|
||||
call_id: "call_1",
|
||||
name: "exec",
|
||||
});
|
||||
expect(textItem).not.toHaveProperty("phase");
|
||||
const fc = fcItem as { arguments: string };
|
||||
expect(JSON.parse(fc.arguments)).toEqual({ cmd: "ls" });
|
||||
});
|
||||
|
||||
it("preserves assistant phase on commentary text before tool calls", () => {
|
||||
const msg = assistantMsg(
|
||||
["Let me run that."],
|
||||
[{ id: "call_1", name: "exec", args: { cmd: "ls" } }],
|
||||
"commentary",
|
||||
);
|
||||
const items = convertMessagesToInputItems([msg] as Parameters<
|
||||
typeof convertMessagesToInputItems
|
||||
>[0]);
|
||||
const textItem = items.find((i) => i.type === "message");
|
||||
expect(textItem).toMatchObject({
|
||||
type: "message",
|
||||
role: "assistant",
|
||||
content: "Let me run that.",
|
||||
phase: "commentary",
|
||||
});
|
||||
});
|
||||
|
||||
it("preserves assistant phase from textSignature metadata without local phase field", () => {
|
||||
const msg = {
|
||||
role: "assistant" as const,
|
||||
content: [
|
||||
{
|
||||
type: "text" as const,
|
||||
text: "Working on it.",
|
||||
textSignature: JSON.stringify({ v: 1, id: "msg_sig", phase: "commentary" }),
|
||||
},
|
||||
],
|
||||
stopReason: "stop",
|
||||
api: "openai-responses",
|
||||
provider: "openai",
|
||||
model: "gpt-5.2",
|
||||
usage: {},
|
||||
timestamp: 0,
|
||||
};
|
||||
const items = convertMessagesToInputItems([msg] as Parameters<
|
||||
typeof convertMessagesToInputItems
|
||||
>[0]);
|
||||
expect(items).toHaveLength(1);
|
||||
expect(items[0]).toMatchObject({
|
||||
type: "message",
|
||||
role: "assistant",
|
||||
content: "Working on it.",
|
||||
phase: "commentary",
|
||||
});
|
||||
});
|
||||
|
||||
it("converts a tool result message", () => {
|
||||
const items = convertMessagesToInputItems([toolResultMsg("call_1", "file.txt")] as Parameters<
|
||||
typeof convertMessagesToInputItems
|
||||
@@ -518,6 +582,34 @@ describe("convertMessagesToInputItems", () => {
|
||||
expect((items[0] as { content?: unknown }).content).toBe("Here is my answer.");
|
||||
});
|
||||
|
||||
it("replays reasoning blocks from thinking signatures", () => {
|
||||
const msg = {
|
||||
role: "assistant" as const,
|
||||
content: [
|
||||
{
|
||||
type: "thinking" as const,
|
||||
thinking: "internal reasoning...",
|
||||
thinkingSignature: JSON.stringify({
|
||||
type: "reasoning",
|
||||
id: "rs_test",
|
||||
summary: [],
|
||||
}),
|
||||
},
|
||||
{ type: "text" as const, text: "Here is my answer." },
|
||||
],
|
||||
stopReason: "stop",
|
||||
api: "openai-responses",
|
||||
provider: "openai",
|
||||
model: "gpt-5.2",
|
||||
usage: {},
|
||||
timestamp: 0,
|
||||
};
|
||||
const items = convertMessagesToInputItems([msg] as Parameters<
|
||||
typeof convertMessagesToInputItems
|
||||
>[0]);
|
||||
expect(items.map((item) => item.type)).toEqual(["reasoning", "message"]);
|
||||
});
|
||||
|
||||
it("returns empty array for empty messages", () => {
|
||||
expect(convertMessagesToInputItems([])).toEqual([]);
|
||||
});
|
||||
@@ -594,6 +686,16 @@ describe("buildAssistantMessageFromResponse", () => {
|
||||
expect(msg.content).toEqual([]);
|
||||
expect(msg.stopReason).toBe("stop");
|
||||
});
|
||||
|
||||
it("preserves phase from assistant message output items", () => {
|
||||
const response = makeResponseObject("resp_8", "Final answer", undefined, "final_answer");
|
||||
const msg = buildAssistantMessageFromResponse(response, modelInfo) as {
|
||||
phase?: string;
|
||||
content: Array<{ type: string; text?: string }>;
|
||||
};
|
||||
expect(msg.phase).toBe("final_answer");
|
||||
expect(msg.content[0]?.text).toBe("Final answer");
|
||||
});
|
||||
});
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
@@ -633,6 +735,7 @@ describe("createOpenAIWebSocketStreamFn", () => {
|
||||
releaseWsSession("sess-fallback");
|
||||
releaseWsSession("sess-incremental");
|
||||
releaseWsSession("sess-full");
|
||||
releaseWsSession("sess-phase");
|
||||
releaseWsSession("sess-tools");
|
||||
releaseWsSession("sess-store-default");
|
||||
releaseWsSession("sess-store-compat");
|
||||
@@ -795,6 +898,40 @@ describe("createOpenAIWebSocketStreamFn", () => {
|
||||
expect(doneEvent?.message.content[0]?.text).toBe("Hello back!");
|
||||
});
|
||||
|
||||
it("keeps assistant phase on completed WebSocket responses", async () => {
|
||||
const streamFn = createOpenAIWebSocketStreamFn("sk-test", "sess-phase");
|
||||
const stream = streamFn(
|
||||
modelStub as Parameters<typeof streamFn>[0],
|
||||
contextStub as Parameters<typeof streamFn>[1],
|
||||
);
|
||||
|
||||
const events: unknown[] = [];
|
||||
const done = (async () => {
|
||||
for await (const ev of await resolveStream(stream)) {
|
||||
events.push(ev);
|
||||
}
|
||||
})();
|
||||
|
||||
await new Promise((r) => setImmediate(r));
|
||||
const manager = MockManager.lastInstance!;
|
||||
manager.simulateEvent({
|
||||
type: "response.completed",
|
||||
response: makeResponseObject("resp_phase", "Working...", "exec", "commentary"),
|
||||
});
|
||||
|
||||
await done;
|
||||
|
||||
const doneEvent = events.find((e) => (e as { type?: string }).type === "done") as
|
||||
| {
|
||||
type: string;
|
||||
reason: string;
|
||||
message: { phase?: string; stopReason: string };
|
||||
}
|
||||
| undefined;
|
||||
expect(doneEvent?.message.phase).toBe("commentary");
|
||||
expect(doneEvent?.message.stopReason).toBe("toolUse");
|
||||
});
|
||||
|
||||
it("falls back to HTTP when WebSocket connect fails (session pre-broken via flag)", async () => {
|
||||
// Set the class-level flag BEFORE calling streamFn so the new instance
|
||||
// fails on connect(). We patch the static default via MockManager directly.
|
||||
|
||||
@@ -37,6 +37,7 @@ import {
|
||||
type ContentPart,
|
||||
type FunctionToolDefinition,
|
||||
type InputItem,
|
||||
type OpenAIResponsesAssistantPhase,
|
||||
type OpenAIWebSocketManagerOptions,
|
||||
type ResponseObject,
|
||||
} from "./openai-ws-connection.js";
|
||||
@@ -100,6 +101,8 @@ export function hasWsSession(sessionId: string): boolean {
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
type AnyMessage = Message & { role: string; content: unknown };
|
||||
type AssistantMessageWithPhase = AssistantMessage & { phase?: OpenAIResponsesAssistantPhase };
|
||||
type ReplayModelInfo = { input?: ReadonlyArray<string> };
|
||||
|
||||
function toNonEmptyString(value: unknown): string | null {
|
||||
if (typeof value !== "string") {
|
||||
@@ -109,6 +112,50 @@ function toNonEmptyString(value: unknown): string | null {
|
||||
return trimmed.length > 0 ? trimmed : null;
|
||||
}
|
||||
|
||||
function normalizeAssistantPhase(value: unknown): OpenAIResponsesAssistantPhase | undefined {
|
||||
return value === "commentary" || value === "final_answer" ? value : undefined;
|
||||
}
|
||||
|
||||
function encodeAssistantTextSignature(params: {
|
||||
id: string;
|
||||
phase?: OpenAIResponsesAssistantPhase;
|
||||
}): string {
|
||||
return JSON.stringify({
|
||||
v: 1,
|
||||
id: params.id,
|
||||
...(params.phase ? { phase: params.phase } : {}),
|
||||
});
|
||||
}
|
||||
|
||||
function parseAssistantTextSignature(
|
||||
value: unknown,
|
||||
): { id: string; phase?: OpenAIResponsesAssistantPhase } | null {
|
||||
if (typeof value !== "string" || value.trim().length === 0) {
|
||||
return null;
|
||||
}
|
||||
if (!value.startsWith("{")) {
|
||||
return { id: value };
|
||||
}
|
||||
try {
|
||||
const parsed = JSON.parse(value) as { v?: unknown; id?: unknown; phase?: unknown };
|
||||
if (parsed.v !== 1 || typeof parsed.id !== "string") {
|
||||
return null;
|
||||
}
|
||||
return {
|
||||
id: parsed.id,
|
||||
...(normalizeAssistantPhase(parsed.phase)
|
||||
? { phase: normalizeAssistantPhase(parsed.phase) }
|
||||
: {}),
|
||||
};
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
function supportsImageInput(modelOverride?: ReplayModelInfo): boolean {
|
||||
return !Array.isArray(modelOverride?.input) || modelOverride.input.includes("image");
|
||||
}
|
||||
|
||||
/** Convert pi-ai content (string | ContentPart[]) to plain text. */
|
||||
function contentToText(content: unknown): string {
|
||||
if (typeof content === "string") {
|
||||
@@ -117,30 +164,50 @@ function contentToText(content: unknown): string {
|
||||
if (!Array.isArray(content)) {
|
||||
return "";
|
||||
}
|
||||
return (content as Array<{ type?: string; text?: string }>)
|
||||
.filter((p) => p.type === "text" && typeof p.text === "string")
|
||||
.map((p) => p.text as string)
|
||||
return content
|
||||
.filter(
|
||||
(part): part is { type?: string; text?: string } => Boolean(part) && typeof part === "object",
|
||||
)
|
||||
.filter(
|
||||
(part) =>
|
||||
(part.type === "text" || part.type === "input_text" || part.type === "output_text") &&
|
||||
typeof part.text === "string",
|
||||
)
|
||||
.map((part) => part.text as string)
|
||||
.join("");
|
||||
}
|
||||
|
||||
/** Convert pi-ai content to OpenAI ContentPart[]. */
|
||||
function contentToOpenAIParts(content: unknown): ContentPart[] {
|
||||
function contentToOpenAIParts(content: unknown, modelOverride?: ReplayModelInfo): ContentPart[] {
|
||||
if (typeof content === "string") {
|
||||
return content ? [{ type: "input_text", text: content }] : [];
|
||||
}
|
||||
if (!Array.isArray(content)) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const includeImages = supportsImageInput(modelOverride);
|
||||
const parts: ContentPart[] = [];
|
||||
for (const part of content as Array<{
|
||||
type?: string;
|
||||
text?: string;
|
||||
data?: string;
|
||||
mimeType?: string;
|
||||
source?: unknown;
|
||||
}>) {
|
||||
if (part.type === "text" && typeof part.text === "string") {
|
||||
if (
|
||||
(part.type === "text" || part.type === "input_text" || part.type === "output_text") &&
|
||||
typeof part.text === "string"
|
||||
) {
|
||||
parts.push({ type: "input_text", text: part.text });
|
||||
} else if (part.type === "image" && typeof part.data === "string") {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!includeImages) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (part.type === "image" && typeof part.data === "string") {
|
||||
parts.push({
|
||||
type: "input_image",
|
||||
source: {
|
||||
@@ -149,11 +216,60 @@ function contentToOpenAIParts(content: unknown): ContentPart[] {
|
||||
data: part.data,
|
||||
},
|
||||
});
|
||||
continue;
|
||||
}
|
||||
|
||||
if (
|
||||
part.type === "input_image" &&
|
||||
part.source &&
|
||||
typeof part.source === "object" &&
|
||||
typeof (part.source as { type?: unknown }).type === "string"
|
||||
) {
|
||||
parts.push({
|
||||
type: "input_image",
|
||||
source: part.source as
|
||||
| { type: "url"; url: string }
|
||||
| { type: "base64"; media_type: string; data: string },
|
||||
});
|
||||
}
|
||||
}
|
||||
return parts;
|
||||
}
|
||||
|
||||
function parseReasoningItem(value: unknown): Extract<InputItem, { type: "reasoning" }> | null {
|
||||
if (!value || typeof value !== "object") {
|
||||
return null;
|
||||
}
|
||||
const record = value as {
|
||||
type?: unknown;
|
||||
content?: unknown;
|
||||
encrypted_content?: unknown;
|
||||
summary?: unknown;
|
||||
};
|
||||
if (record.type !== "reasoning") {
|
||||
return null;
|
||||
}
|
||||
return {
|
||||
type: "reasoning",
|
||||
...(typeof record.content === "string" ? { content: record.content } : {}),
|
||||
...(typeof record.encrypted_content === "string"
|
||||
? { encrypted_content: record.encrypted_content }
|
||||
: {}),
|
||||
...(typeof record.summary === "string" ? { summary: record.summary } : {}),
|
||||
};
|
||||
}
|
||||
|
||||
function parseThinkingSignature(value: unknown): Extract<InputItem, { type: "reasoning" }> | null {
|
||||
if (typeof value !== "string" || value.trim().length === 0) {
|
||||
return null;
|
||||
}
|
||||
try {
|
||||
return parseReasoningItem(JSON.parse(value));
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
/** Convert pi-ai tool array to OpenAI FunctionToolDefinition[]. */
|
||||
export function convertTools(tools: Context["tools"]): FunctionToolDefinition[] {
|
||||
if (!tools || tools.length === 0) {
|
||||
@@ -161,11 +277,9 @@ export function convertTools(tools: Context["tools"]): FunctionToolDefinition[]
|
||||
}
|
||||
return tools.map((tool) => ({
|
||||
type: "function" as const,
|
||||
function: {
|
||||
name: tool.name,
|
||||
description: typeof tool.description === "string" ? tool.description : undefined,
|
||||
parameters: (tool.parameters ?? {}) as Record<string, unknown>,
|
||||
},
|
||||
name: tool.name,
|
||||
description: typeof tool.description === "string" ? tool.description : undefined,
|
||||
parameters: (tool.parameters ?? {}) as Record<string, unknown>,
|
||||
}));
|
||||
}
|
||||
|
||||
@@ -173,14 +287,24 @@ export function convertTools(tools: Context["tools"]): FunctionToolDefinition[]
|
||||
* Convert the full pi-ai message history to an OpenAI `input` array.
|
||||
* Handles user messages, assistant text+tool-call messages, and tool results.
|
||||
*/
|
||||
export function convertMessagesToInputItems(messages: Message[]): InputItem[] {
|
||||
export function convertMessagesToInputItems(
|
||||
messages: Message[],
|
||||
modelOverride?: ReplayModelInfo,
|
||||
): InputItem[] {
|
||||
const items: InputItem[] = [];
|
||||
|
||||
for (const msg of messages) {
|
||||
const m = msg as AnyMessage;
|
||||
const m = msg as AnyMessage & {
|
||||
phase?: unknown;
|
||||
toolCallId?: unknown;
|
||||
toolUseId?: unknown;
|
||||
};
|
||||
|
||||
if (m.role === "user") {
|
||||
const parts = contentToOpenAIParts(m.content);
|
||||
const parts = contentToOpenAIParts(m.content, modelOverride);
|
||||
if (parts.length === 0) {
|
||||
continue;
|
||||
}
|
||||
items.push({
|
||||
type: "message",
|
||||
role: "user",
|
||||
@@ -194,87 +318,116 @@ export function convertMessagesToInputItems(messages: Message[]): InputItem[] {
|
||||
|
||||
if (m.role === "assistant") {
|
||||
const content = m.content;
|
||||
let assistantPhase = normalizeAssistantPhase(m.phase);
|
||||
if (Array.isArray(content)) {
|
||||
// Collect text blocks and tool calls separately
|
||||
const textParts: string[] = [];
|
||||
for (const block of content as Array<{
|
||||
type?: string;
|
||||
text?: string;
|
||||
id?: string;
|
||||
name?: string;
|
||||
arguments?: Record<string, unknown>;
|
||||
thinking?: string;
|
||||
}>) {
|
||||
if (block.type === "text" && typeof block.text === "string") {
|
||||
textParts.push(block.text);
|
||||
} else if (block.type === "thinking" && typeof block.thinking === "string") {
|
||||
// Skip thinking blocks — not sent back to the model
|
||||
} else if (block.type === "toolCall") {
|
||||
// Push accumulated text first
|
||||
if (textParts.length > 0) {
|
||||
items.push({
|
||||
type: "message",
|
||||
role: "assistant",
|
||||
content: textParts.join(""),
|
||||
});
|
||||
textParts.length = 0;
|
||||
}
|
||||
const callId = toNonEmptyString(block.id);
|
||||
const toolName = toNonEmptyString(block.name);
|
||||
if (!callId || !toolName) {
|
||||
continue;
|
||||
}
|
||||
// Push function_call item
|
||||
items.push({
|
||||
type: "function_call",
|
||||
call_id: callId,
|
||||
name: toolName,
|
||||
arguments:
|
||||
typeof block.arguments === "string"
|
||||
? block.arguments
|
||||
: JSON.stringify(block.arguments ?? {}),
|
||||
});
|
||||
const pushAssistantText = () => {
|
||||
if (textParts.length === 0) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
if (textParts.length > 0) {
|
||||
items.push({
|
||||
type: "message",
|
||||
role: "assistant",
|
||||
content: textParts.join(""),
|
||||
...(assistantPhase ? { phase: assistantPhase } : {}),
|
||||
});
|
||||
}
|
||||
} else {
|
||||
const text = contentToText(m.content);
|
||||
if (text) {
|
||||
textParts.length = 0;
|
||||
};
|
||||
|
||||
for (const block of content as Array<{
|
||||
type?: string;
|
||||
text?: string;
|
||||
textSignature?: unknown;
|
||||
id?: unknown;
|
||||
name?: unknown;
|
||||
arguments?: unknown;
|
||||
thinkingSignature?: unknown;
|
||||
}>) {
|
||||
if (block.type === "text" && typeof block.text === "string") {
|
||||
const parsedSignature = parseAssistantTextSignature(block.textSignature);
|
||||
if (!assistantPhase) {
|
||||
assistantPhase = parsedSignature?.phase;
|
||||
}
|
||||
textParts.push(block.text);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (block.type === "thinking") {
|
||||
pushAssistantText();
|
||||
const reasoningItem = parseThinkingSignature(block.thinkingSignature);
|
||||
if (reasoningItem) {
|
||||
items.push(reasoningItem);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (block.type !== "toolCall") {
|
||||
continue;
|
||||
}
|
||||
|
||||
pushAssistantText();
|
||||
const callIdRaw = toNonEmptyString(block.id);
|
||||
const toolName = toNonEmptyString(block.name);
|
||||
if (!callIdRaw || !toolName) {
|
||||
continue;
|
||||
}
|
||||
const [callId, itemId] = callIdRaw.split("|", 2);
|
||||
items.push({
|
||||
type: "message",
|
||||
role: "assistant",
|
||||
content: text,
|
||||
type: "function_call",
|
||||
...(itemId ? { id: itemId } : {}),
|
||||
call_id: callId,
|
||||
name: toolName,
|
||||
arguments:
|
||||
typeof block.arguments === "string"
|
||||
? block.arguments
|
||||
: JSON.stringify(block.arguments ?? {}),
|
||||
});
|
||||
}
|
||||
|
||||
pushAssistantText();
|
||||
continue;
|
||||
}
|
||||
|
||||
const text = contentToText(content);
|
||||
if (!text) {
|
||||
continue;
|
||||
}
|
||||
items.push({
|
||||
type: "message",
|
||||
role: "assistant",
|
||||
content: text,
|
||||
...(assistantPhase ? { phase: assistantPhase } : {}),
|
||||
});
|
||||
continue;
|
||||
}
|
||||
|
||||
if (m.role === "toolResult") {
|
||||
const tr = m as unknown as {
|
||||
toolCallId?: string;
|
||||
toolUseId?: string;
|
||||
content: unknown;
|
||||
isError: boolean;
|
||||
};
|
||||
const callId = toNonEmptyString(tr.toolCallId) ?? toNonEmptyString(tr.toolUseId);
|
||||
if (!callId) {
|
||||
continue;
|
||||
}
|
||||
const outputText = contentToText(tr.content);
|
||||
items.push({
|
||||
type: "function_call_output",
|
||||
call_id: callId,
|
||||
output: outputText,
|
||||
});
|
||||
if (m.role !== "toolResult") {
|
||||
continue;
|
||||
}
|
||||
|
||||
const toolCallId = toNonEmptyString(m.toolCallId) ?? toNonEmptyString(m.toolUseId);
|
||||
if (!toolCallId) {
|
||||
continue;
|
||||
}
|
||||
const [callId] = toolCallId.split("|", 2);
|
||||
const parts = Array.isArray(m.content) ? contentToOpenAIParts(m.content, modelOverride) : [];
|
||||
const textOutput = contentToText(m.content);
|
||||
const imageParts = parts.filter((part) => part.type === "input_image");
|
||||
items.push({
|
||||
type: "function_call_output",
|
||||
call_id: callId,
|
||||
output: textOutput || (imageParts.length > 0 ? "(see attached image)" : ""),
|
||||
});
|
||||
if (imageParts.length > 0) {
|
||||
items.push({
|
||||
type: "message",
|
||||
role: "user",
|
||||
content: [
|
||||
{ type: "input_text", text: "Attached image(s) from tool result:" },
|
||||
...imageParts,
|
||||
],
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
return items;
|
||||
@@ -289,12 +442,24 @@ export function buildAssistantMessageFromResponse(
|
||||
modelInfo: { api: string; provider: string; id: string },
|
||||
): AssistantMessage {
|
||||
const content: (TextContent | ToolCall)[] = [];
|
||||
let assistantPhase: OpenAIResponsesAssistantPhase | undefined;
|
||||
|
||||
for (const item of response.output ?? []) {
|
||||
if (item.type === "message") {
|
||||
const itemPhase = normalizeAssistantPhase(item.phase);
|
||||
if (itemPhase) {
|
||||
assistantPhase = itemPhase;
|
||||
}
|
||||
for (const part of item.content ?? []) {
|
||||
if (part.type === "output_text" && part.text) {
|
||||
content.push({ type: "text", text: part.text });
|
||||
content.push({
|
||||
type: "text",
|
||||
text: part.text,
|
||||
textSignature: encodeAssistantTextSignature({
|
||||
id: item.id,
|
||||
...(itemPhase ? { phase: itemPhase } : {}),
|
||||
}),
|
||||
});
|
||||
}
|
||||
}
|
||||
} else if (item.type === "function_call") {
|
||||
@@ -321,7 +486,7 @@ export function buildAssistantMessageFromResponse(
|
||||
const hasToolCalls = content.some((c) => c.type === "toolCall");
|
||||
const stopReason: StopReason = hasToolCalls ? "toolUse" : "stop";
|
||||
|
||||
return buildAssistantMessage({
|
||||
const message = buildAssistantMessage({
|
||||
model: modelInfo,
|
||||
content,
|
||||
stopReason,
|
||||
@@ -331,6 +496,10 @@ export function buildAssistantMessageFromResponse(
|
||||
totalTokens: response.usage?.total_tokens ?? 0,
|
||||
}),
|
||||
});
|
||||
|
||||
return assistantPhase
|
||||
? ({ ...message, phase: assistantPhase } as AssistantMessageWithPhase)
|
||||
: message;
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
@@ -504,6 +673,7 @@ export function createOpenAIWebSocketStreamFn(
|
||||
|
||||
if (resolveWsWarmup(options) && !session.warmUpAttempted) {
|
||||
session.warmUpAttempted = true;
|
||||
let warmupFailed = false;
|
||||
try {
|
||||
await runWarmUp({
|
||||
manager: session.manager,
|
||||
@@ -517,10 +687,33 @@ export function createOpenAIWebSocketStreamFn(
|
||||
if (signal?.aborted) {
|
||||
throw warmErr instanceof Error ? warmErr : new Error(String(warmErr));
|
||||
}
|
||||
warmupFailed = true;
|
||||
log.warn(
|
||||
`[ws-stream] warm-up failed for session=${sessionId}; continuing without warm-up. error=${String(warmErr)}`,
|
||||
);
|
||||
}
|
||||
if (warmupFailed && !session.manager.isConnected()) {
|
||||
try {
|
||||
session.manager.close();
|
||||
} catch {
|
||||
/* ignore */
|
||||
}
|
||||
try {
|
||||
await session.manager.connect(apiKey);
|
||||
session.everConnected = true;
|
||||
log.debug(`[ws-stream] reconnected after warm-up failure for session=${sessionId}`);
|
||||
} catch (reconnectErr) {
|
||||
session.broken = true;
|
||||
wsRegistry.delete(sessionId);
|
||||
if (transport === "websocket") {
|
||||
throw reconnectErr instanceof Error ? reconnectErr : new Error(String(reconnectErr));
|
||||
}
|
||||
log.warn(
|
||||
`[ws-stream] reconnect after warm-up failed for session=${sessionId}; falling back to HTTP. error=${String(reconnectErr)}`,
|
||||
);
|
||||
return fallbackToHttp(model, context, options, eventStream, opts.signal);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── 3. Compute incremental vs full input ─────────────────────────────
|
||||
@@ -537,16 +730,16 @@ export function createOpenAIWebSocketStreamFn(
|
||||
log.debug(
|
||||
`[ws-stream] session=${sessionId}: no new tool results found; sending full context`,
|
||||
);
|
||||
inputItems = buildFullInput(context);
|
||||
inputItems = buildFullInput(context, model);
|
||||
} else {
|
||||
inputItems = convertMessagesToInputItems(toolResults);
|
||||
inputItems = convertMessagesToInputItems(toolResults, model);
|
||||
}
|
||||
log.debug(
|
||||
`[ws-stream] session=${sessionId}: incremental send (${inputItems.length} tool results) previous_response_id=${prevResponseId}`,
|
||||
);
|
||||
} else {
|
||||
// First turn: send full context
|
||||
inputItems = buildFullInput(context);
|
||||
inputItems = buildFullInput(context, model);
|
||||
log.debug(
|
||||
`[ws-stream] session=${sessionId}: full context send (${inputItems.length} items)`,
|
||||
);
|
||||
@@ -605,10 +798,9 @@ export function createOpenAIWebSocketStreamFn(
|
||||
...extraParams,
|
||||
};
|
||||
const nextPayload = await options?.onPayload?.(payload, model);
|
||||
const requestPayload =
|
||||
nextPayload && typeof nextPayload === "object"
|
||||
? (nextPayload as Parameters<OpenAIWebSocketManager["send"]>[0])
|
||||
: (payload as Parameters<OpenAIWebSocketManager["send"]>[0]);
|
||||
const requestPayload = (nextPayload ?? payload) as Parameters<
|
||||
OpenAIWebSocketManager["send"]
|
||||
>[0];
|
||||
|
||||
try {
|
||||
session.manager.send(requestPayload);
|
||||
@@ -734,8 +926,8 @@ export function createOpenAIWebSocketStreamFn(
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
/** Build full input items from context (system prompt is passed via `instructions` field). */
|
||||
function buildFullInput(context: Context): InputItem[] {
|
||||
return convertMessagesToInputItems(context.messages);
|
||||
function buildFullInput(context: Context, model: ReplayModelInfo): InputItem[] {
|
||||
return convertMessagesToInputItems(context.messages, model);
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -2,6 +2,22 @@ import { describe, expect, it, vi } from "vitest";
|
||||
|
||||
const loadSessionStoreMock = vi.fn();
|
||||
const updateSessionStoreMock = vi.fn();
|
||||
const callGatewayMock = vi.fn();
|
||||
|
||||
const createMockConfig = () => ({
|
||||
session: { mainKey: "main", scope: "per-sender" },
|
||||
agents: {
|
||||
defaults: {
|
||||
model: { primary: "anthropic/claude-opus-4-5" },
|
||||
models: {},
|
||||
},
|
||||
},
|
||||
tools: {
|
||||
agentToAgent: { enabled: false },
|
||||
},
|
||||
});
|
||||
|
||||
let mockConfig: Record<string, unknown> = createMockConfig();
|
||||
|
||||
vi.mock("../config/sessions.js", async (importOriginal) => {
|
||||
const actual = await importOriginal<typeof import("../config/sessions.js")>();
|
||||
@@ -22,19 +38,15 @@ vi.mock("../config/sessions.js", async (importOriginal) => {
|
||||
};
|
||||
});
|
||||
|
||||
vi.mock("../gateway/call.js", () => ({
|
||||
callGateway: (opts: unknown) => callGatewayMock(opts),
|
||||
}));
|
||||
|
||||
vi.mock("../config/config.js", async (importOriginal) => {
|
||||
const actual = await importOriginal<typeof import("../config/config.js")>();
|
||||
return {
|
||||
...actual,
|
||||
loadConfig: () => ({
|
||||
session: { mainKey: "main", scope: "per-sender" },
|
||||
agents: {
|
||||
defaults: {
|
||||
model: { primary: "anthropic/claude-opus-4-5" },
|
||||
models: {},
|
||||
},
|
||||
},
|
||||
}),
|
||||
loadConfig: () => mockConfig,
|
||||
};
|
||||
});
|
||||
|
||||
@@ -82,13 +94,17 @@ import { createOpenClawTools } from "./openclaw-tools.js";
|
||||
function resetSessionStore(store: Record<string, unknown>) {
|
||||
loadSessionStoreMock.mockClear();
|
||||
updateSessionStoreMock.mockClear();
|
||||
callGatewayMock.mockClear();
|
||||
loadSessionStoreMock.mockReturnValue(store);
|
||||
callGatewayMock.mockResolvedValue({});
|
||||
mockConfig = createMockConfig();
|
||||
}
|
||||
|
||||
function getSessionStatusTool(agentSessionKey = "main") {
|
||||
const tool = createOpenClawTools({ agentSessionKey }).find(
|
||||
(candidate) => candidate.name === "session_status",
|
||||
);
|
||||
function getSessionStatusTool(agentSessionKey = "main", options?: { sandboxed?: boolean }) {
|
||||
const tool = createOpenClawTools({
|
||||
agentSessionKey,
|
||||
sandboxed: options?.sandboxed,
|
||||
}).find((candidate) => candidate.name === "session_status");
|
||||
expect(tool).toBeDefined();
|
||||
if (!tool) {
|
||||
throw new Error("missing session_status tool");
|
||||
@@ -176,6 +192,153 @@ describe("session_status tool", () => {
|
||||
);
|
||||
});
|
||||
|
||||
it("blocks sandboxed child session_status access outside its tree before store lookup", async () => {
|
||||
resetSessionStore({
|
||||
"agent:main:subagent:child": {
|
||||
sessionId: "s-child",
|
||||
updatedAt: 20,
|
||||
},
|
||||
"agent:main:main": {
|
||||
sessionId: "s-parent",
|
||||
updatedAt: 10,
|
||||
},
|
||||
});
|
||||
mockConfig = {
|
||||
session: { mainKey: "main", scope: "per-sender" },
|
||||
tools: {
|
||||
sessions: { visibility: "all" },
|
||||
agentToAgent: { enabled: true, allow: ["*"] },
|
||||
},
|
||||
agents: {
|
||||
defaults: {
|
||||
model: { primary: "anthropic/claude-opus-4-5" },
|
||||
models: {},
|
||||
sandbox: { sessionToolsVisibility: "spawned" },
|
||||
},
|
||||
},
|
||||
};
|
||||
callGatewayMock.mockImplementation(async (opts: unknown) => {
|
||||
const request = opts as { method?: string; params?: Record<string, unknown> };
|
||||
if (request.method === "sessions.list") {
|
||||
return { sessions: [] };
|
||||
}
|
||||
return {};
|
||||
});
|
||||
|
||||
const tool = getSessionStatusTool("agent:main:subagent:child", {
|
||||
sandboxed: true,
|
||||
});
|
||||
const expectedError = "Session status visibility is restricted to the current session tree";
|
||||
|
||||
await expect(
|
||||
tool.execute("call6", {
|
||||
sessionKey: "agent:main:main",
|
||||
model: "anthropic/claude-sonnet-4-5",
|
||||
}),
|
||||
).rejects.toThrow(expectedError);
|
||||
|
||||
await expect(
|
||||
tool.execute("call7", {
|
||||
sessionKey: "agent:main:subagent:missing",
|
||||
}),
|
||||
).rejects.toThrow(expectedError);
|
||||
|
||||
expect(loadSessionStoreMock).not.toHaveBeenCalled();
|
||||
expect(updateSessionStoreMock).not.toHaveBeenCalled();
|
||||
expect(callGatewayMock).toHaveBeenCalledTimes(2);
|
||||
expect(callGatewayMock).toHaveBeenNthCalledWith(1, {
|
||||
method: "sessions.list",
|
||||
params: {
|
||||
includeGlobal: false,
|
||||
includeUnknown: false,
|
||||
limit: 500,
|
||||
spawnedBy: "agent:main:subagent:child",
|
||||
},
|
||||
});
|
||||
expect(callGatewayMock).toHaveBeenNthCalledWith(2, {
|
||||
method: "sessions.list",
|
||||
params: {
|
||||
includeGlobal: false,
|
||||
includeUnknown: false,
|
||||
limit: 500,
|
||||
spawnedBy: "agent:main:subagent:child",
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it("keeps legacy main requester keys for sandboxed session tree checks", async () => {
|
||||
resetSessionStore({
|
||||
"agent:main:main": {
|
||||
sessionId: "s-main",
|
||||
updatedAt: 10,
|
||||
},
|
||||
"agent:main:subagent:child": {
|
||||
sessionId: "s-child",
|
||||
updatedAt: 20,
|
||||
},
|
||||
});
|
||||
mockConfig = {
|
||||
session: { mainKey: "main", scope: "per-sender" },
|
||||
tools: {
|
||||
sessions: { visibility: "all" },
|
||||
agentToAgent: { enabled: true, allow: ["*"] },
|
||||
},
|
||||
agents: {
|
||||
defaults: {
|
||||
model: { primary: "anthropic/claude-opus-4-5" },
|
||||
models: {},
|
||||
sandbox: { sessionToolsVisibility: "spawned" },
|
||||
},
|
||||
},
|
||||
};
|
||||
callGatewayMock.mockImplementation(async (opts: unknown) => {
|
||||
const request = opts as { method?: string; params?: Record<string, unknown> };
|
||||
if (request.method === "sessions.list") {
|
||||
return {
|
||||
sessions:
|
||||
request.params?.spawnedBy === "main" ? [{ key: "agent:main:subagent:child" }] : [],
|
||||
};
|
||||
}
|
||||
return {};
|
||||
});
|
||||
|
||||
const tool = getSessionStatusTool("main", {
|
||||
sandboxed: true,
|
||||
});
|
||||
|
||||
const mainResult = await tool.execute("call8", {});
|
||||
const mainDetails = mainResult.details as { ok?: boolean; sessionKey?: string };
|
||||
expect(mainDetails.ok).toBe(true);
|
||||
expect(mainDetails.sessionKey).toBe("agent:main:main");
|
||||
|
||||
const childResult = await tool.execute("call9", {
|
||||
sessionKey: "agent:main:subagent:child",
|
||||
});
|
||||
const childDetails = childResult.details as { ok?: boolean; sessionKey?: string };
|
||||
expect(childDetails.ok).toBe(true);
|
||||
expect(childDetails.sessionKey).toBe("agent:main:subagent:child");
|
||||
|
||||
expect(callGatewayMock).toHaveBeenCalledTimes(2);
|
||||
expect(callGatewayMock).toHaveBeenNthCalledWith(1, {
|
||||
method: "sessions.list",
|
||||
params: {
|
||||
includeGlobal: false,
|
||||
includeUnknown: false,
|
||||
limit: 500,
|
||||
spawnedBy: "main",
|
||||
},
|
||||
});
|
||||
expect(callGatewayMock).toHaveBeenNthCalledWith(2, {
|
||||
method: "sessions.list",
|
||||
params: {
|
||||
includeGlobal: false,
|
||||
includeUnknown: false,
|
||||
limit: 500,
|
||||
spawnedBy: "main",
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it("scopes bare session keys to the requester agent", async () => {
|
||||
loadSessionStoreMock.mockClear();
|
||||
updateSessionStoreMock.mockClear();
|
||||
|
||||
@@ -85,7 +85,10 @@ describe("sessions_spawn depth + child limits", () => {
|
||||
});
|
||||
|
||||
it("rejects spawning when caller depth reaches maxSpawnDepth", async () => {
|
||||
const tool = createSessionsSpawnTool({ agentSessionKey: "agent:main:subagent:parent" });
|
||||
const tool = createSessionsSpawnTool({
|
||||
agentSessionKey: "agent:main:subagent:parent",
|
||||
workspaceDir: "/parent/workspace",
|
||||
});
|
||||
const result = await tool.execute("call-depth-reject", { task: "hello" });
|
||||
|
||||
expect(result.details).toMatchObject({
|
||||
@@ -109,8 +112,13 @@ describe("sessions_spawn depth + child limits", () => {
|
||||
const calls = callGatewayMock.mock.calls.map(
|
||||
(call) => call[0] as { method?: string; params?: Record<string, unknown> },
|
||||
);
|
||||
const agentCall = calls.find((entry) => entry.method === "agent");
|
||||
expect(agentCall?.params?.spawnedBy).toBe("agent:main:subagent:parent");
|
||||
const spawnedByPatch = calls.find(
|
||||
(entry) =>
|
||||
entry.method === "sessions.patch" &&
|
||||
entry.params?.spawnedBy === "agent:main:subagent:parent",
|
||||
);
|
||||
expect(spawnedByPatch?.params?.key).toMatch(/^agent:main:subagent:/);
|
||||
expect(typeof spawnedByPatch?.params?.spawnedWorkspaceDir).toBe("string");
|
||||
|
||||
const spawnDepthPatch = calls.find(
|
||||
(entry) => entry.method === "sessions.patch" && entry.params?.spawnDepth === 2,
|
||||
|
||||
@@ -200,6 +200,7 @@ export function createOpenClawTools(
|
||||
createSessionStatusTool({
|
||||
agentSessionKey: options?.agentSessionKey,
|
||||
config: options?.config,
|
||||
sandboxed: options?.sandboxed,
|
||||
}),
|
||||
...(webSearchTool ? [webSearchTool] : []),
|
||||
...(webFetchTool ? [webFetchTool] : []),
|
||||
|
||||
@@ -106,6 +106,9 @@ describe("isBillingErrorMessage", () => {
|
||||
"Payment Required",
|
||||
"HTTP 402 Payment Required",
|
||||
"plans & billing",
|
||||
// Venice returns "Insufficient USD or Diem balance" which has extra words
|
||||
// between "insufficient" and "balance"
|
||||
"Insufficient USD or Diem balance to complete request. Visit https://venice.ai/settings/api to add credits.",
|
||||
];
|
||||
for (const sample of samples) {
|
||||
expect(isBillingErrorMessage(sample)).toBe(true);
|
||||
@@ -149,6 +152,11 @@ describe("isBillingErrorMessage", () => {
|
||||
expect(longResponse.length).toBeGreaterThan(512);
|
||||
expect(isBillingErrorMessage(longResponse)).toBe(false);
|
||||
});
|
||||
it("does not false-positive on short non-billing text that mentions insufficient and balance", () => {
|
||||
const sample = "The evidence is insufficient to reconcile the final balance after compaction.";
|
||||
expect(isBillingErrorMessage(sample)).toBe(false);
|
||||
expect(classifyFailoverReason(sample)).toBeNull();
|
||||
});
|
||||
it("still matches explicit 402 markers in long payloads", () => {
|
||||
const longStructuredError =
|
||||
'{"error":{"code":402,"message":"payment required","details":"' + "x".repeat(700) + '"}}';
|
||||
@@ -439,6 +447,18 @@ describe("isLikelyContextOverflowError", () => {
|
||||
expect(isLikelyContextOverflowError(sample)).toBe(false);
|
||||
}
|
||||
});
|
||||
|
||||
it("excludes billing errors even when text matches context overflow patterns", () => {
|
||||
const samples = [
|
||||
"402 Payment Required: request token limit exceeded for this billing plan",
|
||||
"insufficient credits: request size exceeds your current plan limits",
|
||||
"Your credit balance is too low. Maximum request token limit exceeded.",
|
||||
];
|
||||
for (const sample of samples) {
|
||||
expect(isBillingErrorMessage(sample)).toBe(true);
|
||||
expect(isLikelyContextOverflowError(sample)).toBe(false);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
describe("isTransientHttpError", () => {
|
||||
@@ -515,6 +535,23 @@ describe("isFailoverErrorMessage", () => {
|
||||
}
|
||||
});
|
||||
|
||||
it("matches network errno codes in serialized error messages", () => {
|
||||
const samples = [
|
||||
"Error: connect ETIMEDOUT 10.0.0.1:443",
|
||||
"Error: connect ESOCKETTIMEDOUT 10.0.0.1:443",
|
||||
"Error: connect EHOSTUNREACH 10.0.0.1:443",
|
||||
"Error: connect ENETUNREACH 10.0.0.1:443",
|
||||
"Error: write EPIPE",
|
||||
"Error: read ENETRESET",
|
||||
"Error: connect EHOSTDOWN 192.168.1.1:443",
|
||||
];
|
||||
for (const sample of samples) {
|
||||
expect(isTimeoutErrorMessage(sample)).toBe(true);
|
||||
expect(classifyFailoverReason(sample)).toBe("timeout");
|
||||
expect(isFailoverErrorMessage(sample)).toBe(true);
|
||||
}
|
||||
});
|
||||
|
||||
it("does not classify MALFORMED_FUNCTION_CALL as timeout", () => {
|
||||
const sample = "Unhandled stop reason: MALFORMED_FUNCTION_CALL";
|
||||
expect(isTimeoutErrorMessage(sample)).toBe(false);
|
||||
@@ -638,6 +675,12 @@ describe("classifyFailoverReason", () => {
|
||||
expect(classifyFailoverReason(TOGETHER_ENGINE_OVERLOADED_MESSAGE)).toBe("overloaded");
|
||||
expect(classifyFailoverReason(GROQ_TOO_MANY_REQUESTS_MESSAGE)).toBe("rate_limit");
|
||||
expect(classifyFailoverReason(GROQ_SERVICE_UNAVAILABLE_MESSAGE)).toBe("overloaded");
|
||||
// Venice 402 billing error with extra words between "insufficient" and "balance"
|
||||
expect(
|
||||
classifyFailoverReason(
|
||||
"Insufficient USD or Diem balance to complete request. Visit https://venice.ai/settings/api to add credits.",
|
||||
),
|
||||
).toBe("billing");
|
||||
});
|
||||
|
||||
it("classifies internal and compatibility error messages", () => {
|
||||
|
||||
@@ -138,6 +138,13 @@ export function isLikelyContextOverflowError(errorMessage?: string): boolean {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Billing/quota errors can contain patterns like "request size exceeds" or
|
||||
// "maximum token limit exceeded" that match the context overflow heuristic.
|
||||
// Billing is a more specific error class — exclude it early.
|
||||
if (isBillingErrorMessage(errorMessage)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (CONTEXT_WINDOW_TOO_SMALL_RE.test(errorMessage)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -37,6 +37,13 @@ const ERROR_PATTERNS = {
|
||||
"fetch failed",
|
||||
"socket hang up",
|
||||
/\beconn(?:refused|reset|aborted)\b/i,
|
||||
/\benetunreach\b/i,
|
||||
/\behostunreach\b/i,
|
||||
/\behostdown\b/i,
|
||||
/\benetreset\b/i,
|
||||
/\betimedout\b/i,
|
||||
/\besockettimedout\b/i,
|
||||
/\bepipe\b/i,
|
||||
/\benotfound\b/i,
|
||||
/\beai_again\b/i,
|
||||
/without sending (?:any )?chunks?/i,
|
||||
@@ -52,6 +59,7 @@ const ERROR_PATTERNS = {
|
||||
"credit balance",
|
||||
"plans & billing",
|
||||
"insufficient balance",
|
||||
"insufficient usd or diem balance",
|
||||
],
|
||||
authPermanent: [
|
||||
/api[_ ]?key[_ ]?(?:revoked|invalid|deactivated|deleted)/i,
|
||||
|
||||
@@ -276,7 +276,7 @@ describe("applyExtraParamsToAgent", () => {
|
||||
const payloads: Record<string, unknown>[] = [];
|
||||
const baseStreamFn: StreamFn = (_model, _context, options) => {
|
||||
const payload: Record<string, unknown> = { model: "deepseek/deepseek-r1" };
|
||||
options?.onPayload?.(payload, _model);
|
||||
options?.onPayload?.(payload, model);
|
||||
payloads.push(payload);
|
||||
return {} as ReturnType<StreamFn>;
|
||||
};
|
||||
@@ -308,7 +308,7 @@ describe("applyExtraParamsToAgent", () => {
|
||||
const payloads: Record<string, unknown>[] = [];
|
||||
const baseStreamFn: StreamFn = (_model, _context, options) => {
|
||||
const payload: Record<string, unknown> = {};
|
||||
options?.onPayload?.(payload, _model);
|
||||
options?.onPayload?.(payload, model);
|
||||
payloads.push(payload);
|
||||
return {} as ReturnType<StreamFn>;
|
||||
};
|
||||
@@ -332,7 +332,7 @@ describe("applyExtraParamsToAgent", () => {
|
||||
const payloads: Record<string, unknown>[] = [];
|
||||
const baseStreamFn: StreamFn = (_model, _context, options) => {
|
||||
const payload: Record<string, unknown> = { reasoning_effort: "high" };
|
||||
options?.onPayload?.(payload, _model);
|
||||
options?.onPayload?.(payload, model);
|
||||
payloads.push(payload);
|
||||
return {} as ReturnType<StreamFn>;
|
||||
};
|
||||
@@ -357,7 +357,7 @@ describe("applyExtraParamsToAgent", () => {
|
||||
const payloads: Record<string, unknown>[] = [];
|
||||
const baseStreamFn: StreamFn = (_model, _context, options) => {
|
||||
const payload: Record<string, unknown> = { reasoning: { max_tokens: 256 } };
|
||||
options?.onPayload?.(payload, _model);
|
||||
options?.onPayload?.(payload, model);
|
||||
payloads.push(payload);
|
||||
return {} as ReturnType<StreamFn>;
|
||||
};
|
||||
@@ -381,7 +381,7 @@ describe("applyExtraParamsToAgent", () => {
|
||||
const payloads: Record<string, unknown>[] = [];
|
||||
const baseStreamFn: StreamFn = (_model, _context, options) => {
|
||||
const payload: Record<string, unknown> = { reasoning_effort: "medium" };
|
||||
options?.onPayload?.(payload, _model);
|
||||
options?.onPayload?.(payload, model);
|
||||
payloads.push(payload);
|
||||
return {} as ReturnType<StreamFn>;
|
||||
};
|
||||
@@ -588,7 +588,7 @@ describe("applyExtraParamsToAgent", () => {
|
||||
const payloads: Record<string, unknown>[] = [];
|
||||
const baseStreamFn: StreamFn = (_model, _context, options) => {
|
||||
const payload: Record<string, unknown> = { thinking: "off" };
|
||||
options?.onPayload?.(payload, _model);
|
||||
options?.onPayload?.(payload, model);
|
||||
payloads.push(payload);
|
||||
return {} as ReturnType<StreamFn>;
|
||||
};
|
||||
@@ -619,7 +619,7 @@ describe("applyExtraParamsToAgent", () => {
|
||||
const payloads: Record<string, unknown>[] = [];
|
||||
const baseStreamFn: StreamFn = (_model, _context, options) => {
|
||||
const payload: Record<string, unknown> = { thinking: "off" };
|
||||
options?.onPayload?.(payload, _model);
|
||||
options?.onPayload?.(payload, model);
|
||||
payloads.push(payload);
|
||||
return {} as ReturnType<StreamFn>;
|
||||
};
|
||||
@@ -650,7 +650,7 @@ describe("applyExtraParamsToAgent", () => {
|
||||
const payloads: Record<string, unknown>[] = [];
|
||||
const baseStreamFn: StreamFn = (_model, _context, options) => {
|
||||
const payload: Record<string, unknown> = {};
|
||||
options?.onPayload?.(payload, _model);
|
||||
options?.onPayload?.(payload, model);
|
||||
payloads.push(payload);
|
||||
return {} as ReturnType<StreamFn>;
|
||||
};
|
||||
@@ -674,7 +674,7 @@ describe("applyExtraParamsToAgent", () => {
|
||||
const payloads: Record<string, unknown>[] = [];
|
||||
const baseStreamFn: StreamFn = (_model, _context, options) => {
|
||||
const payload: Record<string, unknown> = { tool_choice: "required" };
|
||||
options?.onPayload?.(payload, _model);
|
||||
options?.onPayload?.(payload, model);
|
||||
payloads.push(payload);
|
||||
return {} as ReturnType<StreamFn>;
|
||||
};
|
||||
@@ -699,7 +699,7 @@ describe("applyExtraParamsToAgent", () => {
|
||||
const payloads: Record<string, unknown>[] = [];
|
||||
const baseStreamFn: StreamFn = (_model, _context, options) => {
|
||||
const payload: Record<string, unknown> = {};
|
||||
options?.onPayload?.(payload, _model);
|
||||
options?.onPayload?.(payload, model);
|
||||
payloads.push(payload);
|
||||
return {} as ReturnType<StreamFn>;
|
||||
};
|
||||
@@ -749,7 +749,7 @@ describe("applyExtraParamsToAgent", () => {
|
||||
],
|
||||
tool_choice: { type: "tool", name: "read" },
|
||||
};
|
||||
options?.onPayload?.(payload, _model);
|
||||
options?.onPayload?.(payload, model);
|
||||
payloads.push(payload);
|
||||
return {} as ReturnType<StreamFn>;
|
||||
};
|
||||
@@ -793,7 +793,7 @@ describe("applyExtraParamsToAgent", () => {
|
||||
},
|
||||
],
|
||||
};
|
||||
options?.onPayload?.(payload, _model);
|
||||
options?.onPayload?.(payload, model);
|
||||
payloads.push(payload);
|
||||
return {} as ReturnType<StreamFn>;
|
||||
};
|
||||
@@ -832,7 +832,7 @@ describe("applyExtraParamsToAgent", () => {
|
||||
},
|
||||
],
|
||||
};
|
||||
options?.onPayload?.(payload, _model);
|
||||
options?.onPayload?.(payload, model);
|
||||
payloads.push(payload);
|
||||
return {} as ReturnType<StreamFn>;
|
||||
};
|
||||
@@ -896,7 +896,7 @@ describe("applyExtraParamsToAgent", () => {
|
||||
},
|
||||
},
|
||||
};
|
||||
options?.onPayload?.(payload, _model);
|
||||
options?.onPayload?.(payload, model);
|
||||
payloads.push(payload);
|
||||
return {} as ReturnType<StreamFn>;
|
||||
};
|
||||
@@ -943,7 +943,7 @@ describe("applyExtraParamsToAgent", () => {
|
||||
},
|
||||
},
|
||||
};
|
||||
options?.onPayload?.(payload, _model);
|
||||
options?.onPayload?.(payload, model);
|
||||
payloads.push(payload);
|
||||
return {} as ReturnType<StreamFn>;
|
||||
};
|
||||
@@ -1081,7 +1081,7 @@ describe("applyExtraParamsToAgent", () => {
|
||||
|
||||
expect(calls).toHaveLength(1);
|
||||
expect(calls[0]?.transport).toBe("auto");
|
||||
expect(calls[0]?.openaiWsWarmup).toBe(true);
|
||||
expect(calls[0]?.openaiWsWarmup).toBe(false);
|
||||
});
|
||||
|
||||
it("lets runtime options override OpenAI default transport", () => {
|
||||
@@ -1449,6 +1449,20 @@ describe("applyExtraParamsToAgent", () => {
|
||||
expect(payload.store).toBe(true);
|
||||
});
|
||||
|
||||
it("forces store=true for azure-openai provider with openai-responses API (#42800)", () => {
|
||||
const payload = runResponsesPayloadMutationCase({
|
||||
applyProvider: "azure-openai",
|
||||
applyModelId: "gpt-5-mini",
|
||||
model: {
|
||||
api: "openai-responses",
|
||||
provider: "azure-openai",
|
||||
id: "gpt-5-mini",
|
||||
baseUrl: "https://myresource.openai.azure.com/openai/v1",
|
||||
} as unknown as Model<"openai-responses">,
|
||||
});
|
||||
expect(payload.store).toBe(true);
|
||||
});
|
||||
|
||||
it("injects configured OpenAI service_tier into Responses payloads", () => {
|
||||
const payload = runResponsesPayloadMutationCase({
|
||||
applyProvider: "openai",
|
||||
|
||||
@@ -981,7 +981,7 @@ describe("runEmbeddedPiAgent auth profile rotation", () => {
|
||||
}),
|
||||
).rejects.toMatchObject({
|
||||
name: "FailoverError",
|
||||
reason: "rate_limit",
|
||||
reason: "unknown",
|
||||
provider: "openai",
|
||||
model: "mock-1",
|
||||
});
|
||||
@@ -1153,7 +1153,7 @@ describe("runEmbeddedPiAgent auth profile rotation", () => {
|
||||
}),
|
||||
).rejects.toMatchObject({
|
||||
name: "FailoverError",
|
||||
reason: "rate_limit",
|
||||
reason: "unknown",
|
||||
provider: "openai",
|
||||
model: "mock-1",
|
||||
});
|
||||
|
||||
@@ -7,6 +7,7 @@ const {
|
||||
sessionCompactImpl,
|
||||
triggerInternalHook,
|
||||
sanitizeSessionHistoryMock,
|
||||
contextEngineCompactMock,
|
||||
} = vi.hoisted(() => ({
|
||||
hookRunner: {
|
||||
hasHooks: vi.fn(),
|
||||
@@ -28,6 +29,14 @@ const {
|
||||
})),
|
||||
triggerInternalHook: vi.fn(),
|
||||
sanitizeSessionHistoryMock: vi.fn(async (params: { messages: unknown[] }) => params.messages),
|
||||
contextEngineCompactMock: vi.fn(async () => ({
|
||||
ok: true as boolean,
|
||||
compacted: true as boolean,
|
||||
reason: undefined as string | undefined,
|
||||
result: { summary: "engine-summary", tokensAfter: 50 } as
|
||||
| { summary: string; tokensAfter: number }
|
||||
| undefined,
|
||||
})),
|
||||
}));
|
||||
|
||||
vi.mock("../../plugins/hook-runner-global.js", () => ({
|
||||
@@ -123,6 +132,27 @@ vi.mock("../session-write-lock.js", () => ({
|
||||
resolveSessionLockMaxHoldFromTimeout: vi.fn(() => 0),
|
||||
}));
|
||||
|
||||
vi.mock("../../context-engine/index.js", () => ({
|
||||
ensureContextEnginesInitialized: vi.fn(),
|
||||
resolveContextEngine: vi.fn(async () => ({
|
||||
info: { ownsCompaction: true },
|
||||
compact: contextEngineCompactMock,
|
||||
})),
|
||||
}));
|
||||
|
||||
vi.mock("../../process/command-queue.js", () => ({
|
||||
enqueueCommandInLane: vi.fn((_lane: unknown, task: () => unknown) => task()),
|
||||
}));
|
||||
|
||||
vi.mock("./lanes.js", () => ({
|
||||
resolveSessionLane: vi.fn(() => "test-session-lane"),
|
||||
resolveGlobalLane: vi.fn(() => "test-global-lane"),
|
||||
}));
|
||||
|
||||
vi.mock("../context-window-guard.js", () => ({
|
||||
resolveContextWindowInfo: vi.fn(() => ({ tokens: 128_000 })),
|
||||
}));
|
||||
|
||||
vi.mock("../bootstrap-files.js", () => ({
|
||||
makeBootstrapWarn: vi.fn(() => () => {}),
|
||||
resolveBootstrapContextForRun: vi.fn(async () => ({ contextFiles: [] })),
|
||||
@@ -160,7 +190,7 @@ vi.mock("../transcript-policy.js", () => ({
|
||||
}));
|
||||
|
||||
vi.mock("./extensions.js", () => ({
|
||||
buildEmbeddedExtensionFactories: vi.fn(() => []),
|
||||
buildEmbeddedExtensionFactories: vi.fn(() => ({ factories: [] })),
|
||||
}));
|
||||
|
||||
vi.mock("./history.js", () => ({
|
||||
@@ -251,7 +281,7 @@ vi.mock("./utils.js", () => ({
|
||||
|
||||
import { getApiProvider, unregisterApiProviders } from "@mariozechner/pi-ai";
|
||||
import { getCustomApiRegistrySourceId } from "../custom-api-registry.js";
|
||||
import { compactEmbeddedPiSessionDirect } from "./compact.js";
|
||||
import { compactEmbeddedPiSessionDirect, compactEmbeddedPiSession } from "./compact.js";
|
||||
|
||||
const sessionHook = (action: string) =>
|
||||
triggerInternalHook.mock.calls.find(
|
||||
@@ -436,3 +466,103 @@ describe("compactEmbeddedPiSessionDirect hooks", () => {
|
||||
expect(result.ok).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe("compactEmbeddedPiSession hooks (ownsCompaction engine)", () => {
|
||||
beforeEach(() => {
|
||||
hookRunner.hasHooks.mockReset();
|
||||
hookRunner.runBeforeCompaction.mockReset();
|
||||
hookRunner.runAfterCompaction.mockReset();
|
||||
contextEngineCompactMock.mockReset();
|
||||
contextEngineCompactMock.mockResolvedValue({
|
||||
ok: true,
|
||||
compacted: true,
|
||||
reason: undefined,
|
||||
result: { summary: "engine-summary", tokensAfter: 50 },
|
||||
});
|
||||
resolveModelMock.mockReset();
|
||||
resolveModelMock.mockReturnValue({
|
||||
model: { provider: "openai", api: "responses", id: "fake", input: [] },
|
||||
error: null,
|
||||
authStorage: { setRuntimeApiKey: vi.fn() },
|
||||
modelRegistry: {},
|
||||
});
|
||||
});
|
||||
|
||||
it("fires before_compaction with sentinel -1 and after_compaction on success", async () => {
|
||||
hookRunner.hasHooks.mockReturnValue(true);
|
||||
|
||||
const result = await compactEmbeddedPiSession({
|
||||
sessionId: "session-1",
|
||||
sessionKey: "agent:main:session-1",
|
||||
sessionFile: "/tmp/session.jsonl",
|
||||
workspaceDir: "/tmp",
|
||||
messageChannel: "telegram",
|
||||
customInstructions: "focus on decisions",
|
||||
enqueue: (task) => task(),
|
||||
});
|
||||
|
||||
expect(result.ok).toBe(true);
|
||||
expect(result.compacted).toBe(true);
|
||||
|
||||
expect(hookRunner.runBeforeCompaction).toHaveBeenCalledWith(
|
||||
{ messageCount: -1, sessionFile: "/tmp/session.jsonl" },
|
||||
expect.objectContaining({
|
||||
sessionKey: "agent:main:session-1",
|
||||
messageProvider: "telegram",
|
||||
}),
|
||||
);
|
||||
expect(hookRunner.runAfterCompaction).toHaveBeenCalledWith(
|
||||
{
|
||||
messageCount: -1,
|
||||
compactedCount: -1,
|
||||
tokenCount: 50,
|
||||
sessionFile: "/tmp/session.jsonl",
|
||||
},
|
||||
expect.objectContaining({
|
||||
sessionKey: "agent:main:session-1",
|
||||
messageProvider: "telegram",
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it("does not fire after_compaction when compaction fails", async () => {
|
||||
hookRunner.hasHooks.mockReturnValue(true);
|
||||
contextEngineCompactMock.mockResolvedValue({
|
||||
ok: false,
|
||||
compacted: false,
|
||||
reason: "nothing to compact",
|
||||
result: undefined,
|
||||
});
|
||||
|
||||
const result = await compactEmbeddedPiSession({
|
||||
sessionId: "session-1",
|
||||
sessionKey: "agent:main:session-1",
|
||||
sessionFile: "/tmp/session.jsonl",
|
||||
workspaceDir: "/tmp",
|
||||
customInstructions: "focus on decisions",
|
||||
enqueue: (task) => task(),
|
||||
});
|
||||
|
||||
expect(result.ok).toBe(false);
|
||||
expect(hookRunner.runBeforeCompaction).toHaveBeenCalled();
|
||||
expect(hookRunner.runAfterCompaction).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("catches and logs hook exceptions without aborting compaction", async () => {
|
||||
hookRunner.hasHooks.mockReturnValue(true);
|
||||
hookRunner.runBeforeCompaction.mockRejectedValue(new Error("hook boom"));
|
||||
|
||||
const result = await compactEmbeddedPiSession({
|
||||
sessionId: "session-1",
|
||||
sessionKey: "agent:main:session-1",
|
||||
sessionFile: "/tmp/session.jsonl",
|
||||
workspaceDir: "/tmp",
|
||||
customInstructions: "focus on decisions",
|
||||
enqueue: (task) => task(),
|
||||
});
|
||||
|
||||
expect(result.ok).toBe(true);
|
||||
expect(result.compacted).toBe(true);
|
||||
expect(contextEngineCompactMock).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
@@ -936,6 +936,43 @@ export async function compactEmbeddedPiSession(
|
||||
modelContextWindow: ceModel?.contextWindow,
|
||||
defaultTokens: DEFAULT_CONTEXT_TOKENS,
|
||||
});
|
||||
// When the context engine owns compaction, its compact() implementation
|
||||
// bypasses compactEmbeddedPiSessionDirect (which fires the hooks internally).
|
||||
// Fire before_compaction / after_compaction hooks here so plugin subscribers
|
||||
// are notified regardless of which engine is active.
|
||||
const engineOwnsCompaction = contextEngine.info.ownsCompaction === true;
|
||||
const hookRunner = engineOwnsCompaction ? getGlobalHookRunner() : null;
|
||||
const hookSessionKey = params.sessionKey?.trim() || params.sessionId;
|
||||
const { sessionAgentId } = resolveSessionAgentIds({
|
||||
sessionKey: params.sessionKey,
|
||||
config: params.config,
|
||||
});
|
||||
const resolvedMessageProvider = params.messageChannel ?? params.messageProvider;
|
||||
const hookCtx = {
|
||||
sessionId: params.sessionId,
|
||||
agentId: sessionAgentId,
|
||||
sessionKey: hookSessionKey,
|
||||
workspaceDir: resolveUserPath(params.workspaceDir),
|
||||
messageProvider: resolvedMessageProvider,
|
||||
};
|
||||
// Engine-owned compaction doesn't load the transcript at this level, so
|
||||
// message counts are unavailable. We pass sessionFile so hook subscribers
|
||||
// can read the transcript themselves if they need exact counts.
|
||||
if (hookRunner?.hasHooks("before_compaction")) {
|
||||
try {
|
||||
await hookRunner.runBeforeCompaction(
|
||||
{
|
||||
messageCount: -1,
|
||||
sessionFile: params.sessionFile,
|
||||
},
|
||||
hookCtx,
|
||||
);
|
||||
} catch (err) {
|
||||
log.warn("before_compaction hook failed", {
|
||||
errorMessage: err instanceof Error ? err.message : String(err),
|
||||
});
|
||||
}
|
||||
}
|
||||
const result = await contextEngine.compact({
|
||||
sessionId: params.sessionId,
|
||||
sessionFile: params.sessionFile,
|
||||
@@ -944,6 +981,23 @@ export async function compactEmbeddedPiSession(
|
||||
force: params.trigger === "manual",
|
||||
runtimeContext: params as Record<string, unknown>,
|
||||
});
|
||||
if (result.ok && result.compacted && hookRunner?.hasHooks("after_compaction")) {
|
||||
try {
|
||||
await hookRunner.runAfterCompaction(
|
||||
{
|
||||
messageCount: -1,
|
||||
compactedCount: -1,
|
||||
tokenCount: result.result?.tokensAfter,
|
||||
sessionFile: params.sessionFile,
|
||||
},
|
||||
hookCtx,
|
||||
);
|
||||
} catch (err) {
|
||||
log.warn("after_compaction hook failed", {
|
||||
errorMessage: err instanceof Error ? err.message : String(err),
|
||||
});
|
||||
}
|
||||
}
|
||||
return {
|
||||
ok: result.ok,
|
||||
compacted: result.compacted,
|
||||
|
||||
@@ -202,6 +202,42 @@ describe("buildInlineProviderModels", () => {
|
||||
});
|
||||
|
||||
describe("resolveModel", () => {
|
||||
it("defaults model input to text when discovery omits input", () => {
|
||||
mockDiscoveredModel({
|
||||
provider: "custom",
|
||||
modelId: "missing-input",
|
||||
templateModel: {
|
||||
id: "missing-input",
|
||||
name: "missing-input",
|
||||
api: "openai-completions",
|
||||
provider: "custom",
|
||||
baseUrl: "http://localhost:9999",
|
||||
reasoning: false,
|
||||
// NOTE: deliberately omit input to simulate buggy/custom catalogs.
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
|
||||
contextWindow: 8192,
|
||||
maxTokens: 1024,
|
||||
},
|
||||
});
|
||||
|
||||
const result = resolveModel("custom", "missing-input", "/tmp/agent", {
|
||||
models: {
|
||||
providers: {
|
||||
custom: {
|
||||
baseUrl: "http://localhost:9999",
|
||||
api: "openai-completions",
|
||||
// Intentionally keep this minimal — the discovered model provides the rest.
|
||||
models: [{ id: "missing-input", name: "missing-input" }],
|
||||
},
|
||||
},
|
||||
},
|
||||
} as unknown as OpenClawConfig);
|
||||
|
||||
expect(result.error).toBeUndefined();
|
||||
expect(Array.isArray(result.model?.input)).toBe(true);
|
||||
expect(result.model?.input).toEqual(["text"]);
|
||||
});
|
||||
|
||||
it("includes provider baseUrl in fallback model", () => {
|
||||
const cfg = {
|
||||
models: {
|
||||
@@ -346,6 +382,40 @@ describe("resolveModel", () => {
|
||||
expect(result.model?.reasoning).toBe(true);
|
||||
});
|
||||
|
||||
it("matches prefixed OpenRouter native ids in configured fallback models", () => {
|
||||
const cfg = {
|
||||
models: {
|
||||
providers: {
|
||||
openrouter: {
|
||||
baseUrl: "https://openrouter.ai/api/v1",
|
||||
api: "openai-completions",
|
||||
models: [
|
||||
{
|
||||
...makeModel("openrouter/healer-alpha"),
|
||||
reasoning: true,
|
||||
input: ["text", "image"],
|
||||
contextWindow: 262144,
|
||||
maxTokens: 65536,
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
},
|
||||
} as OpenClawConfig;
|
||||
|
||||
const result = resolveModel("openrouter", "openrouter/healer-alpha", "/tmp/agent", cfg);
|
||||
|
||||
expect(result.error).toBeUndefined();
|
||||
expect(result.model).toMatchObject({
|
||||
provider: "openrouter",
|
||||
id: "openrouter/healer-alpha",
|
||||
reasoning: true,
|
||||
input: ["text", "image"],
|
||||
contextWindow: 262144,
|
||||
maxTokens: 65536,
|
||||
});
|
||||
});
|
||||
|
||||
it("prefers configured provider api metadata over discovered registry model", () => {
|
||||
mockDiscoveredModel({
|
||||
provider: "onehub",
|
||||
|
||||
@@ -93,12 +93,18 @@ function applyConfiguredProviderOverrides(params: {
|
||||
headers: discoveredHeaders,
|
||||
};
|
||||
}
|
||||
const resolvedInput = configuredModel?.input ?? discoveredModel.input;
|
||||
const normalizedInput =
|
||||
Array.isArray(resolvedInput) && resolvedInput.length > 0
|
||||
? resolvedInput.filter((item) => item === "text" || item === "image")
|
||||
: (["text"] as Array<"text" | "image">);
|
||||
|
||||
return {
|
||||
...discoveredModel,
|
||||
api: configuredModel?.api ?? providerConfig.api ?? discoveredModel.api,
|
||||
baseUrl: providerConfig.baseUrl ?? discoveredModel.baseUrl,
|
||||
reasoning: configuredModel?.reasoning ?? discoveredModel.reasoning,
|
||||
input: configuredModel?.input ?? discoveredModel.input,
|
||||
input: normalizedInput,
|
||||
cost: configuredModel?.cost ?? discoveredModel.cost,
|
||||
contextWindow: configuredModel?.contextWindow ?? discoveredModel.contextWindow,
|
||||
maxTokens: configuredModel?.maxTokens ?? discoveredModel.maxTokens,
|
||||
|
||||
@@ -6,7 +6,7 @@ import { log } from "./logger.js";
|
||||
type OpenAIServiceTier = "auto" | "default" | "flex" | "priority";
|
||||
|
||||
const OPENAI_RESPONSES_APIS = new Set(["openai-responses"]);
|
||||
const OPENAI_RESPONSES_PROVIDERS = new Set(["openai", "azure-openai-responses"]);
|
||||
const OPENAI_RESPONSES_PROVIDERS = new Set(["openai", "azure-openai", "azure-openai-responses"]);
|
||||
|
||||
function isDirectOpenAIBaseUrl(baseUrl: unknown): boolean {
|
||||
if (typeof baseUrl !== "string" || !baseUrl.trim()) {
|
||||
@@ -250,7 +250,7 @@ export function createOpenAIDefaultTransportWrapper(baseStreamFn: StreamFn | und
|
||||
const mergedOptions = {
|
||||
...options,
|
||||
transport: options?.transport ?? "auto",
|
||||
openaiWsWarmup: typedOptions?.openaiWsWarmup ?? true,
|
||||
openaiWsWarmup: typedOptions?.openaiWsWarmup ?? false,
|
||||
} as SimpleStreamOptions;
|
||||
return underlying(model, context, mergedOptions);
|
||||
};
|
||||
|
||||
@@ -9,16 +9,18 @@ export function makeOverflowError(message: string = DEFAULT_OVERFLOW_ERROR_MESSA
|
||||
|
||||
export function makeCompactionSuccess(params: {
|
||||
summary: string;
|
||||
firstKeptEntryId: string;
|
||||
tokensBefore: number;
|
||||
firstKeptEntryId?: string;
|
||||
tokensBefore?: number;
|
||||
tokensAfter?: number;
|
||||
}) {
|
||||
return {
|
||||
ok: true as const,
|
||||
compacted: true as const,
|
||||
result: {
|
||||
summary: params.summary,
|
||||
firstKeptEntryId: params.firstKeptEntryId,
|
||||
tokensBefore: params.tokensBefore,
|
||||
...(params.firstKeptEntryId ? { firstKeptEntryId: params.firstKeptEntryId } : {}),
|
||||
...(params.tokensBefore !== undefined ? { tokensBefore: params.tokensBefore } : {}),
|
||||
...(params.tokensAfter !== undefined ? { tokensAfter: params.tokensAfter } : {}),
|
||||
},
|
||||
};
|
||||
}
|
||||
@@ -55,8 +57,9 @@ type MockCompactDirect = {
|
||||
compacted: true;
|
||||
result: {
|
||||
summary: string;
|
||||
firstKeptEntryId: string;
|
||||
tokensBefore: number;
|
||||
firstKeptEntryId?: string;
|
||||
tokensBefore?: number;
|
||||
tokensAfter?: number;
|
||||
};
|
||||
}) => unknown;
|
||||
};
|
||||
|
||||
@@ -2,9 +2,13 @@ import "./run.overflow-compaction.mocks.shared.js";
|
||||
import { beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import { isCompactionFailureError, isLikelyContextOverflowError } from "../pi-embedded-helpers.js";
|
||||
|
||||
vi.mock("../../utils.js", () => ({
|
||||
resolveUserPath: vi.fn((p: string) => p),
|
||||
}));
|
||||
vi.mock(import("../../utils.js"), async (importOriginal) => {
|
||||
const actual = await importOriginal();
|
||||
return {
|
||||
...actual,
|
||||
resolveUserPath: vi.fn((p: string) => p),
|
||||
};
|
||||
});
|
||||
|
||||
import { log } from "./logger.js";
|
||||
import { runEmbeddedPiAgent } from "./run.js";
|
||||
@@ -16,6 +20,7 @@ import {
|
||||
queueOverflowAttemptWithOversizedToolOutput,
|
||||
} from "./run.overflow-compaction.fixture.js";
|
||||
import {
|
||||
mockedContextEngine,
|
||||
mockedCompactDirect,
|
||||
mockedRunEmbeddedAttempt,
|
||||
mockedSessionLikelyHasOversizedToolResults,
|
||||
@@ -30,6 +35,11 @@ const mockedIsLikelyContextOverflowError = vi.mocked(isLikelyContextOverflowErro
|
||||
describe("overflow compaction in run loop", () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
mockedRunEmbeddedAttempt.mockReset();
|
||||
mockedCompactDirect.mockReset();
|
||||
mockedSessionLikelyHasOversizedToolResults.mockReset();
|
||||
mockedTruncateOversizedToolResultsInSession.mockReset();
|
||||
mockedContextEngine.info.ownsCompaction = false;
|
||||
mockedIsCompactionFailureError.mockImplementation((msg?: string) => {
|
||||
if (!msg) {
|
||||
return false;
|
||||
@@ -72,7 +82,9 @@ describe("overflow compaction in run loop", () => {
|
||||
|
||||
expect(mockedCompactDirect).toHaveBeenCalledTimes(1);
|
||||
expect(mockedCompactDirect).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ authProfileId: "test-profile" }),
|
||||
expect.objectContaining({
|
||||
runtimeContext: expect.objectContaining({ authProfileId: "test-profile" }),
|
||||
}),
|
||||
);
|
||||
expect(mockedRunEmbeddedAttempt).toHaveBeenCalledTimes(2);
|
||||
expect(log.warn).toHaveBeenCalledWith(
|
||||
|
||||
@@ -6,6 +6,25 @@ import type {
|
||||
PluginHookBeforePromptBuildResult,
|
||||
} from "../../plugins/types.js";
|
||||
|
||||
type MockCompactionResult =
|
||||
| {
|
||||
ok: true;
|
||||
compacted: true;
|
||||
result: {
|
||||
summary: string;
|
||||
firstKeptEntryId?: string;
|
||||
tokensBefore?: number;
|
||||
tokensAfter?: number;
|
||||
};
|
||||
reason?: string;
|
||||
}
|
||||
| {
|
||||
ok: false;
|
||||
compacted: false;
|
||||
reason: string;
|
||||
result?: undefined;
|
||||
};
|
||||
|
||||
export const mockedGlobalHookRunner = {
|
||||
hasHooks: vi.fn((_hookName: string) => false),
|
||||
runBeforeAgentStart: vi.fn(
|
||||
@@ -26,12 +45,35 @@ export const mockedGlobalHookRunner = {
|
||||
_ctx: PluginHookAgentContext,
|
||||
): Promise<PluginHookBeforeModelResolveResult | undefined> => undefined,
|
||||
),
|
||||
runBeforeCompaction: vi.fn(async () => undefined),
|
||||
runAfterCompaction: vi.fn(async () => undefined),
|
||||
};
|
||||
|
||||
export const mockedContextEngine = {
|
||||
info: { ownsCompaction: false as boolean },
|
||||
compact: vi.fn<(params: unknown) => Promise<MockCompactionResult>>(async () => ({
|
||||
ok: false as const,
|
||||
compacted: false as const,
|
||||
reason: "nothing to compact",
|
||||
})),
|
||||
};
|
||||
|
||||
export const mockedContextEngineCompact = vi.mocked(mockedContextEngine.compact);
|
||||
export const mockedEnsureRuntimePluginsLoaded: (...args: unknown[]) => void = vi.fn();
|
||||
|
||||
vi.mock("../../plugins/hook-runner-global.js", () => ({
|
||||
getGlobalHookRunner: vi.fn(() => mockedGlobalHookRunner),
|
||||
}));
|
||||
|
||||
vi.mock("../../context-engine/index.js", () => ({
|
||||
ensureContextEnginesInitialized: vi.fn(),
|
||||
resolveContextEngine: vi.fn(async () => mockedContextEngine),
|
||||
}));
|
||||
|
||||
vi.mock("../runtime-plugins.js", () => ({
|
||||
ensureRuntimePluginsLoaded: mockedEnsureRuntimePluginsLoaded,
|
||||
}));
|
||||
|
||||
vi.mock("../auth-profiles.js", () => ({
|
||||
isProfileInCooldown: vi.fn(() => false),
|
||||
markAuthProfileFailure: vi.fn(async () => {}),
|
||||
@@ -141,9 +183,13 @@ vi.mock("../../process/command-queue.js", () => ({
|
||||
enqueueCommandInLane: vi.fn((_lane: string, task: () => unknown) => task()),
|
||||
}));
|
||||
|
||||
vi.mock("../../utils/message-channel.js", () => ({
|
||||
isMarkdownCapableMessageChannel: vi.fn(() => true),
|
||||
}));
|
||||
vi.mock(import("../../utils/message-channel.js"), async (importOriginal) => {
|
||||
const actual = await importOriginal();
|
||||
return {
|
||||
...actual,
|
||||
isMarkdownCapableMessageChannel: vi.fn(() => true),
|
||||
};
|
||||
});
|
||||
|
||||
vi.mock("../agent-paths.js", () => ({
|
||||
resolveOpenClawAgentDir: vi.fn(() => "/tmp/agent-dir"),
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
import { vi } from "vitest";
|
||||
import { compactEmbeddedPiSessionDirect } from "./compact.js";
|
||||
import {
|
||||
mockedContextEngine,
|
||||
mockedContextEngineCompact,
|
||||
} from "./run.overflow-compaction.mocks.shared.js";
|
||||
import { runEmbeddedAttempt } from "./run/attempt.js";
|
||||
import {
|
||||
sessionLikelyHasOversizedToolResults,
|
||||
@@ -7,13 +10,14 @@ import {
|
||||
} from "./tool-result-truncation.js";
|
||||
|
||||
export const mockedRunEmbeddedAttempt = vi.mocked(runEmbeddedAttempt);
|
||||
export const mockedCompactDirect = vi.mocked(compactEmbeddedPiSessionDirect);
|
||||
export const mockedCompactDirect = mockedContextEngineCompact;
|
||||
export const mockedSessionLikelyHasOversizedToolResults = vi.mocked(
|
||||
sessionLikelyHasOversizedToolResults,
|
||||
);
|
||||
export const mockedTruncateOversizedToolResultsInSession = vi.mocked(
|
||||
truncateOversizedToolResultsInSession,
|
||||
);
|
||||
export { mockedContextEngine };
|
||||
|
||||
export const overflowBaseRunParams = {
|
||||
sessionId: "test-session",
|
||||
|
||||
@@ -11,6 +11,7 @@ import {
|
||||
} from "./run.overflow-compaction.fixture.js";
|
||||
import { mockedGlobalHookRunner } from "./run.overflow-compaction.mocks.shared.js";
|
||||
import {
|
||||
mockedContextEngine,
|
||||
mockedCompactDirect,
|
||||
mockedRunEmbeddedAttempt,
|
||||
mockedSessionLikelyHasOversizedToolResults,
|
||||
@@ -22,6 +23,25 @@ const mockedPickFallbackThinkingLevel = vi.mocked(pickFallbackThinkingLevel);
|
||||
describe("runEmbeddedPiAgent overflow compaction trigger routing", () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
mockedRunEmbeddedAttempt.mockReset();
|
||||
mockedCompactDirect.mockReset();
|
||||
mockedSessionLikelyHasOversizedToolResults.mockReset();
|
||||
mockedTruncateOversizedToolResultsInSession.mockReset();
|
||||
mockedGlobalHookRunner.runBeforeAgentStart.mockReset();
|
||||
mockedGlobalHookRunner.runBeforeCompaction.mockReset();
|
||||
mockedGlobalHookRunner.runAfterCompaction.mockReset();
|
||||
mockedContextEngine.info.ownsCompaction = false;
|
||||
mockedCompactDirect.mockResolvedValue({
|
||||
ok: false,
|
||||
compacted: false,
|
||||
reason: "nothing to compact",
|
||||
});
|
||||
mockedSessionLikelyHasOversizedToolResults.mockReturnValue(false);
|
||||
mockedTruncateOversizedToolResultsInSession.mockResolvedValue({
|
||||
truncated: false,
|
||||
truncatedCount: 0,
|
||||
reason: "no oversized tool results",
|
||||
});
|
||||
mockedGlobalHookRunner.hasHooks.mockImplementation(() => false);
|
||||
});
|
||||
|
||||
@@ -81,8 +101,12 @@ describe("runEmbeddedPiAgent overflow compaction trigger routing", () => {
|
||||
expect(mockedCompactDirect).toHaveBeenCalledTimes(1);
|
||||
expect(mockedCompactDirect).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
trigger: "overflow",
|
||||
authProfileId: "test-profile",
|
||||
sessionId: "test-session",
|
||||
sessionFile: "/tmp/session.json",
|
||||
runtimeContext: expect.objectContaining({
|
||||
trigger: "overflow",
|
||||
authProfileId: "test-profile",
|
||||
}),
|
||||
}),
|
||||
);
|
||||
});
|
||||
@@ -132,6 +156,63 @@ describe("runEmbeddedPiAgent overflow compaction trigger routing", () => {
|
||||
expect(result.meta.error?.kind).toBe("context_overflow");
|
||||
});
|
||||
|
||||
it("fires compaction hooks during overflow recovery for ownsCompaction engines", async () => {
|
||||
mockedContextEngine.info.ownsCompaction = true;
|
||||
mockedGlobalHookRunner.hasHooks.mockImplementation(
|
||||
(hookName) => hookName === "before_compaction" || hookName === "after_compaction",
|
||||
);
|
||||
mockedRunEmbeddedAttempt
|
||||
.mockResolvedValueOnce(makeAttemptResult({ promptError: makeOverflowError() }))
|
||||
.mockResolvedValueOnce(makeAttemptResult({ promptError: null }));
|
||||
mockedCompactDirect.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
compacted: true,
|
||||
result: {
|
||||
summary: "engine-owned compaction",
|
||||
tokensAfter: 50,
|
||||
},
|
||||
});
|
||||
|
||||
await runEmbeddedPiAgent(overflowBaseRunParams);
|
||||
|
||||
expect(mockedGlobalHookRunner.runBeforeCompaction).toHaveBeenCalledWith(
|
||||
{ messageCount: -1, sessionFile: "/tmp/session.json" },
|
||||
expect.objectContaining({
|
||||
sessionKey: "test-key",
|
||||
}),
|
||||
);
|
||||
expect(mockedGlobalHookRunner.runAfterCompaction).toHaveBeenCalledWith(
|
||||
{
|
||||
messageCount: -1,
|
||||
compactedCount: -1,
|
||||
tokenCount: 50,
|
||||
sessionFile: "/tmp/session.json",
|
||||
},
|
||||
expect.objectContaining({
|
||||
sessionKey: "test-key",
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it("guards thrown engine-owned overflow compaction attempts", async () => {
|
||||
mockedContextEngine.info.ownsCompaction = true;
|
||||
mockedGlobalHookRunner.hasHooks.mockImplementation(
|
||||
(hookName) => hookName === "before_compaction" || hookName === "after_compaction",
|
||||
);
|
||||
mockedRunEmbeddedAttempt.mockResolvedValueOnce(
|
||||
makeAttemptResult({ promptError: makeOverflowError() }),
|
||||
);
|
||||
mockedCompactDirect.mockRejectedValueOnce(new Error("engine boom"));
|
||||
|
||||
const result = await runEmbeddedPiAgent(overflowBaseRunParams);
|
||||
|
||||
expect(mockedCompactDirect).toHaveBeenCalledTimes(1);
|
||||
expect(mockedGlobalHookRunner.runBeforeCompaction).toHaveBeenCalledTimes(1);
|
||||
expect(mockedGlobalHookRunner.runAfterCompaction).not.toHaveBeenCalled();
|
||||
expect(result.meta.error?.kind).toBe("context_overflow");
|
||||
expect(result.payloads?.[0]?.isError).toBe(true);
|
||||
});
|
||||
|
||||
it("returns retry_limit when repeated retries never converge", async () => {
|
||||
mockedRunEmbeddedAttempt.mockClear();
|
||||
mockedCompactDirect.mockClear();
|
||||
|
||||
@@ -553,7 +553,7 @@ export async function runEmbeddedPiAgent(
|
||||
resolveProfilesUnavailableReason({
|
||||
store: authStore,
|
||||
profileIds,
|
||||
}) ?? "rate_limit"
|
||||
}) ?? "unknown"
|
||||
);
|
||||
}
|
||||
const classified = classifyFailoverReason(params.message);
|
||||
@@ -669,14 +669,15 @@ export async function runEmbeddedPiAgent(
|
||||
? (resolveProfilesUnavailableReason({
|
||||
store: authStore,
|
||||
profileIds: autoProfileCandidates,
|
||||
}) ?? "rate_limit")
|
||||
}) ?? "unknown")
|
||||
: null;
|
||||
const allowTransientCooldownProbe =
|
||||
params.allowTransientCooldownProbe === true &&
|
||||
allAutoProfilesInCooldown &&
|
||||
(unavailableReason === "rate_limit" ||
|
||||
unavailableReason === "overloaded" ||
|
||||
unavailableReason === "billing");
|
||||
unavailableReason === "billing" ||
|
||||
unavailableReason === "unknown");
|
||||
let didTransientCooldownProbe = false;
|
||||
|
||||
while (profileIndex < profileCandidates.length) {
|
||||
@@ -1027,37 +1028,84 @@ export async function runEmbeddedPiAgent(
|
||||
log.warn(
|
||||
`context overflow detected (attempt ${overflowCompactionAttempts}/${MAX_OVERFLOW_COMPACTION_ATTEMPTS}); attempting auto-compaction for ${provider}/${modelId}`,
|
||||
);
|
||||
const compactResult = await contextEngine.compact({
|
||||
sessionId: params.sessionId,
|
||||
sessionFile: params.sessionFile,
|
||||
tokenBudget: ctxInfo.tokens,
|
||||
force: true,
|
||||
compactionTarget: "budget",
|
||||
runtimeContext: {
|
||||
sessionKey: params.sessionKey,
|
||||
messageChannel: params.messageChannel,
|
||||
messageProvider: params.messageProvider,
|
||||
agentAccountId: params.agentAccountId,
|
||||
authProfileId: lastProfileId,
|
||||
workspaceDir: resolvedWorkspace,
|
||||
agentDir,
|
||||
config: params.config,
|
||||
skillsSnapshot: params.skillsSnapshot,
|
||||
senderIsOwner: params.senderIsOwner,
|
||||
provider,
|
||||
model: modelId,
|
||||
runId: params.runId,
|
||||
thinkLevel,
|
||||
reasoningLevel: params.reasoningLevel,
|
||||
bashElevated: params.bashElevated,
|
||||
extraSystemPrompt: params.extraSystemPrompt,
|
||||
ownerNumbers: params.ownerNumbers,
|
||||
trigger: "overflow",
|
||||
diagId: overflowDiagId,
|
||||
attempt: overflowCompactionAttempts,
|
||||
maxAttempts: MAX_OVERFLOW_COMPACTION_ATTEMPTS,
|
||||
},
|
||||
});
|
||||
let compactResult: Awaited<ReturnType<typeof contextEngine.compact>>;
|
||||
// When the engine owns compaction, hooks are not fired inside
|
||||
// compactEmbeddedPiSessionDirect (which is bypassed). Fire them
|
||||
// here so subscribers (memory extensions, usage trackers) are
|
||||
// notified even on overflow-recovery compactions.
|
||||
const overflowEngineOwnsCompaction = contextEngine.info.ownsCompaction === true;
|
||||
const overflowHookRunner = overflowEngineOwnsCompaction ? hookRunner : null;
|
||||
if (overflowHookRunner?.hasHooks("before_compaction")) {
|
||||
try {
|
||||
await overflowHookRunner.runBeforeCompaction(
|
||||
{ messageCount: -1, sessionFile: params.sessionFile },
|
||||
hookCtx,
|
||||
);
|
||||
} catch (hookErr) {
|
||||
log.warn(
|
||||
`before_compaction hook failed during overflow recovery: ${String(hookErr)}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
try {
|
||||
compactResult = await contextEngine.compact({
|
||||
sessionId: params.sessionId,
|
||||
sessionFile: params.sessionFile,
|
||||
tokenBudget: ctxInfo.tokens,
|
||||
force: true,
|
||||
compactionTarget: "budget",
|
||||
runtimeContext: {
|
||||
sessionKey: params.sessionKey,
|
||||
messageChannel: params.messageChannel,
|
||||
messageProvider: params.messageProvider,
|
||||
agentAccountId: params.agentAccountId,
|
||||
authProfileId: lastProfileId,
|
||||
workspaceDir: resolvedWorkspace,
|
||||
agentDir,
|
||||
config: params.config,
|
||||
skillsSnapshot: params.skillsSnapshot,
|
||||
senderIsOwner: params.senderIsOwner,
|
||||
provider,
|
||||
model: modelId,
|
||||
runId: params.runId,
|
||||
thinkLevel,
|
||||
reasoningLevel: params.reasoningLevel,
|
||||
bashElevated: params.bashElevated,
|
||||
extraSystemPrompt: params.extraSystemPrompt,
|
||||
ownerNumbers: params.ownerNumbers,
|
||||
trigger: "overflow",
|
||||
diagId: overflowDiagId,
|
||||
attempt: overflowCompactionAttempts,
|
||||
maxAttempts: MAX_OVERFLOW_COMPACTION_ATTEMPTS,
|
||||
},
|
||||
});
|
||||
} catch (compactErr) {
|
||||
log.warn(
|
||||
`contextEngine.compact() threw during overflow recovery for ${provider}/${modelId}: ${String(compactErr)}`,
|
||||
);
|
||||
compactResult = { ok: false, compacted: false, reason: String(compactErr) };
|
||||
}
|
||||
if (
|
||||
compactResult.ok &&
|
||||
compactResult.compacted &&
|
||||
overflowHookRunner?.hasHooks("after_compaction")
|
||||
) {
|
||||
try {
|
||||
await overflowHookRunner.runAfterCompaction(
|
||||
{
|
||||
messageCount: -1,
|
||||
compactedCount: -1,
|
||||
tokenCount: compactResult.result?.tokensAfter,
|
||||
sessionFile: params.sessionFile,
|
||||
},
|
||||
hookCtx,
|
||||
);
|
||||
} catch (hookErr) {
|
||||
log.warn(
|
||||
`after_compaction hook failed during overflow recovery: ${String(hookErr)}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
if (compactResult.compacted) {
|
||||
autoCompactionCount += 1;
|
||||
log.info(`auto-compaction succeeded for ${provider}/${modelId}; retrying prompt`);
|
||||
|
||||
@@ -79,6 +79,7 @@ vi.mock("../../../infra/machine-name.js", () => ({
|
||||
}));
|
||||
|
||||
vi.mock("../../../infra/net/undici-global-dispatcher.js", () => ({
|
||||
ensureGlobalUndiciEnvProxyDispatcher: () => {},
|
||||
ensureGlobalUndiciStreamTimeouts: () => {},
|
||||
}));
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ import {
|
||||
shouldInjectOllamaCompatNumCtx,
|
||||
decodeHtmlEntitiesInObject,
|
||||
wrapOllamaCompatNumCtx,
|
||||
wrapStreamFnRepairMalformedToolCallArguments,
|
||||
wrapStreamFnTrimToolCallNames,
|
||||
} from "./attempt.js";
|
||||
|
||||
@@ -430,6 +431,182 @@ describe("wrapStreamFnTrimToolCallNames", () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe("wrapStreamFnRepairMalformedToolCallArguments", () => {
|
||||
function createFakeStream(params: { events: unknown[]; resultMessage: unknown }): {
|
||||
result: () => Promise<unknown>;
|
||||
[Symbol.asyncIterator]: () => AsyncIterator<unknown>;
|
||||
} {
|
||||
return {
|
||||
async result() {
|
||||
return params.resultMessage;
|
||||
},
|
||||
[Symbol.asyncIterator]() {
|
||||
return (async function* () {
|
||||
for (const event of params.events) {
|
||||
yield event;
|
||||
}
|
||||
})();
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
async function invokeWrappedStream(baseFn: (...args: never[]) => unknown) {
|
||||
const wrappedFn = wrapStreamFnRepairMalformedToolCallArguments(baseFn as never);
|
||||
return await wrappedFn({} as never, {} as never, {} as never);
|
||||
}
|
||||
|
||||
it("repairs anthropic-compatible tool arguments when trailing junk follows valid JSON", async () => {
|
||||
const partialToolCall = { type: "toolCall", name: "read", arguments: {} };
|
||||
const streamedToolCall = { type: "toolCall", name: "read", arguments: {} };
|
||||
const endMessageToolCall = { type: "toolCall", name: "read", arguments: {} };
|
||||
const finalToolCall = { type: "toolCall", name: "read", arguments: {} };
|
||||
const partialMessage = { role: "assistant", content: [partialToolCall] };
|
||||
const endMessage = { role: "assistant", content: [endMessageToolCall] };
|
||||
const finalMessage = { role: "assistant", content: [finalToolCall] };
|
||||
const baseFn = vi.fn(() =>
|
||||
createFakeStream({
|
||||
events: [
|
||||
{
|
||||
type: "toolcall_delta",
|
||||
contentIndex: 0,
|
||||
delta: '{"path":"/tmp/report.txt"}',
|
||||
partial: partialMessage,
|
||||
},
|
||||
{
|
||||
type: "toolcall_delta",
|
||||
contentIndex: 0,
|
||||
delta: "xx",
|
||||
partial: partialMessage,
|
||||
},
|
||||
{
|
||||
type: "toolcall_end",
|
||||
contentIndex: 0,
|
||||
toolCall: streamedToolCall,
|
||||
partial: partialMessage,
|
||||
message: endMessage,
|
||||
},
|
||||
],
|
||||
resultMessage: finalMessage,
|
||||
}),
|
||||
);
|
||||
|
||||
const stream = await invokeWrappedStream(baseFn);
|
||||
for await (const _item of stream) {
|
||||
// drain
|
||||
}
|
||||
const result = await stream.result();
|
||||
|
||||
expect(partialToolCall.arguments).toEqual({ path: "/tmp/report.txt" });
|
||||
expect(streamedToolCall.arguments).toEqual({ path: "/tmp/report.txt" });
|
||||
expect(endMessageToolCall.arguments).toEqual({ path: "/tmp/report.txt" });
|
||||
expect(finalToolCall.arguments).toEqual({ path: "/tmp/report.txt" });
|
||||
expect(result).toBe(finalMessage);
|
||||
});
|
||||
|
||||
it("keeps incomplete partial JSON unchanged until a complete object exists", async () => {
|
||||
const partialToolCall = { type: "toolCall", name: "read", arguments: {} };
|
||||
const partialMessage = { role: "assistant", content: [partialToolCall] };
|
||||
const baseFn = vi.fn(() =>
|
||||
createFakeStream({
|
||||
events: [
|
||||
{
|
||||
type: "toolcall_delta",
|
||||
contentIndex: 0,
|
||||
delta: '{"path":"/tmp',
|
||||
partial: partialMessage,
|
||||
},
|
||||
],
|
||||
resultMessage: { role: "assistant", content: [partialToolCall] },
|
||||
}),
|
||||
);
|
||||
|
||||
const stream = await invokeWrappedStream(baseFn);
|
||||
for await (const _item of stream) {
|
||||
// drain
|
||||
}
|
||||
|
||||
expect(partialToolCall.arguments).toEqual({});
|
||||
});
|
||||
|
||||
it("does not repair tool arguments when trailing junk exceeds the Kimi-specific allowance", async () => {
|
||||
const partialToolCall = { type: "toolCall", name: "read", arguments: {} };
|
||||
const streamedToolCall = { type: "toolCall", name: "read", arguments: {} };
|
||||
const partialMessage = { role: "assistant", content: [partialToolCall] };
|
||||
const baseFn = vi.fn(() =>
|
||||
createFakeStream({
|
||||
events: [
|
||||
{
|
||||
type: "toolcall_delta",
|
||||
contentIndex: 0,
|
||||
delta: '{"path":"/tmp/report.txt"}oops',
|
||||
partial: partialMessage,
|
||||
},
|
||||
{
|
||||
type: "toolcall_end",
|
||||
contentIndex: 0,
|
||||
toolCall: streamedToolCall,
|
||||
partial: partialMessage,
|
||||
},
|
||||
],
|
||||
resultMessage: { role: "assistant", content: [partialToolCall] },
|
||||
}),
|
||||
);
|
||||
|
||||
const stream = await invokeWrappedStream(baseFn);
|
||||
for await (const _item of stream) {
|
||||
// drain
|
||||
}
|
||||
|
||||
expect(partialToolCall.arguments).toEqual({});
|
||||
expect(streamedToolCall.arguments).toEqual({});
|
||||
});
|
||||
|
||||
it("clears a cached repair when later deltas make the trailing suffix invalid", async () => {
|
||||
const partialToolCall = { type: "toolCall", name: "read", arguments: {} };
|
||||
const streamedToolCall = { type: "toolCall", name: "read", arguments: {} };
|
||||
const partialMessage = { role: "assistant", content: [partialToolCall] };
|
||||
const baseFn = vi.fn(() =>
|
||||
createFakeStream({
|
||||
events: [
|
||||
{
|
||||
type: "toolcall_delta",
|
||||
contentIndex: 0,
|
||||
delta: '{"path":"/tmp/report.txt"}',
|
||||
partial: partialMessage,
|
||||
},
|
||||
{
|
||||
type: "toolcall_delta",
|
||||
contentIndex: 0,
|
||||
delta: "x",
|
||||
partial: partialMessage,
|
||||
},
|
||||
{
|
||||
type: "toolcall_delta",
|
||||
contentIndex: 0,
|
||||
delta: "yzq",
|
||||
partial: partialMessage,
|
||||
},
|
||||
{
|
||||
type: "toolcall_end",
|
||||
contentIndex: 0,
|
||||
toolCall: streamedToolCall,
|
||||
partial: partialMessage,
|
||||
},
|
||||
],
|
||||
resultMessage: { role: "assistant", content: [partialToolCall] },
|
||||
}),
|
||||
);
|
||||
|
||||
const stream = await invokeWrappedStream(baseFn);
|
||||
for await (const _item of stream) {
|
||||
// drain
|
||||
}
|
||||
|
||||
expect(partialToolCall.arguments).toEqual({});
|
||||
expect(streamedToolCall.arguments).toEqual({});
|
||||
});
|
||||
});
|
||||
|
||||
describe("isOllamaCompatProvider", () => {
|
||||
it("detects native ollama provider id", () => {
|
||||
expect(
|
||||
|
||||
@@ -11,7 +11,10 @@ import { resolveHeartbeatPrompt } from "../../../auto-reply/heartbeat.js";
|
||||
import { resolveChannelCapabilities } from "../../../config/channel-capabilities.js";
|
||||
import type { OpenClawConfig } from "../../../config/config.js";
|
||||
import { getMachineDisplayName } from "../../../infra/machine-name.js";
|
||||
import { ensureGlobalUndiciStreamTimeouts } from "../../../infra/net/undici-global-dispatcher.js";
|
||||
import {
|
||||
ensureGlobalUndiciEnvProxyDispatcher,
|
||||
ensureGlobalUndiciStreamTimeouts,
|
||||
} from "../../../infra/net/undici-global-dispatcher.js";
|
||||
import { MAX_IMAGE_BYTES } from "../../../media/constants.js";
|
||||
import { getGlobalHookRunner } from "../../../plugins/hook-runner-global.js";
|
||||
import type {
|
||||
@@ -433,6 +436,281 @@ export function wrapStreamFnTrimToolCallNames(
|
||||
};
|
||||
}
|
||||
|
||||
function extractBalancedJsonPrefix(raw: string): string | null {
|
||||
let start = 0;
|
||||
while (start < raw.length && /\s/.test(raw[start] ?? "")) {
|
||||
start += 1;
|
||||
}
|
||||
const startChar = raw[start];
|
||||
if (startChar !== "{" && startChar !== "[") {
|
||||
return null;
|
||||
}
|
||||
|
||||
let depth = 0;
|
||||
let inString = false;
|
||||
let escaped = false;
|
||||
for (let i = start; i < raw.length; i += 1) {
|
||||
const char = raw[i];
|
||||
if (char === undefined) {
|
||||
break;
|
||||
}
|
||||
if (inString) {
|
||||
if (escaped) {
|
||||
escaped = false;
|
||||
} else if (char === "\\") {
|
||||
escaped = true;
|
||||
} else if (char === '"') {
|
||||
inString = false;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (char === '"') {
|
||||
inString = true;
|
||||
continue;
|
||||
}
|
||||
if (char === "{" || char === "[") {
|
||||
depth += 1;
|
||||
continue;
|
||||
}
|
||||
if (char === "}" || char === "]") {
|
||||
depth -= 1;
|
||||
if (depth === 0) {
|
||||
return raw.slice(start, i + 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
const MAX_TOOLCALL_REPAIR_BUFFER_CHARS = 64_000;
|
||||
const MAX_TOOLCALL_REPAIR_TRAILING_CHARS = 3;
|
||||
const TOOLCALL_REPAIR_ALLOWED_TRAILING_RE = /^[^\s{}[\]":,\\]{1,3}$/;
|
||||
|
||||
function shouldAttemptMalformedToolCallRepair(partialJson: string, delta: string): boolean {
|
||||
if (/[}\]]/.test(delta)) {
|
||||
return true;
|
||||
}
|
||||
const trimmedDelta = delta.trim();
|
||||
return (
|
||||
trimmedDelta.length > 0 &&
|
||||
trimmedDelta.length <= MAX_TOOLCALL_REPAIR_TRAILING_CHARS &&
|
||||
/[}\]]/.test(partialJson)
|
||||
);
|
||||
}
|
||||
|
||||
type ToolCallArgumentRepair = {
|
||||
args: Record<string, unknown>;
|
||||
trailingSuffix: string;
|
||||
};
|
||||
|
||||
function tryParseMalformedToolCallArguments(raw: string): ToolCallArgumentRepair | undefined {
|
||||
if (!raw.trim()) {
|
||||
return undefined;
|
||||
}
|
||||
try {
|
||||
JSON.parse(raw);
|
||||
return undefined;
|
||||
} catch {
|
||||
const jsonPrefix = extractBalancedJsonPrefix(raw);
|
||||
if (!jsonPrefix) {
|
||||
return undefined;
|
||||
}
|
||||
const suffix = raw.slice(raw.indexOf(jsonPrefix) + jsonPrefix.length).trim();
|
||||
if (
|
||||
suffix.length === 0 ||
|
||||
suffix.length > MAX_TOOLCALL_REPAIR_TRAILING_CHARS ||
|
||||
!TOOLCALL_REPAIR_ALLOWED_TRAILING_RE.test(suffix)
|
||||
) {
|
||||
return undefined;
|
||||
}
|
||||
try {
|
||||
const parsed = JSON.parse(jsonPrefix) as unknown;
|
||||
return parsed && typeof parsed === "object" && !Array.isArray(parsed)
|
||||
? { args: parsed as Record<string, unknown>, trailingSuffix: suffix }
|
||||
: undefined;
|
||||
} catch {
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function repairToolCallArgumentsInMessage(
|
||||
message: unknown,
|
||||
contentIndex: number,
|
||||
repairedArgs: Record<string, unknown>,
|
||||
): void {
|
||||
if (!message || typeof message !== "object") {
|
||||
return;
|
||||
}
|
||||
const content = (message as { content?: unknown }).content;
|
||||
if (!Array.isArray(content)) {
|
||||
return;
|
||||
}
|
||||
const block = content[contentIndex];
|
||||
if (!block || typeof block !== "object") {
|
||||
return;
|
||||
}
|
||||
const typedBlock = block as { type?: unknown; arguments?: unknown };
|
||||
if (!isToolCallBlockType(typedBlock.type)) {
|
||||
return;
|
||||
}
|
||||
typedBlock.arguments = repairedArgs;
|
||||
}
|
||||
|
||||
function clearToolCallArgumentsInMessage(message: unknown, contentIndex: number): void {
|
||||
if (!message || typeof message !== "object") {
|
||||
return;
|
||||
}
|
||||
const content = (message as { content?: unknown }).content;
|
||||
if (!Array.isArray(content)) {
|
||||
return;
|
||||
}
|
||||
const block = content[contentIndex];
|
||||
if (!block || typeof block !== "object") {
|
||||
return;
|
||||
}
|
||||
const typedBlock = block as { type?: unknown; arguments?: unknown };
|
||||
if (!isToolCallBlockType(typedBlock.type)) {
|
||||
return;
|
||||
}
|
||||
typedBlock.arguments = {};
|
||||
}
|
||||
|
||||
function repairMalformedToolCallArgumentsInMessage(
|
||||
message: unknown,
|
||||
repairedArgsByIndex: Map<number, Record<string, unknown>>,
|
||||
): void {
|
||||
if (!message || typeof message !== "object") {
|
||||
return;
|
||||
}
|
||||
const content = (message as { content?: unknown }).content;
|
||||
if (!Array.isArray(content)) {
|
||||
return;
|
||||
}
|
||||
for (const [index, repairedArgs] of repairedArgsByIndex.entries()) {
|
||||
repairToolCallArgumentsInMessage(message, index, repairedArgs);
|
||||
}
|
||||
}
|
||||
|
||||
function wrapStreamRepairMalformedToolCallArguments(
|
||||
stream: ReturnType<typeof streamSimple>,
|
||||
): ReturnType<typeof streamSimple> {
|
||||
const partialJsonByIndex = new Map<number, string>();
|
||||
const repairedArgsByIndex = new Map<number, Record<string, unknown>>();
|
||||
const disabledIndices = new Set<number>();
|
||||
const loggedRepairIndices = new Set<number>();
|
||||
const originalResult = stream.result.bind(stream);
|
||||
stream.result = async () => {
|
||||
const message = await originalResult();
|
||||
repairMalformedToolCallArgumentsInMessage(message, repairedArgsByIndex);
|
||||
partialJsonByIndex.clear();
|
||||
repairedArgsByIndex.clear();
|
||||
disabledIndices.clear();
|
||||
loggedRepairIndices.clear();
|
||||
return message;
|
||||
};
|
||||
|
||||
const originalAsyncIterator = stream[Symbol.asyncIterator].bind(stream);
|
||||
(stream as { [Symbol.asyncIterator]: typeof originalAsyncIterator })[Symbol.asyncIterator] =
|
||||
function () {
|
||||
const iterator = originalAsyncIterator();
|
||||
return {
|
||||
async next() {
|
||||
const result = await iterator.next();
|
||||
if (!result.done && result.value && typeof result.value === "object") {
|
||||
const event = result.value as {
|
||||
type?: unknown;
|
||||
contentIndex?: unknown;
|
||||
delta?: unknown;
|
||||
partial?: unknown;
|
||||
message?: unknown;
|
||||
toolCall?: unknown;
|
||||
};
|
||||
if (
|
||||
typeof event.contentIndex === "number" &&
|
||||
Number.isInteger(event.contentIndex) &&
|
||||
event.type === "toolcall_delta" &&
|
||||
typeof event.delta === "string"
|
||||
) {
|
||||
if (disabledIndices.has(event.contentIndex)) {
|
||||
return result;
|
||||
}
|
||||
const nextPartialJson =
|
||||
(partialJsonByIndex.get(event.contentIndex) ?? "") + event.delta;
|
||||
if (nextPartialJson.length > MAX_TOOLCALL_REPAIR_BUFFER_CHARS) {
|
||||
partialJsonByIndex.delete(event.contentIndex);
|
||||
repairedArgsByIndex.delete(event.contentIndex);
|
||||
disabledIndices.add(event.contentIndex);
|
||||
return result;
|
||||
}
|
||||
partialJsonByIndex.set(event.contentIndex, nextPartialJson);
|
||||
if (shouldAttemptMalformedToolCallRepair(nextPartialJson, event.delta)) {
|
||||
const repair = tryParseMalformedToolCallArguments(nextPartialJson);
|
||||
if (repair) {
|
||||
repairedArgsByIndex.set(event.contentIndex, repair.args);
|
||||
repairToolCallArgumentsInMessage(event.partial, event.contentIndex, repair.args);
|
||||
repairToolCallArgumentsInMessage(event.message, event.contentIndex, repair.args);
|
||||
if (!loggedRepairIndices.has(event.contentIndex)) {
|
||||
loggedRepairIndices.add(event.contentIndex);
|
||||
log.warn(
|
||||
`repairing kimi-coding tool call arguments after ${repair.trailingSuffix.length} trailing chars`,
|
||||
);
|
||||
}
|
||||
} else {
|
||||
repairedArgsByIndex.delete(event.contentIndex);
|
||||
clearToolCallArgumentsInMessage(event.partial, event.contentIndex);
|
||||
clearToolCallArgumentsInMessage(event.message, event.contentIndex);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (
|
||||
typeof event.contentIndex === "number" &&
|
||||
Number.isInteger(event.contentIndex) &&
|
||||
event.type === "toolcall_end"
|
||||
) {
|
||||
const repairedArgs = repairedArgsByIndex.get(event.contentIndex);
|
||||
if (repairedArgs) {
|
||||
if (event.toolCall && typeof event.toolCall === "object") {
|
||||
(event.toolCall as { arguments?: unknown }).arguments = repairedArgs;
|
||||
}
|
||||
repairToolCallArgumentsInMessage(event.partial, event.contentIndex, repairedArgs);
|
||||
repairToolCallArgumentsInMessage(event.message, event.contentIndex, repairedArgs);
|
||||
}
|
||||
partialJsonByIndex.delete(event.contentIndex);
|
||||
disabledIndices.delete(event.contentIndex);
|
||||
loggedRepairIndices.delete(event.contentIndex);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
},
|
||||
async return(value?: unknown) {
|
||||
return iterator.return?.(value) ?? { done: true as const, value: undefined };
|
||||
},
|
||||
async throw(error?: unknown) {
|
||||
return iterator.throw?.(error) ?? { done: true as const, value: undefined };
|
||||
},
|
||||
};
|
||||
};
|
||||
|
||||
return stream;
|
||||
}
|
||||
|
||||
export function wrapStreamFnRepairMalformedToolCallArguments(baseFn: StreamFn): StreamFn {
|
||||
return (model, context, options) => {
|
||||
const maybeStream = baseFn(model, context, options);
|
||||
if (maybeStream && typeof maybeStream === "object" && "then" in maybeStream) {
|
||||
return Promise.resolve(maybeStream).then((stream) =>
|
||||
wrapStreamRepairMalformedToolCallArguments(stream),
|
||||
);
|
||||
}
|
||||
return wrapStreamRepairMalformedToolCallArguments(maybeStream);
|
||||
};
|
||||
}
|
||||
|
||||
function shouldRepairMalformedAnthropicToolCallArguments(provider?: string): boolean {
|
||||
return normalizeProviderId(provider ?? "") === "kimi-coding";
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// xAI / Grok: decode HTML entities in tool call arguments
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -749,6 +1027,9 @@ export async function runEmbeddedAttempt(
|
||||
const resolvedWorkspace = resolveUserPath(params.workspaceDir);
|
||||
const prevCwd = process.cwd();
|
||||
const runAbortController = new AbortController();
|
||||
// Proxy bootstrap must happen before timeout tuning so the timeouts wrap the
|
||||
// active EnvHttpProxyAgent instead of being replaced by a bare proxy dispatcher.
|
||||
ensureGlobalUndiciEnvProxyDispatcher();
|
||||
ensureGlobalUndiciStreamTimeouts();
|
||||
|
||||
log.debug(
|
||||
@@ -1373,6 +1654,15 @@ export async function runEmbeddedAttempt(
|
||||
allowedToolNames,
|
||||
);
|
||||
|
||||
if (
|
||||
params.model.api === "anthropic-messages" &&
|
||||
shouldRepairMalformedAnthropicToolCallArguments(params.provider)
|
||||
) {
|
||||
activeSession.agent.streamFn = wrapStreamFnRepairMalformedToolCallArguments(
|
||||
activeSession.agent.streamFn,
|
||||
);
|
||||
}
|
||||
|
||||
if (isXaiProvider(params.provider, params.modelId)) {
|
||||
activeSession.agent.streamFn = wrapStreamFnDecodeXaiToolCallArguments(
|
||||
activeSession.agent.streamFn,
|
||||
@@ -1768,6 +2058,8 @@ export async function runEmbeddedAttempt(
|
||||
sessionId: params.sessionId,
|
||||
workspaceDir: params.workspaceDir,
|
||||
messageProvider: params.messageProvider ?? undefined,
|
||||
trigger: params.trigger,
|
||||
channelId: params.messageChannel ?? params.messageProvider ?? undefined,
|
||||
},
|
||||
)
|
||||
.catch((err) => {
|
||||
@@ -1976,6 +2268,8 @@ export async function runEmbeddedAttempt(
|
||||
sessionId: params.sessionId,
|
||||
workspaceDir: params.workspaceDir,
|
||||
messageProvider: params.messageProvider ?? undefined,
|
||||
trigger: params.trigger,
|
||||
channelId: params.messageChannel ?? params.messageProvider ?? undefined,
|
||||
},
|
||||
)
|
||||
.catch((err) => {
|
||||
@@ -2036,6 +2330,8 @@ export async function runEmbeddedAttempt(
|
||||
sessionId: params.sessionId,
|
||||
workspaceDir: params.workspaceDir,
|
||||
messageProvider: params.messageProvider ?? undefined,
|
||||
trigger: params.trigger,
|
||||
channelId: params.messageChannel ?? params.messageProvider ?? undefined,
|
||||
},
|
||||
)
|
||||
.catch((err) => {
|
||||
|
||||
@@ -49,6 +49,30 @@ describe("pruneProcessedHistoryImages", () => {
|
||||
expect(first.content[1]).toMatchObject({ type: "image", data: "abc" });
|
||||
});
|
||||
|
||||
it("prunes image blocks from toolResult messages that already have assistant replies", () => {
|
||||
const messages: AgentMessage[] = [
|
||||
castAgentMessage({
|
||||
role: "toolResult",
|
||||
toolName: "read",
|
||||
content: [{ type: "text", text: "screenshot bytes" }, { ...image }],
|
||||
}),
|
||||
castAgentMessage({
|
||||
role: "assistant",
|
||||
content: "ack",
|
||||
}),
|
||||
];
|
||||
|
||||
const didMutate = pruneProcessedHistoryImages(messages);
|
||||
|
||||
expect(didMutate).toBe(true);
|
||||
const firstTool = messages[0] as Extract<AgentMessage, { role: "toolResult" }> | undefined;
|
||||
if (!firstTool || !Array.isArray(firstTool.content)) {
|
||||
throw new Error("expected toolResult array content");
|
||||
}
|
||||
expect(firstTool.content).toHaveLength(2);
|
||||
expect(firstTool.content[1]).toMatchObject({ type: "text", text: PRUNED_HISTORY_IMAGE_MARKER });
|
||||
});
|
||||
|
||||
it("does not change messages when no assistant turn exists", () => {
|
||||
const messages: AgentMessage[] = [
|
||||
castAgentMessage({
|
||||
|
||||
@@ -21,7 +21,11 @@ export function pruneProcessedHistoryImages(messages: AgentMessage[]): boolean {
|
||||
let didMutate = false;
|
||||
for (let i = 0; i < lastAssistantIndex; i++) {
|
||||
const message = messages[i];
|
||||
if (!message || message.role !== "user" || !Array.isArray(message.content)) {
|
||||
if (
|
||||
!message ||
|
||||
(message.role !== "user" && message.role !== "toolResult") ||
|
||||
!Array.isArray(message.content)
|
||||
) {
|
||||
continue;
|
||||
}
|
||||
for (let j = 0; j < message.content.length; j++) {
|
||||
|
||||
@@ -101,6 +101,18 @@ describe("buildEmbeddedRunPayloads", () => {
|
||||
expect(payloads[0]?.isError).toBe(true);
|
||||
});
|
||||
|
||||
it("does not emit a synthetic billing error for successful turns with stale errorMessage", () => {
|
||||
const payloads = buildPayloads({
|
||||
lastAssistant: makeAssistant({
|
||||
stopReason: "stop",
|
||||
errorMessage: "insufficient credits for embedding model",
|
||||
content: [{ type: "text", text: "Handle payment required errors in your API." }],
|
||||
}),
|
||||
});
|
||||
|
||||
expectSinglePayloadText(payloads, "Handle payment required errors in your API.");
|
||||
});
|
||||
|
||||
it("suppresses raw error JSON even when errorMessage is missing", () => {
|
||||
const payloads = buildPayloads({
|
||||
assistantTexts: [errorJsonPretty],
|
||||
|
||||
@@ -128,16 +128,17 @@ export function buildEmbeddedRunPayloads(params: {
|
||||
const useMarkdown = params.toolResultFormat === "markdown";
|
||||
const suppressAssistantArtifacts = params.didSendDeterministicApprovalPrompt === true;
|
||||
const lastAssistantErrored = params.lastAssistant?.stopReason === "error";
|
||||
const errorText = params.lastAssistant
|
||||
? suppressAssistantArtifacts
|
||||
? undefined
|
||||
: formatAssistantErrorText(params.lastAssistant, {
|
||||
cfg: params.config,
|
||||
sessionKey: params.sessionKey,
|
||||
provider: params.provider,
|
||||
model: params.model,
|
||||
})
|
||||
: undefined;
|
||||
const errorText =
|
||||
params.lastAssistant && lastAssistantErrored
|
||||
? suppressAssistantArtifacts
|
||||
? undefined
|
||||
: formatAssistantErrorText(params.lastAssistant, {
|
||||
cfg: params.config,
|
||||
sessionKey: params.sessionKey,
|
||||
provider: params.provider,
|
||||
model: params.model,
|
||||
})
|
||||
: undefined;
|
||||
const rawErrorMessage = lastAssistantErrored
|
||||
? params.lastAssistant?.errorMessage?.trim() || undefined
|
||||
: undefined;
|
||||
|
||||
@@ -134,6 +134,20 @@ describe("extractAssistantText", () => {
|
||||
);
|
||||
});
|
||||
|
||||
it("preserves response when errorMessage set from background failure (#13935)", () => {
|
||||
const responseText = "Handle payment required errors in your API.";
|
||||
const msg = makeAssistantMessage({
|
||||
role: "assistant",
|
||||
errorMessage: "insufficient credits for embedding model",
|
||||
stopReason: "stop",
|
||||
content: [{ type: "text", text: responseText }],
|
||||
timestamp: Date.now(),
|
||||
});
|
||||
|
||||
const result = extractAssistantText(msg);
|
||||
expect(result).toBe(responseText);
|
||||
});
|
||||
|
||||
it("strips Minimax tool invocations with extra attributes", () => {
|
||||
const msg = makeAssistantMessage({
|
||||
role: "assistant",
|
||||
|
||||
@@ -245,7 +245,9 @@ export function extractAssistantText(msg: AssistantMessage): string {
|
||||
}) ?? "";
|
||||
// Only apply keyword-based error rewrites when the assistant message is actually an error.
|
||||
// Otherwise normal prose that *mentions* errors (e.g. "context overflow") can get clobbered.
|
||||
const errorContext = msg.stopReason === "error" || Boolean(msg.errorMessage?.trim());
|
||||
// Gate on stopReason only — a non-error response with an errorMessage set (e.g. from a
|
||||
// background tool failure) should not have its content rewritten (#13935).
|
||||
const errorContext = msg.stopReason === "error";
|
||||
return sanitizeUserFacingText(extracted, { errorContext });
|
||||
}
|
||||
|
||||
|
||||
@@ -358,21 +358,26 @@ describe("context-pruning", () => {
|
||||
expect(toolText(findToolResult(next, "t2"))).toContain("y".repeat(20_000));
|
||||
});
|
||||
|
||||
it("skips tool results that contain images (no soft trim, no hard clear)", () => {
|
||||
it("replaces image blocks in tool results during soft trim", () => {
|
||||
const messages: AgentMessage[] = [
|
||||
makeUser("u1"),
|
||||
makeImageToolResult({
|
||||
toolCallId: "t1",
|
||||
toolName: "exec",
|
||||
text: "x".repeat(20_000),
|
||||
text: "visible tool text",
|
||||
}),
|
||||
];
|
||||
|
||||
const next = pruneWithAggressiveDefaults(messages);
|
||||
const next = pruneWithAggressiveDefaults(messages, {
|
||||
hardClearRatio: 10.0,
|
||||
hardClear: { enabled: false, placeholder: "[cleared]" },
|
||||
softTrim: { maxChars: 200, headChars: 100, tailChars: 100 },
|
||||
});
|
||||
|
||||
const tool = findToolResult(next, "t1");
|
||||
expect(tool.content.some((b) => b.type === "image")).toBe(true);
|
||||
expect(toolText(tool)).toContain("x".repeat(20_000));
|
||||
expect(tool.content.some((b) => b.type === "image")).toBe(false);
|
||||
expect(toolText(tool)).toContain("[image removed during context pruning]");
|
||||
expect(toolText(tool)).toContain("visible tool text");
|
||||
});
|
||||
|
||||
it("soft-trims across block boundaries", () => {
|
||||
|
||||
@@ -45,6 +45,19 @@ function makeAssistant(content: AssistantMessage["content"]): AgentMessage {
|
||||
};
|
||||
}
|
||||
|
||||
function makeToolResult(
|
||||
content: Array<
|
||||
{ type: "text"; text: string } | { type: "image"; data: string; mimeType: string }
|
||||
>,
|
||||
): AgentMessage {
|
||||
return {
|
||||
role: "toolResult",
|
||||
toolName: "read",
|
||||
content,
|
||||
timestamp: Date.now(),
|
||||
} as AgentMessage;
|
||||
}
|
||||
|
||||
describe("pruneContextMessages", () => {
|
||||
it("does not crash on assistant message with malformed thinking block (missing thinking string)", () => {
|
||||
const messages: AgentMessage[] = [
|
||||
@@ -109,4 +122,119 @@ describe("pruneContextMessages", () => {
|
||||
});
|
||||
expect(result).toHaveLength(2);
|
||||
});
|
||||
|
||||
it("soft-trims image-containing tool results by replacing image blocks with placeholders", () => {
|
||||
const messages: AgentMessage[] = [
|
||||
makeUser("summarize this"),
|
||||
makeToolResult([
|
||||
{ type: "text", text: "A".repeat(120) },
|
||||
{ type: "image", data: "img", mimeType: "image/png" },
|
||||
{ type: "text", text: "B".repeat(120) },
|
||||
]),
|
||||
makeAssistant([{ type: "text", text: "done" }]),
|
||||
];
|
||||
|
||||
const result = pruneContextMessages({
|
||||
messages,
|
||||
settings: {
|
||||
...DEFAULT_CONTEXT_PRUNING_SETTINGS,
|
||||
keepLastAssistants: 1,
|
||||
softTrimRatio: 0,
|
||||
hardClear: {
|
||||
...DEFAULT_CONTEXT_PRUNING_SETTINGS.hardClear,
|
||||
enabled: false,
|
||||
},
|
||||
softTrim: {
|
||||
maxChars: 200,
|
||||
headChars: 170,
|
||||
tailChars: 30,
|
||||
},
|
||||
},
|
||||
ctx: CONTEXT_WINDOW_1M,
|
||||
isToolPrunable: () => true,
|
||||
contextWindowTokensOverride: 16,
|
||||
});
|
||||
|
||||
const toolResult = result[1] as Extract<AgentMessage, { role: "toolResult" }>;
|
||||
expect(toolResult.content).toHaveLength(1);
|
||||
expect(toolResult.content[0]).toMatchObject({ type: "text" });
|
||||
const textBlock = toolResult.content[0] as { type: "text"; text: string };
|
||||
expect(textBlock.text).toContain("[image removed during context pruning]");
|
||||
expect(textBlock.text).toContain(
|
||||
"[Tool result trimmed: kept first 170 chars and last 30 chars",
|
||||
);
|
||||
});
|
||||
|
||||
it("replaces image-only tool results with placeholders even when text trimming is not needed", () => {
|
||||
const messages: AgentMessage[] = [
|
||||
makeUser("summarize this"),
|
||||
makeToolResult([{ type: "image", data: "img", mimeType: "image/png" }]),
|
||||
makeAssistant([{ type: "text", text: "done" }]),
|
||||
];
|
||||
|
||||
const result = pruneContextMessages({
|
||||
messages,
|
||||
settings: {
|
||||
...DEFAULT_CONTEXT_PRUNING_SETTINGS,
|
||||
keepLastAssistants: 1,
|
||||
softTrimRatio: 0,
|
||||
hardClearRatio: 10,
|
||||
hardClear: {
|
||||
...DEFAULT_CONTEXT_PRUNING_SETTINGS.hardClear,
|
||||
enabled: false,
|
||||
},
|
||||
softTrim: {
|
||||
maxChars: 5_000,
|
||||
headChars: 2_000,
|
||||
tailChars: 2_000,
|
||||
},
|
||||
},
|
||||
ctx: CONTEXT_WINDOW_1M,
|
||||
isToolPrunable: () => true,
|
||||
contextWindowTokensOverride: 1,
|
||||
});
|
||||
|
||||
const toolResult = result[1] as Extract<AgentMessage, { role: "toolResult" }>;
|
||||
expect(toolResult.content).toEqual([
|
||||
{ type: "text", text: "[image removed during context pruning]" },
|
||||
]);
|
||||
});
|
||||
|
||||
it("hard-clears image-containing tool results once ratios require clearing", () => {
|
||||
const messages: AgentMessage[] = [
|
||||
makeUser("summarize this"),
|
||||
makeToolResult([
|
||||
{ type: "text", text: "small text" },
|
||||
{ type: "image", data: "img", mimeType: "image/png" },
|
||||
]),
|
||||
makeAssistant([{ type: "text", text: "done" }]),
|
||||
];
|
||||
|
||||
const placeholder = "[hard cleared test placeholder]";
|
||||
const result = pruneContextMessages({
|
||||
messages,
|
||||
settings: {
|
||||
...DEFAULT_CONTEXT_PRUNING_SETTINGS,
|
||||
keepLastAssistants: 1,
|
||||
softTrimRatio: 0,
|
||||
hardClearRatio: 0,
|
||||
minPrunableToolChars: 1,
|
||||
softTrim: {
|
||||
maxChars: 5_000,
|
||||
headChars: 2_000,
|
||||
tailChars: 2_000,
|
||||
},
|
||||
hardClear: {
|
||||
enabled: true,
|
||||
placeholder,
|
||||
},
|
||||
},
|
||||
ctx: CONTEXT_WINDOW_1M,
|
||||
isToolPrunable: () => true,
|
||||
contextWindowTokensOverride: 8,
|
||||
});
|
||||
|
||||
const toolResult = result[1] as Extract<AgentMessage, { role: "toolResult" }>;
|
||||
expect(toolResult.content).toEqual([{ type: "text", text: placeholder }]);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -5,9 +5,8 @@ import type { EffectiveContextPruningSettings } from "./settings.js";
|
||||
import { makeToolPrunablePredicate } from "./tools.js";
|
||||
|
||||
const CHARS_PER_TOKEN_ESTIMATE = 4;
|
||||
// We currently skip pruning tool results that contain images. Still, we count them (approx.) so
|
||||
// we start trimming prunable tool results earlier when image-heavy context is consuming the window.
|
||||
const IMAGE_CHAR_ESTIMATE = 8_000;
|
||||
const PRUNED_CONTEXT_IMAGE_MARKER = "[image removed during context pruning]";
|
||||
|
||||
function asText(text: string): TextContent {
|
||||
return { type: "text", text };
|
||||
@@ -23,6 +22,22 @@ function collectTextSegments(content: ReadonlyArray<TextContent | ImageContent>)
|
||||
return parts;
|
||||
}
|
||||
|
||||
function collectPrunableToolResultSegments(
|
||||
content: ReadonlyArray<TextContent | ImageContent>,
|
||||
): string[] {
|
||||
const parts: string[] = [];
|
||||
for (const block of content) {
|
||||
if (block.type === "text") {
|
||||
parts.push(block.text);
|
||||
continue;
|
||||
}
|
||||
if (block.type === "image") {
|
||||
parts.push(PRUNED_CONTEXT_IMAGE_MARKER);
|
||||
}
|
||||
}
|
||||
return parts;
|
||||
}
|
||||
|
||||
function estimateJoinedTextLength(parts: string[]): number {
|
||||
if (parts.length === 0) {
|
||||
return 0;
|
||||
@@ -190,21 +205,25 @@ function softTrimToolResultMessage(params: {
|
||||
settings: EffectiveContextPruningSettings;
|
||||
}): ToolResultMessage | null {
|
||||
const { msg, settings } = params;
|
||||
// Ignore image tool results for now: these are often directly relevant and hard to partially prune safely.
|
||||
if (hasImageBlocks(msg.content)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const parts = collectTextSegments(msg.content);
|
||||
const hasImages = hasImageBlocks(msg.content);
|
||||
const parts = hasImages
|
||||
? collectPrunableToolResultSegments(msg.content)
|
||||
: collectTextSegments(msg.content);
|
||||
const rawLen = estimateJoinedTextLength(parts);
|
||||
if (rawLen <= settings.softTrim.maxChars) {
|
||||
return null;
|
||||
if (!hasImages) {
|
||||
return null;
|
||||
}
|
||||
return { ...msg, content: [asText(parts.join("\n"))] };
|
||||
}
|
||||
|
||||
const headChars = Math.max(0, settings.softTrim.headChars);
|
||||
const tailChars = Math.max(0, settings.softTrim.tailChars);
|
||||
if (headChars + tailChars >= rawLen) {
|
||||
return null;
|
||||
if (!hasImages) {
|
||||
return null;
|
||||
}
|
||||
return { ...msg, content: [asText(parts.join("\n"))] };
|
||||
}
|
||||
|
||||
const head = takeHeadFromJoinedText(parts, headChars);
|
||||
@@ -274,9 +293,6 @@ export function pruneContextMessages(params: {
|
||||
if (!isToolPrunable(msg.toolName)) {
|
||||
continue;
|
||||
}
|
||||
if (hasImageBlocks(msg.content)) {
|
||||
continue;
|
||||
}
|
||||
prunableToolIndexes.push(i);
|
||||
|
||||
const updated = softTrimToolResultMessage({
|
||||
|
||||
@@ -3,10 +3,14 @@ import os from "node:os";
|
||||
import path from "node:path";
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
|
||||
|
||||
vi.mock("@mariozechner/pi-ai/oauth", () => ({
|
||||
getOAuthApiKey: () => undefined,
|
||||
getOAuthProviders: () => [],
|
||||
}));
|
||||
vi.mock("@mariozechner/pi-ai", async (importOriginal) => {
|
||||
const original = await importOriginal<typeof import("@mariozechner/pi-ai")>();
|
||||
return {
|
||||
...original,
|
||||
getOAuthApiKey: () => undefined,
|
||||
getOAuthProviders: () => [],
|
||||
};
|
||||
});
|
||||
|
||||
import { createOpenClawCodingTools } from "./pi-tools.js";
|
||||
|
||||
|
||||
@@ -47,6 +47,7 @@ describe("resolveProviderCapabilities", () => {
|
||||
it("flags providers that opt out of OpenAI-compatible turn validation", () => {
|
||||
expect(supportsOpenAiCompatTurnValidation("openrouter")).toBe(false);
|
||||
expect(supportsOpenAiCompatTurnValidation("opencode")).toBe(false);
|
||||
expect(supportsOpenAiCompatTurnValidation("opencode-go")).toBe(false);
|
||||
expect(supportsOpenAiCompatTurnValidation("moonshot")).toBe(true);
|
||||
});
|
||||
|
||||
@@ -63,6 +64,12 @@ describe("resolveProviderCapabilities", () => {
|
||||
modelId: "gemini-2.0-flash",
|
||||
}),
|
||||
).toBe(true);
|
||||
expect(
|
||||
shouldSanitizeGeminiThoughtSignaturesForModel({
|
||||
provider: "opencode-go",
|
||||
modelId: "google/gemini-2.5-pro-preview",
|
||||
}),
|
||||
).toBe(true);
|
||||
expect(resolveTranscriptToolCallIdMode("mistral", "mistral-large-latest")).toBe("strict9");
|
||||
});
|
||||
|
||||
|
||||
@@ -66,6 +66,11 @@ const PROVIDER_CAPABILITIES: Record<string, Partial<ProviderCapabilities>> = {
|
||||
geminiThoughtSignatureSanitization: true,
|
||||
geminiThoughtSignatureModelHints: ["gemini"],
|
||||
},
|
||||
"opencode-go": {
|
||||
openAiCompatTurnValidation: false,
|
||||
geminiThoughtSignatureSanitization: true,
|
||||
geminiThoughtSignatureModelHints: ["gemini"],
|
||||
},
|
||||
kilocode: {
|
||||
geminiThoughtSignatureSanitization: true,
|
||||
geminiThoughtSignatureModelHints: ["gemini"],
|
||||
|
||||
@@ -137,6 +137,33 @@ describe("buildSandboxCreateArgs", () => {
|
||||
);
|
||||
});
|
||||
|
||||
it("preserves the OpenClaw exec marker when strict env sanitization is enabled", () => {
|
||||
const cfg = createSandboxConfig({
|
||||
env: {
|
||||
NODE_ENV: "test",
|
||||
},
|
||||
});
|
||||
|
||||
const args = buildSandboxCreateArgs({
|
||||
name: "openclaw-sbx-marker",
|
||||
cfg,
|
||||
scopeKey: "main",
|
||||
createdAtMs: 1700000000000,
|
||||
envSanitizationOptions: {
|
||||
strictMode: true,
|
||||
},
|
||||
});
|
||||
|
||||
expect(args).toEqual(
|
||||
expect.arrayContaining([
|
||||
"--env",
|
||||
"NODE_ENV=test",
|
||||
"--env",
|
||||
`OPENCLAW_CLI=${OPENCLAW_CLI_ENV_VALUE}`,
|
||||
]),
|
||||
);
|
||||
});
|
||||
|
||||
it("emits -v flags for safe custom binds", () => {
|
||||
const cfg: SandboxDockerConfig = {
|
||||
image: "openclaw-sandbox:bookworm-slim",
|
||||
|
||||
@@ -5,6 +5,7 @@ import {
|
||||
resolveWindowsSpawnProgram,
|
||||
} from "../../plugin-sdk/windows-spawn.js";
|
||||
import { sanitizeEnvVars } from "./sanitize-env-vars.js";
|
||||
import type { EnvSanitizationOptions } from "./sanitize-env-vars.js";
|
||||
|
||||
type ExecDockerRawOptions = {
|
||||
allowFailure?: boolean;
|
||||
@@ -52,7 +53,7 @@ export function resolveDockerSpawnInvocation(
|
||||
env: runtime.env,
|
||||
execPath: runtime.execPath,
|
||||
packageName: "docker",
|
||||
allowShellFallback: true,
|
||||
allowShellFallback: false,
|
||||
});
|
||||
const resolved = materializeWindowsSpawnProgram(program, args);
|
||||
return {
|
||||
@@ -325,6 +326,7 @@ export function buildSandboxCreateArgs(params: {
|
||||
allowSourcesOutsideAllowedRoots?: boolean;
|
||||
allowReservedContainerTargets?: boolean;
|
||||
allowContainerNamespaceJoin?: boolean;
|
||||
envSanitizationOptions?: EnvSanitizationOptions;
|
||||
}) {
|
||||
// Runtime security validation: blocks dangerous bind mounts, network modes, and profiles.
|
||||
validateSandboxSecurity({
|
||||
@@ -366,14 +368,14 @@ export function buildSandboxCreateArgs(params: {
|
||||
if (params.cfg.user) {
|
||||
args.push("--user", params.cfg.user);
|
||||
}
|
||||
const envSanitization = sanitizeEnvVars(markOpenClawExecEnv(params.cfg.env ?? {}));
|
||||
const envSanitization = sanitizeEnvVars(params.cfg.env ?? {}, params.envSanitizationOptions);
|
||||
if (envSanitization.blocked.length > 0) {
|
||||
log.warn(`Blocked sensitive environment variables: ${envSanitization.blocked.join(", ")}`);
|
||||
}
|
||||
if (envSanitization.warnings.length > 0) {
|
||||
log.warn(`Suspicious environment variables: ${envSanitization.warnings.join(", ")}`);
|
||||
}
|
||||
for (const [key, value] of Object.entries(envSanitization.allowed)) {
|
||||
for (const [key, value] of Object.entries(markOpenClawExecEnv(envSanitization.allowed))) {
|
||||
args.push("--env", `${key}=${value}`);
|
||||
}
|
||||
for (const cap of params.cfg.capDrop) {
|
||||
|
||||
@@ -47,22 +47,20 @@ describe("resolveDockerSpawnInvocation", () => {
|
||||
});
|
||||
});
|
||||
|
||||
it("falls back to shell mode when only unresolved docker.cmd wrapper exists", async () => {
|
||||
it("rejects unresolved docker.cmd wrappers instead of shelling out", async () => {
|
||||
const dir = await createTempDir();
|
||||
const cmdPath = path.join(dir, "docker.cmd");
|
||||
await mkdir(path.dirname(cmdPath), { recursive: true });
|
||||
await writeFile(cmdPath, "@ECHO off\r\necho docker\r\n", "utf8");
|
||||
|
||||
const resolved = resolveDockerSpawnInvocation(["ps"], {
|
||||
platform: "win32",
|
||||
env: { PATH: dir, PATHEXT: ".CMD;.EXE;.BAT" },
|
||||
execPath: "C:\\node\\node.exe",
|
||||
});
|
||||
expect(path.normalize(resolved.command).toLowerCase()).toBe(
|
||||
path.normalize(cmdPath).toLowerCase(),
|
||||
expect(() =>
|
||||
resolveDockerSpawnInvocation(["ps"], {
|
||||
platform: "win32",
|
||||
env: { PATH: dir, PATHEXT: ".CMD;.EXE;.BAT" },
|
||||
execPath: "C:\\node\\node.exe",
|
||||
}),
|
||||
).toThrow(
|
||||
/wrapper resolved, but no executable\/Node entrypoint could be resolved without shell execution\./i,
|
||||
);
|
||||
expect(resolved.args).toEqual(["ps"]);
|
||||
expect(resolved.shell).toBe(true);
|
||||
expect(resolved.windowsHide).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
@@ -3,7 +3,10 @@ import fs from "node:fs/promises";
|
||||
import os from "node:os";
|
||||
import path from "node:path";
|
||||
import { describe, expect, it } from "vitest";
|
||||
import { SANDBOX_PINNED_MUTATION_PYTHON } from "./fs-bridge-mutation-helper.js";
|
||||
import {
|
||||
buildPinnedWritePlan,
|
||||
SANDBOX_PINNED_MUTATION_PYTHON,
|
||||
} from "./fs-bridge-mutation-helper.js";
|
||||
|
||||
async function withTempRoot<T>(prefix: string, run: (root: string) => Promise<T>): Promise<T> {
|
||||
const root = await fs.mkdtemp(path.join(os.tmpdir(), prefix));
|
||||
@@ -22,6 +25,35 @@ function runMutation(args: string[], input?: string) {
|
||||
});
|
||||
}
|
||||
|
||||
function runWritePlan(args: string[], input?: string) {
|
||||
const plan = buildPinnedWritePlan({
|
||||
check: {
|
||||
target: {
|
||||
hostPath: args[1] ?? "",
|
||||
containerPath: args[1] ?? "",
|
||||
relativePath: path.posix.join(args[2] ?? "", args[3] ?? ""),
|
||||
writable: true,
|
||||
},
|
||||
options: {
|
||||
action: "write files",
|
||||
requireWritable: true,
|
||||
},
|
||||
},
|
||||
pinned: {
|
||||
mountRootPath: args[1] ?? "",
|
||||
relativeParentPath: args[2] ?? "",
|
||||
basename: args[3] ?? "",
|
||||
},
|
||||
mkdir: args[4] === "1",
|
||||
});
|
||||
|
||||
return spawnSync("sh", ["-c", plan.script, "moltbot-sandbox-fs", ...(plan.args ?? [])], {
|
||||
input,
|
||||
encoding: "utf8",
|
||||
stdio: ["pipe", "pipe", "pipe"],
|
||||
});
|
||||
}
|
||||
|
||||
describe("sandbox pinned mutation helper", () => {
|
||||
it("writes through a pinned directory fd", async () => {
|
||||
await withTempRoot("openclaw-mutation-helper-", async (root) => {
|
||||
@@ -37,6 +69,26 @@ describe("sandbox pinned mutation helper", () => {
|
||||
});
|
||||
});
|
||||
|
||||
it.runIf(process.platform !== "win32")(
|
||||
"preserves stdin payload bytes when the pinned write plan runs through sh",
|
||||
async () => {
|
||||
await withTempRoot("openclaw-mutation-helper-", async (root) => {
|
||||
const workspace = path.join(root, "workspace");
|
||||
await fs.mkdir(workspace, { recursive: true });
|
||||
|
||||
const result = runWritePlan(
|
||||
["write", workspace, "nested/deeper", "note.txt", "1"],
|
||||
"hello",
|
||||
);
|
||||
|
||||
expect(result.status).toBe(0);
|
||||
await expect(
|
||||
fs.readFile(path.join(workspace, "nested", "deeper", "note.txt"), "utf8"),
|
||||
).resolves.toBe("hello");
|
||||
});
|
||||
},
|
||||
);
|
||||
|
||||
it.runIf(process.platform !== "win32")(
|
||||
"rejects symlink-parent writes instead of materializing a temp file outside the mount",
|
||||
async () => {
|
||||
|
||||
@@ -257,7 +257,13 @@ function buildPinnedMutationPlan(params: {
|
||||
return {
|
||||
checks: params.checks,
|
||||
recheckBeforeCommand: true,
|
||||
script: ["set -eu", "python3 - \"$@\" <<'PY'", SANDBOX_PINNED_MUTATION_PYTHON, "PY"].join("\n"),
|
||||
// Feed the helper source over fd 3 so stdin stays available for write payload bytes.
|
||||
script: [
|
||||
"set -eu",
|
||||
"python3 /dev/fd/3 \"$@\" 3<<'PY'",
|
||||
SANDBOX_PINNED_MUTATION_PYTHON,
|
||||
"PY",
|
||||
].join("\n"),
|
||||
args: params.args,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -24,6 +24,11 @@ export type PinnedSandboxEntry = {
|
||||
basename: string;
|
||||
};
|
||||
|
||||
export type AnchoredSandboxEntry = {
|
||||
canonicalParentPath: string;
|
||||
basename: string;
|
||||
};
|
||||
|
||||
export type PinnedSandboxDirectoryEntry = {
|
||||
mountRootPath: string;
|
||||
relativePath: string;
|
||||
@@ -154,6 +159,48 @@ export class SandboxFsPathGuard {
|
||||
};
|
||||
}
|
||||
|
||||
async resolveAnchoredSandboxEntry(
|
||||
target: SandboxResolvedFsPath,
|
||||
action: string,
|
||||
): Promise<AnchoredSandboxEntry> {
|
||||
const basename = path.posix.basename(target.containerPath);
|
||||
if (!basename || basename === "." || basename === "/") {
|
||||
throw new Error(`Invalid sandbox entry target: ${target.containerPath}`);
|
||||
}
|
||||
const parentPath = normalizeContainerPath(path.posix.dirname(target.containerPath));
|
||||
const canonicalParentPath = await this.resolveCanonicalContainerPath({
|
||||
containerPath: parentPath,
|
||||
allowFinalSymlinkForUnlink: false,
|
||||
});
|
||||
this.resolveRequiredMount(canonicalParentPath, action);
|
||||
return {
|
||||
canonicalParentPath,
|
||||
basename,
|
||||
};
|
||||
}
|
||||
|
||||
async resolveAnchoredPinnedEntry(
|
||||
target: SandboxResolvedFsPath,
|
||||
action: string,
|
||||
): Promise<PinnedSandboxEntry> {
|
||||
const anchoredTarget = await this.resolveAnchoredSandboxEntry(target, action);
|
||||
const mount = this.resolveRequiredMount(anchoredTarget.canonicalParentPath, action);
|
||||
const relativeParentPath = path.posix.relative(
|
||||
mount.containerRoot,
|
||||
anchoredTarget.canonicalParentPath,
|
||||
);
|
||||
if (relativeParentPath.startsWith("..") || path.posix.isAbsolute(relativeParentPath)) {
|
||||
throw new Error(
|
||||
`Sandbox path escapes allowed mounts; cannot ${action}: ${target.containerPath}`,
|
||||
);
|
||||
}
|
||||
return {
|
||||
mountRootPath: mount.containerRoot,
|
||||
relativeParentPath: relativeParentPath === "." ? "" : relativeParentPath,
|
||||
basename: anchoredTarget.basename,
|
||||
};
|
||||
}
|
||||
|
||||
resolvePinnedDirectoryEntry(
|
||||
target: SandboxResolvedFsPath,
|
||||
action: string,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import type { PathSafetyCheck } from "./fs-bridge-path-safety.js";
|
||||
import type { AnchoredSandboxEntry, PathSafetyCheck } from "./fs-bridge-path-safety.js";
|
||||
import type { SandboxResolvedFsPath } from "./fs-paths.js";
|
||||
|
||||
export type SandboxFsCommandPlan = {
|
||||
@@ -10,11 +10,14 @@ export type SandboxFsCommandPlan = {
|
||||
allowFailure?: boolean;
|
||||
};
|
||||
|
||||
export function buildStatPlan(target: SandboxResolvedFsPath): SandboxFsCommandPlan {
|
||||
export function buildStatPlan(
|
||||
target: SandboxResolvedFsPath,
|
||||
anchoredTarget: AnchoredSandboxEntry,
|
||||
): SandboxFsCommandPlan {
|
||||
return {
|
||||
checks: [{ target, options: { action: "stat files" } }],
|
||||
script: 'set -eu; stat -c "%F|%s|%Y" -- "$1"',
|
||||
args: [target.containerPath],
|
||||
script: 'set -eu\ncd -- "$1"\nstat -c "%F|%s|%Y" -- "$2"',
|
||||
args: [anchoredTarget.canonicalParentPath, anchoredTarget.basename],
|
||||
allowFailure: true,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -4,7 +4,12 @@ import { describe, expect, it } from "vitest";
|
||||
import {
|
||||
createSandbox,
|
||||
createSandboxFsBridge,
|
||||
dockerExecResult,
|
||||
findCallsByScriptFragment,
|
||||
findCallByDockerArg,
|
||||
findCallByScriptFragment,
|
||||
getDockerArg,
|
||||
getDockerScript,
|
||||
installFsBridgeTestHarness,
|
||||
mockedExecDockerRaw,
|
||||
withTempDir,
|
||||
@@ -66,6 +71,13 @@ describe("sandbox fs bridge anchored ops", () => {
|
||||
});
|
||||
|
||||
const pinnedCases = [
|
||||
{
|
||||
name: "write pins canonical parent + basename",
|
||||
invoke: (bridge: ReturnType<typeof createSandboxFsBridge>) =>
|
||||
bridge.writeFile({ filePath: "nested/file.txt", data: "updated" }),
|
||||
expectedArgs: ["write", "/workspace", "nested", "file.txt", "1"],
|
||||
forbiddenArgs: ["/workspace/nested/file.txt"],
|
||||
},
|
||||
{
|
||||
name: "mkdirp pins mount root + relative path",
|
||||
invoke: (bridge: ReturnType<typeof createSandboxFsBridge>) =>
|
||||
@@ -108,7 +120,7 @@ describe("sandbox fs bridge anchored ops", () => {
|
||||
const opCall = mockedExecDockerRaw.mock.calls.find(
|
||||
([args]) =>
|
||||
typeof args[5] === "string" &&
|
||||
args[5].includes("python3 - \"$@\" <<'PY'") &&
|
||||
args[5].includes("python3 /dev/fd/3 \"$@\" 3<<'PY'") &&
|
||||
getDockerArg(args, 1) === testCase.expectedArgs[0],
|
||||
);
|
||||
expect(opCall).toBeDefined();
|
||||
@@ -121,4 +133,74 @@ describe("sandbox fs bridge anchored ops", () => {
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
it.runIf(process.platform !== "win32")(
|
||||
"write resolves symlink parents to canonical pinned paths",
|
||||
async () => {
|
||||
await withTempDir("openclaw-fs-bridge-contract-write-", async (stateDir) => {
|
||||
const workspaceDir = path.join(stateDir, "workspace");
|
||||
const realDir = path.join(workspaceDir, "real");
|
||||
await fs.mkdir(realDir, { recursive: true });
|
||||
await fs.symlink(realDir, path.join(workspaceDir, "alias"));
|
||||
|
||||
mockedExecDockerRaw.mockImplementation(async (args) => {
|
||||
const script = getDockerScript(args);
|
||||
if (script.includes('readlink -f -- "$cursor"')) {
|
||||
const target = getDockerArg(args, 1);
|
||||
return dockerExecResult(`${target.replace("/workspace/alias", "/workspace/real")}\n`);
|
||||
}
|
||||
if (script.includes('stat -c "%F|%s|%Y"')) {
|
||||
return dockerExecResult("regular file|1|2");
|
||||
}
|
||||
return dockerExecResult("");
|
||||
});
|
||||
|
||||
const bridge = createSandboxFsBridge({
|
||||
sandbox: createSandbox({
|
||||
workspaceDir,
|
||||
agentWorkspaceDir: workspaceDir,
|
||||
}),
|
||||
});
|
||||
|
||||
await bridge.writeFile({ filePath: "alias/note.txt", data: "updated" });
|
||||
|
||||
const writeCall = findCallByDockerArg(1, "write");
|
||||
expect(writeCall).toBeDefined();
|
||||
const args = writeCall?.[0] ?? [];
|
||||
expect(getDockerArg(args, 2)).toBe("/workspace");
|
||||
expect(getDockerArg(args, 3)).toBe("real");
|
||||
expect(getDockerArg(args, 4)).toBe("note.txt");
|
||||
expect(args).not.toContain("alias");
|
||||
|
||||
const canonicalCalls = findCallsByScriptFragment('readlink -f -- "$cursor"');
|
||||
expect(
|
||||
canonicalCalls.some(([callArgs]) => getDockerArg(callArgs, 1) === "/workspace/alias"),
|
||||
).toBe(true);
|
||||
});
|
||||
},
|
||||
);
|
||||
|
||||
it("stat anchors parent + basename", async () => {
|
||||
await withTempDir("openclaw-fs-bridge-contract-stat-", async (stateDir) => {
|
||||
const workspaceDir = path.join(stateDir, "workspace");
|
||||
await fs.mkdir(path.join(workspaceDir, "nested"), { recursive: true });
|
||||
await fs.writeFile(path.join(workspaceDir, "nested", "file.txt"), "bye", "utf8");
|
||||
|
||||
const bridge = createSandboxFsBridge({
|
||||
sandbox: createSandbox({
|
||||
workspaceDir,
|
||||
agentWorkspaceDir: workspaceDir,
|
||||
}),
|
||||
});
|
||||
|
||||
await bridge.stat({ filePath: "nested/file.txt" });
|
||||
|
||||
const statCall = findCallByScriptFragment('stat -c "%F|%s|%Y" -- "$2"');
|
||||
expect(statCall).toBeDefined();
|
||||
const args = statCall?.[0] ?? [];
|
||||
expect(getDockerArg(args, 1)).toBe("/workspace/nested");
|
||||
expect(getDockerArg(args, 2)).toBe("file.txt");
|
||||
expect(args).not.toContain("/workspace/nested/file.txt");
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -129,6 +129,10 @@ describe("sandbox fs bridge shell compatibility", () => {
|
||||
await bridge.writeFile({ filePath: "b.txt", data: "hello" });
|
||||
|
||||
const scripts = getScriptsFromCalls();
|
||||
expect(scripts.some((script) => script.includes("python3 - \"$@\" <<'PY'"))).toBe(false);
|
||||
expect(scripts.some((script) => script.includes("python3 /dev/fd/3 \"$@\" 3<<'PY'"))).toBe(
|
||||
true,
|
||||
);
|
||||
expect(scripts.some((script) => script.includes('cat >"$1"'))).toBe(false);
|
||||
expect(scripts.some((script) => script.includes('cat >"$tmp"'))).toBe(false);
|
||||
expect(scripts.some((script) => script.includes("os.replace("))).toBe(true);
|
||||
|
||||
@@ -118,7 +118,10 @@ class SandboxFsBridgeImpl implements SandboxFsBridge {
|
||||
const buffer = Buffer.isBuffer(params.data)
|
||||
? params.data
|
||||
: Buffer.from(params.data, params.encoding ?? "utf8");
|
||||
const pinnedWriteTarget = this.pathGuard.resolvePinnedEntry(target, "write files");
|
||||
const pinnedWriteTarget = await this.pathGuard.resolveAnchoredPinnedEntry(
|
||||
target,
|
||||
"write files",
|
||||
);
|
||||
await this.runCheckedCommand({
|
||||
...buildPinnedWritePlan({
|
||||
check: writeCheck,
|
||||
@@ -218,7 +221,11 @@ class SandboxFsBridgeImpl implements SandboxFsBridge {
|
||||
signal?: AbortSignal;
|
||||
}): Promise<SandboxFsStat | null> {
|
||||
const target = this.resolveResolvedPath(params);
|
||||
const result = await this.runPlannedCommand(buildStatPlan(target), params.signal);
|
||||
const anchoredTarget = await this.pathGuard.resolveAnchoredSandboxEntry(target, "stat files");
|
||||
const result = await this.runPlannedCommand(
|
||||
buildStatPlan(target, anchoredTarget),
|
||||
params.signal,
|
||||
);
|
||||
if (result.code !== 0) {
|
||||
const stderr = result.stderr.toString("utf8");
|
||||
if (stderr.includes("No such file or directory")) {
|
||||
|
||||
@@ -380,4 +380,36 @@ describe("sessions_spawn subagent lifecycle hooks", () => {
|
||||
emitLifecycleHooks: true,
|
||||
});
|
||||
});
|
||||
|
||||
it("cleans up the provisional session when lineage patching fails after thread binding", async () => {
|
||||
const callGatewayMock = getCallGatewayMock();
|
||||
callGatewayMock.mockImplementation(async (opts: unknown) => {
|
||||
const request = opts as { method?: string; params?: Record<string, unknown> };
|
||||
if (request.method === "sessions.patch" && typeof request.params?.spawnedBy === "string") {
|
||||
throw new Error("lineage patch failed");
|
||||
}
|
||||
if (request.method === "sessions.delete") {
|
||||
return { ok: true };
|
||||
}
|
||||
return {};
|
||||
});
|
||||
|
||||
const result = await executeDiscordThreadSessionSpawn("call9");
|
||||
|
||||
expect(result.details).toMatchObject({
|
||||
status: "error",
|
||||
error: "lineage patch failed",
|
||||
});
|
||||
expect(hookRunnerMocks.runSubagentSpawned).not.toHaveBeenCalled();
|
||||
expect(hookRunnerMocks.runSubagentEnded).not.toHaveBeenCalled();
|
||||
const methods = getGatewayMethods();
|
||||
expect(methods).toContain("sessions.delete");
|
||||
expect(methods).not.toContain("agent");
|
||||
const deleteCall = findGatewayRequest("sessions.delete");
|
||||
expect(deleteCall?.params).toMatchObject({
|
||||
key: (result.details as { childSessionKey?: string }).childSessionKey,
|
||||
deleteTranscript: true,
|
||||
emitLifecycleHooks: true,
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import fs from "node:fs";
|
||||
import os from "node:os";
|
||||
import path from "node:path";
|
||||
import { beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import { resetSubagentRegistryForTests } from "./subagent-registry.js";
|
||||
import { decodeStrictBase64, spawnSubagentDirect } from "./subagent-spawn.js";
|
||||
|
||||
@@ -31,6 +32,7 @@ let configOverride: Record<string, unknown> = {
|
||||
},
|
||||
},
|
||||
};
|
||||
let workspaceDirOverride = "";
|
||||
|
||||
vi.mock("../config/config.js", async (importOriginal) => {
|
||||
const actual = await importOriginal<typeof import("../config/config.js")>();
|
||||
@@ -61,7 +63,7 @@ vi.mock("./agent-scope.js", async (importOriginal) => {
|
||||
const actual = await importOriginal<typeof import("./agent-scope.js")>();
|
||||
return {
|
||||
...actual,
|
||||
resolveAgentWorkspaceDir: () => path.join(os.tmpdir(), "agent-workspace"),
|
||||
resolveAgentWorkspaceDir: () => workspaceDirOverride,
|
||||
};
|
||||
});
|
||||
|
||||
@@ -145,6 +147,16 @@ describe("spawnSubagentDirect filename validation", () => {
|
||||
resetSubagentRegistryForTests();
|
||||
callGatewayMock.mockClear();
|
||||
setupGatewayMock();
|
||||
workspaceDirOverride = fs.mkdtempSync(
|
||||
path.join(os.tmpdir(), `openclaw-subagent-attachments-${process.pid}-${Date.now()}-`),
|
||||
);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
if (workspaceDirOverride) {
|
||||
fs.rmSync(workspaceDirOverride, { recursive: true, force: true });
|
||||
workspaceDirOverride = "";
|
||||
}
|
||||
});
|
||||
|
||||
const ctx = {
|
||||
@@ -210,4 +222,43 @@ describe("spawnSubagentDirect filename validation", () => {
|
||||
expect(result.status).toBe("error");
|
||||
expect(result.error).toMatch(/attachments_invalid_name/);
|
||||
});
|
||||
|
||||
it("removes materialized attachments when lineage patching fails", async () => {
|
||||
const calls: Array<{ method?: string; params?: Record<string, unknown> }> = [];
|
||||
callGatewayMock.mockImplementation(async (opts: unknown) => {
|
||||
const request = opts as { method?: string; params?: Record<string, unknown> };
|
||||
calls.push(request);
|
||||
if (request.method === "sessions.patch" && typeof request.params?.spawnedBy === "string") {
|
||||
throw new Error("lineage patch failed");
|
||||
}
|
||||
if (request.method === "sessions.delete") {
|
||||
return { ok: true };
|
||||
}
|
||||
return {};
|
||||
});
|
||||
|
||||
const result = await spawnSubagentDirect(
|
||||
{
|
||||
task: "test",
|
||||
attachments: [{ name: "file.txt", content: validContent, encoding: "base64" }],
|
||||
},
|
||||
ctx,
|
||||
);
|
||||
|
||||
expect(result).toMatchObject({
|
||||
status: "error",
|
||||
error: "lineage patch failed",
|
||||
});
|
||||
const attachmentsRoot = path.join(workspaceDirOverride, ".openclaw", "attachments");
|
||||
const retainedDirs = fs.existsSync(attachmentsRoot)
|
||||
? fs.readdirSync(attachmentsRoot).filter((entry) => !entry.startsWith("."))
|
||||
: [];
|
||||
expect(retainedDirs).toHaveLength(0);
|
||||
const deleteCall = calls.find((entry) => entry.method === "sessions.delete");
|
||||
expect(deleteCall?.params).toMatchObject({
|
||||
key: expect.stringMatching(/^agent:main:subagent:/),
|
||||
deleteTranscript: true,
|
||||
emitLifecycleHooks: false,
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -153,6 +153,25 @@ async function cleanupProvisionalSession(
|
||||
}
|
||||
}
|
||||
|
||||
async function cleanupFailedSpawnBeforeAgentStart(params: {
|
||||
childSessionKey: string;
|
||||
attachmentAbsDir?: string;
|
||||
emitLifecycleHooks?: boolean;
|
||||
deleteTranscript?: boolean;
|
||||
}): Promise<void> {
|
||||
if (params.attachmentAbsDir) {
|
||||
try {
|
||||
await fs.rm(params.attachmentAbsDir, { recursive: true, force: true });
|
||||
} catch {
|
||||
// Best-effort cleanup only.
|
||||
}
|
||||
}
|
||||
await cleanupProvisionalSession(params.childSessionKey, {
|
||||
emitLifecycleHooks: params.emitLifecycleHooks,
|
||||
deleteTranscript: params.deleteTranscript,
|
||||
});
|
||||
}
|
||||
|
||||
function resolveSpawnMode(params: {
|
||||
requestedMode?: SpawnSubagentMode;
|
||||
threadRequested: boolean;
|
||||
@@ -561,10 +580,32 @@ export async function spawnSubagentDirect(
|
||||
explicitWorkspaceDir: toolSpawnMetadata.workspaceDir,
|
||||
}),
|
||||
});
|
||||
const spawnLineagePatchError = await patchChildSession({
|
||||
spawnedBy: spawnedByKey,
|
||||
...(spawnedMetadata.workspaceDir ? { spawnedWorkspaceDir: spawnedMetadata.workspaceDir } : {}),
|
||||
});
|
||||
if (spawnLineagePatchError) {
|
||||
await cleanupFailedSpawnBeforeAgentStart({
|
||||
childSessionKey,
|
||||
attachmentAbsDir,
|
||||
emitLifecycleHooks: threadBindingReady,
|
||||
deleteTranscript: true,
|
||||
});
|
||||
return {
|
||||
status: "error",
|
||||
error: spawnLineagePatchError,
|
||||
childSessionKey,
|
||||
};
|
||||
}
|
||||
|
||||
const childIdem = crypto.randomUUID();
|
||||
let childRunId: string = childIdem;
|
||||
try {
|
||||
const {
|
||||
spawnedBy: _spawnedBy,
|
||||
workspaceDir: _workspaceDir,
|
||||
...publicSpawnedMetadata
|
||||
} = spawnedMetadata;
|
||||
const response = await callGateway<{ runId: string }>({
|
||||
method: "agent",
|
||||
params: {
|
||||
@@ -581,7 +622,7 @@ export async function spawnSubagentDirect(
|
||||
thinking: thinkingOverride,
|
||||
timeout: runTimeoutSeconds,
|
||||
label: label || undefined,
|
||||
...spawnedMetadata,
|
||||
...publicSpawnedMetadata,
|
||||
},
|
||||
timeoutMs: 10_000,
|
||||
});
|
||||
|
||||
11
src/agents/tool-catalog.test.ts
Normal file
11
src/agents/tool-catalog.test.ts
Normal file
@@ -0,0 +1,11 @@
|
||||
import { describe, expect, it } from "vitest";
|
||||
import { resolveCoreToolProfilePolicy } from "./tool-catalog.js";
|
||||
|
||||
describe("tool-catalog", () => {
|
||||
it("includes web_search and web_fetch in the coding profile policy", () => {
|
||||
const policy = resolveCoreToolProfilePolicy("coding");
|
||||
expect(policy).toBeDefined();
|
||||
expect(policy!.allow).toContain("web_search");
|
||||
expect(policy!.allow).toContain("web_fetch");
|
||||
});
|
||||
});
|
||||
@@ -86,7 +86,7 @@ const CORE_TOOL_DEFINITIONS: CoreToolDefinition[] = [
|
||||
label: "web_search",
|
||||
description: "Search the web",
|
||||
sectionId: "web",
|
||||
profiles: [],
|
||||
profiles: ["coding"],
|
||||
includeInOpenClawGroup: true,
|
||||
},
|
||||
{
|
||||
@@ -94,7 +94,7 @@ const CORE_TOOL_DEFINITIONS: CoreToolDefinition[] = [
|
||||
label: "web_fetch",
|
||||
description: "Fetch web content",
|
||||
sectionId: "web",
|
||||
profiles: [],
|
||||
profiles: ["coding"],
|
||||
includeInOpenClawGroup: true,
|
||||
},
|
||||
{
|
||||
|
||||
@@ -80,6 +80,7 @@ describe("tool-policy", () => {
|
||||
expect(isOwnerOnlyToolName("whatsapp_login")).toBe(true);
|
||||
expect(isOwnerOnlyToolName("cron")).toBe(true);
|
||||
expect(isOwnerOnlyToolName("gateway")).toBe(true);
|
||||
expect(isOwnerOnlyToolName("nodes")).toBe(true);
|
||||
expect(isOwnerOnlyToolName("read")).toBe(false);
|
||||
});
|
||||
|
||||
@@ -107,6 +108,27 @@ describe("tool-policy", () => {
|
||||
expect(applyOwnerOnlyToolPolicy(tools, false)).toEqual([]);
|
||||
expect(applyOwnerOnlyToolPolicy(tools, true)).toHaveLength(1);
|
||||
});
|
||||
|
||||
it("strips nodes for non-owner senders via fallback policy", () => {
|
||||
const tools = [
|
||||
{
|
||||
name: "read",
|
||||
// oxlint-disable-next-line typescript/no-explicit-any
|
||||
execute: async () => ({ content: [], details: {} }) as any,
|
||||
},
|
||||
{
|
||||
name: "nodes",
|
||||
// oxlint-disable-next-line typescript/no-explicit-any
|
||||
execute: async () => ({ content: [], details: {} }) as any,
|
||||
},
|
||||
] as unknown as AnyAgentTool[];
|
||||
|
||||
expect(applyOwnerOnlyToolPolicy(tools, false).map((tool) => tool.name)).toEqual(["read"]);
|
||||
expect(applyOwnerOnlyToolPolicy(tools, true).map((tool) => tool.name)).toEqual([
|
||||
"read",
|
||||
"nodes",
|
||||
]);
|
||||
});
|
||||
});
|
||||
|
||||
describe("TOOL_POLICY_CONFORMANCE", () => {
|
||||
|
||||
@@ -28,7 +28,12 @@ function wrapOwnerOnlyToolExecution(tool: AnyAgentTool, senderIsOwner: boolean):
|
||||
};
|
||||
}
|
||||
|
||||
const OWNER_ONLY_TOOL_NAME_FALLBACKS = new Set<string>(["whatsapp_login", "cron", "gateway"]);
|
||||
const OWNER_ONLY_TOOL_NAME_FALLBACKS = new Set<string>([
|
||||
"whatsapp_login",
|
||||
"cron",
|
||||
"gateway",
|
||||
"nodes",
|
||||
]);
|
||||
|
||||
export function isOwnerOnlyToolName(name: string) {
|
||||
return OWNER_ONLY_TOOL_NAME_FALLBACKS.has(normalizeToolName(name));
|
||||
|
||||
@@ -19,6 +19,7 @@ import {
|
||||
import {
|
||||
buildAgentMainSessionKey,
|
||||
DEFAULT_AGENT_ID,
|
||||
parseAgentSessionKey,
|
||||
resolveAgentIdFromSessionKey,
|
||||
} from "../../routing/session-key.js";
|
||||
import { applyModelOverrideToSessionEntry } from "../../sessions/model-overrides.js";
|
||||
@@ -36,10 +37,12 @@ import {
|
||||
import type { AnyAgentTool } from "./common.js";
|
||||
import { readStringParam } from "./common.js";
|
||||
import {
|
||||
createSessionVisibilityGuard,
|
||||
shouldResolveSessionIdInput,
|
||||
resolveInternalSessionKey,
|
||||
resolveMainSessionAlias,
|
||||
createAgentToAgentPolicy,
|
||||
resolveEffectiveSessionToolsVisibility,
|
||||
resolveInternalSessionKey,
|
||||
resolveSandboxedSessionToolContext,
|
||||
} from "./sessions-helpers.js";
|
||||
|
||||
const SessionStatusToolSchema = Type.Object({
|
||||
@@ -175,6 +178,7 @@ async function resolveModelOverride(params: {
|
||||
export function createSessionStatusTool(opts?: {
|
||||
agentSessionKey?: string;
|
||||
config?: OpenClawConfig;
|
||||
sandboxed?: boolean;
|
||||
}): AnyAgentTool {
|
||||
return {
|
||||
label: "Session Status",
|
||||
@@ -185,18 +189,70 @@ export function createSessionStatusTool(opts?: {
|
||||
execute: async (_toolCallId, args) => {
|
||||
const params = args as Record<string, unknown>;
|
||||
const cfg = opts?.config ?? loadConfig();
|
||||
const { mainKey, alias } = resolveMainSessionAlias(cfg);
|
||||
const { mainKey, alias, effectiveRequesterKey } = resolveSandboxedSessionToolContext({
|
||||
cfg,
|
||||
agentSessionKey: opts?.agentSessionKey,
|
||||
sandboxed: opts?.sandboxed,
|
||||
});
|
||||
const a2aPolicy = createAgentToAgentPolicy(cfg);
|
||||
const requesterAgentId = resolveAgentIdFromSessionKey(
|
||||
opts?.agentSessionKey ?? effectiveRequesterKey,
|
||||
);
|
||||
const visibilityRequesterKey = effectiveRequesterKey.trim();
|
||||
const usesLegacyMainAlias = alias === mainKey;
|
||||
const isLegacyMainVisibilityKey = (sessionKey: string) => {
|
||||
const trimmed = sessionKey.trim();
|
||||
return usesLegacyMainAlias && (trimmed === "main" || trimmed === mainKey);
|
||||
};
|
||||
const resolveVisibilityMainSessionKey = (sessionAgentId: string) => {
|
||||
const requesterParsed = parseAgentSessionKey(visibilityRequesterKey);
|
||||
if (
|
||||
resolveAgentIdFromSessionKey(visibilityRequesterKey) === sessionAgentId &&
|
||||
(requesterParsed?.rest === mainKey || isLegacyMainVisibilityKey(visibilityRequesterKey))
|
||||
) {
|
||||
return visibilityRequesterKey;
|
||||
}
|
||||
return buildAgentMainSessionKey({
|
||||
agentId: sessionAgentId,
|
||||
mainKey,
|
||||
});
|
||||
};
|
||||
const normalizeVisibilityTargetSessionKey = (sessionKey: string, sessionAgentId: string) => {
|
||||
const trimmed = sessionKey.trim();
|
||||
if (!trimmed) {
|
||||
return trimmed;
|
||||
}
|
||||
if (trimmed.startsWith("agent:")) {
|
||||
const parsed = parseAgentSessionKey(trimmed);
|
||||
if (parsed?.rest === mainKey) {
|
||||
return resolveVisibilityMainSessionKey(sessionAgentId);
|
||||
}
|
||||
return trimmed;
|
||||
}
|
||||
// Preserve legacy bare main keys for requester tree checks.
|
||||
if (isLegacyMainVisibilityKey(trimmed)) {
|
||||
return resolveVisibilityMainSessionKey(sessionAgentId);
|
||||
}
|
||||
return trimmed;
|
||||
};
|
||||
const visibilityGuard =
|
||||
opts?.sandboxed === true
|
||||
? await createSessionVisibilityGuard({
|
||||
action: "status",
|
||||
requesterSessionKey: visibilityRequesterKey,
|
||||
visibility: resolveEffectiveSessionToolsVisibility({
|
||||
cfg,
|
||||
sandboxed: true,
|
||||
}),
|
||||
a2aPolicy,
|
||||
})
|
||||
: null;
|
||||
|
||||
const requestedKeyParam = readStringParam(params, "sessionKey");
|
||||
let requestedKeyRaw = requestedKeyParam ?? opts?.agentSessionKey;
|
||||
if (!requestedKeyRaw?.trim()) {
|
||||
throw new Error("sessionKey required");
|
||||
}
|
||||
|
||||
const requesterAgentId = resolveAgentIdFromSessionKey(
|
||||
opts?.agentSessionKey ?? requestedKeyRaw,
|
||||
);
|
||||
const ensureAgentAccess = (targetAgentId: string) => {
|
||||
if (targetAgentId === requesterAgentId) {
|
||||
return;
|
||||
@@ -213,7 +269,14 @@ export function createSessionStatusTool(opts?: {
|
||||
};
|
||||
|
||||
if (requestedKeyRaw.startsWith("agent:")) {
|
||||
ensureAgentAccess(resolveAgentIdFromSessionKey(requestedKeyRaw));
|
||||
const requestedAgentId = resolveAgentIdFromSessionKey(requestedKeyRaw);
|
||||
ensureAgentAccess(requestedAgentId);
|
||||
const access = visibilityGuard?.check(
|
||||
normalizeVisibilityTargetSessionKey(requestedKeyRaw, requestedAgentId),
|
||||
);
|
||||
if (access && !access.allowed) {
|
||||
throw new Error(access.error);
|
||||
}
|
||||
}
|
||||
|
||||
const isExplicitAgentKey = requestedKeyRaw.startsWith("agent:");
|
||||
@@ -258,6 +321,15 @@ export function createSessionStatusTool(opts?: {
|
||||
throw new Error(`Unknown ${kind}: ${requestedKeyRaw}`);
|
||||
}
|
||||
|
||||
if (visibilityGuard && !requestedKeyRaw.startsWith("agent:")) {
|
||||
const access = visibilityGuard.check(
|
||||
normalizeVisibilityTargetSessionKey(resolved.key, agentId),
|
||||
);
|
||||
if (!access.allowed) {
|
||||
throw new Error(access.error);
|
||||
}
|
||||
}
|
||||
|
||||
const configured = resolveDefaultModelForAgent({ cfg, agentId });
|
||||
const modelRaw = readStringParam(params, "model");
|
||||
let changedModel = false;
|
||||
|
||||
@@ -14,7 +14,7 @@ export type AgentToAgentPolicy = {
|
||||
isAllowed: (requesterAgentId: string, targetAgentId: string) => boolean;
|
||||
};
|
||||
|
||||
export type SessionAccessAction = "history" | "send" | "list";
|
||||
export type SessionAccessAction = "history" | "send" | "list" | "status";
|
||||
|
||||
export type SessionAccessResult =
|
||||
| { allowed: true }
|
||||
@@ -130,6 +130,9 @@ function actionPrefix(action: SessionAccessAction): string {
|
||||
if (action === "send") {
|
||||
return "Session send";
|
||||
}
|
||||
if (action === "status") {
|
||||
return "Session status";
|
||||
}
|
||||
return "Session list";
|
||||
}
|
||||
|
||||
@@ -140,6 +143,9 @@ function a2aDisabledMessage(action: SessionAccessAction): string {
|
||||
if (action === "send") {
|
||||
return "Agent-to-agent messaging is disabled. Set tools.agentToAgent.enabled=true to allow cross-agent sends.";
|
||||
}
|
||||
if (action === "status") {
|
||||
return "Agent-to-agent status is disabled. Set tools.agentToAgent.enabled=true to allow cross-agent access.";
|
||||
}
|
||||
return "Agent-to-agent listing is disabled. Set tools.agentToAgent.enabled=true to allow cross-agent visibility.";
|
||||
}
|
||||
|
||||
@@ -150,6 +156,9 @@ function a2aDeniedMessage(action: SessionAccessAction): string {
|
||||
if (action === "send") {
|
||||
return "Agent-to-agent messaging denied by tools.agentToAgent.allow.";
|
||||
}
|
||||
if (action === "status") {
|
||||
return "Agent-to-agent status denied by tools.agentToAgent.allow.";
|
||||
}
|
||||
return "Agent-to-agent listing denied by tools.agentToAgent.allow.";
|
||||
}
|
||||
|
||||
@@ -160,6 +169,9 @@ function crossVisibilityMessage(action: SessionAccessAction): string {
|
||||
if (action === "send") {
|
||||
return "Session send visibility is restricted. Set tools.sessions.visibility=all to allow cross-agent access.";
|
||||
}
|
||||
if (action === "status") {
|
||||
return "Session status visibility is restricted. Set tools.sessions.visibility=all to allow cross-agent access.";
|
||||
}
|
||||
return "Session list visibility is restricted. Set tools.sessions.visibility=all to allow cross-agent access.";
|
||||
}
|
||||
|
||||
|
||||
@@ -166,9 +166,9 @@ export function extractAssistantText(message: unknown): string | undefined {
|
||||
normalizeText: (text) => text.trim(),
|
||||
}) ?? "";
|
||||
const stopReason = (message as { stopReason?: unknown }).stopReason;
|
||||
const errorMessage = (message as { errorMessage?: unknown }).errorMessage;
|
||||
const errorContext =
|
||||
stopReason === "error" || (typeof errorMessage === "string" && Boolean(errorMessage.trim()));
|
||||
// Gate on stopReason only — a non-error response with a stale/background errorMessage
|
||||
// should not have its content rewritten with error templates (#13935).
|
||||
const errorContext = stopReason === "error";
|
||||
|
||||
return joined ? sanitizeUserFacingText(joined, { errorContext }) : undefined;
|
||||
}
|
||||
|
||||
@@ -199,6 +199,16 @@ describe("extractAssistantText", () => {
|
||||
"Firebase downgraded us to the free Spark plan. Check whether billing should be re-enabled.",
|
||||
);
|
||||
});
|
||||
|
||||
it("preserves successful turns with stale background errorMessage", () => {
|
||||
const message = {
|
||||
role: "assistant",
|
||||
stopReason: "end_turn",
|
||||
errorMessage: "insufficient credits for embedding model",
|
||||
content: [{ type: "text", text: "Handle payment required errors in your API." }],
|
||||
};
|
||||
expect(extractAssistantText(message)).toBe("Handle payment required errors in your API.");
|
||||
});
|
||||
});
|
||||
|
||||
describe("resolveAnnounceTarget", () => {
|
||||
|
||||
@@ -114,7 +114,7 @@ describe("web_fetch Cloudflare Markdown for Agents", () => {
|
||||
sandboxed: false,
|
||||
runtimeFirecrawl: {
|
||||
active: false,
|
||||
apiKeySource: "secretRef",
|
||||
apiKeySource: "secretRef", // pragma: allowlist secret
|
||||
diagnostics: [],
|
||||
},
|
||||
});
|
||||
|
||||
@@ -652,7 +652,7 @@ describe("web_search Perplexity lazy resolution", () => {
|
||||
web: {
|
||||
search: {
|
||||
provider: "gemini",
|
||||
gemini: { apiKey: "gemini-config-test" },
|
||||
gemini: { apiKey: "gemini-config-test" }, // pragma: allowlist secret
|
||||
perplexity: perplexityConfig as { apiKey?: string; baseUrl?: string; model?: string },
|
||||
},
|
||||
},
|
||||
|
||||
@@ -6,8 +6,10 @@ import { getCliSessionId } from "../../agents/cli-session.js";
|
||||
import { runWithModelFallback } from "../../agents/model-fallback.js";
|
||||
import { isCliProvider } from "../../agents/model-selection.js";
|
||||
import {
|
||||
BILLING_ERROR_USER_MESSAGE,
|
||||
isCompactionFailureError,
|
||||
isContextOverflowError,
|
||||
isBillingErrorMessage,
|
||||
isLikelyContextOverflowError,
|
||||
isTransientHttpError,
|
||||
sanitizeUserFacingText,
|
||||
@@ -514,8 +516,9 @@ export async function runAgentTurnWithFallback(params: {
|
||||
break;
|
||||
} catch (err) {
|
||||
const message = err instanceof Error ? err.message : String(err);
|
||||
const isContextOverflow = isLikelyContextOverflowError(message);
|
||||
const isCompactionFailure = isCompactionFailureError(message);
|
||||
const isBilling = isBillingErrorMessage(message);
|
||||
const isContextOverflow = !isBilling && isLikelyContextOverflowError(message);
|
||||
const isCompactionFailure = !isBilling && isCompactionFailureError(message);
|
||||
const isSessionCorruption = /function call turn comes immediately after/i.test(message);
|
||||
const isRoleOrderingError = /incorrect role information|roles must alternate/i.test(message);
|
||||
const isTransientHttp = isTransientHttpError(message);
|
||||
@@ -610,11 +613,13 @@ export async function runAgentTurnWithFallback(params: {
|
||||
? sanitizeUserFacingText(message, { errorContext: true })
|
||||
: message;
|
||||
const trimmedMessage = safeMessage.replace(/\.\s*$/, "");
|
||||
const fallbackText = isContextOverflow
|
||||
? "⚠️ Context overflow — prompt too large for this model. Try a shorter message or a larger-context model."
|
||||
: isRoleOrderingError
|
||||
? "⚠️ Message ordering conflict - please try again. If this persists, use /new to start a fresh session."
|
||||
: `⚠️ Agent failed before reply: ${trimmedMessage}.\nLogs: openclaw logs --follow`;
|
||||
const fallbackText = isBilling
|
||||
? BILLING_ERROR_USER_MESSAGE
|
||||
: isContextOverflow
|
||||
? "⚠️ Context overflow — prompt too large for this model. Try a shorter message or a larger-context model."
|
||||
: isRoleOrderingError
|
||||
? "⚠️ Message ordering conflict - please try again. If this persists, use /new to start a fresh session."
|
||||
: `⚠️ Agent failed before reply: ${trimmedMessage}.\nLogs: openclaw logs --follow`;
|
||||
|
||||
return {
|
||||
kind: "final",
|
||||
|
||||
@@ -169,6 +169,50 @@ describe("buildReplyPayloads media filter integration", () => {
|
||||
expect(replyPayloads).toHaveLength(0);
|
||||
});
|
||||
|
||||
it("drops all final payloads when block pipeline streamed successfully", async () => {
|
||||
const pipeline: Parameters<typeof buildReplyPayloads>[0]["blockReplyPipeline"] = {
|
||||
didStream: () => true,
|
||||
isAborted: () => false,
|
||||
hasSentPayload: () => false,
|
||||
enqueue: () => {},
|
||||
flush: async () => {},
|
||||
stop: () => {},
|
||||
hasBuffered: () => false,
|
||||
};
|
||||
// shouldDropFinalPayloads short-circuits to [] when the pipeline streamed
|
||||
// without aborting, so hasSentPayload is never reached.
|
||||
const { replyPayloads } = await buildReplyPayloads({
|
||||
...baseParams,
|
||||
blockStreamingEnabled: true,
|
||||
blockReplyPipeline: pipeline,
|
||||
replyToMode: "all",
|
||||
payloads: [{ text: "response", replyToId: "post-123" }],
|
||||
});
|
||||
|
||||
expect(replyPayloads).toHaveLength(0);
|
||||
});
|
||||
|
||||
it("deduplicates final payloads against directly sent block keys regardless of replyToId", async () => {
|
||||
// When block streaming is not active but directlySentBlockKeys has entries
|
||||
// (e.g. from pre-tool flush), the key should match even if replyToId differs.
|
||||
const { createBlockReplyContentKey } = await import("./block-reply-pipeline.js");
|
||||
const directlySentBlockKeys = new Set<string>();
|
||||
directlySentBlockKeys.add(
|
||||
createBlockReplyContentKey({ text: "response", replyToId: "post-1" }),
|
||||
);
|
||||
|
||||
const { replyPayloads } = await buildReplyPayloads({
|
||||
...baseParams,
|
||||
blockStreamingEnabled: false,
|
||||
blockReplyPipeline: null,
|
||||
directlySentBlockKeys,
|
||||
replyToMode: "off",
|
||||
payloads: [{ text: "response" }],
|
||||
});
|
||||
|
||||
expect(replyPayloads).toHaveLength(0);
|
||||
});
|
||||
|
||||
it("does not suppress same-target replies when accountId differs", async () => {
|
||||
const { replyPayloads } = await buildReplyPayloads({
|
||||
...baseParams,
|
||||
|
||||
@@ -5,7 +5,7 @@ import type { OriginatingChannelType } from "../templating.js";
|
||||
import { SILENT_REPLY_TOKEN } from "../tokens.js";
|
||||
import type { ReplyPayload } from "../types.js";
|
||||
import { formatBunFetchSocketError, isBunFetchSocketError } from "./agent-runner-utils.js";
|
||||
import { createBlockReplyPayloadKey, type BlockReplyPipeline } from "./block-reply-pipeline.js";
|
||||
import { createBlockReplyContentKey, type BlockReplyPipeline } from "./block-reply-pipeline.js";
|
||||
import {
|
||||
resolveOriginAccountId,
|
||||
resolveOriginMessageProvider,
|
||||
@@ -213,7 +213,7 @@ export async function buildReplyPayloads(params: {
|
||||
)
|
||||
: params.directlySentBlockKeys?.size
|
||||
? mediaFilteredPayloads.filter(
|
||||
(payload) => !params.directlySentBlockKeys!.has(createBlockReplyPayloadKey(payload)),
|
||||
(payload) => !params.directlySentBlockKeys!.has(createBlockReplyContentKey(payload)),
|
||||
)
|
||||
: mediaFilteredPayloads;
|
||||
const replyPayloads = suppressMessagingToolReplies ? [] : filteredPayloads;
|
||||
|
||||
@@ -1628,3 +1628,72 @@ describe("runReplyAgent transient HTTP retry", () => {
|
||||
expect(payload?.text).toContain("Recovered response");
|
||||
});
|
||||
});
|
||||
|
||||
describe("runReplyAgent billing error classification", () => {
|
||||
// Regression guard for the runner-level catch block in runAgentTurnWithFallback.
|
||||
// Billing errors from providers like OpenRouter can contain token/size wording that
|
||||
// matches context overflow heuristics. This test verifies the final user-visible
|
||||
// message is the billing-specific one, not the "Context overflow" fallback.
|
||||
it("returns billing message for mixed-signal error (billing text + overflow patterns)", async () => {
|
||||
runEmbeddedPiAgentMock.mockRejectedValueOnce(
|
||||
new Error("402 Payment Required: request token limit exceeded for this billing plan"),
|
||||
);
|
||||
|
||||
const typing = createMockTypingController();
|
||||
const sessionCtx = {
|
||||
Provider: "telegram",
|
||||
MessageSid: "msg",
|
||||
} as unknown as TemplateContext;
|
||||
const resolvedQueue = { mode: "interrupt" } as unknown as QueueSettings;
|
||||
const followupRun = {
|
||||
prompt: "hello",
|
||||
summaryLine: "hello",
|
||||
enqueuedAt: Date.now(),
|
||||
run: {
|
||||
sessionId: "session",
|
||||
sessionKey: "main",
|
||||
messageProvider: "telegram",
|
||||
sessionFile: "/tmp/session.jsonl",
|
||||
workspaceDir: "/tmp",
|
||||
config: {},
|
||||
skillsSnapshot: {},
|
||||
provider: "anthropic",
|
||||
model: "claude",
|
||||
thinkLevel: "low",
|
||||
verboseLevel: "off",
|
||||
elevatedLevel: "off",
|
||||
bashElevated: {
|
||||
enabled: false,
|
||||
allowed: false,
|
||||
defaultLevel: "off",
|
||||
},
|
||||
timeoutMs: 1_000,
|
||||
blockReplyBreak: "message_end",
|
||||
},
|
||||
} as unknown as FollowupRun;
|
||||
|
||||
const result = await runReplyAgent({
|
||||
commandBody: "hello",
|
||||
followupRun,
|
||||
queueKey: "main",
|
||||
resolvedQueue,
|
||||
shouldSteer: false,
|
||||
shouldFollowup: false,
|
||||
isActive: false,
|
||||
isStreaming: false,
|
||||
typing,
|
||||
sessionCtx,
|
||||
defaultModel: "anthropic/claude",
|
||||
resolvedVerboseLevel: "off",
|
||||
isNewSession: false,
|
||||
blockStreamingEnabled: false,
|
||||
resolvedBlockStreamingBreak: "message_end",
|
||||
shouldInjectGroupIntro: false,
|
||||
typingMode: "instant",
|
||||
});
|
||||
|
||||
const payload = Array.isArray(result) ? result[0] : result;
|
||||
expect(payload?.text).toContain("billing error");
|
||||
expect(payload?.text).not.toContain("Context overflow");
|
||||
});
|
||||
});
|
||||
|
||||
79
src/auto-reply/reply/block-reply-pipeline.test.ts
Normal file
79
src/auto-reply/reply/block-reply-pipeline.test.ts
Normal file
@@ -0,0 +1,79 @@
|
||||
import { describe, expect, it } from "vitest";
|
||||
import {
|
||||
createBlockReplyContentKey,
|
||||
createBlockReplyPayloadKey,
|
||||
createBlockReplyPipeline,
|
||||
} from "./block-reply-pipeline.js";
|
||||
|
||||
describe("createBlockReplyPayloadKey", () => {
|
||||
it("produces different keys for payloads differing only by replyToId", () => {
|
||||
const a = createBlockReplyPayloadKey({ text: "hello world", replyToId: "post-1" });
|
||||
const b = createBlockReplyPayloadKey({ text: "hello world", replyToId: "post-2" });
|
||||
const c = createBlockReplyPayloadKey({ text: "hello world" });
|
||||
expect(a).not.toBe(b);
|
||||
expect(a).not.toBe(c);
|
||||
});
|
||||
|
||||
it("produces different keys for payloads with different text", () => {
|
||||
const a = createBlockReplyPayloadKey({ text: "hello" });
|
||||
const b = createBlockReplyPayloadKey({ text: "world" });
|
||||
expect(a).not.toBe(b);
|
||||
});
|
||||
|
||||
it("produces different keys for payloads with different media", () => {
|
||||
const a = createBlockReplyPayloadKey({ text: "hello", mediaUrl: "file:///a.png" });
|
||||
const b = createBlockReplyPayloadKey({ text: "hello", mediaUrl: "file:///b.png" });
|
||||
expect(a).not.toBe(b);
|
||||
});
|
||||
|
||||
it("trims whitespace from text for key comparison", () => {
|
||||
const a = createBlockReplyPayloadKey({ text: " hello " });
|
||||
const b = createBlockReplyPayloadKey({ text: "hello" });
|
||||
expect(a).toBe(b);
|
||||
});
|
||||
});
|
||||
|
||||
describe("createBlockReplyContentKey", () => {
|
||||
it("produces the same key for payloads differing only by replyToId", () => {
|
||||
const a = createBlockReplyContentKey({ text: "hello world", replyToId: "post-1" });
|
||||
const b = createBlockReplyContentKey({ text: "hello world", replyToId: "post-2" });
|
||||
const c = createBlockReplyContentKey({ text: "hello world" });
|
||||
expect(a).toBe(b);
|
||||
expect(a).toBe(c);
|
||||
});
|
||||
});
|
||||
|
||||
describe("createBlockReplyPipeline dedup with threading", () => {
|
||||
it("keeps separate deliveries for same text with different replyToId", async () => {
|
||||
const sent: Array<{ text?: string; replyToId?: string }> = [];
|
||||
const pipeline = createBlockReplyPipeline({
|
||||
onBlockReply: async (payload) => {
|
||||
sent.push({ text: payload.text, replyToId: payload.replyToId });
|
||||
},
|
||||
timeoutMs: 5000,
|
||||
});
|
||||
|
||||
pipeline.enqueue({ text: "response text", replyToId: "thread-root-1" });
|
||||
pipeline.enqueue({ text: "response text", replyToId: undefined });
|
||||
await pipeline.flush();
|
||||
|
||||
expect(sent).toEqual([
|
||||
{ text: "response text", replyToId: "thread-root-1" },
|
||||
{ text: "response text", replyToId: undefined },
|
||||
]);
|
||||
});
|
||||
|
||||
it("hasSentPayload matches regardless of replyToId", async () => {
|
||||
const pipeline = createBlockReplyPipeline({
|
||||
onBlockReply: async () => {},
|
||||
timeoutMs: 5000,
|
||||
});
|
||||
|
||||
pipeline.enqueue({ text: "response text", replyToId: "thread-root-1" });
|
||||
await pipeline.flush();
|
||||
|
||||
// Final payload with no replyToId should be recognized as already sent
|
||||
expect(pipeline.hasSentPayload({ text: "response text" })).toBe(true);
|
||||
expect(pipeline.hasSentPayload({ text: "response text", replyToId: "other-id" })).toBe(true);
|
||||
});
|
||||
});
|
||||
@@ -48,6 +48,19 @@ export function createBlockReplyPayloadKey(payload: ReplyPayload): string {
|
||||
});
|
||||
}
|
||||
|
||||
export function createBlockReplyContentKey(payload: ReplyPayload): string {
|
||||
const text = payload.text?.trim() ?? "";
|
||||
const mediaList = payload.mediaUrls?.length
|
||||
? payload.mediaUrls
|
||||
: payload.mediaUrl
|
||||
? [payload.mediaUrl]
|
||||
: [];
|
||||
// Content-only key used for final-payload suppression after block streaming.
|
||||
// This intentionally ignores replyToId so a streamed threaded payload and the
|
||||
// later final payload still collapse when they carry the same content.
|
||||
return JSON.stringify({ text, mediaList });
|
||||
}
|
||||
|
||||
const withTimeout = async <T>(
|
||||
promise: Promise<T>,
|
||||
timeoutMs: number,
|
||||
@@ -80,6 +93,7 @@ export function createBlockReplyPipeline(params: {
|
||||
}): BlockReplyPipeline {
|
||||
const { onBlockReply, timeoutMs, coalescing, buffer } = params;
|
||||
const sentKeys = new Set<string>();
|
||||
const sentContentKeys = new Set<string>();
|
||||
const pendingKeys = new Set<string>();
|
||||
const seenKeys = new Set<string>();
|
||||
const bufferedKeys = new Set<string>();
|
||||
@@ -95,6 +109,7 @@ export function createBlockReplyPipeline(params: {
|
||||
return;
|
||||
}
|
||||
const payloadKey = createBlockReplyPayloadKey(payload);
|
||||
const contentKey = createBlockReplyContentKey(payload);
|
||||
if (!bypassSeenCheck) {
|
||||
if (seenKeys.has(payloadKey)) {
|
||||
return;
|
||||
@@ -130,6 +145,7 @@ export function createBlockReplyPipeline(params: {
|
||||
return;
|
||||
}
|
||||
sentKeys.add(payloadKey);
|
||||
sentContentKeys.add(contentKey);
|
||||
didStream = true;
|
||||
})
|
||||
.catch((err) => {
|
||||
@@ -238,8 +254,8 @@ export function createBlockReplyPipeline(params: {
|
||||
didStream: () => didStream,
|
||||
isAborted: () => aborted,
|
||||
hasSentPayload: (payload) => {
|
||||
const payloadKey = createBlockReplyPayloadKey(payload);
|
||||
return sentKeys.has(payloadKey);
|
||||
const payloadKey = createBlockReplyContentKey(payload);
|
||||
return sentContentKeys.has(payloadKey);
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
@@ -19,6 +19,7 @@ const MODEL_PICK_PROVIDER_PREFERENCE = [
|
||||
"zai",
|
||||
"openrouter",
|
||||
"opencode",
|
||||
"opencode-go",
|
||||
"github-copilot",
|
||||
"groq",
|
||||
"cerebras",
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user