diff --git a/extensions/ollama/src/setup.test.ts b/extensions/ollama/src/setup.test.ts index 8a7a4d3b4ae..136c4ef8308 100644 --- a/extensions/ollama/src/setup.test.ts +++ b/extensions/ollama/src/setup.test.ts @@ -398,28 +398,77 @@ describe("ollama setup", () => { describe("ensureOllamaModelPulled", () => { it("pulls model when not available locally", async () => { - const progress = { update: vi.fn(), stop: vi.fn() }; - const prompter = { - progress: vi.fn(() => progress), - } as unknown as WizardPrompter; + vi.useFakeTimers(); + try { + const progress = { update: vi.fn(), stop: vi.fn() }; + const prompter = { + progress: vi.fn(() => progress), + } as unknown as WizardPrompter; - const fetchMock = createOllamaFetchMock({ - tags: ["llama3:8b"], - pullResponse: new Response('{"status":"success"}\n', { status: 200 }), - }); - vi.stubGlobal("fetch", fetchMock); + const fetchMock = createOllamaFetchMock({ + tags: ["llama3:8b"], + pullResponse: new Response('{"status":"success"}\n', { status: 200 }), + }); + vi.stubGlobal("fetch", fetchMock); - await ensureOllamaModelPulled({ - config: createDefaultOllamaConfig("ollama/gemma4"), - model: "ollama/gemma4", - prompter, - }); + await ensureOllamaModelPulled({ + config: createDefaultOllamaConfig("ollama/gemma4"), + model: "ollama/gemma4", + prompter, + }); - expect(fetchMock).toHaveBeenCalledTimes(2); - expect(fetchMock.mock.calls[1][0]).toContain("/api/pull"); - const pullInit = fetchMock.mock.calls[1][1]; - expect(pullInit?.signal).toBeInstanceOf(AbortSignal); - expect(pullInit?.signal?.aborted).toBe(false); + expect(fetchMock).toHaveBeenCalledTimes(2); + expect(fetchMock.mock.calls[1][0]).toContain("/api/pull"); + const pullInit = fetchMock.mock.calls[1][1]; + expect(pullInit?.signal).toBeInstanceOf(AbortSignal); + expect(pullInit?.signal?.aborted).toBe(false); + + await vi.advanceTimersByTimeAsync(30_000); + expect(pullInit?.signal?.aborted).toBe(false); + } finally { + vi.useRealTimers(); + } + }); + + it("fails stalled model pull streams after an idle timeout", async () => { + vi.useFakeTimers(); + try { + const progress = { update: vi.fn(), stop: vi.fn() }; + const prompter = { + progress: vi.fn(() => progress), + } as unknown as WizardPrompter; + const fetchMock = vi.fn(async (input: string | URL | Request) => { + const url = requestUrl(input); + if (url.endsWith("/api/tags")) { + return jsonResponse({ models: [] }); + } + if (url.endsWith("/api/pull")) { + return new Response(new ReadableStream(), { status: 200 }); + } + throw new Error(`Unexpected fetch: ${url}`); + }); + vi.stubGlobal("fetch", fetchMock); + + const pullPromise = ensureOllamaModelPulled({ + config: createDefaultOllamaConfig("ollama/gemma4"), + model: "ollama/gemma4", + prompter, + }).catch((err: unknown) => err); + + for (let attempts = 0; attempts < 50 && fetchMock.mock.calls.length < 2; attempts += 1) { + await vi.advanceTimersByTimeAsync(0); + await Promise.resolve(); + } + expect(fetchMock.mock.calls[1]?.[0]).toContain("/api/pull"); + + await vi.advanceTimersByTimeAsync(300_000); + await expect(pullPromise).resolves.toEqual( + expect.objectContaining({ message: "Failed to download selected Ollama model" }), + ); + expect(progress.stop).toHaveBeenCalledWith(expect.stringContaining("Ollama pull stalled")); + } finally { + vi.useRealTimers(); + } }); it("skips pull when model is already available", async () => { diff --git a/extensions/ollama/src/setup.ts b/extensions/ollama/src/setup.ts index 2c44fdad742..362d3292fce 100644 --- a/extensions/ollama/src/setup.ts +++ b/extensions/ollama/src/setup.ts @@ -42,7 +42,8 @@ const OLLAMA_SUGGESTED_MODELS_LOCAL = [OLLAMA_DEFAULT_MODEL]; const OLLAMA_SUGGESTED_MODELS_CLOUD = ["kimi-k2.5:cloud", "minimax-m2.7:cloud", "glm-5.1:cloud"]; const OLLAMA_CONTEXT_ENRICH_LIMIT = 200; const OLLAMA_CLOUD_MAX_DISCOVERED_MODELS = 500; -const OLLAMA_PULL_REQUEST_TIMEOUT_MS = 30_000; +const OLLAMA_PULL_RESPONSE_TIMEOUT_MS = 30_000; +const OLLAMA_PULL_STREAM_IDLE_TIMEOUT_MS = 300_000; type OllamaSetupOptions = { customBaseUrl?: string; @@ -158,6 +159,48 @@ type OllamaPullChunk = { type OllamaPullResult = { ok: true } | { ok: false; message: string }; +async function readOllamaPullChunkWithIdleTimeout( + reader: ReadableStreamDefaultReader, +): Promise> { + let timeoutId: ReturnType | undefined; + let timedOut = false; + + return await new Promise((resolve, reject) => { + const clear = () => { + if (timeoutId !== undefined) { + clearTimeout(timeoutId); + timeoutId = undefined; + } + }; + + timeoutId = setTimeout(() => { + timedOut = true; + clear(); + void reader.cancel().catch(() => undefined); + reject( + new Error( + `Ollama pull stalled: no data received for ${Math.round(OLLAMA_PULL_STREAM_IDLE_TIMEOUT_MS / 1000)}s`, + ), + ); + }, OLLAMA_PULL_STREAM_IDLE_TIMEOUT_MS); + + void reader.read().then( + (result) => { + clear(); + if (!timedOut) { + resolve(result); + } + }, + (err) => { + clear(); + if (!timedOut) { + reject(err); + } + }, + ); + }); +} + async function pullOllamaModelCore(params: { baseUrl: string; modelName: string; @@ -165,6 +208,11 @@ async function pullOllamaModelCore(params: { }): Promise { const baseUrl = resolveOllamaApiBase(params.baseUrl); const modelName = normalizeOllamaModelName(params.modelName) ?? params.modelName.trim(); + const responseController = new AbortController(); + const responseTimeout = setTimeout( + responseController.abort.bind(responseController), + OLLAMA_PULL_RESPONSE_TIMEOUT_MS, + ); try { const { response, release } = await fetchWithSsrFGuard({ url: `${baseUrl}/api/pull`, @@ -173,10 +221,11 @@ async function pullOllamaModelCore(params: { headers: { "Content-Type": "application/json" }, body: JSON.stringify({ name: modelName }), }, - timeoutMs: OLLAMA_PULL_REQUEST_TIMEOUT_MS, + signal: responseController.signal, policy: buildOllamaBaseUrlSsrFPolicy(baseUrl), auditContext: "ollama-setup.pull", }); + clearTimeout(responseTimeout); try { if (!response.ok) { return { ok: false, message: `Failed to download ${modelName} (HTTP ${response.status})` }; @@ -225,7 +274,7 @@ async function pullOllamaModelCore(params: { }; for (;;) { - const { done, value } = await reader.read(); + const { done, value } = await readOllamaPullChunkWithIdleTimeout(reader); if (done) { break; } @@ -255,6 +304,8 @@ async function pullOllamaModelCore(params: { } catch (err) { const reason = formatErrorMessage(err); return { ok: false, message: `Failed to download ${modelName}: ${reason}` }; + } finally { + clearTimeout(responseTimeout); } }