fix(memory): bound embedding batch poll intervals

This commit is contained in:
Peter Steinberger
2026-05-30 18:44:54 -04:00
parent 68b5371fca
commit 33c44626d2
5 changed files with 101 additions and 12 deletions

View File

@@ -284,7 +284,7 @@ export async function runGeminiEmbeddingBatches(
maxRequests: GEMINI_BATCH_MAX_REQUESTS,
debugLabel: "memory embeddings: gemini batch submit",
}),
runGroup: async ({ group, groupIndex, groups, byCustomId }) => {
runGroup: async ({ group, groupIndex, groups, byCustomId, pollIntervalMs, timeoutMs }) => {
const batchInfo = await submitGeminiBatch({
gemini: params.gemini,
requests: group,
@@ -326,8 +326,8 @@ export async function runGeminiEmbeddingBatches(
gemini: params.gemini,
batchName,
wait: params.wait,
pollIntervalMs: params.pollIntervalMs,
timeoutMs: params.timeoutMs,
pollIntervalMs,
timeoutMs,
debug: params.debug,
initial: batchInfo,
});

View File

@@ -211,7 +211,7 @@ export async function runOpenAiEmbeddingBatches(
maxRequests: OPENAI_BATCH_MAX_REQUESTS,
debugLabel: "memory embeddings: openai batch submit",
}),
runGroup: async ({ group, groupIndex, groups, byCustomId }) => {
runGroup: async ({ group, groupIndex, groups, byCustomId, pollIntervalMs, timeoutMs }) => {
const batchInfo = await submitOpenAiBatch({
openAi: params.openAi,
requests: group,
@@ -239,8 +239,8 @@ export async function runOpenAiEmbeddingBatches(
openAi: params.openAi,
batchId,
wait: params.wait,
pollIntervalMs: params.pollIntervalMs,
timeoutMs: params.timeoutMs,
pollIntervalMs,
timeoutMs,
debug: params.debug,
initial: batchInfo,
}),

View File

@@ -228,7 +228,7 @@ export async function runVoyageEmbeddingBatches(
maxRequests: VOYAGE_BATCH_MAX_REQUESTS,
debugLabel: "memory embeddings: voyage batch submit",
}),
runGroup: async ({ group, groupIndex, groups, byCustomId }) => {
runGroup: async ({ group, groupIndex, groups, byCustomId, pollIntervalMs, timeoutMs }) => {
const batchInfo = await submitVoyageBatch({
client: params.client,
requests: group,
@@ -257,8 +257,8 @@ export async function runVoyageEmbeddingBatches(
client: params.client,
batchId,
wait: params.wait,
pollIntervalMs: params.pollIntervalMs,
timeoutMs: params.timeoutMs,
pollIntervalMs,
timeoutMs,
debug: params.debug,
initial: batchInfo,
deps,

View File

@@ -0,0 +1,63 @@
import { describe, expect, it, vi } from "vitest";
import { MAX_SAFE_TIMEOUT_DELAY_MS } from "../../../gateway-client/src/timeouts.js";
import { buildEmbeddingBatchGroupOptions, runEmbeddingBatchGroups } from "./batch-runner.js";
describe("buildEmbeddingBatchGroupOptions", () => {
it("clamps oversized embedding batch poll intervals to the timeout budget", () => {
const options = buildEmbeddingBatchGroupOptions(
{
requests: ["request-1"],
wait: true,
pollIntervalMs: Number.MAX_SAFE_INTEGER,
timeoutMs: 60_000,
concurrency: 1,
},
{
maxRequests: 100,
debugLabel: "embedding batch submit",
},
);
expect(options.pollIntervalMs).toBe(60_000);
});
it("passes clamped poll intervals into batch group runners", async () => {
const runGroup = vi.fn(async () => {});
await runEmbeddingBatchGroups({
requests: ["request-1"],
maxRequests: 100,
wait: true,
pollIntervalMs: Number.MAX_SAFE_INTEGER,
timeoutMs: 60_000,
concurrency: 1,
debugLabel: "embedding batch submit",
runGroup,
});
expect(runGroup).toHaveBeenCalledWith(
expect.objectContaining({
pollIntervalMs: 60_000,
timeoutMs: 60_000,
}),
);
});
it("keeps timeout-safe oversized embedding batch poll intervals bounded", () => {
const options = buildEmbeddingBatchGroupOptions(
{
requests: ["request-1"],
wait: true,
pollIntervalMs: Number.MAX_SAFE_INTEGER,
timeoutMs: Number.MAX_SAFE_INTEGER,
concurrency: 1,
},
{
maxRequests: 100,
debugLabel: "embedding batch submit",
},
);
expect(options.pollIntervalMs).toBe(MAX_SAFE_TIMEOUT_DELAY_MS);
});
});

View File

@@ -1,3 +1,4 @@
import { resolveSafeTimeoutDelayMs } from "../../../gateway-client/src/timeouts.js";
import { splitBatchRequests } from "./batch-utils.js";
import { runWithConcurrency } from "./internal.js";
@@ -9,6 +10,20 @@ export type EmbeddingBatchExecutionParams = {
debug?: (message: string, data?: Record<string, unknown>) => void;
};
function resolveEmbeddingBatchPollIntervalMs(params: {
pollIntervalMs: number;
timeoutMs: number;
}): number {
const safePollIntervalMs = resolveSafeTimeoutDelayMs(params.pollIntervalMs);
const safeTimeoutMs =
typeof params.timeoutMs === "number" &&
Number.isFinite(params.timeoutMs) &&
params.timeoutMs > 0
? resolveSafeTimeoutDelayMs(params.timeoutMs)
: safePollIntervalMs;
return Math.min(safePollIntervalMs, safeTimeoutMs);
}
export async function runEmbeddingBatchGroups<TRequest>(params: {
requests: TRequest[];
maxRequests: number;
@@ -23,6 +38,8 @@ export async function runEmbeddingBatchGroups<TRequest>(params: {
groupIndex: number;
groups: number;
byCustomId: Map<string, number[]>;
pollIntervalMs: number;
timeoutMs: number;
}) => Promise<void>;
}): Promise<Map<string, number[]>> {
if (params.requests.length === 0) {
@@ -30,8 +47,16 @@ export async function runEmbeddingBatchGroups<TRequest>(params: {
}
const groups = splitBatchRequests(params.requests, params.maxRequests);
const byCustomId = new Map<string, number[]>();
const pollIntervalMs = resolveEmbeddingBatchPollIntervalMs(params);
const tasks = groups.map((group, groupIndex) => async () => {
await params.runGroup({ group, groupIndex, groups: groups.length, byCustomId });
await params.runGroup({
group,
groupIndex,
groups: groups.length,
byCustomId,
pollIntervalMs,
timeoutMs: params.timeoutMs,
});
});
params.debug?.(params.debugLabel, {
@@ -39,7 +64,7 @@ export async function runEmbeddingBatchGroups<TRequest>(params: {
groups: groups.length,
wait: params.wait,
concurrency: params.concurrency,
pollIntervalMs: params.pollIntervalMs,
pollIntervalMs,
timeoutMs: params.timeoutMs,
});
@@ -51,11 +76,12 @@ export function buildEmbeddingBatchGroupOptions<TRequest>(
params: { requests: TRequest[] } & EmbeddingBatchExecutionParams,
options: { maxRequests: number; debugLabel: string },
) {
const pollIntervalMs = resolveEmbeddingBatchPollIntervalMs(params);
return {
requests: params.requests,
maxRequests: options.maxRequests,
wait: params.wait,
pollIntervalMs: params.pollIntervalMs,
pollIntervalMs,
timeoutMs: params.timeoutMs,
concurrency: params.concurrency,
debug: params.debug,