diff --git a/extensions/memory-core/src/memory/manager-provider-state.ts b/extensions/memory-core/src/memory/manager-provider-state.ts new file mode 100644 index 00000000000..2aa01a6ad28 --- /dev/null +++ b/extensions/memory-core/src/memory/manager-provider-state.ts @@ -0,0 +1,79 @@ +import type { + OpenClawConfig, + ResolvedMemorySearchConfig, +} from "openclaw/plugin-sdk/memory-core-host-engine-foundation"; +import { + resolveEmbeddingProviderFallbackModel, + type EmbeddingProvider, + type EmbeddingProviderResult, + type EmbeddingProviderRuntime, +} from "./embeddings.js"; + +export type MemoryResolvedProviderState = { + provider: EmbeddingProvider | null; + fallbackFrom?: string; + fallbackReason?: string; + providerUnavailableReason?: string; + providerRuntime?: EmbeddingProviderRuntime; +}; + +export function resolveMemoryProviderState( + result: Pick< + EmbeddingProviderResult, + "provider" | "fallbackFrom" | "fallbackReason" | "providerUnavailableReason" | "runtime" + >, +): MemoryResolvedProviderState { + return { + provider: result.provider, + fallbackFrom: result.fallbackFrom, + fallbackReason: result.fallbackReason, + providerUnavailableReason: result.providerUnavailableReason, + providerRuntime: result.runtime, + }; +} + +export function applyMemoryFallbackProviderState(params: { + current: MemoryResolvedProviderState; + fallbackFrom: string; + reason: string; + result: Pick; +}): MemoryResolvedProviderState { + return { + ...params.current, + fallbackFrom: params.fallbackFrom, + fallbackReason: params.reason, + provider: params.result.provider, + providerRuntime: params.result.runtime, + }; +} + +export function resolveMemoryFallbackProviderRequest(params: { + cfg: OpenClawConfig; + settings: ResolvedMemorySearchConfig; + currentProviderId: string | null; +}): { + provider: string; + model: string; + remote: ResolvedMemorySearchConfig["remote"]; + outputDimensionality: ResolvedMemorySearchConfig["outputDimensionality"]; + fallback: "none"; + local: ResolvedMemorySearchConfig["local"]; +} | null { + const fallback = params.settings.fallback; + if ( + !fallback || + fallback === "none" || + !params.currentProviderId || + fallback === params.currentProviderId + ) { + return null; + } + return { + provider: fallback, + model: resolveEmbeddingProviderFallbackModel(fallback, params.settings.model, params.cfg), + remote: params.settings.remote, + outputDimensionality: params.settings.outputDimensionality, + fallback: "none", + local: params.settings.local, + }; +} diff --git a/extensions/memory-core/src/memory/manager-sync-ops.ts b/extensions/memory-core/src/memory/manager-sync-ops.ts index 05bdbbd09b5..31c22b8cda2 100644 --- a/extensions/memory-core/src/memory/manager-sync-ops.ts +++ b/extensions/memory-core/src/memory/manager-sync-ops.ts @@ -41,9 +41,12 @@ import { type EmbeddingProvider, type EmbeddingProviderId, type EmbeddingProviderRuntime, - resolveEmbeddingProviderFallbackModel, } from "./embeddings.js"; import { openMemoryDatabaseAtPath } from "./manager-db.js"; +import { + applyMemoryFallbackProviderState, + resolveMemoryFallbackProviderRequest, +} from "./manager-provider-state.js"; import { resolveConfiguredScopeHash, resolveConfiguredSourcesForMeta, @@ -1067,8 +1070,12 @@ export abstract class MemoryManagerSyncOps { } private async activateFallbackProvider(reason: string): Promise { - const fallback = this.settings.fallback; - if (!fallback || fallback === "none" || !this.provider || fallback === this.provider.id) { + const fallbackRequest = resolveMemoryFallbackProviderRequest({ + cfg: this.cfg, + settings: this.settings, + currentProviderId: this.provider?.id ?? null, + }); + if (!fallbackRequest || !this.provider) { return false; } if (this.fallbackFrom) { @@ -1076,30 +1083,33 @@ export abstract class MemoryManagerSyncOps { } const fallbackFrom = this.provider.id; - const fallbackModel = resolveEmbeddingProviderFallbackModel( - fallback, - this.settings.model, - this.cfg, - ); - const fallbackResult = await createEmbeddingProvider({ config: this.cfg, agentDir: resolveAgentDir(this.cfg, this.agentId), - provider: fallback, - remote: this.settings.remote, - model: fallbackModel, - outputDimensionality: this.settings.outputDimensionality, - fallback: "none", - local: this.settings.local, + ...fallbackRequest, }); - this.fallbackFrom = fallbackFrom; - this.fallbackReason = reason; - this.provider = fallbackResult.provider; - this.providerRuntime = fallbackResult.runtime; + const fallbackState = applyMemoryFallbackProviderState({ + current: { + provider: this.provider, + fallbackFrom: this.fallbackFrom, + fallbackReason: this.fallbackReason, + providerUnavailableReason: undefined, + providerRuntime: this.providerRuntime, + }, + fallbackFrom, + reason, + result: fallbackResult, + }); + this.fallbackFrom = fallbackState.fallbackFrom; + this.fallbackReason = fallbackState.fallbackReason; + this.provider = fallbackState.provider; + this.providerRuntime = fallbackState.providerRuntime; this.providerKey = this.computeProviderKey(); this.batch = this.resolveBatchConfig(); - log.warn(`memory embeddings: switched to fallback provider (${fallback})`, { reason }); + log.warn(`memory embeddings: switched to fallback provider (${fallbackRequest.provider})`, { + reason, + }); return true; } diff --git a/extensions/memory-core/src/memory/manager.mistral-provider.test.ts b/extensions/memory-core/src/memory/manager.mistral-provider.test.ts index 862654b3eb8..1554b7718e6 100644 --- a/extensions/memory-core/src/memory/manager.mistral-provider.test.ts +++ b/extensions/memory-core/src/memory/manager.mistral-provider.test.ts @@ -1,15 +1,21 @@ -import fs from "node:fs/promises"; -import os from "node:os"; -import path from "node:path"; -import type { OpenClawConfig } from "openclaw/plugin-sdk/memory-core-host-engine-foundation"; -import { afterAll, afterEach, beforeAll, beforeEach, describe, expect, it, vi } from "vitest"; -import type { MemoryIndexManager } from "./index.js"; -type MemoryIndexModule = typeof import("./index.js"); -type MemoryEmbeddingProvidersModule = - typeof import("../../../../src/plugins/memory-embedding-providers.js"); +import type { + OpenClawConfig, + ResolvedMemorySearchConfig, +} from "openclaw/plugin-sdk/memory-core-host-engine-foundation"; +import { describe, expect, it, vi } from "vitest"; +import { + applyMemoryFallbackProviderState, + resolveMemoryFallbackProviderRequest, + resolveMemoryProviderState, +} from "./manager-provider-state.js"; const DEFAULT_OLLAMA_EMBEDDING_MODEL = "nomic-embed-text"; +vi.mock("./embeddings.js", () => ({ + resolveEmbeddingProviderFallbackModel: (providerId: string, fallbackSourceModel: string) => + providerId === "ollama" ? DEFAULT_OLLAMA_EMBEDDING_MODEL : fallbackSourceModel, +})); + type EmbeddingProvider = { id: string; model: string; @@ -22,40 +28,6 @@ type EmbeddingProviderRuntime = { cacheKeyData: { provider: string; model: string }; }; -type EmbeddingProviderResult = { - requestedProvider: string; - provider: EmbeddingProvider | null; - fallbackFrom?: string; - fallbackReason?: string; - providerUnavailableReason?: string; - runtime?: EmbeddingProviderRuntime; -}; - -const { createEmbeddingProviderMock } = vi.hoisted(() => ({ - createEmbeddingProviderMock: vi.fn(), -})); - -vi.mock("./embeddings.js", () => ({ - createEmbeddingProvider: createEmbeddingProviderMock, - resolveEmbeddingProviderFallbackModel: (providerId: string, fallbackSourceModel: string) => - providerId === "ollama" ? DEFAULT_OLLAMA_EMBEDDING_MODEL : fallbackSourceModel, -})); - -vi.mock("./sqlite-vec.js", () => ({ - loadSqliteVecExtension: async () => ({ ok: false, error: "sqlite-vec disabled in tests" }), -})); - -let getMemorySearchManager: MemoryIndexModule["getMemorySearchManager"]; -let closeAllMemorySearchManagers: MemoryIndexModule["closeAllMemorySearchManagers"]; - -async function ensureProviderInitialized(manager: MemoryIndexManager): Promise { - await ( - manager as unknown as { - ensureProviderInitialized: () => Promise; - } - ).ensureProviderInitialized(); -} - function createProvider(id: string): EmbeddingProvider { return { id, @@ -65,121 +37,40 @@ function createProvider(id: string): EmbeddingProvider { }; } -function buildConfig(params: { - workspaceDir: string; - indexPath: string; +function createSettings(params: { provider: "openai" | "mistral"; fallback?: "none" | "mistral" | "ollama"; -}): OpenClawConfig { +}): ResolvedMemorySearchConfig { return { - agents: { - defaults: { - workspace: params.workspaceDir, - memorySearch: { - provider: params.provider, - model: params.provider === "mistral" ? "mistral/mistral-embed" : "text-embedding-3-small", - fallback: params.fallback ?? "none", - store: { path: params.indexPath, vector: { enabled: false } }, - sync: { watch: false, onSessionStart: false, onSearch: false }, - query: { minScore: 0, hybrid: { enabled: false } }, - }, - }, - list: [{ id: "main", default: true }], - }, - } as OpenClawConfig; + provider: params.provider, + model: params.provider === "mistral" ? "mistral/mistral-embed" : "text-embedding-3-small", + fallback: params.fallback ?? "none", + remote: undefined, + outputDimensionality: undefined, + local: undefined, + } as unknown as ResolvedMemorySearchConfig; } describe("memory manager mistral provider wiring", () => { - let workspaceDir = ""; - let indexPath = ""; - let manager: MemoryIndexManager | null = null; - let clearRegistry: MemoryEmbeddingProvidersModule["clearMemoryEmbeddingProviders"]; - let registerAdapter: MemoryEmbeddingProvidersModule["registerMemoryEmbeddingProvider"]; - - beforeAll(async () => { - vi.resetModules(); - ({ getMemorySearchManager, closeAllMemorySearchManagers } = await import("./index.js")); - ({ - clearMemoryEmbeddingProviders: clearRegistry, - registerMemoryEmbeddingProvider: registerAdapter, - } = await import("../../../../src/plugins/memory-embedding-providers.js")); - }); - - beforeEach(async () => { - vi.clearAllMocks(); - createEmbeddingProviderMock.mockReset(); - clearRegistry(); - registerAdapter({ - id: "openai", - defaultModel: "text-embedding-3-small", - transport: "remote", - create: async () => ({ provider: null }), - }); - registerAdapter({ - id: "mistral", - defaultModel: "mistral-embed", - transport: "remote", - create: async () => ({ provider: null }), - }); - registerAdapter({ - id: "ollama", - defaultModel: DEFAULT_OLLAMA_EMBEDDING_MODEL, - transport: "remote", - create: async () => ({ provider: null }), - }); - workspaceDir = await fs.mkdtemp(path.join(os.tmpdir(), "openclaw-memory-mistral-")); - indexPath = path.join(workspaceDir, "index.sqlite"); - await fs.mkdir(path.join(workspaceDir, "memory"), { recursive: true }); - await fs.writeFile(path.join(workspaceDir, "MEMORY.md"), "test"); - }); - - afterEach(async () => { - if (manager) { - await manager.close(); - manager = null; - } - await closeAllMemorySearchManagers(); - clearRegistry(); - if (workspaceDir) { - await fs.rm(workspaceDir, { recursive: true, force: true }); - workspaceDir = ""; - indexPath = ""; - } - }); - - afterAll(() => { - vi.resetModules(); - }); - - it("stores mistral client when mistral provider is selected", async () => { + it("stores mistral client when mistral provider is selected", () => { const mistralRuntime: EmbeddingProviderRuntime = { id: "mistral", cacheKeyData: { provider: "mistral", model: "mistral-embed" }, }; - const providerResult: EmbeddingProviderResult = { - requestedProvider: "mistral", + + const state = resolveMemoryProviderState({ provider: createProvider("mistral"), runtime: mistralRuntime, - }; - createEmbeddingProviderMock.mockResolvedValueOnce(providerResult); + fallbackFrom: undefined, + fallbackReason: undefined, + providerUnavailableReason: undefined, + }); - const cfg = buildConfig({ workspaceDir, indexPath, provider: "mistral" }); - const result = await getMemorySearchManager({ cfg, agentId: "main" }); - if (!result.manager) { - throw new Error(`manager missing: ${result.error ?? "no error provided"}`); - } - manager = result.manager as unknown as MemoryIndexManager; - await ensureProviderInitialized(manager); - - const internal = manager as unknown as { - ensureProviderInitialized: () => Promise; - providerRuntime?: EmbeddingProviderRuntime; - }; - await internal.ensureProviderInitialized(); - expect(internal.providerRuntime).toBe(mistralRuntime); + expect(state.provider?.id).toBe("mistral"); + expect(state.providerRuntime).toBe(mistralRuntime); }); - it("stores mistral client after fallback activation", async () => { + it("stores mistral client after fallback activation", () => { const openAiRuntime: EmbeddingProviderRuntime = { id: "openai", cacheKeyData: { provider: "openai", model: "text-embedding-3-small" }, @@ -188,80 +79,39 @@ describe("memory manager mistral provider wiring", () => { id: "mistral", cacheKeyData: { provider: "mistral", model: "mistral-embed" }, }; - createEmbeddingProviderMock.mockResolvedValueOnce({ - requestedProvider: "openai", + const current = resolveMemoryProviderState({ provider: createProvider("openai"), runtime: openAiRuntime, - } as EmbeddingProviderResult); - createEmbeddingProviderMock.mockResolvedValueOnce({ - requestedProvider: "mistral", - provider: createProvider("mistral"), - runtime: mistralRuntime, - } as EmbeddingProviderResult); + fallbackFrom: undefined, + fallbackReason: undefined, + providerUnavailableReason: undefined, + }); - const cfg = buildConfig({ workspaceDir, indexPath, provider: "openai", fallback: "mistral" }); - const result = await getMemorySearchManager({ cfg, agentId: "main" }); - if (!result.manager) { - throw new Error(`manager missing: ${result.error ?? "no error provided"}`); - } - manager = result.manager as unknown as MemoryIndexManager; - await ensureProviderInitialized(manager); - const internal = manager as unknown as { - ensureProviderInitialized: () => Promise; - activateFallbackProvider: (reason: string) => Promise; - providerRuntime?: EmbeddingProviderRuntime; - }; + const fallbackState = applyMemoryFallbackProviderState({ + current, + fallbackFrom: "openai", + reason: "forced test", + result: { + provider: createProvider("mistral"), + runtime: mistralRuntime, + }, + }); - await internal.ensureProviderInitialized(); - expect(internal.providerRuntime?.id).toBe("openai"); - const activated = await internal.activateFallbackProvider("forced test"); - expect(activated).toBe(true); - expect(internal.providerRuntime).toBe(mistralRuntime); + expect(fallbackState.fallbackFrom).toBe("openai"); + expect(fallbackState.fallbackReason).toBe("forced test"); + expect(fallbackState.provider?.id).toBe("mistral"); + expect(fallbackState.providerRuntime).toBe(mistralRuntime); }); - it("uses default ollama model when activating ollama fallback", async () => { - const openAiRuntime: EmbeddingProviderRuntime = { - id: "openai", - cacheKeyData: { provider: "openai", model: "text-embedding-3-small" }, - }; - const ollamaRuntime: EmbeddingProviderRuntime = { - id: "ollama", - cacheKeyData: { provider: "ollama", model: DEFAULT_OLLAMA_EMBEDDING_MODEL }, - }; - createEmbeddingProviderMock.mockResolvedValueOnce({ - requestedProvider: "openai", - provider: createProvider("openai"), - runtime: openAiRuntime, - } as EmbeddingProviderResult); - createEmbeddingProviderMock.mockResolvedValueOnce({ - requestedProvider: "ollama", - provider: createProvider("ollama"), - runtime: ollamaRuntime, - } as EmbeddingProviderResult); + it("uses default ollama model when activating ollama fallback", () => { + const request = resolveMemoryFallbackProviderRequest({ + cfg: {} as OpenClawConfig, + settings: createSettings({ provider: "openai", fallback: "ollama" }), + currentProviderId: "openai", + }); - const cfg = buildConfig({ workspaceDir, indexPath, provider: "openai", fallback: "ollama" }); - const result = await getMemorySearchManager({ cfg, agentId: "main" }); - if (!result.manager) { - throw new Error(`manager missing: ${result.error ?? "no error provided"}`); - } - manager = result.manager as unknown as MemoryIndexManager; - await ensureProviderInitialized(manager); - const internal = manager as unknown as { - ensureProviderInitialized: () => Promise; - activateFallbackProvider: (reason: string) => Promise; - providerRuntime?: EmbeddingProviderRuntime; - }; - - await internal.ensureProviderInitialized(); - expect(internal.providerRuntime?.id).toBe("openai"); - const activated = await internal.activateFallbackProvider("forced ollama fallback"); - expect(activated).toBe(true); - expect(internal.providerRuntime).toBe(ollamaRuntime); - - const fallbackCall = createEmbeddingProviderMock.mock.calls[1]?.[0] as - | { provider?: string; model?: string } - | undefined; - expect(fallbackCall?.provider).toBe("ollama"); - expect(fallbackCall?.model).toBe(DEFAULT_OLLAMA_EMBEDDING_MODEL); + expect(request?.provider).toBe("ollama"); + expect(request?.model).toBe(DEFAULT_OLLAMA_EMBEDDING_MODEL); + expect(request?.fallback).toBe("none"); }); }); diff --git a/extensions/memory-core/src/memory/manager.ts b/extensions/memory-core/src/memory/manager.ts index 50df6f28e71..dc60c122805 100644 --- a/extensions/memory-core/src/memory/manager.ts +++ b/extensions/memory-core/src/memory/manager.ts @@ -32,6 +32,7 @@ import { resolveSingletonManagedCache, } from "./manager-cache.js"; import { MemoryManagerEmbeddingOps } from "./manager-embedding-ops.js"; +import { resolveMemoryProviderState } from "./manager-provider-state.js"; import { searchKeyword, searchVector } from "./manager-search.js"; import { collectMemoryStatusAggregate, @@ -231,11 +232,12 @@ export class MemoryIndexManager extends MemoryManagerEmbeddingOps implements Mem } private applyProviderResult(providerResult: EmbeddingProviderResult): void { - this.provider = providerResult.provider; - this.fallbackFrom = providerResult.fallbackFrom; - this.fallbackReason = providerResult.fallbackReason; - this.providerUnavailableReason = providerResult.providerUnavailableReason; - this.providerRuntime = providerResult.runtime; + const providerState = resolveMemoryProviderState(providerResult); + this.provider = providerState.provider; + this.fallbackFrom = providerState.fallbackFrom; + this.fallbackReason = providerState.fallbackReason; + this.providerUnavailableReason = providerState.providerUnavailableReason; + this.providerRuntime = providerState.providerRuntime; this.providerInitialized = true; }