feat(github-copilot): add embedding provider for memory search (#61718)

Merged via squash.

Prepared head SHA: 05a78ce7f2
Co-authored-by: feiskyer <676637+feiskyer@users.noreply.github.com>
Co-authored-by: vincentkoc <25068+vincentkoc@users.noreply.github.com>
Reviewed-by: @vincentkoc
This commit is contained in:
Pengfei Ni
2026-04-15 17:39:28 +08:00
committed by GitHub
parent 7821fae05d
commit 88d3620a85
14 changed files with 1094 additions and 69 deletions

View File

@@ -30,6 +30,10 @@ export {
createMistralEmbeddingProvider,
DEFAULT_MISTRAL_EMBEDDING_MODEL,
} from "./host/embeddings-mistral.js";
export {
createGitHubCopilotEmbeddingProvider,
type GitHubCopilotEmbeddingClient,
} from "./host/embeddings-github-copilot.js";
export {
createOllamaEmbeddingProvider,
DEFAULT_OLLAMA_EMBEDDING_MODEL,

View File

@@ -0,0 +1,178 @@
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
const resolveCopilotApiTokenMock = vi.hoisted(() => vi.fn());
const fetchWithSsrFGuardMock = vi.hoisted(() => vi.fn());
vi.mock("../../agents/github-copilot-token.js", () => ({
DEFAULT_COPILOT_API_BASE_URL: "https://api.githubcopilot.test",
resolveCopilotApiToken: resolveCopilotApiTokenMock,
}));
vi.mock("../../infra/net/fetch-guard.js", () => ({
fetchWithSsrFGuard: fetchWithSsrFGuardMock,
}));
import { createGitHubCopilotEmbeddingProvider } from "./embeddings-github-copilot.js";
function mockFetchResponse(spec: { ok: boolean; status?: number; json?: unknown; text?: string }) {
fetchWithSsrFGuardMock.mockImplementationOnce(async () => ({
response: {
ok: spec.ok,
status: spec.status ?? (spec.ok ? 200 : 500),
json: async () => spec.json,
text: async () => spec.text ?? "",
},
release: vi.fn(async () => {}),
}));
}
describe("createGitHubCopilotEmbeddingProvider", () => {
beforeEach(() => {
resolveCopilotApiTokenMock.mockResolvedValue({
token: "copilot-token-a",
expiresAt: Date.now() + 3_600_000,
source: "test",
baseUrl: "https://api.githubcopilot.test",
});
});
afterEach(() => {
vi.restoreAllMocks();
resolveCopilotApiTokenMock.mockReset();
fetchWithSsrFGuardMock.mockReset();
});
it("normalizes embeddings returned for queries", async () => {
mockFetchResponse({
ok: true,
json: {
data: [{ index: 0, embedding: [3, 4] }],
},
});
const { provider } = await createGitHubCopilotEmbeddingProvider({
githubToken: "gh_test",
model: "text-embedding-3-small",
});
await expect(provider.embedQuery("hello")).resolves.toEqual([0.6, 0.8]);
expect(fetchWithSsrFGuardMock).toHaveBeenCalledWith(
expect.objectContaining({
url: "https://api.githubcopilot.test/embeddings",
}),
);
});
it("preserves input order by explicit response index", async () => {
mockFetchResponse({
ok: true,
json: {
data: [
{ index: 1, embedding: [0, 2] },
{ index: 0, embedding: [1, 0] },
],
},
});
const { provider } = await createGitHubCopilotEmbeddingProvider({
githubToken: "gh_test",
model: "text-embedding-3-small",
});
await expect(provider.embedBatch(["first", "second"])).resolves.toEqual([
[1, 0],
[0, 1],
]);
});
it("uses a fresh Copilot token for later requests", async () => {
resolveCopilotApiTokenMock
.mockResolvedValueOnce({
token: "copilot-token-create",
expiresAt: Date.now() + 3_600_000,
source: "test",
baseUrl: "https://api.githubcopilot.test",
})
.mockResolvedValueOnce({
token: "copilot-token-first",
expiresAt: Date.now() + 3_600_000,
source: "test",
baseUrl: "https://api.githubcopilot.test",
})
.mockResolvedValueOnce({
token: "copilot-token-second",
expiresAt: Date.now() + 3_600_000,
source: "test",
baseUrl: "https://api.githubcopilot.test",
});
mockFetchResponse({
ok: true,
json: { data: [{ index: 0, embedding: [1, 0] }] },
});
mockFetchResponse({
ok: true,
json: { data: [{ index: 0, embedding: [0, 1] }] },
});
const { provider } = await createGitHubCopilotEmbeddingProvider({
githubToken: "gh_test",
model: "text-embedding-3-small",
});
await provider.embedQuery("first");
await provider.embedQuery("second");
const firstHeaders = fetchWithSsrFGuardMock.mock.calls[0]?.[0]?.init?.headers as Record<
string,
string
>;
const secondHeaders = fetchWithSsrFGuardMock.mock.calls[1]?.[0]?.init?.headers as Record<
string,
string
>;
expect(firstHeaders.Authorization).toBe("Bearer copilot-token-first");
expect(secondHeaders.Authorization).toBe("Bearer copilot-token-second");
});
it("honors custom baseUrl and header overrides", async () => {
mockFetchResponse({
ok: true,
json: { data: [{ index: 0, embedding: [1, 0] }] },
});
const { provider } = await createGitHubCopilotEmbeddingProvider({
githubToken: "gh_test",
model: "text-embedding-3-small",
baseUrl: "https://proxy.example/v1",
headers: { "X-Proxy-Token": "proxy" },
});
await provider.embedQuery("hello");
const call = fetchWithSsrFGuardMock.mock.calls[0]?.[0] as {
init: { headers: Record<string, string> };
url: string;
};
expect(call.url).toBe("https://proxy.example/v1/embeddings");
expect(call.init.headers["X-Proxy-Token"]).toBe("proxy");
expect(call.init.headers.Authorization).toBe("Bearer copilot-token-a");
});
it("fails fast on sparse or malformed embedding payloads", async () => {
mockFetchResponse({
ok: true,
json: {
data: [{ index: 1, embedding: [1, 0] }],
},
});
const { provider } = await createGitHubCopilotEmbeddingProvider({
githubToken: "gh_test",
model: "text-embedding-3-small",
});
await expect(provider.embedBatch(["first", "second"])).rejects.toThrow(
"GitHub Copilot embeddings response missing vectors for some inputs",
);
});
});

View File

@@ -0,0 +1,151 @@
import {
DEFAULT_COPILOT_API_BASE_URL,
resolveCopilotApiToken,
} from "../../agents/github-copilot-token.js";
import { sanitizeAndNormalizeEmbedding } from "./embedding-vectors.js";
import type { EmbeddingProvider } from "./embeddings.types.js";
import { buildRemoteBaseUrlPolicy, withRemoteHttpResponse } from "./remote-http.js";
export type GitHubCopilotEmbeddingClient = {
githubToken: string;
model: string;
baseUrl?: string;
headers?: Record<string, string>;
env?: NodeJS.ProcessEnv;
fetchImpl?: typeof fetch;
};
const COPILOT_EMBEDDING_PROVIDER_ID = "github-copilot";
const COPILOT_HEADERS_STATIC: Record<string, string> = {
"Content-Type": "application/json",
"Editor-Version": "vscode/1.96.2",
"User-Agent": "GitHubCopilotChat/0.26.7",
};
function resolveConfiguredBaseUrl(
configuredBaseUrl: string | undefined,
tokenBaseUrl: string | undefined,
): string {
const trimmed = configuredBaseUrl?.trim();
if (trimmed) {
return trimmed;
}
return tokenBaseUrl || DEFAULT_COPILOT_API_BASE_URL;
}
async function resolveGitHubCopilotEmbeddingSession(client: GitHubCopilotEmbeddingClient): Promise<{
baseUrl: string;
headers: Record<string, string>;
}> {
const token = await resolveCopilotApiToken({
githubToken: client.githubToken,
env: client.env,
fetchImpl: client.fetchImpl,
});
const baseUrl = resolveConfiguredBaseUrl(client.baseUrl, token.baseUrl);
return {
baseUrl,
headers: {
...COPILOT_HEADERS_STATIC,
...client.headers,
Authorization: `Bearer ${token.token}`,
},
};
}
function parseGitHubCopilotEmbeddingPayload(payload: unknown, expectedCount: number): number[][] {
if (!payload || typeof payload !== "object") {
throw new Error("GitHub Copilot embeddings response missing data[]");
}
const data = (payload as { data?: unknown }).data;
if (!Array.isArray(data)) {
throw new Error("GitHub Copilot embeddings response missing data[]");
}
const vectors = Array.from<number[] | undefined>({ length: expectedCount });
for (const entry of data) {
if (!entry || typeof entry !== "object") {
throw new Error("GitHub Copilot embeddings response contains an invalid entry");
}
const indexValue = (entry as { index?: unknown }).index;
const embedding = (entry as { embedding?: unknown }).embedding;
const index = typeof indexValue === "number" ? indexValue : Number.NaN;
if (!Number.isInteger(index)) {
throw new Error("GitHub Copilot embeddings response contains an invalid index");
}
if (index < 0 || index >= expectedCount) {
throw new Error("GitHub Copilot embeddings response contains an out-of-range index");
}
if (vectors[index] !== undefined) {
throw new Error("GitHub Copilot embeddings response contains duplicate indexes");
}
if (!Array.isArray(embedding) || !embedding.every((value) => typeof value === "number")) {
throw new Error("GitHub Copilot embeddings response contains an invalid embedding");
}
vectors[index] = sanitizeAndNormalizeEmbedding(embedding);
}
for (let index = 0; index < expectedCount; index += 1) {
if (vectors[index] === undefined) {
throw new Error("GitHub Copilot embeddings response missing vectors for some inputs");
}
}
return vectors as number[][];
}
export async function createGitHubCopilotEmbeddingProvider(
client: GitHubCopilotEmbeddingClient,
): Promise<{ provider: EmbeddingProvider; client: GitHubCopilotEmbeddingClient }> {
const initialSession = await resolveGitHubCopilotEmbeddingSession(client);
const embed = async (input: string[]): Promise<number[][]> => {
if (input.length === 0) {
return [];
}
const session = await resolveGitHubCopilotEmbeddingSession(client);
const url = `${session.baseUrl.replace(/\/$/, "")}/embeddings`;
return await withRemoteHttpResponse({
url,
fetchImpl: client.fetchImpl,
ssrfPolicy: buildRemoteBaseUrlPolicy(session.baseUrl),
init: {
method: "POST",
headers: session.headers,
body: JSON.stringify({ model: client.model, input }),
},
onResponse: async (response) => {
if (!response.ok) {
throw new Error(
`GitHub Copilot embeddings HTTP ${response.status}: ${await response.text()}`,
);
}
let payload: unknown;
try {
payload = await response.json();
} catch {
throw new Error("GitHub Copilot embeddings returned invalid JSON");
}
return parseGitHubCopilotEmbeddingPayload(payload, input.length);
},
});
};
return {
provider: {
id: COPILOT_EMBEDDING_PROVIDER_ID,
model: client.model,
embedQuery: async (text) => {
const [vector] = await embed([text]);
return vector ?? [];
},
embedBatch: embed,
},
client: {
...client,
baseUrl: initialSession.baseUrl,
},
};
}