diff --git a/extensions/ollama/index.test.ts b/extensions/ollama/index.test.ts index 16efc6e59ad..43057ab8f68 100644 --- a/extensions/ollama/index.test.ts +++ b/extensions/ollama/index.test.ts @@ -19,6 +19,9 @@ const promptAndConfigureOllamaMock = vi.hoisted(() => ); const ensureOllamaModelPulledMock = vi.hoisted(() => vi.fn(async () => {})); const buildOllamaProviderMock = vi.hoisted(() => vi.fn()); +const createConfiguredOllamaStreamFnMock = vi.hoisted(() => + vi.fn((_params: { model: unknown; providerBaseUrl?: string }) => ({}) as never), +); vi.mock("./api.js", () => ({ promptAndConfigureOllama: promptAndConfigureOllamaMock, @@ -27,10 +30,19 @@ vi.mock("./api.js", () => ({ buildOllamaProvider: buildOllamaProviderMock, })); +vi.mock("./src/stream.js", async (importOriginal) => { + const actual = await importOriginal(); + return { + ...actual, + createConfiguredOllamaStreamFn: createConfiguredOllamaStreamFnMock, + }; +}); + beforeEach(() => { promptAndConfigureOllamaMock.mockClear(); ensureOllamaModelPulledMock.mockClear(); buildOllamaProviderMock.mockReset(); + createConfiguredOllamaStreamFnMock.mockClear(); }); function registerProvider() { @@ -207,6 +219,60 @@ describe("ollama plugin", () => { ).toBeUndefined(); }); + it("routes createStreamFn to the correct provider baseUrl for ollama2", () => { + const provider = registerProvider(); + const config = { + models: { + providers: { + ollama: { + api: "ollama", + baseUrl: "http://127.0.0.1:11434", + models: [], + }, + ollama2: { + api: "ollama", + baseUrl: "http://127.0.0.1:11435", + models: [], + }, + }, + }, + }; + const model = { id: "llama3.2", provider: "ollama2", baseUrl: undefined }; + + provider.createStreamFn?.({ config, model, provider: "ollama2" } as never); + + expect(createConfiguredOllamaStreamFnMock).toHaveBeenCalledWith( + expect.objectContaining({ providerBaseUrl: "http://127.0.0.1:11435" }), + ); + }); + + it("uses ollama provider baseUrl when provider is ollama (backward compat)", () => { + const provider = registerProvider(); + const config = { + models: { + providers: { + ollama: { + api: "ollama", + baseUrl: "http://127.0.0.1:11434", + models: [], + }, + ollama2: { + api: "ollama", + baseUrl: "http://127.0.0.1:11435", + models: [], + }, + }, + }, + }; + const model = { id: "llama3.2", provider: "ollama", baseUrl: undefined }; + + provider.createStreamFn?.({ config, model, provider: "ollama" } as never); + + expect(createConfiguredOllamaStreamFnMock).toHaveBeenCalledWith( + expect.objectContaining({ providerBaseUrl: "http://127.0.0.1:11434" }), + ); + }); + it("wraps native Ollama payloads with top-level think=false when thinking is off", () => { const provider = registerProvider(); let payloadSeen: Record | undefined; diff --git a/extensions/ollama/index.ts b/extensions/ollama/index.ts index 3b915dcce77..e479183dc64 100644 --- a/extensions/ollama/index.ts +++ b/extensions/ollama/index.ts @@ -23,6 +23,7 @@ import { resolveOllamaApiBase } from "./src/provider-models.js"; import { createConfiguredOllamaCompatStreamWrapper, createConfiguredOllamaStreamFn, + resolveConfiguredOllamaProviderConfig, } from "./src/stream.js"; import { createOllamaWebSearchProvider } from "./src/web-search-provider.js"; @@ -183,10 +184,11 @@ export default definePluginEntry({ } await ensureOllamaModelPulled({ config, model, prompter }); }, - createStreamFn: ({ config, model }) => { + createStreamFn: ({ config, model, provider }) => { return createConfiguredOllamaStreamFn({ model, - providerBaseUrl: config?.models?.providers?.ollama?.baseUrl, + providerBaseUrl: resolveConfiguredOllamaProviderConfig({ config, providerId: provider }) + ?.baseUrl, }); }, ...OPENAI_COMPATIBLE_REPLAY_HOOKS, diff --git a/extensions/ollama/src/stream.ts b/extensions/ollama/src/stream.ts index d9bfb76f733..943183ee3e4 100644 --- a/extensions/ollama/src/stream.ts +++ b/extensions/ollama/src/stream.ts @@ -50,7 +50,7 @@ export function resolveOllamaBaseUrlForRun(params: { return OLLAMA_NATIVE_BASE_URL; } -function resolveConfiguredOllamaProviderConfig(params: { +export function resolveConfiguredOllamaProviderConfig(params: { config?: OpenClawConfig; providerId?: string; }) {