fix(google): cache managed tool config

This commit is contained in:
Neerav Makwana
2026-05-21 09:55:28 -04:00
parent f7be167430
commit 198a42bbc6
2 changed files with 112 additions and 5 deletions

View File

@@ -179,7 +179,7 @@ describe("google prompt cache", () => {
},
],
} as never,
{ temperature: 0.2 } as never,
{ temperature: 0.2, toolChoice: "auto" } as never,
),
);
@@ -200,11 +200,28 @@ describe("google prompt cache", () => {
systemInstruction: {
parts: [{ text: "Follow policy." }],
},
tools: [
{
functionDeclarations: [
{
name: "lookup",
description: "Look up a value",
parametersJsonSchema: { type: "object" },
},
],
},
],
toolConfig: {
functionCallingConfig: {
mode: "AUTO",
},
},
});
expect(innerStreamFn).toHaveBeenCalledTimes(1);
expect(streamContext(innerStreamFn).systemPrompt).toBeUndefined();
expect(Array.isArray(streamContext(innerStreamFn).tools)).toBe(true);
expect(streamContext(innerStreamFn).tools).toBeUndefined();
expect(streamOptions(innerStreamFn).temperature).toBe(0.2);
expect(streamOptions(innerStreamFn).toolChoice).toBe("auto");
expect(getCapturedPayload()?.cachedContent).toBe("cachedContents/system-cache-1");
expect(entries).toEqual([
{
@@ -221,6 +238,7 @@ describe("google prompt cache", () => {
modelApi: "google-generative-ai",
baseUrl: "https://generativelanguage.googleapis.com/v1beta",
systemPromptDigest,
cacheConfigDigest: expect.any(String),
cacheRetention: "long",
cachedContent: "cachedContents/system-cache-1",
expireTime,

View File

@@ -31,6 +31,8 @@ type GooglePromptCacheModel = Model<Api> & {
headers?: Record<string, string>;
provider: string;
};
type GooglePromptCacheContext = Parameters<StreamFn>[1];
type GooglePromptCacheOptions = Parameters<StreamFn>[2];
type GooglePromptCacheEntry = {
timestamp: number;
@@ -39,6 +41,7 @@ type GooglePromptCacheEntry = {
modelApi?: string | null;
baseUrl: string;
systemPromptDigest: string;
cacheConfigDigest?: string;
cacheRetention: CacheRetention;
} & (
| {
@@ -111,6 +114,7 @@ function buildGooglePromptCacheMatchKey(params: {
modelApi?: string | null;
baseUrl: string;
systemPromptDigest: string;
cacheConfigDigest?: string;
}) {
return stableStringify(params);
}
@@ -150,6 +154,8 @@ function readLatestGooglePromptCacheEntry(
: null,
baseUrl: stringifyGooglePromptCacheKeyPart(cacheData.baseUrl),
systemPromptDigest: stringifyGooglePromptCacheKeyPart(cacheData.systemPromptDigest),
cacheConfigDigest:
typeof cacheData.cacheConfigDigest === "string" ? cacheData.cacheConfigDigest : undefined,
});
if (candidateKey === matchKey) {
return data as GooglePromptCacheEntry;
@@ -183,13 +189,79 @@ function parseExpireTimeMs(expireTime: string | undefined): number | null {
return Number.isFinite(timestamp) ? timestamp : null;
}
function buildManagedContextWithoutSystemPrompt(context: Parameters<StreamFn>[1]) {
if (!context.systemPrompt) {
function convertManagedGoogleTools(tools: NonNullable<GooglePromptCacheContext["tools"]>) {
if (tools.length === 0) {
return undefined;
}
return [
{
functionDeclarations: tools.map((tool) => ({
name: tool.name,
description: tool.description,
parametersJsonSchema: tool.parameters,
})),
},
];
}
function mapManagedGoogleToolChoice(
choice: unknown,
): { mode: "AUTO" | "NONE" | "ANY"; allowedFunctionNames?: string[] } | undefined {
if (!choice) {
return undefined;
}
if (
typeof choice === "object" &&
choice !== null &&
(choice as { type?: unknown }).type === "function"
) {
const functionName = (choice as { function?: { name?: unknown } }).function?.name;
return typeof functionName === "string"
? { mode: "ANY", allowedFunctionNames: [functionName] }
: { mode: "ANY" };
}
switch (choice) {
case "none":
return { mode: "NONE" };
case "any":
case "required":
return { mode: "ANY" };
default:
return { mode: "AUTO" };
}
}
function buildManagedGooglePromptCacheConfig(
context: GooglePromptCacheContext,
options: GooglePromptCacheOptions,
) {
const tools = context.tools?.length ? convertManagedGoogleTools(context.tools) : undefined;
const toolChoice = tools
? mapManagedGoogleToolChoice((options as { toolChoice?: unknown } | undefined)?.toolChoice)
: undefined;
const toolConfig = toolChoice ? { functionCallingConfig: toolChoice } : undefined;
const cacheConfigDigest =
tools || toolConfig
? stableStringify({
tools,
toolConfig,
})
: undefined;
return {
cacheConfigDigest,
tools,
toolConfig,
};
}
function buildManagedContextForCachedContent(context: GooglePromptCacheContext) {
if (!context.systemPrompt && !context.tools?.length) {
return context;
}
return {
...context,
systemPrompt: undefined,
tools: undefined,
};
}
@@ -229,6 +301,8 @@ async function createGooglePromptCache(params: {
modelId: string;
signal?: AbortSignal;
systemPrompt: string;
tools?: unknown;
toolConfig?: unknown;
}): Promise<{ cachedContent: string; expireTime?: string } | null> {
const response = await params.fetchImpl(`${params.baseUrl}/cachedContents`, {
method: "POST",
@@ -239,6 +313,8 @@ async function createGooglePromptCache(params: {
systemInstruction: {
parts: [{ text: params.systemPrompt }],
},
...(params.tools ? { tools: params.tools } : {}),
...(params.toolConfig ? { toolConfig: params.toolConfig } : {}),
}),
signal: params.signal,
});
@@ -256,9 +332,12 @@ async function ensureGooglePromptCache(
cacheRetention: CacheRetention;
model: GooglePromptCacheModel;
provider: string;
cacheConfigDigest?: string;
sessionManager: GooglePromptCacheSessionManager;
signal?: AbortSignal;
systemPrompt: string;
tools?: unknown;
toolConfig?: unknown;
},
deps: GooglePromptCacheDeps,
): Promise<string | null> {
@@ -271,6 +350,7 @@ async function ensureGooglePromptCache(
modelApi: params.model.api,
baseUrl,
systemPromptDigest,
cacheConfigDigest: params.cacheConfigDigest,
});
const latestEntry = readLatestGooglePromptCacheEntry(params.sessionManager, matchKey);
@@ -306,6 +386,7 @@ async function ensureGooglePromptCache(
modelApi: params.model.api,
baseUrl,
systemPromptDigest,
cacheConfigDigest: params.cacheConfigDigest,
cacheRetention: params.cacheRetention,
cachedContent: latestEntry.cachedContent,
expireTime: refreshed.expireTime ?? latestEntry.expireTime,
@@ -325,6 +406,8 @@ async function ensureGooglePromptCache(
modelId: params.model.id,
signal: params.signal,
systemPrompt: params.systemPrompt,
tools: params.tools,
toolConfig: params.toolConfig,
});
if (!created) {
await appendGooglePromptCacheEntry(params.sessionManager, {
@@ -335,6 +418,7 @@ async function ensureGooglePromptCache(
modelApi: params.model.api,
baseUrl,
systemPromptDigest,
cacheConfigDigest: params.cacheConfigDigest,
cacheRetention: params.cacheRetention,
retryAfter: now + GOOGLE_PROMPT_CACHE_RETRY_BACKOFF_MS,
});
@@ -349,6 +433,7 @@ async function ensureGooglePromptCache(
modelApi: params.model.api,
baseUrl,
systemPromptDigest,
cacheConfigDigest: params.cacheConfigDigest,
cacheRetention: params.cacheRetention,
cachedContent: created.cachedContent,
expireTime: created.expireTime,
@@ -386,15 +471,19 @@ export async function prepareGooglePromptCacheStreamFn(
const inner = params.streamFn;
return async (model, context, options) => {
const cacheConfig = buildManagedGooglePromptCacheConfig(context, options);
const cachedContent = await ensureGooglePromptCache(
{
apiKey,
cacheConfigDigest: cacheConfig.cacheConfigDigest,
cacheRetention: resolvedRetention,
model: params.model,
provider: params.provider,
sessionManager: params.sessionManager,
signal: params.signal,
systemPrompt,
tools: cacheConfig.tools,
toolConfig: cacheConfig.toolConfig,
},
deps,
);
@@ -408,7 +497,7 @@ export async function prepareGooglePromptCacheStreamFn(
return streamWithPayloadPatch(
inner,
model,
buildManagedContextWithoutSystemPrompt(context),
buildManagedContextForCachedContent(context),
options,
(payload) => {
payload.cachedContent = cachedContent;