diff --git a/CHANGELOG.md b/CHANGELOG.md index 30cd4aa9a75..280bcf601d9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,6 +27,7 @@ Docs: https://docs.openclaw.ai - Memory/dreaming (experimental): add opt-in weighted short-term recall promotion to `MEMORY.md`, managed dreaming modes (`off|core|rem|deep`), and a `/dreaming` command plus Dreams UI so durable memory promotion can run on background cadence without manual scheduling. (#60569) Thanks @vignesh07. - Agents/system prompts: add an internal cache-prefix boundary across Anthropic-family, OpenAI-family, Google, and CLI transport shaping so stable system-prompt prefixes stay reusable without leaking internal cache markers to provider payloads. (#59054) Thanks @coletebou and @vincentkoc. - Docs/memory: add a dedicated Dreaming concept page, expand Memory overview with the Dreaming model, and link Dreaming from further reading to document the experimental opt-in consolidation workflow. Thanks @vignesh07. +- Agents/cache prefixes: route compaction, OpenAI WebSocket HTTP fallback, and later-turn embedded session reuse through the same cache-safe prompt shaping path so Anthropic-family and OpenAI-family requests keep stable prompt bytes across follow-up turns and fallback transport changes. (#60691) Thanks @vincentkoc. ### Fixes diff --git a/src/agents/openai-ws-stream.test.ts b/src/agents/openai-ws-stream.test.ts index 00f6b3484df..80c0226e700 100644 --- a/src/agents/openai-ws-stream.test.ts +++ b/src/agents/openai-ws-stream.test.ts @@ -219,6 +219,7 @@ const mockStreamSimple = vi.fn((model: unknown, context: unknown, options?: unkn }); return stream; }); +const mockCreateHttpFallbackStreamFn = vi.fn(() => mockStreamSimple as never); // ───────────────────────────────────────────────────────────────────────────── // Helpers @@ -1180,8 +1181,11 @@ describe("createOpenAIWebSocketStreamFn", () => { beforeEach(() => { MockManager.reset(); streamSimpleCalls.length = 0; + mockCreateHttpFallbackStreamFn.mockReset(); + mockCreateHttpFallbackStreamFn.mockReturnValue(mockStreamSimple as never); openAIWsStreamTesting.setDepsForTest({ createManager: ((options?: unknown) => new MockManager(options)) as never, + createHttpFallbackStreamFn: mockCreateHttpFallbackStreamFn as never, streamSimple: mockStreamSimple, }); }); @@ -1195,6 +1199,7 @@ describe("createOpenAIWebSocketStreamFn", () => { releaseWsSession("sess-1"); releaseWsSession("sess-2"); releaseWsSession("sess-fallback"); + releaseWsSession("sess-boundary-http-fallback"); releaseWsSession("sess-incremental"); releaseWsSession("sess-full"); releaseWsSession("sess-phase"); @@ -1931,6 +1936,64 @@ describe("createOpenAIWebSocketStreamFn", () => { expect(streamSimpleCalls.length).toBeGreaterThan(callsBefore); }); + it("routes websocket HTTP fallback through the configured HTTP fallback builder", async () => { + const httpFallbackCalls: Array<{ model: unknown; context: unknown; options?: unknown }> = []; + const httpFallbackStreamFn = vi.fn((model: unknown, context: unknown, options?: unknown) => { + httpFallbackCalls.push({ model, context, options }); + const stream = createAssistantMessageEventStream(); + queueMicrotask(() => { + const msg = makeFakeAssistantMessage("boundary-safe fallback"); + stream.push({ type: "done", reason: "stop", message: msg }); + stream.end(); + }); + return stream; + }); + mockCreateHttpFallbackStreamFn.mockReturnValue(httpFallbackStreamFn as never); + const sessionId = "sess-boundary-http-fallback"; + const streamFn = createOpenAIWebSocketStreamFn("sk-test", sessionId); + + const stream1 = streamFn( + modelStub as Parameters[0], + contextStub as Parameters[1], + ); + await new Promise((resolve, reject) => { + queueMicrotask(async () => { + try { + await new Promise((r) => setImmediate(r)); + MockManager.lastInstance!.simulateEvent({ + type: "response.completed", + response: makeResponseObject("resp-ok", "OK"), + }); + for await (const _ of await resolveStream(stream1)) { + /* consume */ + } + resolve(); + } catch (e) { + reject(e); + } + }); + }); + + MockManager.globalSendFailuresRemaining = 2; + const stream2 = streamFn( + modelStub as Parameters[0], + { + ...contextStub, + systemPrompt: `Stable prefix${SYSTEM_PROMPT_CACHE_BOUNDARY}Dynamic suffix`, + } as Parameters[1], + ); + for await (const _ of await resolveStream(stream2)) { + /* consume */ + } + + expect(mockCreateHttpFallbackStreamFn).toHaveBeenCalled(); + expect(streamSimpleCalls).toHaveLength(0); + expect(httpFallbackCalls).toHaveLength(1); + expect(httpFallbackCalls[0]?.context).toMatchObject({ + systemPrompt: `Stable prefix${SYSTEM_PROMPT_CACHE_BOUNDARY}Dynamic suffix`, + }); + }); + it("forwards temperature and maxTokens to response.create", async () => { const streamFn = createOpenAIWebSocketStreamFn("sk-test", "sess-temp"); const opts = { temperature: 0.3, maxTokens: 256 }; @@ -2288,6 +2351,7 @@ describe("releaseWsSession / hasWsSession", () => { MockManager.reset(); openAIWsStreamTesting.setDepsForTest({ createManager: (() => new MockManager()) as never, + createHttpFallbackStreamFn: mockCreateHttpFallbackStreamFn as never, streamSimple: mockStreamSimple, }); }); diff --git a/src/agents/openai-ws-stream.ts b/src/agents/openai-ws-stream.ts index e765b0e78b0..f16feab74dd 100644 --- a/src/agents/openai-ws-stream.ts +++ b/src/agents/openai-ws-stream.ts @@ -48,6 +48,7 @@ import { planTurnInput, } from "./openai-ws-message-conversion.js"; import { buildOpenAIWebSocketResponseCreatePayload } from "./openai-ws-request.js"; +import { createBoundaryAwareStreamFnForModel } from "./provider-transport-stream.js"; import { log } from "./pi-embedded-runner/logger.js"; import { buildAssistantMessageWithZeroUsage, @@ -80,11 +81,13 @@ const wsRegistry = new Map(); type OpenAIWsStreamDeps = { createManager: (options?: OpenAIWebSocketManagerOptions) => OpenAIWebSocketManager; + createHttpFallbackStreamFn: (model: ProviderRuntimeModel) => StreamFn | undefined; streamSimple: typeof piAi.streamSimple; }; const defaultOpenAIWsStreamDeps: OpenAIWsStreamDeps = { createManager: (options) => new OpenAIWebSocketManager(options), + createHttpFallbackStreamFn: (model) => createBoundaryAwareStreamFnForModel(model), streamSimple: (...args) => piAi.streamSimple(...args), }; @@ -916,7 +919,7 @@ export function createOpenAIWebSocketStreamFn( } /** - * Fall back to HTTP (`streamSimple`) and pipe events into the existing stream. + * Fall back to HTTP and pipe events into the existing stream. * This is called when the WebSocket is broken or unavailable. */ async function fallbackToHttp( @@ -957,7 +960,10 @@ async function fallbackToHttp( : {}), ...(signal ? { signal } : {}), }; - const httpStream = openAIWsStreamDeps.streamSimple(model, context, mergedOptions); + const httpStreamFn = + openAIWsStreamDeps.createHttpFallbackStreamFn(model as ProviderRuntimeModel) ?? + openAIWsStreamDeps.streamSimple; + const httpStream = httpStreamFn(model, context, mergedOptions); for await (const event of httpStream) { if (fallbackOptions?.suppressStart && event.type === "start") { continue; diff --git a/src/agents/pi-embedded-runner/compact.hooks.harness.ts b/src/agents/pi-embedded-runner/compact.hooks.harness.ts index 207f382469e..ae27a1bd4b5 100644 --- a/src/agents/pi-embedded-runner/compact.hooks.harness.ts +++ b/src/agents/pi-embedded-runner/compact.hooks.harness.ts @@ -81,6 +81,9 @@ export const sessionMessages: unknown[] = [ ]; export const sessionAbortCompactionMock: Mock<(reason?: unknown) => void> = vi.fn(); export const createOpenClawCodingToolsMock = vi.fn(() => []); +export const resolveEmbeddedAgentStreamFnMock = vi.fn((_params?: unknown) => vi.fn()); +export const applyExtraParamsToAgentMock = vi.fn(() => ({ effectiveExtraParams: {} })); +export const resolveAgentTransportOverrideMock = vi.fn(() => undefined); export function resetCompactSessionStateMocks(): void { sanitizeSessionHistoryMock.mockReset(); @@ -122,6 +125,12 @@ export function resetCompactSessionStateMocks(): void { }, ); sessionAbortCompactionMock.mockReset(); + resolveEmbeddedAgentStreamFnMock.mockReset(); + resolveEmbeddedAgentStreamFnMock.mockImplementation((_params?: unknown) => vi.fn()); + applyExtraParamsToAgentMock.mockReset(); + applyExtraParamsToAgentMock.mockReturnValue({ effectiveExtraParams: {} }); + resolveAgentTransportOverrideMock.mockReset(); + resolveAgentTransportOverrideMock.mockReturnValue(undefined); } export function resetCompactHooksHarnessMocks(): void { @@ -223,6 +232,8 @@ export async function loadCompactHooksHarness(): Promise<{ session.messages = [...(messages as typeof session.messages)]; }), streamFn: vi.fn(), + setTransport: vi.fn(), + transport: "sse", }, compact: vi.fn(async () => { session.messages.splice(1); @@ -340,6 +351,15 @@ export async function loadCompactHooksHarness(): Promise<{ normalizeProviderToolSchemas: vi.fn(({ tools }: { tools: unknown[] }) => tools), })); + vi.doMock("./stream-resolution.js", () => ({ + resolveEmbeddedAgentStreamFn: resolveEmbeddedAgentStreamFnMock, + })); + + vi.doMock("./extra-params.js", () => ({ + applyExtraParamsToAgent: applyExtraParamsToAgentMock, + resolveAgentTransportOverride: resolveAgentTransportOverrideMock, + })); + vi.doMock("./tool-split.js", () => ({ splitSdkTools: vi.fn(() => ({ builtInTools: [], customTools: [] })), })); diff --git a/src/agents/pi-embedded-runner/compact.hooks.test.ts b/src/agents/pi-embedded-runner/compact.hooks.test.ts index ea67c68048d..301b5c5ad3d 100644 --- a/src/agents/pi-embedded-runner/compact.hooks.test.ts +++ b/src/agents/pi-embedded-runner/compact.hooks.test.ts @@ -3,13 +3,16 @@ import { getApiProvider, unregisterApiProviders } from "@mariozechner/pi-ai"; import { beforeAll, beforeEach, describe, expect, it, vi } from "vitest"; import { getCustomApiRegistrySourceId } from "../custom-api-registry.js"; import { + applyExtraParamsToAgentMock, contextEngineCompactMock, ensureRuntimePluginsLoaded, estimateTokensMock, getMemorySearchManagerMock, hookRunner, loadCompactHooksHarness, + resolveAgentTransportOverrideMock, resolveContextEngineMock, + resolveEmbeddedAgentStreamFnMock, resolveMemorySearchConfigMock, resolveModelMock, resolveSessionAgentIdMock, @@ -209,6 +212,52 @@ describe("compactEmbeddedPiSessionDirect hooks", () => { }); }); + it("routes compaction through shared stream resolution and extra params", async () => { + const resolvedStreamFn = vi.fn(); + resolveEmbeddedAgentStreamFnMock.mockReturnValue(resolvedStreamFn); + applyExtraParamsToAgentMock.mockReturnValue({ + effectiveExtraParams: { transport: "websocket" }, + }); + resolveAgentTransportOverrideMock.mockReturnValue("websocket"); + + await compactEmbeddedPiSessionDirect({ + sessionId: "session-1", + sessionFile: "/tmp/session.jsonl", + workspaceDir: "/tmp/workspace", + provider: "openai", + model: "gpt-5.4", + }); + + expect(resolveEmbeddedAgentStreamFnMock).toHaveBeenCalledWith( + expect.objectContaining({ + currentStreamFn: expect.any(Function), + sessionId: "session-1", + }), + ); + expect(applyExtraParamsToAgentMock).toHaveBeenCalledWith( + expect.objectContaining({ + streamFn: resolvedStreamFn, + }), + undefined, + "openai", + "gpt-5.4", + undefined, + undefined, + "main", + "/tmp/workspace", + expect.objectContaining({ + provider: "openai", + id: "fake", + }), + "/tmp", + ); + expect(resolveAgentTransportOverrideMock).toHaveBeenCalledWith( + expect.objectContaining({ + effectiveExtraParams: { transport: "websocket" }, + }), + ); + }); + it("emits internal + plugin compaction hooks with counts", async () => { hookRunner.hasHooks.mockReturnValue(true); await runCompactionHooks({ diff --git a/src/agents/pi-embedded-runner/compact.ts b/src/agents/pi-embedded-runner/compact.ts index d8c23135bc9..071d181023b 100644 --- a/src/agents/pi-embedded-runner/compact.ts +++ b/src/agents/pi-embedded-runner/compact.ts @@ -80,6 +80,7 @@ import { type SkillSnapshot, } from "../skills.js"; import { resolveTranscriptPolicy } from "../transcript-policy.js"; +import { applyExtraParamsToAgent, resolveAgentTransportOverride } from "./extra-params.js"; import { classifyCompactionReason, resolveCompactionFailureReason } from "./compact-reasons.js"; import { asCompactionHookRunner, @@ -108,6 +109,10 @@ import { buildModelAliasLines, resolveModelAsync } from "./model.js"; import { sanitizeSessionHistory, validateReplayTurns } from "./replay-history.js"; import { buildEmbeddedSandboxInfo } from "./sandbox-info.js"; import { prewarmSessionFile, trackSessionManagerAccess } from "./session-manager-cache.js"; +import { + resolveEmbeddedAgentBaseStreamFn, + resolveEmbeddedAgentStreamFn, +} from "./stream-resolution.js"; import { truncateSessionAfterCompaction } from "./session-truncation.js"; import { resolveEmbeddedRunSkillEntries } from "./skills-runtime.js"; import { @@ -124,6 +129,7 @@ import { splitSdkTools } from "./tool-split.js"; import type { EmbeddedPiCompactResult } from "./types.js"; import { describeUnknownError, mapThinkingLevel } from "./utils.js"; import { flushPendingToolResultsAfterIdle } from "./wait-for-idle-before-flush.js"; +import { shouldUseOpenAIWebSocketTransport } from "./run/attempt.thread-helpers.js"; export type CompactEmbeddedPiSessionParams = { sessionId: string; @@ -746,13 +752,61 @@ export async function compactEmbeddedPiSessionDirect( }); applySystemPromptOverrideToSession(session, systemPromptOverride()); const providerStreamFn = registerProviderStreamForModel({ - model, + model: effectiveModel, cfg: params.config, agentDir, workspaceDir: effectiveWorkspace, }); - if (providerStreamFn) { - session.agent.streamFn = providerStreamFn; + const shouldUseWebSocketTransport = shouldUseOpenAIWebSocketTransport({ + provider, + modelApi: effectiveModel.api, + }); + const wsApiKey = shouldUseWebSocketTransport + ? await authStorage.getApiKey(provider) + : undefined; + if (shouldUseWebSocketTransport && !wsApiKey) { + log.warn( + `[ws-stream] no API key for provider=${provider}; keeping compaction HTTP transport`, + ); + } + // Compaction builds the same embedded system prompt, so it must flow + // through the same transport/payload shaping stack as normal turns. + session.agent.streamFn = resolveEmbeddedAgentStreamFn({ + currentStreamFn: resolveEmbeddedAgentBaseStreamFn({ session }), + providerStreamFn, + shouldUseWebSocketTransport, + wsApiKey, + sessionId: params.sessionId, + signal: runAbortController.signal, + model: effectiveModel, + authStorage, + }); + const { effectiveExtraParams } = applyExtraParamsToAgent( + session.agent, + params.config, + provider, + modelId, + undefined, + params.thinkLevel, + sessionAgentId, + effectiveWorkspace, + effectiveModel, + agentDir, + ); + const agentTransportOverride = resolveAgentTransportOverride({ + settingsManager, + effectiveExtraParams, + }); + if ( + agentTransportOverride && + typeof (session.agent as { setTransport?: unknown }).setTransport === "function" && + (session.agent as { transport?: unknown }).transport !== agentTransportOverride + ) { + ( + session.agent as { + setTransport(nextTransport: string): void; + } + ).setTransport(agentTransportOverride); } try { diff --git a/src/agents/pi-embedded-runner/run/attempt.test.ts b/src/agents/pi-embedded-runner/run/attempt.test.ts index 4155c48e4a7..54796418c0a 100644 --- a/src/agents/pi-embedded-runner/run/attempt.test.ts +++ b/src/agents/pi-embedded-runner/run/attempt.test.ts @@ -15,6 +15,8 @@ import { composeSystemPromptWithHookContext, decodeHtmlEntitiesInObject, prependSystemPromptAddition, + resetEmbeddedAgentBaseStreamFnCacheForTest, + resolveEmbeddedAgentBaseStreamFn, resolveAttemptFsWorkspaceOnly, resolveEmbeddedAgentStreamFn, resolvePromptBuildHookResult, @@ -240,6 +242,21 @@ describe("shouldWarnOnOrphanedUserRepair", () => { }); describe("resolveEmbeddedAgentStreamFn", () => { + it("reuses the session's original base stream across later wrapper mutations", () => { + resetEmbeddedAgentBaseStreamFnCacheForTest(); + const baseStreamFn = vi.fn(); + const wrapperStreamFn = vi.fn(); + const session = { + agent: { + streamFn: baseStreamFn, + }, + }; + + expect(resolveEmbeddedAgentBaseStreamFn({ session })).toBe(baseStreamFn); + session.agent.streamFn = wrapperStreamFn; + expect(resolveEmbeddedAgentBaseStreamFn({ session })).toBe(baseStreamFn); + }); + it("injects authStorage api keys into provider-owned stream functions", async () => { const providerStreamFn = vi.fn(async (_model, _context, options) => options); const streamFn = resolveEmbeddedAgentStreamFn({ @@ -292,7 +309,6 @@ describe("resolveEmbeddedAgentStreamFn", () => { }); expect(providerStreamFn).toHaveBeenCalledTimes(1); }); - it("routes supported default streamSimple fallbacks through boundary-aware transports", () => { const streamFn = resolveEmbeddedAgentStreamFn({ currentStreamFn: undefined, diff --git a/src/agents/pi-embedded-runner/run/attempt.ts b/src/agents/pi-embedded-runner/run/attempt.ts index d0770decdf6..8dbbd75aaf5 100644 --- a/src/agents/pi-embedded-runner/run/attempt.ts +++ b/src/agents/pi-embedded-runner/run/attempt.ts @@ -1,7 +1,6 @@ import fs from "node:fs/promises"; import os from "node:os"; import type { AgentMessage, StreamFn } from "@mariozechner/pi-agent-core"; -import { streamSimple } from "@mariozechner/pi-ai"; import { createAgentSession, DefaultResourceLoader, @@ -31,7 +30,6 @@ import { isReasoningTagProvider } from "../../../utils/provider-utils.js"; import { resolveOpenClawAgentDir } from "../../agent-paths.js"; import { resolveSessionAgentIds } from "../../agent-scope.js"; import { createAnthropicPayloadLogger } from "../../anthropic-payload-log.js"; -import { createAnthropicVertexStreamFnForModel } from "../../anthropic-vertex-stream.js"; import { analyzeBootstrapBudget, buildBootstrapPromptWarning, @@ -55,7 +53,7 @@ import { buildModelAliasLines } from "../../model-alias-lines.js"; import { resolveModelAuthMode } from "../../model-auth.js"; import { resolveDefaultModelForAgent } from "../../model-selection.js"; import { supportsModelTools } from "../../model-tool-support.js"; -import { createOpenAIWebSocketStreamFn, releaseWsSession } from "../../openai-ws-stream.js"; +import { releaseWsSession } from "../../openai-ws-stream.js"; import { resolveOwnerDisplaySetting } from "../../owner-display.js"; import { createBundleLspToolRuntime } from "../../pi-bundle-lsp-runtime.js"; import { @@ -75,7 +73,6 @@ import { applyPiAutoCompactionGuard } from "../../pi-settings.js"; import { toClientToolDefinitions } from "../../pi-tool-definition-adapter.js"; import { createOpenClawCodingTools, resolveToolLoopDetectionConfig } from "../../pi-tools.js"; import { registerProviderStreamForModel } from "../../provider-stream.js"; -import { createBoundaryAwareStreamFnForModel } from "../../provider-transport-stream.js"; import { resolveSandboxContext } from "../../sandbox.js"; import { resolveSandboxRuntimeStatus } from "../../sandbox/runtime-status.js"; import { repairSessionFileIfNeeded } from "../../session-file-repair.js"; @@ -91,7 +88,6 @@ import { applySkillEnvOverridesFromSnapshot, resolveSkillsPromptForRun, } from "../../skills.js"; -import { stripSystemPromptCacheBoundary } from "../../system-prompt-cache-boundary.js"; import { buildSystemPromptParams } from "../../system-prompt-params.js"; import { buildSystemPromptReport } from "../../system-prompt-report.js"; import { sanitizeToolCallIdsForCloudCodeAssist } from "../../tool-call-id.js"; @@ -122,6 +118,11 @@ import { buildEmbeddedSystemPrompt, createSystemPromptOverride, } from "../system-prompt.js"; +import { + resetEmbeddedAgentBaseStreamFnCacheForTest, + resolveEmbeddedAgentBaseStreamFn, + resolveEmbeddedAgentStreamFn, +} from "../stream-resolution.js"; import { dropThinkingBlocks } from "../thinking.js"; import { collectAllowedToolNames } from "../tool-name-allowlist.js"; import { installToolResultContextGuard } from "../tool-result-context-guard.js"; @@ -216,67 +217,14 @@ export { wrapStreamFnSanitizeMalformedToolCalls, wrapStreamFnTrimToolCallNames, } from "./attempt.tool-call-normalization.js"; +export { + resetEmbeddedAgentBaseStreamFnCacheForTest, + resolveEmbeddedAgentBaseStreamFn, + resolveEmbeddedAgentStreamFn, +}; const MAX_BTW_SNAPSHOT_MESSAGES = 100; -export function resolveEmbeddedAgentStreamFn(params: { - currentStreamFn: StreamFn | undefined; - providerStreamFn?: StreamFn; - shouldUseWebSocketTransport: boolean; - wsApiKey?: string; - sessionId: string; - signal?: AbortSignal; - model: EmbeddedRunAttemptParams["model"]; - authStorage?: { getApiKey(provider: string): Promise }; -}): StreamFn { - if (params.providerStreamFn) { - const inner = params.providerStreamFn; - const normalizeContext = (context: Parameters[1]) => - context.systemPrompt - ? { - ...context, - systemPrompt: stripSystemPromptCacheBoundary(context.systemPrompt), - } - : context; - // Provider-owned transports bypass pi-coding-agent's default auth lookup, - // so keep injecting the resolved runtime apiKey for streamSimple-compatible - // transports that still read credentials from options.apiKey. - if (params.authStorage) { - const { authStorage, model } = params; - return async (m, context, options) => { - const apiKey = await authStorage.getApiKey(model.provider); - return inner(m, normalizeContext(context), { - ...options, - apiKey: apiKey ?? options?.apiKey, - }); - }; - } - return (m, context, options) => inner(m, normalizeContext(context), options); - } - - const currentStreamFn = params.currentStreamFn ?? streamSimple; - if (params.shouldUseWebSocketTransport) { - return params.wsApiKey - ? createOpenAIWebSocketStreamFn(params.wsApiKey, params.sessionId, { - signal: params.signal, - }) - : currentStreamFn; - } - - if (params.model.provider === "anthropic-vertex") { - return createAnthropicVertexStreamFnForModel(params.model); - } - - if (params.currentStreamFn === undefined || params.currentStreamFn === streamSimple) { - const boundaryAwareStreamFn = createBoundaryAwareStreamFnForModel(params.model); - if (boundaryAwareStreamFn) { - return boundaryAwareStreamFn; - } - } - - return currentStreamFn; -} - function summarizeMessagePayload(msg: AgentMessage): { textChars: number; imageBlocks: number } { const content = (msg as { content?: unknown }).content; if (typeof content === "string") { @@ -953,7 +901,11 @@ export async function runEmbeddedAttempt( workspaceDir: params.workspaceDir, }); - const defaultSessionStreamFn = activeSession.agent.streamFn; + // Rebuild each turn from the session's original stream base so prior-turn + // wrappers do not pin us to stale provider/API transport behavior. + const defaultSessionStreamFn = resolveEmbeddedAgentBaseStreamFn({ + session: activeSession, + }); const providerStreamFn = registerProviderStreamForModel({ model: params.model, cfg: params.config, diff --git a/src/agents/pi-embedded-runner/stream-resolution.ts b/src/agents/pi-embedded-runner/stream-resolution.ts new file mode 100644 index 00000000000..2f36b6beed7 --- /dev/null +++ b/src/agents/pi-embedded-runner/stream-resolution.ts @@ -0,0 +1,83 @@ +import type { StreamFn } from "@mariozechner/pi-agent-core"; +import { streamSimple } from "@mariozechner/pi-ai"; +import { createAnthropicVertexStreamFnForModel } from "../anthropic-vertex-stream.js"; +import { createOpenAIWebSocketStreamFn } from "../openai-ws-stream.js"; +import { createBoundaryAwareStreamFnForModel } from "../provider-transport-stream.js"; +import { stripSystemPromptCacheBoundary } from "../system-prompt-cache-boundary.js"; +import type { EmbeddedRunAttemptParams } from "./run/types.js"; + +let embeddedAgentBaseStreamFnCache = new WeakMap(); + +export function resolveEmbeddedAgentBaseStreamFn(params: { + session: { agent: { streamFn?: StreamFn } }; +}): StreamFn | undefined { + const cached = embeddedAgentBaseStreamFnCache.get(params.session); + if (cached !== undefined || embeddedAgentBaseStreamFnCache.has(params.session)) { + return cached; + } + const baseStreamFn = params.session.agent.streamFn; + embeddedAgentBaseStreamFnCache.set(params.session, baseStreamFn); + return baseStreamFn; +} + +export function resetEmbeddedAgentBaseStreamFnCacheForTest(): void { + embeddedAgentBaseStreamFnCache = new WeakMap(); +} + +export function resolveEmbeddedAgentStreamFn(params: { + currentStreamFn: StreamFn | undefined; + providerStreamFn?: StreamFn; + shouldUseWebSocketTransport: boolean; + wsApiKey?: string; + sessionId: string; + signal?: AbortSignal; + model: EmbeddedRunAttemptParams["model"]; + authStorage?: { getApiKey(provider: string): Promise }; +}): StreamFn { + if (params.providerStreamFn) { + const inner = params.providerStreamFn; + const normalizeContext = (context: Parameters[1]) => + context.systemPrompt + ? { + ...context, + systemPrompt: stripSystemPromptCacheBoundary(context.systemPrompt), + } + : context; + // Provider-owned transports bypass pi-coding-agent's default auth lookup, + // so keep injecting the resolved runtime apiKey for streamSimple-compatible + // transports that still read credentials from options.apiKey. + if (params.authStorage) { + const { authStorage, model } = params; + return async (m, context, options) => { + const apiKey = await authStorage.getApiKey(model.provider); + return inner(m, normalizeContext(context), { + ...options, + apiKey: apiKey ?? options?.apiKey, + }); + }; + } + return (m, context, options) => inner(m, normalizeContext(context), options); + } + + const currentStreamFn = params.currentStreamFn ?? streamSimple; + if (params.shouldUseWebSocketTransport) { + return params.wsApiKey + ? createOpenAIWebSocketStreamFn(params.wsApiKey, params.sessionId, { + signal: params.signal, + }) + : currentStreamFn; + } + + if (params.model.provider === "anthropic-vertex") { + return createAnthropicVertexStreamFnForModel(params.model); + } + + if (params.currentStreamFn === undefined || params.currentStreamFn === streamSimple) { + const boundaryAwareStreamFn = createBoundaryAwareStreamFnForModel(params.model); + if (boundaryAwareStreamFn) { + return boundaryAwareStreamFn; + } + } + + return currentStreamFn; +}