mirror of
https://github.com/openclaw/openclaw.git
synced 2026-04-10 16:51:13 +00:00
fix(agents): close remaining prompt cache boundary gaps (#60691)
* fix(agents): route default stream fallbacks through boundary shapers * fix(agents): close remaining cache boundary gaps * chore(changelog): note cache prefix follow-up rollout * fix(agents): preserve cache-safe fallback stream bases
This commit is contained in:
@@ -27,6 +27,7 @@ Docs: https://docs.openclaw.ai
|
||||
- Memory/dreaming (experimental): add opt-in weighted short-term recall promotion to `MEMORY.md`, managed dreaming modes (`off|core|rem|deep`), and a `/dreaming` command plus Dreams UI so durable memory promotion can run on background cadence without manual scheduling. (#60569) Thanks @vignesh07.
|
||||
- Agents/system prompts: add an internal cache-prefix boundary across Anthropic-family, OpenAI-family, Google, and CLI transport shaping so stable system-prompt prefixes stay reusable without leaking internal cache markers to provider payloads. (#59054) Thanks @coletebou and @vincentkoc.
|
||||
- Docs/memory: add a dedicated Dreaming concept page, expand Memory overview with the Dreaming model, and link Dreaming from further reading to document the experimental opt-in consolidation workflow. Thanks @vignesh07.
|
||||
- Agents/cache prefixes: route compaction, OpenAI WebSocket HTTP fallback, and later-turn embedded session reuse through the same cache-safe prompt shaping path so Anthropic-family and OpenAI-family requests keep stable prompt bytes across follow-up turns and fallback transport changes. (#60691) Thanks @vincentkoc.
|
||||
|
||||
### Fixes
|
||||
|
||||
|
||||
@@ -219,6 +219,7 @@ const mockStreamSimple = vi.fn((model: unknown, context: unknown, options?: unkn
|
||||
});
|
||||
return stream;
|
||||
});
|
||||
const mockCreateHttpFallbackStreamFn = vi.fn(() => mockStreamSimple as never);
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// Helpers
|
||||
@@ -1180,8 +1181,11 @@ describe("createOpenAIWebSocketStreamFn", () => {
|
||||
beforeEach(() => {
|
||||
MockManager.reset();
|
||||
streamSimpleCalls.length = 0;
|
||||
mockCreateHttpFallbackStreamFn.mockReset();
|
||||
mockCreateHttpFallbackStreamFn.mockReturnValue(mockStreamSimple as never);
|
||||
openAIWsStreamTesting.setDepsForTest({
|
||||
createManager: ((options?: unknown) => new MockManager(options)) as never,
|
||||
createHttpFallbackStreamFn: mockCreateHttpFallbackStreamFn as never,
|
||||
streamSimple: mockStreamSimple,
|
||||
});
|
||||
});
|
||||
@@ -1195,6 +1199,7 @@ describe("createOpenAIWebSocketStreamFn", () => {
|
||||
releaseWsSession("sess-1");
|
||||
releaseWsSession("sess-2");
|
||||
releaseWsSession("sess-fallback");
|
||||
releaseWsSession("sess-boundary-http-fallback");
|
||||
releaseWsSession("sess-incremental");
|
||||
releaseWsSession("sess-full");
|
||||
releaseWsSession("sess-phase");
|
||||
@@ -1931,6 +1936,64 @@ describe("createOpenAIWebSocketStreamFn", () => {
|
||||
expect(streamSimpleCalls.length).toBeGreaterThan(callsBefore);
|
||||
});
|
||||
|
||||
it("routes websocket HTTP fallback through the configured HTTP fallback builder", async () => {
|
||||
const httpFallbackCalls: Array<{ model: unknown; context: unknown; options?: unknown }> = [];
|
||||
const httpFallbackStreamFn = vi.fn((model: unknown, context: unknown, options?: unknown) => {
|
||||
httpFallbackCalls.push({ model, context, options });
|
||||
const stream = createAssistantMessageEventStream();
|
||||
queueMicrotask(() => {
|
||||
const msg = makeFakeAssistantMessage("boundary-safe fallback");
|
||||
stream.push({ type: "done", reason: "stop", message: msg });
|
||||
stream.end();
|
||||
});
|
||||
return stream;
|
||||
});
|
||||
mockCreateHttpFallbackStreamFn.mockReturnValue(httpFallbackStreamFn as never);
|
||||
const sessionId = "sess-boundary-http-fallback";
|
||||
const streamFn = createOpenAIWebSocketStreamFn("sk-test", sessionId);
|
||||
|
||||
const stream1 = streamFn(
|
||||
modelStub as Parameters<typeof streamFn>[0],
|
||||
contextStub as Parameters<typeof streamFn>[1],
|
||||
);
|
||||
await new Promise<void>((resolve, reject) => {
|
||||
queueMicrotask(async () => {
|
||||
try {
|
||||
await new Promise((r) => setImmediate(r));
|
||||
MockManager.lastInstance!.simulateEvent({
|
||||
type: "response.completed",
|
||||
response: makeResponseObject("resp-ok", "OK"),
|
||||
});
|
||||
for await (const _ of await resolveStream(stream1)) {
|
||||
/* consume */
|
||||
}
|
||||
resolve();
|
||||
} catch (e) {
|
||||
reject(e);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
MockManager.globalSendFailuresRemaining = 2;
|
||||
const stream2 = streamFn(
|
||||
modelStub as Parameters<typeof streamFn>[0],
|
||||
{
|
||||
...contextStub,
|
||||
systemPrompt: `Stable prefix${SYSTEM_PROMPT_CACHE_BOUNDARY}Dynamic suffix`,
|
||||
} as Parameters<typeof streamFn>[1],
|
||||
);
|
||||
for await (const _ of await resolveStream(stream2)) {
|
||||
/* consume */
|
||||
}
|
||||
|
||||
expect(mockCreateHttpFallbackStreamFn).toHaveBeenCalled();
|
||||
expect(streamSimpleCalls).toHaveLength(0);
|
||||
expect(httpFallbackCalls).toHaveLength(1);
|
||||
expect(httpFallbackCalls[0]?.context).toMatchObject({
|
||||
systemPrompt: `Stable prefix${SYSTEM_PROMPT_CACHE_BOUNDARY}Dynamic suffix`,
|
||||
});
|
||||
});
|
||||
|
||||
it("forwards temperature and maxTokens to response.create", async () => {
|
||||
const streamFn = createOpenAIWebSocketStreamFn("sk-test", "sess-temp");
|
||||
const opts = { temperature: 0.3, maxTokens: 256 };
|
||||
@@ -2288,6 +2351,7 @@ describe("releaseWsSession / hasWsSession", () => {
|
||||
MockManager.reset();
|
||||
openAIWsStreamTesting.setDepsForTest({
|
||||
createManager: (() => new MockManager()) as never,
|
||||
createHttpFallbackStreamFn: mockCreateHttpFallbackStreamFn as never,
|
||||
streamSimple: mockStreamSimple,
|
||||
});
|
||||
});
|
||||
|
||||
@@ -48,6 +48,7 @@ import {
|
||||
planTurnInput,
|
||||
} from "./openai-ws-message-conversion.js";
|
||||
import { buildOpenAIWebSocketResponseCreatePayload } from "./openai-ws-request.js";
|
||||
import { createBoundaryAwareStreamFnForModel } from "./provider-transport-stream.js";
|
||||
import { log } from "./pi-embedded-runner/logger.js";
|
||||
import {
|
||||
buildAssistantMessageWithZeroUsage,
|
||||
@@ -80,11 +81,13 @@ const wsRegistry = new Map<string, WsSession>();
|
||||
|
||||
type OpenAIWsStreamDeps = {
|
||||
createManager: (options?: OpenAIWebSocketManagerOptions) => OpenAIWebSocketManager;
|
||||
createHttpFallbackStreamFn: (model: ProviderRuntimeModel) => StreamFn | undefined;
|
||||
streamSimple: typeof piAi.streamSimple;
|
||||
};
|
||||
|
||||
const defaultOpenAIWsStreamDeps: OpenAIWsStreamDeps = {
|
||||
createManager: (options) => new OpenAIWebSocketManager(options),
|
||||
createHttpFallbackStreamFn: (model) => createBoundaryAwareStreamFnForModel(model),
|
||||
streamSimple: (...args) => piAi.streamSimple(...args),
|
||||
};
|
||||
|
||||
@@ -916,7 +919,7 @@ export function createOpenAIWebSocketStreamFn(
|
||||
}
|
||||
|
||||
/**
|
||||
* Fall back to HTTP (`streamSimple`) and pipe events into the existing stream.
|
||||
* Fall back to HTTP and pipe events into the existing stream.
|
||||
* This is called when the WebSocket is broken or unavailable.
|
||||
*/
|
||||
async function fallbackToHttp(
|
||||
@@ -957,7 +960,10 @@ async function fallbackToHttp(
|
||||
: {}),
|
||||
...(signal ? { signal } : {}),
|
||||
};
|
||||
const httpStream = openAIWsStreamDeps.streamSimple(model, context, mergedOptions);
|
||||
const httpStreamFn =
|
||||
openAIWsStreamDeps.createHttpFallbackStreamFn(model as ProviderRuntimeModel) ??
|
||||
openAIWsStreamDeps.streamSimple;
|
||||
const httpStream = httpStreamFn(model, context, mergedOptions);
|
||||
for await (const event of httpStream) {
|
||||
if (fallbackOptions?.suppressStart && event.type === "start") {
|
||||
continue;
|
||||
|
||||
@@ -81,6 +81,9 @@ export const sessionMessages: unknown[] = [
|
||||
];
|
||||
export const sessionAbortCompactionMock: Mock<(reason?: unknown) => void> = vi.fn();
|
||||
export const createOpenClawCodingToolsMock = vi.fn(() => []);
|
||||
export const resolveEmbeddedAgentStreamFnMock = vi.fn((_params?: unknown) => vi.fn());
|
||||
export const applyExtraParamsToAgentMock = vi.fn(() => ({ effectiveExtraParams: {} }));
|
||||
export const resolveAgentTransportOverrideMock = vi.fn(() => undefined);
|
||||
|
||||
export function resetCompactSessionStateMocks(): void {
|
||||
sanitizeSessionHistoryMock.mockReset();
|
||||
@@ -122,6 +125,12 @@ export function resetCompactSessionStateMocks(): void {
|
||||
},
|
||||
);
|
||||
sessionAbortCompactionMock.mockReset();
|
||||
resolveEmbeddedAgentStreamFnMock.mockReset();
|
||||
resolveEmbeddedAgentStreamFnMock.mockImplementation((_params?: unknown) => vi.fn());
|
||||
applyExtraParamsToAgentMock.mockReset();
|
||||
applyExtraParamsToAgentMock.mockReturnValue({ effectiveExtraParams: {} });
|
||||
resolveAgentTransportOverrideMock.mockReset();
|
||||
resolveAgentTransportOverrideMock.mockReturnValue(undefined);
|
||||
}
|
||||
|
||||
export function resetCompactHooksHarnessMocks(): void {
|
||||
@@ -223,6 +232,8 @@ export async function loadCompactHooksHarness(): Promise<{
|
||||
session.messages = [...(messages as typeof session.messages)];
|
||||
}),
|
||||
streamFn: vi.fn(),
|
||||
setTransport: vi.fn(),
|
||||
transport: "sse",
|
||||
},
|
||||
compact: vi.fn(async () => {
|
||||
session.messages.splice(1);
|
||||
@@ -340,6 +351,15 @@ export async function loadCompactHooksHarness(): Promise<{
|
||||
normalizeProviderToolSchemas: vi.fn(({ tools }: { tools: unknown[] }) => tools),
|
||||
}));
|
||||
|
||||
vi.doMock("./stream-resolution.js", () => ({
|
||||
resolveEmbeddedAgentStreamFn: resolveEmbeddedAgentStreamFnMock,
|
||||
}));
|
||||
|
||||
vi.doMock("./extra-params.js", () => ({
|
||||
applyExtraParamsToAgent: applyExtraParamsToAgentMock,
|
||||
resolveAgentTransportOverride: resolveAgentTransportOverrideMock,
|
||||
}));
|
||||
|
||||
vi.doMock("./tool-split.js", () => ({
|
||||
splitSdkTools: vi.fn(() => ({ builtInTools: [], customTools: [] })),
|
||||
}));
|
||||
|
||||
@@ -3,13 +3,16 @@ import { getApiProvider, unregisterApiProviders } from "@mariozechner/pi-ai";
|
||||
import { beforeAll, beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import { getCustomApiRegistrySourceId } from "../custom-api-registry.js";
|
||||
import {
|
||||
applyExtraParamsToAgentMock,
|
||||
contextEngineCompactMock,
|
||||
ensureRuntimePluginsLoaded,
|
||||
estimateTokensMock,
|
||||
getMemorySearchManagerMock,
|
||||
hookRunner,
|
||||
loadCompactHooksHarness,
|
||||
resolveAgentTransportOverrideMock,
|
||||
resolveContextEngineMock,
|
||||
resolveEmbeddedAgentStreamFnMock,
|
||||
resolveMemorySearchConfigMock,
|
||||
resolveModelMock,
|
||||
resolveSessionAgentIdMock,
|
||||
@@ -209,6 +212,52 @@ describe("compactEmbeddedPiSessionDirect hooks", () => {
|
||||
});
|
||||
});
|
||||
|
||||
it("routes compaction through shared stream resolution and extra params", async () => {
|
||||
const resolvedStreamFn = vi.fn();
|
||||
resolveEmbeddedAgentStreamFnMock.mockReturnValue(resolvedStreamFn);
|
||||
applyExtraParamsToAgentMock.mockReturnValue({
|
||||
effectiveExtraParams: { transport: "websocket" },
|
||||
});
|
||||
resolveAgentTransportOverrideMock.mockReturnValue("websocket");
|
||||
|
||||
await compactEmbeddedPiSessionDirect({
|
||||
sessionId: "session-1",
|
||||
sessionFile: "/tmp/session.jsonl",
|
||||
workspaceDir: "/tmp/workspace",
|
||||
provider: "openai",
|
||||
model: "gpt-5.4",
|
||||
});
|
||||
|
||||
expect(resolveEmbeddedAgentStreamFnMock).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
currentStreamFn: expect.any(Function),
|
||||
sessionId: "session-1",
|
||||
}),
|
||||
);
|
||||
expect(applyExtraParamsToAgentMock).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
streamFn: resolvedStreamFn,
|
||||
}),
|
||||
undefined,
|
||||
"openai",
|
||||
"gpt-5.4",
|
||||
undefined,
|
||||
undefined,
|
||||
"main",
|
||||
"/tmp/workspace",
|
||||
expect.objectContaining({
|
||||
provider: "openai",
|
||||
id: "fake",
|
||||
}),
|
||||
"/tmp",
|
||||
);
|
||||
expect(resolveAgentTransportOverrideMock).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
effectiveExtraParams: { transport: "websocket" },
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it("emits internal + plugin compaction hooks with counts", async () => {
|
||||
hookRunner.hasHooks.mockReturnValue(true);
|
||||
await runCompactionHooks({
|
||||
|
||||
@@ -80,6 +80,7 @@ import {
|
||||
type SkillSnapshot,
|
||||
} from "../skills.js";
|
||||
import { resolveTranscriptPolicy } from "../transcript-policy.js";
|
||||
import { applyExtraParamsToAgent, resolveAgentTransportOverride } from "./extra-params.js";
|
||||
import { classifyCompactionReason, resolveCompactionFailureReason } from "./compact-reasons.js";
|
||||
import {
|
||||
asCompactionHookRunner,
|
||||
@@ -108,6 +109,10 @@ import { buildModelAliasLines, resolveModelAsync } from "./model.js";
|
||||
import { sanitizeSessionHistory, validateReplayTurns } from "./replay-history.js";
|
||||
import { buildEmbeddedSandboxInfo } from "./sandbox-info.js";
|
||||
import { prewarmSessionFile, trackSessionManagerAccess } from "./session-manager-cache.js";
|
||||
import {
|
||||
resolveEmbeddedAgentBaseStreamFn,
|
||||
resolveEmbeddedAgentStreamFn,
|
||||
} from "./stream-resolution.js";
|
||||
import { truncateSessionAfterCompaction } from "./session-truncation.js";
|
||||
import { resolveEmbeddedRunSkillEntries } from "./skills-runtime.js";
|
||||
import {
|
||||
@@ -124,6 +129,7 @@ import { splitSdkTools } from "./tool-split.js";
|
||||
import type { EmbeddedPiCompactResult } from "./types.js";
|
||||
import { describeUnknownError, mapThinkingLevel } from "./utils.js";
|
||||
import { flushPendingToolResultsAfterIdle } from "./wait-for-idle-before-flush.js";
|
||||
import { shouldUseOpenAIWebSocketTransport } from "./run/attempt.thread-helpers.js";
|
||||
|
||||
export type CompactEmbeddedPiSessionParams = {
|
||||
sessionId: string;
|
||||
@@ -746,13 +752,61 @@ export async function compactEmbeddedPiSessionDirect(
|
||||
});
|
||||
applySystemPromptOverrideToSession(session, systemPromptOverride());
|
||||
const providerStreamFn = registerProviderStreamForModel({
|
||||
model,
|
||||
model: effectiveModel,
|
||||
cfg: params.config,
|
||||
agentDir,
|
||||
workspaceDir: effectiveWorkspace,
|
||||
});
|
||||
if (providerStreamFn) {
|
||||
session.agent.streamFn = providerStreamFn;
|
||||
const shouldUseWebSocketTransport = shouldUseOpenAIWebSocketTransport({
|
||||
provider,
|
||||
modelApi: effectiveModel.api,
|
||||
});
|
||||
const wsApiKey = shouldUseWebSocketTransport
|
||||
? await authStorage.getApiKey(provider)
|
||||
: undefined;
|
||||
if (shouldUseWebSocketTransport && !wsApiKey) {
|
||||
log.warn(
|
||||
`[ws-stream] no API key for provider=${provider}; keeping compaction HTTP transport`,
|
||||
);
|
||||
}
|
||||
// Compaction builds the same embedded system prompt, so it must flow
|
||||
// through the same transport/payload shaping stack as normal turns.
|
||||
session.agent.streamFn = resolveEmbeddedAgentStreamFn({
|
||||
currentStreamFn: resolveEmbeddedAgentBaseStreamFn({ session }),
|
||||
providerStreamFn,
|
||||
shouldUseWebSocketTransport,
|
||||
wsApiKey,
|
||||
sessionId: params.sessionId,
|
||||
signal: runAbortController.signal,
|
||||
model: effectiveModel,
|
||||
authStorage,
|
||||
});
|
||||
const { effectiveExtraParams } = applyExtraParamsToAgent(
|
||||
session.agent,
|
||||
params.config,
|
||||
provider,
|
||||
modelId,
|
||||
undefined,
|
||||
params.thinkLevel,
|
||||
sessionAgentId,
|
||||
effectiveWorkspace,
|
||||
effectiveModel,
|
||||
agentDir,
|
||||
);
|
||||
const agentTransportOverride = resolveAgentTransportOverride({
|
||||
settingsManager,
|
||||
effectiveExtraParams,
|
||||
});
|
||||
if (
|
||||
agentTransportOverride &&
|
||||
typeof (session.agent as { setTransport?: unknown }).setTransport === "function" &&
|
||||
(session.agent as { transport?: unknown }).transport !== agentTransportOverride
|
||||
) {
|
||||
(
|
||||
session.agent as {
|
||||
setTransport(nextTransport: string): void;
|
||||
}
|
||||
).setTransport(agentTransportOverride);
|
||||
}
|
||||
|
||||
try {
|
||||
|
||||
@@ -15,6 +15,8 @@ import {
|
||||
composeSystemPromptWithHookContext,
|
||||
decodeHtmlEntitiesInObject,
|
||||
prependSystemPromptAddition,
|
||||
resetEmbeddedAgentBaseStreamFnCacheForTest,
|
||||
resolveEmbeddedAgentBaseStreamFn,
|
||||
resolveAttemptFsWorkspaceOnly,
|
||||
resolveEmbeddedAgentStreamFn,
|
||||
resolvePromptBuildHookResult,
|
||||
@@ -240,6 +242,21 @@ describe("shouldWarnOnOrphanedUserRepair", () => {
|
||||
});
|
||||
|
||||
describe("resolveEmbeddedAgentStreamFn", () => {
|
||||
it("reuses the session's original base stream across later wrapper mutations", () => {
|
||||
resetEmbeddedAgentBaseStreamFnCacheForTest();
|
||||
const baseStreamFn = vi.fn();
|
||||
const wrapperStreamFn = vi.fn();
|
||||
const session = {
|
||||
agent: {
|
||||
streamFn: baseStreamFn,
|
||||
},
|
||||
};
|
||||
|
||||
expect(resolveEmbeddedAgentBaseStreamFn({ session })).toBe(baseStreamFn);
|
||||
session.agent.streamFn = wrapperStreamFn;
|
||||
expect(resolveEmbeddedAgentBaseStreamFn({ session })).toBe(baseStreamFn);
|
||||
});
|
||||
|
||||
it("injects authStorage api keys into provider-owned stream functions", async () => {
|
||||
const providerStreamFn = vi.fn(async (_model, _context, options) => options);
|
||||
const streamFn = resolveEmbeddedAgentStreamFn({
|
||||
@@ -292,7 +309,6 @@ describe("resolveEmbeddedAgentStreamFn", () => {
|
||||
});
|
||||
expect(providerStreamFn).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it("routes supported default streamSimple fallbacks through boundary-aware transports", () => {
|
||||
const streamFn = resolveEmbeddedAgentStreamFn({
|
||||
currentStreamFn: undefined,
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import fs from "node:fs/promises";
|
||||
import os from "node:os";
|
||||
import type { AgentMessage, StreamFn } from "@mariozechner/pi-agent-core";
|
||||
import { streamSimple } from "@mariozechner/pi-ai";
|
||||
import {
|
||||
createAgentSession,
|
||||
DefaultResourceLoader,
|
||||
@@ -31,7 +30,6 @@ import { isReasoningTagProvider } from "../../../utils/provider-utils.js";
|
||||
import { resolveOpenClawAgentDir } from "../../agent-paths.js";
|
||||
import { resolveSessionAgentIds } from "../../agent-scope.js";
|
||||
import { createAnthropicPayloadLogger } from "../../anthropic-payload-log.js";
|
||||
import { createAnthropicVertexStreamFnForModel } from "../../anthropic-vertex-stream.js";
|
||||
import {
|
||||
analyzeBootstrapBudget,
|
||||
buildBootstrapPromptWarning,
|
||||
@@ -55,7 +53,7 @@ import { buildModelAliasLines } from "../../model-alias-lines.js";
|
||||
import { resolveModelAuthMode } from "../../model-auth.js";
|
||||
import { resolveDefaultModelForAgent } from "../../model-selection.js";
|
||||
import { supportsModelTools } from "../../model-tool-support.js";
|
||||
import { createOpenAIWebSocketStreamFn, releaseWsSession } from "../../openai-ws-stream.js";
|
||||
import { releaseWsSession } from "../../openai-ws-stream.js";
|
||||
import { resolveOwnerDisplaySetting } from "../../owner-display.js";
|
||||
import { createBundleLspToolRuntime } from "../../pi-bundle-lsp-runtime.js";
|
||||
import {
|
||||
@@ -75,7 +73,6 @@ import { applyPiAutoCompactionGuard } from "../../pi-settings.js";
|
||||
import { toClientToolDefinitions } from "../../pi-tool-definition-adapter.js";
|
||||
import { createOpenClawCodingTools, resolveToolLoopDetectionConfig } from "../../pi-tools.js";
|
||||
import { registerProviderStreamForModel } from "../../provider-stream.js";
|
||||
import { createBoundaryAwareStreamFnForModel } from "../../provider-transport-stream.js";
|
||||
import { resolveSandboxContext } from "../../sandbox.js";
|
||||
import { resolveSandboxRuntimeStatus } from "../../sandbox/runtime-status.js";
|
||||
import { repairSessionFileIfNeeded } from "../../session-file-repair.js";
|
||||
@@ -91,7 +88,6 @@ import {
|
||||
applySkillEnvOverridesFromSnapshot,
|
||||
resolveSkillsPromptForRun,
|
||||
} from "../../skills.js";
|
||||
import { stripSystemPromptCacheBoundary } from "../../system-prompt-cache-boundary.js";
|
||||
import { buildSystemPromptParams } from "../../system-prompt-params.js";
|
||||
import { buildSystemPromptReport } from "../../system-prompt-report.js";
|
||||
import { sanitizeToolCallIdsForCloudCodeAssist } from "../../tool-call-id.js";
|
||||
@@ -122,6 +118,11 @@ import {
|
||||
buildEmbeddedSystemPrompt,
|
||||
createSystemPromptOverride,
|
||||
} from "../system-prompt.js";
|
||||
import {
|
||||
resetEmbeddedAgentBaseStreamFnCacheForTest,
|
||||
resolveEmbeddedAgentBaseStreamFn,
|
||||
resolveEmbeddedAgentStreamFn,
|
||||
} from "../stream-resolution.js";
|
||||
import { dropThinkingBlocks } from "../thinking.js";
|
||||
import { collectAllowedToolNames } from "../tool-name-allowlist.js";
|
||||
import { installToolResultContextGuard } from "../tool-result-context-guard.js";
|
||||
@@ -216,67 +217,14 @@ export {
|
||||
wrapStreamFnSanitizeMalformedToolCalls,
|
||||
wrapStreamFnTrimToolCallNames,
|
||||
} from "./attempt.tool-call-normalization.js";
|
||||
export {
|
||||
resetEmbeddedAgentBaseStreamFnCacheForTest,
|
||||
resolveEmbeddedAgentBaseStreamFn,
|
||||
resolveEmbeddedAgentStreamFn,
|
||||
};
|
||||
|
||||
const MAX_BTW_SNAPSHOT_MESSAGES = 100;
|
||||
|
||||
export function resolveEmbeddedAgentStreamFn(params: {
|
||||
currentStreamFn: StreamFn | undefined;
|
||||
providerStreamFn?: StreamFn;
|
||||
shouldUseWebSocketTransport: boolean;
|
||||
wsApiKey?: string;
|
||||
sessionId: string;
|
||||
signal?: AbortSignal;
|
||||
model: EmbeddedRunAttemptParams["model"];
|
||||
authStorage?: { getApiKey(provider: string): Promise<string | undefined> };
|
||||
}): StreamFn {
|
||||
if (params.providerStreamFn) {
|
||||
const inner = params.providerStreamFn;
|
||||
const normalizeContext = (context: Parameters<StreamFn>[1]) =>
|
||||
context.systemPrompt
|
||||
? {
|
||||
...context,
|
||||
systemPrompt: stripSystemPromptCacheBoundary(context.systemPrompt),
|
||||
}
|
||||
: context;
|
||||
// Provider-owned transports bypass pi-coding-agent's default auth lookup,
|
||||
// so keep injecting the resolved runtime apiKey for streamSimple-compatible
|
||||
// transports that still read credentials from options.apiKey.
|
||||
if (params.authStorage) {
|
||||
const { authStorage, model } = params;
|
||||
return async (m, context, options) => {
|
||||
const apiKey = await authStorage.getApiKey(model.provider);
|
||||
return inner(m, normalizeContext(context), {
|
||||
...options,
|
||||
apiKey: apiKey ?? options?.apiKey,
|
||||
});
|
||||
};
|
||||
}
|
||||
return (m, context, options) => inner(m, normalizeContext(context), options);
|
||||
}
|
||||
|
||||
const currentStreamFn = params.currentStreamFn ?? streamSimple;
|
||||
if (params.shouldUseWebSocketTransport) {
|
||||
return params.wsApiKey
|
||||
? createOpenAIWebSocketStreamFn(params.wsApiKey, params.sessionId, {
|
||||
signal: params.signal,
|
||||
})
|
||||
: currentStreamFn;
|
||||
}
|
||||
|
||||
if (params.model.provider === "anthropic-vertex") {
|
||||
return createAnthropicVertexStreamFnForModel(params.model);
|
||||
}
|
||||
|
||||
if (params.currentStreamFn === undefined || params.currentStreamFn === streamSimple) {
|
||||
const boundaryAwareStreamFn = createBoundaryAwareStreamFnForModel(params.model);
|
||||
if (boundaryAwareStreamFn) {
|
||||
return boundaryAwareStreamFn;
|
||||
}
|
||||
}
|
||||
|
||||
return currentStreamFn;
|
||||
}
|
||||
|
||||
function summarizeMessagePayload(msg: AgentMessage): { textChars: number; imageBlocks: number } {
|
||||
const content = (msg as { content?: unknown }).content;
|
||||
if (typeof content === "string") {
|
||||
@@ -953,7 +901,11 @@ export async function runEmbeddedAttempt(
|
||||
workspaceDir: params.workspaceDir,
|
||||
});
|
||||
|
||||
const defaultSessionStreamFn = activeSession.agent.streamFn;
|
||||
// Rebuild each turn from the session's original stream base so prior-turn
|
||||
// wrappers do not pin us to stale provider/API transport behavior.
|
||||
const defaultSessionStreamFn = resolveEmbeddedAgentBaseStreamFn({
|
||||
session: activeSession,
|
||||
});
|
||||
const providerStreamFn = registerProviderStreamForModel({
|
||||
model: params.model,
|
||||
cfg: params.config,
|
||||
|
||||
83
src/agents/pi-embedded-runner/stream-resolution.ts
Normal file
83
src/agents/pi-embedded-runner/stream-resolution.ts
Normal file
@@ -0,0 +1,83 @@
|
||||
import type { StreamFn } from "@mariozechner/pi-agent-core";
|
||||
import { streamSimple } from "@mariozechner/pi-ai";
|
||||
import { createAnthropicVertexStreamFnForModel } from "../anthropic-vertex-stream.js";
|
||||
import { createOpenAIWebSocketStreamFn } from "../openai-ws-stream.js";
|
||||
import { createBoundaryAwareStreamFnForModel } from "../provider-transport-stream.js";
|
||||
import { stripSystemPromptCacheBoundary } from "../system-prompt-cache-boundary.js";
|
||||
import type { EmbeddedRunAttemptParams } from "./run/types.js";
|
||||
|
||||
let embeddedAgentBaseStreamFnCache = new WeakMap<object, StreamFn | undefined>();
|
||||
|
||||
export function resolveEmbeddedAgentBaseStreamFn(params: {
|
||||
session: { agent: { streamFn?: StreamFn } };
|
||||
}): StreamFn | undefined {
|
||||
const cached = embeddedAgentBaseStreamFnCache.get(params.session);
|
||||
if (cached !== undefined || embeddedAgentBaseStreamFnCache.has(params.session)) {
|
||||
return cached;
|
||||
}
|
||||
const baseStreamFn = params.session.agent.streamFn;
|
||||
embeddedAgentBaseStreamFnCache.set(params.session, baseStreamFn);
|
||||
return baseStreamFn;
|
||||
}
|
||||
|
||||
export function resetEmbeddedAgentBaseStreamFnCacheForTest(): void {
|
||||
embeddedAgentBaseStreamFnCache = new WeakMap<object, StreamFn | undefined>();
|
||||
}
|
||||
|
||||
export function resolveEmbeddedAgentStreamFn(params: {
|
||||
currentStreamFn: StreamFn | undefined;
|
||||
providerStreamFn?: StreamFn;
|
||||
shouldUseWebSocketTransport: boolean;
|
||||
wsApiKey?: string;
|
||||
sessionId: string;
|
||||
signal?: AbortSignal;
|
||||
model: EmbeddedRunAttemptParams["model"];
|
||||
authStorage?: { getApiKey(provider: string): Promise<string | undefined> };
|
||||
}): StreamFn {
|
||||
if (params.providerStreamFn) {
|
||||
const inner = params.providerStreamFn;
|
||||
const normalizeContext = (context: Parameters<StreamFn>[1]) =>
|
||||
context.systemPrompt
|
||||
? {
|
||||
...context,
|
||||
systemPrompt: stripSystemPromptCacheBoundary(context.systemPrompt),
|
||||
}
|
||||
: context;
|
||||
// Provider-owned transports bypass pi-coding-agent's default auth lookup,
|
||||
// so keep injecting the resolved runtime apiKey for streamSimple-compatible
|
||||
// transports that still read credentials from options.apiKey.
|
||||
if (params.authStorage) {
|
||||
const { authStorage, model } = params;
|
||||
return async (m, context, options) => {
|
||||
const apiKey = await authStorage.getApiKey(model.provider);
|
||||
return inner(m, normalizeContext(context), {
|
||||
...options,
|
||||
apiKey: apiKey ?? options?.apiKey,
|
||||
});
|
||||
};
|
||||
}
|
||||
return (m, context, options) => inner(m, normalizeContext(context), options);
|
||||
}
|
||||
|
||||
const currentStreamFn = params.currentStreamFn ?? streamSimple;
|
||||
if (params.shouldUseWebSocketTransport) {
|
||||
return params.wsApiKey
|
||||
? createOpenAIWebSocketStreamFn(params.wsApiKey, params.sessionId, {
|
||||
signal: params.signal,
|
||||
})
|
||||
: currentStreamFn;
|
||||
}
|
||||
|
||||
if (params.model.provider === "anthropic-vertex") {
|
||||
return createAnthropicVertexStreamFnForModel(params.model);
|
||||
}
|
||||
|
||||
if (params.currentStreamFn === undefined || params.currentStreamFn === streamSimple) {
|
||||
const boundaryAwareStreamFn = createBoundaryAwareStreamFnForModel(params.model);
|
||||
if (boundaryAwareStreamFn) {
|
||||
return boundaryAwareStreamFn;
|
||||
}
|
||||
}
|
||||
|
||||
return currentStreamFn;
|
||||
}
|
||||
Reference in New Issue
Block a user