fix(ollama): normalize greedy top_p (#87049)

This commit is contained in:
Vincent Koc
2026-05-27 02:41:30 +01:00
committed by GitHub
parent 1d2bf82461
commit dfadc7b704
3 changed files with 125 additions and 1 deletions

View File

@@ -258,7 +258,7 @@ describe.skipIf(!LIVE)("ollama live", () => {
expect(events.map((event) => (event as { type?: string }).type)).toContain("done");
expect(payload?.model).toBe(CHAT_MODEL);
expect(payload?.options?.num_ctx).toBe(4096);
expect(payload?.options?.top_p).toBe(0.9);
expect(payload?.options?.top_p).toBe(1);
expect(payload?.think).toBe(false);
expect(payload?.keep_alive).toBe("5m");
const properties = payload?.tools?.[0]?.function?.parameters?.properties;

View File

@@ -2319,6 +2319,117 @@ describe("createOllamaStreamFn", () => {
);
});
it("sets top_p=1 for native Ollama greedy sampling requests", async () => {
await withMockNdjsonFetch(
[
'{"model":"m","created_at":"t","message":{"role":"assistant","content":"ok"},"done":false}',
'{"model":"m","created_at":"t","message":{"role":"assistant","content":""},"done":true,"prompt_eval_count":1,"eval_count":1}',
],
async (fetchMock) => {
const stream = await createOllamaTestStream({
baseUrl: "http://ollama-host:11434",
model: {
params: {
num_ctx: 4096,
top_p: 0.9,
thinking: false,
},
},
options: { temperature: 0 },
});
const events = await collectStreamEvents(stream);
expect(events.at(-1)?.type).toBe("done");
const requestInit = getGuardedFetchCall(fetchMock).init ?? {};
if (typeof requestInit.body !== "string") {
throw new Error("Expected string request body");
}
const requestBody = JSON.parse(requestInit.body) as {
options: {
temperature?: number;
top_p?: number;
};
};
expect(requestBody.options.temperature).toBe(0);
expect(requestBody.options.top_p).toBe(1);
},
);
});
it("sets top_p=1 for native Ollama greedy requests without configured top_p", async () => {
await withMockNdjsonFetch(
[
'{"model":"m","created_at":"t","message":{"role":"assistant","content":"ok"},"done":false}',
'{"model":"m","created_at":"t","message":{"role":"assistant","content":""},"done":true,"prompt_eval_count":1,"eval_count":1}',
],
async (fetchMock) => {
const stream = await createOllamaTestStream({
baseUrl: "http://ollama-host:11434",
model: {
params: {
num_ctx: 4096,
thinking: false,
},
},
options: { temperature: 0 },
});
const events = await collectStreamEvents(stream);
expect(events.at(-1)?.type).toBe("done");
const requestInit = getGuardedFetchCall(fetchMock).init ?? {};
if (typeof requestInit.body !== "string") {
throw new Error("Expected string request body");
}
const requestBody = JSON.parse(requestInit.body) as {
options: {
temperature?: number;
top_p?: number;
};
};
expect(requestBody.options.temperature).toBe(0);
expect(requestBody.options.top_p).toBe(1);
},
);
});
it("preserves configured top_p for native Ollama non-greedy sampling requests", async () => {
await withMockNdjsonFetch(
[
'{"model":"m","created_at":"t","message":{"role":"assistant","content":"ok"},"done":false}',
'{"model":"m","created_at":"t","message":{"role":"assistant","content":""},"done":true,"prompt_eval_count":1,"eval_count":1}',
],
async (fetchMock) => {
const stream = await createOllamaTestStream({
baseUrl: "http://ollama-host:11434",
model: {
params: {
top_p: 0.9,
},
},
options: { temperature: 0.2 },
});
const events = await collectStreamEvents(stream);
expect(events.at(-1)?.type).toBe("done");
const requestInit = getGuardedFetchCall(fetchMock).init ?? {};
if (typeof requestInit.body !== "string") {
throw new Error("Expected string request body");
}
const requestBody = JSON.parse(requestInit.body) as {
options: {
temperature?: number;
top_p?: number;
};
};
expect(requestBody.options.temperature).toBe(0.2);
expect(requestBody.options.top_p).toBe(0.9);
},
);
});
it("omits num_ctx when the model has no params.num_ctx and no catalog window", async () => {
await withMockNdjsonFetch(
[

View File

@@ -343,6 +343,18 @@ function resolveOllamaModelOptions(model: ProviderRuntimeModel): Record<string,
return options;
}
function normalizeOllamaGreedySamplingOptions(options: Record<string, unknown>): void {
if (options.temperature !== 0) {
return;
}
if (
options.top_p === undefined ||
(typeof options.top_p === "number" && Number.isFinite(options.top_p) && options.top_p !== 1)
) {
options.top_p = 1;
}
}
function resolveOllamaTopLevelParams(
model: ProviderRuntimeModel,
): Record<string, unknown> | undefined {
@@ -1098,6 +1110,7 @@ export function createOllamaStreamFn(
if (typeof options?.maxTokens === "number") {
ollamaOptions.num_predict = options.maxTokens;
}
normalizeOllamaGreedySamplingOptions(ollamaOptions);
const body = buildOllamaChatRequest({
modelId: model.id,