mirror of
https://github.com/openclaw/openclaw.git
synced 2026-05-06 15:30:47 +00:00
feat(github-copilot): add embedding provider for memory search (#61718)
Merged via squash.
Prepared head SHA: 05a78ce7f2
Co-authored-by: feiskyer <676637+feiskyer@users.noreply.github.com>
Co-authored-by: vincentkoc <25068+vincentkoc@users.noreply.github.com>
Reviewed-by: @vincentkoc
This commit is contained in:
@@ -30,6 +30,10 @@ export {
|
||||
createMistralEmbeddingProvider,
|
||||
DEFAULT_MISTRAL_EMBEDDING_MODEL,
|
||||
} from "./host/embeddings-mistral.js";
|
||||
export {
|
||||
createGitHubCopilotEmbeddingProvider,
|
||||
type GitHubCopilotEmbeddingClient,
|
||||
} from "./host/embeddings-github-copilot.js";
|
||||
export {
|
||||
createOllamaEmbeddingProvider,
|
||||
DEFAULT_OLLAMA_EMBEDDING_MODEL,
|
||||
|
||||
178
src/memory-host-sdk/host/embeddings-github-copilot.test.ts
Normal file
178
src/memory-host-sdk/host/embeddings-github-copilot.test.ts
Normal file
@@ -0,0 +1,178 @@
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
|
||||
|
||||
const resolveCopilotApiTokenMock = vi.hoisted(() => vi.fn());
|
||||
const fetchWithSsrFGuardMock = vi.hoisted(() => vi.fn());
|
||||
|
||||
vi.mock("../../agents/github-copilot-token.js", () => ({
|
||||
DEFAULT_COPILOT_API_BASE_URL: "https://api.githubcopilot.test",
|
||||
resolveCopilotApiToken: resolveCopilotApiTokenMock,
|
||||
}));
|
||||
|
||||
vi.mock("../../infra/net/fetch-guard.js", () => ({
|
||||
fetchWithSsrFGuard: fetchWithSsrFGuardMock,
|
||||
}));
|
||||
|
||||
import { createGitHubCopilotEmbeddingProvider } from "./embeddings-github-copilot.js";
|
||||
|
||||
function mockFetchResponse(spec: { ok: boolean; status?: number; json?: unknown; text?: string }) {
|
||||
fetchWithSsrFGuardMock.mockImplementationOnce(async () => ({
|
||||
response: {
|
||||
ok: spec.ok,
|
||||
status: spec.status ?? (spec.ok ? 200 : 500),
|
||||
json: async () => spec.json,
|
||||
text: async () => spec.text ?? "",
|
||||
},
|
||||
release: vi.fn(async () => {}),
|
||||
}));
|
||||
}
|
||||
|
||||
describe("createGitHubCopilotEmbeddingProvider", () => {
|
||||
beforeEach(() => {
|
||||
resolveCopilotApiTokenMock.mockResolvedValue({
|
||||
token: "copilot-token-a",
|
||||
expiresAt: Date.now() + 3_600_000,
|
||||
source: "test",
|
||||
baseUrl: "https://api.githubcopilot.test",
|
||||
});
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
resolveCopilotApiTokenMock.mockReset();
|
||||
fetchWithSsrFGuardMock.mockReset();
|
||||
});
|
||||
|
||||
it("normalizes embeddings returned for queries", async () => {
|
||||
mockFetchResponse({
|
||||
ok: true,
|
||||
json: {
|
||||
data: [{ index: 0, embedding: [3, 4] }],
|
||||
},
|
||||
});
|
||||
|
||||
const { provider } = await createGitHubCopilotEmbeddingProvider({
|
||||
githubToken: "gh_test",
|
||||
model: "text-embedding-3-small",
|
||||
});
|
||||
|
||||
await expect(provider.embedQuery("hello")).resolves.toEqual([0.6, 0.8]);
|
||||
expect(fetchWithSsrFGuardMock).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
url: "https://api.githubcopilot.test/embeddings",
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it("preserves input order by explicit response index", async () => {
|
||||
mockFetchResponse({
|
||||
ok: true,
|
||||
json: {
|
||||
data: [
|
||||
{ index: 1, embedding: [0, 2] },
|
||||
{ index: 0, embedding: [1, 0] },
|
||||
],
|
||||
},
|
||||
});
|
||||
|
||||
const { provider } = await createGitHubCopilotEmbeddingProvider({
|
||||
githubToken: "gh_test",
|
||||
model: "text-embedding-3-small",
|
||||
});
|
||||
|
||||
await expect(provider.embedBatch(["first", "second"])).resolves.toEqual([
|
||||
[1, 0],
|
||||
[0, 1],
|
||||
]);
|
||||
});
|
||||
|
||||
it("uses a fresh Copilot token for later requests", async () => {
|
||||
resolveCopilotApiTokenMock
|
||||
.mockResolvedValueOnce({
|
||||
token: "copilot-token-create",
|
||||
expiresAt: Date.now() + 3_600_000,
|
||||
source: "test",
|
||||
baseUrl: "https://api.githubcopilot.test",
|
||||
})
|
||||
.mockResolvedValueOnce({
|
||||
token: "copilot-token-first",
|
||||
expiresAt: Date.now() + 3_600_000,
|
||||
source: "test",
|
||||
baseUrl: "https://api.githubcopilot.test",
|
||||
})
|
||||
.mockResolvedValueOnce({
|
||||
token: "copilot-token-second",
|
||||
expiresAt: Date.now() + 3_600_000,
|
||||
source: "test",
|
||||
baseUrl: "https://api.githubcopilot.test",
|
||||
});
|
||||
mockFetchResponse({
|
||||
ok: true,
|
||||
json: { data: [{ index: 0, embedding: [1, 0] }] },
|
||||
});
|
||||
mockFetchResponse({
|
||||
ok: true,
|
||||
json: { data: [{ index: 0, embedding: [0, 1] }] },
|
||||
});
|
||||
|
||||
const { provider } = await createGitHubCopilotEmbeddingProvider({
|
||||
githubToken: "gh_test",
|
||||
model: "text-embedding-3-small",
|
||||
});
|
||||
|
||||
await provider.embedQuery("first");
|
||||
await provider.embedQuery("second");
|
||||
|
||||
const firstHeaders = fetchWithSsrFGuardMock.mock.calls[0]?.[0]?.init?.headers as Record<
|
||||
string,
|
||||
string
|
||||
>;
|
||||
const secondHeaders = fetchWithSsrFGuardMock.mock.calls[1]?.[0]?.init?.headers as Record<
|
||||
string,
|
||||
string
|
||||
>;
|
||||
expect(firstHeaders.Authorization).toBe("Bearer copilot-token-first");
|
||||
expect(secondHeaders.Authorization).toBe("Bearer copilot-token-second");
|
||||
});
|
||||
|
||||
it("honors custom baseUrl and header overrides", async () => {
|
||||
mockFetchResponse({
|
||||
ok: true,
|
||||
json: { data: [{ index: 0, embedding: [1, 0] }] },
|
||||
});
|
||||
|
||||
const { provider } = await createGitHubCopilotEmbeddingProvider({
|
||||
githubToken: "gh_test",
|
||||
model: "text-embedding-3-small",
|
||||
baseUrl: "https://proxy.example/v1",
|
||||
headers: { "X-Proxy-Token": "proxy" },
|
||||
});
|
||||
|
||||
await provider.embedQuery("hello");
|
||||
|
||||
const call = fetchWithSsrFGuardMock.mock.calls[0]?.[0] as {
|
||||
init: { headers: Record<string, string> };
|
||||
url: string;
|
||||
};
|
||||
expect(call.url).toBe("https://proxy.example/v1/embeddings");
|
||||
expect(call.init.headers["X-Proxy-Token"]).toBe("proxy");
|
||||
expect(call.init.headers.Authorization).toBe("Bearer copilot-token-a");
|
||||
});
|
||||
|
||||
it("fails fast on sparse or malformed embedding payloads", async () => {
|
||||
mockFetchResponse({
|
||||
ok: true,
|
||||
json: {
|
||||
data: [{ index: 1, embedding: [1, 0] }],
|
||||
},
|
||||
});
|
||||
|
||||
const { provider } = await createGitHubCopilotEmbeddingProvider({
|
||||
githubToken: "gh_test",
|
||||
model: "text-embedding-3-small",
|
||||
});
|
||||
|
||||
await expect(provider.embedBatch(["first", "second"])).rejects.toThrow(
|
||||
"GitHub Copilot embeddings response missing vectors for some inputs",
|
||||
);
|
||||
});
|
||||
});
|
||||
151
src/memory-host-sdk/host/embeddings-github-copilot.ts
Normal file
151
src/memory-host-sdk/host/embeddings-github-copilot.ts
Normal file
@@ -0,0 +1,151 @@
|
||||
import {
|
||||
DEFAULT_COPILOT_API_BASE_URL,
|
||||
resolveCopilotApiToken,
|
||||
} from "../../agents/github-copilot-token.js";
|
||||
import { sanitizeAndNormalizeEmbedding } from "./embedding-vectors.js";
|
||||
import type { EmbeddingProvider } from "./embeddings.types.js";
|
||||
import { buildRemoteBaseUrlPolicy, withRemoteHttpResponse } from "./remote-http.js";
|
||||
|
||||
export type GitHubCopilotEmbeddingClient = {
|
||||
githubToken: string;
|
||||
model: string;
|
||||
baseUrl?: string;
|
||||
headers?: Record<string, string>;
|
||||
env?: NodeJS.ProcessEnv;
|
||||
fetchImpl?: typeof fetch;
|
||||
};
|
||||
|
||||
const COPILOT_EMBEDDING_PROVIDER_ID = "github-copilot";
|
||||
|
||||
const COPILOT_HEADERS_STATIC: Record<string, string> = {
|
||||
"Content-Type": "application/json",
|
||||
"Editor-Version": "vscode/1.96.2",
|
||||
"User-Agent": "GitHubCopilotChat/0.26.7",
|
||||
};
|
||||
|
||||
function resolveConfiguredBaseUrl(
|
||||
configuredBaseUrl: string | undefined,
|
||||
tokenBaseUrl: string | undefined,
|
||||
): string {
|
||||
const trimmed = configuredBaseUrl?.trim();
|
||||
if (trimmed) {
|
||||
return trimmed;
|
||||
}
|
||||
return tokenBaseUrl || DEFAULT_COPILOT_API_BASE_URL;
|
||||
}
|
||||
|
||||
async function resolveGitHubCopilotEmbeddingSession(client: GitHubCopilotEmbeddingClient): Promise<{
|
||||
baseUrl: string;
|
||||
headers: Record<string, string>;
|
||||
}> {
|
||||
const token = await resolveCopilotApiToken({
|
||||
githubToken: client.githubToken,
|
||||
env: client.env,
|
||||
fetchImpl: client.fetchImpl,
|
||||
});
|
||||
const baseUrl = resolveConfiguredBaseUrl(client.baseUrl, token.baseUrl);
|
||||
return {
|
||||
baseUrl,
|
||||
headers: {
|
||||
...COPILOT_HEADERS_STATIC,
|
||||
...client.headers,
|
||||
Authorization: `Bearer ${token.token}`,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
function parseGitHubCopilotEmbeddingPayload(payload: unknown, expectedCount: number): number[][] {
|
||||
if (!payload || typeof payload !== "object") {
|
||||
throw new Error("GitHub Copilot embeddings response missing data[]");
|
||||
}
|
||||
const data = (payload as { data?: unknown }).data;
|
||||
if (!Array.isArray(data)) {
|
||||
throw new Error("GitHub Copilot embeddings response missing data[]");
|
||||
}
|
||||
|
||||
const vectors = Array.from<number[] | undefined>({ length: expectedCount });
|
||||
for (const entry of data) {
|
||||
if (!entry || typeof entry !== "object") {
|
||||
throw new Error("GitHub Copilot embeddings response contains an invalid entry");
|
||||
}
|
||||
const indexValue = (entry as { index?: unknown }).index;
|
||||
const embedding = (entry as { embedding?: unknown }).embedding;
|
||||
const index = typeof indexValue === "number" ? indexValue : Number.NaN;
|
||||
if (!Number.isInteger(index)) {
|
||||
throw new Error("GitHub Copilot embeddings response contains an invalid index");
|
||||
}
|
||||
if (index < 0 || index >= expectedCount) {
|
||||
throw new Error("GitHub Copilot embeddings response contains an out-of-range index");
|
||||
}
|
||||
if (vectors[index] !== undefined) {
|
||||
throw new Error("GitHub Copilot embeddings response contains duplicate indexes");
|
||||
}
|
||||
if (!Array.isArray(embedding) || !embedding.every((value) => typeof value === "number")) {
|
||||
throw new Error("GitHub Copilot embeddings response contains an invalid embedding");
|
||||
}
|
||||
vectors[index] = sanitizeAndNormalizeEmbedding(embedding);
|
||||
}
|
||||
|
||||
for (let index = 0; index < expectedCount; index += 1) {
|
||||
if (vectors[index] === undefined) {
|
||||
throw new Error("GitHub Copilot embeddings response missing vectors for some inputs");
|
||||
}
|
||||
}
|
||||
return vectors as number[][];
|
||||
}
|
||||
|
||||
export async function createGitHubCopilotEmbeddingProvider(
|
||||
client: GitHubCopilotEmbeddingClient,
|
||||
): Promise<{ provider: EmbeddingProvider; client: GitHubCopilotEmbeddingClient }> {
|
||||
const initialSession = await resolveGitHubCopilotEmbeddingSession(client);
|
||||
|
||||
const embed = async (input: string[]): Promise<number[][]> => {
|
||||
if (input.length === 0) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const session = await resolveGitHubCopilotEmbeddingSession(client);
|
||||
const url = `${session.baseUrl.replace(/\/$/, "")}/embeddings`;
|
||||
return await withRemoteHttpResponse({
|
||||
url,
|
||||
fetchImpl: client.fetchImpl,
|
||||
ssrfPolicy: buildRemoteBaseUrlPolicy(session.baseUrl),
|
||||
init: {
|
||||
method: "POST",
|
||||
headers: session.headers,
|
||||
body: JSON.stringify({ model: client.model, input }),
|
||||
},
|
||||
onResponse: async (response) => {
|
||||
if (!response.ok) {
|
||||
throw new Error(
|
||||
`GitHub Copilot embeddings HTTP ${response.status}: ${await response.text()}`,
|
||||
);
|
||||
}
|
||||
|
||||
let payload: unknown;
|
||||
try {
|
||||
payload = await response.json();
|
||||
} catch {
|
||||
throw new Error("GitHub Copilot embeddings returned invalid JSON");
|
||||
}
|
||||
return parseGitHubCopilotEmbeddingPayload(payload, input.length);
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
return {
|
||||
provider: {
|
||||
id: COPILOT_EMBEDDING_PROVIDER_ID,
|
||||
model: client.model,
|
||||
embedQuery: async (text) => {
|
||||
const [vector] = await embed([text]);
|
||||
return vector ?? [];
|
||||
},
|
||||
embedBatch: embed,
|
||||
},
|
||||
client: {
|
||||
...client,
|
||||
baseUrl: initialSession.baseUrl,
|
||||
},
|
||||
};
|
||||
}
|
||||
Reference in New Issue
Block a user