fix(lmstudio): promote bracketed tool calls

This commit is contained in:
Peter Steinberger
2026-04-27 08:38:42 +01:00
parent d5e6abcb3d
commit da55212c6e
8 changed files with 740 additions and 4 deletions

View File

@@ -0,0 +1,167 @@
import { randomUUID } from "node:crypto";
export type LmstudioPlainTextToolCallBlock = {
arguments: Record<string, unknown>;
name: string;
};
const END_TOOL_REQUEST = "[END_TOOL_REQUEST]";
const MAX_PAYLOAD_CHARS = 256_000;
function isToolNameChar(char: string | undefined): boolean {
return Boolean(char && /[A-Za-z0-9_-]/.test(char));
}
function skipHorizontalWhitespace(text: string, start: number): number {
let index = start;
while (index < text.length && (text[index] === " " || text[index] === "\t")) {
index += 1;
}
return index;
}
function skipWhitespace(text: string, start: number): number {
let index = start;
while (index < text.length && /\s/.test(text[index] ?? "")) {
index += 1;
}
return index;
}
function consumeLineBreak(text: string, start: number): number | null {
if (text[start] === "\r") {
return text[start + 1] === "\n" ? start + 2 : start + 1;
}
if (text[start] === "\n") {
return start + 1;
}
return null;
}
function parseOpening(text: string, start: number): { end: number; name: string } | null {
if (text[start] !== "[") {
return null;
}
let cursor = start + 1;
const nameStart = cursor;
while (isToolNameChar(text[cursor])) {
cursor += 1;
}
if (cursor === nameStart || text[cursor] !== "]") {
return null;
}
const name = text.slice(nameStart, cursor);
cursor += 1;
cursor = skipHorizontalWhitespace(text, cursor);
const afterLineBreak = consumeLineBreak(text, cursor);
if (afterLineBreak === null) {
return null;
}
return { end: afterLineBreak, name };
}
function consumeJsonObject(
text: string,
start: number,
): { end: number; value: Record<string, unknown> } | null {
const cursor = skipWhitespace(text, start);
if (text[cursor] !== "{") {
return null;
}
let depth = 0;
let inString = false;
let escaped = false;
for (let index = cursor; index < text.length; index += 1) {
if (index + 1 - cursor > MAX_PAYLOAD_CHARS) {
return null;
}
const char = text[index];
if (inString) {
if (escaped) {
escaped = false;
} else if (char === "\\") {
escaped = true;
} else if (char === '"') {
inString = false;
}
continue;
}
if (char === '"') {
inString = true;
continue;
}
if (char === "{") {
depth += 1;
} else if (char === "}") {
depth -= 1;
if (depth === 0) {
try {
const parsed = JSON.parse(text.slice(cursor, index + 1)) as unknown;
if (!parsed || typeof parsed !== "object" || Array.isArray(parsed)) {
return null;
}
return { end: index + 1, value: parsed as Record<string, unknown> };
} catch {
return null;
}
}
}
}
return null;
}
function parseClosing(text: string, start: number, name: string): number | null {
const cursor = skipWhitespace(text, start);
if (text.startsWith(END_TOOL_REQUEST, cursor)) {
return cursor + END_TOOL_REQUEST.length;
}
const namedClosing = `[/${name}]`;
if (text.startsWith(namedClosing, cursor)) {
return cursor + namedClosing.length;
}
return null;
}
function parseBlockAt(
text: string,
start: number,
allowedToolNames: Set<string>,
): { block: LmstudioPlainTextToolCallBlock; end: number } | null {
const opening = parseOpening(text, start);
if (!opening || !allowedToolNames.has(opening.name)) {
return null;
}
const payload = consumeJsonObject(text, opening.end);
if (!payload) {
return null;
}
const end = parseClosing(text, payload.end, opening.name);
if (end === null) {
return null;
}
return {
block: { arguments: payload.value, name: opening.name },
end,
};
}
export function parseLmstudioPlainTextToolCalls(
text: string,
allowedToolNames: Set<string>,
): LmstudioPlainTextToolCallBlock[] | null {
const blocks: LmstudioPlainTextToolCallBlock[] = [];
let cursor = skipWhitespace(text, 0);
while (cursor < text.length) {
const parsed = parseBlockAt(text, cursor, allowedToolNames);
if (!parsed) {
return null;
}
blocks.push(parsed.block);
cursor = skipWhitespace(text, parsed.end);
}
return blocks.length > 0 ? blocks : null;
}
export function createLmstudioSyntheticToolCallId(): string {
return `call_${randomUUID().replace(/-/g, "").slice(0, 24)}`;
}

View File

@@ -28,7 +28,7 @@ vi.mock("./runtime.js", async (importOriginal) => {
};
});
type StreamEvent = { type: string };
type StreamEvent = { type: string } & Record<string, unknown>;
async function collectEvents(stream: ReturnType<StreamFn>): Promise<StreamEvent[]> {
const resolved = stream instanceof Promise ? await stream : stream;
@@ -50,6 +50,19 @@ function buildDoneStreamFn(): StreamFn {
});
}
function buildEventStreamFn(events: unknown[]): StreamFn {
return vi.fn((_model, _context, _options) => {
const stream = createAssistantMessageEventStream();
queueMicrotask(() => {
for (const event of events) {
stream.push(event as never);
}
stream.end();
});
return stream;
});
}
function createWrappedLmstudioStream(
baseStream: StreamFn,
params?: { baseUrl?: string },
@@ -75,6 +88,7 @@ function runWrappedLmstudioStream(
wrapped: StreamFn,
model: Record<string, unknown>,
options?: Record<string, unknown>,
context?: Record<string, unknown>,
) {
return wrapped(
{
@@ -83,7 +97,7 @@ function runWrappedLmstudioStream(
id: "lmstudio/qwen3-8b-instruct",
...model,
} as never,
{ messages: [] } as never,
{ messages: [], ...context } as never,
options as never,
);
}
@@ -400,4 +414,99 @@ describe("lmstudio stream wrapper", () => {
undefined,
);
});
it("promotes standalone bracketed local-model tool text to a structured tool call", async () => {
const rawToolText = [
"[mempalace_mempalace_search]",
'{"query":"codename","wing":"personal","room":"identities"}',
"[END_TOOL_REQUEST]",
].join("\n");
const baseStream = buildEventStreamFn([
{ type: "start", partial: { content: [] } },
{ type: "text_start", contentIndex: 0, partial: { content: [{ type: "text", text: "" }] } },
{ type: "text_delta", contentIndex: 0, delta: rawToolText },
{ type: "text_end", contentIndex: 0, content: rawToolText },
{
type: "done",
reason: "stop",
message: {
role: "assistant",
content: [{ type: "text", text: rawToolText }],
stopReason: "stop",
},
},
]);
const wrapped = createWrappedLmstudioStream(baseStream);
const events = await collectEvents(
runWrappedLmstudioStream(wrapped, {}, undefined, {
tools: [
{
name: "mempalace_mempalace_search",
description: "Search MemPalace",
parameters: { type: "object", properties: {} },
},
],
}),
);
expect(events.map((event) => event.type)).toEqual([
"start",
"toolcall_start",
"toolcall_delta",
"done",
]);
expect(events.some((event) => event.type === "text_delta")).toBe(false);
const done = events.find((event) => event.type === "done") as {
message?: { content?: Array<Record<string, unknown>>; stopReason?: string };
reason?: string;
};
expect(done.reason).toBe("toolUse");
expect(done.message?.stopReason).toBe("toolUse");
expect(done.message?.content?.[0]).toMatchObject({
type: "toolCall",
name: "mempalace_mempalace_search",
arguments: { query: "codename", wing: "personal", room: "identities" },
});
expect(String(done.message?.content?.[0]?.id)).toMatch(/^call_[a-f0-9]{24}$/);
});
it("passes through bracketed text when the tool is not registered", async () => {
const rawToolText = [
"[mempalace_mempalace_search]",
'{"query":"codename"}',
"[/mempalace_mempalace_search]",
].join("\n");
const baseStream = buildEventStreamFn([
{ type: "start", partial: { content: [] } },
{ type: "text_start", contentIndex: 0, partial: { content: [{ type: "text", text: "" }] } },
{ type: "text_delta", contentIndex: 0, delta: rawToolText },
{ type: "text_end", contentIndex: 0, content: rawToolText },
{
type: "done",
reason: "stop",
message: {
role: "assistant",
content: [{ type: "text", text: rawToolText }],
stopReason: "stop",
},
},
]);
const wrapped = createWrappedLmstudioStream(baseStream);
const events = await collectEvents(
runWrappedLmstudioStream(wrapped, {}, undefined, {
tools: [{ name: "read", description: "Read", parameters: { type: "object" } }],
}),
);
expect(events.map((event) => event.type)).toEqual([
"start",
"text_start",
"text_delta",
"text_end",
"done",
]);
expect(events.find((event) => event.type === "text_delta")).toMatchObject({
delta: rawToolText,
});
});
});

View File

@@ -1,17 +1,22 @@
import type { StreamFn } from "@mariozechner/pi-agent-core";
import { streamSimple } from "@mariozechner/pi-ai";
import { createAssistantMessageEventStream, streamSimple } from "@mariozechner/pi-ai";
import { createSubsystemLogger } from "openclaw/plugin-sdk/logging-core";
import type { ProviderWrapStreamFnContext } from "openclaw/plugin-sdk/plugin-entry";
import { ssrfPolicyFromHttpBaseUrlAllowedHostname } from "openclaw/plugin-sdk/ssrf-runtime";
import { LMSTUDIO_PROVIDER_ID } from "./defaults.js";
import { ensureLmstudioModelLoaded } from "./models.fetch.js";
import { resolveLmstudioInferenceBase } from "./models.js";
import {
createLmstudioSyntheticToolCallId,
parseLmstudioPlainTextToolCalls,
} from "./plain-text-tool-calls.js";
import { resolveLmstudioProviderHeaders, resolveLmstudioRuntimeApiKey } from "./runtime.js";
const log = createSubsystemLogger("extensions/lmstudio/stream");
type StreamOptions = Parameters<StreamFn>[2];
type StreamModel = Parameters<StreamFn>[0];
type StreamContext = Parameters<StreamFn>[1];
const preloadInFlight = new Map<string, Promise<void>>();
@@ -112,6 +117,215 @@ function resolveModelHeaders(model: StreamModel): Record<string, string> | undef
return model.headers;
}
function toRecord(value: unknown): Record<string, unknown> | undefined {
return value && typeof value === "object" ? (value as Record<string, unknown>) : undefined;
}
function resolveContextToolNames(context: StreamContext): Set<string> {
const tools = (context as { tools?: unknown }).tools;
if (!Array.isArray(tools)) {
return new Set();
}
const names = tools
.map((tool) => {
const record = toRecord(tool);
return typeof record?.name === "string" && record.name.trim() ? record.name : undefined;
})
.filter((name): name is string => Boolean(name));
return new Set(names);
}
function couldStillBePlainTextToolCall(text: string): boolean {
if (text.length > 256_000) {
return false;
}
const trimmed = text.trimStart();
return trimmed.length === 0 || trimmed.startsWith("[");
}
function createLmstudioToolCallBlock(parsed: {
arguments: Record<string, unknown>;
name: string;
}): Record<string, unknown> {
return {
type: "toolCall",
id: createLmstudioSyntheticToolCallId(),
name: parsed.name,
arguments: parsed.arguments,
partialArgs: JSON.stringify(parsed.arguments),
};
}
function promoteLmstudioPlainTextToolCalls(
message: unknown,
toolNames: Set<string>,
): Record<string, unknown> | undefined {
const messageRecord = toRecord(message);
if (!messageRecord) {
return undefined;
}
if (!Array.isArray(messageRecord.content)) {
if (typeof messageRecord.content !== "string" || !messageRecord.content.trim()) {
return undefined;
}
const parsed = parseLmstudioPlainTextToolCalls(messageRecord.content, toolNames);
if (!parsed) {
return undefined;
}
return {
...messageRecord,
content: parsed.map(createLmstudioToolCallBlock),
stopReason: "toolUse",
};
}
if (
messageRecord.content.some((block) => toRecord(block)?.type === "toolCall") ||
messageRecord.content.length === 0
) {
return undefined;
}
let promoted = false;
const nextContent: Array<Record<string, unknown>> = [];
for (const block of messageRecord.content) {
const blockRecord = toRecord(block);
if (!blockRecord) {
return undefined;
}
if (blockRecord.type !== "text") {
nextContent.push(blockRecord);
continue;
}
const text = typeof blockRecord.text === "string" ? blockRecord.text : "";
if (!text.trim()) {
continue;
}
const parsed = parseLmstudioPlainTextToolCalls(text, toolNames);
if (!parsed) {
return undefined;
}
nextContent.push(...parsed.map(createLmstudioToolCallBlock));
promoted = true;
}
if (!promoted) {
return undefined;
}
return {
...messageRecord,
content: nextContent,
stopReason: "toolUse",
};
}
function emitPromotedToolCallEvents(
stream: { push(event: unknown): void },
message: Record<string, unknown>,
): void {
const content = Array.isArray(message.content) ? message.content : [];
content.forEach((block, contentIndex) => {
const record = toRecord(block);
if (record?.type !== "toolCall") {
return;
}
stream.push({ type: "toolcall_start", contentIndex, partial: message });
stream.push({
type: "toolcall_delta",
contentIndex,
delta: typeof record.partialArgs === "string" ? record.partialArgs : "{}",
partial: message,
});
});
}
function wrapLmstudioPlainTextToolCalls(
source: ReturnType<StreamFn>,
context: StreamContext,
): ReturnType<StreamFn> {
const toolNames = resolveContextToolNames(context);
if (toolNames.size === 0) {
return source;
}
const output = createAssistantMessageEventStream();
const stream = output as unknown as { push(event: unknown): void; end(): void };
void (async () => {
const bufferedTextEvents: unknown[] = [];
let bufferedText = "";
let ended = false;
const endStream = () => {
if (!ended) {
ended = true;
stream.end();
}
};
const flushBufferedTextEvents = () => {
for (const event of bufferedTextEvents.splice(0)) {
stream.push(event);
}
bufferedText = "";
};
try {
for await (const event of source as AsyncIterable<unknown>) {
const record = toRecord(event);
const type = typeof record?.type === "string" ? record.type : "";
if (type === "text_start" || type === "text_delta" || type === "text_end") {
bufferedTextEvents.push(event);
if (typeof record?.delta === "string") {
bufferedText += record.delta;
} else if (typeof record?.content === "string" && !bufferedText) {
bufferedText = record.content;
}
if (!couldStillBePlainTextToolCall(bufferedText)) {
flushBufferedTextEvents();
}
continue;
}
if (type === "done") {
const promotedMessage = promoteLmstudioPlainTextToolCalls(record?.message, toolNames);
if (promotedMessage) {
bufferedTextEvents.splice(0);
bufferedText = "";
emitPromotedToolCallEvents(stream, promotedMessage);
stream.push({ ...record, reason: "toolUse", message: promotedMessage });
} else {
flushBufferedTextEvents();
stream.push(event);
}
endStream();
return;
}
flushBufferedTextEvents();
stream.push(event);
if (type === "error") {
endStream();
return;
}
}
flushBufferedTextEvents();
} catch (error) {
stream.push({
type: "error",
reason: "error",
error: {
role: "assistant",
content: [],
stopReason: "error",
errorMessage: error instanceof Error ? error.message : String(error),
},
});
} finally {
endStream();
}
})();
return output as ReturnType<StreamFn>;
}
function createPreloadKey(params: {
baseUrl: string;
modelKey: string;
@@ -248,7 +462,8 @@ export function wrapLmstudioInferencePreload(ctx: ProviderWrapStreamFnContext):
},
};
const stream = underlying(modelWithUsageCompat, context, options);
return stream instanceof Promise ? await stream : stream;
const resolvedStream = stream instanceof Promise ? await stream : stream;
return wrapLmstudioPlainTextToolCalls(resolvedStream, context);
})();
};
}