From 198a42bbc68107a47ee38e577f2a5aaa5e8bcf56 Mon Sep 17 00:00:00 2001 From: Neerav Makwana <261249544+neeravmakwana@users.noreply.github.com> Date: Thu, 21 May 2026 09:55:28 -0400 Subject: [PATCH] fix(google): cache managed tool config --- .../google-prompt-cache.test.ts | 22 ++++- .../pi-embedded-runner/google-prompt-cache.ts | 95 ++++++++++++++++++- 2 files changed, 112 insertions(+), 5 deletions(-) diff --git a/src/agents/pi-embedded-runner/google-prompt-cache.test.ts b/src/agents/pi-embedded-runner/google-prompt-cache.test.ts index 4fd63f38061..307d36e3727 100644 --- a/src/agents/pi-embedded-runner/google-prompt-cache.test.ts +++ b/src/agents/pi-embedded-runner/google-prompt-cache.test.ts @@ -179,7 +179,7 @@ describe("google prompt cache", () => { }, ], } as never, - { temperature: 0.2 } as never, + { temperature: 0.2, toolChoice: "auto" } as never, ), ); @@ -200,11 +200,28 @@ describe("google prompt cache", () => { systemInstruction: { parts: [{ text: "Follow policy." }], }, + tools: [ + { + functionDeclarations: [ + { + name: "lookup", + description: "Look up a value", + parametersJsonSchema: { type: "object" }, + }, + ], + }, + ], + toolConfig: { + functionCallingConfig: { + mode: "AUTO", + }, + }, }); expect(innerStreamFn).toHaveBeenCalledTimes(1); expect(streamContext(innerStreamFn).systemPrompt).toBeUndefined(); - expect(Array.isArray(streamContext(innerStreamFn).tools)).toBe(true); + expect(streamContext(innerStreamFn).tools).toBeUndefined(); expect(streamOptions(innerStreamFn).temperature).toBe(0.2); + expect(streamOptions(innerStreamFn).toolChoice).toBe("auto"); expect(getCapturedPayload()?.cachedContent).toBe("cachedContents/system-cache-1"); expect(entries).toEqual([ { @@ -221,6 +238,7 @@ describe("google prompt cache", () => { modelApi: "google-generative-ai", baseUrl: "https://generativelanguage.googleapis.com/v1beta", systemPromptDigest, + cacheConfigDigest: expect.any(String), cacheRetention: "long", cachedContent: "cachedContents/system-cache-1", expireTime, diff --git a/src/agents/pi-embedded-runner/google-prompt-cache.ts b/src/agents/pi-embedded-runner/google-prompt-cache.ts index 18fe5972c6c..a945beab7ca 100644 --- a/src/agents/pi-embedded-runner/google-prompt-cache.ts +++ b/src/agents/pi-embedded-runner/google-prompt-cache.ts @@ -31,6 +31,8 @@ type GooglePromptCacheModel = Model & { headers?: Record; provider: string; }; +type GooglePromptCacheContext = Parameters[1]; +type GooglePromptCacheOptions = Parameters[2]; type GooglePromptCacheEntry = { timestamp: number; @@ -39,6 +41,7 @@ type GooglePromptCacheEntry = { modelApi?: string | null; baseUrl: string; systemPromptDigest: string; + cacheConfigDigest?: string; cacheRetention: CacheRetention; } & ( | { @@ -111,6 +114,7 @@ function buildGooglePromptCacheMatchKey(params: { modelApi?: string | null; baseUrl: string; systemPromptDigest: string; + cacheConfigDigest?: string; }) { return stableStringify(params); } @@ -150,6 +154,8 @@ function readLatestGooglePromptCacheEntry( : null, baseUrl: stringifyGooglePromptCacheKeyPart(cacheData.baseUrl), systemPromptDigest: stringifyGooglePromptCacheKeyPart(cacheData.systemPromptDigest), + cacheConfigDigest: + typeof cacheData.cacheConfigDigest === "string" ? cacheData.cacheConfigDigest : undefined, }); if (candidateKey === matchKey) { return data as GooglePromptCacheEntry; @@ -183,13 +189,79 @@ function parseExpireTimeMs(expireTime: string | undefined): number | null { return Number.isFinite(timestamp) ? timestamp : null; } -function buildManagedContextWithoutSystemPrompt(context: Parameters[1]) { - if (!context.systemPrompt) { +function convertManagedGoogleTools(tools: NonNullable) { + if (tools.length === 0) { + return undefined; + } + return [ + { + functionDeclarations: tools.map((tool) => ({ + name: tool.name, + description: tool.description, + parametersJsonSchema: tool.parameters, + })), + }, + ]; +} + +function mapManagedGoogleToolChoice( + choice: unknown, +): { mode: "AUTO" | "NONE" | "ANY"; allowedFunctionNames?: string[] } | undefined { + if (!choice) { + return undefined; + } + if ( + typeof choice === "object" && + choice !== null && + (choice as { type?: unknown }).type === "function" + ) { + const functionName = (choice as { function?: { name?: unknown } }).function?.name; + return typeof functionName === "string" + ? { mode: "ANY", allowedFunctionNames: [functionName] } + : { mode: "ANY" }; + } + switch (choice) { + case "none": + return { mode: "NONE" }; + case "any": + case "required": + return { mode: "ANY" }; + default: + return { mode: "AUTO" }; + } +} + +function buildManagedGooglePromptCacheConfig( + context: GooglePromptCacheContext, + options: GooglePromptCacheOptions, +) { + const tools = context.tools?.length ? convertManagedGoogleTools(context.tools) : undefined; + const toolChoice = tools + ? mapManagedGoogleToolChoice((options as { toolChoice?: unknown } | undefined)?.toolChoice) + : undefined; + const toolConfig = toolChoice ? { functionCallingConfig: toolChoice } : undefined; + const cacheConfigDigest = + tools || toolConfig + ? stableStringify({ + tools, + toolConfig, + }) + : undefined; + return { + cacheConfigDigest, + tools, + toolConfig, + }; +} + +function buildManagedContextForCachedContent(context: GooglePromptCacheContext) { + if (!context.systemPrompt && !context.tools?.length) { return context; } return { ...context, systemPrompt: undefined, + tools: undefined, }; } @@ -229,6 +301,8 @@ async function createGooglePromptCache(params: { modelId: string; signal?: AbortSignal; systemPrompt: string; + tools?: unknown; + toolConfig?: unknown; }): Promise<{ cachedContent: string; expireTime?: string } | null> { const response = await params.fetchImpl(`${params.baseUrl}/cachedContents`, { method: "POST", @@ -239,6 +313,8 @@ async function createGooglePromptCache(params: { systemInstruction: { parts: [{ text: params.systemPrompt }], }, + ...(params.tools ? { tools: params.tools } : {}), + ...(params.toolConfig ? { toolConfig: params.toolConfig } : {}), }), signal: params.signal, }); @@ -256,9 +332,12 @@ async function ensureGooglePromptCache( cacheRetention: CacheRetention; model: GooglePromptCacheModel; provider: string; + cacheConfigDigest?: string; sessionManager: GooglePromptCacheSessionManager; signal?: AbortSignal; systemPrompt: string; + tools?: unknown; + toolConfig?: unknown; }, deps: GooglePromptCacheDeps, ): Promise { @@ -271,6 +350,7 @@ async function ensureGooglePromptCache( modelApi: params.model.api, baseUrl, systemPromptDigest, + cacheConfigDigest: params.cacheConfigDigest, }); const latestEntry = readLatestGooglePromptCacheEntry(params.sessionManager, matchKey); @@ -306,6 +386,7 @@ async function ensureGooglePromptCache( modelApi: params.model.api, baseUrl, systemPromptDigest, + cacheConfigDigest: params.cacheConfigDigest, cacheRetention: params.cacheRetention, cachedContent: latestEntry.cachedContent, expireTime: refreshed.expireTime ?? latestEntry.expireTime, @@ -325,6 +406,8 @@ async function ensureGooglePromptCache( modelId: params.model.id, signal: params.signal, systemPrompt: params.systemPrompt, + tools: params.tools, + toolConfig: params.toolConfig, }); if (!created) { await appendGooglePromptCacheEntry(params.sessionManager, { @@ -335,6 +418,7 @@ async function ensureGooglePromptCache( modelApi: params.model.api, baseUrl, systemPromptDigest, + cacheConfigDigest: params.cacheConfigDigest, cacheRetention: params.cacheRetention, retryAfter: now + GOOGLE_PROMPT_CACHE_RETRY_BACKOFF_MS, }); @@ -349,6 +433,7 @@ async function ensureGooglePromptCache( modelApi: params.model.api, baseUrl, systemPromptDigest, + cacheConfigDigest: params.cacheConfigDigest, cacheRetention: params.cacheRetention, cachedContent: created.cachedContent, expireTime: created.expireTime, @@ -386,15 +471,19 @@ export async function prepareGooglePromptCacheStreamFn( const inner = params.streamFn; return async (model, context, options) => { + const cacheConfig = buildManagedGooglePromptCacheConfig(context, options); const cachedContent = await ensureGooglePromptCache( { apiKey, + cacheConfigDigest: cacheConfig.cacheConfigDigest, cacheRetention: resolvedRetention, model: params.model, provider: params.provider, sessionManager: params.sessionManager, signal: params.signal, systemPrompt, + tools: cacheConfig.tools, + toolConfig: cacheConfig.toolConfig, }, deps, ); @@ -408,7 +497,7 @@ export async function prepareGooglePromptCacheStreamFn( return streamWithPayloadPatch( inner, model, - buildManagedContextWithoutSystemPrompt(context), + buildManagedContextForCachedContent(context), options, (payload) => { payload.cachedContent = cachedContent;