perf: extract memory provider state helpers

This commit is contained in:
Peter Steinberger
2026-04-06 20:51:03 +01:00
parent 37d7c716f4
commit 6e9382b5c8
4 changed files with 177 additions and 236 deletions

View File

@@ -0,0 +1,79 @@
import type {
OpenClawConfig,
ResolvedMemorySearchConfig,
} from "openclaw/plugin-sdk/memory-core-host-engine-foundation";
import {
resolveEmbeddingProviderFallbackModel,
type EmbeddingProvider,
type EmbeddingProviderResult,
type EmbeddingProviderRuntime,
} from "./embeddings.js";
export type MemoryResolvedProviderState = {
provider: EmbeddingProvider | null;
fallbackFrom?: string;
fallbackReason?: string;
providerUnavailableReason?: string;
providerRuntime?: EmbeddingProviderRuntime;
};
export function resolveMemoryProviderState(
result: Pick<
EmbeddingProviderResult,
"provider" | "fallbackFrom" | "fallbackReason" | "providerUnavailableReason" | "runtime"
>,
): MemoryResolvedProviderState {
return {
provider: result.provider,
fallbackFrom: result.fallbackFrom,
fallbackReason: result.fallbackReason,
providerUnavailableReason: result.providerUnavailableReason,
providerRuntime: result.runtime,
};
}
export function applyMemoryFallbackProviderState(params: {
current: MemoryResolvedProviderState;
fallbackFrom: string;
reason: string;
result: Pick<EmbeddingProviderResult, "provider" | "runtime">;
}): MemoryResolvedProviderState {
return {
...params.current,
fallbackFrom: params.fallbackFrom,
fallbackReason: params.reason,
provider: params.result.provider,
providerRuntime: params.result.runtime,
};
}
export function resolveMemoryFallbackProviderRequest(params: {
cfg: OpenClawConfig;
settings: ResolvedMemorySearchConfig;
currentProviderId: string | null;
}): {
provider: string;
model: string;
remote: ResolvedMemorySearchConfig["remote"];
outputDimensionality: ResolvedMemorySearchConfig["outputDimensionality"];
fallback: "none";
local: ResolvedMemorySearchConfig["local"];
} | null {
const fallback = params.settings.fallback;
if (
!fallback ||
fallback === "none" ||
!params.currentProviderId ||
fallback === params.currentProviderId
) {
return null;
}
return {
provider: fallback,
model: resolveEmbeddingProviderFallbackModel(fallback, params.settings.model, params.cfg),
remote: params.settings.remote,
outputDimensionality: params.settings.outputDimensionality,
fallback: "none",
local: params.settings.local,
};
}

View File

@@ -41,9 +41,12 @@ import {
type EmbeddingProvider,
type EmbeddingProviderId,
type EmbeddingProviderRuntime,
resolveEmbeddingProviderFallbackModel,
} from "./embeddings.js";
import { openMemoryDatabaseAtPath } from "./manager-db.js";
import {
applyMemoryFallbackProviderState,
resolveMemoryFallbackProviderRequest,
} from "./manager-provider-state.js";
import {
resolveConfiguredScopeHash,
resolveConfiguredSourcesForMeta,
@@ -1067,8 +1070,12 @@ export abstract class MemoryManagerSyncOps {
}
private async activateFallbackProvider(reason: string): Promise<boolean> {
const fallback = this.settings.fallback;
if (!fallback || fallback === "none" || !this.provider || fallback === this.provider.id) {
const fallbackRequest = resolveMemoryFallbackProviderRequest({
cfg: this.cfg,
settings: this.settings,
currentProviderId: this.provider?.id ?? null,
});
if (!fallbackRequest || !this.provider) {
return false;
}
if (this.fallbackFrom) {
@@ -1076,30 +1083,33 @@ export abstract class MemoryManagerSyncOps {
}
const fallbackFrom = this.provider.id;
const fallbackModel = resolveEmbeddingProviderFallbackModel(
fallback,
this.settings.model,
this.cfg,
);
const fallbackResult = await createEmbeddingProvider({
config: this.cfg,
agentDir: resolveAgentDir(this.cfg, this.agentId),
provider: fallback,
remote: this.settings.remote,
model: fallbackModel,
outputDimensionality: this.settings.outputDimensionality,
fallback: "none",
local: this.settings.local,
...fallbackRequest,
});
this.fallbackFrom = fallbackFrom;
this.fallbackReason = reason;
this.provider = fallbackResult.provider;
this.providerRuntime = fallbackResult.runtime;
const fallbackState = applyMemoryFallbackProviderState({
current: {
provider: this.provider,
fallbackFrom: this.fallbackFrom,
fallbackReason: this.fallbackReason,
providerUnavailableReason: undefined,
providerRuntime: this.providerRuntime,
},
fallbackFrom,
reason,
result: fallbackResult,
});
this.fallbackFrom = fallbackState.fallbackFrom;
this.fallbackReason = fallbackState.fallbackReason;
this.provider = fallbackState.provider;
this.providerRuntime = fallbackState.providerRuntime;
this.providerKey = this.computeProviderKey();
this.batch = this.resolveBatchConfig();
log.warn(`memory embeddings: switched to fallback provider (${fallback})`, { reason });
log.warn(`memory embeddings: switched to fallback provider (${fallbackRequest.provider})`, {
reason,
});
return true;
}

View File

@@ -1,15 +1,21 @@
import fs from "node:fs/promises";
import os from "node:os";
import path from "node:path";
import type { OpenClawConfig } from "openclaw/plugin-sdk/memory-core-host-engine-foundation";
import { afterAll, afterEach, beforeAll, beforeEach, describe, expect, it, vi } from "vitest";
import type { MemoryIndexManager } from "./index.js";
type MemoryIndexModule = typeof import("./index.js");
type MemoryEmbeddingProvidersModule =
typeof import("../../../../src/plugins/memory-embedding-providers.js");
import type {
OpenClawConfig,
ResolvedMemorySearchConfig,
} from "openclaw/plugin-sdk/memory-core-host-engine-foundation";
import { describe, expect, it, vi } from "vitest";
import {
applyMemoryFallbackProviderState,
resolveMemoryFallbackProviderRequest,
resolveMemoryProviderState,
} from "./manager-provider-state.js";
const DEFAULT_OLLAMA_EMBEDDING_MODEL = "nomic-embed-text";
vi.mock("./embeddings.js", () => ({
resolveEmbeddingProviderFallbackModel: (providerId: string, fallbackSourceModel: string) =>
providerId === "ollama" ? DEFAULT_OLLAMA_EMBEDDING_MODEL : fallbackSourceModel,
}));
type EmbeddingProvider = {
id: string;
model: string;
@@ -22,40 +28,6 @@ type EmbeddingProviderRuntime = {
cacheKeyData: { provider: string; model: string };
};
type EmbeddingProviderResult = {
requestedProvider: string;
provider: EmbeddingProvider | null;
fallbackFrom?: string;
fallbackReason?: string;
providerUnavailableReason?: string;
runtime?: EmbeddingProviderRuntime;
};
const { createEmbeddingProviderMock } = vi.hoisted(() => ({
createEmbeddingProviderMock: vi.fn(),
}));
vi.mock("./embeddings.js", () => ({
createEmbeddingProvider: createEmbeddingProviderMock,
resolveEmbeddingProviderFallbackModel: (providerId: string, fallbackSourceModel: string) =>
providerId === "ollama" ? DEFAULT_OLLAMA_EMBEDDING_MODEL : fallbackSourceModel,
}));
vi.mock("./sqlite-vec.js", () => ({
loadSqliteVecExtension: async () => ({ ok: false, error: "sqlite-vec disabled in tests" }),
}));
let getMemorySearchManager: MemoryIndexModule["getMemorySearchManager"];
let closeAllMemorySearchManagers: MemoryIndexModule["closeAllMemorySearchManagers"];
async function ensureProviderInitialized(manager: MemoryIndexManager): Promise<void> {
await (
manager as unknown as {
ensureProviderInitialized: () => Promise<void>;
}
).ensureProviderInitialized();
}
function createProvider(id: string): EmbeddingProvider {
return {
id,
@@ -65,121 +37,40 @@ function createProvider(id: string): EmbeddingProvider {
};
}
function buildConfig(params: {
workspaceDir: string;
indexPath: string;
function createSettings(params: {
provider: "openai" | "mistral";
fallback?: "none" | "mistral" | "ollama";
}): OpenClawConfig {
}): ResolvedMemorySearchConfig {
return {
agents: {
defaults: {
workspace: params.workspaceDir,
memorySearch: {
provider: params.provider,
model: params.provider === "mistral" ? "mistral/mistral-embed" : "text-embedding-3-small",
fallback: params.fallback ?? "none",
store: { path: params.indexPath, vector: { enabled: false } },
sync: { watch: false, onSessionStart: false, onSearch: false },
query: { minScore: 0, hybrid: { enabled: false } },
},
},
list: [{ id: "main", default: true }],
},
} as OpenClawConfig;
provider: params.provider,
model: params.provider === "mistral" ? "mistral/mistral-embed" : "text-embedding-3-small",
fallback: params.fallback ?? "none",
remote: undefined,
outputDimensionality: undefined,
local: undefined,
} as unknown as ResolvedMemorySearchConfig;
}
describe("memory manager mistral provider wiring", () => {
let workspaceDir = "";
let indexPath = "";
let manager: MemoryIndexManager | null = null;
let clearRegistry: MemoryEmbeddingProvidersModule["clearMemoryEmbeddingProviders"];
let registerAdapter: MemoryEmbeddingProvidersModule["registerMemoryEmbeddingProvider"];
beforeAll(async () => {
vi.resetModules();
({ getMemorySearchManager, closeAllMemorySearchManagers } = await import("./index.js"));
({
clearMemoryEmbeddingProviders: clearRegistry,
registerMemoryEmbeddingProvider: registerAdapter,
} = await import("../../../../src/plugins/memory-embedding-providers.js"));
});
beforeEach(async () => {
vi.clearAllMocks();
createEmbeddingProviderMock.mockReset();
clearRegistry();
registerAdapter({
id: "openai",
defaultModel: "text-embedding-3-small",
transport: "remote",
create: async () => ({ provider: null }),
});
registerAdapter({
id: "mistral",
defaultModel: "mistral-embed",
transport: "remote",
create: async () => ({ provider: null }),
});
registerAdapter({
id: "ollama",
defaultModel: DEFAULT_OLLAMA_EMBEDDING_MODEL,
transport: "remote",
create: async () => ({ provider: null }),
});
workspaceDir = await fs.mkdtemp(path.join(os.tmpdir(), "openclaw-memory-mistral-"));
indexPath = path.join(workspaceDir, "index.sqlite");
await fs.mkdir(path.join(workspaceDir, "memory"), { recursive: true });
await fs.writeFile(path.join(workspaceDir, "MEMORY.md"), "test");
});
afterEach(async () => {
if (manager) {
await manager.close();
manager = null;
}
await closeAllMemorySearchManagers();
clearRegistry();
if (workspaceDir) {
await fs.rm(workspaceDir, { recursive: true, force: true });
workspaceDir = "";
indexPath = "";
}
});
afterAll(() => {
vi.resetModules();
});
it("stores mistral client when mistral provider is selected", async () => {
it("stores mistral client when mistral provider is selected", () => {
const mistralRuntime: EmbeddingProviderRuntime = {
id: "mistral",
cacheKeyData: { provider: "mistral", model: "mistral-embed" },
};
const providerResult: EmbeddingProviderResult = {
requestedProvider: "mistral",
const state = resolveMemoryProviderState({
provider: createProvider("mistral"),
runtime: mistralRuntime,
};
createEmbeddingProviderMock.mockResolvedValueOnce(providerResult);
fallbackFrom: undefined,
fallbackReason: undefined,
providerUnavailableReason: undefined,
});
const cfg = buildConfig({ workspaceDir, indexPath, provider: "mistral" });
const result = await getMemorySearchManager({ cfg, agentId: "main" });
if (!result.manager) {
throw new Error(`manager missing: ${result.error ?? "no error provided"}`);
}
manager = result.manager as unknown as MemoryIndexManager;
await ensureProviderInitialized(manager);
const internal = manager as unknown as {
ensureProviderInitialized: () => Promise<void>;
providerRuntime?: EmbeddingProviderRuntime;
};
await internal.ensureProviderInitialized();
expect(internal.providerRuntime).toBe(mistralRuntime);
expect(state.provider?.id).toBe("mistral");
expect(state.providerRuntime).toBe(mistralRuntime);
});
it("stores mistral client after fallback activation", async () => {
it("stores mistral client after fallback activation", () => {
const openAiRuntime: EmbeddingProviderRuntime = {
id: "openai",
cacheKeyData: { provider: "openai", model: "text-embedding-3-small" },
@@ -188,80 +79,39 @@ describe("memory manager mistral provider wiring", () => {
id: "mistral",
cacheKeyData: { provider: "mistral", model: "mistral-embed" },
};
createEmbeddingProviderMock.mockResolvedValueOnce({
requestedProvider: "openai",
const current = resolveMemoryProviderState({
provider: createProvider("openai"),
runtime: openAiRuntime,
} as EmbeddingProviderResult);
createEmbeddingProviderMock.mockResolvedValueOnce({
requestedProvider: "mistral",
provider: createProvider("mistral"),
runtime: mistralRuntime,
} as EmbeddingProviderResult);
fallbackFrom: undefined,
fallbackReason: undefined,
providerUnavailableReason: undefined,
});
const cfg = buildConfig({ workspaceDir, indexPath, provider: "openai", fallback: "mistral" });
const result = await getMemorySearchManager({ cfg, agentId: "main" });
if (!result.manager) {
throw new Error(`manager missing: ${result.error ?? "no error provided"}`);
}
manager = result.manager as unknown as MemoryIndexManager;
await ensureProviderInitialized(manager);
const internal = manager as unknown as {
ensureProviderInitialized: () => Promise<void>;
activateFallbackProvider: (reason: string) => Promise<boolean>;
providerRuntime?: EmbeddingProviderRuntime;
};
const fallbackState = applyMemoryFallbackProviderState({
current,
fallbackFrom: "openai",
reason: "forced test",
result: {
provider: createProvider("mistral"),
runtime: mistralRuntime,
},
});
await internal.ensureProviderInitialized();
expect(internal.providerRuntime?.id).toBe("openai");
const activated = await internal.activateFallbackProvider("forced test");
expect(activated).toBe(true);
expect(internal.providerRuntime).toBe(mistralRuntime);
expect(fallbackState.fallbackFrom).toBe("openai");
expect(fallbackState.fallbackReason).toBe("forced test");
expect(fallbackState.provider?.id).toBe("mistral");
expect(fallbackState.providerRuntime).toBe(mistralRuntime);
});
it("uses default ollama model when activating ollama fallback", async () => {
const openAiRuntime: EmbeddingProviderRuntime = {
id: "openai",
cacheKeyData: { provider: "openai", model: "text-embedding-3-small" },
};
const ollamaRuntime: EmbeddingProviderRuntime = {
id: "ollama",
cacheKeyData: { provider: "ollama", model: DEFAULT_OLLAMA_EMBEDDING_MODEL },
};
createEmbeddingProviderMock.mockResolvedValueOnce({
requestedProvider: "openai",
provider: createProvider("openai"),
runtime: openAiRuntime,
} as EmbeddingProviderResult);
createEmbeddingProviderMock.mockResolvedValueOnce({
requestedProvider: "ollama",
provider: createProvider("ollama"),
runtime: ollamaRuntime,
} as EmbeddingProviderResult);
it("uses default ollama model when activating ollama fallback", () => {
const request = resolveMemoryFallbackProviderRequest({
cfg: {} as OpenClawConfig,
settings: createSettings({ provider: "openai", fallback: "ollama" }),
currentProviderId: "openai",
});
const cfg = buildConfig({ workspaceDir, indexPath, provider: "openai", fallback: "ollama" });
const result = await getMemorySearchManager({ cfg, agentId: "main" });
if (!result.manager) {
throw new Error(`manager missing: ${result.error ?? "no error provided"}`);
}
manager = result.manager as unknown as MemoryIndexManager;
await ensureProviderInitialized(manager);
const internal = manager as unknown as {
ensureProviderInitialized: () => Promise<void>;
activateFallbackProvider: (reason: string) => Promise<boolean>;
providerRuntime?: EmbeddingProviderRuntime;
};
await internal.ensureProviderInitialized();
expect(internal.providerRuntime?.id).toBe("openai");
const activated = await internal.activateFallbackProvider("forced ollama fallback");
expect(activated).toBe(true);
expect(internal.providerRuntime).toBe(ollamaRuntime);
const fallbackCall = createEmbeddingProviderMock.mock.calls[1]?.[0] as
| { provider?: string; model?: string }
| undefined;
expect(fallbackCall?.provider).toBe("ollama");
expect(fallbackCall?.model).toBe(DEFAULT_OLLAMA_EMBEDDING_MODEL);
expect(request?.provider).toBe("ollama");
expect(request?.model).toBe(DEFAULT_OLLAMA_EMBEDDING_MODEL);
expect(request?.fallback).toBe("none");
});
});

View File

@@ -32,6 +32,7 @@ import {
resolveSingletonManagedCache,
} from "./manager-cache.js";
import { MemoryManagerEmbeddingOps } from "./manager-embedding-ops.js";
import { resolveMemoryProviderState } from "./manager-provider-state.js";
import { searchKeyword, searchVector } from "./manager-search.js";
import {
collectMemoryStatusAggregate,
@@ -231,11 +232,12 @@ export class MemoryIndexManager extends MemoryManagerEmbeddingOps implements Mem
}
private applyProviderResult(providerResult: EmbeddingProviderResult): void {
this.provider = providerResult.provider;
this.fallbackFrom = providerResult.fallbackFrom;
this.fallbackReason = providerResult.fallbackReason;
this.providerUnavailableReason = providerResult.providerUnavailableReason;
this.providerRuntime = providerResult.runtime;
const providerState = resolveMemoryProviderState(providerResult);
this.provider = providerState.provider;
this.fallbackFrom = providerState.fallbackFrom;
this.fallbackReason = providerState.fallbackReason;
this.providerUnavailableReason = providerState.providerUnavailableReason;
this.providerRuntime = providerState.providerRuntime;
this.providerInitialized = true;
}