diff --git a/src/agents/pi-embedded-runner/run/attempt.spawn-workspace.websocket.test.ts b/src/agents/pi-embedded-runner/run/attempt.spawn-workspace.websocket.test.ts index 906624dff9d..0b73dc6b4b6 100644 --- a/src/agents/pi-embedded-runner/run/attempt.spawn-workspace.websocket.test.ts +++ b/src/agents/pi-embedded-runner/run/attempt.spawn-workspace.websocket.test.ts @@ -1,5 +1,8 @@ import { describe, expect, it } from "vitest"; -import { shouldUseOpenAIWebSocketTransport } from "./attempt.thread-helpers.js"; +import { + shouldUseOpenAIWebSocketTransport, + shouldUseOpenAIWebSocketTransportForAttempt, +} from "./attempt.thread-helpers.js"; describe("openai websocket transport selection", () => { it("accepts direct OpenAI Responses endpoints", () => { @@ -76,4 +79,24 @@ describe("openai websocket transport selection", () => { }), ).toBe(false); }); + + it("honors prepared SSE transport params before selecting websocket", () => { + expect( + shouldUseOpenAIWebSocketTransportForAttempt({ + provider: "openai", + modelApi: "openai-responses", + modelBaseUrl: "https://api.openai.com/v1", + effectiveExtraParams: { transport: "sse" }, + }), + ).toBe(false); + + expect( + shouldUseOpenAIWebSocketTransportForAttempt({ + provider: "openai", + modelApi: "openai-responses", + modelBaseUrl: "https://api.openai.com/v1", + effectiveExtraParams: { transport: "auto" }, + }), + ).toBe(true); + }); }); diff --git a/src/agents/pi-embedded-runner/run/attempt.thread-helpers.ts b/src/agents/pi-embedded-runner/run/attempt.thread-helpers.ts index c1760c970f6..314e39f6b82 100644 --- a/src/agents/pi-embedded-runner/run/attempt.thread-helpers.ts +++ b/src/agents/pi-embedded-runner/run/attempt.thread-helpers.ts @@ -55,6 +55,29 @@ export function shouldUseOpenAIWebSocketTransport(params: { return endpointClass === "default" || endpointClass === "openai-public"; } +function hasExplicitSseTransport(sources: Array | undefined>): boolean { + return sources.some((source) => { + const transport = typeof source?.transport === "string" ? source.transport : ""; + return transport.trim().toLowerCase() === "sse"; + }); +} + +export function shouldUseOpenAIWebSocketTransportForAttempt(params: { + provider: string; + modelApi?: string | null; + modelBaseUrl?: string | null; + streamParams?: Record; + effectiveExtraParams?: Record; + modelParams?: Record; +}): boolean { + if ( + hasExplicitSseTransport([params.streamParams, params.effectiveExtraParams, params.modelParams]) + ) { + return false; + } + return shouldUseOpenAIWebSocketTransport(params); +} + function shouldAppendAttemptCacheTtl(params: { timedOutDuringCompaction: boolean; compactionOccurredThisAttempt: boolean; diff --git a/src/agents/pi-embedded-runner/run/attempt.ts b/src/agents/pi-embedded-runner/run/attempt.ts index 7f56c0b12c9..f116dffa970 100644 --- a/src/agents/pi-embedded-runner/run/attempt.ts +++ b/src/agents/pi-embedded-runner/run/attempt.ts @@ -172,6 +172,8 @@ import { applyExtraParamsToAgent, resolveAgentTransportOverride, resolveExplicitSettingsTransport, + resolveExtraParams, + resolvePreparedExtraParams, } from "../extra-params.js"; import { prepareGooglePromptCacheStreamFn } from "../google-prompt-cache.js"; import { getDmHistoryLimitFromSessionKey, limitHistoryTurns } from "../history.js"; @@ -292,7 +294,7 @@ import { composeSystemPromptWithHookContext, resolveAttemptSpawnWorkspaceDir, shouldPersistCompletedBootstrapTurn, - shouldUseOpenAIWebSocketTransport, + shouldUseOpenAIWebSocketTransportForAttempt, } from "./attempt.thread-helpers.js"; import { shouldRepairMalformedToolCallArguments, @@ -1769,25 +1771,57 @@ export async function runEmbeddedAttempt( const defaultSessionStreamFn = resolveEmbeddedAgentBaseStreamFn({ session: activeSession, }); + const resolvedTransport = resolveExplicitSettingsTransport({ + settingsManager, + sessionTransport: activeSession.agent.transport, + }); + const streamExtraParamsOverride = { + ...params.streamParams, + fastMode: params.fastMode, + }; + const preparedRuntimeExtraParams = params.runtimePlan?.transport.resolveExtraParams({ + extraParamsOverride: streamExtraParamsOverride, + thinkingLevel: params.thinkLevel, + agentId: sessionAgentId, + workspaceDir: effectiveWorkspace, + model: params.model, + resolvedTransport, + }); + const resolvedExtraParams = resolveExtraParams({ + cfg: params.config, + provider: params.provider, + modelId: params.modelId, + agentId: sessionAgentId, + }); + const effectiveExtraParams = + preparedRuntimeExtraParams ?? + resolvePreparedExtraParams({ + cfg: params.config, + provider: params.provider, + modelId: params.modelId, + extraParamsOverride: streamExtraParamsOverride, + thinkingLevel: params.thinkLevel, + agentId: sessionAgentId, + agentDir, + workspaceDir: effectiveWorkspace, + resolvedExtraParams, + model: params.model, + resolvedTransport, + }); const providerStreamFn = registerProviderStreamForModel({ model: params.model, cfg: params.config, agentDir, workspaceDir: effectiveWorkspace, }); - const hasExplicitSseTransport = [ - (params.streamParams as { transport?: unknown } | undefined)?.transport, - (params.model as { params?: { transport?: unknown } }).params?.transport, - ] - .map((value) => (typeof value === "string" ? value.trim().toLowerCase() : "")) - .includes("sse"); - const shouldUseWebSocketTransport = - !hasExplicitSseTransport && - shouldUseOpenAIWebSocketTransport({ - provider: params.provider, - modelApi: params.model.api, - modelBaseUrl: params.model.baseUrl, - }); + const shouldUseWebSocketTransport = shouldUseOpenAIWebSocketTransportForAttempt({ + provider: params.provider, + modelApi: params.model.api, + modelBaseUrl: params.model.baseUrl, + streamParams: params.streamParams, + effectiveExtraParams, + modelParams: (params.model as { params?: Record }).params, + }); const wsApiKey = shouldUseWebSocketTransport ? await resolveEmbeddedAgentApiKey({ provider: params.provider, @@ -1832,23 +1866,7 @@ export async function runEmbeddedAttempt( }); } - const resolvedTransport = resolveExplicitSettingsTransport({ - settingsManager, - sessionTransport: activeSession.agent.transport, - }); - const streamExtraParamsOverride = { - ...params.streamParams, - fastMode: params.fastMode, - }; - const preparedRuntimeExtraParams = params.runtimePlan?.transport.resolveExtraParams({ - extraParamsOverride: streamExtraParamsOverride, - thinkingLevel: params.thinkLevel, - agentId: sessionAgentId, - workspaceDir: effectiveWorkspace, - model: params.model, - resolvedTransport, - }); - const { effectiveExtraParams } = applyExtraParamsToAgent( + applyExtraParamsToAgent( activeSession.agent, params.config, params.provider, @@ -1860,9 +1878,7 @@ export async function runEmbeddedAttempt( params.model, agentDir, resolvedTransport, - preparedRuntimeExtraParams - ? { preparedExtraParams: preparedRuntimeExtraParams } - : undefined, + { preparedExtraParams: effectiveExtraParams }, ); const effectivePromptCacheRetention = resolveCacheRetention( effectiveExtraParams,