diff --git a/src/browser/chrome-mcp.test.ts b/src/browser/chrome-mcp.test.ts index 5f97c533aec..03204cf3b87 100644 --- a/src/browser/chrome-mcp.test.ts +++ b/src/browser/chrome-mcp.test.ts @@ -263,6 +263,33 @@ describe("chrome MCP page parsing", () => { expect(tabs).toHaveLength(2); }); + it("creates a fresh session when userDataDir changes for the same profile", async () => { + const createdSessions: ChromeMcpSession[] = []; + const closeMocks: Array> = []; + const factoryCalls: Array<{ profileName: string; userDataDir?: string }> = []; + const factory: ChromeMcpSessionFactory = async (profileName, userDataDir) => { + factoryCalls.push({ profileName, userDataDir }); + const session = createFakeSession(); + const closeMock = vi.fn().mockResolvedValue(undefined); + session.client.close = closeMock as typeof session.client.close; + createdSessions.push(session); + closeMocks.push(closeMock); + return session; + }; + setChromeMcpSessionFactoryForTest(factory); + + await listChromeMcpTabs("chrome-live", "/tmp/brave-a"); + await listChromeMcpTabs("chrome-live", "/tmp/brave-b"); + + expect(factoryCalls).toEqual([ + { profileName: "chrome-live", userDataDir: "/tmp/brave-a" }, + { profileName: "chrome-live", userDataDir: "/tmp/brave-b" }, + ]); + expect(createdSessions).toHaveLength(2); + expect(closeMocks[0]).toHaveBeenCalledTimes(1); + expect(closeMocks[1]).not.toHaveBeenCalled(); + }); + it("clears failed pending sessions so the next call can retry", async () => { let factoryCalls = 0; const factory: ChromeMcpSessionFactory = async () => { diff --git a/src/browser/chrome-mcp.ts b/src/browser/chrome-mcp.ts index 0fba963d6a0..bc724d2eaea 100644 --- a/src/browser/chrome-mcp.ts +++ b/src/browser/chrome-mcp.ts @@ -176,6 +176,43 @@ function normalizeChromeMcpUserDataDir(userDataDir?: string): string | undefined return trimmed ? trimmed : undefined; } +function buildChromeMcpSessionCacheKey(profileName: string, userDataDir?: string): string { + return JSON.stringify([profileName, normalizeChromeMcpUserDataDir(userDataDir) ?? ""]); +} + +function cacheKeyMatchesProfileName(cacheKey: string, profileName: string): boolean { + try { + const parsed = JSON.parse(cacheKey); + return Array.isArray(parsed) && parsed[0] === profileName; + } catch { + return false; + } +} + +async function closeChromeMcpSessionsForProfile( + profileName: string, + keepKey?: string, +): Promise { + let closed = false; + + for (const key of Array.from(pendingSessions.keys())) { + if (key !== keepKey && cacheKeyMatchesProfileName(key, profileName)) { + pendingSessions.delete(key); + closed = true; + } + } + + for (const [key, session] of Array.from(sessions.entries())) { + if (key !== keepKey && cacheKeyMatchesProfileName(key, profileName)) { + sessions.delete(key); + closed = true; + await session.client.close().catch(() => {}); + } + } + + return closed; +} + export function buildChromeMcpArgs(userDataDir?: string): string[] { const normalizedUserDataDir = normalizeChromeMcpUserDataDir(userDataDir); return normalizedUserDataDir @@ -228,26 +265,33 @@ async function createRealSession( } async function getSession(profileName: string, userDataDir?: string): Promise { - let session = sessions.get(profileName); + const cacheKey = buildChromeMcpSessionCacheKey(profileName, userDataDir); + await closeChromeMcpSessionsForProfile(profileName, cacheKey); + + let session = sessions.get(cacheKey); if (session && session.transport.pid === null) { - sessions.delete(profileName); + sessions.delete(cacheKey); session = undefined; } if (!session) { - let pending = pendingSessions.get(profileName); + let pending = pendingSessions.get(cacheKey); if (!pending) { pending = (async () => { const created = await (sessionFactory ?? createRealSession)(profileName, userDataDir); - sessions.set(profileName, created); + if (pendingSessions.get(cacheKey) === pending) { + sessions.set(cacheKey, created); + } else { + await created.client.close().catch(() => {}); + } return created; })(); - pendingSessions.set(profileName, pending); + pendingSessions.set(cacheKey, pending); } try { session = await pending; } finally { - if (pendingSessions.get(profileName) === pending) { - pendingSessions.delete(profileName); + if (pendingSessions.get(cacheKey) === pending) { + pendingSessions.delete(cacheKey); } } } @@ -255,9 +299,9 @@ async function getSession(profileName: string, userDataDir?: string): Promise = {}, ): Promise { + const cacheKey = buildChromeMcpSessionCacheKey(profileName, userDataDir); const session = await getSession(profileName, userDataDir); let result: ChromeMcpToolResult; try { @@ -278,7 +323,7 @@ async function callTool( })) as ChromeMcpToolResult; } catch (err) { // Transport/connection error — tear down session so it reconnects on next call - sessions.delete(profileName); + sessions.delete(cacheKey); await session.client.close().catch(() => {}); throw err; } @@ -321,22 +366,20 @@ export async function ensureChromeMcpAvailable( } export function getChromeMcpPid(profileName: string): number | null { - return sessions.get(profileName)?.transport.pid ?? null; + for (const [key, session] of sessions.entries()) { + if (cacheKeyMatchesProfileName(key, profileName)) { + return session.transport.pid ?? null; + } + } + return null; } export async function closeChromeMcpSession(profileName: string): Promise { - pendingSessions.delete(profileName); - const session = sessions.get(profileName); - if (!session) { - return false; - } - sessions.delete(profileName); - await session.client.close().catch(() => {}); - return true; + return await closeChromeMcpSessionsForProfile(profileName); } export async function stopAllChromeMcpSessions(): Promise { - const names = [...sessions.keys()]; + const names = [...new Set([...sessions.keys()].map((key) => JSON.parse(key)[0] as string))]; for (const name of names) { await closeChromeMcpSession(name).catch(() => {}); }