From 2a999bf9c96714fdbec3cdf09c1bf8e6f81ddd56 Mon Sep 17 00:00:00 2001 From: Vincent Koc Date: Sun, 5 Apr 2026 07:33:14 +0100 Subject: [PATCH] refactor(memory): invert memory host sdk dependency --- .../memory-host-sdk/src/engine-embeddings.ts | 57 +- .../memory-host-sdk/src/engine-foundation.ts | 46 +- packages/memory-host-sdk/src/engine-qmd.ts | 21 +- .../memory-host-sdk/src/engine-storage.ts | 37 +- packages/memory-host-sdk/src/engine.ts | 8 +- packages/memory-host-sdk/src/multimodal.ts | 7 +- packages/memory-host-sdk/src/query.ts | 2 +- packages/memory-host-sdk/src/runtime-cli.ts | 12 +- packages/memory-host-sdk/src/runtime-core.ts | 25 +- packages/memory-host-sdk/src/runtime-files.ts | 7 +- packages/memory-host-sdk/src/runtime.ts | 7 +- packages/memory-host-sdk/src/secret.ts | 5 +- packages/memory-host-sdk/src/status.ts | 7 +- src/memory-host-sdk/engine-embeddings.ts | 57 +- src/memory-host-sdk/engine-foundation.ts | 45 + src/memory-host-sdk/engine-qmd.ts | 21 +- src/memory-host-sdk/engine-storage.ts | 37 +- src/memory-host-sdk/engine.ts | 7 + .../host/backend-config.test.ts | 506 +++++++++++ src/memory-host-sdk/host/backend-config.ts | 429 +++++++++ .../host/batch-embedding-common.ts | 22 + .../host/batch-error-utils.test.ts | 32 + src/memory-host-sdk/host/batch-error-utils.ts | 31 + src/memory-host-sdk/host/batch-gemini.test.ts | 116 +++ src/memory-host-sdk/host/batch-gemini.ts | 368 ++++++++ src/memory-host-sdk/host/batch-http.test.ts | 86 ++ src/memory-host-sdk/host/batch-http.ts | 35 + src/memory-host-sdk/host/batch-openai.ts | 259 ++++++ src/memory-host-sdk/host/batch-output.test.ts | 82 ++ src/memory-host-sdk/host/batch-output.ts | 55 ++ .../host/batch-provider-common.ts | 12 + src/memory-host-sdk/host/batch-runner.ts | 64 ++ src/memory-host-sdk/host/batch-status.test.ts | 60 ++ src/memory-host-sdk/host/batch-status.ts | 69 ++ src/memory-host-sdk/host/batch-upload.ts | 44 + src/memory-host-sdk/host/batch-utils.ts | 38 + src/memory-host-sdk/host/batch-voyage.test.ts | 176 ++++ src/memory-host-sdk/host/batch-voyage.ts | 315 +++++++ .../host/embedding-chunk-limits.test.ts | 102 +++ .../host/embedding-chunk-limits.ts | 41 + .../host/embedding-input-limits.ts | 85 ++ src/memory-host-sdk/host/embedding-inputs.ts | 34 + .../host/embedding-model-limits.ts | 41 + src/memory-host-sdk/host/embedding-vectors.ts | 8 + src/memory-host-sdk/host/embeddings-debug.ts | 13 + .../host/embeddings-gemini.test.ts | 592 +++++++++++++ src/memory-host-sdk/host/embeddings-gemini.ts | 336 +++++++ .../host/embeddings-mistral.test.ts | 19 + .../host/embeddings-mistral.ts | 51 ++ .../host/embeddings-model-normalize.test.ts | 34 + .../host/embeddings-model-normalize.ts | 16 + .../host/embeddings-ollama.test.ts | 146 +++ src/memory-host-sdk/host/embeddings-ollama.ts | 5 + src/memory-host-sdk/host/embeddings-openai.ts | 58 ++ .../host/embeddings-remote-client.ts | 39 + .../host/embeddings-remote-fetch.test.ts | 59 ++ .../host/embeddings-remote-fetch.ts | 25 + .../host/embeddings-remote-provider.ts | 63 ++ .../host/embeddings-voyage.test.ts | 177 ++++ src/memory-host-sdk/host/embeddings-voyage.ts | 82 ++ src/memory-host-sdk/host/embeddings.test.ts | 752 ++++++++++++++++ src/memory-host-sdk/host/embeddings.ts | 324 +++++++ src/memory-host-sdk/host/fs-utils.ts | 31 + src/memory-host-sdk/host/internal.test.ts | 423 +++++++++ src/memory-host-sdk/host/internal.ts | 504 +++++++++++ src/memory-host-sdk/host/memory-schema.ts | 102 +++ src/memory-host-sdk/host/multimodal.ts | 118 +++ src/memory-host-sdk/host/node-llama.ts | 3 + src/memory-host-sdk/host/post-json.test.ts | 60 ++ src/memory-host-sdk/host/post-json.ts | 35 + src/memory-host-sdk/host/qmd-process.test.ts | 154 ++++ src/memory-host-sdk/host/qmd-process.ts | 184 ++++ .../host/qmd-query-parser.test.ts | 64 ++ src/memory-host-sdk/host/qmd-query-parser.ts | 151 ++++ src/memory-host-sdk/host/qmd-scope.test.ts | 54 ++ src/memory-host-sdk/host/qmd-scope.ts | 106 +++ .../host/query-expansion.test.ts | 244 ++++++ src/memory-host-sdk/host/query-expansion.ts | 828 ++++++++++++++++++ src/memory-host-sdk/host/read-file.ts | 96 ++ src/memory-host-sdk/host/remote-http.ts | 40 + src/memory-host-sdk/host/secret-input.ts | 18 + .../host/session-files.test.ts | 123 +++ src/memory-host-sdk/host/session-files.ts | 132 +++ src/memory-host-sdk/host/sqlite-vec.ts | 24 + src/memory-host-sdk/host/sqlite.ts | 19 + src/memory-host-sdk/host/status-format.ts | 45 + src/memory-host-sdk/host/test-helpers/ssrf.ts | 14 + src/memory-host-sdk/host/types.ts | 81 ++ src/memory-host-sdk/multimodal.ts | 7 +- src/memory-host-sdk/query.ts | 2 +- src/memory-host-sdk/runtime-cli.ts | 11 + src/memory-host-sdk/runtime-core.ts | 24 + src/memory-host-sdk/runtime-files.ts | 6 + src/memory-host-sdk/runtime.ts | 6 + src/memory-host-sdk/secret.ts | 5 +- src/memory-host-sdk/status.ts | 7 +- .../memory-core-host-engine-embeddings.ts | 2 +- .../memory-core-host-engine-foundation.ts | 2 +- src/plugin-sdk/memory-core-host-engine-qmd.ts | 2 +- .../memory-core-host-engine-storage.ts | 2 +- src/plugin-sdk/memory-core-host-multimodal.ts | 2 +- src/plugin-sdk/memory-core-host-query.ts | 2 +- .../memory-core-host-runtime-cli.ts | 2 +- .../memory-core-host-runtime-core.ts | 2 +- .../memory-core-host-runtime-files.ts | 2 +- src/plugin-sdk/memory-core-host-secret.ts | 2 +- src/plugin-sdk/memory-core-host-status.ts | 2 +- .../contracts/plugin-sdk-subpaths.test.ts | 6 +- src/plugins/memory-embedding-providers.ts | 2 +- test/helpers/memory-tool-manager-mock.ts | 2 +- 110 files changed, 9811 insertions(+), 251 deletions(-) create mode 100644 src/memory-host-sdk/engine-foundation.ts create mode 100644 src/memory-host-sdk/engine.ts create mode 100644 src/memory-host-sdk/host/backend-config.test.ts create mode 100644 src/memory-host-sdk/host/backend-config.ts create mode 100644 src/memory-host-sdk/host/batch-embedding-common.ts create mode 100644 src/memory-host-sdk/host/batch-error-utils.test.ts create mode 100644 src/memory-host-sdk/host/batch-error-utils.ts create mode 100644 src/memory-host-sdk/host/batch-gemini.test.ts create mode 100644 src/memory-host-sdk/host/batch-gemini.ts create mode 100644 src/memory-host-sdk/host/batch-http.test.ts create mode 100644 src/memory-host-sdk/host/batch-http.ts create mode 100644 src/memory-host-sdk/host/batch-openai.ts create mode 100644 src/memory-host-sdk/host/batch-output.test.ts create mode 100644 src/memory-host-sdk/host/batch-output.ts create mode 100644 src/memory-host-sdk/host/batch-provider-common.ts create mode 100644 src/memory-host-sdk/host/batch-runner.ts create mode 100644 src/memory-host-sdk/host/batch-status.test.ts create mode 100644 src/memory-host-sdk/host/batch-status.ts create mode 100644 src/memory-host-sdk/host/batch-upload.ts create mode 100644 src/memory-host-sdk/host/batch-utils.ts create mode 100644 src/memory-host-sdk/host/batch-voyage.test.ts create mode 100644 src/memory-host-sdk/host/batch-voyage.ts create mode 100644 src/memory-host-sdk/host/embedding-chunk-limits.test.ts create mode 100644 src/memory-host-sdk/host/embedding-chunk-limits.ts create mode 100644 src/memory-host-sdk/host/embedding-input-limits.ts create mode 100644 src/memory-host-sdk/host/embedding-inputs.ts create mode 100644 src/memory-host-sdk/host/embedding-model-limits.ts create mode 100644 src/memory-host-sdk/host/embedding-vectors.ts create mode 100644 src/memory-host-sdk/host/embeddings-debug.ts create mode 100644 src/memory-host-sdk/host/embeddings-gemini.test.ts create mode 100644 src/memory-host-sdk/host/embeddings-gemini.ts create mode 100644 src/memory-host-sdk/host/embeddings-mistral.test.ts create mode 100644 src/memory-host-sdk/host/embeddings-mistral.ts create mode 100644 src/memory-host-sdk/host/embeddings-model-normalize.test.ts create mode 100644 src/memory-host-sdk/host/embeddings-model-normalize.ts create mode 100644 src/memory-host-sdk/host/embeddings-ollama.test.ts create mode 100644 src/memory-host-sdk/host/embeddings-ollama.ts create mode 100644 src/memory-host-sdk/host/embeddings-openai.ts create mode 100644 src/memory-host-sdk/host/embeddings-remote-client.ts create mode 100644 src/memory-host-sdk/host/embeddings-remote-fetch.test.ts create mode 100644 src/memory-host-sdk/host/embeddings-remote-fetch.ts create mode 100644 src/memory-host-sdk/host/embeddings-remote-provider.ts create mode 100644 src/memory-host-sdk/host/embeddings-voyage.test.ts create mode 100644 src/memory-host-sdk/host/embeddings-voyage.ts create mode 100644 src/memory-host-sdk/host/embeddings.test.ts create mode 100644 src/memory-host-sdk/host/embeddings.ts create mode 100644 src/memory-host-sdk/host/fs-utils.ts create mode 100644 src/memory-host-sdk/host/internal.test.ts create mode 100644 src/memory-host-sdk/host/internal.ts create mode 100644 src/memory-host-sdk/host/memory-schema.ts create mode 100644 src/memory-host-sdk/host/multimodal.ts create mode 100644 src/memory-host-sdk/host/node-llama.ts create mode 100644 src/memory-host-sdk/host/post-json.test.ts create mode 100644 src/memory-host-sdk/host/post-json.ts create mode 100644 src/memory-host-sdk/host/qmd-process.test.ts create mode 100644 src/memory-host-sdk/host/qmd-process.ts create mode 100644 src/memory-host-sdk/host/qmd-query-parser.test.ts create mode 100644 src/memory-host-sdk/host/qmd-query-parser.ts create mode 100644 src/memory-host-sdk/host/qmd-scope.test.ts create mode 100644 src/memory-host-sdk/host/qmd-scope.ts create mode 100644 src/memory-host-sdk/host/query-expansion.test.ts create mode 100644 src/memory-host-sdk/host/query-expansion.ts create mode 100644 src/memory-host-sdk/host/read-file.ts create mode 100644 src/memory-host-sdk/host/remote-http.ts create mode 100644 src/memory-host-sdk/host/secret-input.ts create mode 100644 src/memory-host-sdk/host/session-files.test.ts create mode 100644 src/memory-host-sdk/host/session-files.ts create mode 100644 src/memory-host-sdk/host/sqlite-vec.ts create mode 100644 src/memory-host-sdk/host/sqlite.ts create mode 100644 src/memory-host-sdk/host/status-format.ts create mode 100644 src/memory-host-sdk/host/test-helpers/ssrf.ts create mode 100644 src/memory-host-sdk/host/types.ts create mode 100644 src/memory-host-sdk/runtime-cli.ts create mode 100644 src/memory-host-sdk/runtime-core.ts create mode 100644 src/memory-host-sdk/runtime-files.ts create mode 100644 src/memory-host-sdk/runtime.ts diff --git a/packages/memory-host-sdk/src/engine-embeddings.ts b/packages/memory-host-sdk/src/engine-embeddings.ts index c76cd508e2a..9a2916d1926 100644 --- a/packages/memory-host-sdk/src/engine-embeddings.ts +++ b/packages/memory-host-sdk/src/engine-embeddings.ts @@ -1,56 +1 @@ -// Real workspace contract for memory embedding providers and batch helpers. - -export { - getMemoryEmbeddingProvider, - listMemoryEmbeddingProviders, -} from "../../../src/plugins/memory-embedding-providers.js"; -export type { - MemoryEmbeddingBatchChunk, - MemoryEmbeddingBatchOptions, - MemoryEmbeddingProvider, - MemoryEmbeddingProviderAdapter, - MemoryEmbeddingProviderCreateOptions, - MemoryEmbeddingProviderCreateResult, - MemoryEmbeddingProviderRuntime, -} from "../../../src/plugins/memory-embedding-providers.js"; -export { createLocalEmbeddingProvider, DEFAULT_LOCAL_MODEL } from "./host/embeddings.js"; -export { - createGeminiEmbeddingProvider, - DEFAULT_GEMINI_EMBEDDING_MODEL, - buildGeminiEmbeddingRequest, -} from "./host/embeddings-gemini.js"; -export { - createMistralEmbeddingProvider, - DEFAULT_MISTRAL_EMBEDDING_MODEL, -} from "./host/embeddings-mistral.js"; -export { - createOllamaEmbeddingProvider, - DEFAULT_OLLAMA_EMBEDDING_MODEL, -} from "./host/embeddings-ollama.js"; -export type { OllamaEmbeddingClient } from "./host/embeddings-ollama.js"; -export { - createOpenAiEmbeddingProvider, - DEFAULT_OPENAI_EMBEDDING_MODEL, -} from "./host/embeddings-openai.js"; -export { - createVoyageEmbeddingProvider, - DEFAULT_VOYAGE_EMBEDDING_MODEL, -} from "./host/embeddings-voyage.js"; -export { runGeminiEmbeddingBatches, type GeminiBatchRequest } from "./host/batch-gemini.js"; -export { - OPENAI_BATCH_ENDPOINT, - runOpenAiEmbeddingBatches, - type OpenAiBatchRequest, -} from "./host/batch-openai.js"; -export { runVoyageEmbeddingBatches, type VoyageBatchRequest } from "./host/batch-voyage.js"; -export { enforceEmbeddingMaxInputTokens } from "./host/embedding-chunk-limits.js"; -export { - estimateStructuredEmbeddingInputBytes, - estimateUtf8Bytes, -} from "./host/embedding-input-limits.js"; -export { hasNonTextEmbeddingParts, type EmbeddingInput } from "./host/embedding-inputs.js"; -export { - buildCaseInsensitiveExtensionGlob, - classifyMemoryMultimodalPath, - getMemoryMultimodalExtensions, -} from "./host/multimodal.js"; +export * from "../../../src/memory-host-sdk/engine-embeddings.js"; diff --git a/packages/memory-host-sdk/src/engine-foundation.ts b/packages/memory-host-sdk/src/engine-foundation.ts index eac16c61ce3..c048f59890d 100644 --- a/packages/memory-host-sdk/src/engine-foundation.ts +++ b/packages/memory-host-sdk/src/engine-foundation.ts @@ -1,45 +1 @@ -// Real workspace contract for memory engine foundation concerns. - -export { - resolveAgentDir, - resolveAgentWorkspaceDir, - resolveDefaultAgentId, - resolveSessionAgentId, -} from "../../../src/agents/agent-scope.js"; -export { - resolveMemorySearchConfig, - type ResolvedMemorySearchConfig, -} from "../../../src/agents/memory-search.js"; -export { parseDurationMs } from "../../../src/cli/parse-duration.js"; -export { loadConfig } from "../../../src/config/config.js"; -export { resolveStateDir } from "../../../src/config/paths.js"; -export { resolveSessionTranscriptsDirForAgent } from "../../../src/config/sessions/paths.js"; -export { - hasConfiguredSecretInput, - normalizeResolvedSecretInputString, -} from "../../../src/config/types.secrets.js"; -export { writeFileWithinRoot } from "../../../src/infra/fs-safe.js"; -export { createSubsystemLogger } from "../../../src/logging/subsystem.js"; -export { detectMime } from "../../../src/media/mime.js"; -export { resolveGlobalSingleton } from "../../../src/shared/global-singleton.js"; -export { onSessionTranscriptUpdate } from "../../../src/sessions/transcript-events.js"; -export { splitShellArgs } from "../../../src/utils/shell-argv.js"; -export { runTasksWithConcurrency } from "../../../src/utils/run-with-concurrency.js"; -export { - shortenHomeInString, - shortenHomePath, - resolveUserPath, - truncateUtf16Safe, -} from "../../../src/utils.js"; -export type { OpenClawConfig } from "../../../src/config/config.js"; -export type { SessionSendPolicyConfig } from "../../../src/config/types.base.js"; -export type { SecretInput } from "../../../src/config/types.secrets.js"; -export type { - MemoryBackend, - MemoryCitationsMode, - MemoryQmdConfig, - MemoryQmdIndexPath, - MemoryQmdMcporterConfig, - MemoryQmdSearchMode, -} from "../../../src/config/types.memory.js"; -export type { MemorySearchConfig } from "../../../src/config/types.tools.js"; +export * from "../../../src/memory-host-sdk/engine-foundation.js"; diff --git a/packages/memory-host-sdk/src/engine-qmd.ts b/packages/memory-host-sdk/src/engine-qmd.ts index 2b5cc128141..41f161481b2 100644 --- a/packages/memory-host-sdk/src/engine-qmd.ts +++ b/packages/memory-host-sdk/src/engine-qmd.ts @@ -1,20 +1 @@ -// Real workspace contract for QMD/session/query helpers used by the memory engine. - -export { extractKeywords, isQueryStopWordToken } from "./host/query-expansion.js"; -export { - buildSessionEntry, - listSessionFilesForAgent, - sessionPathForFile, - type SessionFileEntry, -} from "./host/session-files.js"; -export { parseQmdQueryJson, type QmdQueryResult } from "./host/qmd-query-parser.js"; -export { - deriveQmdScopeChannel, - deriveQmdScopeChatType, - isQmdScopeAllowed, -} from "./host/qmd-scope.js"; -export { - checkQmdBinaryAvailability, - resolveCliSpawnInvocation, - runCliCommand, -} from "./host/qmd-process.js"; +export * from "../../../src/memory-host-sdk/engine-qmd.js"; diff --git a/packages/memory-host-sdk/src/engine-storage.ts b/packages/memory-host-sdk/src/engine-storage.ts index a1dc489d6fc..7b6a5606c86 100644 --- a/packages/memory-host-sdk/src/engine-storage.ts +++ b/packages/memory-host-sdk/src/engine-storage.ts @@ -1,36 +1 @@ -// Real workspace contract for memory engine storage/index helpers. - -export { - buildFileEntry, - buildMultimodalChunkForIndexing, - chunkMarkdown, - cosineSimilarity, - ensureDir, - hashText, - listMemoryFiles, - normalizeExtraMemoryPaths, - parseEmbedding, - remapChunkLines, - runWithConcurrency, - type MemoryChunk, - type MemoryFileEntry, -} from "./host/internal.js"; -export { readMemoryFile } from "./host/read-file.js"; -export { resolveMemoryBackendConfig } from "./host/backend-config.js"; -export type { - ResolvedMemoryBackendConfig, - ResolvedQmdConfig, - ResolvedQmdMcporterConfig, -} from "./host/backend-config.js"; -export type { - MemoryEmbeddingProbeResult, - MemoryProviderStatus, - MemorySearchManager, - MemorySearchResult, - MemorySource, - MemorySyncProgressUpdate, -} from "./host/types.js"; -export { ensureMemoryIndexSchema } from "./host/memory-schema.js"; -export { loadSqliteVecExtension } from "./host/sqlite-vec.js"; -export { requireNodeSqlite } from "./host/sqlite.js"; -export { isFileMissingError, statRegularFile } from "./host/fs-utils.js"; +export * from "../../../src/memory-host-sdk/engine-storage.js"; diff --git a/packages/memory-host-sdk/src/engine.ts b/packages/memory-host-sdk/src/engine.ts index a18fef9e8ba..25269114848 100644 --- a/packages/memory-host-sdk/src/engine.ts +++ b/packages/memory-host-sdk/src/engine.ts @@ -1,7 +1 @@ -// Aggregate workspace contract for the memory engine surface. -// Keep focused subpaths preferred for new code. - -export * from "./engine-foundation.js"; -export * from "./engine-storage.js"; -export * from "./engine-embeddings.js"; -export * from "./engine-qmd.js"; +export * from "../../../src/memory-host-sdk/engine.js"; diff --git a/packages/memory-host-sdk/src/multimodal.ts b/packages/memory-host-sdk/src/multimodal.ts index 5c62de35490..af483ef2422 100644 --- a/packages/memory-host-sdk/src/multimodal.ts +++ b/packages/memory-host-sdk/src/multimodal.ts @@ -1,6 +1 @@ -export { - isMemoryMultimodalEnabled, - normalizeMemoryMultimodalSettings, - supportsMemoryMultimodalEmbeddings, - type MemoryMultimodalSettings, -} from "./host/multimodal.js"; +export * from "../../../src/memory-host-sdk/multimodal.js"; diff --git a/packages/memory-host-sdk/src/query.ts b/packages/memory-host-sdk/src/query.ts index bb945afaa65..dd2605d656b 100644 --- a/packages/memory-host-sdk/src/query.ts +++ b/packages/memory-host-sdk/src/query.ts @@ -1 +1 @@ -export { extractKeywords, isQueryStopWordToken } from "./host/query-expansion.js"; +export * from "../../../src/memory-host-sdk/query.js"; diff --git a/packages/memory-host-sdk/src/runtime-cli.ts b/packages/memory-host-sdk/src/runtime-cli.ts index 9a1b858cd0d..3f2651422ee 100644 --- a/packages/memory-host-sdk/src/runtime-cli.ts +++ b/packages/memory-host-sdk/src/runtime-cli.ts @@ -1,11 +1 @@ -// Focused runtime contract for memory CLI/UI helpers. - -export { formatErrorMessage, withManager } from "../../../src/cli/cli-utils.js"; -export { formatHelpExamples } from "../../../src/cli/help-format.js"; -export { resolveCommandSecretRefsViaGateway } from "../../../src/cli/command-secret-gateway.js"; -export { withProgress, withProgressTotals } from "../../../src/cli/progress.js"; -export { defaultRuntime } from "../../../src/runtime.js"; -export { formatDocsLink } from "../../../src/terminal/links.js"; -export { colorize, isRich, theme } from "../../../src/terminal/theme.js"; -export { isVerbose, setVerbose } from "../../../src/globals.js"; -export { shortenHomeInString, shortenHomePath } from "../../../src/utils.js"; +export * from "../../../src/memory-host-sdk/runtime-cli.js"; diff --git a/packages/memory-host-sdk/src/runtime-core.ts b/packages/memory-host-sdk/src/runtime-core.ts index 00720f42896..fa9bf04d6d8 100644 --- a/packages/memory-host-sdk/src/runtime-core.ts +++ b/packages/memory-host-sdk/src/runtime-core.ts @@ -1,24 +1 @@ -// Focused runtime contract for memory plugin config/state/helpers. - -export type { AnyAgentTool } from "../../../src/agents/tools/common.js"; -export { resolveCronStyleNow } from "../../../src/agents/current-time.js"; -export { DEFAULT_PI_COMPACTION_RESERVE_TOKENS_FLOOR } from "../../../src/agents/pi-settings.js"; -export { resolveDefaultAgentId, resolveSessionAgentId } from "../../../src/agents/agent-scope.js"; -export { resolveMemorySearchConfig } from "../../../src/agents/memory-search.js"; -export { jsonResult, readNumberParam, readStringParam } from "../../../src/agents/tools/common.js"; -export { SILENT_REPLY_TOKEN } from "../../../src/auto-reply/tokens.js"; -export { parseNonNegativeByteSize } from "../../../src/config/byte-size.js"; -export { loadConfig } from "../../../src/config/config.js"; -export { resolveStateDir } from "../../../src/config/paths.js"; -export { resolveSessionTranscriptsDirForAgent } from "../../../src/config/sessions/paths.js"; -export { emptyPluginConfigSchema } from "../../../src/plugins/config-schema.js"; -export { parseAgentSessionKey } from "../../../src/routing/session-key.js"; -export type { OpenClawConfig } from "../../../src/config/config.js"; -export type { MemoryCitationsMode } from "../../../src/config/types.memory.js"; -export type { - MemoryFlushPlan, - MemoryFlushPlanResolver, - MemoryPluginRuntime, - MemoryPromptSectionBuilder, -} from "../../../src/plugins/memory-state.js"; -export type { OpenClawPluginApi } from "../../../src/plugins/types.js"; +export * from "../../../src/memory-host-sdk/runtime-core.js"; diff --git a/packages/memory-host-sdk/src/runtime-files.ts b/packages/memory-host-sdk/src/runtime-files.ts index dd50c31eb46..5c8aa4b2ae0 100644 --- a/packages/memory-host-sdk/src/runtime-files.ts +++ b/packages/memory-host-sdk/src/runtime-files.ts @@ -1,6 +1 @@ -// Focused runtime contract for memory file/backend access. - -export { listMemoryFiles, normalizeExtraMemoryPaths } from "./host/internal.js"; -export { readAgentMemoryFile } from "./host/read-file.js"; -export { resolveMemoryBackendConfig } from "./host/backend-config.js"; -export type { MemorySearchResult } from "./host/types.js"; +export * from "../../../src/memory-host-sdk/runtime-files.js"; diff --git a/packages/memory-host-sdk/src/runtime.ts b/packages/memory-host-sdk/src/runtime.ts index 6e152ea0dcb..0b3bef94c6c 100644 --- a/packages/memory-host-sdk/src/runtime.ts +++ b/packages/memory-host-sdk/src/runtime.ts @@ -1,6 +1 @@ -// Aggregate workspace contract for memory runtime/helper seams. -// Keep focused subpaths preferred for new code. - -export * from "./runtime-core.js"; -export * from "./runtime-cli.js"; -export * from "./runtime-files.js"; +export * from "../../../src/memory-host-sdk/runtime.js"; diff --git a/packages/memory-host-sdk/src/secret.ts b/packages/memory-host-sdk/src/secret.ts index b2b6b94ab47..c8ac074db0c 100644 --- a/packages/memory-host-sdk/src/secret.ts +++ b/packages/memory-host-sdk/src/secret.ts @@ -1,4 +1 @@ -export { - hasConfiguredMemorySecretInput, - resolveMemorySecretInputString, -} from "./host/secret-input.js"; +export * from "../../../src/memory-host-sdk/secret.js"; diff --git a/packages/memory-host-sdk/src/status.ts b/packages/memory-host-sdk/src/status.ts index dc718abd96b..3e46beba8d6 100644 --- a/packages/memory-host-sdk/src/status.ts +++ b/packages/memory-host-sdk/src/status.ts @@ -1,6 +1 @@ -export { - resolveMemoryCacheSummary, - resolveMemoryFtsState, - resolveMemoryVectorState, - type Tone, -} from "./host/status-format.js"; +export * from "../../../src/memory-host-sdk/status.js"; diff --git a/src/memory-host-sdk/engine-embeddings.ts b/src/memory-host-sdk/engine-embeddings.ts index a5b7e2b91a9..bdd388529b6 100644 --- a/src/memory-host-sdk/engine-embeddings.ts +++ b/src/memory-host-sdk/engine-embeddings.ts @@ -1 +1,56 @@ -export * from "../../packages/memory-host-sdk/src/engine-embeddings.js"; +// Real workspace contract for memory embedding providers and batch helpers. + +export { + getMemoryEmbeddingProvider, + listMemoryEmbeddingProviders, +} from "../plugins/memory-embedding-providers.js"; +export type { + MemoryEmbeddingBatchChunk, + MemoryEmbeddingBatchOptions, + MemoryEmbeddingProvider, + MemoryEmbeddingProviderAdapter, + MemoryEmbeddingProviderCreateOptions, + MemoryEmbeddingProviderCreateResult, + MemoryEmbeddingProviderRuntime, +} from "../plugins/memory-embedding-providers.js"; +export { createLocalEmbeddingProvider, DEFAULT_LOCAL_MODEL } from "./host/embeddings.js"; +export { + createGeminiEmbeddingProvider, + DEFAULT_GEMINI_EMBEDDING_MODEL, + buildGeminiEmbeddingRequest, +} from "./host/embeddings-gemini.js"; +export { + createMistralEmbeddingProvider, + DEFAULT_MISTRAL_EMBEDDING_MODEL, +} from "./host/embeddings-mistral.js"; +export { + createOllamaEmbeddingProvider, + DEFAULT_OLLAMA_EMBEDDING_MODEL, +} from "./host/embeddings-ollama.js"; +export type { OllamaEmbeddingClient } from "./host/embeddings-ollama.js"; +export { + createOpenAiEmbeddingProvider, + DEFAULT_OPENAI_EMBEDDING_MODEL, +} from "./host/embeddings-openai.js"; +export { + createVoyageEmbeddingProvider, + DEFAULT_VOYAGE_EMBEDDING_MODEL, +} from "./host/embeddings-voyage.js"; +export { runGeminiEmbeddingBatches, type GeminiBatchRequest } from "./host/batch-gemini.js"; +export { + OPENAI_BATCH_ENDPOINT, + runOpenAiEmbeddingBatches, + type OpenAiBatchRequest, +} from "./host/batch-openai.js"; +export { runVoyageEmbeddingBatches, type VoyageBatchRequest } from "./host/batch-voyage.js"; +export { enforceEmbeddingMaxInputTokens } from "./host/embedding-chunk-limits.js"; +export { + estimateStructuredEmbeddingInputBytes, + estimateUtf8Bytes, +} from "./host/embedding-input-limits.js"; +export { hasNonTextEmbeddingParts, type EmbeddingInput } from "./host/embedding-inputs.js"; +export { + buildCaseInsensitiveExtensionGlob, + classifyMemoryMultimodalPath, + getMemoryMultimodalExtensions, +} from "./host/multimodal.js"; diff --git a/src/memory-host-sdk/engine-foundation.ts b/src/memory-host-sdk/engine-foundation.ts new file mode 100644 index 00000000000..42dad0f1b4a --- /dev/null +++ b/src/memory-host-sdk/engine-foundation.ts @@ -0,0 +1,45 @@ +// Real workspace contract for memory engine foundation concerns. + +export { + resolveAgentDir, + resolveAgentWorkspaceDir, + resolveDefaultAgentId, + resolveSessionAgentId, +} from "../agents/agent-scope.js"; +export { + resolveMemorySearchConfig, + type ResolvedMemorySearchConfig, +} from "../agents/memory-search.js"; +export { parseDurationMs } from "../cli/parse-duration.js"; +export { loadConfig } from "../config/config.js"; +export { resolveStateDir } from "../config/paths.js"; +export { resolveSessionTranscriptsDirForAgent } from "../config/sessions/paths.js"; +export { + hasConfiguredSecretInput, + normalizeResolvedSecretInputString, +} from "../config/types.secrets.js"; +export { writeFileWithinRoot } from "../infra/fs-safe.js"; +export { createSubsystemLogger } from "../logging/subsystem.js"; +export { detectMime } from "../media/mime.js"; +export { resolveGlobalSingleton } from "../shared/global-singleton.js"; +export { onSessionTranscriptUpdate } from "../sessions/transcript-events.js"; +export { splitShellArgs } from "../utils/shell-argv.js"; +export { runTasksWithConcurrency } from "../utils/run-with-concurrency.js"; +export { + shortenHomeInString, + shortenHomePath, + resolveUserPath, + truncateUtf16Safe, +} from "../utils.js"; +export type { OpenClawConfig } from "../config/config.js"; +export type { SessionSendPolicyConfig } from "../config/types.base.js"; +export type { SecretInput } from "../config/types.secrets.js"; +export type { + MemoryBackend, + MemoryCitationsMode, + MemoryQmdConfig, + MemoryQmdIndexPath, + MemoryQmdMcporterConfig, + MemoryQmdSearchMode, +} from "../config/types.memory.js"; +export type { MemorySearchConfig } from "../config/types.tools.js"; diff --git a/src/memory-host-sdk/engine-qmd.ts b/src/memory-host-sdk/engine-qmd.ts index 21a0be44873..2b5cc128141 100644 --- a/src/memory-host-sdk/engine-qmd.ts +++ b/src/memory-host-sdk/engine-qmd.ts @@ -1 +1,20 @@ -export * from "../../packages/memory-host-sdk/src/engine-qmd.js"; +// Real workspace contract for QMD/session/query helpers used by the memory engine. + +export { extractKeywords, isQueryStopWordToken } from "./host/query-expansion.js"; +export { + buildSessionEntry, + listSessionFilesForAgent, + sessionPathForFile, + type SessionFileEntry, +} from "./host/session-files.js"; +export { parseQmdQueryJson, type QmdQueryResult } from "./host/qmd-query-parser.js"; +export { + deriveQmdScopeChannel, + deriveQmdScopeChatType, + isQmdScopeAllowed, +} from "./host/qmd-scope.js"; +export { + checkQmdBinaryAvailability, + resolveCliSpawnInvocation, + runCliCommand, +} from "./host/qmd-process.js"; diff --git a/src/memory-host-sdk/engine-storage.ts b/src/memory-host-sdk/engine-storage.ts index ee3f3a4e410..a1dc489d6fc 100644 --- a/src/memory-host-sdk/engine-storage.ts +++ b/src/memory-host-sdk/engine-storage.ts @@ -1 +1,36 @@ -export * from "../../packages/memory-host-sdk/src/engine-storage.js"; +// Real workspace contract for memory engine storage/index helpers. + +export { + buildFileEntry, + buildMultimodalChunkForIndexing, + chunkMarkdown, + cosineSimilarity, + ensureDir, + hashText, + listMemoryFiles, + normalizeExtraMemoryPaths, + parseEmbedding, + remapChunkLines, + runWithConcurrency, + type MemoryChunk, + type MemoryFileEntry, +} from "./host/internal.js"; +export { readMemoryFile } from "./host/read-file.js"; +export { resolveMemoryBackendConfig } from "./host/backend-config.js"; +export type { + ResolvedMemoryBackendConfig, + ResolvedQmdConfig, + ResolvedQmdMcporterConfig, +} from "./host/backend-config.js"; +export type { + MemoryEmbeddingProbeResult, + MemoryProviderStatus, + MemorySearchManager, + MemorySearchResult, + MemorySource, + MemorySyncProgressUpdate, +} from "./host/types.js"; +export { ensureMemoryIndexSchema } from "./host/memory-schema.js"; +export { loadSqliteVecExtension } from "./host/sqlite-vec.js"; +export { requireNodeSqlite } from "./host/sqlite.js"; +export { isFileMissingError, statRegularFile } from "./host/fs-utils.js"; diff --git a/src/memory-host-sdk/engine.ts b/src/memory-host-sdk/engine.ts new file mode 100644 index 00000000000..a18fef9e8ba --- /dev/null +++ b/src/memory-host-sdk/engine.ts @@ -0,0 +1,7 @@ +// Aggregate workspace contract for the memory engine surface. +// Keep focused subpaths preferred for new code. + +export * from "./engine-foundation.js"; +export * from "./engine-storage.js"; +export * from "./engine-embeddings.js"; +export * from "./engine-qmd.js"; diff --git a/src/memory-host-sdk/host/backend-config.test.ts b/src/memory-host-sdk/host/backend-config.test.ts new file mode 100644 index 00000000000..eaae5b36885 --- /dev/null +++ b/src/memory-host-sdk/host/backend-config.test.ts @@ -0,0 +1,506 @@ +import fs from "node:fs/promises"; +import os from "node:os"; +import path from "node:path"; +import { describe, expect, it } from "vitest"; +import { resolveAgentWorkspaceDir } from "../../agents/agent-scope.js"; +import type { OpenClawConfig } from "../../config/config.js"; +import { resolveMemoryBackendConfig } from "./backend-config.js"; + +const resolveComparablePath = (value: string, workspaceDir = "/workspace/root"): string => + path.isAbsolute(value) ? path.resolve(value) : path.resolve(workspaceDir, value); + +describe("resolveMemoryBackendConfig", () => { + it("defaults to builtin backend when config missing", () => { + const cfg = { agents: { defaults: { workspace: "/tmp/memory-test" } } } as OpenClawConfig; + const resolved = resolveMemoryBackendConfig({ cfg, agentId: "main" }); + expect(resolved.backend).toBe("builtin"); + expect(resolved.citations).toBe("auto"); + expect(resolved.qmd).toBeUndefined(); + }); + + it("resolves qmd backend with default collections", () => { + const cfg = { + agents: { defaults: { workspace: "/tmp/memory-test" } }, + memory: { + backend: "qmd", + qmd: {}, + }, + } as OpenClawConfig; + const resolved = resolveMemoryBackendConfig({ cfg, agentId: "main" }); + expect(resolved.backend).toBe("qmd"); + expect(resolved.qmd?.collections.length).toBeGreaterThanOrEqual(3); + expect(resolved.qmd?.command).toBe("qmd"); + expect(resolved.qmd?.searchMode).toBe("search"); + expect(resolved.qmd?.update.intervalMs).toBeGreaterThan(0); + expect(resolved.qmd?.update.waitForBootSync).toBe(false); + expect(resolved.qmd?.update.commandTimeoutMs).toBe(30_000); + expect(resolved.qmd?.update.updateTimeoutMs).toBe(120_000); + expect(resolved.qmd?.update.embedTimeoutMs).toBe(120_000); + const names = new Set((resolved.qmd?.collections ?? []).map((collection) => collection.name)); + expect(names.has("memory-root-main")).toBe(true); + expect(names.has("memory-alt-main")).toBe(true); + expect(names.has("memory-dir-main")).toBe(true); + }); + + it("parses quoted qmd command paths", () => { + const cfg = { + agents: { defaults: { workspace: "/tmp/memory-test" } }, + memory: { + backend: "qmd", + qmd: { + command: '"/Applications/QMD Tools/qmd" --flag', + }, + }, + } as OpenClawConfig; + const resolved = resolveMemoryBackendConfig({ cfg, agentId: "main" }); + expect(resolved.qmd?.command).toBe("/Applications/QMD Tools/qmd"); + }); + + it("resolves custom paths relative to workspace", () => { + const cfg = { + agents: { + defaults: { workspace: "/workspace/root" }, + list: [{ id: "main", workspace: "/workspace/root" }], + }, + memory: { + backend: "qmd", + qmd: { + paths: [ + { + path: "notes", + name: "custom-notes", + pattern: "**/*.md", + }, + ], + }, + }, + } as OpenClawConfig; + const resolved = resolveMemoryBackendConfig({ cfg, agentId: "main" }); + const custom = resolved.qmd?.collections.find((c) => c.name.startsWith("custom-notes")); + expect(custom).toBeDefined(); + const workspaceRoot = resolveAgentWorkspaceDir(cfg, "main"); + expect(custom?.path).toBe(path.resolve(workspaceRoot, "notes")); + }); + + it("scopes qmd collection names per agent", () => { + const cfg = { + agents: { + defaults: { workspace: "/workspace/root" }, + list: [ + { id: "main", default: true, workspace: "/workspace/root" }, + { id: "dev", workspace: "/workspace/dev" }, + ], + }, + memory: { + backend: "qmd", + qmd: { + includeDefaultMemory: true, + paths: [{ path: "notes", name: "workspace", pattern: "**/*.md" }], + }, + }, + } as OpenClawConfig; + const mainResolved = resolveMemoryBackendConfig({ cfg, agentId: "main" }); + const devResolved = resolveMemoryBackendConfig({ cfg, agentId: "dev" }); + const mainNames = new Set( + (mainResolved.qmd?.collections ?? []).map((collection) => collection.name), + ); + const devNames = new Set( + (devResolved.qmd?.collections ?? []).map((collection) => collection.name), + ); + expect(mainNames.has("memory-dir-main")).toBe(true); + expect(devNames.has("memory-dir-dev")).toBe(true); + expect(mainNames.has("workspace-main")).toBe(true); + expect(devNames.has("workspace-dev")).toBe(true); + }); + + it("merges default and per-agent qmd extra collections", () => { + const cfg = { + agents: { + defaults: { + workspace: "/workspace/root", + memorySearch: { + qmd: { + extraCollections: [ + { + path: "/shared/team-notes", + name: "team-notes", + pattern: "**/*.md", + }, + ], + }, + }, + }, + list: [ + { + id: "main", + default: true, + workspace: "/workspace/root", + memorySearch: { + qmd: { + extraCollections: [ + { + path: "notes", + name: "notes", + pattern: "**/*.md", + }, + ], + }, + }, + }, + ], + }, + memory: { + backend: "qmd", + qmd: { + includeDefaultMemory: false, + }, + }, + } as OpenClawConfig; + const resolved = resolveMemoryBackendConfig({ cfg, agentId: "main" }); + const names = new Set((resolved.qmd?.collections ?? []).map((collection) => collection.name)); + expect(names.has("team-notes")).toBe(true); + expect(names.has("notes-main")).toBe(true); + }); + + it("preserves explicit custom collection names for paths outside the workspace", () => { + const cfg = { + agents: { + defaults: { workspace: "/workspace/root" }, + list: [ + { id: "main", default: true, workspace: "/workspace/root" }, + { id: "dev", workspace: "/workspace/dev" }, + ], + }, + memory: { + backend: "qmd", + qmd: { + includeDefaultMemory: true, + paths: [{ path: "/shared/notion-mirror", name: "notion-mirror", pattern: "**/*.md" }], + }, + }, + } as OpenClawConfig; + const mainResolved = resolveMemoryBackendConfig({ cfg, agentId: "main" }); + const devResolved = resolveMemoryBackendConfig({ cfg, agentId: "dev" }); + const mainNames = new Set( + (mainResolved.qmd?.collections ?? []).map((collection) => collection.name), + ); + const devNames = new Set( + (devResolved.qmd?.collections ?? []).map((collection) => collection.name), + ); + expect(mainNames.has("memory-dir-main")).toBe(true); + expect(devNames.has("memory-dir-dev")).toBe(true); + expect(mainNames.has("notion-mirror")).toBe(true); + expect(devNames.has("notion-mirror")).toBe(true); + }); + + it("keeps symlinked workspace paths agent-scoped when deciding custom collection names", async () => { + const tmpRoot = await fs.mkdtemp(path.join(os.tmpdir(), "qmd-backend-config-")); + const workspaceDir = path.join(tmpRoot, "workspace"); + const workspaceAliasDir = path.join(tmpRoot, "workspace-alias"); + try { + await fs.mkdir(workspaceDir, { recursive: true }); + await fs.symlink(workspaceDir, workspaceAliasDir); + const cfg = { + agents: { + defaults: { workspace: workspaceDir }, + list: [{ id: "main", default: true, workspace: workspaceDir }], + }, + memory: { + backend: "qmd", + qmd: { + includeDefaultMemory: false, + paths: [{ path: workspaceAliasDir, name: "workspace", pattern: "**/*.md" }], + }, + }, + } as OpenClawConfig; + const resolved = resolveMemoryBackendConfig({ cfg, agentId: "main" }); + const names = new Set((resolved.qmd?.collections ?? []).map((collection) => collection.name)); + expect(names.has("workspace-main")).toBe(true); + expect(names.has("workspace")).toBe(false); + } finally { + await fs.rm(tmpRoot, { recursive: true, force: true }); + } + }); + + it("keeps unresolved child paths under a symlinked workspace agent-scoped", async () => { + const tmpRoot = await fs.mkdtemp(path.join(os.tmpdir(), "qmd-backend-config-")); + const realRootDir = path.join(tmpRoot, "real-root"); + const aliasRootDir = path.join(tmpRoot, "alias-root"); + const workspaceDir = path.join(realRootDir, "workspace"); + const workspaceAliasDir = path.join(aliasRootDir, "workspace"); + try { + await fs.mkdir(workspaceDir, { recursive: true }); + await fs.symlink(realRootDir, aliasRootDir); + const cfg = { + agents: { + defaults: { workspace: workspaceDir }, + list: [{ id: "main", default: true, workspace: workspaceDir }], + }, + memory: { + backend: "qmd", + qmd: { + includeDefaultMemory: false, + paths: [ + { path: path.join(workspaceAliasDir, "notes"), name: "notes", pattern: "**/*.md" }, + ], + }, + }, + } as OpenClawConfig; + const resolved = resolveMemoryBackendConfig({ cfg, agentId: "main" }); + const names = new Set((resolved.qmd?.collections ?? []).map((collection) => collection.name)); + expect(names.has("notes-main")).toBe(true); + expect(names.has("notes")).toBe(false); + } finally { + await fs.rm(tmpRoot, { recursive: true, force: true }); + } + }); + + it("resolves qmd update timeout overrides", () => { + const cfg = { + agents: { defaults: { workspace: "/tmp/memory-test" } }, + memory: { + backend: "qmd", + qmd: { + update: { + waitForBootSync: true, + commandTimeoutMs: 12_000, + updateTimeoutMs: 480_000, + embedTimeoutMs: 360_000, + }, + }, + }, + } as OpenClawConfig; + const resolved = resolveMemoryBackendConfig({ cfg, agentId: "main" }); + expect(resolved.qmd?.update.waitForBootSync).toBe(true); + expect(resolved.qmd?.update.commandTimeoutMs).toBe(12_000); + expect(resolved.qmd?.update.updateTimeoutMs).toBe(480_000); + expect(resolved.qmd?.update.embedTimeoutMs).toBe(360_000); + }); + + it("resolves qmd search mode override", () => { + const cfg = { + agents: { defaults: { workspace: "/tmp/memory-test" } }, + memory: { + backend: "qmd", + qmd: { + searchMode: "vsearch", + }, + }, + } as OpenClawConfig; + const resolved = resolveMemoryBackendConfig({ cfg, agentId: "main" }); + expect(resolved.qmd?.searchMode).toBe("vsearch"); + }); + + it("resolves qmd mcporter search tool override", () => { + const cfg = { + agents: { defaults: { workspace: "/tmp/memory-test" } }, + memory: { + backend: "qmd", + qmd: { + searchMode: "query", + searchTool: " hybrid_search ", + }, + }, + } as OpenClawConfig; + const resolved = resolveMemoryBackendConfig({ cfg, agentId: "main" }); + expect(resolved.qmd?.searchMode).toBe("query"); + expect(resolved.qmd?.searchTool).toBe("hybrid_search"); + }); +}); + +describe("memorySearch.extraPaths integration", () => { + it("maps agents.defaults.memorySearch.extraPaths to QMD collections", () => { + const cfg = { + memory: { backend: "qmd" }, + agents: { + defaults: { + workspace: "/workspace/root", + memorySearch: { + extraPaths: ["/home/user/docs", "/home/user/vault"], + }, + }, + }, + } as OpenClawConfig; + const result = resolveMemoryBackendConfig({ cfg, agentId: "test-agent" }); + expect(result.backend).toBe("qmd"); + const customCollections = (result.qmd?.collections ?? []).filter( + (collection) => collection.kind === "custom", + ); + expect(customCollections.length).toBeGreaterThanOrEqual(2); + expect(customCollections.map((collection) => collection.path)).toEqual( + expect.arrayContaining([ + resolveComparablePath("/home/user/docs"), + resolveComparablePath("/home/user/vault"), + ]), + ); + }); + + it("merges default and per-agent memorySearch.extraPaths for QMD collections", () => { + const cfg = { + memory: { backend: "qmd" }, + agents: { + defaults: { + workspace: "/workspace/root", + memorySearch: { + extraPaths: ["/default/path"], + }, + }, + list: [ + { + id: "my-agent", + memorySearch: { + extraPaths: ["/agent/specific/path"], + }, + }, + ], + }, + } as OpenClawConfig; + const result = resolveMemoryBackendConfig({ cfg, agentId: "my-agent" }); + expect(result.backend).toBe("qmd"); + const customCollections = (result.qmd?.collections ?? []).filter( + (collection) => collection.kind === "custom", + ); + const paths = customCollections.map((collection) => collection.path); + expect(paths).toContain(resolveComparablePath("/agent/specific/path")); + expect(paths).toContain(resolveComparablePath("/default/path")); + }); + + it("falls back to defaults when agent has no overrides", () => { + const cfg = { + memory: { backend: "qmd" }, + agents: { + defaults: { + workspace: "/workspace/root", + memorySearch: { + extraPaths: ["/default/path"], + }, + }, + list: [ + { + id: "other-agent", + memorySearch: { + extraPaths: ["/other/path"], + }, + }, + ], + }, + } as OpenClawConfig; + const result = resolveMemoryBackendConfig({ cfg, agentId: "my-agent" }); + expect(result.backend).toBe("qmd"); + const customCollections = (result.qmd?.collections ?? []).filter( + (collection) => collection.kind === "custom", + ); + const paths = customCollections.map((collection) => collection.path); + expect(paths).toContain(resolveComparablePath("/default/path")); + }); + + it("deduplicates merged memorySearch.extraPaths for QMD collections", () => { + const cfg = { + memory: { backend: "qmd" }, + agents: { + defaults: { + workspace: "/workspace/root", + memorySearch: { + extraPaths: ["/shared/path", " /shared/path "], + }, + }, + list: [ + { + id: "my-agent", + memorySearch: { + extraPaths: ["/shared/path", "/agent-only"], + }, + }, + ], + }, + } as OpenClawConfig; + + const result = resolveMemoryBackendConfig({ cfg, agentId: "my-agent" }); + const customCollections = (result.qmd?.collections ?? []).filter( + (collection) => collection.kind === "custom", + ); + const paths = customCollections.map((collection) => collection.path); + + expect( + paths.filter((collectionPath) => collectionPath === resolveComparablePath("/shared/path")), + ).toHaveLength(1); + expect(paths).toContain(resolveComparablePath("/agent-only")); + }); + + it("keeps unnamed extra paths agent-scoped even when they resolve outside the workspace", () => { + const cfg = { + memory: { backend: "qmd" }, + agents: { + defaults: { + workspace: "/workspace/root", + memorySearch: { + extraPaths: ["/shared/path"], + }, + }, + }, + } as OpenClawConfig; + const result = resolveMemoryBackendConfig({ cfg, agentId: "my-agent" }); + const customCollections = (result.qmd?.collections ?? []).filter( + (collection) => collection.kind === "custom", + ); + expect(customCollections.map((collection) => collection.name)).toContain("custom-1-my-agent"); + }); + + it("matches per-agent memorySearch.extraPaths using normalized agent ids", () => { + const cfg = { + memory: { backend: "qmd" }, + agents: { + defaults: { + workspace: "/workspace/root", + }, + list: [ + { + id: "My-Agent", + memorySearch: { + extraPaths: ["/agent/mixed-case"], + }, + }, + ], + }, + } as OpenClawConfig; + + const result = resolveMemoryBackendConfig({ cfg, agentId: "my-agent" }); + const customCollections = (result.qmd?.collections ?? []).filter( + (collection) => collection.kind === "custom", + ); + + expect(customCollections.map((collection) => collection.path)).toContain( + resolveComparablePath("/agent/mixed-case"), + ); + }); + + it("deduplicates identical roots shared by memory.qmd.paths and memorySearch.extraPaths", () => { + const cfg = { + memory: { + backend: "qmd", + qmd: { + paths: [{ path: "docs", pattern: "**/*.md", name: "workspace-docs" }], + }, + }, + agents: { + defaults: { + workspace: "/workspace/root", + memorySearch: { + extraPaths: ["./docs"], + }, + }, + }, + } as OpenClawConfig; + + const result = resolveMemoryBackendConfig({ cfg, agentId: "main" }); + const customCollections = (result.qmd?.collections ?? []).filter( + (collection) => collection.kind === "custom", + ); + const docsCollections = customCollections.filter( + (collection) => + collection.path === resolveComparablePath("./docs") && collection.pattern === "**/*.md", + ); + + expect(docsCollections).toHaveLength(1); + }); +}); diff --git a/src/memory-host-sdk/host/backend-config.ts b/src/memory-host-sdk/host/backend-config.ts new file mode 100644 index 00000000000..b5a9db26741 --- /dev/null +++ b/src/memory-host-sdk/host/backend-config.ts @@ -0,0 +1,429 @@ +import fs from "node:fs"; +import path from "node:path"; +import { resolveAgentWorkspaceDir } from "../../agents/agent-scope.js"; +import { parseDurationMs } from "../../cli/parse-duration.js"; +import type { OpenClawConfig } from "../../config/config.js"; +import type { SessionSendPolicyConfig } from "../../config/types.base.js"; +import type { + MemoryBackend, + MemoryCitationsMode, + MemoryQmdConfig, + MemoryQmdIndexPath, + MemoryQmdMcporterConfig, + MemoryQmdSearchMode, +} from "../../config/types.memory.js"; +import { normalizeAgentId } from "../../routing/session-key.js"; +import { resolveUserPath } from "../../utils.js"; +import { splitShellArgs } from "../../utils/shell-argv.js"; + +export type ResolvedMemoryBackendConfig = { + backend: MemoryBackend; + citations: MemoryCitationsMode; + qmd?: ResolvedQmdConfig; +}; + +export type ResolvedQmdCollection = { + name: string; + path: string; + pattern: string; + kind: "memory" | "custom" | "sessions"; +}; + +export type ResolvedQmdUpdateConfig = { + intervalMs: number; + debounceMs: number; + onBoot: boolean; + waitForBootSync: boolean; + embedIntervalMs: number; + commandTimeoutMs: number; + updateTimeoutMs: number; + embedTimeoutMs: number; +}; + +export type ResolvedQmdLimitsConfig = { + maxResults: number; + maxSnippetChars: number; + maxInjectedChars: number; + timeoutMs: number; +}; + +export type ResolvedQmdSessionConfig = { + enabled: boolean; + exportDir?: string; + retentionDays?: number; +}; + +export type ResolvedQmdMcporterConfig = { + enabled: boolean; + serverName: string; + startDaemon: boolean; +}; + +export type ResolvedQmdConfig = { + command: string; + mcporter: ResolvedQmdMcporterConfig; + searchMode: MemoryQmdSearchMode; + searchTool?: string; + collections: ResolvedQmdCollection[]; + sessions: ResolvedQmdSessionConfig; + update: ResolvedQmdUpdateConfig; + limits: ResolvedQmdLimitsConfig; + includeDefaultMemory: boolean; + scope?: SessionSendPolicyConfig; +}; + +const DEFAULT_BACKEND: MemoryBackend = "builtin"; +const DEFAULT_CITATIONS: MemoryCitationsMode = "auto"; +const DEFAULT_QMD_INTERVAL = "5m"; +const DEFAULT_QMD_DEBOUNCE_MS = 15_000; +const DEFAULT_QMD_TIMEOUT_MS = 4_000; +// Defaulting to `query` can be extremely slow on CPU-only systems (query expansion + rerank). +// Prefer a faster mode for interactive use; users can opt into `query` for best recall. +const DEFAULT_QMD_SEARCH_MODE: MemoryQmdSearchMode = "search"; +const DEFAULT_QMD_EMBED_INTERVAL = "60m"; +const DEFAULT_QMD_COMMAND_TIMEOUT_MS = 30_000; +const DEFAULT_QMD_UPDATE_TIMEOUT_MS = 120_000; +const DEFAULT_QMD_EMBED_TIMEOUT_MS = 120_000; +const DEFAULT_QMD_LIMITS: ResolvedQmdLimitsConfig = { + maxResults: 6, + maxSnippetChars: 700, + maxInjectedChars: 4_000, + timeoutMs: DEFAULT_QMD_TIMEOUT_MS, +}; +const DEFAULT_QMD_MCPORTER: ResolvedQmdMcporterConfig = { + enabled: false, + serverName: "qmd", + startDaemon: true, +}; + +const DEFAULT_QMD_SCOPE: SessionSendPolicyConfig = { + default: "deny", + rules: [ + { + action: "allow", + match: { chatType: "direct" }, + }, + ], +}; + +function sanitizeName(input: string): string { + const lower = input.toLowerCase().replace(/[^a-z0-9-]+/g, "-"); + const trimmed = lower.replace(/^-+|-+$/g, ""); + return trimmed || "collection"; +} + +function scopeCollectionBase(base: string, agentId: string): string { + return `${base}-${sanitizeName(agentId)}`; +} + +function canonicalizePathForContainment(rawPath: string): string { + const resolved = path.resolve(rawPath); + let current = resolved; + const suffix: string[] = []; + while (true) { + try { + const canonical = path.normalize(fs.realpathSync.native(current)); + return path.normalize(path.join(canonical, ...suffix)); + } catch { + const parent = path.dirname(current); + if (parent === current) { + return path.normalize(resolved); + } + suffix.unshift(path.basename(current)); + current = parent; + } + } +} + +function isPathInsideRoot(candidatePath: string, rootPath: string): boolean { + const relative = path.relative( + canonicalizePathForContainment(rootPath), + canonicalizePathForContainment(candidatePath), + ); + return relative === "" || (!relative.startsWith("..") && !path.isAbsolute(relative)); +} + +function ensureUniqueName(base: string, existing: Set): string { + let name = sanitizeName(base); + if (!existing.has(name)) { + existing.add(name); + return name; + } + let suffix = 2; + while (existing.has(`${name}-${suffix}`)) { + suffix += 1; + } + const unique = `${name}-${suffix}`; + existing.add(unique); + return unique; +} + +function resolvePath(raw: string, workspaceDir: string): string { + const trimmed = raw.trim(); + if (!trimmed) { + throw new Error("path required"); + } + if (trimmed.startsWith("~") || path.isAbsolute(trimmed)) { + return path.normalize(resolveUserPath(trimmed)); + } + return path.normalize(path.resolve(workspaceDir, trimmed)); +} + +function resolveIntervalMs(raw: string | undefined): number { + const value = raw?.trim(); + if (!value) { + return parseDurationMs(DEFAULT_QMD_INTERVAL, { defaultUnit: "m" }); + } + try { + return parseDurationMs(value, { defaultUnit: "m" }); + } catch { + return parseDurationMs(DEFAULT_QMD_INTERVAL, { defaultUnit: "m" }); + } +} + +function resolveEmbedIntervalMs(raw: string | undefined): number { + const value = raw?.trim(); + if (!value) { + return parseDurationMs(DEFAULT_QMD_EMBED_INTERVAL, { defaultUnit: "m" }); + } + try { + return parseDurationMs(value, { defaultUnit: "m" }); + } catch { + return parseDurationMs(DEFAULT_QMD_EMBED_INTERVAL, { defaultUnit: "m" }); + } +} + +function resolveDebounceMs(raw: number | undefined): number { + if (typeof raw === "number" && Number.isFinite(raw) && raw >= 0) { + return Math.floor(raw); + } + return DEFAULT_QMD_DEBOUNCE_MS; +} + +function resolveTimeoutMs(raw: number | undefined, fallback: number): number { + if (typeof raw === "number" && Number.isFinite(raw) && raw > 0) { + return Math.floor(raw); + } + return fallback; +} + +function resolveLimits(raw?: MemoryQmdConfig["limits"]): ResolvedQmdLimitsConfig { + const parsed: ResolvedQmdLimitsConfig = { ...DEFAULT_QMD_LIMITS }; + if (raw?.maxResults && raw.maxResults > 0) { + parsed.maxResults = Math.floor(raw.maxResults); + } + if (raw?.maxSnippetChars && raw.maxSnippetChars > 0) { + parsed.maxSnippetChars = Math.floor(raw.maxSnippetChars); + } + if (raw?.maxInjectedChars && raw.maxInjectedChars > 0) { + parsed.maxInjectedChars = Math.floor(raw.maxInjectedChars); + } + if (raw?.timeoutMs && raw.timeoutMs > 0) { + parsed.timeoutMs = Math.floor(raw.timeoutMs); + } + return parsed; +} + +function resolveSearchMode(raw?: MemoryQmdConfig["searchMode"]): MemoryQmdSearchMode { + if (raw === "search" || raw === "vsearch" || raw === "query") { + return raw; + } + return DEFAULT_QMD_SEARCH_MODE; +} + +function resolveSearchTool(raw?: MemoryQmdConfig["searchTool"]): string | undefined { + const value = raw?.trim(); + return value ? value : undefined; +} + +function resolveSessionConfig( + cfg: MemoryQmdConfig["sessions"], + workspaceDir: string, +): ResolvedQmdSessionConfig { + const enabled = Boolean(cfg?.enabled); + const exportDirRaw = cfg?.exportDir?.trim(); + const exportDir = exportDirRaw ? resolvePath(exportDirRaw, workspaceDir) : undefined; + const retentionDays = + cfg?.retentionDays && cfg.retentionDays > 0 ? Math.floor(cfg.retentionDays) : undefined; + return { + enabled, + exportDir, + retentionDays, + }; +} + +function resolveCustomPaths( + rawPaths: MemoryQmdIndexPath[] | undefined, + workspaceDir: string, + existing: Set, + agentId: string, +): ResolvedQmdCollection[] { + if (!rawPaths?.length) { + return []; + } + const collections: ResolvedQmdCollection[] = []; + const seenRoots = new Set(); + rawPaths.forEach((entry, index) => { + const trimmedPath = entry?.path?.trim(); + if (!trimmedPath) { + return; + } + let resolved: string; + try { + resolved = resolvePath(trimmedPath, workspaceDir); + } catch { + return; + } + const pattern = entry.pattern?.trim() || "**/*.md"; + const dedupeKey = `${resolved}\u0000${pattern}`; + if (seenRoots.has(dedupeKey)) { + return; + } + seenRoots.add(dedupeKey); + const explicitName = entry.name?.trim(); + const baseName = + explicitName && !isPathInsideRoot(resolved, workspaceDir) + ? explicitName + : scopeCollectionBase(explicitName || `custom-${index + 1}`, agentId); + const name = ensureUniqueName(baseName, existing); + collections.push({ + name, + path: resolved, + pattern, + kind: "custom", + }); + }); + return collections; +} + +function resolveMcporterConfig(raw?: MemoryQmdMcporterConfig): ResolvedQmdMcporterConfig { + const parsed: ResolvedQmdMcporterConfig = { ...DEFAULT_QMD_MCPORTER }; + if (!raw) { + return parsed; + } + if (raw.enabled !== undefined) { + parsed.enabled = raw.enabled; + } + if (typeof raw.serverName === "string" && raw.serverName.trim()) { + parsed.serverName = raw.serverName.trim(); + } + if (raw.startDaemon !== undefined) { + parsed.startDaemon = raw.startDaemon; + } + // When enabled, default startDaemon to true. + if (parsed.enabled && raw.startDaemon === undefined) { + parsed.startDaemon = true; + } + return parsed; +} + +function resolveDefaultCollections( + include: boolean, + workspaceDir: string, + existing: Set, + agentId: string, +): ResolvedQmdCollection[] { + if (!include) { + return []; + } + const entries: Array<{ path: string; pattern: string; base: string }> = [ + { path: workspaceDir, pattern: "MEMORY.md", base: "memory-root" }, + { path: workspaceDir, pattern: "memory.md", base: "memory-alt" }, + { path: path.join(workspaceDir, "memory"), pattern: "**/*.md", base: "memory-dir" }, + ]; + return entries.map((entry) => ({ + name: ensureUniqueName(scopeCollectionBase(entry.base, agentId), existing), + path: entry.path, + pattern: entry.pattern, + kind: "memory", + })); +} + +export function resolveMemoryBackendConfig(params: { + cfg: OpenClawConfig; + agentId: string; +}): ResolvedMemoryBackendConfig { + const normalizedAgentId = normalizeAgentId(params.agentId); + const backend = params.cfg.memory?.backend ?? DEFAULT_BACKEND; + const citations = params.cfg.memory?.citations ?? DEFAULT_CITATIONS; + if (backend !== "qmd") { + return { backend: "builtin", citations }; + } + + const workspaceDir = resolveAgentWorkspaceDir(params.cfg, normalizedAgentId); + const qmdCfg = params.cfg.memory?.qmd; + const includeDefaultMemory = qmdCfg?.includeDefaultMemory !== false; + const nameSet = new Set(); + const agentEntry = params.cfg.agents?.list?.find( + (entry) => normalizeAgentId(entry?.id) === normalizedAgentId, + ); + const mergedExtraPaths = [ + ...(params.cfg.agents?.defaults?.memorySearch?.extraPaths ?? []), + ...(agentEntry?.memorySearch?.extraPaths ?? []), + ] + .filter((value): value is string => typeof value === "string") + .map((value) => value.trim()) + .filter(Boolean); + const dedupedExtraPaths = Array.from(new Set(mergedExtraPaths)); + const searchExtraPaths = dedupedExtraPaths.map( + (pathValue): { path: string; pattern?: string; name?: string } => ({ path: pathValue }), + ); + const mergedExtraCollections = [ + ...(params.cfg.agents?.defaults?.memorySearch?.qmd?.extraCollections ?? []), + ...(agentEntry?.memorySearch?.qmd?.extraCollections ?? []), + ].filter((value): value is MemoryQmdIndexPath => + Boolean(value && typeof value === "object" && typeof value.path === "string"), + ); + + // Combine QMD-specific paths with extraPaths and per-agent cross-agent collections. + const allQmdPaths: MemoryQmdIndexPath[] = [ + ...(qmdCfg?.paths ?? []), + ...searchExtraPaths, + ...mergedExtraCollections, + ]; + + const collections = [ + ...resolveDefaultCollections(includeDefaultMemory, workspaceDir, nameSet, normalizedAgentId), + ...resolveCustomPaths(allQmdPaths, workspaceDir, nameSet, normalizedAgentId), + ]; + + const rawCommand = qmdCfg?.command?.trim() || "qmd"; + const parsedCommand = splitShellArgs(rawCommand); + const command = parsedCommand?.[0] || rawCommand.split(/\s+/)[0] || "qmd"; + const resolved: ResolvedQmdConfig = { + command, + mcporter: resolveMcporterConfig(qmdCfg?.mcporter), + searchMode: resolveSearchMode(qmdCfg?.searchMode), + searchTool: resolveSearchTool(qmdCfg?.searchTool), + collections, + includeDefaultMemory, + sessions: resolveSessionConfig(qmdCfg?.sessions, workspaceDir), + update: { + intervalMs: resolveIntervalMs(qmdCfg?.update?.interval), + debounceMs: resolveDebounceMs(qmdCfg?.update?.debounceMs), + onBoot: qmdCfg?.update?.onBoot !== false, + waitForBootSync: qmdCfg?.update?.waitForBootSync === true, + embedIntervalMs: resolveEmbedIntervalMs(qmdCfg?.update?.embedInterval), + commandTimeoutMs: resolveTimeoutMs( + qmdCfg?.update?.commandTimeoutMs, + DEFAULT_QMD_COMMAND_TIMEOUT_MS, + ), + updateTimeoutMs: resolveTimeoutMs( + qmdCfg?.update?.updateTimeoutMs, + DEFAULT_QMD_UPDATE_TIMEOUT_MS, + ), + embedTimeoutMs: resolveTimeoutMs( + qmdCfg?.update?.embedTimeoutMs, + DEFAULT_QMD_EMBED_TIMEOUT_MS, + ), + }, + limits: resolveLimits(qmdCfg?.limits), + scope: qmdCfg?.scope ?? DEFAULT_QMD_SCOPE, + }; + + return { + backend: "qmd", + citations, + qmd: resolved, + }; +} diff --git a/src/memory-host-sdk/host/batch-embedding-common.ts b/src/memory-host-sdk/host/batch-embedding-common.ts new file mode 100644 index 00000000000..2aa3351150f --- /dev/null +++ b/src/memory-host-sdk/host/batch-embedding-common.ts @@ -0,0 +1,22 @@ +export { extractBatchErrorMessage, formatUnavailableBatchError } from "./batch-error-utils.js"; +export { postJsonWithRetry } from "./batch-http.js"; +export { applyEmbeddingBatchOutputLine } from "./batch-output.js"; +export { + resolveBatchCompletionFromStatus, + resolveCompletedBatchResult, + throwIfBatchTerminalFailure, + type BatchCompletionResult, +} from "./batch-status.js"; +export { + EMBEDDING_BATCH_ENDPOINT, + type EmbeddingBatchStatus, + type ProviderBatchOutputLine, +} from "./batch-provider-common.js"; +export { + buildEmbeddingBatchGroupOptions, + runEmbeddingBatchGroups, + type EmbeddingBatchExecutionParams, +} from "./batch-runner.js"; +export { uploadBatchJsonlFile } from "./batch-upload.js"; +export { buildBatchHeaders, normalizeBatchBaseUrl } from "./batch-utils.js"; +export { withRemoteHttpResponse } from "./remote-http.js"; diff --git a/src/memory-host-sdk/host/batch-error-utils.test.ts b/src/memory-host-sdk/host/batch-error-utils.test.ts new file mode 100644 index 00000000000..c92c9cbac39 --- /dev/null +++ b/src/memory-host-sdk/host/batch-error-utils.test.ts @@ -0,0 +1,32 @@ +import { describe, expect, it } from "vitest"; +import { extractBatchErrorMessage, formatUnavailableBatchError } from "./batch-error-utils.js"; + +describe("extractBatchErrorMessage", () => { + it("returns the first top-level error message", () => { + expect( + extractBatchErrorMessage([ + { response: { body: { error: { message: "nested" } } } }, + { error: { message: "top-level" } }, + ]), + ).toBe("nested"); + }); + + it("falls back to nested response error message", () => { + expect( + extractBatchErrorMessage([{ response: { body: { error: { message: "nested-only" } } } }, {}]), + ).toBe("nested-only"); + }); + + it("accepts plain string response bodies", () => { + expect(extractBatchErrorMessage([{ response: { body: "provider plain-text error" } }])).toBe( + "provider plain-text error", + ); + }); +}); + +describe("formatUnavailableBatchError", () => { + it("formats errors and non-error values", () => { + expect(formatUnavailableBatchError(new Error("boom"))).toBe("error file unavailable: boom"); + expect(formatUnavailableBatchError("unreachable")).toBe("error file unavailable: unreachable"); + }); +}); diff --git a/src/memory-host-sdk/host/batch-error-utils.ts b/src/memory-host-sdk/host/batch-error-utils.ts new file mode 100644 index 00000000000..215b0672a8c --- /dev/null +++ b/src/memory-host-sdk/host/batch-error-utils.ts @@ -0,0 +1,31 @@ +type BatchOutputErrorLike = { + error?: { message?: string }; + response?: { + body?: + | string + | { + error?: { message?: string }; + }; + }; +}; + +function getResponseErrorMessage(line: BatchOutputErrorLike | undefined): string | undefined { + const body = line?.response?.body; + if (typeof body === "string") { + return body || undefined; + } + if (!body || typeof body !== "object") { + return undefined; + } + return typeof body.error?.message === "string" ? body.error.message : undefined; +} + +export function extractBatchErrorMessage(lines: BatchOutputErrorLike[]): string | undefined { + const first = lines.find((line) => line.error?.message || getResponseErrorMessage(line)); + return first?.error?.message ?? getResponseErrorMessage(first); +} + +export function formatUnavailableBatchError(err: unknown): string | undefined { + const message = err instanceof Error ? err.message : String(err); + return message ? `error file unavailable: ${message}` : undefined; +} diff --git a/src/memory-host-sdk/host/batch-gemini.test.ts b/src/memory-host-sdk/host/batch-gemini.test.ts new file mode 100644 index 00000000000..095ebe008b9 --- /dev/null +++ b/src/memory-host-sdk/host/batch-gemini.test.ts @@ -0,0 +1,116 @@ +import { afterEach, beforeAll, beforeEach, describe, expect, it, vi } from "vitest"; +import type { GeminiEmbeddingClient } from "./embeddings-gemini.js"; + +vi.mock("./remote-http.js", () => ({ + withRemoteHttpResponse: vi.fn(), +})); + +function magnitude(values: number[]) { + return Math.sqrt(values.reduce((sum, value) => sum + value * value, 0)); +} + +describe("runGeminiEmbeddingBatches", () => { + let runGeminiEmbeddingBatches: typeof import("./batch-gemini.js").runGeminiEmbeddingBatches; + let withRemoteHttpResponse: typeof import("./remote-http.js").withRemoteHttpResponse; + let remoteHttpMock: ReturnType>; + + beforeAll(async () => { + ({ runGeminiEmbeddingBatches } = await import("./batch-gemini.js")); + ({ withRemoteHttpResponse } = await import("./remote-http.js")); + remoteHttpMock = vi.mocked(withRemoteHttpResponse); + }); + + beforeEach(() => { + vi.clearAllMocks(); + }); + + afterEach(() => { + vi.resetAllMocks(); + vi.unstubAllGlobals(); + }); + + const mockClient: GeminiEmbeddingClient = { + baseUrl: "https://generativelanguage.googleapis.com/v1beta", + headers: {}, + model: "gemini-embedding-2-preview", + modelPath: "models/gemini-embedding-2-preview", + apiKeys: ["test-key"], + outputDimensionality: 1536, + }; + + it("includes outputDimensionality in batch upload requests", async () => { + remoteHttpMock.mockImplementationOnce(async (params) => { + expect(params.url).toContain("/upload/v1beta/files?uploadType=multipart"); + const body = params.init?.body; + if (!(body instanceof Blob)) { + throw new Error("expected multipart blob body"); + } + const text = await body.text(); + expect(text).toContain('"taskType":"RETRIEVAL_DOCUMENT"'); + expect(text).toContain('"outputDimensionality":1536'); + return await params.onResponse( + new Response(JSON.stringify({ name: "files/file-123" }), { + status: 200, + headers: { "Content-Type": "application/json" }, + }), + ); + }); + remoteHttpMock.mockImplementationOnce(async (params) => { + expect(params.url).toMatch(/:asyncBatchEmbedContent$/u); + return await params.onResponse( + new Response( + JSON.stringify({ + name: "batches/batch-1", + state: "COMPLETED", + outputConfig: { file: "files/output-1" }, + }), + { + status: 200, + headers: { "Content-Type": "application/json" }, + }, + ), + ); + }); + remoteHttpMock.mockImplementationOnce(async (params) => { + expect(params.url).toMatch(/\/files\/output-1:download$/u); + return await params.onResponse( + new Response( + JSON.stringify({ + key: "req-1", + response: { embedding: { values: [3, 4] } }, + }), + { + status: 200, + headers: { "Content-Type": "application/jsonl" }, + }, + ), + ); + }); + + const results = await runGeminiEmbeddingBatches({ + gemini: mockClient, + agentId: "main", + requests: [ + { + custom_id: "req-1", + request: { + content: { parts: [{ text: "hello world" }] }, + taskType: "RETRIEVAL_DOCUMENT", + outputDimensionality: 1536, + }, + }, + ], + wait: true, + pollIntervalMs: 1, + timeoutMs: 1000, + concurrency: 1, + }); + + const embedding = results.get("req-1"); + expect(embedding).toBeDefined(); + expect(embedding?.[0]).toBeCloseTo(0.6, 5); + expect(embedding?.[1]).toBeCloseTo(0.8, 5); + expect(magnitude(embedding ?? [])).toBeCloseTo(1, 5); + expect(remoteHttpMock).toHaveBeenCalledTimes(3); + }); +}); diff --git a/src/memory-host-sdk/host/batch-gemini.ts b/src/memory-host-sdk/host/batch-gemini.ts new file mode 100644 index 00000000000..4bdc9fa055e --- /dev/null +++ b/src/memory-host-sdk/host/batch-gemini.ts @@ -0,0 +1,368 @@ +import { + buildEmbeddingBatchGroupOptions, + runEmbeddingBatchGroups, + type EmbeddingBatchExecutionParams, +} from "./batch-runner.js"; +import { buildBatchHeaders, normalizeBatchBaseUrl } from "./batch-utils.js"; +import { sanitizeAndNormalizeEmbedding } from "./embedding-vectors.js"; +import { debugEmbeddingsLog } from "./embeddings-debug.js"; +import type { GeminiEmbeddingClient, GeminiTextEmbeddingRequest } from "./embeddings-gemini.js"; +import { hashText } from "./internal.js"; +import { withRemoteHttpResponse } from "./remote-http.js"; + +export type GeminiBatchRequest = { + custom_id: string; + request: GeminiTextEmbeddingRequest; +}; + +export type GeminiBatchStatus = { + name?: string; + state?: string; + outputConfig?: { file?: string; fileId?: string }; + metadata?: { + output?: { + responsesFile?: string; + }; + }; + error?: { message?: string }; +}; + +export type GeminiBatchOutputLine = { + key?: string; + custom_id?: string; + request_id?: string; + embedding?: { values?: number[] }; + response?: { + embedding?: { values?: number[] }; + error?: { message?: string }; + }; + error?: { message?: string }; +}; + +const GEMINI_BATCH_MAX_REQUESTS = 50000; +function getGeminiUploadUrl(baseUrl: string): string { + if (baseUrl.includes("/v1beta")) { + return baseUrl.replace(/\/v1beta\/?$/, "/upload/v1beta"); + } + return `${baseUrl.replace(/\/$/, "")}/upload`; +} + +function buildGeminiUploadBody(params: { jsonl: string; displayName: string }): { + body: Blob; + contentType: string; +} { + const boundary = `openclaw-${hashText(params.displayName)}`; + const jsonPart = JSON.stringify({ + file: { + displayName: params.displayName, + mimeType: "application/jsonl", + }, + }); + const delimiter = `--${boundary}\r\n`; + const closeDelimiter = `--${boundary}--\r\n`; + const parts = [ + `${delimiter}Content-Type: application/json; charset=UTF-8\r\n\r\n${jsonPart}\r\n`, + `${delimiter}Content-Type: application/jsonl; charset=UTF-8\r\n\r\n${params.jsonl}\r\n`, + closeDelimiter, + ]; + const body = new Blob([parts.join("")], { type: "multipart/related" }); + return { + body, + contentType: `multipart/related; boundary=${boundary}`, + }; +} + +async function submitGeminiBatch(params: { + gemini: GeminiEmbeddingClient; + requests: GeminiBatchRequest[]; + agentId: string; +}): Promise { + const baseUrl = normalizeBatchBaseUrl(params.gemini); + const jsonl = params.requests + .map((request) => + JSON.stringify({ + key: request.custom_id, + request: request.request, + }), + ) + .join("\n"); + const displayName = `memory-embeddings-${hashText(String(Date.now()))}`; + const uploadPayload = buildGeminiUploadBody({ jsonl, displayName }); + + const uploadUrl = `${getGeminiUploadUrl(baseUrl)}/files?uploadType=multipart`; + debugEmbeddingsLog("memory embeddings: gemini batch upload", { + uploadUrl, + baseUrl, + requests: params.requests.length, + }); + const filePayload = await withRemoteHttpResponse({ + url: uploadUrl, + ssrfPolicy: params.gemini.ssrfPolicy, + init: { + method: "POST", + headers: { + ...buildBatchHeaders(params.gemini, { json: false }), + "Content-Type": uploadPayload.contentType, + }, + body: uploadPayload.body, + }, + onResponse: async (fileRes) => { + if (!fileRes.ok) { + const text = await fileRes.text(); + throw new Error(`gemini batch file upload failed: ${fileRes.status} ${text}`); + } + return (await fileRes.json()) as { name?: string; file?: { name?: string } }; + }, + }); + const fileId = filePayload.name ?? filePayload.file?.name; + if (!fileId) { + throw new Error("gemini batch file upload failed: missing file id"); + } + + const batchBody = { + batch: { + displayName: `memory-embeddings-${params.agentId}`, + inputConfig: { + file_name: fileId, + }, + }, + }; + + const batchEndpoint = `${baseUrl}/${params.gemini.modelPath}:asyncBatchEmbedContent`; + debugEmbeddingsLog("memory embeddings: gemini batch create", { + batchEndpoint, + fileId, + }); + return await withRemoteHttpResponse({ + url: batchEndpoint, + ssrfPolicy: params.gemini.ssrfPolicy, + init: { + method: "POST", + headers: buildBatchHeaders(params.gemini, { json: true }), + body: JSON.stringify(batchBody), + }, + onResponse: async (batchRes) => { + if (batchRes.ok) { + return (await batchRes.json()) as GeminiBatchStatus; + } + const text = await batchRes.text(); + if (batchRes.status === 404) { + throw new Error( + "gemini batch create failed: 404 (asyncBatchEmbedContent not available for this model/baseUrl). Disable remote.batch.enabled or switch providers.", + ); + } + throw new Error(`gemini batch create failed: ${batchRes.status} ${text}`); + }, + }); +} + +async function fetchGeminiBatchStatus(params: { + gemini: GeminiEmbeddingClient; + batchName: string; +}): Promise { + const baseUrl = normalizeBatchBaseUrl(params.gemini); + const name = params.batchName.startsWith("batches/") + ? params.batchName + : `batches/${params.batchName}`; + const statusUrl = `${baseUrl}/${name}`; + debugEmbeddingsLog("memory embeddings: gemini batch status", { statusUrl }); + return await withRemoteHttpResponse({ + url: statusUrl, + ssrfPolicy: params.gemini.ssrfPolicy, + init: { + headers: buildBatchHeaders(params.gemini, { json: true }), + }, + onResponse: async (res) => { + if (!res.ok) { + const text = await res.text(); + throw new Error(`gemini batch status failed: ${res.status} ${text}`); + } + return (await res.json()) as GeminiBatchStatus; + }, + }); +} + +async function fetchGeminiFileContent(params: { + gemini: GeminiEmbeddingClient; + fileId: string; +}): Promise { + const baseUrl = normalizeBatchBaseUrl(params.gemini); + const file = params.fileId.startsWith("files/") ? params.fileId : `files/${params.fileId}`; + const downloadUrl = `${baseUrl}/${file}:download`; + debugEmbeddingsLog("memory embeddings: gemini batch download", { downloadUrl }); + return await withRemoteHttpResponse({ + url: downloadUrl, + ssrfPolicy: params.gemini.ssrfPolicy, + init: { + headers: buildBatchHeaders(params.gemini, { json: true }), + }, + onResponse: async (res) => { + if (!res.ok) { + const text = await res.text(); + throw new Error(`gemini batch file content failed: ${res.status} ${text}`); + } + return await res.text(); + }, + }); +} + +function parseGeminiBatchOutput(text: string): GeminiBatchOutputLine[] { + if (!text.trim()) { + return []; + } + return text + .split("\n") + .map((line) => line.trim()) + .filter(Boolean) + .map((line) => JSON.parse(line) as GeminiBatchOutputLine); +} + +async function waitForGeminiBatch(params: { + gemini: GeminiEmbeddingClient; + batchName: string; + wait: boolean; + pollIntervalMs: number; + timeoutMs: number; + debug?: (message: string, data?: Record) => void; + initial?: GeminiBatchStatus; +}): Promise<{ outputFileId: string }> { + const start = Date.now(); + let current: GeminiBatchStatus | undefined = params.initial; + while (true) { + const status = + current ?? + (await fetchGeminiBatchStatus({ + gemini: params.gemini, + batchName: params.batchName, + })); + const state = status.state ?? "UNKNOWN"; + if (["SUCCEEDED", "COMPLETED", "DONE"].includes(state)) { + const outputFileId = + status.outputConfig?.file ?? + status.outputConfig?.fileId ?? + status.metadata?.output?.responsesFile; + if (!outputFileId) { + throw new Error(`gemini batch ${params.batchName} completed without output file`); + } + return { outputFileId }; + } + if (["FAILED", "CANCELLED", "CANCELED", "EXPIRED"].includes(state)) { + const message = status.error?.message ?? "unknown error"; + throw new Error(`gemini batch ${params.batchName} ${state}: ${message}`); + } + if (!params.wait) { + throw new Error(`gemini batch ${params.batchName} still ${state}; wait disabled`); + } + if (Date.now() - start > params.timeoutMs) { + throw new Error(`gemini batch ${params.batchName} timed out after ${params.timeoutMs}ms`); + } + params.debug?.(`gemini batch ${params.batchName} ${state}; waiting ${params.pollIntervalMs}ms`); + await new Promise((resolve) => setTimeout(resolve, params.pollIntervalMs)); + current = undefined; + } +} + +export async function runGeminiEmbeddingBatches( + params: { + gemini: GeminiEmbeddingClient; + agentId: string; + requests: GeminiBatchRequest[]; + } & EmbeddingBatchExecutionParams, +): Promise> { + return await runEmbeddingBatchGroups({ + ...buildEmbeddingBatchGroupOptions(params, { + maxRequests: GEMINI_BATCH_MAX_REQUESTS, + debugLabel: "memory embeddings: gemini batch submit", + }), + runGroup: async ({ group, groupIndex, groups, byCustomId }) => { + const batchInfo = await submitGeminiBatch({ + gemini: params.gemini, + requests: group, + agentId: params.agentId, + }); + const batchName = batchInfo.name ?? ""; + if (!batchName) { + throw new Error("gemini batch create failed: missing batch name"); + } + + params.debug?.("memory embeddings: gemini batch created", { + batchName, + state: batchInfo.state, + group: groupIndex + 1, + groups, + requests: group.length, + }); + + if ( + !params.wait && + batchInfo.state && + !["SUCCEEDED", "COMPLETED", "DONE"].includes(batchInfo.state) + ) { + throw new Error( + `gemini batch ${batchName} submitted; enable remote.batch.wait to await completion`, + ); + } + + const completed = + batchInfo.state && ["SUCCEEDED", "COMPLETED", "DONE"].includes(batchInfo.state) + ? { + outputFileId: + batchInfo.outputConfig?.file ?? + batchInfo.outputConfig?.fileId ?? + batchInfo.metadata?.output?.responsesFile ?? + "", + } + : await waitForGeminiBatch({ + gemini: params.gemini, + batchName, + wait: params.wait, + pollIntervalMs: params.pollIntervalMs, + timeoutMs: params.timeoutMs, + debug: params.debug, + initial: batchInfo, + }); + if (!completed.outputFileId) { + throw new Error(`gemini batch ${batchName} completed without output file`); + } + + const content = await fetchGeminiFileContent({ + gemini: params.gemini, + fileId: completed.outputFileId, + }); + const outputLines = parseGeminiBatchOutput(content); + const errors: string[] = []; + const remaining = new Set(group.map((request) => request.custom_id)); + + for (const line of outputLines) { + const customId = line.key ?? line.custom_id ?? line.request_id; + if (!customId) { + continue; + } + remaining.delete(customId); + if (line.error?.message) { + errors.push(`${customId}: ${line.error.message}`); + continue; + } + if (line.response?.error?.message) { + errors.push(`${customId}: ${line.response.error.message}`); + continue; + } + const embedding = sanitizeAndNormalizeEmbedding( + line.embedding?.values ?? line.response?.embedding?.values ?? [], + ); + if (embedding.length === 0) { + errors.push(`${customId}: empty embedding`); + continue; + } + byCustomId.set(customId, embedding); + } + + if (errors.length > 0) { + throw new Error(`gemini batch ${batchName} failed: ${errors.join("; ")}`); + } + if (remaining.size > 0) { + throw new Error(`gemini batch ${batchName} missing ${remaining.size} embedding responses`); + } + }, + }); +} diff --git a/src/memory-host-sdk/host/batch-http.test.ts b/src/memory-host-sdk/host/batch-http.test.ts new file mode 100644 index 00000000000..3519a80f038 --- /dev/null +++ b/src/memory-host-sdk/host/batch-http.test.ts @@ -0,0 +1,86 @@ +import { beforeAll, beforeEach, describe, expect, it, vi } from "vitest"; + +vi.mock("../../infra/retry.js", () => ({ + retryAsync: vi.fn(async (run: () => Promise) => await run()), +})); + +vi.mock("./post-json.js", () => ({ + postJson: vi.fn(), +})); + +describe("postJsonWithRetry", () => { + let retryAsyncMock: ReturnType< + typeof vi.mocked + >; + let postJsonMock: ReturnType>; + let postJsonWithRetry: typeof import("./batch-http.js").postJsonWithRetry; + + beforeAll(async () => { + ({ postJsonWithRetry } = await import("./batch-http.js")); + const retryModule = await import("../../infra/retry.js"); + const postJsonModule = await import("./post-json.js"); + retryAsyncMock = vi.mocked(retryModule.retryAsync); + postJsonMock = vi.mocked(postJsonModule.postJson); + }); + + beforeEach(() => { + vi.clearAllMocks(); + }); + + it("posts JSON and returns parsed response payload", async () => { + postJsonMock.mockImplementationOnce(async (params) => { + return await params.parse({ ok: true, ids: [1, 2] }); + }); + + const result = await postJsonWithRetry<{ ok: boolean; ids: number[] }>({ + url: "https://memory.example/v1/batch", + headers: { Authorization: "Bearer test" }, + body: { chunks: ["a", "b"] }, + errorPrefix: "memory batch failed", + }); + + expect(result).toEqual({ ok: true, ids: [1, 2] }); + expect(postJsonMock).toHaveBeenCalledWith( + expect.objectContaining({ + url: "https://memory.example/v1/batch", + headers: { Authorization: "Bearer test" }, + body: { chunks: ["a", "b"] }, + errorPrefix: "memory batch failed", + attachStatus: true, + }), + ); + + const retryOptions = retryAsyncMock.mock.calls[0]?.[1] as + | { + attempts: number; + minDelayMs: number; + maxDelayMs: number; + shouldRetry: (err: unknown) => boolean; + } + | undefined; + expect(retryOptions?.attempts).toBe(3); + expect(retryOptions?.minDelayMs).toBe(300); + expect(retryOptions?.maxDelayMs).toBe(2000); + expect(retryOptions?.shouldRetry({ status: 429 })).toBe(true); + expect(retryOptions?.shouldRetry({ status: 503 })).toBe(true); + expect(retryOptions?.shouldRetry({ status: 400 })).toBe(false); + }); + + it("attaches status to non-ok errors", async () => { + postJsonMock.mockRejectedValueOnce( + Object.assign(new Error("memory batch failed: 503 backend down"), { status: 503 }), + ); + + await expect( + postJsonWithRetry({ + url: "https://memory.example/v1/batch", + headers: {}, + body: { chunks: [] }, + errorPrefix: "memory batch failed", + }), + ).rejects.toMatchObject({ + message: expect.stringContaining("memory batch failed: 503 backend down"), + status: 503, + }); + }); +}); diff --git a/src/memory-host-sdk/host/batch-http.ts b/src/memory-host-sdk/host/batch-http.ts new file mode 100644 index 00000000000..b098d382f16 --- /dev/null +++ b/src/memory-host-sdk/host/batch-http.ts @@ -0,0 +1,35 @@ +import type { SsrFPolicy } from "../../infra/net/ssrf.js"; +import { retryAsync } from "../../infra/retry.js"; +import { postJson } from "./post-json.js"; + +export async function postJsonWithRetry(params: { + url: string; + headers: Record; + ssrfPolicy?: SsrFPolicy; + body: unknown; + errorPrefix: string; +}): Promise { + return await retryAsync( + async () => { + return await postJson({ + url: params.url, + headers: params.headers, + ssrfPolicy: params.ssrfPolicy, + body: params.body, + errorPrefix: params.errorPrefix, + attachStatus: true, + parse: async (payload) => payload as T, + }); + }, + { + attempts: 3, + minDelayMs: 300, + maxDelayMs: 2000, + jitter: 0.2, + shouldRetry: (err) => { + const status = (err as { status?: number }).status; + return status === 429 || (typeof status === "number" && status >= 500); + }, + }, + ); +} diff --git a/src/memory-host-sdk/host/batch-openai.ts b/src/memory-host-sdk/host/batch-openai.ts new file mode 100644 index 00000000000..e17a420812c --- /dev/null +++ b/src/memory-host-sdk/host/batch-openai.ts @@ -0,0 +1,259 @@ +import { + applyEmbeddingBatchOutputLine, + buildBatchHeaders, + buildEmbeddingBatchGroupOptions, + EMBEDDING_BATCH_ENDPOINT, + extractBatchErrorMessage, + formatUnavailableBatchError, + normalizeBatchBaseUrl, + postJsonWithRetry, + resolveBatchCompletionFromStatus, + resolveCompletedBatchResult, + runEmbeddingBatchGroups, + throwIfBatchTerminalFailure, + type EmbeddingBatchExecutionParams, + type EmbeddingBatchStatus, + type BatchCompletionResult, + type ProviderBatchOutputLine, + uploadBatchJsonlFile, + withRemoteHttpResponse, +} from "./batch-embedding-common.js"; +import type { OpenAiEmbeddingClient } from "./embeddings-openai.js"; + +export type OpenAiBatchRequest = { + custom_id: string; + method: "POST"; + url: "/v1/embeddings"; + body: { + model: string; + input: string; + }; +}; + +export type OpenAiBatchStatus = EmbeddingBatchStatus; +export type OpenAiBatchOutputLine = ProviderBatchOutputLine; + +export const OPENAI_BATCH_ENDPOINT = EMBEDDING_BATCH_ENDPOINT; +const OPENAI_BATCH_COMPLETION_WINDOW = "24h"; +const OPENAI_BATCH_MAX_REQUESTS = 50000; + +async function submitOpenAiBatch(params: { + openAi: OpenAiEmbeddingClient; + requests: OpenAiBatchRequest[]; + agentId: string; +}): Promise { + const baseUrl = normalizeBatchBaseUrl(params.openAi); + const inputFileId = await uploadBatchJsonlFile({ + client: params.openAi, + requests: params.requests, + errorPrefix: "openai batch file upload failed", + }); + + return await postJsonWithRetry({ + url: `${baseUrl}/batches`, + headers: buildBatchHeaders(params.openAi, { json: true }), + ssrfPolicy: params.openAi.ssrfPolicy, + body: { + input_file_id: inputFileId, + endpoint: OPENAI_BATCH_ENDPOINT, + completion_window: OPENAI_BATCH_COMPLETION_WINDOW, + metadata: { + source: "openclaw-memory", + agent: params.agentId, + }, + }, + errorPrefix: "openai batch create failed", + }); +} + +async function fetchOpenAiBatchStatus(params: { + openAi: OpenAiEmbeddingClient; + batchId: string; +}): Promise { + return await fetchOpenAiBatchResource({ + openAi: params.openAi, + path: `/batches/${params.batchId}`, + errorPrefix: "openai batch status", + parse: async (res) => (await res.json()) as OpenAiBatchStatus, + }); +} + +async function fetchOpenAiFileContent(params: { + openAi: OpenAiEmbeddingClient; + fileId: string; +}): Promise { + return await fetchOpenAiBatchResource({ + openAi: params.openAi, + path: `/files/${params.fileId}/content`, + errorPrefix: "openai batch file content", + parse: async (res) => await res.text(), + }); +} + +async function fetchOpenAiBatchResource(params: { + openAi: OpenAiEmbeddingClient; + path: string; + errorPrefix: string; + parse: (res: Response) => Promise; +}): Promise { + const baseUrl = normalizeBatchBaseUrl(params.openAi); + return await withRemoteHttpResponse({ + url: `${baseUrl}${params.path}`, + ssrfPolicy: params.openAi.ssrfPolicy, + init: { + headers: buildBatchHeaders(params.openAi, { json: true }), + }, + onResponse: async (res) => { + if (!res.ok) { + const text = await res.text(); + throw new Error(`${params.errorPrefix} failed: ${res.status} ${text}`); + } + return await params.parse(res); + }, + }); +} + +function parseOpenAiBatchOutput(text: string): OpenAiBatchOutputLine[] { + if (!text.trim()) { + return []; + } + return text + .split("\n") + .map((line) => line.trim()) + .filter(Boolean) + .map((line) => JSON.parse(line) as OpenAiBatchOutputLine); +} + +async function readOpenAiBatchError(params: { + openAi: OpenAiEmbeddingClient; + errorFileId: string; +}): Promise { + try { + const content = await fetchOpenAiFileContent({ + openAi: params.openAi, + fileId: params.errorFileId, + }); + const lines = parseOpenAiBatchOutput(content); + return extractBatchErrorMessage(lines); + } catch (err) { + return formatUnavailableBatchError(err); + } +} + +async function waitForOpenAiBatch(params: { + openAi: OpenAiEmbeddingClient; + batchId: string; + wait: boolean; + pollIntervalMs: number; + timeoutMs: number; + debug?: (message: string, data?: Record) => void; + initial?: OpenAiBatchStatus; +}): Promise { + const start = Date.now(); + let current: OpenAiBatchStatus | undefined = params.initial; + while (true) { + const status = + current ?? + (await fetchOpenAiBatchStatus({ + openAi: params.openAi, + batchId: params.batchId, + })); + const state = status.status ?? "unknown"; + if (state === "completed") { + return resolveBatchCompletionFromStatus({ + provider: "openai", + batchId: params.batchId, + status, + }); + } + await throwIfBatchTerminalFailure({ + provider: "openai", + status: { ...status, id: params.batchId }, + readError: async (errorFileId) => + await readOpenAiBatchError({ + openAi: params.openAi, + errorFileId, + }), + }); + if (!params.wait) { + throw new Error(`openai batch ${params.batchId} still ${state}; wait disabled`); + } + if (Date.now() - start > params.timeoutMs) { + throw new Error(`openai batch ${params.batchId} timed out after ${params.timeoutMs}ms`); + } + params.debug?.(`openai batch ${params.batchId} ${state}; waiting ${params.pollIntervalMs}ms`); + await new Promise((resolve) => setTimeout(resolve, params.pollIntervalMs)); + current = undefined; + } +} + +export async function runOpenAiEmbeddingBatches( + params: { + openAi: OpenAiEmbeddingClient; + agentId: string; + requests: OpenAiBatchRequest[]; + } & EmbeddingBatchExecutionParams, +): Promise> { + return await runEmbeddingBatchGroups({ + ...buildEmbeddingBatchGroupOptions(params, { + maxRequests: OPENAI_BATCH_MAX_REQUESTS, + debugLabel: "memory embeddings: openai batch submit", + }), + runGroup: async ({ group, groupIndex, groups, byCustomId }) => { + const batchInfo = await submitOpenAiBatch({ + openAi: params.openAi, + requests: group, + agentId: params.agentId, + }); + if (!batchInfo.id) { + throw new Error("openai batch create failed: missing batch id"); + } + const batchId = batchInfo.id; + + params.debug?.("memory embeddings: openai batch created", { + batchId: batchInfo.id, + status: batchInfo.status, + group: groupIndex + 1, + groups, + requests: group.length, + }); + + const completed = await resolveCompletedBatchResult({ + provider: "openai", + status: batchInfo, + wait: params.wait, + waitForBatch: async () => + await waitForOpenAiBatch({ + openAi: params.openAi, + batchId, + wait: params.wait, + pollIntervalMs: params.pollIntervalMs, + timeoutMs: params.timeoutMs, + debug: params.debug, + initial: batchInfo, + }), + }); + + const content = await fetchOpenAiFileContent({ + openAi: params.openAi, + fileId: completed.outputFileId, + }); + const outputLines = parseOpenAiBatchOutput(content); + const errors: string[] = []; + const remaining = new Set(group.map((request) => request.custom_id)); + + for (const line of outputLines) { + applyEmbeddingBatchOutputLine({ line, remaining, errors, byCustomId }); + } + + if (errors.length > 0) { + throw new Error(`openai batch ${batchInfo.id} failed: ${errors.join("; ")}`); + } + if (remaining.size > 0) { + throw new Error( + `openai batch ${batchInfo.id} missing ${remaining.size} embedding responses`, + ); + } + }, + }); +} diff --git a/src/memory-host-sdk/host/batch-output.test.ts b/src/memory-host-sdk/host/batch-output.test.ts new file mode 100644 index 00000000000..b5aa9334238 --- /dev/null +++ b/src/memory-host-sdk/host/batch-output.test.ts @@ -0,0 +1,82 @@ +import { describe, expect, it } from "vitest"; +import { applyEmbeddingBatchOutputLine } from "./batch-output.js"; + +describe("applyEmbeddingBatchOutputLine", () => { + it("stores embedding for successful response", () => { + const remaining = new Set(["req-1"]); + const errors: string[] = []; + const byCustomId = new Map(); + + applyEmbeddingBatchOutputLine({ + line: { + custom_id: "req-1", + response: { + status_code: 200, + body: { data: [{ embedding: [0.1, 0.2] }] }, + }, + }, + remaining, + errors, + byCustomId, + }); + + expect(remaining.has("req-1")).toBe(false); + expect(errors).toEqual([]); + expect(byCustomId.get("req-1")).toEqual([0.1, 0.2]); + }); + + it("records provider error from line.error", () => { + const remaining = new Set(["req-2"]); + const errors: string[] = []; + const byCustomId = new Map(); + + applyEmbeddingBatchOutputLine({ + line: { + custom_id: "req-2", + error: { message: "provider failed" }, + }, + remaining, + errors, + byCustomId, + }); + + expect(remaining.has("req-2")).toBe(false); + expect(errors).toEqual(["req-2: provider failed"]); + expect(byCustomId.size).toBe(0); + }); + + it("records non-2xx response errors and empty embedding errors", () => { + const remaining = new Set(["req-3", "req-4"]); + const errors: string[] = []; + const byCustomId = new Map(); + + applyEmbeddingBatchOutputLine({ + line: { + custom_id: "req-3", + response: { + status_code: 500, + body: { error: { message: "internal" } }, + }, + }, + remaining, + errors, + byCustomId, + }); + + applyEmbeddingBatchOutputLine({ + line: { + custom_id: "req-4", + response: { + status_code: 200, + body: { data: [] }, + }, + }, + remaining, + errors, + byCustomId, + }); + + expect(errors).toEqual(["req-3: internal", "req-4: empty embedding"]); + expect(byCustomId.size).toBe(0); + }); +}); diff --git a/src/memory-host-sdk/host/batch-output.ts b/src/memory-host-sdk/host/batch-output.ts new file mode 100644 index 00000000000..e2a75a878da --- /dev/null +++ b/src/memory-host-sdk/host/batch-output.ts @@ -0,0 +1,55 @@ +export type EmbeddingBatchOutputLine = { + custom_id?: string; + error?: { message?: string }; + response?: { + status_code?: number; + body?: + | { + data?: Array<{ + embedding?: number[]; + }>; + error?: { message?: string }; + } + | string; + }; +}; + +export function applyEmbeddingBatchOutputLine(params: { + line: EmbeddingBatchOutputLine; + remaining: Set; + errors: string[]; + byCustomId: Map; +}) { + const customId = params.line.custom_id; + if (!customId) { + return; + } + params.remaining.delete(customId); + + const errorMessage = params.line.error?.message; + if (errorMessage) { + params.errors.push(`${customId}: ${errorMessage}`); + return; + } + + const response = params.line.response; + const statusCode = response?.status_code ?? 0; + if (statusCode >= 400) { + const messageFromObject = + response?.body && typeof response.body === "object" + ? (response.body as { error?: { message?: string } }).error?.message + : undefined; + const messageFromString = typeof response?.body === "string" ? response.body : undefined; + params.errors.push(`${customId}: ${messageFromObject ?? messageFromString ?? "unknown error"}`); + return; + } + + const data = + response?.body && typeof response.body === "object" ? (response.body.data ?? []) : []; + const embedding = data[0]?.embedding ?? []; + if (embedding.length === 0) { + params.errors.push(`${customId}: empty embedding`); + return; + } + params.byCustomId.set(customId, embedding); +} diff --git a/src/memory-host-sdk/host/batch-provider-common.ts b/src/memory-host-sdk/host/batch-provider-common.ts new file mode 100644 index 00000000000..878387ffd6d --- /dev/null +++ b/src/memory-host-sdk/host/batch-provider-common.ts @@ -0,0 +1,12 @@ +import type { EmbeddingBatchOutputLine } from "./batch-output.js"; + +export type EmbeddingBatchStatus = { + id?: string; + status?: string; + output_file_id?: string | null; + error_file_id?: string | null; +}; + +export type ProviderBatchOutputLine = EmbeddingBatchOutputLine; + +export const EMBEDDING_BATCH_ENDPOINT = "/v1/embeddings"; diff --git a/src/memory-host-sdk/host/batch-runner.ts b/src/memory-host-sdk/host/batch-runner.ts new file mode 100644 index 00000000000..aa1785095bb --- /dev/null +++ b/src/memory-host-sdk/host/batch-runner.ts @@ -0,0 +1,64 @@ +import { splitBatchRequests } from "./batch-utils.js"; +import { runWithConcurrency } from "./internal.js"; + +export type EmbeddingBatchExecutionParams = { + wait: boolean; + pollIntervalMs: number; + timeoutMs: number; + concurrency: number; + debug?: (message: string, data?: Record) => void; +}; + +export async function runEmbeddingBatchGroups(params: { + requests: TRequest[]; + maxRequests: number; + wait: EmbeddingBatchExecutionParams["wait"]; + pollIntervalMs: EmbeddingBatchExecutionParams["pollIntervalMs"]; + timeoutMs: EmbeddingBatchExecutionParams["timeoutMs"]; + concurrency: EmbeddingBatchExecutionParams["concurrency"]; + debugLabel: string; + debug?: EmbeddingBatchExecutionParams["debug"]; + runGroup: (args: { + group: TRequest[]; + groupIndex: number; + groups: number; + byCustomId: Map; + }) => Promise; +}): Promise> { + if (params.requests.length === 0) { + return new Map(); + } + const groups = splitBatchRequests(params.requests, params.maxRequests); + const byCustomId = new Map(); + const tasks = groups.map((group, groupIndex) => async () => { + await params.runGroup({ group, groupIndex, groups: groups.length, byCustomId }); + }); + + params.debug?.(params.debugLabel, { + requests: params.requests.length, + groups: groups.length, + wait: params.wait, + concurrency: params.concurrency, + pollIntervalMs: params.pollIntervalMs, + timeoutMs: params.timeoutMs, + }); + + await runWithConcurrency(tasks, params.concurrency); + return byCustomId; +} + +export function buildEmbeddingBatchGroupOptions( + params: { requests: TRequest[] } & EmbeddingBatchExecutionParams, + options: { maxRequests: number; debugLabel: string }, +) { + return { + requests: params.requests, + maxRequests: options.maxRequests, + wait: params.wait, + pollIntervalMs: params.pollIntervalMs, + timeoutMs: params.timeoutMs, + concurrency: params.concurrency, + debug: params.debug, + debugLabel: options.debugLabel, + }; +} diff --git a/src/memory-host-sdk/host/batch-status.test.ts b/src/memory-host-sdk/host/batch-status.test.ts new file mode 100644 index 00000000000..82a992556af --- /dev/null +++ b/src/memory-host-sdk/host/batch-status.test.ts @@ -0,0 +1,60 @@ +import { describe, expect, it } from "vitest"; +import { + resolveBatchCompletionFromStatus, + resolveCompletedBatchResult, + throwIfBatchTerminalFailure, +} from "./batch-status.js"; + +describe("batch-status helpers", () => { + it("resolves completion payload from completed status", () => { + expect( + resolveBatchCompletionFromStatus({ + provider: "openai", + batchId: "b1", + status: { + output_file_id: "out-1", + error_file_id: "err-1", + }, + }), + ).toEqual({ + outputFileId: "out-1", + errorFileId: "err-1", + }); + }); + + it("throws for terminal failure states", async () => { + await expect( + throwIfBatchTerminalFailure({ + provider: "voyage", + status: { id: "b2", status: "failed", error_file_id: "err-file" }, + readError: async () => "bad input", + }), + ).rejects.toThrow("voyage batch b2 failed: bad input"); + }); + + it("returns completed result directly without waiting", async () => { + const waitForBatch = async () => ({ outputFileId: "out-2" }); + const result = await resolveCompletedBatchResult({ + provider: "openai", + status: { + id: "b3", + status: "completed", + output_file_id: "out-3", + }, + wait: false, + waitForBatch, + }); + expect(result).toEqual({ outputFileId: "out-3", errorFileId: undefined }); + }); + + it("throws when wait disabled and batch is not complete", async () => { + await expect( + resolveCompletedBatchResult({ + provider: "openai", + status: { id: "b4", status: "pending" }, + wait: false, + waitForBatch: async () => ({ outputFileId: "out" }), + }), + ).rejects.toThrow("openai batch b4 submitted; enable remote.batch.wait to await completion"); + }); +}); diff --git a/src/memory-host-sdk/host/batch-status.ts b/src/memory-host-sdk/host/batch-status.ts new file mode 100644 index 00000000000..96e8da62894 --- /dev/null +++ b/src/memory-host-sdk/host/batch-status.ts @@ -0,0 +1,69 @@ +const TERMINAL_FAILURE_STATES = new Set(["failed", "expired", "cancelled", "canceled"]); + +type BatchStatusLike = { + id?: string; + status?: string; + output_file_id?: string | null; + error_file_id?: string | null; +}; + +export type BatchCompletionResult = { + outputFileId: string; + errorFileId?: string; +}; + +export function resolveBatchCompletionFromStatus(params: { + provider: string; + batchId: string; + status: BatchStatusLike; +}): BatchCompletionResult { + if (!params.status.output_file_id) { + throw new Error(`${params.provider} batch ${params.batchId} completed without output file`); + } + return { + outputFileId: params.status.output_file_id, + errorFileId: params.status.error_file_id ?? undefined, + }; +} + +export async function throwIfBatchTerminalFailure(params: { + provider: string; + status: BatchStatusLike; + readError: (errorFileId: string) => Promise; +}): Promise { + const state = params.status.status ?? "unknown"; + if (!TERMINAL_FAILURE_STATES.has(state)) { + return; + } + const detail = params.status.error_file_id + ? await params.readError(params.status.error_file_id) + : undefined; + const suffix = detail ? `: ${detail}` : ""; + throw new Error(`${params.provider} batch ${params.status.id ?? ""} ${state}${suffix}`); +} + +export async function resolveCompletedBatchResult(params: { + provider: string; + status: BatchStatusLike; + wait: boolean; + waitForBatch: () => Promise; +}): Promise { + const batchId = params.status.id ?? ""; + if (!params.wait && params.status.status !== "completed") { + throw new Error( + `${params.provider} batch ${batchId} submitted; enable remote.batch.wait to await completion`, + ); + } + const completed = + params.status.status === "completed" + ? resolveBatchCompletionFromStatus({ + provider: params.provider, + batchId, + status: params.status, + }) + : await params.waitForBatch(); + if (!completed.outputFileId) { + throw new Error(`${params.provider} batch ${batchId} completed without output file`); + } + return completed; +} diff --git a/src/memory-host-sdk/host/batch-upload.ts b/src/memory-host-sdk/host/batch-upload.ts new file mode 100644 index 00000000000..efe4aa7000a --- /dev/null +++ b/src/memory-host-sdk/host/batch-upload.ts @@ -0,0 +1,44 @@ +import { + buildBatchHeaders, + normalizeBatchBaseUrl, + type BatchHttpClientConfig, +} from "./batch-utils.js"; +import { hashText } from "./internal.js"; +import { withRemoteHttpResponse } from "./remote-http.js"; + +export async function uploadBatchJsonlFile(params: { + client: BatchHttpClientConfig; + requests: unknown[]; + errorPrefix: string; +}): Promise { + const baseUrl = normalizeBatchBaseUrl(params.client); + const jsonl = params.requests.map((request) => JSON.stringify(request)).join("\n"); + const form = new FormData(); + form.append("purpose", "batch"); + form.append( + "file", + new Blob([jsonl], { type: "application/jsonl" }), + `memory-embeddings.${hashText(String(Date.now()))}.jsonl`, + ); + + const filePayload = await withRemoteHttpResponse({ + url: `${baseUrl}/files`, + ssrfPolicy: params.client.ssrfPolicy, + init: { + method: "POST", + headers: buildBatchHeaders(params.client, { json: false }), + body: form, + }, + onResponse: async (fileRes) => { + if (!fileRes.ok) { + const text = await fileRes.text(); + throw new Error(`${params.errorPrefix}: ${fileRes.status} ${text}`); + } + return (await fileRes.json()) as { id?: string }; + }, + }); + if (!filePayload.id) { + throw new Error(`${params.errorPrefix}: missing file id`); + } + return filePayload.id; +} diff --git a/src/memory-host-sdk/host/batch-utils.ts b/src/memory-host-sdk/host/batch-utils.ts new file mode 100644 index 00000000000..c44dace3f8a --- /dev/null +++ b/src/memory-host-sdk/host/batch-utils.ts @@ -0,0 +1,38 @@ +import type { SsrFPolicy } from "../../infra/net/ssrf.js"; + +export type BatchHttpClientConfig = { + baseUrl?: string; + headers?: Record; + ssrfPolicy?: SsrFPolicy; +}; + +export function normalizeBatchBaseUrl(client: BatchHttpClientConfig): string { + return client.baseUrl?.replace(/\/$/, "") ?? ""; +} + +export function buildBatchHeaders( + client: Pick, + params: { json: boolean }, +): Record { + const headers = client.headers ? { ...client.headers } : {}; + if (params.json) { + if (!headers["Content-Type"] && !headers["content-type"]) { + headers["Content-Type"] = "application/json"; + } + } else { + delete headers["Content-Type"]; + delete headers["content-type"]; + } + return headers; +} + +export function splitBatchRequests(requests: T[], maxRequests: number): T[][] { + if (requests.length <= maxRequests) { + return [requests]; + } + const groups: T[][] = []; + for (let i = 0; i < requests.length; i += maxRequests) { + groups.push(requests.slice(i, i + maxRequests)); + } + return groups; +} diff --git a/src/memory-host-sdk/host/batch-voyage.test.ts b/src/memory-host-sdk/host/batch-voyage.test.ts new file mode 100644 index 00000000000..2fcdb9ec7c0 --- /dev/null +++ b/src/memory-host-sdk/host/batch-voyage.test.ts @@ -0,0 +1,176 @@ +import { ReadableStream } from "node:stream/web"; +import { setTimeout as nativeSleep } from "node:timers/promises"; +import { describe, expect, it, vi } from "vitest"; +import { + runVoyageEmbeddingBatches, + type VoyageBatchOutputLine, + type VoyageBatchRequest, +} from "./batch-voyage.js"; +import type { VoyageEmbeddingClient } from "./embeddings-voyage.js"; + +const realNow = Date.now.bind(Date); + +describe("runVoyageEmbeddingBatches", () => { + const mockClient: VoyageEmbeddingClient = { + baseUrl: "https://api.voyageai.com/v1", + headers: { Authorization: "Bearer test-key" }, + model: "voyage-4-large", + }; + + const mockRequests: VoyageBatchRequest[] = [ + { custom_id: "req-1", body: { input: "text1" } }, + { custom_id: "req-2", body: { input: "text2" } }, + ]; + + it("successfully submits batch, waits, and streams results", async () => { + const outputLines: VoyageBatchOutputLine[] = [ + { + custom_id: "req-1", + response: { status_code: 200, body: { data: [{ embedding: [0.1, 0.1] }] } }, + }, + { + custom_id: "req-2", + response: { status_code: 200, body: { data: [{ embedding: [0.2, 0.2] }] } }, + }, + ]; + const withRemoteHttpResponse = vi.fn(); + const postJsonWithRetry = vi.fn(); + const uploadBatchJsonlFile = vi.fn(); + + // Create a stream that emits the NDJSON lines + const stream = new ReadableStream({ + start(controller) { + const text = outputLines.map((l) => JSON.stringify(l)).join("\n"); + controller.enqueue(new TextEncoder().encode(text)); + controller.close(); + }, + }); + uploadBatchJsonlFile.mockImplementationOnce(async (params) => { + expect(params.errorPrefix).toBe("voyage batch file upload failed"); + expect(params.requests).toEqual(mockRequests); + return "file-123"; + }); + postJsonWithRetry.mockImplementationOnce(async (params) => { + expect(params.url).toContain("/batches"); + expect(params.body).toMatchObject({ + input_file_id: "file-123", + completion_window: "12h", + request_params: { + model: "voyage-4-large", + input_type: "document", + }, + }); + return { + id: "batch-abc", + status: "pending", + }; + }); + withRemoteHttpResponse.mockImplementationOnce(async (params) => { + expect(params.url).toContain("/batches/batch-abc"); + return await params.onResponse( + new Response( + JSON.stringify({ + id: "batch-abc", + status: "completed", + output_file_id: "file-out-999", + }), + { + status: 200, + headers: { "Content-Type": "application/json" }, + }, + ), + ); + }); + withRemoteHttpResponse.mockImplementationOnce(async (params) => { + expect(params.url).toContain("/files/file-out-999/content"); + return await params.onResponse( + new Response(stream as unknown as BodyInit, { + status: 200, + headers: { "Content-Type": "application/x-ndjson" }, + }), + ); + }); + + const results = await runVoyageEmbeddingBatches({ + client: mockClient, + agentId: "agent-1", + requests: mockRequests, + wait: true, + pollIntervalMs: 1, // fast poll + timeoutMs: 1000, + concurrency: 1, + deps: { + now: realNow, + sleep: async (ms) => { + await nativeSleep(ms); + }, + postJsonWithRetry, + uploadBatchJsonlFile, + withRemoteHttpResponse, + }, + }); + + expect(results.size).toBe(2); + expect(results.get("req-1")).toEqual([0.1, 0.1]); + expect(results.get("req-2")).toEqual([0.2, 0.2]); + expect(uploadBatchJsonlFile).toHaveBeenCalledTimes(1); + expect(postJsonWithRetry).toHaveBeenCalledTimes(1); + expect(withRemoteHttpResponse).toHaveBeenCalledTimes(2); + }); + + it("handles empty lines and stream chunks correctly", async () => { + const withRemoteHttpResponse = vi.fn(); + const postJsonWithRetry = vi.fn(); + const uploadBatchJsonlFile = vi.fn(); + const stream = new ReadableStream({ + start(controller) { + const line1 = JSON.stringify({ + custom_id: "req-1", + response: { body: { data: [{ embedding: [1] }] } }, + }); + const line2 = JSON.stringify({ + custom_id: "req-2", + response: { body: { data: [{ embedding: [2] }] } }, + }); + + // Split across chunks + controller.enqueue(new TextEncoder().encode(line1 + "\n")); + controller.enqueue(new TextEncoder().encode("\n")); // empty line + controller.enqueue(new TextEncoder().encode(line2)); // no newline at EOF + controller.close(); + }, + }); + uploadBatchJsonlFile.mockResolvedValueOnce("f1"); + postJsonWithRetry.mockResolvedValueOnce({ + id: "b1", + status: "completed", + output_file_id: "out1", + }); + withRemoteHttpResponse.mockImplementationOnce(async (params) => { + expect(params.url).toContain("/files/out1/content"); + return await params.onResponse(new Response(stream as unknown as BodyInit, { status: 200 })); + }); + + const results = await runVoyageEmbeddingBatches({ + client: mockClient, + agentId: "a1", + requests: mockRequests, + wait: true, + pollIntervalMs: 1, + timeoutMs: 1000, + concurrency: 1, + deps: { + now: realNow, + sleep: async (ms) => { + await nativeSleep(ms); + }, + postJsonWithRetry, + uploadBatchJsonlFile, + withRemoteHttpResponse, + }, + }); + + expect(results.get("req-1")).toEqual([1]); + expect(results.get("req-2")).toEqual([2]); + }); +}); diff --git a/src/memory-host-sdk/host/batch-voyage.ts b/src/memory-host-sdk/host/batch-voyage.ts new file mode 100644 index 00000000000..fcb257a4d7d --- /dev/null +++ b/src/memory-host-sdk/host/batch-voyage.ts @@ -0,0 +1,315 @@ +import { createInterface } from "node:readline"; +import { Readable } from "node:stream"; +import { + applyEmbeddingBatchOutputLine, + buildBatchHeaders, + buildEmbeddingBatchGroupOptions, + EMBEDDING_BATCH_ENDPOINT, + extractBatchErrorMessage, + formatUnavailableBatchError, + normalizeBatchBaseUrl, + postJsonWithRetry, + resolveBatchCompletionFromStatus, + resolveCompletedBatchResult, + runEmbeddingBatchGroups, + throwIfBatchTerminalFailure, + type EmbeddingBatchExecutionParams, + type EmbeddingBatchStatus, + type BatchCompletionResult, + type ProviderBatchOutputLine, + uploadBatchJsonlFile, + withRemoteHttpResponse, +} from "./batch-embedding-common.js"; +import type { VoyageEmbeddingClient } from "./embeddings-voyage.js"; + +/** + * Voyage Batch API Input Line format. + * See: https://docs.voyageai.com/docs/batch-inference + */ +export type VoyageBatchRequest = { + custom_id: string; + body: { + input: string | string[]; + }; +}; + +export type VoyageBatchStatus = EmbeddingBatchStatus; +export type VoyageBatchOutputLine = ProviderBatchOutputLine; + +export const VOYAGE_BATCH_ENDPOINT = EMBEDDING_BATCH_ENDPOINT; +const VOYAGE_BATCH_COMPLETION_WINDOW = "12h"; +const VOYAGE_BATCH_MAX_REQUESTS = 50000; + +type VoyageBatchDeps = { + now: () => number; + sleep: (ms: number) => Promise; + postJsonWithRetry: typeof postJsonWithRetry; + uploadBatchJsonlFile: typeof uploadBatchJsonlFile; + withRemoteHttpResponse: typeof withRemoteHttpResponse; +}; + +function resolveVoyageBatchDeps(overrides: Partial | undefined): VoyageBatchDeps { + return { + now: overrides?.now ?? Date.now, + sleep: + overrides?.sleep ?? + (async (ms: number) => await new Promise((resolve) => setTimeout(resolve, ms))), + postJsonWithRetry: overrides?.postJsonWithRetry ?? postJsonWithRetry, + uploadBatchJsonlFile: overrides?.uploadBatchJsonlFile ?? uploadBatchJsonlFile, + withRemoteHttpResponse: overrides?.withRemoteHttpResponse ?? withRemoteHttpResponse, + }; +} + +async function assertVoyageResponseOk(res: Response, context: string): Promise { + if (!res.ok) { + const text = await res.text(); + throw new Error(`${context}: ${res.status} ${text}`); + } +} + +function buildVoyageBatchRequest(params: { + client: VoyageEmbeddingClient; + path: string; + onResponse: (res: Response) => Promise; +}) { + const baseUrl = normalizeBatchBaseUrl(params.client); + return { + url: `${baseUrl}/${params.path}`, + ssrfPolicy: params.client.ssrfPolicy, + init: { + headers: buildBatchHeaders(params.client, { json: true }), + }, + onResponse: params.onResponse, + }; +} + +async function submitVoyageBatch(params: { + client: VoyageEmbeddingClient; + requests: VoyageBatchRequest[]; + agentId: string; + deps: VoyageBatchDeps; +}): Promise { + const baseUrl = normalizeBatchBaseUrl(params.client); + const inputFileId = await params.deps.uploadBatchJsonlFile({ + client: params.client, + requests: params.requests, + errorPrefix: "voyage batch file upload failed", + }); + + // 2. Create batch job using Voyage Batches API + return await params.deps.postJsonWithRetry({ + url: `${baseUrl}/batches`, + headers: buildBatchHeaders(params.client, { json: true }), + ssrfPolicy: params.client.ssrfPolicy, + body: { + input_file_id: inputFileId, + endpoint: VOYAGE_BATCH_ENDPOINT, + completion_window: VOYAGE_BATCH_COMPLETION_WINDOW, + request_params: { + model: params.client.model, + input_type: "document", + }, + metadata: { + source: "clawdbot-memory", + agent: params.agentId, + }, + }, + errorPrefix: "voyage batch create failed", + }); +} + +async function fetchVoyageBatchStatus(params: { + client: VoyageEmbeddingClient; + batchId: string; + deps: VoyageBatchDeps; +}): Promise { + return await params.deps.withRemoteHttpResponse( + buildVoyageBatchRequest({ + client: params.client, + path: `batches/${params.batchId}`, + onResponse: async (res) => { + await assertVoyageResponseOk(res, "voyage batch status failed"); + return (await res.json()) as VoyageBatchStatus; + }, + }), + ); +} + +async function readVoyageBatchError(params: { + client: VoyageEmbeddingClient; + errorFileId: string; + deps: VoyageBatchDeps; +}): Promise { + try { + return await params.deps.withRemoteHttpResponse( + buildVoyageBatchRequest({ + client: params.client, + path: `files/${params.errorFileId}/content`, + onResponse: async (res) => { + await assertVoyageResponseOk(res, "voyage batch error file content failed"); + const text = await res.text(); + if (!text.trim()) { + return undefined; + } + const lines = text + .split("\n") + .map((line) => line.trim()) + .filter(Boolean) + .map((line) => JSON.parse(line) as VoyageBatchOutputLine); + return extractBatchErrorMessage(lines); + }, + }), + ); + } catch (err) { + return formatUnavailableBatchError(err); + } +} + +async function waitForVoyageBatch(params: { + client: VoyageEmbeddingClient; + batchId: string; + wait: boolean; + pollIntervalMs: number; + timeoutMs: number; + debug?: (message: string, data?: Record) => void; + initial?: VoyageBatchStatus; + deps: VoyageBatchDeps; +}): Promise { + const start = params.deps.now(); + let current: VoyageBatchStatus | undefined = params.initial; + while (true) { + const status = + current ?? + (await fetchVoyageBatchStatus({ + client: params.client, + batchId: params.batchId, + deps: params.deps, + })); + const state = status.status ?? "unknown"; + if (state === "completed") { + return resolveBatchCompletionFromStatus({ + provider: "voyage", + batchId: params.batchId, + status, + }); + } + await throwIfBatchTerminalFailure({ + provider: "voyage", + status: { ...status, id: params.batchId }, + readError: async (errorFileId) => + await readVoyageBatchError({ + client: params.client, + errorFileId, + deps: params.deps, + }), + }); + if (!params.wait) { + throw new Error(`voyage batch ${params.batchId} still ${state}; wait disabled`); + } + if (params.deps.now() - start > params.timeoutMs) { + throw new Error(`voyage batch ${params.batchId} timed out after ${params.timeoutMs}ms`); + } + params.debug?.(`voyage batch ${params.batchId} ${state}; waiting ${params.pollIntervalMs}ms`); + await params.deps.sleep(params.pollIntervalMs); + current = undefined; + } +} + +export async function runVoyageEmbeddingBatches( + params: { + client: VoyageEmbeddingClient; + agentId: string; + requests: VoyageBatchRequest[]; + deps?: Partial; + } & EmbeddingBatchExecutionParams, +): Promise> { + const deps = resolveVoyageBatchDeps(params.deps); + return await runEmbeddingBatchGroups({ + ...buildEmbeddingBatchGroupOptions(params, { + maxRequests: VOYAGE_BATCH_MAX_REQUESTS, + debugLabel: "memory embeddings: voyage batch submit", + }), + runGroup: async ({ group, groupIndex, groups, byCustomId }) => { + const batchInfo = await submitVoyageBatch({ + client: params.client, + requests: group, + agentId: params.agentId, + deps, + }); + if (!batchInfo.id) { + throw new Error("voyage batch create failed: missing batch id"); + } + const batchId = batchInfo.id; + + params.debug?.("memory embeddings: voyage batch created", { + batchId: batchInfo.id, + status: batchInfo.status, + group: groupIndex + 1, + groups, + requests: group.length, + }); + + const completed = await resolveCompletedBatchResult({ + provider: "voyage", + status: batchInfo, + wait: params.wait, + waitForBatch: async () => + await waitForVoyageBatch({ + client: params.client, + batchId, + wait: params.wait, + pollIntervalMs: params.pollIntervalMs, + timeoutMs: params.timeoutMs, + debug: params.debug, + initial: batchInfo, + deps, + }), + }); + + const baseUrl = normalizeBatchBaseUrl(params.client); + const errors: string[] = []; + const remaining = new Set(group.map((request) => request.custom_id)); + + await deps.withRemoteHttpResponse({ + url: `${baseUrl}/files/${completed.outputFileId}/content`, + ssrfPolicy: params.client.ssrfPolicy, + init: { + headers: buildBatchHeaders(params.client, { json: true }), + }, + onResponse: async (contentRes) => { + if (!contentRes.ok) { + const text = await contentRes.text(); + throw new Error(`voyage batch file content failed: ${contentRes.status} ${text}`); + } + + if (!contentRes.body) { + return; + } + const reader = createInterface({ + input: Readable.fromWeb( + contentRes.body as unknown as import("stream/web").ReadableStream, + ), + terminal: false, + }); + + for await (const rawLine of reader) { + if (!rawLine.trim()) { + continue; + } + const line = JSON.parse(rawLine) as VoyageBatchOutputLine; + applyEmbeddingBatchOutputLine({ line, remaining, errors, byCustomId }); + } + }, + }); + + if (errors.length > 0) { + throw new Error(`voyage batch ${batchInfo.id} failed: ${errors.join("; ")}`); + } + if (remaining.size > 0) { + throw new Error( + `voyage batch ${batchInfo.id} missing ${remaining.size} embedding responses`, + ); + } + }, + }); +} diff --git a/src/memory-host-sdk/host/embedding-chunk-limits.test.ts b/src/memory-host-sdk/host/embedding-chunk-limits.test.ts new file mode 100644 index 00000000000..733f98fe7b2 --- /dev/null +++ b/src/memory-host-sdk/host/embedding-chunk-limits.test.ts @@ -0,0 +1,102 @@ +import { describe, expect, it } from "vitest"; +import { enforceEmbeddingMaxInputTokens } from "./embedding-chunk-limits.js"; +import { estimateUtf8Bytes } from "./embedding-input-limits.js"; +import type { EmbeddingProvider } from "./embeddings.js"; + +function createProvider(maxInputTokens: number): EmbeddingProvider { + return { + id: "mock", + model: "mock-embed", + maxInputTokens, + embedQuery: async () => [0], + embedBatch: async () => [[0]], + }; +} + +function createProviderWithoutMaxInputTokens(params: { + id: string; + model: string; +}): EmbeddingProvider { + return { + id: params.id, + model: params.model, + embedQuery: async () => [0], + embedBatch: async () => [[0]], + }; +} + +describe("embedding chunk limits", () => { + it("splits oversized chunks so each embedding input stays <= maxInputTokens bytes", () => { + const provider = createProvider(8192); + const input = { + startLine: 1, + endLine: 1, + text: "x".repeat(9000), + hash: "ignored", + }; + + const out = enforceEmbeddingMaxInputTokens(provider, [input]); + expect(out.length).toBeGreaterThan(1); + expect(out.map((chunk) => chunk.text).join("")).toBe(input.text); + expect(out.every((chunk) => estimateUtf8Bytes(chunk.text) <= 8192)).toBe(true); + expect(out.every((chunk) => chunk.startLine === 1 && chunk.endLine === 1)).toBe(true); + expect(out.every((chunk) => typeof chunk.hash === "string" && chunk.hash.length > 0)).toBe( + true, + ); + }); + + it("does not split inside surrogate pairs (emoji)", () => { + const provider = createProvider(8192); + const emoji = "😀"; + const inputText = `${emoji.repeat(2100)}\n${emoji.repeat(2100)}`; + + const out = enforceEmbeddingMaxInputTokens(provider, [ + { startLine: 1, endLine: 2, text: inputText, hash: "ignored" }, + ]); + + expect(out.length).toBeGreaterThan(1); + expect(out.map((chunk) => chunk.text).join("")).toBe(inputText); + expect(out.every((chunk) => estimateUtf8Bytes(chunk.text) <= 8192)).toBe(true); + + // If we split inside surrogate pairs we'd likely end up with replacement chars. + expect(out.map((chunk) => chunk.text).join("")).not.toContain("\uFFFD"); + }); + + it("uses conservative fallback limits for local providers without declared maxInputTokens", () => { + const provider = createProviderWithoutMaxInputTokens({ + id: "local", + model: "unknown-local-embedding", + }); + + const out = enforceEmbeddingMaxInputTokens(provider, [ + { + startLine: 1, + endLine: 1, + text: "x".repeat(3000), + hash: "ignored", + }, + ]); + + expect(out.length).toBeGreaterThan(1); + expect(out.every((chunk) => estimateUtf8Bytes(chunk.text) <= 2048)).toBe(true); + }); + + it("honors hard safety caps lower than provider maxInputTokens", () => { + const provider = createProvider(8192); + const out = enforceEmbeddingMaxInputTokens( + provider, + [ + { + startLine: 1, + endLine: 1, + text: "x".repeat(8100), + hash: "ignored", + }, + ], + 8000, + ); + + expect(out.length).toBeGreaterThan(1); + expect(out.every((chunk) => estimateUtf8Bytes(chunk.text) <= 8000)).toBe(true); + }); +}); diff --git a/src/memory-host-sdk/host/embedding-chunk-limits.ts b/src/memory-host-sdk/host/embedding-chunk-limits.ts new file mode 100644 index 00000000000..5c8cf9020f3 --- /dev/null +++ b/src/memory-host-sdk/host/embedding-chunk-limits.ts @@ -0,0 +1,41 @@ +import { estimateUtf8Bytes, splitTextToUtf8ByteLimit } from "./embedding-input-limits.js"; +import { hasNonTextEmbeddingParts } from "./embedding-inputs.js"; +import { resolveEmbeddingMaxInputTokens } from "./embedding-model-limits.js"; +import type { EmbeddingProvider } from "./embeddings.js"; +import { hashText, type MemoryChunk } from "./internal.js"; + +export function enforceEmbeddingMaxInputTokens( + provider: EmbeddingProvider, + chunks: MemoryChunk[], + hardMaxInputTokens?: number, +): MemoryChunk[] { + const providerMaxInputTokens = resolveEmbeddingMaxInputTokens(provider); + const maxInputTokens = + typeof hardMaxInputTokens === "number" && hardMaxInputTokens > 0 + ? Math.min(providerMaxInputTokens, hardMaxInputTokens) + : providerMaxInputTokens; + const out: MemoryChunk[] = []; + + for (const chunk of chunks) { + if (hasNonTextEmbeddingParts(chunk.embeddingInput)) { + out.push(chunk); + continue; + } + if (estimateUtf8Bytes(chunk.text) <= maxInputTokens) { + out.push(chunk); + continue; + } + + for (const text of splitTextToUtf8ByteLimit(chunk.text, maxInputTokens)) { + out.push({ + startLine: chunk.startLine, + endLine: chunk.endLine, + text, + hash: hashText(text), + embeddingInput: { text }, + }); + } + } + + return out; +} diff --git a/src/memory-host-sdk/host/embedding-input-limits.ts b/src/memory-host-sdk/host/embedding-input-limits.ts new file mode 100644 index 00000000000..4eadf1bf48d --- /dev/null +++ b/src/memory-host-sdk/host/embedding-input-limits.ts @@ -0,0 +1,85 @@ +import type { EmbeddingInput } from "./embedding-inputs.js"; + +// Helpers for enforcing embedding model input size limits. +// +// We use UTF-8 byte length as a conservative upper bound for tokenizer output. +// Tokenizers operate over bytes; a token must contain at least one byte, so +// token_count <= utf8_byte_length. + +export function estimateUtf8Bytes(text: string): number { + if (!text) { + return 0; + } + return Buffer.byteLength(text, "utf8"); +} + +export function estimateStructuredEmbeddingInputBytes(input: EmbeddingInput): number { + if (!input.parts?.length) { + return estimateUtf8Bytes(input.text); + } + let total = 0; + for (const part of input.parts) { + if (part.type === "text") { + total += estimateUtf8Bytes(part.text); + continue; + } + total += estimateUtf8Bytes(part.mimeType); + total += estimateUtf8Bytes(part.data); + } + return total; +} + +export function splitTextToUtf8ByteLimit(text: string, maxUtf8Bytes: number): string[] { + if (maxUtf8Bytes <= 0) { + return [text]; + } + if (estimateUtf8Bytes(text) <= maxUtf8Bytes) { + return [text]; + } + + const parts: string[] = []; + let cursor = 0; + while (cursor < text.length) { + // The number of UTF-16 code units is always <= the number of UTF-8 bytes. + // This makes `cursor + maxUtf8Bytes` a safe upper bound on the next split point. + let low = cursor + 1; + let high = Math.min(text.length, cursor + maxUtf8Bytes); + let best = cursor; + + while (low <= high) { + const mid = Math.floor((low + high) / 2); + const bytes = estimateUtf8Bytes(text.slice(cursor, mid)); + if (bytes <= maxUtf8Bytes) { + best = mid; + low = mid + 1; + } else { + high = mid - 1; + } + } + + if (best <= cursor) { + best = Math.min(text.length, cursor + 1); + } + + // Avoid splitting inside a surrogate pair. + if ( + best < text.length && + best > cursor && + text.charCodeAt(best - 1) >= 0xd800 && + text.charCodeAt(best - 1) <= 0xdbff && + text.charCodeAt(best) >= 0xdc00 && + text.charCodeAt(best) <= 0xdfff + ) { + best -= 1; + } + + const part = text.slice(cursor, best); + if (!part) { + break; + } + parts.push(part); + cursor = best; + } + + return parts; +} diff --git a/src/memory-host-sdk/host/embedding-inputs.ts b/src/memory-host-sdk/host/embedding-inputs.ts new file mode 100644 index 00000000000..767a463f740 --- /dev/null +++ b/src/memory-host-sdk/host/embedding-inputs.ts @@ -0,0 +1,34 @@ +export type EmbeddingInputTextPart = { + type: "text"; + text: string; +}; + +export type EmbeddingInputInlineDataPart = { + type: "inline-data"; + mimeType: string; + data: string; +}; + +export type EmbeddingInputPart = EmbeddingInputTextPart | EmbeddingInputInlineDataPart; + +export type EmbeddingInput = { + text: string; + parts?: EmbeddingInputPart[]; +}; + +export function buildTextEmbeddingInput(text: string): EmbeddingInput { + return { text }; +} + +export function isInlineDataEmbeddingInputPart( + part: EmbeddingInputPart, +): part is EmbeddingInputInlineDataPart { + return part.type === "inline-data"; +} + +export function hasNonTextEmbeddingParts(input: EmbeddingInput | undefined): boolean { + if (!input?.parts?.length) { + return false; + } + return input.parts.some((part) => isInlineDataEmbeddingInputPart(part)); +} diff --git a/src/memory-host-sdk/host/embedding-model-limits.ts b/src/memory-host-sdk/host/embedding-model-limits.ts new file mode 100644 index 00000000000..0819686b905 --- /dev/null +++ b/src/memory-host-sdk/host/embedding-model-limits.ts @@ -0,0 +1,41 @@ +import type { EmbeddingProvider } from "./embeddings.js"; + +const DEFAULT_EMBEDDING_MAX_INPUT_TOKENS = 8192; +const DEFAULT_LOCAL_EMBEDDING_MAX_INPUT_TOKENS = 2048; + +const KNOWN_EMBEDDING_MAX_INPUT_TOKENS: Record = { + "openai:text-embedding-3-small": 8192, + "openai:text-embedding-3-large": 8192, + "openai:text-embedding-ada-002": 8191, + "gemini:text-embedding-004": 2048, + "gemini:gemini-embedding-001": 2048, + "gemini:gemini-embedding-2-preview": 8192, + "voyage:voyage-3": 32000, + "voyage:voyage-3-lite": 16000, + "voyage:voyage-code-3": 32000, +}; + +export function resolveEmbeddingMaxInputTokens(provider: EmbeddingProvider): number { + if (typeof provider.maxInputTokens === "number") { + return provider.maxInputTokens; + } + + // Provider/model mapping is best-effort; different providers use different + // limits and we prefer to be conservative when we don't know. + const key = `${provider.id}:${provider.model}`.toLowerCase(); + const known = KNOWN_EMBEDDING_MAX_INPUT_TOKENS[key]; + if (typeof known === "number") { + return known; + } + + // Provider-specific conservative fallbacks. This prevents us from accidentally + // using the OpenAI default for providers with much smaller limits. + if (provider.id.toLowerCase() === "gemini") { + return 2048; + } + if (provider.id.toLowerCase() === "local") { + return DEFAULT_LOCAL_EMBEDDING_MAX_INPUT_TOKENS; + } + + return DEFAULT_EMBEDDING_MAX_INPUT_TOKENS; +} diff --git a/src/memory-host-sdk/host/embedding-vectors.ts b/src/memory-host-sdk/host/embedding-vectors.ts new file mode 100644 index 00000000000..d589f61390d --- /dev/null +++ b/src/memory-host-sdk/host/embedding-vectors.ts @@ -0,0 +1,8 @@ +export function sanitizeAndNormalizeEmbedding(vec: number[]): number[] { + const sanitized = vec.map((value) => (Number.isFinite(value) ? value : 0)); + const magnitude = Math.sqrt(sanitized.reduce((sum, value) => sum + value * value, 0)); + if (magnitude < 1e-10) { + return sanitized; + } + return sanitized.map((value) => value / magnitude); +} diff --git a/src/memory-host-sdk/host/embeddings-debug.ts b/src/memory-host-sdk/host/embeddings-debug.ts new file mode 100644 index 00000000000..a9f20d55e8a --- /dev/null +++ b/src/memory-host-sdk/host/embeddings-debug.ts @@ -0,0 +1,13 @@ +import { isTruthyEnvValue } from "../../infra/env.js"; +import { createSubsystemLogger } from "../../logging/subsystem.js"; + +const debugEmbeddings = isTruthyEnvValue(process.env.OPENCLAW_DEBUG_MEMORY_EMBEDDINGS); +const log = createSubsystemLogger("memory/embeddings"); + +export function debugEmbeddingsLog(message: string, meta?: Record): void { + if (!debugEmbeddings) { + return; + } + const suffix = meta ? ` ${JSON.stringify(meta)}` : ""; + log.raw(`${message}${suffix}`); +} diff --git a/src/memory-host-sdk/host/embeddings-gemini.test.ts b/src/memory-host-sdk/host/embeddings-gemini.test.ts new file mode 100644 index 00000000000..a1f4ef028ef --- /dev/null +++ b/src/memory-host-sdk/host/embeddings-gemini.test.ts @@ -0,0 +1,592 @@ +import { afterEach, beforeAll, beforeEach, describe, expect, it, vi } from "vitest"; +import * as authModule from "../../agents/model-auth.js"; +import { mockPublicPinnedHostname } from "./test-helpers/ssrf.js"; + +vi.mock("../../infra/net/fetch-guard.js", () => ({ + fetchWithSsrFGuard: async (params: { + url: string; + init?: RequestInit; + fetchImpl?: typeof fetch; + }) => { + const fetchImpl = params.fetchImpl ?? globalThis.fetch; + if (!fetchImpl) { + throw new Error("fetch is not available"); + } + const response = await fetchImpl(params.url, params.init); + return { + response, + finalUrl: params.url, + release: async () => {}, + }; + }, +})); + +vi.mock("../../agents/model-auth.js", async () => { + const { createModelAuthMockModule } = await import("../../test-utils/model-auth-mock.js"); + return createModelAuthMockModule(); +}); + +const createGeminiFetchMock = (embeddingValues = [1, 2, 3]) => + vi.fn(async (_input?: unknown, _init?: unknown) => ({ + ok: true, + status: 200, + json: async () => ({ embedding: { values: embeddingValues } }), + })); + +const createGeminiBatchFetchMock = (count: number, embeddingValues = [1, 2, 3]) => + vi.fn(async (_input?: unknown, _init?: unknown) => ({ + ok: true, + status: 200, + json: async () => ({ + embeddings: Array.from({ length: count }, () => ({ values: embeddingValues })), + }), + })); + +function installFetchMock(fetchMock: typeof globalThis.fetch) { + vi.stubGlobal("fetch", fetchMock); +} + +function readFirstFetchRequest(fetchMock: { mock: { calls: unknown[][] } }) { + const [url, init] = fetchMock.mock.calls[0] ?? []; + return { url, init: init as RequestInit | undefined }; +} + +function parseFetchBody(fetchMock: { mock: { calls: unknown[][] } }, callIndex = 0) { + const init = fetchMock.mock.calls[callIndex]?.[1] as RequestInit | undefined; + return JSON.parse((init?.body as string) ?? "{}") as Record; +} + +function magnitude(values: number[]) { + return Math.sqrt(values.reduce((sum, value) => sum + value * value, 0)); +} + +let buildGeminiEmbeddingRequest: typeof import("./embeddings-gemini.js").buildGeminiEmbeddingRequest; +let buildGeminiTextEmbeddingRequest: typeof import("./embeddings-gemini.js").buildGeminiTextEmbeddingRequest; +let createGeminiEmbeddingProvider: typeof import("./embeddings-gemini.js").createGeminiEmbeddingProvider; +let DEFAULT_GEMINI_EMBEDDING_MODEL: typeof import("./embeddings-gemini.js").DEFAULT_GEMINI_EMBEDDING_MODEL; +let GEMINI_EMBEDDING_2_MODELS: typeof import("./embeddings-gemini.js").GEMINI_EMBEDDING_2_MODELS; +let isGeminiEmbedding2Model: typeof import("./embeddings-gemini.js").isGeminiEmbedding2Model; +let resolveGeminiOutputDimensionality: typeof import("./embeddings-gemini.js").resolveGeminiOutputDimensionality; + +beforeAll(async () => { + vi.doUnmock("undici"); + ({ + buildGeminiEmbeddingRequest, + buildGeminiTextEmbeddingRequest, + createGeminiEmbeddingProvider, + DEFAULT_GEMINI_EMBEDDING_MODEL, + GEMINI_EMBEDDING_2_MODELS, + isGeminiEmbedding2Model, + resolveGeminiOutputDimensionality, + } = await import("./embeddings-gemini.js")); +}); + +beforeEach(() => { + vi.useRealTimers(); + vi.doUnmock("undici"); +}); + +afterEach(() => { + vi.doUnmock("undici"); + vi.resetAllMocks(); + vi.unstubAllGlobals(); +}); + +function mockResolvedProviderKey(apiKey = "test-key") { + vi.mocked(authModule.resolveApiKeyForProvider).mockResolvedValue({ + apiKey, + mode: "api-key", + source: "test", + }); +} + +type GeminiFetchMock = + | ReturnType + | ReturnType; + +async function createProviderWithFetch( + fetchMock: GeminiFetchMock, + options: Partial[0]> & { model: string }, +) { + installFetchMock(fetchMock as unknown as typeof globalThis.fetch); + mockPublicPinnedHostname(); + mockResolvedProviderKey(); + const { provider } = await createGeminiEmbeddingProvider({ + config: {} as never, + provider: "gemini", + fallback: "none", + ...options, + }); + return provider; +} + +function expectNormalizedThreeFourVector(embedding: number[]) { + expect(embedding[0]).toBeCloseTo(0.6, 5); + expect(embedding[1]).toBeCloseTo(0.8, 5); + expect(magnitude(embedding)).toBeCloseTo(1, 5); +} + +describe("buildGeminiTextEmbeddingRequest", () => { + it("builds a text embedding request with optional model and dimensions", () => { + expect( + buildGeminiTextEmbeddingRequest({ + text: "hello", + taskType: "RETRIEVAL_DOCUMENT", + modelPath: "models/gemini-embedding-2-preview", + outputDimensionality: 1536, + }), + ).toEqual({ + model: "models/gemini-embedding-2-preview", + content: { parts: [{ text: "hello" }] }, + taskType: "RETRIEVAL_DOCUMENT", + outputDimensionality: 1536, + }); + }); +}); + +describe("buildGeminiEmbeddingRequest", () => { + it("builds a multimodal request from structured input parts", () => { + expect( + buildGeminiEmbeddingRequest({ + input: { + text: "Image file: diagram.png", + parts: [ + { type: "text", text: "Image file: diagram.png" }, + { type: "inline-data", mimeType: "image/png", data: "abc123" }, + ], + }, + taskType: "RETRIEVAL_DOCUMENT", + modelPath: "models/gemini-embedding-2-preview", + outputDimensionality: 1536, + }), + ).toEqual({ + model: "models/gemini-embedding-2-preview", + content: { + parts: [ + { text: "Image file: diagram.png" }, + { inlineData: { mimeType: "image/png", data: "abc123" } }, + ], + }, + taskType: "RETRIEVAL_DOCUMENT", + outputDimensionality: 1536, + }); + }); +}); + +// ---------- Model detection ---------- + +describe("isGeminiEmbedding2Model", () => { + it("returns true for gemini-embedding-2-preview", () => { + expect(isGeminiEmbedding2Model("gemini-embedding-2-preview")).toBe(true); + }); + + it("returns false for gemini-embedding-001", () => { + expect(isGeminiEmbedding2Model("gemini-embedding-001")).toBe(false); + }); + + it("returns false for text-embedding-004", () => { + expect(isGeminiEmbedding2Model("text-embedding-004")).toBe(false); + }); +}); + +describe("GEMINI_EMBEDDING_2_MODELS", () => { + it("contains gemini-embedding-2-preview", () => { + expect(GEMINI_EMBEDDING_2_MODELS.has("gemini-embedding-2-preview")).toBe(true); + }); +}); + +// ---------- Dimension resolution ---------- + +describe("resolveGeminiOutputDimensionality", () => { + it("returns undefined for non-v2 models", () => { + expect(resolveGeminiOutputDimensionality("gemini-embedding-001")).toBeUndefined(); + expect(resolveGeminiOutputDimensionality("text-embedding-004")).toBeUndefined(); + }); + + it("returns 3072 by default for v2 models", () => { + expect(resolveGeminiOutputDimensionality("gemini-embedding-2-preview")).toBe(3072); + }); + + it("accepts valid dimension values", () => { + expect(resolveGeminiOutputDimensionality("gemini-embedding-2-preview", 768)).toBe(768); + expect(resolveGeminiOutputDimensionality("gemini-embedding-2-preview", 1536)).toBe(1536); + expect(resolveGeminiOutputDimensionality("gemini-embedding-2-preview", 3072)).toBe(3072); + }); + + it("throws for invalid dimension values", () => { + expect(() => resolveGeminiOutputDimensionality("gemini-embedding-2-preview", 512)).toThrow( + /Invalid outputDimensionality 512/, + ); + expect(() => resolveGeminiOutputDimensionality("gemini-embedding-2-preview", 1024)).toThrow( + /Valid values: 768, 1536, 3072/, + ); + }); +}); + +// ---------- Provider: gemini-embedding-001 (backward compat) ---------- + +describe("gemini-embedding-001 provider (backward compat)", () => { + it("does NOT include outputDimensionality in embedQuery", async () => { + const fetchMock = createGeminiFetchMock(); + const provider = await createProviderWithFetch(fetchMock, { + model: "gemini-embedding-001", + }); + + await provider.embedQuery("test query"); + + const body = parseFetchBody(fetchMock); + expect(body).not.toHaveProperty("outputDimensionality"); + expect(body.taskType).toBe("RETRIEVAL_QUERY"); + expect(body.content).toEqual({ parts: [{ text: "test query" }] }); + }); + + it("does NOT include outputDimensionality in embedBatch", async () => { + const fetchMock = createGeminiBatchFetchMock(2); + const provider = await createProviderWithFetch(fetchMock, { + model: "gemini-embedding-001", + }); + + await provider.embedBatch(["text1", "text2"]); + + const body = parseFetchBody(fetchMock); + expect(body).not.toHaveProperty("outputDimensionality"); + }); +}); + +// ---------- Provider: gemini-embedding-2-preview ---------- + +describe("gemini-embedding-2-preview provider", () => { + it("includes outputDimensionality in embedQuery request", async () => { + const fetchMock = createGeminiFetchMock(); + const provider = await createProviderWithFetch(fetchMock, { + model: "gemini-embedding-2-preview", + }); + + await provider.embedQuery("test query"); + + const body = parseFetchBody(fetchMock); + expect(body.outputDimensionality).toBe(3072); + expect(body.taskType).toBe("RETRIEVAL_QUERY"); + expect(body.content).toEqual({ parts: [{ text: "test query" }] }); + }); + + it("normalizes embedQuery response vectors", async () => { + const fetchMock = createGeminiFetchMock([3, 4]); + const provider = await createProviderWithFetch(fetchMock, { + model: "gemini-embedding-2-preview", + }); + + const embedding = await provider.embedQuery("test query"); + + expectNormalizedThreeFourVector(embedding); + }); + + it("includes outputDimensionality in embedBatch request", async () => { + const fetchMock = createGeminiBatchFetchMock(2); + const provider = await createProviderWithFetch(fetchMock, { + model: "gemini-embedding-2-preview", + }); + + await provider.embedBatch(["text1", "text2"]); + + const body = parseFetchBody(fetchMock); + expect(body.requests).toEqual([ + { + model: "models/gemini-embedding-2-preview", + content: { parts: [{ text: "text1" }] }, + taskType: "RETRIEVAL_DOCUMENT", + outputDimensionality: 3072, + }, + { + model: "models/gemini-embedding-2-preview", + content: { parts: [{ text: "text2" }] }, + taskType: "RETRIEVAL_DOCUMENT", + outputDimensionality: 3072, + }, + ]); + }); + + it("normalizes embedBatch response vectors", async () => { + const fetchMock = createGeminiBatchFetchMock(2, [3, 4]); + const provider = await createProviderWithFetch(fetchMock, { + model: "gemini-embedding-2-preview", + }); + + const embeddings = await provider.embedBatch(["text1", "text2"]); + + expect(embeddings).toHaveLength(2); + for (const embedding of embeddings) { + expectNormalizedThreeFourVector(embedding); + } + }); + + it("respects custom outputDimensionality", async () => { + const fetchMock = createGeminiFetchMock(); + const provider = await createProviderWithFetch(fetchMock, { + model: "gemini-embedding-2-preview", + outputDimensionality: 768, + }); + + await provider.embedQuery("test"); + + const body = parseFetchBody(fetchMock); + expect(body.outputDimensionality).toBe(768); + }); + + it("sanitizes and normalizes embedQuery responses", async () => { + const fetchMock = createGeminiFetchMock([3, 4, Number.NaN]); + const provider = await createProviderWithFetch(fetchMock, { + model: "gemini-embedding-2-preview", + }); + + await expect(provider.embedQuery("test")).resolves.toEqual([0.6, 0.8, 0]); + }); + + it("uses custom outputDimensionality for each embedBatch request", async () => { + const fetchMock = createGeminiBatchFetchMock(2); + const provider = await createProviderWithFetch(fetchMock, { + model: "gemini-embedding-2-preview", + outputDimensionality: 768, + }); + + await provider.embedBatch(["text1", "text2"]); + + const body = parseFetchBody(fetchMock); + expect(body.requests).toEqual([ + expect.objectContaining({ outputDimensionality: 768 }), + expect.objectContaining({ outputDimensionality: 768 }), + ]); + }); + + it("sanitizes and normalizes structured batch responses", async () => { + const fetchMock = createGeminiBatchFetchMock(1, [0, Number.POSITIVE_INFINITY, 5]); + const provider = await createProviderWithFetch(fetchMock, { + model: "gemini-embedding-2-preview", + }); + + await expect( + provider.embedBatchInputs?.([ + { + text: "Image file: diagram.png", + parts: [ + { type: "text", text: "Image file: diagram.png" }, + { type: "inline-data", mimeType: "image/png", data: "img" }, + ], + }, + ]), + ).resolves.toEqual([[0, 0, 1]]); + }); + + it("supports multimodal embedBatchInputs requests", async () => { + const fetchMock = createGeminiBatchFetchMock(2); + const provider = await createProviderWithFetch(fetchMock, { + model: "gemini-embedding-2-preview", + }); + + expect(provider.embedBatchInputs).toBeDefined(); + await provider.embedBatchInputs?.([ + { + text: "Image file: diagram.png", + parts: [ + { type: "text", text: "Image file: diagram.png" }, + { type: "inline-data", mimeType: "image/png", data: "img" }, + ], + }, + { + text: "Audio file: note.wav", + parts: [ + { type: "text", text: "Audio file: note.wav" }, + { type: "inline-data", mimeType: "audio/wav", data: "aud" }, + ], + }, + ]); + + const body = parseFetchBody(fetchMock); + expect(body.requests).toEqual([ + { + model: "models/gemini-embedding-2-preview", + content: { + parts: [ + { text: "Image file: diagram.png" }, + { inlineData: { mimeType: "image/png", data: "img" } }, + ], + }, + taskType: "RETRIEVAL_DOCUMENT", + outputDimensionality: 3072, + }, + { + model: "models/gemini-embedding-2-preview", + content: { + parts: [ + { text: "Audio file: note.wav" }, + { inlineData: { mimeType: "audio/wav", data: "aud" } }, + ], + }, + taskType: "RETRIEVAL_DOCUMENT", + outputDimensionality: 3072, + }, + ]); + }); + + it("throws for invalid outputDimensionality", async () => { + mockResolvedProviderKey(); + + await expect( + createGeminiEmbeddingProvider({ + config: {} as never, + provider: "gemini", + model: "gemini-embedding-2-preview", + fallback: "none", + outputDimensionality: 512, + }), + ).rejects.toThrow(/Invalid outputDimensionality 512/); + }); + + it("sanitizes non-finite values before normalization", async () => { + const fetchMock = createGeminiFetchMock([ + 1, + Number.NaN, + Number.POSITIVE_INFINITY, + Number.NEGATIVE_INFINITY, + ]); + const provider = await createProviderWithFetch(fetchMock, { + model: "gemini-embedding-2-preview", + }); + + const embedding = await provider.embedQuery("test"); + + expect(embedding).toEqual([1, 0, 0, 0]); + }); + + it("uses correct endpoint URL", async () => { + const fetchMock = createGeminiFetchMock(); + const provider = await createProviderWithFetch(fetchMock, { + model: "gemini-embedding-2-preview", + }); + + await provider.embedQuery("test"); + + const { url } = readFirstFetchRequest(fetchMock); + expect(url).toBe( + "https://generativelanguage.googleapis.com/v1beta/models/gemini-embedding-2-preview:embedContent", + ); + }); + + it("allows taskType override via options", async () => { + const fetchMock = createGeminiFetchMock(); + const provider = await createProviderWithFetch(fetchMock, { + model: "gemini-embedding-2-preview", + taskType: "SEMANTIC_SIMILARITY", + }); + + await provider.embedQuery("test"); + + const body = parseFetchBody(fetchMock); + expect(body.taskType).toBe("SEMANTIC_SIMILARITY"); + }); +}); + +// ---------- Model normalization ---------- + +describe("gemini model normalization", () => { + it("handles models/ prefix for v2 model", async () => { + const fetchMock = createGeminiFetchMock(); + installFetchMock(fetchMock as unknown as typeof globalThis.fetch); + mockPublicPinnedHostname(); + mockResolvedProviderKey(); + + const { provider } = await createGeminiEmbeddingProvider({ + config: {} as never, + provider: "gemini", + model: "models/gemini-embedding-2-preview", + fallback: "none", + }); + + await provider.embedQuery("test"); + + const body = parseFetchBody(fetchMock); + expect(body.outputDimensionality).toBe(3072); + }); + + it("handles gemini/ prefix for v2 model", async () => { + const fetchMock = createGeminiFetchMock(); + installFetchMock(fetchMock as unknown as typeof globalThis.fetch); + mockPublicPinnedHostname(); + mockResolvedProviderKey(); + + const { provider } = await createGeminiEmbeddingProvider({ + config: {} as never, + provider: "gemini", + model: "gemini/gemini-embedding-2-preview", + fallback: "none", + }); + + await provider.embedQuery("test"); + + const body = parseFetchBody(fetchMock); + expect(body.outputDimensionality).toBe(3072); + }); + + it("handles google/ prefix for v2 model", async () => { + const fetchMock = createGeminiFetchMock(); + installFetchMock(fetchMock as unknown as typeof globalThis.fetch); + mockPublicPinnedHostname(); + mockResolvedProviderKey(); + + const { provider } = await createGeminiEmbeddingProvider({ + config: {} as never, + provider: "gemini", + model: "google/gemini-embedding-2-preview", + fallback: "none", + }); + + await provider.embedQuery("test"); + + const body = parseFetchBody(fetchMock); + expect(body.outputDimensionality).toBe(3072); + }); + + it("defaults to gemini-embedding-001 when model is empty", async () => { + const fetchMock = createGeminiFetchMock(); + installFetchMock(fetchMock as unknown as typeof globalThis.fetch); + mockResolvedProviderKey(); + + const { provider, client } = await createGeminiEmbeddingProvider({ + config: {} as never, + provider: "gemini", + model: "", + fallback: "none", + }); + + expect(client.model).toBe(DEFAULT_GEMINI_EMBEDDING_MODEL); + expect(provider.model).toBe(DEFAULT_GEMINI_EMBEDDING_MODEL); + }); + + it("returns empty array for blank query text", async () => { + mockResolvedProviderKey(); + + const { provider } = await createGeminiEmbeddingProvider({ + config: {} as never, + provider: "gemini", + model: "gemini-embedding-2-preview", + fallback: "none", + }); + + const result = await provider.embedQuery(" "); + expect(result).toEqual([]); + }); + + it("returns empty array for empty batch", async () => { + mockResolvedProviderKey(); + + const { provider } = await createGeminiEmbeddingProvider({ + config: {} as never, + provider: "gemini", + model: "gemini-embedding-2-preview", + fallback: "none", + }); + + const result = await provider.embedBatch([]); + expect(result).toEqual([]); + }); +}); diff --git a/src/memory-host-sdk/host/embeddings-gemini.ts b/src/memory-host-sdk/host/embeddings-gemini.ts new file mode 100644 index 00000000000..3826398a371 --- /dev/null +++ b/src/memory-host-sdk/host/embeddings-gemini.ts @@ -0,0 +1,336 @@ +import { + collectProviderApiKeysForExecution, + executeWithApiKeyRotation, +} from "../../agents/api-key-rotation.js"; +import { requireApiKey, resolveApiKeyForProvider } from "../../agents/model-auth.js"; +import { parseGeminiAuth } from "../../infra/gemini-auth.js"; +import { + DEFAULT_GOOGLE_API_BASE_URL, + normalizeGoogleApiBaseUrl, +} from "../../infra/google-api-base-url.js"; +import type { SsrFPolicy } from "../../infra/net/ssrf.js"; +import type { EmbeddingInput } from "./embedding-inputs.js"; +import { sanitizeAndNormalizeEmbedding } from "./embedding-vectors.js"; +import { debugEmbeddingsLog } from "./embeddings-debug.js"; +import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.js"; +import { buildRemoteBaseUrlPolicy, withRemoteHttpResponse } from "./remote-http.js"; +import { resolveMemorySecretInputString } from "./secret-input.js"; + +export type GeminiEmbeddingClient = { + baseUrl: string; + headers: Record; + ssrfPolicy?: SsrFPolicy; + model: string; + modelPath: string; + apiKeys: string[]; + outputDimensionality?: number; +}; + +export const DEFAULT_GEMINI_EMBEDDING_MODEL = "gemini-embedding-001"; +const GEMINI_MAX_INPUT_TOKENS: Record = { + "text-embedding-004": 2048, +}; + +// --- gemini-embedding-2-preview support --- + +export const GEMINI_EMBEDDING_2_MODELS = new Set([ + "gemini-embedding-2-preview", + // Add the GA model name here once released. +]); + +const GEMINI_EMBEDDING_2_DEFAULT_DIMENSIONS = 3072; +const GEMINI_EMBEDDING_2_VALID_DIMENSIONS = [768, 1536, 3072] as const; + +export type GeminiTaskType = + | "RETRIEVAL_QUERY" + | "RETRIEVAL_DOCUMENT" + | "SEMANTIC_SIMILARITY" + | "CLASSIFICATION" + | "CLUSTERING" + | "QUESTION_ANSWERING" + | "FACT_VERIFICATION"; + +export type GeminiTextPart = { text: string }; +export type GeminiInlinePart = { + inlineData: { mimeType: string; data: string }; +}; +export type GeminiPart = GeminiTextPart | GeminiInlinePart; +export type GeminiEmbeddingRequest = { + content: { parts: GeminiPart[] }; + taskType: GeminiTaskType; + outputDimensionality?: number; + model?: string; +}; +export type GeminiTextEmbeddingRequest = GeminiEmbeddingRequest; + +/** Builds the text-only Gemini embedding request shape used across direct and batch APIs. */ +export function buildGeminiTextEmbeddingRequest(params: { + text: string; + taskType: GeminiTaskType; + outputDimensionality?: number; + modelPath?: string; +}): GeminiTextEmbeddingRequest { + return buildGeminiEmbeddingRequest({ + input: { text: params.text }, + taskType: params.taskType, + outputDimensionality: params.outputDimensionality, + modelPath: params.modelPath, + }); +} + +export function buildGeminiEmbeddingRequest(params: { + input: EmbeddingInput; + taskType: GeminiTaskType; + outputDimensionality?: number; + modelPath?: string; +}): GeminiEmbeddingRequest { + const request: GeminiEmbeddingRequest = { + content: { + parts: params.input.parts?.map((part) => + part.type === "text" + ? ({ text: part.text } satisfies GeminiTextPart) + : ({ + inlineData: { mimeType: part.mimeType, data: part.data }, + } satisfies GeminiInlinePart), + ) ?? [{ text: params.input.text }], + }, + taskType: params.taskType, + }; + if (params.modelPath) { + request.model = params.modelPath; + } + if (params.outputDimensionality != null) { + request.outputDimensionality = params.outputDimensionality; + } + return request; +} + +/** + * Returns true if the given model name is a gemini-embedding-2 variant that + * supports `outputDimensionality` and extended task types. + */ +export function isGeminiEmbedding2Model(model: string): boolean { + return GEMINI_EMBEDDING_2_MODELS.has(model); +} + +/** + * Validate and return the `outputDimensionality` for gemini-embedding-2 models. + * Returns `undefined` for older models (they don't support the param). + */ +export function resolveGeminiOutputDimensionality( + model: string, + requested?: number, +): number | undefined { + if (!isGeminiEmbedding2Model(model)) { + return undefined; + } + if (requested == null) { + return GEMINI_EMBEDDING_2_DEFAULT_DIMENSIONS; + } + const valid: readonly number[] = GEMINI_EMBEDDING_2_VALID_DIMENSIONS; + if (!valid.includes(requested)) { + throw new Error( + `Invalid outputDimensionality ${requested} for ${model}. Valid values: ${valid.join(", ")}`, + ); + } + return requested; +} +function resolveRemoteApiKey(remoteApiKey: unknown): string | undefined { + const trimmed = resolveMemorySecretInputString({ + value: remoteApiKey, + path: "agents.*.memorySearch.remote.apiKey", + }); + if (!trimmed) { + return undefined; + } + if (trimmed === "GOOGLE_API_KEY" || trimmed === "GEMINI_API_KEY") { + return process.env[trimmed]?.trim(); + } + return trimmed; +} + +export function normalizeGeminiModel(model: string): string { + const trimmed = model.trim(); + if (!trimmed) { + return DEFAULT_GEMINI_EMBEDDING_MODEL; + } + const withoutPrefix = trimmed.replace(/^models\//, ""); + if (withoutPrefix.startsWith("gemini/")) { + return withoutPrefix.slice("gemini/".length); + } + if (withoutPrefix.startsWith("google/")) { + return withoutPrefix.slice("google/".length); + } + return withoutPrefix; +} + +async function fetchGeminiEmbeddingPayload(params: { + client: GeminiEmbeddingClient; + endpoint: string; + body: unknown; +}): Promise<{ + embedding?: { values?: number[] }; + embeddings?: Array<{ values?: number[] }>; +}> { + return await executeWithApiKeyRotation({ + provider: "google", + apiKeys: params.client.apiKeys, + execute: async (apiKey) => { + const authHeaders = parseGeminiAuth(apiKey); + const headers = { + ...authHeaders.headers, + ...params.client.headers, + }; + return await withRemoteHttpResponse({ + url: params.endpoint, + ssrfPolicy: params.client.ssrfPolicy, + init: { + method: "POST", + headers, + body: JSON.stringify(params.body), + }, + onResponse: async (res) => { + if (!res.ok) { + const text = await res.text(); + throw new Error(`gemini embeddings failed: ${res.status} ${text}`); + } + return (await res.json()) as { + embedding?: { values?: number[] }; + embeddings?: Array<{ values?: number[] }>; + }; + }, + }); + }, + }); +} + +function normalizeGeminiBaseUrl(raw: string): string { + const trimmed = raw.replace(/\/+$/, ""); + const openAiIndex = trimmed.indexOf("/openai"); + if (openAiIndex > -1) { + return normalizeGoogleApiBaseUrl(trimmed.slice(0, openAiIndex)); + } + return normalizeGoogleApiBaseUrl(trimmed); +} + +function buildGeminiModelPath(model: string): string { + return model.startsWith("models/") ? model : `models/${model}`; +} + +export async function createGeminiEmbeddingProvider( + options: EmbeddingProviderOptions, +): Promise<{ provider: EmbeddingProvider; client: GeminiEmbeddingClient }> { + const client = await resolveGeminiEmbeddingClient(options); + const baseUrl = client.baseUrl.replace(/\/$/, ""); + const embedUrl = `${baseUrl}/${client.modelPath}:embedContent`; + const batchUrl = `${baseUrl}/${client.modelPath}:batchEmbedContents`; + const isV2 = isGeminiEmbedding2Model(client.model); + const outputDimensionality = client.outputDimensionality; + + const embedQuery = async (text: string): Promise => { + if (!text.trim()) { + return []; + } + const payload = await fetchGeminiEmbeddingPayload({ + client, + endpoint: embedUrl, + body: buildGeminiTextEmbeddingRequest({ + text, + taskType: options.taskType ?? "RETRIEVAL_QUERY", + outputDimensionality: isV2 ? outputDimensionality : undefined, + }), + }); + return sanitizeAndNormalizeEmbedding(payload.embedding?.values ?? []); + }; + + const embedBatchInputs = async (inputs: EmbeddingInput[]): Promise => { + if (inputs.length === 0) { + return []; + } + const payload = await fetchGeminiEmbeddingPayload({ + client, + endpoint: batchUrl, + body: { + requests: inputs.map((input) => + buildGeminiEmbeddingRequest({ + input, + modelPath: client.modelPath, + taskType: options.taskType ?? "RETRIEVAL_DOCUMENT", + outputDimensionality: isV2 ? outputDimensionality : undefined, + }), + ), + }, + }); + const embeddings = Array.isArray(payload.embeddings) ? payload.embeddings : []; + return inputs.map((_, index) => sanitizeAndNormalizeEmbedding(embeddings[index]?.values ?? [])); + }; + + const embedBatch = async (texts: string[]): Promise => { + return await embedBatchInputs( + texts.map((text) => ({ + text, + })), + ); + }; + + return { + provider: { + id: "gemini", + model: client.model, + maxInputTokens: GEMINI_MAX_INPUT_TOKENS[client.model], + embedQuery, + embedBatch, + embedBatchInputs, + }, + client, + }; +} + +export async function resolveGeminiEmbeddingClient( + options: EmbeddingProviderOptions, +): Promise { + const remote = options.remote; + const remoteApiKey = resolveRemoteApiKey(remote?.apiKey); + const remoteBaseUrl = remote?.baseUrl?.trim(); + + const apiKey = remoteApiKey + ? remoteApiKey + : requireApiKey( + await resolveApiKeyForProvider({ + provider: "google", + cfg: options.config, + agentDir: options.agentDir, + }), + "google", + ); + + const providerConfig = options.config.models?.providers?.google; + const rawBaseUrl = + remoteBaseUrl || providerConfig?.baseUrl?.trim() || DEFAULT_GOOGLE_API_BASE_URL; + const baseUrl = normalizeGeminiBaseUrl(rawBaseUrl); + const ssrfPolicy = buildRemoteBaseUrlPolicy(baseUrl); + const headerOverrides = Object.assign({}, providerConfig?.headers, remote?.headers); + const headers: Record = { + ...headerOverrides, + }; + const apiKeys = collectProviderApiKeysForExecution({ + provider: "google", + primaryApiKey: apiKey, + }); + const model = normalizeGeminiModel(options.model); + const modelPath = buildGeminiModelPath(model); + const outputDimensionality = resolveGeminiOutputDimensionality( + model, + options.outputDimensionality, + ); + debugEmbeddingsLog("memory embeddings: gemini client", { + rawBaseUrl, + baseUrl, + model, + modelPath, + outputDimensionality, + embedEndpoint: `${baseUrl}/${modelPath}:embedContent`, + batchEndpoint: `${baseUrl}/${modelPath}:batchEmbedContents`, + }); + return { baseUrl, headers, ssrfPolicy, model, modelPath, apiKeys, outputDimensionality }; +} diff --git a/src/memory-host-sdk/host/embeddings-mistral.test.ts b/src/memory-host-sdk/host/embeddings-mistral.test.ts new file mode 100644 index 00000000000..7826cd35467 --- /dev/null +++ b/src/memory-host-sdk/host/embeddings-mistral.test.ts @@ -0,0 +1,19 @@ +import { describe, expect, it } from "vitest"; +import { DEFAULT_MISTRAL_EMBEDDING_MODEL, normalizeMistralModel } from "./embeddings-mistral.js"; + +describe("normalizeMistralModel", () => { + it("returns the default model for empty values", () => { + expect(normalizeMistralModel("")).toBe(DEFAULT_MISTRAL_EMBEDDING_MODEL); + expect(normalizeMistralModel(" ")).toBe(DEFAULT_MISTRAL_EMBEDDING_MODEL); + }); + + it("strips the mistral/ prefix", () => { + expect(normalizeMistralModel("mistral/mistral-embed")).toBe("mistral-embed"); + expect(normalizeMistralModel(" mistral/custom-embed ")).toBe("custom-embed"); + }); + + it("keeps explicit non-prefixed models", () => { + expect(normalizeMistralModel("mistral-embed")).toBe("mistral-embed"); + expect(normalizeMistralModel("custom-embed-v2")).toBe("custom-embed-v2"); + }); +}); diff --git a/src/memory-host-sdk/host/embeddings-mistral.ts b/src/memory-host-sdk/host/embeddings-mistral.ts new file mode 100644 index 00000000000..90e9799414f --- /dev/null +++ b/src/memory-host-sdk/host/embeddings-mistral.ts @@ -0,0 +1,51 @@ +import type { SsrFPolicy } from "../../infra/net/ssrf.js"; +import { normalizeEmbeddingModelWithPrefixes } from "./embeddings-model-normalize.js"; +import { + createRemoteEmbeddingProvider, + resolveRemoteEmbeddingClient, +} from "./embeddings-remote-provider.js"; +import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.js"; + +export type MistralEmbeddingClient = { + baseUrl: string; + headers: Record; + ssrfPolicy?: SsrFPolicy; + model: string; +}; + +export const DEFAULT_MISTRAL_EMBEDDING_MODEL = "mistral-embed"; +const DEFAULT_MISTRAL_BASE_URL = "https://api.mistral.ai/v1"; + +export function normalizeMistralModel(model: string): string { + return normalizeEmbeddingModelWithPrefixes({ + model, + defaultModel: DEFAULT_MISTRAL_EMBEDDING_MODEL, + prefixes: ["mistral/"], + }); +} + +export async function createMistralEmbeddingProvider( + options: EmbeddingProviderOptions, +): Promise<{ provider: EmbeddingProvider; client: MistralEmbeddingClient }> { + const client = await resolveMistralEmbeddingClient(options); + + return { + provider: createRemoteEmbeddingProvider({ + id: "mistral", + client, + errorPrefix: "mistral embeddings failed", + }), + client, + }; +} + +export async function resolveMistralEmbeddingClient( + options: EmbeddingProviderOptions, +): Promise { + return await resolveRemoteEmbeddingClient({ + provider: "mistral", + options, + defaultBaseUrl: DEFAULT_MISTRAL_BASE_URL, + normalizeModel: normalizeMistralModel, + }); +} diff --git a/src/memory-host-sdk/host/embeddings-model-normalize.test.ts b/src/memory-host-sdk/host/embeddings-model-normalize.test.ts new file mode 100644 index 00000000000..dc0581b82fe --- /dev/null +++ b/src/memory-host-sdk/host/embeddings-model-normalize.test.ts @@ -0,0 +1,34 @@ +import { describe, expect, it } from "vitest"; +import { normalizeEmbeddingModelWithPrefixes } from "./embeddings-model-normalize.js"; + +describe("normalizeEmbeddingModelWithPrefixes", () => { + it("returns default model when input is blank", () => { + expect( + normalizeEmbeddingModelWithPrefixes({ + model: " ", + defaultModel: "fallback-model", + prefixes: ["openai/"], + }), + ).toBe("fallback-model"); + }); + + it("strips the first matching prefix", () => { + expect( + normalizeEmbeddingModelWithPrefixes({ + model: "openai/text-embedding-3-small", + defaultModel: "fallback-model", + prefixes: ["openai/"], + }), + ).toBe("text-embedding-3-small"); + }); + + it("keeps explicit model names when no prefix matches", () => { + expect( + normalizeEmbeddingModelWithPrefixes({ + model: "voyage-4-large", + defaultModel: "fallback-model", + prefixes: ["voyage/"], + }), + ).toBe("voyage-4-large"); + }); +}); diff --git a/src/memory-host-sdk/host/embeddings-model-normalize.ts b/src/memory-host-sdk/host/embeddings-model-normalize.ts new file mode 100644 index 00000000000..85fcf5b16ce --- /dev/null +++ b/src/memory-host-sdk/host/embeddings-model-normalize.ts @@ -0,0 +1,16 @@ +export function normalizeEmbeddingModelWithPrefixes(params: { + model: string; + defaultModel: string; + prefixes: string[]; +}): string { + const trimmed = params.model.trim(); + if (!trimmed) { + return params.defaultModel; + } + for (const prefix of params.prefixes) { + if (trimmed.startsWith(prefix)) { + return trimmed.slice(prefix.length); + } + } + return trimmed; +} diff --git a/src/memory-host-sdk/host/embeddings-ollama.test.ts b/src/memory-host-sdk/host/embeddings-ollama.test.ts new file mode 100644 index 00000000000..6e425a7d1fd --- /dev/null +++ b/src/memory-host-sdk/host/embeddings-ollama.test.ts @@ -0,0 +1,146 @@ +import { afterEach, beforeAll, beforeEach, describe, it, expect, vi } from "vitest"; +import type { OpenClawConfig } from "../../config/config.js"; + +let createOllamaEmbeddingProvider: typeof import("./embeddings-ollama.js").createOllamaEmbeddingProvider; + +beforeAll(async () => { + ({ createOllamaEmbeddingProvider } = await import("./embeddings-ollama.js")); +}); + +beforeEach(() => { + vi.useRealTimers(); + vi.doUnmock("undici"); +}); + +afterEach(() => { + vi.doUnmock("undici"); + vi.unstubAllGlobals(); + vi.unstubAllEnvs(); + vi.resetAllMocks(); +}); + +describe("embeddings-ollama", () => { + it("calls /api/embeddings and returns normalized vectors", async () => { + const fetchMock = vi.fn( + async () => + new Response(JSON.stringify({ embedding: [3, 4] }), { + status: 200, + headers: { "content-type": "application/json" }, + }), + ); + globalThis.fetch = fetchMock as unknown as typeof fetch; + + const { provider } = await createOllamaEmbeddingProvider({ + config: {} as OpenClawConfig, + provider: "ollama", + model: "nomic-embed-text", + fallback: "none", + remote: { baseUrl: "http://127.0.0.1:11434" }, + }); + + const v = await provider.embedQuery("hi"); + expect(fetchMock).toHaveBeenCalledTimes(1); + // normalized [3,4] => [0.6,0.8] + expect(v[0]).toBeCloseTo(0.6, 5); + expect(v[1]).toBeCloseTo(0.8, 5); + }); + + it("resolves baseUrl/apiKey/headers from models.providers.ollama and strips /v1", async () => { + const fetchMock = vi.fn( + async () => + new Response(JSON.stringify({ embedding: [1, 0] }), { + status: 200, + headers: { "content-type": "application/json" }, + }), + ); + globalThis.fetch = fetchMock as unknown as typeof fetch; + + const { provider } = await createOllamaEmbeddingProvider({ + config: { + models: { + providers: { + ollama: { + baseUrl: "http://127.0.0.1:11434/v1", + apiKey: "ollama-\nlocal\r\n", // pragma: allowlist secret + headers: { + "X-Provider-Header": "provider", + }, + }, + }, + }, + } as unknown as OpenClawConfig, + provider: "ollama", + model: "", + fallback: "none", + }); + + await provider.embedQuery("hello"); + + expect(fetchMock).toHaveBeenCalledWith( + "http://127.0.0.1:11434/api/embeddings", + expect.objectContaining({ + method: "POST", + headers: expect.objectContaining({ + "Content-Type": "application/json", + Authorization: "Bearer ollama-local", + "X-Provider-Header": "provider", + }), + }), + ); + }); + + it("fails fast when memory-search remote apiKey is an unresolved SecretRef", async () => { + await expect( + createOllamaEmbeddingProvider({ + config: {} as OpenClawConfig, + provider: "ollama", + model: "nomic-embed-text", + fallback: "none", + remote: { + baseUrl: "http://127.0.0.1:11434", + apiKey: { source: "env", provider: "default", id: "OLLAMA_API_KEY" }, + }, + }), + ).rejects.toThrow(/agents\.\*\.memorySearch\.remote\.apiKey: unresolved SecretRef/i); + }); + + it("falls back to env key when models.providers.ollama.apiKey is an unresolved SecretRef", async () => { + const fetchMock = vi.fn( + async () => + new Response(JSON.stringify({ embedding: [1, 0] }), { + status: 200, + headers: { "content-type": "application/json" }, + }), + ); + globalThis.fetch = fetchMock as unknown as typeof fetch; + vi.stubEnv("OLLAMA_API_KEY", "ollama-env"); + + const { provider } = await createOllamaEmbeddingProvider({ + config: { + models: { + providers: { + ollama: { + baseUrl: "http://127.0.0.1:11434/v1", + apiKey: { source: "env", provider: "default", id: "OLLAMA_API_KEY" }, + models: [], + }, + }, + }, + } as unknown as OpenClawConfig, + provider: "ollama", + model: "nomic-embed-text", + fallback: "none", + }); + + await provider.embedQuery("hello"); + + expect(fetchMock).toHaveBeenCalledWith( + "http://127.0.0.1:11434/api/embeddings", + expect.objectContaining({ + headers: expect.objectContaining({ + Authorization: "Bearer ollama-env", + }), + }), + ); + }); +}); diff --git a/src/memory-host-sdk/host/embeddings-ollama.ts b/src/memory-host-sdk/host/embeddings-ollama.ts new file mode 100644 index 00000000000..d3dd7090381 --- /dev/null +++ b/src/memory-host-sdk/host/embeddings-ollama.ts @@ -0,0 +1,5 @@ +export type { OllamaEmbeddingClient } from "../../plugin-sdk/ollama.js"; +export { + createOllamaEmbeddingProvider, + DEFAULT_OLLAMA_EMBEDDING_MODEL, +} from "../../plugin-sdk/ollama.js"; diff --git a/src/memory-host-sdk/host/embeddings-openai.ts b/src/memory-host-sdk/host/embeddings-openai.ts new file mode 100644 index 00000000000..867767acaf5 --- /dev/null +++ b/src/memory-host-sdk/host/embeddings-openai.ts @@ -0,0 +1,58 @@ +import type { SsrFPolicy } from "../../infra/net/ssrf.js"; +import { OPENAI_DEFAULT_EMBEDDING_MODEL } from "../../plugins/provider-model-defaults.js"; +import { normalizeEmbeddingModelWithPrefixes } from "./embeddings-model-normalize.js"; +import { + createRemoteEmbeddingProvider, + resolveRemoteEmbeddingClient, +} from "./embeddings-remote-provider.js"; +import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.js"; + +export type OpenAiEmbeddingClient = { + baseUrl: string; + headers: Record; + ssrfPolicy?: SsrFPolicy; + model: string; +}; + +const DEFAULT_OPENAI_BASE_URL = "https://api.openai.com/v1"; +export const DEFAULT_OPENAI_EMBEDDING_MODEL = OPENAI_DEFAULT_EMBEDDING_MODEL; +const OPENAI_MAX_INPUT_TOKENS: Record = { + "text-embedding-3-small": 8192, + "text-embedding-3-large": 8192, + "text-embedding-ada-002": 8191, +}; + +export function normalizeOpenAiModel(model: string): string { + return normalizeEmbeddingModelWithPrefixes({ + model, + defaultModel: DEFAULT_OPENAI_EMBEDDING_MODEL, + prefixes: ["openai/"], + }); +} + +export async function createOpenAiEmbeddingProvider( + options: EmbeddingProviderOptions, +): Promise<{ provider: EmbeddingProvider; client: OpenAiEmbeddingClient }> { + const client = await resolveOpenAiEmbeddingClient(options); + + return { + provider: createRemoteEmbeddingProvider({ + id: "openai", + client, + errorPrefix: "openai embeddings failed", + maxInputTokens: OPENAI_MAX_INPUT_TOKENS[client.model], + }), + client, + }; +} + +export async function resolveOpenAiEmbeddingClient( + options: EmbeddingProviderOptions, +): Promise { + return await resolveRemoteEmbeddingClient({ + provider: "openai", + options, + defaultBaseUrl: DEFAULT_OPENAI_BASE_URL, + normalizeModel: normalizeOpenAiModel, + }); +} diff --git a/src/memory-host-sdk/host/embeddings-remote-client.ts b/src/memory-host-sdk/host/embeddings-remote-client.ts new file mode 100644 index 00000000000..154b886cdf2 --- /dev/null +++ b/src/memory-host-sdk/host/embeddings-remote-client.ts @@ -0,0 +1,39 @@ +import { requireApiKey, resolveApiKeyForProvider } from "../../agents/model-auth.js"; +import type { SsrFPolicy } from "../../infra/net/ssrf.js"; +import type { EmbeddingProviderOptions } from "./embeddings.js"; +import { buildRemoteBaseUrlPolicy } from "./remote-http.js"; +import { resolveMemorySecretInputString } from "./secret-input.js"; + +export type RemoteEmbeddingProviderId = "openai" | "voyage" | "mistral"; + +export async function resolveRemoteEmbeddingBearerClient(params: { + provider: RemoteEmbeddingProviderId; + options: EmbeddingProviderOptions; + defaultBaseUrl: string; +}): Promise<{ baseUrl: string; headers: Record; ssrfPolicy?: SsrFPolicy }> { + const remote = params.options.remote; + const remoteApiKey = resolveMemorySecretInputString({ + value: remote?.apiKey, + path: "agents.*.memorySearch.remote.apiKey", + }); + const remoteBaseUrl = remote?.baseUrl?.trim(); + const providerConfig = params.options.config.models?.providers?.[params.provider]; + const apiKey = remoteApiKey + ? remoteApiKey + : requireApiKey( + await resolveApiKeyForProvider({ + provider: params.provider, + cfg: params.options.config, + agentDir: params.options.agentDir, + }), + params.provider, + ); + const baseUrl = remoteBaseUrl || providerConfig?.baseUrl?.trim() || params.defaultBaseUrl; + const headerOverrides = Object.assign({}, providerConfig?.headers, remote?.headers); + const headers: Record = { + "Content-Type": "application/json", + Authorization: `Bearer ${apiKey}`, + ...headerOverrides, + }; + return { baseUrl, headers, ssrfPolicy: buildRemoteBaseUrlPolicy(baseUrl) }; +} diff --git a/src/memory-host-sdk/host/embeddings-remote-fetch.test.ts b/src/memory-host-sdk/host/embeddings-remote-fetch.test.ts new file mode 100644 index 00000000000..3ddddc708f5 --- /dev/null +++ b/src/memory-host-sdk/host/embeddings-remote-fetch.test.ts @@ -0,0 +1,59 @@ +import { beforeAll, beforeEach, describe, expect, it, vi } from "vitest"; + +const postJsonMock = vi.hoisted(() => vi.fn()); + +vi.mock("./post-json.js", () => ({ + postJson: postJsonMock, +})); + +type EmbeddingsRemoteFetchModule = typeof import("./embeddings-remote-fetch.js"); + +let fetchRemoteEmbeddingVectors: EmbeddingsRemoteFetchModule["fetchRemoteEmbeddingVectors"]; + +describe("fetchRemoteEmbeddingVectors", () => { + beforeAll(async () => { + ({ fetchRemoteEmbeddingVectors } = await import("./embeddings-remote-fetch.js")); + }); + + beforeEach(() => { + postJsonMock.mockReset(); + }); + + it("maps remote embedding response data to vectors", async () => { + postJsonMock.mockImplementationOnce(async (params) => { + return await params.parse({ + data: [{ embedding: [0.1, 0.2] }, {}, { embedding: [0.3] }], + }); + }); + + const vectors = await fetchRemoteEmbeddingVectors({ + url: "https://memory.example/v1/embeddings", + headers: { Authorization: "Bearer test" }, + body: { input: ["one", "two", "three"] }, + errorPrefix: "embedding fetch failed", + }); + + expect(vectors).toEqual([[0.1, 0.2], [], [0.3]]); + expect(postJsonMock).toHaveBeenCalledWith( + expect.objectContaining({ + url: "https://memory.example/v1/embeddings", + headers: { Authorization: "Bearer test" }, + body: { input: ["one", "two", "three"] }, + errorPrefix: "embedding fetch failed", + }), + ); + }); + + it("throws a status-rich error on non-ok responses", async () => { + postJsonMock.mockRejectedValueOnce(new Error("embedding fetch failed: 403 forbidden")); + + await expect( + fetchRemoteEmbeddingVectors({ + url: "https://memory.example/v1/embeddings", + headers: {}, + body: { input: ["one"] }, + errorPrefix: "embedding fetch failed", + }), + ).rejects.toThrow("embedding fetch failed: 403 forbidden"); + }); +}); diff --git a/src/memory-host-sdk/host/embeddings-remote-fetch.ts b/src/memory-host-sdk/host/embeddings-remote-fetch.ts new file mode 100644 index 00000000000..a45acb37456 --- /dev/null +++ b/src/memory-host-sdk/host/embeddings-remote-fetch.ts @@ -0,0 +1,25 @@ +import type { SsrFPolicy } from "../../infra/net/ssrf.js"; +import { postJson } from "./post-json.js"; + +export async function fetchRemoteEmbeddingVectors(params: { + url: string; + headers: Record; + ssrfPolicy?: SsrFPolicy; + body: unknown; + errorPrefix: string; +}): Promise { + return await postJson({ + url: params.url, + headers: params.headers, + ssrfPolicy: params.ssrfPolicy, + body: params.body, + errorPrefix: params.errorPrefix, + parse: (payload) => { + const typedPayload = payload as { + data?: Array<{ embedding?: number[] }>; + }; + const data = typedPayload.data ?? []; + return data.map((entry) => entry.embedding ?? []); + }, + }); +} diff --git a/src/memory-host-sdk/host/embeddings-remote-provider.ts b/src/memory-host-sdk/host/embeddings-remote-provider.ts new file mode 100644 index 00000000000..c0c9a0cb2dd --- /dev/null +++ b/src/memory-host-sdk/host/embeddings-remote-provider.ts @@ -0,0 +1,63 @@ +import type { SsrFPolicy } from "../../infra/net/ssrf.js"; +import { + resolveRemoteEmbeddingBearerClient, + type RemoteEmbeddingProviderId, +} from "./embeddings-remote-client.js"; +import { fetchRemoteEmbeddingVectors } from "./embeddings-remote-fetch.js"; +import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.js"; + +export type RemoteEmbeddingClient = { + baseUrl: string; + headers: Record; + ssrfPolicy?: SsrFPolicy; + model: string; +}; + +export function createRemoteEmbeddingProvider(params: { + id: string; + client: RemoteEmbeddingClient; + errorPrefix: string; + maxInputTokens?: number; +}): EmbeddingProvider { + const { client } = params; + const url = `${client.baseUrl.replace(/\/$/, "")}/embeddings`; + + const embed = async (input: string[]): Promise => { + if (input.length === 0) { + return []; + } + return await fetchRemoteEmbeddingVectors({ + url, + headers: client.headers, + ssrfPolicy: client.ssrfPolicy, + body: { model: client.model, input }, + errorPrefix: params.errorPrefix, + }); + }; + + return { + id: params.id, + model: client.model, + ...(typeof params.maxInputTokens === "number" ? { maxInputTokens: params.maxInputTokens } : {}), + embedQuery: async (text) => { + const [vec] = await embed([text]); + return vec ?? []; + }, + embedBatch: embed, + }; +} + +export async function resolveRemoteEmbeddingClient(params: { + provider: RemoteEmbeddingProviderId; + options: EmbeddingProviderOptions; + defaultBaseUrl: string; + normalizeModel: (model: string) => string; +}): Promise { + const { baseUrl, headers, ssrfPolicy } = await resolveRemoteEmbeddingBearerClient({ + provider: params.provider, + options: params.options, + defaultBaseUrl: params.defaultBaseUrl, + }); + const model = params.normalizeModel(params.options.model); + return { baseUrl, headers, ssrfPolicy, model }; +} diff --git a/src/memory-host-sdk/host/embeddings-voyage.test.ts b/src/memory-host-sdk/host/embeddings-voyage.test.ts new file mode 100644 index 00000000000..c5834796299 --- /dev/null +++ b/src/memory-host-sdk/host/embeddings-voyage.test.ts @@ -0,0 +1,177 @@ +import { afterEach, beforeAll, beforeEach, describe, expect, it, vi } from "vitest"; +import * as authModule from "../../agents/model-auth.js"; +import { type FetchMock, withFetchPreconnect } from "../../test-utils/fetch-mock.js"; +import { mockPublicPinnedHostname } from "./test-helpers/ssrf.js"; + +vi.mock("../../infra/net/fetch-guard.js", () => ({ + fetchWithSsrFGuard: async (params: { + url: string; + init?: RequestInit; + fetchImpl?: typeof fetch; + }) => { + const fetchImpl = params.fetchImpl ?? globalThis.fetch; + if (!fetchImpl) { + throw new Error("fetch is not available"); + } + const response = await fetchImpl(params.url, params.init); + return { + response, + finalUrl: params.url, + release: async () => {}, + }; + }, +})); + +vi.mock("../../agents/model-auth.js", async () => { + const { createModelAuthMockModule } = await import("../../test-utils/model-auth-mock.js"); + return createModelAuthMockModule(); +}); + +const createFetchMock = () => { + const fetchMock = vi.fn( + async (_input: RequestInfo | URL, _init?: RequestInit) => + new Response(JSON.stringify({ data: [{ embedding: [0.1, 0.2, 0.3] }] }), { + status: 200, + headers: { "Content-Type": "application/json" }, + }), + ); + return withFetchPreconnect(fetchMock); +}; + +function installFetchMock(fetchMock: typeof globalThis.fetch) { + vi.stubGlobal("fetch", fetchMock); +} + +let createVoyageEmbeddingProvider: typeof import("./embeddings-voyage.js").createVoyageEmbeddingProvider; +let normalizeVoyageModel: typeof import("./embeddings-voyage.js").normalizeVoyageModel; + +beforeAll(async () => { + ({ createVoyageEmbeddingProvider, normalizeVoyageModel } = + await import("./embeddings-voyage.js")); +}); + +beforeEach(() => { + vi.useRealTimers(); + vi.doUnmock("undici"); +}); + +function mockVoyageApiKey() { + vi.mocked(authModule.resolveApiKeyForProvider).mockResolvedValue({ + apiKey: "voyage-key-123", + mode: "api-key", + source: "test", + }); +} + +async function createDefaultVoyageProvider( + model: string, + fetchMock: ReturnType, +) { + installFetchMock(fetchMock as unknown as typeof globalThis.fetch); + mockPublicPinnedHostname(); + mockVoyageApiKey(); + return createVoyageEmbeddingProvider({ + config: {} as never, + provider: "voyage", + model, + fallback: "none", + }); +} + +describe("voyage embedding provider", () => { + afterEach(() => { + vi.doUnmock("undici"); + vi.resetAllMocks(); + vi.unstubAllGlobals(); + }); + + it("configures client with correct defaults and headers", async () => { + const fetchMock = createFetchMock(); + const result = await createDefaultVoyageProvider("voyage-4-large", fetchMock); + + await result.provider.embedQuery("test query"); + + expect(authModule.resolveApiKeyForProvider).toHaveBeenCalledWith( + expect.objectContaining({ provider: "voyage" }), + ); + + const call = fetchMock.mock.calls[0]; + expect(call).toBeDefined(); + const [url, init] = call as [RequestInfo | URL, RequestInit | undefined]; + expect(url).toBe("https://api.voyageai.com/v1/embeddings"); + + const headers = (init?.headers ?? {}) as Record; + expect(headers.Authorization).toBe("Bearer voyage-key-123"); + expect(headers["Content-Type"]).toBe("application/json"); + + const body = JSON.parse(init?.body as string); + expect(body).toEqual({ + model: "voyage-4-large", + input: ["test query"], + input_type: "query", + }); + }); + + it("respects remote overrides for baseUrl and apiKey", async () => { + const fetchMock = createFetchMock(); + installFetchMock(fetchMock as unknown as typeof globalThis.fetch); + mockPublicPinnedHostname(); + + const result = await createVoyageEmbeddingProvider({ + config: {} as never, + provider: "voyage", + model: "voyage-4-lite", + fallback: "none", + remote: { + baseUrl: "https://example.com", + apiKey: "remote-override-key", + headers: { "X-Custom": "123" }, + }, + }); + + await result.provider.embedQuery("test"); + + const call = fetchMock.mock.calls[0]; + expect(call).toBeDefined(); + const [url, init] = call as [RequestInfo | URL, RequestInit | undefined]; + expect(url).toBe("https://example.com/embeddings"); + + const headers = (init?.headers ?? {}) as Record; + expect(headers.Authorization).toBe("Bearer remote-override-key"); + expect(headers["X-Custom"]).toBe("123"); + }); + + it("passes input_type=document for embedBatch", async () => { + const fetchMock = withFetchPreconnect( + vi.fn( + async (_input: RequestInfo | URL, _init?: RequestInit) => + new Response( + JSON.stringify({ + data: [{ embedding: [0.1, 0.2] }, { embedding: [0.3, 0.4] }], + }), + { status: 200, headers: { "Content-Type": "application/json" } }, + ), + ), + ); + const result = await createDefaultVoyageProvider("voyage-4-large", fetchMock); + + await result.provider.embedBatch(["doc1", "doc2"]); + + const call = fetchMock.mock.calls[0]; + expect(call).toBeDefined(); + const [, init] = call as [RequestInfo | URL, RequestInit | undefined]; + const body = JSON.parse(init?.body as string); + expect(body).toEqual({ + model: "voyage-4-large", + input: ["doc1", "doc2"], + input_type: "document", + }); + }); + + it("normalizes model names", async () => { + expect(normalizeVoyageModel("voyage/voyage-large-2")).toBe("voyage-large-2"); + expect(normalizeVoyageModel("voyage-4-large")).toBe("voyage-4-large"); + expect(normalizeVoyageModel(" voyage-lite ")).toBe("voyage-lite"); + expect(normalizeVoyageModel("")).toBe("voyage-4-large"); // Default + }); +}); diff --git a/src/memory-host-sdk/host/embeddings-voyage.ts b/src/memory-host-sdk/host/embeddings-voyage.ts new file mode 100644 index 00000000000..caf4165d1f1 --- /dev/null +++ b/src/memory-host-sdk/host/embeddings-voyage.ts @@ -0,0 +1,82 @@ +import type { SsrFPolicy } from "../../infra/net/ssrf.js"; +import { normalizeEmbeddingModelWithPrefixes } from "./embeddings-model-normalize.js"; +import { resolveRemoteEmbeddingBearerClient } from "./embeddings-remote-client.js"; +import { fetchRemoteEmbeddingVectors } from "./embeddings-remote-fetch.js"; +import type { EmbeddingProvider, EmbeddingProviderOptions } from "./embeddings.js"; + +export type VoyageEmbeddingClient = { + baseUrl: string; + headers: Record; + ssrfPolicy?: SsrFPolicy; + model: string; +}; + +export const DEFAULT_VOYAGE_EMBEDDING_MODEL = "voyage-4-large"; +const DEFAULT_VOYAGE_BASE_URL = "https://api.voyageai.com/v1"; +const VOYAGE_MAX_INPUT_TOKENS: Record = { + "voyage-3": 32000, + "voyage-3-lite": 16000, + "voyage-code-3": 32000, +}; + +export function normalizeVoyageModel(model: string): string { + return normalizeEmbeddingModelWithPrefixes({ + model, + defaultModel: DEFAULT_VOYAGE_EMBEDDING_MODEL, + prefixes: ["voyage/"], + }); +} + +export async function createVoyageEmbeddingProvider( + options: EmbeddingProviderOptions, +): Promise<{ provider: EmbeddingProvider; client: VoyageEmbeddingClient }> { + const client = await resolveVoyageEmbeddingClient(options); + const url = `${client.baseUrl.replace(/\/$/, "")}/embeddings`; + + const embed = async (input: string[], input_type?: "query" | "document"): Promise => { + if (input.length === 0) { + return []; + } + const body: { model: string; input: string[]; input_type?: "query" | "document" } = { + model: client.model, + input, + }; + if (input_type) { + body.input_type = input_type; + } + + return await fetchRemoteEmbeddingVectors({ + url, + headers: client.headers, + ssrfPolicy: client.ssrfPolicy, + body, + errorPrefix: "voyage embeddings failed", + }); + }; + + return { + provider: { + id: "voyage", + model: client.model, + maxInputTokens: VOYAGE_MAX_INPUT_TOKENS[client.model], + embedQuery: async (text) => { + const [vec] = await embed([text], "query"); + return vec ?? []; + }, + embedBatch: async (texts) => embed(texts, "document"), + }, + client, + }; +} + +export async function resolveVoyageEmbeddingClient( + options: EmbeddingProviderOptions, +): Promise { + const { baseUrl, headers, ssrfPolicy } = await resolveRemoteEmbeddingBearerClient({ + provider: "voyage", + options, + defaultBaseUrl: DEFAULT_VOYAGE_BASE_URL, + }); + const model = normalizeVoyageModel(options.model); + return { baseUrl, headers, ssrfPolicy, model }; +} diff --git a/src/memory-host-sdk/host/embeddings.test.ts b/src/memory-host-sdk/host/embeddings.test.ts new file mode 100644 index 00000000000..c4fb6545d30 --- /dev/null +++ b/src/memory-host-sdk/host/embeddings.test.ts @@ -0,0 +1,752 @@ +import { setTimeout as sleep } from "node:timers/promises"; +import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"; +import * as authModule from "../../agents/model-auth.js"; +import { DEFAULT_GEMINI_EMBEDDING_MODEL } from "./embeddings-gemini.js"; +import { createEmbeddingProvider, DEFAULT_LOCAL_MODEL } from "./embeddings.js"; +import * as nodeLlamaModule from "./node-llama.js"; +import { mockPublicPinnedHostname } from "./test-helpers/ssrf.js"; + +const { createOllamaEmbeddingProviderMock } = vi.hoisted(() => ({ + createOllamaEmbeddingProviderMock: vi.fn(async () => { + throw new Error("Unexpected ollama provider in embeddings.test.ts"); + }), +})); + +vi.mock("../../infra/net/fetch-guard.js", () => ({ + fetchWithSsrFGuard: async (params: { + url: string; + init?: RequestInit; + fetchImpl?: typeof fetch; + }) => { + const fetchImpl = params.fetchImpl ?? globalThis.fetch; + if (!fetchImpl) { + throw new Error("fetch is not available"); + } + const response = await fetchImpl(params.url, params.init); + return { + response, + finalUrl: params.url, + release: async () => {}, + }; + }, +})); + +vi.mock("./embeddings-ollama.js", () => ({ + createOllamaEmbeddingProvider: createOllamaEmbeddingProviderMock, +})); + +const createFetchMock = () => + vi.fn(async (_input?: unknown, _init?: unknown) => ({ + ok: true, + status: 200, + json: async () => ({ data: [{ embedding: [1, 2, 3] }] }), + })); + +const createGeminiFetchMock = () => + vi.fn(async (_input?: unknown, _init?: unknown) => ({ + ok: true, + status: 200, + json: async () => ({ embedding: { values: [1, 2, 3] } }), + })); + +function installFetchMock(fetchMock: typeof globalThis.fetch) { + vi.stubGlobal("fetch", fetchMock); +} + +function readFirstFetchRequest(fetchMock: { mock: { calls: unknown[][] } }) { + const [url, init] = fetchMock.mock.calls[0] ?? []; + return { url, init: init as RequestInit | undefined }; +} + +type ResolvedProviderAuth = Awaited>; + +beforeEach(() => { + vi.spyOn(authModule, "resolveApiKeyForProvider"); + vi.spyOn(nodeLlamaModule, "importNodeLlamaCpp"); +}); + +beforeEach(() => { + vi.useRealTimers(); +}); + +afterEach(() => { + vi.resetAllMocks(); + vi.unstubAllGlobals(); +}); + +function requireProvider(result: Awaited>) { + if (!result.provider) { + throw new Error("Expected embedding provider"); + } + return result.provider; +} + +function mockResolvedProviderKey(apiKey = "provider-key") { + vi.mocked(authModule.resolveApiKeyForProvider).mockResolvedValue({ + apiKey, + mode: "api-key", + source: "test", + }); +} + +function mockMissingLocalEmbeddingDependency() { + vi.mocked(nodeLlamaModule.importNodeLlamaCpp).mockRejectedValue( + Object.assign(new Error("Cannot find package 'node-llama-cpp'"), { + code: "ERR_MODULE_NOT_FOUND", + }), + ); +} + +function createLocalProvider(options?: { fallback?: "none" | "openai" }) { + return createEmbeddingProvider({ + config: {} as never, + provider: "local", + model: "text-embedding-3-small", + fallback: options?.fallback ?? "none", + }); +} + +function expectAutoSelectedProvider( + result: Awaited>, + expectedId: "openai" | "gemini" | "mistral", +) { + expect(result.requestedProvider).toBe("auto"); + const provider = requireProvider(result); + expect(provider.id).toBe(expectedId); + return provider; +} + +function createAutoProvider(model = "") { + return createEmbeddingProvider({ + config: {} as never, + provider: "auto", + model, + fallback: "none", + }); +} + +describe("embedding provider remote overrides", () => { + it("uses remote baseUrl/apiKey and merges headers", async () => { + const fetchMock = createFetchMock(); + installFetchMock(fetchMock as unknown as typeof globalThis.fetch); + mockPublicPinnedHostname(); + mockResolvedProviderKey("provider-key"); + + const cfg = { + models: { + providers: { + openai: { + baseUrl: "https://api.openai.com/v1", + headers: { + "X-Provider": "p", + "X-Shared": "provider", + }, + }, + }, + }, + }; + + const result = await createEmbeddingProvider({ + config: cfg as never, + provider: "openai", + remote: { + baseUrl: "https://example.com/v1", + apiKey: " remote-key ", + headers: { + "X-Shared": "remote", + "X-Remote": "r", + }, + }, + model: "text-embedding-3-small", + fallback: "openai", + }); + + const provider = requireProvider(result); + await provider.embedQuery("hello"); + + expect(authModule.resolveApiKeyForProvider).not.toHaveBeenCalled(); + const url = fetchMock.mock.calls[0]?.[0]; + const init = fetchMock.mock.calls[0]?.[1] as RequestInit | undefined; + expect(url).toBe("https://example.com/v1/embeddings"); + const headers = (init?.headers ?? {}) as Record; + expect(headers.Authorization).toBe("Bearer remote-key"); + expect(headers["Content-Type"]).toBe("application/json"); + expect(headers["X-Provider"]).toBe("p"); + expect(headers["X-Shared"]).toBe("remote"); + expect(headers["X-Remote"]).toBe("r"); + }); + + it("falls back to resolved api key when remote apiKey is blank", async () => { + const fetchMock = createFetchMock(); + installFetchMock(fetchMock as unknown as typeof globalThis.fetch); + mockPublicPinnedHostname(); + mockResolvedProviderKey("provider-key"); + + const cfg = { + models: { + providers: { + openai: { + baseUrl: "https://api.openai.com/v1", + }, + }, + }, + }; + + const result = await createEmbeddingProvider({ + config: cfg as never, + provider: "openai", + remote: { + baseUrl: "https://example.com/v1", + apiKey: " ", + }, + model: "text-embedding-3-small", + fallback: "openai", + }); + + const provider = requireProvider(result); + await provider.embedQuery("hello"); + + expect(authModule.resolveApiKeyForProvider).toHaveBeenCalledTimes(1); + const init = fetchMock.mock.calls[0]?.[1] as RequestInit | undefined; + const headers = (init?.headers as Record) ?? {}; + expect(headers.Authorization).toBe("Bearer provider-key"); + }); + + it("builds Gemini embeddings requests with api key header", async () => { + const fetchMock = createGeminiFetchMock(); + installFetchMock(fetchMock as unknown as typeof globalThis.fetch); + mockPublicPinnedHostname(); + mockResolvedProviderKey("provider-key"); + + const cfg = { + models: { + providers: { + google: { + baseUrl: "https://generativelanguage.googleapis.com/v1beta", + }, + }, + }, + }; + + const result = await createEmbeddingProvider({ + config: cfg as never, + provider: "gemini", + remote: { + apiKey: "gemini-key", + }, + model: "text-embedding-004", + fallback: "openai", + }); + + const provider = requireProvider(result); + await provider.embedQuery("hello"); + + const { url, init } = readFirstFetchRequest(fetchMock); + expect(url).toBe( + "https://generativelanguage.googleapis.com/v1beta/models/text-embedding-004:embedContent", + ); + const headers = (init?.headers ?? {}) as Record; + expect(headers["x-goog-api-key"]).toBe("gemini-key"); + expect(headers["Content-Type"]).toBe("application/json"); + }); + + it("fails fast when Gemini remote apiKey is an unresolved SecretRef", async () => { + await expect( + createEmbeddingProvider({ + config: {} as never, + provider: "gemini", + remote: { + apiKey: { source: "env", provider: "default", id: "GEMINI_API_KEY" }, + }, + model: "text-embedding-004", + fallback: "openai", + }), + ).rejects.toThrow(/agents\.\*\.memorySearch\.remote\.apiKey:/i); + }); + + it("uses GEMINI_API_KEY env indirection for Gemini remote apiKey", async () => { + const fetchMock = createGeminiFetchMock(); + installFetchMock(fetchMock as unknown as typeof globalThis.fetch); + mockPublicPinnedHostname(); + vi.stubEnv("GEMINI_API_KEY", "env-gemini-key"); + + const result = await createEmbeddingProvider({ + config: {} as never, + provider: "gemini", + remote: { + apiKey: "GEMINI_API_KEY", // pragma: allowlist secret + }, + model: "text-embedding-004", + fallback: "openai", + }); + + const provider = requireProvider(result); + await provider.embedQuery("hello"); + + const { init } = readFirstFetchRequest(fetchMock); + const headers = (init?.headers ?? {}) as Record; + expect(headers["x-goog-api-key"]).toBe("env-gemini-key"); + }); + + it("builds Mistral embeddings requests with bearer auth", async () => { + const fetchMock = createFetchMock(); + installFetchMock(fetchMock as unknown as typeof globalThis.fetch); + mockPublicPinnedHostname(); + mockResolvedProviderKey("provider-key"); + + const cfg = { + models: { + providers: { + mistral: { + baseUrl: "https://api.mistral.ai/v1", + }, + }, + }, + }; + + const result = await createEmbeddingProvider({ + config: cfg as never, + provider: "mistral", + remote: { + apiKey: "mistral-key", // pragma: allowlist secret + }, + model: "mistral/mistral-embed", + fallback: "none", + }); + + const provider = requireProvider(result); + await provider.embedQuery("hello"); + + const { url, init } = readFirstFetchRequest(fetchMock); + expect(url).toBe("https://api.mistral.ai/v1/embeddings"); + const headers = (init?.headers ?? {}) as Record; + expect(headers.Authorization).toBe("Bearer mistral-key"); + const payload = JSON.parse((init?.body as string | undefined) ?? "{}") as { model?: string }; + expect(payload.model).toBe("mistral-embed"); + }); +}); + +describe("embedding provider auto selection", () => { + it("keeps explicit model when openai is selected", async () => { + const fetchMock = vi.fn(async (_input?: unknown, _init?: unknown) => ({ + ok: true, + status: 200, + json: async () => ({ data: [{ embedding: [1, 2, 3] }] }), + })); + installFetchMock(fetchMock as unknown as typeof globalThis.fetch); + mockPublicPinnedHostname(); + vi.mocked(authModule.resolveApiKeyForProvider).mockImplementation(async ({ provider }) => { + if (provider === "openai") { + return { apiKey: "openai-key", source: "env: OPENAI_API_KEY", mode: "api-key" }; + } + throw new Error(`Unexpected provider ${provider}`); + }); + + const result = await createEmbeddingProvider({ + config: {} as never, + provider: "auto", + model: "text-embedding-3-small", + fallback: "none", + }); + + expect(result.requestedProvider).toBe("auto"); + const provider = requireProvider(result); + expect(provider.id).toBe("openai"); + await provider.embedQuery("hello"); + const url = fetchMock.mock.calls[0]?.[0]; + const init = fetchMock.mock.calls[0]?.[1] as RequestInit | undefined; + expect(url).toBe("https://api.openai.com/v1/embeddings"); + const payload = JSON.parse(init?.body as string) as { model?: string }; + expect(payload.model).toBe("text-embedding-3-small"); + }); + + it("selects the first available remote provider in auto mode", async () => { + const cases: Array<{ + name: string; + expectedProvider: "openai" | "gemini" | "mistral"; + fetchMockFactory: typeof createFetchMock | typeof createGeminiFetchMock; + resolveApiKey: (provider: string) => ResolvedProviderAuth; + expectedUrl: string; + }> = [ + { + name: "openai first", + expectedProvider: "openai" as const, + fetchMockFactory: createFetchMock, + resolveApiKey(provider: string): ResolvedProviderAuth { + if (provider === "openai") { + return { apiKey: "openai-key", source: "env: OPENAI_API_KEY", mode: "api-key" }; + } + throw new Error(`No API key found for provider "${provider}".`); + }, + expectedUrl: "https://api.openai.com/v1/embeddings", + }, + { + name: "gemini fallback", + expectedProvider: "gemini" as const, + fetchMockFactory: createGeminiFetchMock, + resolveApiKey(provider: string): ResolvedProviderAuth { + if (provider === "openai") { + throw new Error('No API key found for provider "openai".'); + } + if (provider === "google") { + return { + apiKey: "gemini-key", + source: "env: GEMINI_API_KEY", + mode: "api-key" as const, + }; + } + throw new Error(`Unexpected provider ${provider}`); + }, + expectedUrl: `https://generativelanguage.googleapis.com/v1beta/models/${DEFAULT_GEMINI_EMBEDDING_MODEL}:embedContent`, + }, + { + name: "mistral after earlier misses", + expectedProvider: "mistral" as const, + fetchMockFactory: createFetchMock, + resolveApiKey(provider: string): ResolvedProviderAuth { + if (provider === "mistral") { + return { + apiKey: "mistral-key", + source: "env: MISTRAL_API_KEY", + mode: "api-key" as const, + }; + } + throw new Error(`No API key found for provider "${provider}".`); + }, + expectedUrl: "https://api.mistral.ai/v1/embeddings", + }, + ]; + + for (const testCase of cases) { + vi.resetAllMocks(); + vi.unstubAllGlobals(); + const fetchMock = testCase.fetchMockFactory(); + installFetchMock(fetchMock as unknown as typeof globalThis.fetch); + mockPublicPinnedHostname(); + vi.mocked(authModule.resolveApiKeyForProvider).mockImplementation(async ({ provider }) => + testCase.resolveApiKey(provider), + ); + + const result = await createAutoProvider(); + const provider = expectAutoSelectedProvider(result, testCase.expectedProvider); + await provider.embedQuery("hello"); + const [url] = fetchMock.mock.calls[0] ?? []; + expect(url, testCase.name).toBe(testCase.expectedUrl); + } + }); +}); + +describe("embedding provider local fallback", () => { + it("falls back to openai when node-llama-cpp is missing", async () => { + mockMissingLocalEmbeddingDependency(); + + const fetchMock = createFetchMock(); + installFetchMock(fetchMock as unknown as typeof globalThis.fetch); + + mockResolvedProviderKey("provider-key"); + + const result = await createLocalProvider({ fallback: "openai" }); + + const provider = requireProvider(result); + expect(provider.id).toBe("openai"); + expect(result.fallbackFrom).toBe("local"); + expect(result.fallbackReason).toContain("node-llama-cpp"); + }); + + it("throws a helpful error when local is requested and fallback is none", async () => { + mockMissingLocalEmbeddingDependency(); + await expect(createLocalProvider()).rejects.toThrow(/optional dependency node-llama-cpp/i); + }); + + it("mentions every remote provider in local setup guidance", async () => { + mockMissingLocalEmbeddingDependency(); + await expect(createLocalProvider()).rejects.toThrow(/provider = "gemini"/i); + await expect(createLocalProvider()).rejects.toThrow(/provider = "mistral"/i); + }); +}); + +describe("local embedding normalization", () => { + async function createLocalProviderForTest() { + return createEmbeddingProvider({ + config: {} as never, + provider: "local", + model: "", + fallback: "none", + }); + } + + function mockSingleLocalEmbeddingVector( + vector: number[], + resolveModelFile: (modelPath: string, modelDirectory?: string) => Promise = async () => + "/fake/model.gguf", + ): void { + vi.mocked(nodeLlamaModule.importNodeLlamaCpp).mockResolvedValue({ + getLlama: async () => ({ + loadModel: vi.fn().mockResolvedValue({ + createEmbeddingContext: vi.fn().mockResolvedValue({ + getEmbeddingFor: vi.fn().mockResolvedValue({ + vector: new Float32Array(vector), + }), + }), + }), + }), + resolveModelFile, + LlamaLogLevel: { error: 0 }, + } as never); + } + + it("normalizes local embeddings to magnitude ~1.0", async () => { + const unnormalizedVector = [2.35, 3.45, 0.63, 4.3, 1.2, 5.1, 2.8, 3.9]; + const resolveModelFileMock = vi.fn(async () => "/fake/model.gguf"); + + mockSingleLocalEmbeddingVector(unnormalizedVector, resolveModelFileMock); + + const result = await createLocalProviderForTest(); + + const provider = requireProvider(result); + const embedding = await provider.embedQuery("test query"); + + const magnitude = Math.sqrt(embedding.reduce((sum, x) => sum + x * x, 0)); + + expect(magnitude).toBeCloseTo(1.0, 5); + expect(resolveModelFileMock).toHaveBeenCalledWith(DEFAULT_LOCAL_MODEL, undefined); + }); + + it("handles zero vector without division by zero", async () => { + const zeroVector = [0, 0, 0, 0]; + + mockSingleLocalEmbeddingVector(zeroVector); + + const result = await createLocalProviderForTest(); + + const provider = requireProvider(result); + const embedding = await provider.embedQuery("test"); + + expect(embedding).toEqual([0, 0, 0, 0]); + expect(embedding.every((value) => Number.isFinite(value))).toBe(true); + }); + + it("sanitizes non-finite values before normalization", async () => { + const nonFiniteVector = [1, Number.NaN, Number.POSITIVE_INFINITY, Number.NEGATIVE_INFINITY]; + + mockSingleLocalEmbeddingVector(nonFiniteVector); + + const result = await createLocalProviderForTest(); + + const provider = requireProvider(result); + const embedding = await provider.embedQuery("test"); + + expect(embedding).toEqual([1, 0, 0, 0]); + expect(embedding.every((value) => Number.isFinite(value))).toBe(true); + }); + + it("normalizes batch embeddings to magnitude ~1.0", async () => { + const unnormalizedVectors = [ + [2.35, 3.45, 0.63, 4.3], + [10.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 1.0, 1.0], + ]; + + vi.mocked(nodeLlamaModule.importNodeLlamaCpp).mockResolvedValue({ + getLlama: async () => ({ + loadModel: vi.fn().mockResolvedValue({ + createEmbeddingContext: vi.fn().mockResolvedValue({ + getEmbeddingFor: vi + .fn() + .mockResolvedValueOnce({ vector: new Float32Array(unnormalizedVectors[0]) }) + .mockResolvedValueOnce({ vector: new Float32Array(unnormalizedVectors[1]) }) + .mockResolvedValueOnce({ vector: new Float32Array(unnormalizedVectors[2]) }), + }), + }), + }), + resolveModelFile: async () => "/fake/model.gguf", + LlamaLogLevel: { error: 0 }, + } as never); + + const result = await createLocalProviderForTest(); + + const provider = requireProvider(result); + const embeddings = await provider.embedBatch(["text1", "text2", "text3"]); + + for (const embedding of embeddings) { + const magnitude = Math.sqrt(embedding.reduce((sum, x) => sum + x * x, 0)); + expect(magnitude).toBeCloseTo(1.0, 5); + } + }); +}); + +describe("local embedding ensureContext concurrency", () => { + async function setupLocalProviderWithMockedInit(params?: { + initializationDelayMs?: number; + failFirstGetLlama?: boolean; + }) { + const getLlamaSpy = vi.fn(); + const loadModelSpy = vi.fn(); + const createContextSpy = vi.fn(); + let shouldFail = params?.failFirstGetLlama ?? false; + + vi.spyOn(nodeLlamaModule, "importNodeLlamaCpp").mockResolvedValue({ + getLlama: async (...args: unknown[]) => { + getLlamaSpy(...args); + if (shouldFail) { + shouldFail = false; + throw new Error("transient init failure"); + } + if (params?.initializationDelayMs) { + await sleep(params.initializationDelayMs); + } + return { + loadModel: async (...modelArgs: unknown[]) => { + loadModelSpy(...modelArgs); + if (params?.initializationDelayMs) { + await sleep(params.initializationDelayMs); + } + return { + createEmbeddingContext: async () => { + createContextSpy(); + return { + getEmbeddingFor: vi.fn().mockResolvedValue({ + vector: new Float32Array([1, 0, 0, 0]), + }), + }; + }, + }; + }, + }; + }, + resolveModelFile: async () => "/fake/model.gguf", + LlamaLogLevel: { error: 0 }, + } as never); + + const result = await createEmbeddingProvider({ + config: {} as never, + provider: "local", + model: "", + fallback: "none", + }); + + return { + provider: requireProvider(result), + getLlamaSpy, + loadModelSpy, + createContextSpy, + }; + } + + it("loads the model only once when embedBatch is called concurrently", async () => { + const { provider, getLlamaSpy, loadModelSpy, createContextSpy } = + await setupLocalProviderWithMockedInit({ + initializationDelayMs: 50, + }); + + const results = await Promise.all([ + provider.embedBatch(["text1"]), + provider.embedBatch(["text2"]), + provider.embedBatch(["text3"]), + provider.embedBatch(["text4"]), + ]); + + expect(results).toHaveLength(4); + for (const embeddings of results) { + expect(embeddings).toHaveLength(1); + expect(embeddings[0]).toHaveLength(4); + } + + expect(getLlamaSpy).toHaveBeenCalledTimes(1); + expect(loadModelSpy).toHaveBeenCalledTimes(1); + expect(createContextSpy).toHaveBeenCalledTimes(1); + }); + + it("retries initialization after a transient ensureContext failure", async () => { + const { provider, getLlamaSpy, loadModelSpy, createContextSpy } = + await setupLocalProviderWithMockedInit({ + failFirstGetLlama: true, + }); + + await expect(provider.embedBatch(["first"])).rejects.toThrow("transient init failure"); + + const recovered = await provider.embedBatch(["second"]); + expect(recovered).toHaveLength(1); + expect(recovered[0]).toHaveLength(4); + + expect(getLlamaSpy).toHaveBeenCalledTimes(2); + expect(loadModelSpy).toHaveBeenCalledTimes(1); + expect(createContextSpy).toHaveBeenCalledTimes(1); + }); + + it("shares initialization when embedQuery and embedBatch start concurrently", async () => { + const { provider, getLlamaSpy, loadModelSpy, createContextSpy } = + await setupLocalProviderWithMockedInit({ + initializationDelayMs: 50, + }); + + const [queryA, batch, queryB] = await Promise.all([ + provider.embedQuery("query-a"), + provider.embedBatch(["batch-a", "batch-b"]), + provider.embedQuery("query-b"), + ]); + + expect(queryA).toHaveLength(4); + expect(batch).toHaveLength(2); + expect(queryB).toHaveLength(4); + expect(batch[0]).toHaveLength(4); + expect(batch[1]).toHaveLength(4); + + expect(getLlamaSpy).toHaveBeenCalledTimes(1); + expect(loadModelSpy).toHaveBeenCalledTimes(1); + expect(createContextSpy).toHaveBeenCalledTimes(1); + }); +}); + +describe("FTS-only fallback when no provider available", () => { + it("returns null provider when all requested auth paths fail", async () => { + vi.mocked(authModule.resolveApiKeyForProvider).mockRejectedValue( + new Error("No API key found for provider"), + ); + + for (const testCase of [ + { + name: "auto mode", + options: { + config: {} as never, + provider: "auto" as const, + model: "", + fallback: "none" as const, + }, + requestedProvider: "auto", + fallbackFrom: undefined, + reasonIncludes: "No API key", + }, + { + name: "explicit provider only", + options: { + config: {} as never, + provider: "openai" as const, + model: "text-embedding-3-small", + fallback: "none" as const, + }, + requestedProvider: "openai", + fallbackFrom: undefined, + reasonIncludes: "No API key", + }, + { + name: "primary and fallback", + options: { + config: {} as never, + provider: "openai" as const, + model: "text-embedding-3-small", + fallback: "gemini" as const, + }, + requestedProvider: "openai", + fallbackFrom: "openai", + reasonIncludes: "Fallback to gemini failed", + }, + ]) { + const result = await createEmbeddingProvider(testCase.options); + expect(result.provider, testCase.name).toBeNull(); + expect(result.requestedProvider, testCase.name).toBe(testCase.requestedProvider); + expect(result.fallbackFrom, testCase.name).toBe(testCase.fallbackFrom); + expect(result.providerUnavailableReason, testCase.name).toContain(testCase.reasonIncludes); + } + }); +}); diff --git a/src/memory-host-sdk/host/embeddings.ts b/src/memory-host-sdk/host/embeddings.ts new file mode 100644 index 00000000000..ee18eb34c7f --- /dev/null +++ b/src/memory-host-sdk/host/embeddings.ts @@ -0,0 +1,324 @@ +import fsSync from "node:fs"; +import type { Llama, LlamaEmbeddingContext, LlamaModel } from "node-llama-cpp"; +import type { OpenClawConfig } from "../../config/config.js"; +import type { SecretInput } from "../../config/types.secrets.js"; +import { formatErrorMessage } from "../../infra/errors.js"; +import { resolveUserPath } from "../../utils.js"; +import type { EmbeddingInput } from "./embedding-inputs.js"; +import { sanitizeAndNormalizeEmbedding } from "./embedding-vectors.js"; +import { + createGeminiEmbeddingProvider, + type GeminiEmbeddingClient, + type GeminiTaskType, +} from "./embeddings-gemini.js"; +import { + createMistralEmbeddingProvider, + type MistralEmbeddingClient, +} from "./embeddings-mistral.js"; +import { createOllamaEmbeddingProvider, type OllamaEmbeddingClient } from "./embeddings-ollama.js"; +import { createOpenAiEmbeddingProvider, type OpenAiEmbeddingClient } from "./embeddings-openai.js"; +import { createVoyageEmbeddingProvider, type VoyageEmbeddingClient } from "./embeddings-voyage.js"; +import { importNodeLlamaCpp } from "./node-llama.js"; + +export type { GeminiEmbeddingClient } from "./embeddings-gemini.js"; +export type { MistralEmbeddingClient } from "./embeddings-mistral.js"; +export type { OpenAiEmbeddingClient } from "./embeddings-openai.js"; +export type { VoyageEmbeddingClient } from "./embeddings-voyage.js"; +export type { OllamaEmbeddingClient } from "./embeddings-ollama.js"; + +export type EmbeddingProvider = { + id: string; + model: string; + maxInputTokens?: number; + embedQuery: (text: string) => Promise; + embedBatch: (texts: string[]) => Promise; + embedBatchInputs?: (inputs: EmbeddingInput[]) => Promise; +}; + +export type EmbeddingProviderId = "openai" | "local" | "gemini" | "voyage" | "mistral" | "ollama"; +export type EmbeddingProviderRequest = EmbeddingProviderId | "auto"; +export type EmbeddingProviderFallback = EmbeddingProviderId | "none"; + +// Remote providers considered for auto-selection when provider === "auto". +// Ollama is intentionally excluded here so that "auto" mode does not +// implicitly assume a local Ollama instance is available. +const REMOTE_EMBEDDING_PROVIDER_IDS = ["openai", "gemini", "voyage", "mistral"] as const; + +export type EmbeddingProviderResult = { + provider: EmbeddingProvider | null; + requestedProvider: EmbeddingProviderRequest; + fallbackFrom?: EmbeddingProviderId; + fallbackReason?: string; + providerUnavailableReason?: string; + openAi?: OpenAiEmbeddingClient; + gemini?: GeminiEmbeddingClient; + voyage?: VoyageEmbeddingClient; + mistral?: MistralEmbeddingClient; + ollama?: OllamaEmbeddingClient; +}; + +export type EmbeddingProviderOptions = { + config: OpenClawConfig; + agentDir?: string; + provider: EmbeddingProviderRequest; + remote?: { + baseUrl?: string; + apiKey?: SecretInput; + headers?: Record; + }; + model: string; + fallback: EmbeddingProviderFallback; + local?: { + modelPath?: string; + modelCacheDir?: string; + }; + /** Gemini embedding-2: output vector dimensions (768, 1536, or 3072). */ + outputDimensionality?: number; + /** Gemini: override the default task type sent with embedding requests. */ + taskType?: GeminiTaskType; +}; + +export const DEFAULT_LOCAL_MODEL = + "hf:ggml-org/embeddinggemma-300m-qat-q8_0-GGUF/embeddinggemma-300m-qat-Q8_0.gguf"; + +function canAutoSelectLocal(options: EmbeddingProviderOptions): boolean { + const modelPath = options.local?.modelPath?.trim(); + if (!modelPath) { + return false; + } + if (/^(hf:|https?:)/i.test(modelPath)) { + return false; + } + const resolved = resolveUserPath(modelPath); + try { + return fsSync.statSync(resolved).isFile(); + } catch { + return false; + } +} + +function isMissingApiKeyError(err: unknown): boolean { + const message = formatErrorMessage(err); + return message.includes("No API key found for provider"); +} + +export async function createLocalEmbeddingProvider( + options: EmbeddingProviderOptions, +): Promise { + const modelPath = options.local?.modelPath?.trim() || DEFAULT_LOCAL_MODEL; + const modelCacheDir = options.local?.modelCacheDir?.trim(); + + // Lazy-load node-llama-cpp to keep startup light unless local is enabled. + const { getLlama, resolveModelFile, LlamaLogLevel } = await importNodeLlamaCpp(); + + let llama: Llama | null = null; + let embeddingModel: LlamaModel | null = null; + let embeddingContext: LlamaEmbeddingContext | null = null; + let initPromise: Promise | null = null; + + const ensureContext = async (): Promise => { + if (embeddingContext) { + return embeddingContext; + } + if (initPromise) { + return initPromise; + } + initPromise = (async () => { + try { + if (!llama) { + llama = await getLlama({ logLevel: LlamaLogLevel.error }); + } + if (!embeddingModel) { + const resolved = await resolveModelFile(modelPath, modelCacheDir || undefined); + embeddingModel = await llama.loadModel({ modelPath: resolved }); + } + if (!embeddingContext) { + embeddingContext = await embeddingModel.createEmbeddingContext(); + } + return embeddingContext; + } catch (err) { + initPromise = null; + throw err; + } + })(); + return initPromise; + }; + + return { + id: "local", + model: modelPath, + embedQuery: async (text) => { + const ctx = await ensureContext(); + const embedding = await ctx.getEmbeddingFor(text); + return sanitizeAndNormalizeEmbedding(Array.from(embedding.vector)); + }, + embedBatch: async (texts) => { + const ctx = await ensureContext(); + const embeddings = await Promise.all( + texts.map(async (text) => { + const embedding = await ctx.getEmbeddingFor(text); + return sanitizeAndNormalizeEmbedding(Array.from(embedding.vector)); + }), + ); + return embeddings; + }, + }; +} + +export async function createEmbeddingProvider( + options: EmbeddingProviderOptions, +): Promise { + const requestedProvider = options.provider; + const fallback = options.fallback; + + const createProvider = async (id: EmbeddingProviderId) => { + if (id === "local") { + const provider = await createLocalEmbeddingProvider(options); + return { provider }; + } + if (id === "ollama") { + const { provider, client } = await createOllamaEmbeddingProvider(options); + return { provider, ollama: client }; + } + if (id === "gemini") { + const { provider, client } = await createGeminiEmbeddingProvider(options); + return { provider, gemini: client }; + } + if (id === "voyage") { + const { provider, client } = await createVoyageEmbeddingProvider(options); + return { provider, voyage: client }; + } + if (id === "mistral") { + const { provider, client } = await createMistralEmbeddingProvider(options); + return { provider, mistral: client }; + } + const { provider, client } = await createOpenAiEmbeddingProvider(options); + return { provider, openAi: client }; + }; + + const formatPrimaryError = (err: unknown, provider: EmbeddingProviderId) => + provider === "local" ? formatLocalSetupError(err) : formatErrorMessage(err); + + if (requestedProvider === "auto") { + const missingKeyErrors: string[] = []; + let localError: string | null = null; + + if (canAutoSelectLocal(options)) { + try { + const local = await createProvider("local"); + return { ...local, requestedProvider }; + } catch (err) { + localError = formatLocalSetupError(err); + } + } + + for (const provider of REMOTE_EMBEDDING_PROVIDER_IDS) { + try { + const result = await createProvider(provider); + return { ...result, requestedProvider }; + } catch (err) { + const message = formatPrimaryError(err, provider); + if (isMissingApiKeyError(err)) { + missingKeyErrors.push(message); + continue; + } + // Non-auth errors (e.g., network) are still fatal + const wrapped = new Error(message) as Error & { cause?: unknown }; + wrapped.cause = err; + throw wrapped; + } + } + + // All providers failed due to missing API keys - return null provider for FTS-only mode + const details = [...missingKeyErrors, localError].filter(Boolean) as string[]; + const reason = details.length > 0 ? details.join("\n\n") : "No embeddings provider available."; + return { + provider: null, + requestedProvider, + providerUnavailableReason: reason, + }; + } + + try { + const primary = await createProvider(requestedProvider); + return { ...primary, requestedProvider }; + } catch (primaryErr) { + const reason = formatPrimaryError(primaryErr, requestedProvider); + if (fallback && fallback !== "none" && fallback !== requestedProvider) { + try { + const fallbackResult = await createProvider(fallback); + return { + ...fallbackResult, + requestedProvider, + fallbackFrom: requestedProvider, + fallbackReason: reason, + }; + } catch (fallbackErr) { + // Both primary and fallback failed - check if it's auth-related + const fallbackReason = formatErrorMessage(fallbackErr); + const combinedReason = `${reason}\n\nFallback to ${fallback} failed: ${fallbackReason}`; + if (isMissingApiKeyError(primaryErr) && isMissingApiKeyError(fallbackErr)) { + // Both failed due to missing API keys - return null for FTS-only mode + return { + provider: null, + requestedProvider, + fallbackFrom: requestedProvider, + fallbackReason: reason, + providerUnavailableReason: combinedReason, + }; + } + // Non-auth errors are still fatal + const wrapped = new Error(combinedReason) as Error & { cause?: unknown }; + wrapped.cause = fallbackErr; + throw wrapped; + } + } + // No fallback configured - check if we should degrade to FTS-only + if (isMissingApiKeyError(primaryErr)) { + return { + provider: null, + requestedProvider, + providerUnavailableReason: reason, + }; + } + const wrapped = new Error(reason) as Error & { cause?: unknown }; + wrapped.cause = primaryErr; + throw wrapped; + } +} + +function isNodeLlamaCppMissing(err: unknown): boolean { + if (!(err instanceof Error)) { + return false; + } + const code = (err as Error & { code?: unknown }).code; + if (code === "ERR_MODULE_NOT_FOUND") { + return err.message.includes("node-llama-cpp"); + } + return false; +} + +function formatLocalSetupError(err: unknown): string { + const detail = formatErrorMessage(err); + const missing = isNodeLlamaCppMissing(err); + return [ + "Local embeddings unavailable.", + missing + ? "Reason: optional dependency node-llama-cpp is missing (or failed to install)." + : detail + ? `Reason: ${detail}` + : undefined, + missing && detail ? `Detail: ${detail}` : null, + "To enable local embeddings:", + "1) Use Node 24 (recommended for installs/updates; Node 22 LTS, currently 22.14+, remains supported)", + missing + ? "2) Reinstall OpenClaw (this should install node-llama-cpp): npm i -g openclaw@latest" + : null, + "3) If you use pnpm: pnpm approve-builds (select node-llama-cpp), then pnpm rebuild node-llama-cpp", + ...REMOTE_EMBEDDING_PROVIDER_IDS.map( + (provider) => `Or set agents.defaults.memorySearch.provider = "${provider}" (remote).`, + ), + ] + .filter(Boolean) + .join("\n"); +} diff --git a/src/memory-host-sdk/host/fs-utils.ts b/src/memory-host-sdk/host/fs-utils.ts new file mode 100644 index 00000000000..81107c7ef3d --- /dev/null +++ b/src/memory-host-sdk/host/fs-utils.ts @@ -0,0 +1,31 @@ +import type { Stats } from "node:fs"; +import fs from "node:fs/promises"; + +export type RegularFileStatResult = { missing: true } | { missing: false; stat: Stats }; + +export function isFileMissingError( + err: unknown, +): err is NodeJS.ErrnoException & { code: "ENOENT" } { + return Boolean( + err && + typeof err === "object" && + "code" in err && + (err as Partial).code === "ENOENT", + ); +} + +export async function statRegularFile(absPath: string): Promise { + let stat: Stats; + try { + stat = await fs.lstat(absPath); + } catch (err) { + if (isFileMissingError(err)) { + return { missing: true }; + } + throw err; + } + if (stat.isSymbolicLink() || !stat.isFile()) { + throw new Error("path required"); + } + return { missing: false, stat }; +} diff --git a/src/memory-host-sdk/host/internal.test.ts b/src/memory-host-sdk/host/internal.test.ts new file mode 100644 index 00000000000..8e7a748d740 --- /dev/null +++ b/src/memory-host-sdk/host/internal.test.ts @@ -0,0 +1,423 @@ +import fs from "node:fs/promises"; +import os from "node:os"; +import path from "node:path"; +import { afterEach, beforeEach, describe, expect, it } from "vitest"; +import { + buildMultimodalChunkForIndexing, + buildFileEntry, + chunkMarkdown, + listMemoryFiles, + normalizeExtraMemoryPaths, + remapChunkLines, +} from "./internal.js"; +import { + DEFAULT_MEMORY_MULTIMODAL_MAX_FILE_BYTES, + type MemoryMultimodalSettings, +} from "./multimodal.js"; + +function setupTempDirLifecycle(prefix: string): () => string { + let tmpDir = ""; + beforeEach(async () => { + tmpDir = await fs.mkdtemp(path.join(os.tmpdir(), prefix)); + }); + afterEach(async () => { + await fs.rm(tmpDir, { recursive: true, force: true }); + }); + return () => tmpDir; +} + +describe("normalizeExtraMemoryPaths", () => { + it("trims, resolves, and dedupes paths", () => { + const workspaceDir = path.join(os.tmpdir(), "memory-test-workspace"); + const absPath = path.resolve(path.sep, "shared-notes"); + const result = normalizeExtraMemoryPaths(workspaceDir, [ + " notes ", + "./notes", + absPath, + absPath, + "", + ]); + expect(result).toEqual([path.resolve(workspaceDir, "notes"), absPath]); + }); +}); + +describe("listMemoryFiles", () => { + const getTmpDir = setupTempDirLifecycle("memory-test-"); + const multimodal: MemoryMultimodalSettings = { + enabled: true, + modalities: ["image", "audio"], + maxFileBytes: DEFAULT_MEMORY_MULTIMODAL_MAX_FILE_BYTES, + }; + + it("includes files from additional paths (directory)", async () => { + const tmpDir = getTmpDir(); + await fs.writeFile(path.join(tmpDir, "MEMORY.md"), "# Default memory"); + const extraDir = path.join(tmpDir, "extra-notes"); + await fs.mkdir(extraDir, { recursive: true }); + await fs.writeFile(path.join(extraDir, "note1.md"), "# Note 1"); + await fs.writeFile(path.join(extraDir, "note2.md"), "# Note 2"); + await fs.writeFile(path.join(extraDir, "ignore.txt"), "Not a markdown file"); + + const files = await listMemoryFiles(tmpDir, [extraDir]); + expect(files).toHaveLength(3); + expect(files.some((file) => file.endsWith("MEMORY.md"))).toBe(true); + expect(files.some((file) => file.endsWith("note1.md"))).toBe(true); + expect(files.some((file) => file.endsWith("note2.md"))).toBe(true); + expect(files.some((file) => file.endsWith("ignore.txt"))).toBe(false); + }); + + it("includes files from additional paths (single file)", async () => { + const tmpDir = getTmpDir(); + await fs.writeFile(path.join(tmpDir, "MEMORY.md"), "# Default memory"); + const singleFile = path.join(tmpDir, "standalone.md"); + await fs.writeFile(singleFile, "# Standalone"); + + const files = await listMemoryFiles(tmpDir, [singleFile]); + expect(files).toHaveLength(2); + expect(files.some((file) => file.endsWith("standalone.md"))).toBe(true); + }); + + it("handles relative paths in additional paths", async () => { + const tmpDir = getTmpDir(); + await fs.writeFile(path.join(tmpDir, "MEMORY.md"), "# Default memory"); + const extraDir = path.join(tmpDir, "subdir"); + await fs.mkdir(extraDir, { recursive: true }); + await fs.writeFile(path.join(extraDir, "nested.md"), "# Nested"); + + const files = await listMemoryFiles(tmpDir, ["subdir"]); + expect(files).toHaveLength(2); + expect(files.some((file) => file.endsWith("nested.md"))).toBe(true); + }); + + it("ignores non-existent additional paths", async () => { + const tmpDir = getTmpDir(); + await fs.writeFile(path.join(tmpDir, "MEMORY.md"), "# Default memory"); + + const files = await listMemoryFiles(tmpDir, ["/does/not/exist"]); + expect(files).toHaveLength(1); + }); + + it("ignores symlinked files and directories", async () => { + const tmpDir = getTmpDir(); + await fs.writeFile(path.join(tmpDir, "MEMORY.md"), "# Default memory"); + const extraDir = path.join(tmpDir, "extra"); + await fs.mkdir(extraDir, { recursive: true }); + await fs.writeFile(path.join(extraDir, "note.md"), "# Note"); + + const targetFile = path.join(tmpDir, "target.md"); + await fs.writeFile(targetFile, "# Target"); + const linkFile = path.join(extraDir, "linked.md"); + + const targetDir = path.join(tmpDir, "target-dir"); + await fs.mkdir(targetDir, { recursive: true }); + await fs.writeFile(path.join(targetDir, "nested.md"), "# Nested"); + const linkDir = path.join(tmpDir, "linked-dir"); + + let symlinksOk = true; + try { + await fs.symlink(targetFile, linkFile, "file"); + await fs.symlink(targetDir, linkDir, "dir"); + } catch (err) { + const code = (err as NodeJS.ErrnoException).code; + if (code === "EPERM" || code === "EACCES") { + symlinksOk = false; + } else { + throw err; + } + } + + const files = await listMemoryFiles(tmpDir, [extraDir, linkDir]); + expect(files.some((file) => file.endsWith("note.md"))).toBe(true); + if (symlinksOk) { + expect(files.some((file) => file.endsWith("linked.md"))).toBe(false); + expect(files.some((file) => file.endsWith("nested.md"))).toBe(false); + } + }); + + it("dedupes overlapping extra paths that resolve to the same file", async () => { + const tmpDir = getTmpDir(); + await fs.writeFile(path.join(tmpDir, "MEMORY.md"), "# Default memory"); + const files = await listMemoryFiles(tmpDir, [tmpDir, ".", path.join(tmpDir, "MEMORY.md")]); + const memoryMatches = files.filter((file) => file.endsWith("MEMORY.md")); + expect(memoryMatches).toHaveLength(1); + }); + + it("includes image and audio files from extra paths when multimodal is enabled", async () => { + const tmpDir = getTmpDir(); + const extraDir = path.join(tmpDir, "media"); + await fs.mkdir(extraDir, { recursive: true }); + await fs.writeFile(path.join(extraDir, "diagram.png"), Buffer.from("png")); + await fs.writeFile(path.join(extraDir, "note.wav"), Buffer.from("wav")); + await fs.writeFile(path.join(extraDir, "ignore.bin"), Buffer.from("bin")); + + const files = await listMemoryFiles(tmpDir, [extraDir], multimodal); + expect(files.some((file) => file.endsWith("diagram.png"))).toBe(true); + expect(files.some((file) => file.endsWith("note.wav"))).toBe(true); + expect(files.some((file) => file.endsWith("ignore.bin"))).toBe(false); + }); +}); + +describe("buildFileEntry", () => { + const getTmpDir = setupTempDirLifecycle("memory-build-entry-"); + const multimodal: MemoryMultimodalSettings = { + enabled: true, + modalities: ["image", "audio"], + maxFileBytes: DEFAULT_MEMORY_MULTIMODAL_MAX_FILE_BYTES, + }; + + it("returns null when the file disappears before reading", async () => { + const tmpDir = getTmpDir(); + const target = path.join(tmpDir, "ghost.md"); + await fs.writeFile(target, "ghost", "utf-8"); + await fs.rm(target); + const entry = await buildFileEntry(target, tmpDir); + expect(entry).toBeNull(); + }); + + it("returns metadata when the file exists", async () => { + const tmpDir = getTmpDir(); + const target = path.join(tmpDir, "note.md"); + await fs.writeFile(target, "hello", "utf-8"); + const entry = await buildFileEntry(target, tmpDir); + expect(entry).not.toBeNull(); + expect(entry?.path).toBe("note.md"); + expect(entry?.size).toBeGreaterThan(0); + }); + + it("returns multimodal metadata for eligible image files", async () => { + const tmpDir = getTmpDir(); + const target = path.join(tmpDir, "diagram.png"); + await fs.writeFile(target, Buffer.from("png")); + + const entry = await buildFileEntry(target, tmpDir, multimodal); + + expect(entry).toMatchObject({ + path: "diagram.png", + kind: "multimodal", + modality: "image", + mimeType: "image/png", + contentText: "Image file: diagram.png", + }); + }); + + it("builds a multimodal chunk lazily for indexing", async () => { + const tmpDir = getTmpDir(); + const target = path.join(tmpDir, "diagram.png"); + await fs.writeFile(target, Buffer.from("png")); + + const entry = await buildFileEntry(target, tmpDir, multimodal); + const built = await buildMultimodalChunkForIndexing(entry!); + + expect(built?.chunk.embeddingInput?.parts).toEqual([ + { type: "text", text: "Image file: diagram.png" }, + expect.objectContaining({ type: "inline-data", mimeType: "image/png" }), + ]); + expect(built?.structuredInputBytes).toBeGreaterThan(0); + }); + + it("skips lazy multimodal indexing when the file grows after discovery", async () => { + const tmpDir = getTmpDir(); + const target = path.join(tmpDir, "diagram.png"); + await fs.writeFile(target, Buffer.from("png")); + + const entry = await buildFileEntry(target, tmpDir, multimodal); + await fs.writeFile(target, Buffer.alloc(entry!.size + 32, 1)); + + await expect(buildMultimodalChunkForIndexing(entry!)).resolves.toBeNull(); + }); + + it("skips lazy multimodal indexing when file bytes change after discovery", async () => { + const tmpDir = getTmpDir(); + const target = path.join(tmpDir, "diagram.png"); + await fs.writeFile(target, Buffer.from("png")); + + const entry = await buildFileEntry(target, tmpDir, multimodal); + await fs.writeFile(target, Buffer.from("gif")); + + await expect(buildMultimodalChunkForIndexing(entry!)).resolves.toBeNull(); + }); +}); + +describe("chunkMarkdown", () => { + it("splits overly long lines into max-sized chunks", () => { + const chunkTokens = 400; + const maxChars = chunkTokens * 4; + const content = "a".repeat(maxChars * 3 + 25); + const chunks = chunkMarkdown(content, { tokens: chunkTokens, overlap: 0 }); + expect(chunks.length).toBeGreaterThan(1); + for (const chunk of chunks) { + expect(chunk.text.length).toBeLessThanOrEqual(maxChars); + } + }); + + it("produces more chunks for CJK text than for equal-length ASCII text", () => { + // CJK chars ≈ 1 token each; ASCII chars ≈ 0.25 tokens each. + // For the same raw character count, CJK content should produce more chunks + // because each character "weighs" ~4× more in token estimation. + const chunkTokens = 50; + + // 400 ASCII chars → ~100 tokens → fits in ~2 chunks + const asciiLines = Array.from({ length: 20 }, () => "a".repeat(20)).join("\n"); + const asciiChunks = chunkMarkdown(asciiLines, { tokens: chunkTokens, overlap: 0 }); + + // 400 CJK chars → ~400 tokens → needs ~8 chunks + const cjkLines = Array.from({ length: 20 }, () => "你".repeat(20)).join("\n"); + const cjkChunks = chunkMarkdown(cjkLines, { tokens: chunkTokens, overlap: 0 }); + + expect(cjkChunks.length).toBeGreaterThan(asciiChunks.length); + }); + + it("respects token budget for Chinese text", () => { + // With tokens=100, each CJK char ≈ 1 token, so chunks should hold ~100 CJK chars. + const chunkTokens = 100; + const lines: string[] = []; + for (let i = 0; i < 50; i++) { + lines.push("这是一个测试句子用来验证分块逻辑是否正确处理中文文本内容"); + } + const content = lines.join("\n"); + const chunks = chunkMarkdown(content, { tokens: chunkTokens, overlap: 0 }); + + expect(chunks.length).toBeGreaterThan(1); + // Each chunk's CJK content should not vastly exceed the token budget. + // With CJK-aware estimation, each char ≈ 1 token, so chunk text length + // (in CJK chars) should be roughly <= tokens budget (with some tolerance + // for line boundaries). + for (const chunk of chunks) { + // Count actual CJK characters in the chunk + const cjkCount = (chunk.text.match(/[\u4e00-\u9fff]/g) ?? []).length; + // Allow 2× tolerance for line-boundary rounding + expect(cjkCount).toBeLessThanOrEqual(chunkTokens * 2); + } + }); + + it("keeps English chunking behavior unchanged", () => { + const chunkTokens = 100; + const maxChars = chunkTokens * 4; // 400 chars + const content = "hello world this is a test. ".repeat(50); + const chunks = chunkMarkdown(content, { tokens: chunkTokens, overlap: 0 }); + expect(chunks.length).toBeGreaterThan(1); + for (const chunk of chunks) { + expect(chunk.text.length).toBeLessThanOrEqual(maxChars); + } + }); + + it("handles mixed CJK and ASCII content correctly", () => { + const chunkTokens = 50; + const lines: string[] = []; + for (let i = 0; i < 30; i++) { + lines.push(`Line ${i}: 这是中英文混合的测试内容 with some English text`); + } + const content = lines.join("\n"); + const chunks = chunkMarkdown(content, { tokens: chunkTokens, overlap: 0 }); + // Should produce multiple chunks and not crash + expect(chunks.length).toBeGreaterThan(1); + // Verify all content is preserved + const reconstructed = chunks.map((c) => c.text).join("\n"); + // Due to overlap=0, the concatenated chunks should cover all lines + expect(reconstructed).toContain("Line 0"); + expect(reconstructed).toContain("Line 29"); + }); + + it("splits very long CJK lines into budget-sized segments", () => { + // A single line of 2000 CJK characters (no newlines). + // With tokens=200, each CJK char ≈ 1 token. + const longCjkLine = "中".repeat(2000); + const chunks = chunkMarkdown(longCjkLine, { tokens: 200, overlap: 0 }); + expect(chunks.length).toBeGreaterThanOrEqual(8); + for (const chunk of chunks) { + const cjkCount = (chunk.text.match(/[\u4E00-\u9FFF]/g) ?? []).length; + expect(cjkCount).toBeLessThanOrEqual(200 * 2); + } + }); + it("does not break surrogate pairs when splitting long CJK lines", () => { + // "𠀀" (U+20000) is a surrogate pair: 2 UTF-16 code units per character. + // A line of 500 such characters = 1000 UTF-16 code units. + // With tokens=99 (odd), the fine-split must not cut inside a pair. + const surrogateChar = "\u{20000}"; // 𠀀 + const longLine = surrogateChar.repeat(500); + const chunks = chunkMarkdown(longLine, { tokens: 99, overlap: 0 }); + for (const chunk of chunks) { + // No chunk should contain the Unicode replacement character U+FFFD, + // which would indicate a broken surrogate pair. + expect(chunk.text).not.toContain("\uFFFD"); + // Every character in the chunk should be a valid string (no lone surrogates). + for (let i = 0; i < chunk.text.length; i += 1) { + const code = chunk.text.charCodeAt(i); + if (code >= 0xd800 && code <= 0xdbff) { + // High surrogate must be followed by a low surrogate + const next = chunk.text.charCodeAt(i + 1); + expect(next).toBeGreaterThanOrEqual(0xdc00); + expect(next).toBeLessThanOrEqual(0xdfff); + } + } + } + }); + it("does not over-split long Latin lines (backward compat)", () => { + // 2000 ASCII chars / 800 maxChars -> about 3 segments, not 10 tiny ones. + const longLatinLine = "a".repeat(2000); + const chunks = chunkMarkdown(longLatinLine, { tokens: 200, overlap: 0 }); + expect(chunks.length).toBeLessThanOrEqual(5); + }); +}); + +describe("remapChunkLines", () => { + it("remaps chunk line numbers using a lineMap", () => { + // Simulate 5 content lines that came from JSONL lines [4, 6, 7, 10, 13] (1-indexed) + const lineMap = [4, 6, 7, 10, 13]; + + // Create chunks from content that has 5 lines + const content = "User: Hello\nAssistant: Hi\nUser: Question\nAssistant: Answer\nUser: Thanks"; + const chunks = chunkMarkdown(content, { tokens: 400, overlap: 0 }); + expect(chunks.length).toBeGreaterThan(0); + + // Before remapping, startLine/endLine reference content line numbers (1-indexed) + expect(chunks[0].startLine).toBe(1); + + // Remap + remapChunkLines(chunks, lineMap); + + // After remapping, line numbers should reference original JSONL lines + // Content line 1 → JSONL line 4, content line 5 → JSONL line 13 + expect(chunks[0].startLine).toBe(4); + const lastChunk = chunks[chunks.length - 1]; + expect(lastChunk.endLine).toBe(13); + }); + + it("preserves original line numbers when lineMap is undefined", () => { + const content = "Line one\nLine two\nLine three"; + const chunks = chunkMarkdown(content, { tokens: 400, overlap: 0 }); + const originalStart = chunks[0].startLine; + const originalEnd = chunks[chunks.length - 1].endLine; + + remapChunkLines(chunks, undefined); + + expect(chunks[0].startLine).toBe(originalStart); + expect(chunks[chunks.length - 1].endLine).toBe(originalEnd); + }); + + it("handles multi-chunk content with correct remapping", () => { + // Use small chunk size to force multiple chunks + // lineMap: 10 content lines from JSONL lines [2, 5, 8, 11, 14, 17, 20, 23, 26, 29] + const lineMap = [2, 5, 8, 11, 14, 17, 20, 23, 26, 29]; + const contentLines = lineMap.map((_, i) => + i % 2 === 0 ? `User: Message ${i}` : `Assistant: Reply ${i}`, + ); + const content = contentLines.join("\n"); + + // Use very small chunk size to force splitting + const chunks = chunkMarkdown(content, { tokens: 10, overlap: 0 }); + expect(chunks.length).toBeGreaterThan(1); + + remapChunkLines(chunks, lineMap); + + // First chunk should start at JSONL line 2 + expect(chunks[0].startLine).toBe(2); + // Last chunk should end at JSONL line 29 + expect(chunks[chunks.length - 1].endLine).toBe(29); + + // Each chunk's startLine should be ≤ its endLine + for (const chunk of chunks) { + expect(chunk.startLine).toBeLessThanOrEqual(chunk.endLine); + } + }); +}); diff --git a/src/memory-host-sdk/host/internal.ts b/src/memory-host-sdk/host/internal.ts new file mode 100644 index 00000000000..557e1a49c31 --- /dev/null +++ b/src/memory-host-sdk/host/internal.ts @@ -0,0 +1,504 @@ +import crypto from "node:crypto"; +import fsSync from "node:fs"; +import fs from "node:fs/promises"; +import path from "node:path"; +import { detectMime } from "../../media/mime.js"; +import { CHARS_PER_TOKEN_ESTIMATE, estimateStringChars } from "../../utils/cjk-chars.js"; +import { runTasksWithConcurrency } from "../../utils/run-with-concurrency.js"; +import { estimateStructuredEmbeddingInputBytes } from "./embedding-input-limits.js"; +import { buildTextEmbeddingInput, type EmbeddingInput } from "./embedding-inputs.js"; +import { isFileMissingError } from "./fs-utils.js"; +import { + buildMemoryMultimodalLabel, + classifyMemoryMultimodalPath, + type MemoryMultimodalModality, + type MemoryMultimodalSettings, +} from "./multimodal.js"; + +export type MemoryFileEntry = { + path: string; + absPath: string; + mtimeMs: number; + size: number; + hash: string; + dataHash?: string; + kind?: "markdown" | "multimodal"; + contentText?: string; + modality?: MemoryMultimodalModality; + mimeType?: string; +}; + +export type MemoryChunk = { + startLine: number; + endLine: number; + text: string; + hash: string; + embeddingInput?: EmbeddingInput; +}; + +export type MultimodalMemoryChunk = { + chunk: MemoryChunk; + structuredInputBytes: number; +}; + +const DISABLED_MULTIMODAL_SETTINGS: MemoryMultimodalSettings = { + enabled: false, + modalities: [], + maxFileBytes: 0, +}; + +export function ensureDir(dir: string): string { + try { + fsSync.mkdirSync(dir, { recursive: true }); + } catch {} + return dir; +} + +export function normalizeRelPath(value: string): string { + const trimmed = value.trim().replace(/^[./]+/, ""); + return trimmed.replace(/\\/g, "/"); +} + +export function normalizeExtraMemoryPaths(workspaceDir: string, extraPaths?: string[]): string[] { + if (!extraPaths?.length) { + return []; + } + const resolved = extraPaths + .map((value) => value.trim()) + .filter(Boolean) + .map((value) => + path.isAbsolute(value) ? path.resolve(value) : path.resolve(workspaceDir, value), + ); + return Array.from(new Set(resolved)); +} + +export function isMemoryPath(relPath: string): boolean { + const normalized = normalizeRelPath(relPath); + if (!normalized) { + return false; + } + if (normalized === "MEMORY.md" || normalized === "memory.md") { + return true; + } + return normalized.startsWith("memory/"); +} + +function isAllowedMemoryFilePath(filePath: string, multimodal?: MemoryMultimodalSettings): boolean { + if (filePath.endsWith(".md")) { + return true; + } + return ( + classifyMemoryMultimodalPath(filePath, multimodal ?? DISABLED_MULTIMODAL_SETTINGS) !== null + ); +} + +async function walkDir(dir: string, files: string[], multimodal?: MemoryMultimodalSettings) { + const entries = await fs.readdir(dir, { withFileTypes: true }); + for (const entry of entries) { + const full = path.join(dir, entry.name); + if (entry.isSymbolicLink()) { + continue; + } + if (entry.isDirectory()) { + await walkDir(full, files, multimodal); + continue; + } + if (!entry.isFile()) { + continue; + } + if (!isAllowedMemoryFilePath(full, multimodal)) { + continue; + } + files.push(full); + } +} + +export async function listMemoryFiles( + workspaceDir: string, + extraPaths?: string[], + multimodal?: MemoryMultimodalSettings, +): Promise { + const result: string[] = []; + const memoryFile = path.join(workspaceDir, "MEMORY.md"); + const altMemoryFile = path.join(workspaceDir, "memory.md"); + const memoryDir = path.join(workspaceDir, "memory"); + + const addMarkdownFile = async (absPath: string) => { + try { + const stat = await fs.lstat(absPath); + if (stat.isSymbolicLink() || !stat.isFile()) { + return; + } + if (!absPath.endsWith(".md")) { + return; + } + result.push(absPath); + } catch {} + }; + + await addMarkdownFile(memoryFile); + await addMarkdownFile(altMemoryFile); + try { + const dirStat = await fs.lstat(memoryDir); + if (!dirStat.isSymbolicLink() && dirStat.isDirectory()) { + await walkDir(memoryDir, result); + } + } catch {} + + const normalizedExtraPaths = normalizeExtraMemoryPaths(workspaceDir, extraPaths); + if (normalizedExtraPaths.length > 0) { + for (const inputPath of normalizedExtraPaths) { + try { + const stat = await fs.lstat(inputPath); + if (stat.isSymbolicLink()) { + continue; + } + if (stat.isDirectory()) { + await walkDir(inputPath, result, multimodal); + continue; + } + if (stat.isFile() && isAllowedMemoryFilePath(inputPath, multimodal)) { + result.push(inputPath); + } + } catch {} + } + } + if (result.length <= 1) { + return result; + } + const seen = new Set(); + const deduped: string[] = []; + for (const entry of result) { + let key = entry; + try { + key = await fs.realpath(entry); + } catch {} + if (seen.has(key)) { + continue; + } + seen.add(key); + deduped.push(entry); + } + return deduped; +} + +export function hashText(value: string): string { + return crypto.createHash("sha256").update(value).digest("hex"); +} + +export async function buildFileEntry( + absPath: string, + workspaceDir: string, + multimodal?: MemoryMultimodalSettings, +): Promise { + let stat; + try { + stat = await fs.stat(absPath); + } catch (err) { + if (isFileMissingError(err)) { + return null; + } + throw err; + } + const normalizedPath = path.relative(workspaceDir, absPath).replace(/\\/g, "/"); + const multimodalSettings = multimodal ?? DISABLED_MULTIMODAL_SETTINGS; + const modality = classifyMemoryMultimodalPath(absPath, multimodalSettings); + if (modality) { + if (stat.size > multimodalSettings.maxFileBytes) { + return null; + } + let buffer: Buffer; + try { + buffer = await fs.readFile(absPath); + } catch (err) { + if (isFileMissingError(err)) { + return null; + } + throw err; + } + const mimeType = await detectMime({ buffer: buffer.subarray(0, 512), filePath: absPath }); + if (!mimeType || !mimeType.startsWith(`${modality}/`)) { + return null; + } + const contentText = buildMemoryMultimodalLabel(modality, normalizedPath); + const dataHash = crypto.createHash("sha256").update(buffer).digest("hex"); + const chunkHash = hashText( + JSON.stringify({ + path: normalizedPath, + contentText, + mimeType, + dataHash, + }), + ); + return { + path: normalizedPath, + absPath, + mtimeMs: stat.mtimeMs, + size: stat.size, + hash: chunkHash, + dataHash, + kind: "multimodal", + contentText, + modality, + mimeType, + }; + } + let content: string; + try { + content = await fs.readFile(absPath, "utf-8"); + } catch (err) { + if (isFileMissingError(err)) { + return null; + } + throw err; + } + const hash = hashText(content); + return { + path: normalizedPath, + absPath, + mtimeMs: stat.mtimeMs, + size: stat.size, + hash, + kind: "markdown", + }; +} + +async function loadMultimodalEmbeddingInput( + entry: Pick< + MemoryFileEntry, + "absPath" | "contentText" | "mimeType" | "kind" | "size" | "dataHash" + >, +): Promise { + if (entry.kind !== "multimodal" || !entry.contentText || !entry.mimeType) { + return null; + } + let stat; + try { + stat = await fs.stat(entry.absPath); + } catch (err) { + if (isFileMissingError(err)) { + return null; + } + throw err; + } + if (stat.size !== entry.size) { + return null; + } + let buffer: Buffer; + try { + buffer = await fs.readFile(entry.absPath); + } catch (err) { + if (isFileMissingError(err)) { + return null; + } + throw err; + } + const dataHash = crypto.createHash("sha256").update(buffer).digest("hex"); + if (entry.dataHash && entry.dataHash !== dataHash) { + return null; + } + return { + text: entry.contentText, + parts: [ + { type: "text", text: entry.contentText }, + { + type: "inline-data", + mimeType: entry.mimeType, + data: buffer.toString("base64"), + }, + ], + }; +} + +export async function buildMultimodalChunkForIndexing( + entry: Pick< + MemoryFileEntry, + "absPath" | "contentText" | "mimeType" | "kind" | "hash" | "size" | "dataHash" + >, +): Promise { + const embeddingInput = await loadMultimodalEmbeddingInput(entry); + if (!embeddingInput) { + return null; + } + return { + chunk: { + startLine: 1, + endLine: 1, + text: entry.contentText ?? embeddingInput.text, + hash: entry.hash, + embeddingInput, + }, + structuredInputBytes: estimateStructuredEmbeddingInputBytes(embeddingInput), + }; +} + +export function chunkMarkdown( + content: string, + chunking: { tokens: number; overlap: number }, +): MemoryChunk[] { + const lines = content.split("\n"); + if (lines.length === 0) { + return []; + } + const maxChars = Math.max(32, chunking.tokens * CHARS_PER_TOKEN_ESTIMATE); + const overlapChars = Math.max(0, chunking.overlap * CHARS_PER_TOKEN_ESTIMATE); + const chunks: MemoryChunk[] = []; + + let current: Array<{ line: string; lineNo: number }> = []; + let currentChars = 0; + + const flush = () => { + if (current.length === 0) { + return; + } + const firstEntry = current[0]; + const lastEntry = current[current.length - 1]; + if (!firstEntry || !lastEntry) { + return; + } + const text = current.map((entry) => entry.line).join("\n"); + const startLine = firstEntry.lineNo; + const endLine = lastEntry.lineNo; + chunks.push({ + startLine, + endLine, + text, + hash: hashText(text), + embeddingInput: buildTextEmbeddingInput(text), + }); + }; + + const carryOverlap = () => { + if (overlapChars <= 0 || current.length === 0) { + current = []; + currentChars = 0; + return; + } + let acc = 0; + const kept: Array<{ line: string; lineNo: number }> = []; + for (let i = current.length - 1; i >= 0; i -= 1) { + const entry = current[i]; + if (!entry) { + continue; + } + acc += estimateStringChars(entry.line) + 1; + kept.unshift(entry); + if (acc >= overlapChars) { + break; + } + } + current = kept; + currentChars = kept.reduce((sum, entry) => sum + estimateStringChars(entry.line) + 1, 0); + }; + + for (let i = 0; i < lines.length; i += 1) { + const line = lines[i] ?? ""; + const lineNo = i + 1; + const segments: string[] = []; + if (line.length === 0) { + segments.push(""); + } else { + // First pass: slice at maxChars (preserves original behaviour for Latin). + // Second pass: if a segment's *weighted* size still exceeds the budget + // (happens for CJK-heavy text where 1 char ≈ 1 token), re-split it at + // chunking.tokens so the chunk stays within the token budget. + for (let start = 0; start < line.length; start += maxChars) { + const coarse = line.slice(start, start + maxChars); + if (estimateStringChars(coarse) > maxChars) { + const fineStep = Math.max(1, chunking.tokens); + for (let j = 0; j < coarse.length; ) { + let end = Math.min(j + fineStep, coarse.length); + // Avoid splitting inside a UTF-16 surrogate pair (CJK Extension B+). + if (end < coarse.length) { + const code = coarse.charCodeAt(end - 1); + if (code >= 0xd800 && code <= 0xdbff) { + end += 1; // include the low surrogate + } + } + segments.push(coarse.slice(j, end)); + j = end; // advance cursor to the adjusted boundary + } + } else { + segments.push(coarse); + } + } + } + for (const segment of segments) { + const lineSize = estimateStringChars(segment) + 1; + if (currentChars + lineSize > maxChars && current.length > 0) { + flush(); + carryOverlap(); + } + current.push({ line: segment, lineNo }); + currentChars += lineSize; + } + } + flush(); + return chunks; +} + +/** + * Remap chunk startLine/endLine from content-relative positions to original + * source file positions using a lineMap. Each entry in lineMap gives the + * 1-indexed source line for the corresponding 0-indexed content line. + * + * This is used for session JSONL files where buildSessionEntry() flattens + * messages into a plain-text string before chunking. Without remapping the + * stored line numbers would reference positions in the flattened text rather + * than the original JSONL file. + */ +export function remapChunkLines(chunks: MemoryChunk[], lineMap: number[] | undefined): void { + if (!lineMap || lineMap.length === 0) { + return; + } + for (const chunk of chunks) { + // startLine/endLine are 1-indexed; lineMap is 0-indexed by content line + chunk.startLine = lineMap[chunk.startLine - 1] ?? chunk.startLine; + chunk.endLine = lineMap[chunk.endLine - 1] ?? chunk.endLine; + } +} + +export function parseEmbedding(raw: string): number[] { + try { + const parsed = JSON.parse(raw) as number[]; + return Array.isArray(parsed) ? parsed : []; + } catch { + return []; + } +} + +export function cosineSimilarity(a: number[], b: number[]): number { + if (a.length === 0 || b.length === 0) { + return 0; + } + const len = Math.min(a.length, b.length); + let dot = 0; + let normA = 0; + let normB = 0; + for (let i = 0; i < len; i += 1) { + const av = a[i] ?? 0; + const bv = b[i] ?? 0; + dot += av * bv; + normA += av * av; + normB += bv * bv; + } + if (normA === 0 || normB === 0) { + return 0; + } + return dot / (Math.sqrt(normA) * Math.sqrt(normB)); +} + +export async function runWithConcurrency( + tasks: Array<() => Promise>, + limit: number, +): Promise { + const { results, firstError, hasError } = await runTasksWithConcurrency({ + tasks, + limit, + errorMode: "stop", + }); + if (hasError) { + throw firstError; + } + return results; +} diff --git a/src/memory-host-sdk/host/memory-schema.ts b/src/memory-host-sdk/host/memory-schema.ts new file mode 100644 index 00000000000..582cddeee1c --- /dev/null +++ b/src/memory-host-sdk/host/memory-schema.ts @@ -0,0 +1,102 @@ +import type { DatabaseSync } from "node:sqlite"; + +export function ensureMemoryIndexSchema(params: { + db: DatabaseSync; + embeddingCacheTable: string; + cacheEnabled: boolean; + ftsTable: string; + ftsEnabled: boolean; + ftsTokenizer?: "unicode61" | "trigram"; +}): { ftsAvailable: boolean; ftsError?: string } { + params.db.exec(` + CREATE TABLE IF NOT EXISTS meta ( + key TEXT PRIMARY KEY, + value TEXT NOT NULL + ); + `); + params.db.exec(` + CREATE TABLE IF NOT EXISTS files ( + path TEXT PRIMARY KEY, + source TEXT NOT NULL DEFAULT 'memory', + hash TEXT NOT NULL, + mtime INTEGER NOT NULL, + size INTEGER NOT NULL + ); + `); + params.db.exec(` + CREATE TABLE IF NOT EXISTS chunks ( + id TEXT PRIMARY KEY, + path TEXT NOT NULL, + source TEXT NOT NULL DEFAULT 'memory', + start_line INTEGER NOT NULL, + end_line INTEGER NOT NULL, + hash TEXT NOT NULL, + model TEXT NOT NULL, + text TEXT NOT NULL, + embedding TEXT NOT NULL, + updated_at INTEGER NOT NULL + ); + `); + if (params.cacheEnabled) { + params.db.exec(` + CREATE TABLE IF NOT EXISTS ${params.embeddingCacheTable} ( + provider TEXT NOT NULL, + model TEXT NOT NULL, + provider_key TEXT NOT NULL, + hash TEXT NOT NULL, + embedding TEXT NOT NULL, + dims INTEGER, + updated_at INTEGER NOT NULL, + PRIMARY KEY (provider, model, provider_key, hash) + ); + `); + params.db.exec( + `CREATE INDEX IF NOT EXISTS idx_embedding_cache_updated_at ON ${params.embeddingCacheTable}(updated_at);`, + ); + } + + let ftsAvailable = false; + let ftsError: string | undefined; + if (params.ftsEnabled) { + try { + const tokenizer = params.ftsTokenizer ?? "unicode61"; + const tokenizeClause = tokenizer === "trigram" ? `, tokenize='trigram case_sensitive 0'` : ""; + params.db.exec( + `CREATE VIRTUAL TABLE IF NOT EXISTS ${params.ftsTable} USING fts5(\n` + + ` text,\n` + + ` id UNINDEXED,\n` + + ` path UNINDEXED,\n` + + ` source UNINDEXED,\n` + + ` model UNINDEXED,\n` + + ` start_line UNINDEXED,\n` + + ` end_line UNINDEXED\n` + + `${tokenizeClause});`, + ); + ftsAvailable = true; + } catch (err) { + const message = err instanceof Error ? err.message : String(err); + ftsAvailable = false; + ftsError = message; + } + } + + ensureColumn(params.db, "files", "source", "TEXT NOT NULL DEFAULT 'memory'"); + ensureColumn(params.db, "chunks", "source", "TEXT NOT NULL DEFAULT 'memory'"); + params.db.exec(`CREATE INDEX IF NOT EXISTS idx_chunks_path ON chunks(path);`); + params.db.exec(`CREATE INDEX IF NOT EXISTS idx_chunks_source ON chunks(source);`); + + return { ftsAvailable, ...(ftsError ? { ftsError } : {}) }; +} + +function ensureColumn( + db: DatabaseSync, + table: "files" | "chunks", + column: string, + definition: string, +): void { + const rows = db.prepare(`PRAGMA table_info(${table})`).all() as Array<{ name: string }>; + if (rows.some((row) => row.name === column)) { + return; + } + db.exec(`ALTER TABLE ${table} ADD COLUMN ${column} ${definition}`); +} diff --git a/src/memory-host-sdk/host/multimodal.ts b/src/memory-host-sdk/host/multimodal.ts new file mode 100644 index 00000000000..df72ed8c495 --- /dev/null +++ b/src/memory-host-sdk/host/multimodal.ts @@ -0,0 +1,118 @@ +const MEMORY_MULTIMODAL_SPECS = { + image: { + labelPrefix: "Image file", + extensions: [".jpg", ".jpeg", ".png", ".webp", ".gif", ".heic", ".heif"], + }, + audio: { + labelPrefix: "Audio file", + extensions: [".mp3", ".wav", ".ogg", ".opus", ".m4a", ".aac", ".flac"], + }, +} as const; + +export type MemoryMultimodalModality = keyof typeof MEMORY_MULTIMODAL_SPECS; +export const MEMORY_MULTIMODAL_MODALITIES = Object.keys( + MEMORY_MULTIMODAL_SPECS, +) as MemoryMultimodalModality[]; +export type MemoryMultimodalSelection = MemoryMultimodalModality | "all"; + +export type MemoryMultimodalSettings = { + enabled: boolean; + modalities: MemoryMultimodalModality[]; + maxFileBytes: number; +}; + +export const DEFAULT_MEMORY_MULTIMODAL_MAX_FILE_BYTES = 10 * 1024 * 1024; + +export function normalizeMemoryMultimodalModalities( + raw: MemoryMultimodalSelection[] | undefined, +): MemoryMultimodalModality[] { + if (raw === undefined || raw.includes("all")) { + return [...MEMORY_MULTIMODAL_MODALITIES]; + } + const normalized = new Set(); + for (const value of raw) { + if (value === "image" || value === "audio") { + normalized.add(value); + } + } + return Array.from(normalized); +} + +export function normalizeMemoryMultimodalSettings(raw: { + enabled?: boolean; + modalities?: MemoryMultimodalSelection[]; + maxFileBytes?: number; +}): MemoryMultimodalSettings { + const enabled = raw.enabled === true; + const maxFileBytes = + typeof raw.maxFileBytes === "number" && Number.isFinite(raw.maxFileBytes) + ? Math.max(1, Math.floor(raw.maxFileBytes)) + : DEFAULT_MEMORY_MULTIMODAL_MAX_FILE_BYTES; + return { + enabled, + modalities: enabled ? normalizeMemoryMultimodalModalities(raw.modalities) : [], + maxFileBytes, + }; +} + +export function isMemoryMultimodalEnabled(settings: MemoryMultimodalSettings): boolean { + return settings.enabled && settings.modalities.length > 0; +} + +export function getMemoryMultimodalExtensions( + modality: MemoryMultimodalModality, +): readonly string[] { + return MEMORY_MULTIMODAL_SPECS[modality].extensions; +} + +export function buildMemoryMultimodalLabel( + modality: MemoryMultimodalModality, + normalizedPath: string, +): string { + return `${MEMORY_MULTIMODAL_SPECS[modality].labelPrefix}: ${normalizedPath}`; +} + +export function buildCaseInsensitiveExtensionGlob(extension: string): string { + const normalized = extension.trim().replace(/^\./, "").toLowerCase(); + if (!normalized) { + return "*"; + } + const parts = Array.from(normalized, (char) => `[${char.toLowerCase()}${char.toUpperCase()}]`); + return `*.${parts.join("")}`; +} + +export function classifyMemoryMultimodalPath( + filePath: string, + settings: MemoryMultimodalSettings, +): MemoryMultimodalModality | null { + if (!isMemoryMultimodalEnabled(settings)) { + return null; + } + const lower = filePath.trim().toLowerCase(); + for (const modality of settings.modalities) { + for (const extension of getMemoryMultimodalExtensions(modality)) { + if (lower.endsWith(extension)) { + return modality; + } + } + } + return null; +} + +export function normalizeGeminiEmbeddingModelForMemory(model: string): string { + const trimmed = model.trim(); + if (!trimmed) { + return ""; + } + return trimmed.replace(/^models\//, "").replace(/^(gemini|google)\//, ""); +} + +export function supportsMemoryMultimodalEmbeddings(params: { + provider: string; + model: string; +}): boolean { + if (params.provider !== "gemini") { + return false; + } + return normalizeGeminiEmbeddingModelForMemory(params.model) === "gemini-embedding-2-preview"; +} diff --git a/src/memory-host-sdk/host/node-llama.ts b/src/memory-host-sdk/host/node-llama.ts new file mode 100644 index 00000000000..9327a1c4503 --- /dev/null +++ b/src/memory-host-sdk/host/node-llama.ts @@ -0,0 +1,3 @@ +export async function importNodeLlamaCpp() { + return import("node-llama-cpp"); +} diff --git a/src/memory-host-sdk/host/post-json.test.ts b/src/memory-host-sdk/host/post-json.test.ts new file mode 100644 index 00000000000..d09fa694782 --- /dev/null +++ b/src/memory-host-sdk/host/post-json.test.ts @@ -0,0 +1,60 @@ +import { beforeAll, beforeEach, describe, expect, it, vi } from "vitest"; + +vi.mock("./remote-http.js", () => ({ + withRemoteHttpResponse: vi.fn(), +})); + +let postJson: typeof import("./post-json.js").postJson; +let withRemoteHttpResponse: typeof import("./remote-http.js").withRemoteHttpResponse; + +describe("postJson", () => { + let remoteHttpMock: ReturnType>; + + beforeAll(async () => { + ({ postJson } = await import("./post-json.js")); + ({ withRemoteHttpResponse } = await import("./remote-http.js")); + remoteHttpMock = vi.mocked(withRemoteHttpResponse); + }); + + beforeEach(() => { + vi.clearAllMocks(); + }); + + it("parses JSON payload on successful response", async () => { + remoteHttpMock.mockImplementationOnce(async (params) => { + return await params.onResponse( + new Response(JSON.stringify({ data: [{ embedding: [1, 2] }] }), { status: 200 }), + ); + }); + + const result = await postJson({ + url: "https://memory.example/v1/post", + headers: { Authorization: "Bearer test" }, + body: { input: ["x"] }, + errorPrefix: "post failed", + parse: (payload) => payload, + }); + + expect(result).toEqual({ data: [{ embedding: [1, 2] }] }); + }); + + it("attaches status to thrown error when requested", async () => { + remoteHttpMock.mockImplementationOnce(async (params) => { + return await params.onResponse(new Response("bad gateway", { status: 502 })); + }); + + await expect( + postJson({ + url: "https://memory.example/v1/post", + headers: {}, + body: {}, + errorPrefix: "post failed", + attachStatus: true, + parse: () => ({}), + }), + ).rejects.toMatchObject({ + message: expect.stringContaining("post failed: 502 bad gateway"), + status: 502, + }); + }); +}); diff --git a/src/memory-host-sdk/host/post-json.ts b/src/memory-host-sdk/host/post-json.ts new file mode 100644 index 00000000000..8eaee669cac --- /dev/null +++ b/src/memory-host-sdk/host/post-json.ts @@ -0,0 +1,35 @@ +import type { SsrFPolicy } from "../../infra/net/ssrf.js"; +import { withRemoteHttpResponse } from "./remote-http.js"; + +export async function postJson(params: { + url: string; + headers: Record; + ssrfPolicy?: SsrFPolicy; + body: unknown; + errorPrefix: string; + attachStatus?: boolean; + parse: (payload: unknown) => T | Promise; +}): Promise { + return await withRemoteHttpResponse({ + url: params.url, + ssrfPolicy: params.ssrfPolicy, + init: { + method: "POST", + headers: params.headers, + body: JSON.stringify(params.body), + }, + onResponse: async (res) => { + if (!res.ok) { + const text = await res.text(); + const err = new Error(`${params.errorPrefix}: ${res.status} ${text}`) as Error & { + status?: number; + }; + if (params.attachStatus) { + err.status = res.status; + } + throw err; + } + return await params.parse(await res.json()); + }, + }); +} diff --git a/src/memory-host-sdk/host/qmd-process.test.ts b/src/memory-host-sdk/host/qmd-process.test.ts new file mode 100644 index 00000000000..f706f525af3 --- /dev/null +++ b/src/memory-host-sdk/host/qmd-process.test.ts @@ -0,0 +1,154 @@ +import { EventEmitter } from "node:events"; +import fs from "node:fs/promises"; +import os from "node:os"; +import path from "node:path"; +import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"; + +const spawnMock = vi.hoisted(() => vi.fn()); + +vi.mock("node:child_process", async () => { + const actual = await vi.importActual("node:child_process"); + return { + ...actual, + spawn: spawnMock, + }; +}); + +import { checkQmdBinaryAvailability, resolveCliSpawnInvocation } from "./qmd-process.js"; + +function createMockChild() { + const child = new EventEmitter() as EventEmitter & { + kill: ReturnType; + }; + child.kill = vi.fn(); + return child; +} + +let tempDir = ""; +let platformSpy: { mockRestore(): void } | null = null; +const originalPath = process.env.PATH; +const originalPathExt = process.env.PATHEXT; + +beforeEach(async () => { + tempDir = await fs.mkdtemp(path.join(os.tmpdir(), "openclaw-qmd-win-spawn-")); + platformSpy = vi.spyOn(process, "platform", "get").mockReturnValue("win32"); +}); + +afterEach(async () => { + platformSpy?.mockRestore(); + process.env.PATH = originalPath; + process.env.PATHEXT = originalPathExt; + spawnMock.mockReset(); + if (tempDir) { + await fs.rm(tempDir, { recursive: true, force: true }); + tempDir = ""; + } +}); + +describe("resolveCliSpawnInvocation", () => { + it("unwraps npm cmd shims to a direct node entrypoint", async () => { + const binDir = path.join(tempDir, "node_modules", ".bin"); + const packageDir = path.join(tempDir, "node_modules", "qmd"); + const scriptPath = path.join(packageDir, "dist", "cli.js"); + await fs.mkdir(path.dirname(scriptPath), { recursive: true }); + await fs.mkdir(binDir, { recursive: true }); + await fs.writeFile(path.join(binDir, "qmd.cmd"), "@echo off\r\n", "utf8"); + await fs.writeFile( + path.join(packageDir, "package.json"), + JSON.stringify({ name: "qmd", version: "0.0.0", bin: { qmd: "dist/cli.js" } }), + "utf8", + ); + await fs.writeFile(scriptPath, "module.exports = {};\n", "utf8"); + + process.env.PATH = `${binDir};${originalPath ?? ""}`; + process.env.PATHEXT = ".CMD;.EXE"; + + const invocation = resolveCliSpawnInvocation({ + command: "qmd", + args: ["query", "hello"], + env: process.env, + packageName: "qmd", + }); + + expect(invocation.command).toBe(process.execPath); + expect(invocation.argv).toEqual([scriptPath, "query", "hello"]); + expect(invocation.shell).not.toBe(true); + expect(invocation.windowsHide).toBe(true); + }); + + it("fails closed when a Windows cmd shim cannot be resolved without shell execution", async () => { + const binDir = path.join(tempDir, "bad-bin"); + await fs.mkdir(binDir, { recursive: true }); + await fs.writeFile(path.join(binDir, "qmd.cmd"), "@echo off\r\nREM no entrypoint\r\n", "utf8"); + + process.env.PATH = `${binDir};${originalPath ?? ""}`; + process.env.PATHEXT = ".CMD;.EXE"; + + expect(() => + resolveCliSpawnInvocation({ + command: "qmd", + args: ["query", "hello"], + env: process.env, + packageName: "qmd", + }), + ).toThrow(/without shell execution/); + }); + + it("keeps bare commands bare when no Windows wrapper exists on PATH", () => { + process.env.PATH = originalPath ?? ""; + process.env.PATHEXT = ".CMD;.EXE"; + + const invocation = resolveCliSpawnInvocation({ + command: "qmd", + args: ["query", "hello"], + env: process.env, + packageName: "qmd", + }); + + expect(invocation.command).toBe("qmd"); + expect(invocation.argv).toEqual(["query", "hello"]); + expect(invocation.shell).not.toBe(true); + }); +}); + +describe("checkQmdBinaryAvailability", () => { + it("returns available when the qmd process spawns successfully", async () => { + const child = createMockChild(); + spawnMock.mockImplementationOnce(() => { + queueMicrotask(() => child.emit("spawn")); + return child; + }); + + await expect( + checkQmdBinaryAvailability({ command: "qmd", env: process.env, cwd: tempDir }), + ).resolves.toEqual({ available: true }); + expect(child.kill).toHaveBeenCalled(); + }); + + it("returns unavailable when the qmd process cannot be spawned", async () => { + const child = createMockChild(); + const err = Object.assign(new Error("spawn qmd ENOENT"), { code: "ENOENT" }); + spawnMock.mockImplementationOnce(() => { + queueMicrotask(() => child.emit("error", err)); + return child; + }); + + await expect( + checkQmdBinaryAvailability({ command: "qmd", env: process.env, cwd: tempDir }), + ).resolves.toEqual({ available: false, error: "spawn qmd ENOENT" }); + }); + + it("does not treat close-before-spawn as a successful availability probe", async () => { + const child = createMockChild(); + const err = Object.assign(new Error("spawn qmd ENOENT"), { code: "ENOENT" }); + spawnMock.mockImplementationOnce(() => { + queueMicrotask(() => child.emit("close")); + queueMicrotask(() => child.emit("error", err)); + return child; + }); + + await expect( + checkQmdBinaryAvailability({ command: "qmd", env: process.env, cwd: tempDir }), + ).resolves.toEqual({ available: false, error: "spawn qmd ENOENT" }); + }); +}); diff --git a/src/memory-host-sdk/host/qmd-process.ts b/src/memory-host-sdk/host/qmd-process.ts new file mode 100644 index 00000000000..eeccae9654e --- /dev/null +++ b/src/memory-host-sdk/host/qmd-process.ts @@ -0,0 +1,184 @@ +import { spawn } from "node:child_process"; +import { + materializeWindowsSpawnProgram, + resolveWindowsSpawnProgram, +} from "../../plugin-sdk/windows-spawn.js"; + +export type CliSpawnInvocation = { + command: string; + argv: string[]; + shell?: boolean; + windowsHide?: boolean; +}; + +export type QmdBinaryAvailability = { + available: boolean; + error?: string; +}; + +export function resolveCliSpawnInvocation(params: { + command: string; + args: string[]; + env: NodeJS.ProcessEnv; + packageName: string; +}): CliSpawnInvocation { + const program = resolveWindowsSpawnProgram({ + command: params.command, + platform: process.platform, + env: params.env, + execPath: process.execPath, + packageName: params.packageName, + allowShellFallback: false, + }); + return materializeWindowsSpawnProgram(program, params.args); +} + +export async function checkQmdBinaryAvailability(params: { + command: string; + env: NodeJS.ProcessEnv; + cwd?: string; + timeoutMs?: number; +}): Promise { + let spawnInvocation: CliSpawnInvocation; + try { + spawnInvocation = resolveCliSpawnInvocation({ + command: params.command, + args: [], + env: params.env, + packageName: "qmd", + }); + } catch (err) { + return { available: false, error: formatQmdAvailabilityError(err) }; + } + + return await new Promise((resolve) => { + let settled = false; + let didSpawn = false; + const finish = (result: QmdBinaryAvailability) => { + if (settled) { + return; + } + settled = true; + if (timer) { + clearTimeout(timer); + } + resolve(result); + }; + + const child = spawn(spawnInvocation.command, spawnInvocation.argv, { + env: params.env, + cwd: params.cwd ?? process.cwd(), + shell: spawnInvocation.shell, + windowsHide: spawnInvocation.windowsHide, + stdio: "ignore", + }); + const timer = setTimeout(() => { + child.kill("SIGKILL"); + finish({ + available: false, + error: `spawn ${params.command} timed out after ${params.timeoutMs ?? 2_000}ms`, + }); + }, params.timeoutMs ?? 2_000); + + child.once("error", (err) => { + finish({ available: false, error: formatQmdAvailabilityError(err) }); + }); + child.once("spawn", () => { + didSpawn = true; + child.kill(); + finish({ available: true }); + }); + child.once("close", () => { + if (!didSpawn) { + return; + } + finish({ available: true }); + }); + }); +} + +export async function runCliCommand(params: { + commandSummary: string; + spawnInvocation: CliSpawnInvocation; + env: NodeJS.ProcessEnv; + cwd: string; + timeoutMs?: number; + maxOutputChars: number; + discardStdout?: boolean; +}): Promise<{ stdout: string; stderr: string }> { + return await new Promise((resolve, reject) => { + const child = spawn(params.spawnInvocation.command, params.spawnInvocation.argv, { + env: params.env, + cwd: params.cwd, + shell: params.spawnInvocation.shell, + windowsHide: params.spawnInvocation.windowsHide, + }); + let stdout = ""; + let stderr = ""; + let stdoutTruncated = false; + let stderrTruncated = false; + const discardStdout = params.discardStdout === true; + const timer = params.timeoutMs + ? setTimeout(() => { + child.kill("SIGKILL"); + reject(new Error(`${params.commandSummary} timed out after ${params.timeoutMs}ms`)); + }, params.timeoutMs) + : null; + child.stdout.on("data", (data) => { + if (discardStdout) { + return; + } + const next = appendOutputWithCap(stdout, data.toString("utf8"), params.maxOutputChars); + stdout = next.text; + stdoutTruncated = stdoutTruncated || next.truncated; + }); + child.stderr.on("data", (data) => { + const next = appendOutputWithCap(stderr, data.toString("utf8"), params.maxOutputChars); + stderr = next.text; + stderrTruncated = stderrTruncated || next.truncated; + }); + child.on("error", (err) => { + if (timer) { + clearTimeout(timer); + } + reject(err); + }); + child.on("close", (code) => { + if (timer) { + clearTimeout(timer); + } + if (!discardStdout && (stdoutTruncated || stderrTruncated)) { + reject( + new Error( + `${params.commandSummary} produced too much output (limit ${params.maxOutputChars} chars)`, + ), + ); + return; + } + if (code === 0) { + resolve({ stdout, stderr }); + } else { + reject(new Error(`${params.commandSummary} failed (code ${code}): ${stderr || stdout}`)); + } + }); + }); +} + +function appendOutputWithCap( + current: string, + chunk: string, + maxChars: number, +): { text: string; truncated: boolean } { + const appended = current + chunk; + if (appended.length <= maxChars) { + return { text: appended, truncated: false }; + } + return { text: appended.slice(-maxChars), truncated: true }; +} + +function formatQmdAvailabilityError(err: unknown): string { + if (err instanceof Error && err.message) { + return err.message; + } + return String(err); +} diff --git a/src/memory-host-sdk/host/qmd-query-parser.test.ts b/src/memory-host-sdk/host/qmd-query-parser.test.ts new file mode 100644 index 00000000000..34134be5cd4 --- /dev/null +++ b/src/memory-host-sdk/host/qmd-query-parser.test.ts @@ -0,0 +1,64 @@ +import { describe, expect, it } from "vitest"; +import { parseQmdQueryJson } from "./qmd-query-parser.js"; + +describe("parseQmdQueryJson", () => { + it("parses clean qmd JSON output", () => { + const results = parseQmdQueryJson('[{"docid":"abc","score":1,"snippet":"@@ -1,1\\none"}]', ""); + expect(results).toEqual([ + { + docid: "abc", + score: 1, + snippet: "@@ -1,1\none", + }, + ]); + }); + + it("extracts embedded result arrays from noisy stdout", () => { + const results = parseQmdQueryJson( + `initializing +{"payload":"ok"} +[{"docid":"abc","score":0.5}] +complete`, + "", + ); + expect(results).toEqual([{ docid: "abc", score: 0.5 }]); + }); + + it("preserves explicit qmd line metadata when present", () => { + const results = parseQmdQueryJson( + '[{"docid":"abc","score":0.5,"start_line":4,"end_line":6,"snippet":"@@ -10,1\\nignored"}]', + "", + ); + expect(results).toEqual([ + { + docid: "abc", + score: 0.5, + snippet: "@@ -10,1\nignored", + startLine: 4, + endLine: 6, + }, + ]); + }); + + it("treats plain-text no-results from stderr as an empty result set", () => { + const results = parseQmdQueryJson("", "No results found\n"); + expect(results).toEqual([]); + }); + + it("treats prefixed no-results marker output as an empty result set", () => { + expect(parseQmdQueryJson("warning: no results found", "")).toEqual([]); + expect(parseQmdQueryJson("", "[qmd] warning: no results found\n")).toEqual([]); + }); + + it("does not treat arbitrary non-marker text as no-results output", () => { + expect(() => + parseQmdQueryJson("warning: search completed; no results found for this query", ""), + ).toThrow(/qmd query returned invalid JSON/i); + }); + + it("throws when stdout cannot be interpreted as qmd JSON", () => { + expect(() => parseQmdQueryJson("this is not json", "")).toThrow( + /qmd query returned invalid JSON/i, + ); + }); +}); diff --git a/src/memory-host-sdk/host/qmd-query-parser.ts b/src/memory-host-sdk/host/qmd-query-parser.ts new file mode 100644 index 00000000000..680999924e5 --- /dev/null +++ b/src/memory-host-sdk/host/qmd-query-parser.ts @@ -0,0 +1,151 @@ +import { createSubsystemLogger } from "../../logging/subsystem.js"; + +const log = createSubsystemLogger("memory"); + +export type QmdQueryResult = { + docid?: string; + score?: number; + collection?: string; + file?: string; + snippet?: string; + body?: string; + startLine?: number; + endLine?: number; +}; + +export function parseQmdQueryJson(stdout: string, stderr: string): QmdQueryResult[] { + const trimmedStdout = stdout.trim(); + const trimmedStderr = stderr.trim(); + const stdoutIsMarker = trimmedStdout.length > 0 && isQmdNoResultsOutput(trimmedStdout); + const stderrIsMarker = trimmedStderr.length > 0 && isQmdNoResultsOutput(trimmedStderr); + if (stdoutIsMarker || (!trimmedStdout && stderrIsMarker)) { + return []; + } + if (!trimmedStdout) { + const context = trimmedStderr ? ` (stderr: ${summarizeQmdStderr(trimmedStderr)})` : ""; + const message = `stdout empty${context}`; + log.warn(`qmd query returned invalid JSON: ${message}`); + throw new Error(`qmd query returned invalid JSON: ${message}`); + } + try { + const parsed = parseQmdQueryResultArray(trimmedStdout); + if (parsed !== null) { + return parsed; + } + const noisyPayload = extractFirstJsonArray(trimmedStdout); + if (!noisyPayload) { + throw new Error("qmd query JSON response was not an array"); + } + const fallback = parseQmdQueryResultArray(noisyPayload); + if (fallback !== null) { + return fallback; + } + throw new Error("qmd query JSON response was not an array"); + } catch (err) { + const message = err instanceof Error ? err.message : String(err); + log.warn(`qmd query returned invalid JSON: ${message}`); + throw new Error(`qmd query returned invalid JSON: ${message}`, { cause: err }); + } +} + +function isQmdNoResultsOutput(raw: string): boolean { + const lines = raw + .split(/\r?\n/) + .map((line) => line.trim().toLowerCase().replace(/\s+/g, " ")) + .filter((line) => line.length > 0); + return lines.some((line) => isQmdNoResultsLine(line)); +} + +function isQmdNoResultsLine(line: string): boolean { + if (line === "no results found" || line === "no results found.") { + return true; + } + return /^(?:\[[^\]]+\]\s*)?(?:(?:warn(?:ing)?|info|error|qmd)\s*:\s*)+no results found\.?$/.test( + line, + ); +} + +function summarizeQmdStderr(raw: string): string { + return raw.length <= 120 ? raw : `${raw.slice(0, 117)}...`; +} + +function parseQmdQueryResultArray(raw: string): QmdQueryResult[] | null { + try { + const parsed = JSON.parse(raw) as unknown; + if (!Array.isArray(parsed)) { + return null; + } + return parsed.map((item) => { + if (typeof item !== "object" || item === null) { + return item as QmdQueryResult; + } + const record = item as Record; + const docid = typeof record.docid === "string" ? record.docid : undefined; + const score = + typeof record.score === "number" && Number.isFinite(record.score) + ? record.score + : undefined; + const collection = typeof record.collection === "string" ? record.collection : undefined; + const file = typeof record.file === "string" ? record.file : undefined; + const snippet = typeof record.snippet === "string" ? record.snippet : undefined; + const body = typeof record.body === "string" ? record.body : undefined; + return { + docid, + score, + collection, + file, + snippet, + body, + startLine: parseQmdLineNumber(record.start_line ?? record.startLine), + endLine: parseQmdLineNumber(record.end_line ?? record.endLine), + } as QmdQueryResult; + }); + } catch { + return null; + } +} + +function parseQmdLineNumber(value: unknown): number | undefined { + return typeof value === "number" && Number.isFinite(value) && value > 0 ? value : undefined; +} + +function extractFirstJsonArray(raw: string): string | null { + const start = raw.indexOf("["); + if (start < 0) { + return null; + } + let depth = 0; + let inString = false; + let escaped = false; + for (let i = start; i < raw.length; i += 1) { + const char = raw[i]; + if (char === undefined) { + break; + } + if (inString) { + if (escaped) { + escaped = false; + continue; + } + if (char === "\\") { + escaped = true; + } else if (char === '"') { + inString = false; + } + continue; + } + if (char === '"') { + inString = true; + continue; + } + if (char === "[") { + depth += 1; + } else if (char === "]") { + depth -= 1; + if (depth === 0) { + return raw.slice(start, i + 1); + } + } + } + return null; +} diff --git a/src/memory-host-sdk/host/qmd-scope.test.ts b/src/memory-host-sdk/host/qmd-scope.test.ts new file mode 100644 index 00000000000..5a826e9c9b3 --- /dev/null +++ b/src/memory-host-sdk/host/qmd-scope.test.ts @@ -0,0 +1,54 @@ +import { describe, expect, it } from "vitest"; +import type { ResolvedQmdConfig } from "./backend-config.js"; +import { deriveQmdScopeChannel, deriveQmdScopeChatType, isQmdScopeAllowed } from "./qmd-scope.js"; + +describe("qmd scope", () => { + const allowDirect: ResolvedQmdConfig["scope"] = { + default: "deny", + rules: [{ action: "allow", match: { chatType: "direct" } }], + }; + + it("derives channel and chat type from canonical keys once", () => { + expect(deriveQmdScopeChannel("Workspace:group:123")).toBe("workspace"); + expect(deriveQmdScopeChatType("Workspace:group:123")).toBe("group"); + }); + + it("derives channel and chat type from stored key suffixes", () => { + expect(deriveQmdScopeChannel("agent:agent-1:workspace:channel:chan-123")).toBe("workspace"); + expect(deriveQmdScopeChatType("agent:agent-1:workspace:channel:chan-123")).toBe("channel"); + }); + + it("treats parsed keys with no chat prefix as direct", () => { + expect(deriveQmdScopeChannel("agent:agent-1:peer-direct")).toBeUndefined(); + expect(deriveQmdScopeChatType("agent:agent-1:peer-direct")).toBe("direct"); + expect(isQmdScopeAllowed(allowDirect, "agent:agent-1:peer-direct")).toBe(true); + expect(isQmdScopeAllowed(allowDirect, "agent:agent-1:peer:group:abc")).toBe(false); + }); + + it("applies scoped key-prefix checks against normalized key", () => { + const scope: ResolvedQmdConfig["scope"] = { + default: "deny", + rules: [{ action: "allow", match: { keyPrefix: "workspace:" } }], + }; + expect(isQmdScopeAllowed(scope, "agent:agent-1:workspace:group:123")).toBe(true); + expect(isQmdScopeAllowed(scope, "agent:agent-1:other:group:123")).toBe(false); + }); + + it("supports rawKeyPrefix matches for agent-prefixed keys", () => { + const scope: ResolvedQmdConfig["scope"] = { + default: "allow", + rules: [{ action: "deny", match: { rawKeyPrefix: "agent:main:discord:" } }], + }; + expect(isQmdScopeAllowed(scope, "agent:main:discord:channel:c123")).toBe(false); + expect(isQmdScopeAllowed(scope, "agent:main:slack:channel:c123")).toBe(true); + }); + + it("keeps legacy agent-prefixed keyPrefix rules working", () => { + const scope: ResolvedQmdConfig["scope"] = { + default: "allow", + rules: [{ action: "deny", match: { keyPrefix: "agent:main:discord:" } }], + }; + expect(isQmdScopeAllowed(scope, "agent:main:discord:channel:c123")).toBe(false); + expect(isQmdScopeAllowed(scope, "agent:main:slack:channel:c123")).toBe(true); + }); +}); diff --git a/src/memory-host-sdk/host/qmd-scope.ts b/src/memory-host-sdk/host/qmd-scope.ts new file mode 100644 index 00000000000..a206cc9c2bd --- /dev/null +++ b/src/memory-host-sdk/host/qmd-scope.ts @@ -0,0 +1,106 @@ +import { parseAgentSessionKey } from "../../sessions/session-key-utils.js"; +import type { ResolvedQmdConfig } from "./backend-config.js"; + +type ParsedQmdSessionScope = { + channel?: string; + chatType?: "channel" | "group" | "direct"; + normalizedKey?: string; +}; + +export function isQmdScopeAllowed(scope: ResolvedQmdConfig["scope"], sessionKey?: string): boolean { + if (!scope) { + return true; + } + const parsed = parseQmdSessionScope(sessionKey); + const channel = parsed.channel; + const chatType = parsed.chatType; + const normalizedKey = parsed.normalizedKey ?? ""; + const rawKey = sessionKey?.trim().toLowerCase() ?? ""; + for (const rule of scope.rules ?? []) { + if (!rule) { + continue; + } + const match = rule.match ?? {}; + if (match.channel && match.channel !== channel) { + continue; + } + if (match.chatType && match.chatType !== chatType) { + continue; + } + const normalizedPrefix = match.keyPrefix?.trim().toLowerCase() || undefined; + const rawPrefix = match.rawKeyPrefix?.trim().toLowerCase() || undefined; + + if (rawPrefix && !rawKey.startsWith(rawPrefix)) { + continue; + } + if (normalizedPrefix) { + // Backward compat: older configs used `keyPrefix: "agent::..."` to match raw keys. + const isLegacyRaw = normalizedPrefix.startsWith("agent:"); + if (isLegacyRaw) { + if (!rawKey.startsWith(normalizedPrefix)) { + continue; + } + } else if (!normalizedKey.startsWith(normalizedPrefix)) { + continue; + } + } + return rule.action === "allow"; + } + const fallback = scope.default ?? "allow"; + return fallback === "allow"; +} + +export function deriveQmdScopeChannel(key?: string): string | undefined { + return parseQmdSessionScope(key).channel; +} + +export function deriveQmdScopeChatType(key?: string): "channel" | "group" | "direct" | undefined { + return parseQmdSessionScope(key).chatType; +} + +function parseQmdSessionScope(key?: string): ParsedQmdSessionScope { + const normalized = normalizeQmdSessionKey(key); + if (!normalized) { + return {}; + } + const parts = normalized.split(":").filter(Boolean); + let chatType: ParsedQmdSessionScope["chatType"]; + if ( + parts.length >= 2 && + (parts[1] === "group" || parts[1] === "channel" || parts[1] === "direct" || parts[1] === "dm") + ) { + if (parts.includes("group")) { + chatType = "group"; + } else if (parts.includes("channel")) { + chatType = "channel"; + } + return { + normalizedKey: normalized, + channel: parts[0]?.toLowerCase(), + chatType: chatType ?? "direct", + }; + } + if (normalized.includes(":group:")) { + return { normalizedKey: normalized, chatType: "group" }; + } + if (normalized.includes(":channel:")) { + return { normalizedKey: normalized, chatType: "channel" }; + } + return { normalizedKey: normalized, chatType: "direct" }; +} + +function normalizeQmdSessionKey(key?: string): string | undefined { + if (!key) { + return undefined; + } + const trimmed = key.trim(); + if (!trimmed) { + return undefined; + } + const parsed = parseAgentSessionKey(trimmed); + const normalized = (parsed?.rest ?? trimmed).toLowerCase(); + if (normalized.startsWith("subagent:")) { + return undefined; + } + return normalized; +} diff --git a/src/memory-host-sdk/host/query-expansion.test.ts b/src/memory-host-sdk/host/query-expansion.test.ts new file mode 100644 index 00000000000..f1e9bff520e --- /dev/null +++ b/src/memory-host-sdk/host/query-expansion.test.ts @@ -0,0 +1,244 @@ +import { describe, expect, it } from "vitest"; +import { expandQueryForFts, extractKeywords } from "./query-expansion.js"; + +describe("extractKeywords", () => { + it("extracts keywords from English conversational query", () => { + const keywords = extractKeywords("that thing we discussed about the API"); + expect(keywords).toContain("discussed"); + expect(keywords).toContain("api"); + // Should not include stop words + expect(keywords).not.toContain("that"); + expect(keywords).not.toContain("thing"); + expect(keywords).not.toContain("we"); + expect(keywords).not.toContain("about"); + expect(keywords).not.toContain("the"); + }); + + it("extracts keywords from Chinese conversational query", () => { + const keywords = extractKeywords("之前讨论的那个方案"); + expect(keywords).toContain("讨论"); + expect(keywords).toContain("方案"); + // Should not include stop words + expect(keywords).not.toContain("之前"); + expect(keywords).not.toContain("的"); + expect(keywords).not.toContain("那个"); + }); + + it("extracts keywords from mixed language query", () => { + const keywords = extractKeywords("昨天讨论的 API design"); + expect(keywords).toContain("讨论"); + expect(keywords).toContain("api"); + expect(keywords).toContain("design"); + }); + + it("returns specific technical terms", () => { + const keywords = extractKeywords("what was the solution for the CFR bug"); + expect(keywords).toContain("solution"); + expect(keywords).toContain("cfr"); + expect(keywords).toContain("bug"); + }); + + it("extracts keywords from Korean conversational query", () => { + const keywords = extractKeywords("어제 논의한 배포 전략"); + expect(keywords).toContain("논의한"); + expect(keywords).toContain("배포"); + expect(keywords).toContain("전략"); + // Should not include stop words + expect(keywords).not.toContain("어제"); + }); + + it("strips Korean particles to extract stems", () => { + const keywords = extractKeywords("서버에서 발생한 에러를 확인"); + expect(keywords).toContain("서버"); + expect(keywords).toContain("에러"); + expect(keywords).toContain("확인"); + }); + + it("filters Korean stop words including inflected forms", () => { + const keywords = extractKeywords("나는 그리고 그래서"); + expect(keywords).not.toContain("나"); + expect(keywords).not.toContain("나는"); + expect(keywords).not.toContain("그리고"); + expect(keywords).not.toContain("그래서"); + }); + + it("filters inflected Korean stop words not explicitly listed", () => { + const keywords = extractKeywords("그녀는 우리는"); + expect(keywords).not.toContain("그녀는"); + expect(keywords).not.toContain("우리는"); + expect(keywords).not.toContain("그녀"); + expect(keywords).not.toContain("우리"); + }); + + it("does not produce bogus single-char stems from particle stripping", () => { + const keywords = extractKeywords("논의"); + expect(keywords).toContain("논의"); + expect(keywords).not.toContain("논"); + }); + + it("strips longest Korean trailing particles first", () => { + const keywords = extractKeywords("기능으로 설명"); + expect(keywords).toContain("기능"); + expect(keywords).not.toContain("기능으"); + }); + + it("keeps stripped ASCII stems for mixed Korean tokens", () => { + const keywords = extractKeywords("API를 배포했다"); + expect(keywords).toContain("api"); + expect(keywords).toContain("배포했다"); + }); + + it("handles mixed Korean and English query", () => { + const keywords = extractKeywords("API 배포에 대한 논의"); + expect(keywords).toContain("api"); + expect(keywords).toContain("배포"); + expect(keywords).toContain("논의"); + }); + + it("extracts keywords from Japanese conversational query", () => { + const keywords = extractKeywords("昨日話したデプロイ戦略"); + expect(keywords).toContain("デプロイ"); + expect(keywords).toContain("戦略"); + expect(keywords).not.toContain("昨日"); + }); + + it("handles mixed Japanese and English query", () => { + const keywords = extractKeywords("昨日話したAPIのバグ"); + expect(keywords).toContain("api"); + expect(keywords).toContain("バグ"); + expect(keywords).not.toContain("した"); + }); + + it("filters Japanese stop words", () => { + const keywords = extractKeywords("これ それ そして どう"); + expect(keywords).not.toContain("これ"); + expect(keywords).not.toContain("それ"); + expect(keywords).not.toContain("そして"); + expect(keywords).not.toContain("どう"); + }); + + it("extracts keywords from Spanish conversational query", () => { + const keywords = extractKeywords("ayer hablamos sobre la estrategia de despliegue"); + expect(keywords).toContain("estrategia"); + expect(keywords).toContain("despliegue"); + expect(keywords).not.toContain("ayer"); + expect(keywords).not.toContain("sobre"); + }); + + it("extracts keywords from Portuguese conversational query", () => { + const keywords = extractKeywords("ontem falamos sobre a estratégia de implantação"); + expect(keywords).toContain("estratégia"); + expect(keywords).toContain("implantação"); + expect(keywords).not.toContain("ontem"); + expect(keywords).not.toContain("sobre"); + }); + + it("filters Spanish and Portuguese question stop words", () => { + const keywords = extractKeywords("cómo cuando donde porquê quando onde"); + expect(keywords).not.toContain("cómo"); + expect(keywords).not.toContain("cuando"); + expect(keywords).not.toContain("donde"); + expect(keywords).not.toContain("porquê"); + expect(keywords).not.toContain("quando"); + expect(keywords).not.toContain("onde"); + }); + + it("extracts keywords from Arabic conversational query", () => { + const keywords = extractKeywords("بالأمس ناقشنا استراتيجية النشر"); + expect(keywords).toContain("ناقشنا"); + expect(keywords).toContain("استراتيجية"); + expect(keywords).toContain("النشر"); + expect(keywords).not.toContain("بالأمس"); + }); + + it("filters Arabic question stop words", () => { + const keywords = extractKeywords("كيف متى أين ماذا"); + expect(keywords).not.toContain("كيف"); + expect(keywords).not.toContain("متى"); + expect(keywords).not.toContain("أين"); + expect(keywords).not.toContain("ماذا"); + }); + + it("handles empty query", () => { + expect(extractKeywords("")).toEqual([]); + expect(extractKeywords(" ")).toEqual([]); + }); + + it("handles query with only stop words", () => { + const keywords = extractKeywords("the a an is are"); + expect(keywords.length).toBe(0); + }); + + it("removes duplicate keywords", () => { + const keywords = extractKeywords("test test testing"); + const testCount = keywords.filter((k) => k === "test").length; + expect(testCount).toBe(1); + }); + + describe("with trigram tokenizer", () => { + const trigramOpts = { ftsTokenizer: "trigram" as const }; + + it("emits whole CJK block instead of unigrams in trigram mode", () => { + const defaultKeywords = extractKeywords("之前讨论的那个方案"); + const trigramKeywords = extractKeywords("之前讨论的那个方案", trigramOpts); + // Default mode produces bigrams + expect(defaultKeywords).toContain("讨论"); + expect(defaultKeywords).toContain("方案"); + // Trigram mode emits the whole contiguous CJK block (FTS5 trigram + // requires >= 3 chars per term; individual characters return no results) + expect(trigramKeywords).toContain("之前讨论的那个方案"); + expect(trigramKeywords).not.toContain("讨论"); + expect(trigramKeywords).not.toContain("方案"); + }); + + it("skips Japanese kanji bigrams in trigram mode", () => { + const defaultKeywords = extractKeywords("経済政策について"); + const trigramKeywords = extractKeywords("経済政策について", trigramOpts); + // Default mode adds kanji bigrams: 経済, 済政, 政策 + expect(defaultKeywords).toContain("経済"); + expect(defaultKeywords).toContain("済政"); + expect(defaultKeywords).toContain("政策"); + // Trigram mode keeps the full kanji block but skips bigram splitting + expect(trigramKeywords).toContain("経済政策"); + expect(trigramKeywords).not.toContain("済政"); + }); + + it("still filters stop words in trigram mode", () => { + const keywords = extractKeywords("これ それ そして どう", trigramOpts); + expect(keywords).not.toContain("これ"); + expect(keywords).not.toContain("それ"); + expect(keywords).not.toContain("そして"); + expect(keywords).not.toContain("どう"); + }); + + it("does not affect English keyword extraction", () => { + const keywords = extractKeywords("that thing we discussed about the API", trigramOpts); + expect(keywords).toContain("discussed"); + expect(keywords).toContain("api"); + expect(keywords).not.toContain("that"); + expect(keywords).not.toContain("the"); + }); + }); +}); + +describe("expandQueryForFts", () => { + it("returns original query and extracted keywords", () => { + const result = expandQueryForFts("that API we discussed"); + expect(result.original).toBe("that API we discussed"); + expect(result.keywords).toContain("api"); + expect(result.keywords).toContain("discussed"); + }); + + it("builds expanded OR query for FTS", () => { + const result = expandQueryForFts("the solution for bugs"); + expect(result.expanded).toContain("OR"); + expect(result.expanded).toContain("solution"); + expect(result.expanded).toContain("bugs"); + }); + + it("returns original query when no keywords extracted", () => { + const result = expandQueryForFts("the"); + expect(result.keywords.length).toBe(0); + expect(result.expanded).toBe("the"); + }); +}); diff --git a/src/memory-host-sdk/host/query-expansion.ts b/src/memory-host-sdk/host/query-expansion.ts new file mode 100644 index 00000000000..5ce120f1453 --- /dev/null +++ b/src/memory-host-sdk/host/query-expansion.ts @@ -0,0 +1,828 @@ +/** + * Query expansion for FTS-only search mode. + * + * When no embedding provider is available, we fall back to FTS (full-text search). + * FTS works best with specific keywords, but users often ask conversational queries + * like "that thing we discussed yesterday" or "之前讨论的那个方案". + * + * This module extracts meaningful keywords from such queries to improve FTS results. + */ + +// Common stop words that don't add search value +const STOP_WORDS_EN = new Set([ + // Articles and determiners + "a", + "an", + "the", + "this", + "that", + "these", + "those", + // Pronouns + "i", + "me", + "my", + "we", + "our", + "you", + "your", + "he", + "she", + "it", + "they", + "them", + // Common verbs + "is", + "are", + "was", + "were", + "be", + "been", + "being", + "have", + "has", + "had", + "do", + "does", + "did", + "will", + "would", + "could", + "should", + "can", + "may", + "might", + // Prepositions + "in", + "on", + "at", + "to", + "for", + "of", + "with", + "by", + "from", + "about", + "into", + "through", + "during", + "before", + "after", + "above", + "below", + "between", + "under", + "over", + // Conjunctions + "and", + "or", + "but", + "if", + "then", + "because", + "as", + "while", + "when", + "where", + "what", + "which", + "who", + "how", + "why", + // Time references (vague, not useful for FTS) + "yesterday", + "today", + "tomorrow", + "earlier", + "later", + "recently", + "before", + "ago", + "just", + "now", + // Vague references + "thing", + "things", + "stuff", + "something", + "anything", + "everything", + "nothing", + // Question words + "please", + "help", + "find", + "show", + "get", + "tell", + "give", +]); + +const STOP_WORDS_ES = new Set([ + // Articles and determiners + "el", + "la", + "los", + "las", + "un", + "una", + "unos", + "unas", + "este", + "esta", + "ese", + "esa", + // Pronouns + "yo", + "me", + "mi", + "nosotros", + "nosotras", + "tu", + "tus", + "usted", + "ustedes", + "ellos", + "ellas", + // Prepositions and conjunctions + "de", + "del", + "a", + "en", + "con", + "por", + "para", + "sobre", + "entre", + "y", + "o", + "pero", + "si", + "porque", + "como", + // Common verbs / auxiliaries + "es", + "son", + "fue", + "fueron", + "ser", + "estar", + "haber", + "tener", + "hacer", + // Time references (vague) + "ayer", + "hoy", + "mañana", + "antes", + "despues", + "después", + "ahora", + "recientemente", + // Question/request words + "que", + "qué", + "cómo", + "cuando", + "cuándo", + "donde", + "dónde", + "porqué", + "favor", + "ayuda", +]); + +const STOP_WORDS_PT = new Set([ + // Articles and determiners + "o", + "a", + "os", + "as", + "um", + "uma", + "uns", + "umas", + "este", + "esta", + "esse", + "essa", + // Pronouns + "eu", + "me", + "meu", + "minha", + "nos", + "nós", + "você", + "vocês", + "ele", + "ela", + "eles", + "elas", + // Prepositions and conjunctions + "de", + "do", + "da", + "em", + "com", + "por", + "para", + "sobre", + "entre", + "e", + "ou", + "mas", + "se", + "porque", + "como", + // Common verbs / auxiliaries + "é", + "são", + "foi", + "foram", + "ser", + "estar", + "ter", + "fazer", + // Time references (vague) + "ontem", + "hoje", + "amanhã", + "antes", + "depois", + "agora", + "recentemente", + // Question/request words + "que", + "quê", + "quando", + "onde", + "porquê", + "favor", + "ajuda", +]); + +const STOP_WORDS_AR = new Set([ + // Articles and connectors + "ال", + "و", + "أو", + "لكن", + "ثم", + "بل", + // Pronouns / references + "أنا", + "نحن", + "هو", + "هي", + "هم", + "هذا", + "هذه", + "ذلك", + "تلك", + "هنا", + "هناك", + // Common prepositions + "من", + "إلى", + "الى", + "في", + "على", + "عن", + "مع", + "بين", + "ل", + "ب", + "ك", + // Common auxiliaries / vague verbs + "كان", + "كانت", + "يكون", + "تكون", + "صار", + "أصبح", + "يمكن", + "ممكن", + // Time references (vague) + "بالأمس", + "امس", + "اليوم", + "غدا", + "الآن", + "قبل", + "بعد", + "مؤخرا", + // Question/request words + "لماذا", + "كيف", + "ماذا", + "متى", + "أين", + "هل", + "من فضلك", + "فضلا", + "ساعد", +]); + +const STOP_WORDS_KO = new Set([ + // Particles (조사) + "은", + "는", + "이", + "가", + "을", + "를", + "의", + "에", + "에서", + "로", + "으로", + "와", + "과", + "도", + "만", + "까지", + "부터", + "한테", + "에게", + "께", + "처럼", + "같이", + "보다", + "마다", + "밖에", + "대로", + // Pronouns (대명사) + "나", + "나는", + "내가", + "나를", + "너", + "우리", + "저", + "저희", + "그", + "그녀", + "그들", + "이것", + "저것", + "그것", + "여기", + "저기", + "거기", + // Common verbs / auxiliaries (일반 동사/보조 동사) + "있다", + "없다", + "하다", + "되다", + "이다", + "아니다", + "보다", + "주다", + "오다", + "가다", + // Nouns (의존 명사 / vague) + "것", + "거", + "등", + "수", + "때", + "곳", + "중", + "분", + // Adverbs + "잘", + "더", + "또", + "매우", + "정말", + "아주", + "많이", + "너무", + "좀", + // Conjunctions + "그리고", + "하지만", + "그래서", + "그런데", + "그러나", + "또는", + "그러면", + // Question words + "왜", + "어떻게", + "뭐", + "언제", + "어디", + "누구", + "무엇", + "어떤", + // Time (vague) + "어제", + "오늘", + "내일", + "최근", + "지금", + "아까", + "나중", + "전에", + // Request words + "제발", + "부탁", +]); + +// Common Korean trailing particles to strip from words for tokenization +// Sorted by descending length so longest-match-first is guaranteed. +const KO_TRAILING_PARTICLES = [ + "에서", + "으로", + "에게", + "한테", + "처럼", + "같이", + "보다", + "까지", + "부터", + "마다", + "밖에", + "대로", + "은", + "는", + "이", + "가", + "을", + "를", + "의", + "에", + "로", + "와", + "과", + "도", + "만", +].toSorted((a, b) => b.length - a.length); + +function stripKoreanTrailingParticle(token: string): string | null { + for (const particle of KO_TRAILING_PARTICLES) { + if (token.length > particle.length && token.endsWith(particle)) { + return token.slice(0, -particle.length); + } + } + return null; +} + +function isUsefulKoreanStem(stem: string): boolean { + // Prevent bogus one-syllable stems from words like "논의" -> "논". + if (/[\uac00-\ud7af]/.test(stem)) { + return stem.length >= 2; + } + // Keep stripped ASCII stems for mixed tokens like "API를" -> "api". + return /^[a-z0-9_]+$/i.test(stem); +} + +const STOP_WORDS_JA = new Set([ + // Pronouns and references + "これ", + "それ", + "あれ", + "この", + "その", + "あの", + "ここ", + "そこ", + "あそこ", + // Common auxiliaries / vague verbs + "する", + "した", + "して", + "です", + "ます", + "いる", + "ある", + "なる", + "できる", + // Particles / connectors + "の", + "こと", + "もの", + "ため", + "そして", + "しかし", + "また", + "でも", + "から", + "まで", + "より", + "だけ", + // Question words + "なぜ", + "どう", + "何", + "いつ", + "どこ", + "誰", + "どれ", + // Time (vague) + "昨日", + "今日", + "明日", + "最近", + "今", + "さっき", + "前", + "後", +]); + +const STOP_WORDS_ZH = new Set([ + // Pronouns + "我", + "我们", + "你", + "你们", + "他", + "她", + "它", + "他们", + "这", + "那", + "这个", + "那个", + "这些", + "那些", + // Auxiliary words + "的", + "了", + "着", + "过", + "得", + "地", + "吗", + "呢", + "吧", + "啊", + "呀", + "嘛", + "啦", + // Verbs (common, vague) + "是", + "有", + "在", + "被", + "把", + "给", + "让", + "用", + "到", + "去", + "来", + "做", + "说", + "看", + "找", + "想", + "要", + "能", + "会", + "可以", + // Prepositions and conjunctions + "和", + "与", + "或", + "但", + "但是", + "因为", + "所以", + "如果", + "虽然", + "而", + "也", + "都", + "就", + "还", + "又", + "再", + "才", + "只", + // Time (vague) + "之前", + "以前", + "之后", + "以后", + "刚才", + "现在", + "昨天", + "今天", + "明天", + "最近", + // Vague references + "东西", + "事情", + "事", + "什么", + "哪个", + "哪些", + "怎么", + "为什么", + "多少", + // Question/request words + "请", + "帮", + "帮忙", + "告诉", +]); + +export function isQueryStopWordToken(token: string): boolean { + return ( + STOP_WORDS_EN.has(token) || + STOP_WORDS_ES.has(token) || + STOP_WORDS_PT.has(token) || + STOP_WORDS_AR.has(token) || + STOP_WORDS_ZH.has(token) || + STOP_WORDS_KO.has(token) || + STOP_WORDS_JA.has(token) + ); +} + +/** + * Check if a token looks like a meaningful keyword. + * Returns false for short tokens, numbers-only, etc. + */ +function isValidKeyword(token: string): boolean { + if (!token || token.length === 0) { + return false; + } + // Skip very short English words (likely stop words or fragments) + if (/^[a-zA-Z]+$/.test(token) && token.length < 3) { + return false; + } + // Skip pure numbers (not useful for semantic search) + if (/^\d+$/.test(token)) { + return false; + } + // Skip tokens that are all punctuation + if (/^[\p{P}\p{S}]+$/u.test(token)) { + return false; + } + return true; +} + +/** + * Simple tokenizer that handles English, Chinese, Korean, and Japanese text. + * For Chinese, we do character-based splitting since we don't have a proper segmenter. + * For English, we split on whitespace and punctuation. + */ +function tokenize(text: string, opts?: { ftsTokenizer?: "unicode61" | "trigram" }): string[] { + const useTrigram = opts?.ftsTokenizer === "trigram"; + const tokens: string[] = []; + const normalized = text.toLowerCase().trim(); + + // Split into segments (English words, Chinese character sequences, etc.) + const segments = normalized.split(/[\s\p{P}]+/u).filter(Boolean); + + for (const segment of segments) { + // Japanese text often mixes scripts (kanji/kana/ASCII) without spaces. + // Extract script-specific chunks so technical terms like "API" / "バグ" are retained. + if (/[\u3040-\u30ff]/.test(segment)) { + const jpParts = + segment.match(/[a-z0-9_]+|[\u30a0-\u30ffー]+|[\u4e00-\u9fff]+|[\u3040-\u309f]{2,}/g) ?? []; + for (const part of jpParts) { + if (/^[\u4e00-\u9fff]+$/.test(part)) { + tokens.push(part); + if (!useTrigram) { + for (let i = 0; i < part.length - 1; i++) { + tokens.push(part[i] + part[i + 1]); + } + } + } else { + tokens.push(part); + } + } + } else if (/[\u4e00-\u9fff]/.test(segment)) { + // Check if segment contains CJK characters (Chinese) + const chars = Array.from(segment).filter((c) => /[\u4e00-\u9fff]/.test(c)); + if (useTrigram) { + // In trigram mode, push the whole contiguous CJK block (mirroring the + // Japanese kanji path). SQLite's trigram FTS requires at least 3 characters + // per query term — individual characters silently return no results. + const block = chars.join(""); + if (block.length > 0) { + tokens.push(block); + } + } else { + // Default mode: unigrams + bigrams for phrase matching + tokens.push(...chars); + for (let i = 0; i < chars.length - 1; i++) { + tokens.push(chars[i] + chars[i + 1]); + } + } + } else if (/[\uac00-\ud7af\u3131-\u3163]/.test(segment)) { + // For Korean (Hangul syllables and jamo), keep the word as-is unless it is + // effectively a stop word once trailing particles are removed. + const stem = stripKoreanTrailingParticle(segment); + const stemIsStopWord = stem !== null && STOP_WORDS_KO.has(stem); + if (!STOP_WORDS_KO.has(segment) && !stemIsStopWord) { + tokens.push(segment); + } + // Also emit particle-stripped stems when they are useful keywords. + if (stem && !STOP_WORDS_KO.has(stem) && isUsefulKoreanStem(stem)) { + tokens.push(stem); + } + } else { + // For non-CJK, keep as single token + tokens.push(segment); + } + } + + return tokens; +} + +/** + * Extract keywords from a conversational query for FTS search. + * + * Examples: + * - "that thing we discussed about the API" → ["discussed", "API"] + * - "之前讨论的那个方案" → ["讨论", "方案"] + * - "what was the solution for the bug" → ["solution", "bug"] + */ +export function extractKeywords( + query: string, + opts?: { ftsTokenizer?: "unicode61" | "trigram" }, +): string[] { + const tokens = tokenize(query, opts); + const keywords: string[] = []; + const seen = new Set(); + + for (const token of tokens) { + // Skip stop words + if (isQueryStopWordToken(token)) { + continue; + } + // Skip invalid keywords + if (!isValidKeyword(token)) { + continue; + } + // Skip duplicates + if (seen.has(token)) { + continue; + } + seen.add(token); + keywords.push(token); + } + + return keywords; +} + +/** + * Expand a query for FTS search. + * Returns both the original query and extracted keywords for OR-matching. + * + * @param query - User's original query + * @returns Object with original query and extracted keywords + */ +export function expandQueryForFts( + query: string, + opts?: { ftsTokenizer?: "unicode61" | "trigram" }, +): { + original: string; + keywords: string[]; + expanded: string; +} { + const original = query.trim(); + const keywords = extractKeywords(original, opts); + + // Build expanded query: original terms OR extracted keywords + // This ensures both exact matches and keyword matches are found + const expanded = keywords.length > 0 ? `${original} OR ${keywords.join(" OR ")}` : original; + + return { original, keywords, expanded }; +} + +/** + * Type for an optional LLM-based query expander. + * Can be provided to enhance keyword extraction with semantic understanding. + */ +export type LlmQueryExpander = (query: string) => Promise; + +/** + * Expand query with optional LLM assistance. + * Falls back to local extraction if LLM is unavailable or fails. + */ +export async function expandQueryWithLlm( + query: string, + llmExpander?: LlmQueryExpander, + opts?: { ftsTokenizer?: "unicode61" | "trigram" }, +): Promise { + // If LLM expander is provided, try it first + if (llmExpander) { + try { + const llmKeywords = await llmExpander(query); + if (llmKeywords.length > 0) { + return llmKeywords; + } + } catch { + // LLM failed, fall back to local extraction + } + } + + // Fall back to local keyword extraction + return extractKeywords(query, opts); +} diff --git a/src/memory-host-sdk/host/read-file.ts b/src/memory-host-sdk/host/read-file.ts new file mode 100644 index 00000000000..d9e6bbc3ce8 --- /dev/null +++ b/src/memory-host-sdk/host/read-file.ts @@ -0,0 +1,96 @@ +import fs from "node:fs/promises"; +import path from "node:path"; +import { resolveAgentWorkspaceDir } from "../../agents/agent-scope.js"; +import { resolveMemorySearchConfig } from "../../agents/memory-search.js"; +import type { OpenClawConfig } from "../../config/config.js"; +import { isFileMissingError, statRegularFile } from "./fs-utils.js"; +import { isMemoryPath, normalizeExtraMemoryPaths } from "./internal.js"; + +export async function readMemoryFile(params: { + workspaceDir: string; + extraPaths?: string[]; + relPath: string; + from?: number; + lines?: number; +}): Promise<{ text: string; path: string }> { + const rawPath = params.relPath.trim(); + if (!rawPath) { + throw new Error("path required"); + } + const absPath = path.isAbsolute(rawPath) + ? path.resolve(rawPath) + : path.resolve(params.workspaceDir, rawPath); + const relPath = path.relative(params.workspaceDir, absPath).replace(/\\/g, "/"); + const inWorkspace = relPath.length > 0 && !relPath.startsWith("..") && !path.isAbsolute(relPath); + const allowedWorkspace = inWorkspace && isMemoryPath(relPath); + let allowedAdditional = false; + if (!allowedWorkspace && (params.extraPaths?.length ?? 0) > 0) { + const additionalPaths = normalizeExtraMemoryPaths(params.workspaceDir, params.extraPaths); + for (const additionalPath of additionalPaths) { + try { + const stat = await fs.lstat(additionalPath); + if (stat.isSymbolicLink()) { + continue; + } + if (stat.isDirectory()) { + if (absPath === additionalPath || absPath.startsWith(`${additionalPath}${path.sep}`)) { + allowedAdditional = true; + break; + } + continue; + } + if (stat.isFile() && absPath === additionalPath && absPath.endsWith(".md")) { + allowedAdditional = true; + break; + } + } catch {} + } + } + if (!allowedWorkspace && !allowedAdditional) { + throw new Error("path required"); + } + if (!absPath.endsWith(".md")) { + throw new Error("path required"); + } + const statResult = await statRegularFile(absPath); + if (statResult.missing) { + return { text: "", path: relPath }; + } + let content: string; + try { + content = await fs.readFile(absPath, "utf-8"); + } catch (err) { + if (isFileMissingError(err)) { + return { text: "", path: relPath }; + } + throw err; + } + if (!params.from && !params.lines) { + return { text: content, path: relPath }; + } + const fileLines = content.split("\n"); + const start = Math.max(1, params.from ?? 1); + const count = Math.max(1, params.lines ?? fileLines.length); + const slice = fileLines.slice(start - 1, start - 1 + count); + return { text: slice.join("\n"), path: relPath }; +} + +export async function readAgentMemoryFile(params: { + cfg: OpenClawConfig; + agentId: string; + relPath: string; + from?: number; + lines?: number; +}): Promise<{ text: string; path: string }> { + const settings = resolveMemorySearchConfig(params.cfg, params.agentId); + if (!settings) { + throw new Error("memory search disabled"); + } + return await readMemoryFile({ + workspaceDir: resolveAgentWorkspaceDir(params.cfg, params.agentId), + extraPaths: settings.extraPaths, + relPath: params.relPath, + from: params.from, + lines: params.lines, + }); +} diff --git a/src/memory-host-sdk/host/remote-http.ts b/src/memory-host-sdk/host/remote-http.ts new file mode 100644 index 00000000000..132e92a7548 --- /dev/null +++ b/src/memory-host-sdk/host/remote-http.ts @@ -0,0 +1,40 @@ +import { fetchWithSsrFGuard } from "../../infra/net/fetch-guard.js"; +import type { SsrFPolicy } from "../../infra/net/ssrf.js"; + +export function buildRemoteBaseUrlPolicy(baseUrl: string): SsrFPolicy | undefined { + const trimmed = baseUrl.trim(); + if (!trimmed) { + return undefined; + } + try { + const parsed = new URL(trimmed); + if (parsed.protocol !== "http:" && parsed.protocol !== "https:") { + return undefined; + } + // Keep policy tied to the configured host so private operator endpoints + // continue to work, while cross-host redirects stay blocked. + return { allowedHostnames: [parsed.hostname] }; + } catch { + return undefined; + } +} + +export async function withRemoteHttpResponse(params: { + url: string; + init?: RequestInit; + ssrfPolicy?: SsrFPolicy; + auditContext?: string; + onResponse: (response: Response) => Promise; +}): Promise { + const { response, release } = await fetchWithSsrFGuard({ + url: params.url, + init: params.init, + policy: params.ssrfPolicy, + auditContext: params.auditContext ?? "memory-remote", + }); + try { + return await params.onResponse(response); + } finally { + await release(); + } +} diff --git a/src/memory-host-sdk/host/secret-input.ts b/src/memory-host-sdk/host/secret-input.ts new file mode 100644 index 00000000000..98dd0c87084 --- /dev/null +++ b/src/memory-host-sdk/host/secret-input.ts @@ -0,0 +1,18 @@ +import { + hasConfiguredSecretInput, + normalizeResolvedSecretInputString, +} from "../../config/types.secrets.js"; + +export function hasConfiguredMemorySecretInput(value: unknown): boolean { + return hasConfiguredSecretInput(value); +} + +export function resolveMemorySecretInputString(params: { + value: unknown; + path: string; +}): string | undefined { + return normalizeResolvedSecretInputString({ + value: params.value, + path: params.path, + }); +} diff --git a/src/memory-host-sdk/host/session-files.test.ts b/src/memory-host-sdk/host/session-files.test.ts new file mode 100644 index 00000000000..476aa35644b --- /dev/null +++ b/src/memory-host-sdk/host/session-files.test.ts @@ -0,0 +1,123 @@ +import fs from "node:fs/promises"; +import os from "node:os"; +import path from "node:path"; +import { afterEach, beforeEach, describe, expect, it } from "vitest"; +import { buildSessionEntry, listSessionFilesForAgent } from "./session-files.js"; + +let tmpDir: string; +let originalStateDir: string | undefined; + +beforeEach(async () => { + tmpDir = await fs.mkdtemp(path.join(os.tmpdir(), "session-entry-test-")); + originalStateDir = process.env.OPENCLAW_STATE_DIR; + process.env.OPENCLAW_STATE_DIR = tmpDir; +}); + +afterEach(async () => { + if (originalStateDir === undefined) { + delete process.env.OPENCLAW_STATE_DIR; + } else { + process.env.OPENCLAW_STATE_DIR = originalStateDir; + } + await fs.rm(tmpDir, { recursive: true, force: true }); +}); + +describe("listSessionFilesForAgent", () => { + it("includes reset and deleted transcripts in session file listing", async () => { + const sessionsDir = path.join(tmpDir, "agents", "main", "sessions"); + await fs.mkdir(path.join(sessionsDir, "archive"), { recursive: true }); + + const included = [ + "active.jsonl", + "active.jsonl.reset.2026-02-16T22-26-33.000Z", + "active.jsonl.deleted.2026-02-16T22-27-33.000Z", + ]; + const excluded = ["active.jsonl.bak.2026-02-16T22-28-33.000Z", "sessions.json", "notes.md"]; + + for (const fileName of [...included, ...excluded]) { + await fs.writeFile(path.join(sessionsDir, fileName), ""); + } + await fs.writeFile( + path.join(sessionsDir, "archive", "nested.jsonl.deleted.2026-02-16T22-29-33.000Z"), + "", + ); + + const files = await listSessionFilesForAgent("main"); + + expect(files.map((filePath) => path.basename(filePath)).toSorted()).toEqual( + included.toSorted(), + ); + }); +}); + +describe("buildSessionEntry", () => { + it("returns lineMap tracking original JSONL line numbers", async () => { + // Simulate a real session JSONL file with metadata records interspersed + // Lines 1-3: non-message metadata records + // Line 4: user message + // Line 5: metadata + // Line 6: assistant message + // Line 7: user message + const jsonlLines = [ + JSON.stringify({ type: "custom", customType: "model-snapshot", data: {} }), + JSON.stringify({ type: "custom", customType: "openclaw.cache-ttl", data: {} }), + JSON.stringify({ type: "session-meta", agentId: "test" }), + JSON.stringify({ type: "message", message: { role: "user", content: "Hello world" } }), + JSON.stringify({ type: "custom", customType: "tool-result", data: {} }), + JSON.stringify({ + type: "message", + message: { role: "assistant", content: "Hi there, how can I help?" }, + }), + JSON.stringify({ type: "message", message: { role: "user", content: "Tell me a joke" } }), + ]; + const filePath = path.join(tmpDir, "session.jsonl"); + await fs.writeFile(filePath, jsonlLines.join("\n")); + + const entry = await buildSessionEntry(filePath); + expect(entry).not.toBeNull(); + + // The content should have 3 lines (3 message records) + const contentLines = entry!.content.split("\n"); + expect(contentLines).toHaveLength(3); + expect(contentLines[0]).toContain("User: Hello world"); + expect(contentLines[1]).toContain("Assistant: Hi there"); + expect(contentLines[2]).toContain("User: Tell me a joke"); + + // lineMap should map each content line to its original JSONL line (1-indexed) + // Content line 0 → JSONL line 4 (the first user message) + // Content line 1 → JSONL line 6 (the assistant message) + // Content line 2 → JSONL line 7 (the second user message) + expect(entry!.lineMap).toBeDefined(); + expect(entry!.lineMap).toEqual([4, 6, 7]); + }); + + it("returns empty lineMap when no messages are found", async () => { + const jsonlLines = [ + JSON.stringify({ type: "custom", customType: "model-snapshot", data: {} }), + JSON.stringify({ type: "session-meta", agentId: "test" }), + ]; + const filePath = path.join(tmpDir, "empty-session.jsonl"); + await fs.writeFile(filePath, jsonlLines.join("\n")); + + const entry = await buildSessionEntry(filePath); + expect(entry).not.toBeNull(); + expect(entry!.content).toBe(""); + expect(entry!.lineMap).toEqual([]); + }); + + it("skips blank lines and invalid JSON without breaking lineMap", async () => { + const jsonlLines = [ + "", + "not valid json", + JSON.stringify({ type: "message", message: { role: "user", content: "First" } }), + "", + JSON.stringify({ type: "message", message: { role: "assistant", content: "Second" } }), + ]; + const filePath = path.join(tmpDir, "gaps.jsonl"); + await fs.writeFile(filePath, jsonlLines.join("\n")); + + const entry = await buildSessionEntry(filePath); + expect(entry).not.toBeNull(); + expect(entry!.lineMap).toEqual([3, 5]); + }); +}); diff --git a/src/memory-host-sdk/host/session-files.ts b/src/memory-host-sdk/host/session-files.ts new file mode 100644 index 00000000000..8db6a46692c --- /dev/null +++ b/src/memory-host-sdk/host/session-files.ts @@ -0,0 +1,132 @@ +import fs from "node:fs/promises"; +import path from "node:path"; +import { isUsageCountedSessionTranscriptFileName } from "../../config/sessions/artifacts.js"; +import { resolveSessionTranscriptsDirForAgent } from "../../config/sessions/paths.js"; +import { redactSensitiveText } from "../../logging/redact.js"; +import { createSubsystemLogger } from "../../logging/subsystem.js"; +import { hashText } from "./internal.js"; + +const log = createSubsystemLogger("memory"); + +export type SessionFileEntry = { + path: string; + absPath: string; + mtimeMs: number; + size: number; + hash: string; + content: string; + /** Maps each content line (0-indexed) to its 1-indexed JSONL source line. */ + lineMap: number[]; +}; + +export async function listSessionFilesForAgent(agentId: string): Promise { + const dir = resolveSessionTranscriptsDirForAgent(agentId); + try { + const entries = await fs.readdir(dir, { withFileTypes: true }); + return entries + .filter((entry) => entry.isFile()) + .map((entry) => entry.name) + .filter((name) => isUsageCountedSessionTranscriptFileName(name)) + .map((name) => path.join(dir, name)); + } catch { + return []; + } +} + +export function sessionPathForFile(absPath: string): string { + return path.join("sessions", path.basename(absPath)).replace(/\\/g, "/"); +} + +function normalizeSessionText(value: string): string { + return value + .replace(/\s*\n+\s*/g, " ") + .replace(/\s+/g, " ") + .trim(); +} + +export function extractSessionText(content: unknown): string | null { + if (typeof content === "string") { + const normalized = normalizeSessionText(content); + return normalized ? normalized : null; + } + if (!Array.isArray(content)) { + return null; + } + const parts: string[] = []; + for (const block of content) { + if (!block || typeof block !== "object") { + continue; + } + const record = block as { type?: unknown; text?: unknown }; + if (record.type !== "text" || typeof record.text !== "string") { + continue; + } + const normalized = normalizeSessionText(record.text); + if (normalized) { + parts.push(normalized); + } + } + if (parts.length === 0) { + return null; + } + return parts.join(" "); +} + +export async function buildSessionEntry(absPath: string): Promise { + try { + const stat = await fs.stat(absPath); + const raw = await fs.readFile(absPath, "utf-8"); + const lines = raw.split("\n"); + const collected: string[] = []; + const lineMap: number[] = []; + for (let jsonlIdx = 0; jsonlIdx < lines.length; jsonlIdx++) { + const line = lines[jsonlIdx]; + if (!line.trim()) { + continue; + } + let record: unknown; + try { + record = JSON.parse(line); + } catch { + continue; + } + if ( + !record || + typeof record !== "object" || + (record as { type?: unknown }).type !== "message" + ) { + continue; + } + const message = (record as { message?: unknown }).message as + | { role?: unknown; content?: unknown } + | undefined; + if (!message || typeof message.role !== "string") { + continue; + } + if (message.role !== "user" && message.role !== "assistant") { + continue; + } + const text = extractSessionText(message.content); + if (!text) { + continue; + } + const safe = redactSensitiveText(text, { mode: "tools" }); + const label = message.role === "user" ? "User" : "Assistant"; + collected.push(`${label}: ${safe}`); + lineMap.push(jsonlIdx + 1); + } + const content = collected.join("\n"); + return { + path: sessionPathForFile(absPath), + absPath, + mtimeMs: stat.mtimeMs, + size: stat.size, + hash: hashText(content + "\n" + lineMap.join(",")), + content, + lineMap, + }; + } catch (err) { + log.debug(`Failed reading session file ${absPath}: ${String(err)}`); + return null; + } +} diff --git a/src/memory-host-sdk/host/sqlite-vec.ts b/src/memory-host-sdk/host/sqlite-vec.ts new file mode 100644 index 00000000000..4769634669e --- /dev/null +++ b/src/memory-host-sdk/host/sqlite-vec.ts @@ -0,0 +1,24 @@ +import type { DatabaseSync } from "node:sqlite"; + +export async function loadSqliteVecExtension(params: { + db: DatabaseSync; + extensionPath?: string; +}): Promise<{ ok: boolean; extensionPath?: string; error?: string }> { + try { + const sqliteVec = await import("sqlite-vec"); + const resolvedPath = params.extensionPath?.trim() ? params.extensionPath.trim() : undefined; + const extensionPath = resolvedPath ?? sqliteVec.getLoadablePath(); + + params.db.enableLoadExtension(true); + if (resolvedPath) { + params.db.loadExtension(extensionPath); + } else { + sqliteVec.load(params.db); + } + + return { ok: true, extensionPath }; + } catch (err) { + const message = err instanceof Error ? err.message : String(err); + return { ok: false, error: message }; + } +} diff --git a/src/memory-host-sdk/host/sqlite.ts b/src/memory-host-sdk/host/sqlite.ts new file mode 100644 index 00000000000..fabb16d983a --- /dev/null +++ b/src/memory-host-sdk/host/sqlite.ts @@ -0,0 +1,19 @@ +import { createRequire } from "node:module"; +import { installProcessWarningFilter } from "../../infra/warning-filter.js"; + +const require = createRequire(import.meta.url); + +export function requireNodeSqlite(): typeof import("node:sqlite") { + installProcessWarningFilter(); + try { + return require("node:sqlite") as typeof import("node:sqlite"); + } catch (err) { + const message = err instanceof Error ? err.message : String(err); + // Node distributions can ship without the experimental builtin SQLite module. + // Surface an actionable error instead of the generic "unknown builtin module". + throw new Error( + `SQLite support is unavailable in this Node runtime (missing node:sqlite). ${message}`, + { cause: err }, + ); + } +} diff --git a/src/memory-host-sdk/host/status-format.ts b/src/memory-host-sdk/host/status-format.ts new file mode 100644 index 00000000000..0fd70136be0 --- /dev/null +++ b/src/memory-host-sdk/host/status-format.ts @@ -0,0 +1,45 @@ +export type Tone = "ok" | "warn" | "muted"; + +export function resolveMemoryVectorState(vector: { enabled: boolean; available?: boolean }): { + tone: Tone; + state: "ready" | "unavailable" | "disabled" | "unknown"; +} { + if (!vector.enabled) { + return { tone: "muted", state: "disabled" }; + } + if (vector.available === true) { + return { tone: "ok", state: "ready" }; + } + if (vector.available === false) { + return { tone: "warn", state: "unavailable" }; + } + return { tone: "muted", state: "unknown" }; +} + +export function resolveMemoryFtsState(fts: { enabled: boolean; available: boolean }): { + tone: Tone; + state: "ready" | "unavailable" | "disabled"; +} { + if (!fts.enabled) { + return { tone: "muted", state: "disabled" }; + } + return fts.available ? { tone: "ok", state: "ready" } : { tone: "warn", state: "unavailable" }; +} + +export function resolveMemoryCacheSummary(cache: { enabled: boolean; entries?: number }): { + tone: Tone; + text: string; +} { + if (!cache.enabled) { + return { tone: "muted", text: "cache off" }; + } + const suffix = typeof cache.entries === "number" ? ` (${cache.entries})` : ""; + return { tone: "ok", text: `cache on${suffix}` }; +} + +export function resolveMemoryCacheState(cache: { enabled: boolean }): { + tone: Tone; + state: "enabled" | "disabled"; +} { + return cache.enabled ? { tone: "ok", state: "enabled" } : { tone: "muted", state: "disabled" }; +} diff --git a/src/memory-host-sdk/host/test-helpers/ssrf.ts b/src/memory-host-sdk/host/test-helpers/ssrf.ts new file mode 100644 index 00000000000..e8b6f99d553 --- /dev/null +++ b/src/memory-host-sdk/host/test-helpers/ssrf.ts @@ -0,0 +1,14 @@ +import { vi } from "vitest"; +import * as ssrf from "../../../infra/net/ssrf.js"; + +export function mockPublicPinnedHostname() { + return vi.spyOn(ssrf, "resolvePinnedHostnameWithPolicy").mockImplementation(async (hostname) => { + const normalized = hostname.trim().toLowerCase().replace(/\.$/, ""); + const addresses = ["93.184.216.34"]; + return { + hostname: normalized, + addresses, + lookup: ssrf.createPinnedLookup({ hostname: normalized, addresses }), + }; + }); +} diff --git a/src/memory-host-sdk/host/types.ts b/src/memory-host-sdk/host/types.ts new file mode 100644 index 00000000000..880384df71a --- /dev/null +++ b/src/memory-host-sdk/host/types.ts @@ -0,0 +1,81 @@ +export type MemorySource = "memory" | "sessions"; + +export type MemorySearchResult = { + path: string; + startLine: number; + endLine: number; + score: number; + snippet: string; + source: MemorySource; + citation?: string; +}; + +export type MemoryEmbeddingProbeResult = { + ok: boolean; + error?: string; +}; + +export type MemorySyncProgressUpdate = { + completed: number; + total: number; + label?: string; +}; + +export type MemoryProviderStatus = { + backend: "builtin" | "qmd"; + provider: string; + model?: string; + requestedProvider?: string; + files?: number; + chunks?: number; + dirty?: boolean; + workspaceDir?: string; + dbPath?: string; + extraPaths?: string[]; + sources?: MemorySource[]; + sourceCounts?: Array<{ source: MemorySource; files: number; chunks: number }>; + cache?: { enabled: boolean; entries?: number; maxEntries?: number }; + fts?: { enabled: boolean; available: boolean; error?: string }; + fallback?: { from: string; reason?: string }; + vector?: { + enabled: boolean; + available?: boolean; + extensionPath?: string; + loadError?: string; + dims?: number; + }; + batch?: { + enabled: boolean; + failures: number; + limit: number; + wait: boolean; + concurrency: number; + pollIntervalMs: number; + timeoutMs: number; + lastError?: string; + lastProvider?: string; + }; + custom?: Record; +}; + +export interface MemorySearchManager { + search( + query: string, + opts?: { maxResults?: number; minScore?: number; sessionKey?: string }, + ): Promise; + readFile(params: { + relPath: string; + from?: number; + lines?: number; + }): Promise<{ text: string; path: string }>; + status(): MemoryProviderStatus; + sync?(params?: { + reason?: string; + force?: boolean; + sessionFiles?: string[]; + progress?: (update: MemorySyncProgressUpdate) => void; + }): Promise; + probeEmbeddingAvailability(): Promise; + probeVectorAvailability(): Promise; + close?(): Promise; +} diff --git a/src/memory-host-sdk/multimodal.ts b/src/memory-host-sdk/multimodal.ts index 36b50cbbf4b..5c62de35490 100644 --- a/src/memory-host-sdk/multimodal.ts +++ b/src/memory-host-sdk/multimodal.ts @@ -1 +1,6 @@ -export * from "../../packages/memory-host-sdk/src/multimodal.js"; +export { + isMemoryMultimodalEnabled, + normalizeMemoryMultimodalSettings, + supportsMemoryMultimodalEmbeddings, + type MemoryMultimodalSettings, +} from "./host/multimodal.js"; diff --git a/src/memory-host-sdk/query.ts b/src/memory-host-sdk/query.ts index 2a2ef6bbed4..bb945afaa65 100644 --- a/src/memory-host-sdk/query.ts +++ b/src/memory-host-sdk/query.ts @@ -1 +1 @@ -export * from "../../packages/memory-host-sdk/src/query.js"; +export { extractKeywords, isQueryStopWordToken } from "./host/query-expansion.js"; diff --git a/src/memory-host-sdk/runtime-cli.ts b/src/memory-host-sdk/runtime-cli.ts new file mode 100644 index 00000000000..63b918ad6b8 --- /dev/null +++ b/src/memory-host-sdk/runtime-cli.ts @@ -0,0 +1,11 @@ +// Focused runtime contract for memory CLI/UI helpers. + +export { formatErrorMessage, withManager } from "../cli/cli-utils.js"; +export { formatHelpExamples } from "../cli/help-format.js"; +export { resolveCommandSecretRefsViaGateway } from "../cli/command-secret-gateway.js"; +export { withProgress, withProgressTotals } from "../cli/progress.js"; +export { defaultRuntime } from "../runtime.js"; +export { formatDocsLink } from "../terminal/links.js"; +export { colorize, isRich, theme } from "../terminal/theme.js"; +export { isVerbose, setVerbose } from "../globals.js"; +export { shortenHomeInString, shortenHomePath } from "../utils.js"; diff --git a/src/memory-host-sdk/runtime-core.ts b/src/memory-host-sdk/runtime-core.ts new file mode 100644 index 00000000000..b18782c0ce9 --- /dev/null +++ b/src/memory-host-sdk/runtime-core.ts @@ -0,0 +1,24 @@ +// Focused runtime contract for memory plugin config/state/helpers. + +export type { AnyAgentTool } from "../agents/tools/common.js"; +export { resolveCronStyleNow } from "../agents/current-time.js"; +export { DEFAULT_PI_COMPACTION_RESERVE_TOKENS_FLOOR } from "../agents/pi-settings.js"; +export { resolveDefaultAgentId, resolveSessionAgentId } from "../agents/agent-scope.js"; +export { resolveMemorySearchConfig } from "../agents/memory-search.js"; +export { jsonResult, readNumberParam, readStringParam } from "../agents/tools/common.js"; +export { SILENT_REPLY_TOKEN } from "../auto-reply/tokens.js"; +export { parseNonNegativeByteSize } from "../config/byte-size.js"; +export { loadConfig } from "../config/config.js"; +export { resolveStateDir } from "../config/paths.js"; +export { resolveSessionTranscriptsDirForAgent } from "../config/sessions/paths.js"; +export { emptyPluginConfigSchema } from "../plugins/config-schema.js"; +export { parseAgentSessionKey } from "../routing/session-key.js"; +export type { OpenClawConfig } from "../config/config.js"; +export type { MemoryCitationsMode } from "../config/types.memory.js"; +export type { + MemoryFlushPlan, + MemoryFlushPlanResolver, + MemoryPluginRuntime, + MemoryPromptSectionBuilder, +} from "../plugins/memory-state.js"; +export type { OpenClawPluginApi } from "../plugins/types.js"; diff --git a/src/memory-host-sdk/runtime-files.ts b/src/memory-host-sdk/runtime-files.ts new file mode 100644 index 00000000000..dd50c31eb46 --- /dev/null +++ b/src/memory-host-sdk/runtime-files.ts @@ -0,0 +1,6 @@ +// Focused runtime contract for memory file/backend access. + +export { listMemoryFiles, normalizeExtraMemoryPaths } from "./host/internal.js"; +export { readAgentMemoryFile } from "./host/read-file.js"; +export { resolveMemoryBackendConfig } from "./host/backend-config.js"; +export type { MemorySearchResult } from "./host/types.js"; diff --git a/src/memory-host-sdk/runtime.ts b/src/memory-host-sdk/runtime.ts new file mode 100644 index 00000000000..6e152ea0dcb --- /dev/null +++ b/src/memory-host-sdk/runtime.ts @@ -0,0 +1,6 @@ +// Aggregate workspace contract for memory runtime/helper seams. +// Keep focused subpaths preferred for new code. + +export * from "./runtime-core.js"; +export * from "./runtime-cli.js"; +export * from "./runtime-files.js"; diff --git a/src/memory-host-sdk/secret.ts b/src/memory-host-sdk/secret.ts index f293730b357..b2b6b94ab47 100644 --- a/src/memory-host-sdk/secret.ts +++ b/src/memory-host-sdk/secret.ts @@ -1 +1,4 @@ -export * from "../../packages/memory-host-sdk/src/secret.js"; +export { + hasConfiguredMemorySecretInput, + resolveMemorySecretInputString, +} from "./host/secret-input.js"; diff --git a/src/memory-host-sdk/status.ts b/src/memory-host-sdk/status.ts index 704b37737b4..dc718abd96b 100644 --- a/src/memory-host-sdk/status.ts +++ b/src/memory-host-sdk/status.ts @@ -1 +1,6 @@ -export * from "../../packages/memory-host-sdk/src/status.js"; +export { + resolveMemoryCacheSummary, + resolveMemoryFtsState, + resolveMemoryVectorState, + type Tone, +} from "./host/status-format.js"; diff --git a/src/plugin-sdk/memory-core-host-engine-embeddings.ts b/src/plugin-sdk/memory-core-host-engine-embeddings.ts index a5b7e2b91a9..06a545bd682 100644 --- a/src/plugin-sdk/memory-core-host-engine-embeddings.ts +++ b/src/plugin-sdk/memory-core-host-engine-embeddings.ts @@ -1 +1 @@ -export * from "../../packages/memory-host-sdk/src/engine-embeddings.js"; +export * from "../memory-host-sdk/engine-embeddings.js"; diff --git a/src/plugin-sdk/memory-core-host-engine-foundation.ts b/src/plugin-sdk/memory-core-host-engine-foundation.ts index 84cd0cbdc5e..c71fb39406c 100644 --- a/src/plugin-sdk/memory-core-host-engine-foundation.ts +++ b/src/plugin-sdk/memory-core-host-engine-foundation.ts @@ -1 +1 @@ -export * from "../../packages/memory-host-sdk/src/engine-foundation.js"; +export * from "../memory-host-sdk/engine-foundation.js"; diff --git a/src/plugin-sdk/memory-core-host-engine-qmd.ts b/src/plugin-sdk/memory-core-host-engine-qmd.ts index 21a0be44873..7356e906e68 100644 --- a/src/plugin-sdk/memory-core-host-engine-qmd.ts +++ b/src/plugin-sdk/memory-core-host-engine-qmd.ts @@ -1 +1 @@ -export * from "../../packages/memory-host-sdk/src/engine-qmd.js"; +export * from "../memory-host-sdk/engine-qmd.js"; diff --git a/src/plugin-sdk/memory-core-host-engine-storage.ts b/src/plugin-sdk/memory-core-host-engine-storage.ts index ee3f3a4e410..3d0836052f7 100644 --- a/src/plugin-sdk/memory-core-host-engine-storage.ts +++ b/src/plugin-sdk/memory-core-host-engine-storage.ts @@ -1 +1 @@ -export * from "../../packages/memory-host-sdk/src/engine-storage.js"; +export * from "../memory-host-sdk/engine-storage.js"; diff --git a/src/plugin-sdk/memory-core-host-multimodal.ts b/src/plugin-sdk/memory-core-host-multimodal.ts index 36b50cbbf4b..448c4ae530e 100644 --- a/src/plugin-sdk/memory-core-host-multimodal.ts +++ b/src/plugin-sdk/memory-core-host-multimodal.ts @@ -1 +1 @@ -export * from "../../packages/memory-host-sdk/src/multimodal.js"; +export * from "../memory-host-sdk/multimodal.js"; diff --git a/src/plugin-sdk/memory-core-host-query.ts b/src/plugin-sdk/memory-core-host-query.ts index 2a2ef6bbed4..48fd920ea32 100644 --- a/src/plugin-sdk/memory-core-host-query.ts +++ b/src/plugin-sdk/memory-core-host-query.ts @@ -1 +1 @@ -export * from "../../packages/memory-host-sdk/src/query.js"; +export * from "../memory-host-sdk/query.js"; diff --git a/src/plugin-sdk/memory-core-host-runtime-cli.ts b/src/plugin-sdk/memory-core-host-runtime-cli.ts index 69ea0ceaad6..d0267ce2676 100644 --- a/src/plugin-sdk/memory-core-host-runtime-cli.ts +++ b/src/plugin-sdk/memory-core-host-runtime-cli.ts @@ -1 +1 @@ -export * from "../../packages/memory-host-sdk/src/runtime-cli.js"; +export * from "../memory-host-sdk/runtime-cli.js"; diff --git a/src/plugin-sdk/memory-core-host-runtime-core.ts b/src/plugin-sdk/memory-core-host-runtime-core.ts index f45bde01cb0..4cfbc2cee7c 100644 --- a/src/plugin-sdk/memory-core-host-runtime-core.ts +++ b/src/plugin-sdk/memory-core-host-runtime-core.ts @@ -1 +1 @@ -export * from "../../packages/memory-host-sdk/src/runtime-core.js"; +export * from "../memory-host-sdk/runtime-core.js"; diff --git a/src/plugin-sdk/memory-core-host-runtime-files.ts b/src/plugin-sdk/memory-core-host-runtime-files.ts index f4daa70a22d..78929e9cb5d 100644 --- a/src/plugin-sdk/memory-core-host-runtime-files.ts +++ b/src/plugin-sdk/memory-core-host-runtime-files.ts @@ -1 +1 @@ -export * from "../../packages/memory-host-sdk/src/runtime-files.js"; +export * from "../memory-host-sdk/runtime-files.js"; diff --git a/src/plugin-sdk/memory-core-host-secret.ts b/src/plugin-sdk/memory-core-host-secret.ts index f293730b357..8d5257e823f 100644 --- a/src/plugin-sdk/memory-core-host-secret.ts +++ b/src/plugin-sdk/memory-core-host-secret.ts @@ -1 +1 @@ -export * from "../../packages/memory-host-sdk/src/secret.js"; +export * from "../memory-host-sdk/secret.js"; diff --git a/src/plugin-sdk/memory-core-host-status.ts b/src/plugin-sdk/memory-core-host-status.ts index 704b37737b4..fe626a778fa 100644 --- a/src/plugin-sdk/memory-core-host-status.ts +++ b/src/plugin-sdk/memory-core-host-status.ts @@ -1 +1 @@ -export * from "../../packages/memory-host-sdk/src/status.js"; +export * from "../memory-host-sdk/status.js"; diff --git a/src/plugins/contracts/plugin-sdk-subpaths.test.ts b/src/plugins/contracts/plugin-sdk-subpaths.test.ts index 3f6d84e63ed..67736517675 100644 --- a/src/plugins/contracts/plugin-sdk-subpaths.test.ts +++ b/src/plugins/contracts/plugin-sdk-subpaths.test.ts @@ -363,15 +363,15 @@ describe("plugin-sdk subpath exports", () => { ]); expectSourceContains( "memory-core-host-runtime-core", - 'export * from "../../packages/memory-host-sdk/src/runtime-core.js";', + 'export * from "../memory-host-sdk/runtime-core.js";', ); expectSourceContains( "memory-core-host-runtime-cli", - 'export * from "../../packages/memory-host-sdk/src/runtime-cli.js";', + 'export * from "../memory-host-sdk/runtime-cli.js";', ); expectSourceContains( "memory-core-host-runtime-files", - 'export * from "../../packages/memory-host-sdk/src/runtime-files.js";', + 'export * from "../memory-host-sdk/runtime-files.js";', ); }); diff --git a/src/plugins/memory-embedding-providers.ts b/src/plugins/memory-embedding-providers.ts index 167773ae0df..ae6ea2dd1b0 100644 --- a/src/plugins/memory-embedding-providers.ts +++ b/src/plugins/memory-embedding-providers.ts @@ -1,6 +1,6 @@ -import type { EmbeddingInput } from "../../packages/memory-host-sdk/src/host/embedding-inputs.js"; import type { OpenClawConfig } from "../config/config.js"; import type { SecretInput } from "../config/types.secrets.js"; +import type { EmbeddingInput } from "../memory-host-sdk/engine-embeddings.js"; export type MemoryEmbeddingBatchChunk = { text: string; diff --git a/test/helpers/memory-tool-manager-mock.ts b/test/helpers/memory-tool-manager-mock.ts index cd32a7e06a2..12a5e8aeebb 100644 --- a/test/helpers/memory-tool-manager-mock.ts +++ b/test/helpers/memory-tool-manager-mock.ts @@ -48,7 +48,7 @@ vi.mock(memoryIndexModuleId, () => ({ getMemorySearchManager: getMemorySearchManagerMock, })); -vi.mock("../../packages/memory-host-sdk/src/host/read-file.js", () => ({ +vi.mock("../../src/memory-host-sdk/host/read-file.js", () => ({ readAgentMemoryFile: readAgentMemoryFileMock, }));