mirror of
https://github.com/openclaw/openclaw.git
synced 2026-05-06 17:20:45 +00:00
fix(lmstudio): promote bracketed tool calls
This commit is contained in:
167
extensions/lmstudio/src/plain-text-tool-calls.ts
Normal file
167
extensions/lmstudio/src/plain-text-tool-calls.ts
Normal file
@@ -0,0 +1,167 @@
|
||||
import { randomUUID } from "node:crypto";
|
||||
|
||||
export type LmstudioPlainTextToolCallBlock = {
|
||||
arguments: Record<string, unknown>;
|
||||
name: string;
|
||||
};
|
||||
|
||||
const END_TOOL_REQUEST = "[END_TOOL_REQUEST]";
|
||||
const MAX_PAYLOAD_CHARS = 256_000;
|
||||
|
||||
function isToolNameChar(char: string | undefined): boolean {
|
||||
return Boolean(char && /[A-Za-z0-9_-]/.test(char));
|
||||
}
|
||||
|
||||
function skipHorizontalWhitespace(text: string, start: number): number {
|
||||
let index = start;
|
||||
while (index < text.length && (text[index] === " " || text[index] === "\t")) {
|
||||
index += 1;
|
||||
}
|
||||
return index;
|
||||
}
|
||||
|
||||
function skipWhitespace(text: string, start: number): number {
|
||||
let index = start;
|
||||
while (index < text.length && /\s/.test(text[index] ?? "")) {
|
||||
index += 1;
|
||||
}
|
||||
return index;
|
||||
}
|
||||
|
||||
function consumeLineBreak(text: string, start: number): number | null {
|
||||
if (text[start] === "\r") {
|
||||
return text[start + 1] === "\n" ? start + 2 : start + 1;
|
||||
}
|
||||
if (text[start] === "\n") {
|
||||
return start + 1;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
function parseOpening(text: string, start: number): { end: number; name: string } | null {
|
||||
if (text[start] !== "[") {
|
||||
return null;
|
||||
}
|
||||
let cursor = start + 1;
|
||||
const nameStart = cursor;
|
||||
while (isToolNameChar(text[cursor])) {
|
||||
cursor += 1;
|
||||
}
|
||||
if (cursor === nameStart || text[cursor] !== "]") {
|
||||
return null;
|
||||
}
|
||||
const name = text.slice(nameStart, cursor);
|
||||
cursor += 1;
|
||||
cursor = skipHorizontalWhitespace(text, cursor);
|
||||
const afterLineBreak = consumeLineBreak(text, cursor);
|
||||
if (afterLineBreak === null) {
|
||||
return null;
|
||||
}
|
||||
return { end: afterLineBreak, name };
|
||||
}
|
||||
|
||||
function consumeJsonObject(
|
||||
text: string,
|
||||
start: number,
|
||||
): { end: number; value: Record<string, unknown> } | null {
|
||||
const cursor = skipWhitespace(text, start);
|
||||
if (text[cursor] !== "{") {
|
||||
return null;
|
||||
}
|
||||
let depth = 0;
|
||||
let inString = false;
|
||||
let escaped = false;
|
||||
for (let index = cursor; index < text.length; index += 1) {
|
||||
if (index + 1 - cursor > MAX_PAYLOAD_CHARS) {
|
||||
return null;
|
||||
}
|
||||
const char = text[index];
|
||||
if (inString) {
|
||||
if (escaped) {
|
||||
escaped = false;
|
||||
} else if (char === "\\") {
|
||||
escaped = true;
|
||||
} else if (char === '"') {
|
||||
inString = false;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (char === '"') {
|
||||
inString = true;
|
||||
continue;
|
||||
}
|
||||
if (char === "{") {
|
||||
depth += 1;
|
||||
} else if (char === "}") {
|
||||
depth -= 1;
|
||||
if (depth === 0) {
|
||||
try {
|
||||
const parsed = JSON.parse(text.slice(cursor, index + 1)) as unknown;
|
||||
if (!parsed || typeof parsed !== "object" || Array.isArray(parsed)) {
|
||||
return null;
|
||||
}
|
||||
return { end: index + 1, value: parsed as Record<string, unknown> };
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
function parseClosing(text: string, start: number, name: string): number | null {
|
||||
const cursor = skipWhitespace(text, start);
|
||||
if (text.startsWith(END_TOOL_REQUEST, cursor)) {
|
||||
return cursor + END_TOOL_REQUEST.length;
|
||||
}
|
||||
const namedClosing = `[/${name}]`;
|
||||
if (text.startsWith(namedClosing, cursor)) {
|
||||
return cursor + namedClosing.length;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
function parseBlockAt(
|
||||
text: string,
|
||||
start: number,
|
||||
allowedToolNames: Set<string>,
|
||||
): { block: LmstudioPlainTextToolCallBlock; end: number } | null {
|
||||
const opening = parseOpening(text, start);
|
||||
if (!opening || !allowedToolNames.has(opening.name)) {
|
||||
return null;
|
||||
}
|
||||
const payload = consumeJsonObject(text, opening.end);
|
||||
if (!payload) {
|
||||
return null;
|
||||
}
|
||||
const end = parseClosing(text, payload.end, opening.name);
|
||||
if (end === null) {
|
||||
return null;
|
||||
}
|
||||
return {
|
||||
block: { arguments: payload.value, name: opening.name },
|
||||
end,
|
||||
};
|
||||
}
|
||||
|
||||
export function parseLmstudioPlainTextToolCalls(
|
||||
text: string,
|
||||
allowedToolNames: Set<string>,
|
||||
): LmstudioPlainTextToolCallBlock[] | null {
|
||||
const blocks: LmstudioPlainTextToolCallBlock[] = [];
|
||||
let cursor = skipWhitespace(text, 0);
|
||||
while (cursor < text.length) {
|
||||
const parsed = parseBlockAt(text, cursor, allowedToolNames);
|
||||
if (!parsed) {
|
||||
return null;
|
||||
}
|
||||
blocks.push(parsed.block);
|
||||
cursor = skipWhitespace(text, parsed.end);
|
||||
}
|
||||
return blocks.length > 0 ? blocks : null;
|
||||
}
|
||||
|
||||
export function createLmstudioSyntheticToolCallId(): string {
|
||||
return `call_${randomUUID().replace(/-/g, "").slice(0, 24)}`;
|
||||
}
|
||||
@@ -28,7 +28,7 @@ vi.mock("./runtime.js", async (importOriginal) => {
|
||||
};
|
||||
});
|
||||
|
||||
type StreamEvent = { type: string };
|
||||
type StreamEvent = { type: string } & Record<string, unknown>;
|
||||
|
||||
async function collectEvents(stream: ReturnType<StreamFn>): Promise<StreamEvent[]> {
|
||||
const resolved = stream instanceof Promise ? await stream : stream;
|
||||
@@ -50,6 +50,19 @@ function buildDoneStreamFn(): StreamFn {
|
||||
});
|
||||
}
|
||||
|
||||
function buildEventStreamFn(events: unknown[]): StreamFn {
|
||||
return vi.fn((_model, _context, _options) => {
|
||||
const stream = createAssistantMessageEventStream();
|
||||
queueMicrotask(() => {
|
||||
for (const event of events) {
|
||||
stream.push(event as never);
|
||||
}
|
||||
stream.end();
|
||||
});
|
||||
return stream;
|
||||
});
|
||||
}
|
||||
|
||||
function createWrappedLmstudioStream(
|
||||
baseStream: StreamFn,
|
||||
params?: { baseUrl?: string },
|
||||
@@ -75,6 +88,7 @@ function runWrappedLmstudioStream(
|
||||
wrapped: StreamFn,
|
||||
model: Record<string, unknown>,
|
||||
options?: Record<string, unknown>,
|
||||
context?: Record<string, unknown>,
|
||||
) {
|
||||
return wrapped(
|
||||
{
|
||||
@@ -83,7 +97,7 @@ function runWrappedLmstudioStream(
|
||||
id: "lmstudio/qwen3-8b-instruct",
|
||||
...model,
|
||||
} as never,
|
||||
{ messages: [] } as never,
|
||||
{ messages: [], ...context } as never,
|
||||
options as never,
|
||||
);
|
||||
}
|
||||
@@ -400,4 +414,99 @@ describe("lmstudio stream wrapper", () => {
|
||||
undefined,
|
||||
);
|
||||
});
|
||||
|
||||
it("promotes standalone bracketed local-model tool text to a structured tool call", async () => {
|
||||
const rawToolText = [
|
||||
"[mempalace_mempalace_search]",
|
||||
'{"query":"codename","wing":"personal","room":"identities"}',
|
||||
"[END_TOOL_REQUEST]",
|
||||
].join("\n");
|
||||
const baseStream = buildEventStreamFn([
|
||||
{ type: "start", partial: { content: [] } },
|
||||
{ type: "text_start", contentIndex: 0, partial: { content: [{ type: "text", text: "" }] } },
|
||||
{ type: "text_delta", contentIndex: 0, delta: rawToolText },
|
||||
{ type: "text_end", contentIndex: 0, content: rawToolText },
|
||||
{
|
||||
type: "done",
|
||||
reason: "stop",
|
||||
message: {
|
||||
role: "assistant",
|
||||
content: [{ type: "text", text: rawToolText }],
|
||||
stopReason: "stop",
|
||||
},
|
||||
},
|
||||
]);
|
||||
const wrapped = createWrappedLmstudioStream(baseStream);
|
||||
const events = await collectEvents(
|
||||
runWrappedLmstudioStream(wrapped, {}, undefined, {
|
||||
tools: [
|
||||
{
|
||||
name: "mempalace_mempalace_search",
|
||||
description: "Search MemPalace",
|
||||
parameters: { type: "object", properties: {} },
|
||||
},
|
||||
],
|
||||
}),
|
||||
);
|
||||
|
||||
expect(events.map((event) => event.type)).toEqual([
|
||||
"start",
|
||||
"toolcall_start",
|
||||
"toolcall_delta",
|
||||
"done",
|
||||
]);
|
||||
expect(events.some((event) => event.type === "text_delta")).toBe(false);
|
||||
const done = events.find((event) => event.type === "done") as {
|
||||
message?: { content?: Array<Record<string, unknown>>; stopReason?: string };
|
||||
reason?: string;
|
||||
};
|
||||
expect(done.reason).toBe("toolUse");
|
||||
expect(done.message?.stopReason).toBe("toolUse");
|
||||
expect(done.message?.content?.[0]).toMatchObject({
|
||||
type: "toolCall",
|
||||
name: "mempalace_mempalace_search",
|
||||
arguments: { query: "codename", wing: "personal", room: "identities" },
|
||||
});
|
||||
expect(String(done.message?.content?.[0]?.id)).toMatch(/^call_[a-f0-9]{24}$/);
|
||||
});
|
||||
|
||||
it("passes through bracketed text when the tool is not registered", async () => {
|
||||
const rawToolText = [
|
||||
"[mempalace_mempalace_search]",
|
||||
'{"query":"codename"}',
|
||||
"[/mempalace_mempalace_search]",
|
||||
].join("\n");
|
||||
const baseStream = buildEventStreamFn([
|
||||
{ type: "start", partial: { content: [] } },
|
||||
{ type: "text_start", contentIndex: 0, partial: { content: [{ type: "text", text: "" }] } },
|
||||
{ type: "text_delta", contentIndex: 0, delta: rawToolText },
|
||||
{ type: "text_end", contentIndex: 0, content: rawToolText },
|
||||
{
|
||||
type: "done",
|
||||
reason: "stop",
|
||||
message: {
|
||||
role: "assistant",
|
||||
content: [{ type: "text", text: rawToolText }],
|
||||
stopReason: "stop",
|
||||
},
|
||||
},
|
||||
]);
|
||||
const wrapped = createWrappedLmstudioStream(baseStream);
|
||||
const events = await collectEvents(
|
||||
runWrappedLmstudioStream(wrapped, {}, undefined, {
|
||||
tools: [{ name: "read", description: "Read", parameters: { type: "object" } }],
|
||||
}),
|
||||
);
|
||||
|
||||
expect(events.map((event) => event.type)).toEqual([
|
||||
"start",
|
||||
"text_start",
|
||||
"text_delta",
|
||||
"text_end",
|
||||
"done",
|
||||
]);
|
||||
expect(events.find((event) => event.type === "text_delta")).toMatchObject({
|
||||
delta: rawToolText,
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,17 +1,22 @@
|
||||
import type { StreamFn } from "@mariozechner/pi-agent-core";
|
||||
import { streamSimple } from "@mariozechner/pi-ai";
|
||||
import { createAssistantMessageEventStream, streamSimple } from "@mariozechner/pi-ai";
|
||||
import { createSubsystemLogger } from "openclaw/plugin-sdk/logging-core";
|
||||
import type { ProviderWrapStreamFnContext } from "openclaw/plugin-sdk/plugin-entry";
|
||||
import { ssrfPolicyFromHttpBaseUrlAllowedHostname } from "openclaw/plugin-sdk/ssrf-runtime";
|
||||
import { LMSTUDIO_PROVIDER_ID } from "./defaults.js";
|
||||
import { ensureLmstudioModelLoaded } from "./models.fetch.js";
|
||||
import { resolveLmstudioInferenceBase } from "./models.js";
|
||||
import {
|
||||
createLmstudioSyntheticToolCallId,
|
||||
parseLmstudioPlainTextToolCalls,
|
||||
} from "./plain-text-tool-calls.js";
|
||||
import { resolveLmstudioProviderHeaders, resolveLmstudioRuntimeApiKey } from "./runtime.js";
|
||||
|
||||
const log = createSubsystemLogger("extensions/lmstudio/stream");
|
||||
|
||||
type StreamOptions = Parameters<StreamFn>[2];
|
||||
type StreamModel = Parameters<StreamFn>[0];
|
||||
type StreamContext = Parameters<StreamFn>[1];
|
||||
|
||||
const preloadInFlight = new Map<string, Promise<void>>();
|
||||
|
||||
@@ -112,6 +117,215 @@ function resolveModelHeaders(model: StreamModel): Record<string, string> | undef
|
||||
return model.headers;
|
||||
}
|
||||
|
||||
function toRecord(value: unknown): Record<string, unknown> | undefined {
|
||||
return value && typeof value === "object" ? (value as Record<string, unknown>) : undefined;
|
||||
}
|
||||
|
||||
function resolveContextToolNames(context: StreamContext): Set<string> {
|
||||
const tools = (context as { tools?: unknown }).tools;
|
||||
if (!Array.isArray(tools)) {
|
||||
return new Set();
|
||||
}
|
||||
const names = tools
|
||||
.map((tool) => {
|
||||
const record = toRecord(tool);
|
||||
return typeof record?.name === "string" && record.name.trim() ? record.name : undefined;
|
||||
})
|
||||
.filter((name): name is string => Boolean(name));
|
||||
return new Set(names);
|
||||
}
|
||||
|
||||
function couldStillBePlainTextToolCall(text: string): boolean {
|
||||
if (text.length > 256_000) {
|
||||
return false;
|
||||
}
|
||||
const trimmed = text.trimStart();
|
||||
return trimmed.length === 0 || trimmed.startsWith("[");
|
||||
}
|
||||
|
||||
function createLmstudioToolCallBlock(parsed: {
|
||||
arguments: Record<string, unknown>;
|
||||
name: string;
|
||||
}): Record<string, unknown> {
|
||||
return {
|
||||
type: "toolCall",
|
||||
id: createLmstudioSyntheticToolCallId(),
|
||||
name: parsed.name,
|
||||
arguments: parsed.arguments,
|
||||
partialArgs: JSON.stringify(parsed.arguments),
|
||||
};
|
||||
}
|
||||
|
||||
function promoteLmstudioPlainTextToolCalls(
|
||||
message: unknown,
|
||||
toolNames: Set<string>,
|
||||
): Record<string, unknown> | undefined {
|
||||
const messageRecord = toRecord(message);
|
||||
if (!messageRecord) {
|
||||
return undefined;
|
||||
}
|
||||
if (!Array.isArray(messageRecord.content)) {
|
||||
if (typeof messageRecord.content !== "string" || !messageRecord.content.trim()) {
|
||||
return undefined;
|
||||
}
|
||||
const parsed = parseLmstudioPlainTextToolCalls(messageRecord.content, toolNames);
|
||||
if (!parsed) {
|
||||
return undefined;
|
||||
}
|
||||
return {
|
||||
...messageRecord,
|
||||
content: parsed.map(createLmstudioToolCallBlock),
|
||||
stopReason: "toolUse",
|
||||
};
|
||||
}
|
||||
if (
|
||||
messageRecord.content.some((block) => toRecord(block)?.type === "toolCall") ||
|
||||
messageRecord.content.length === 0
|
||||
) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
let promoted = false;
|
||||
const nextContent: Array<Record<string, unknown>> = [];
|
||||
for (const block of messageRecord.content) {
|
||||
const blockRecord = toRecord(block);
|
||||
if (!blockRecord) {
|
||||
return undefined;
|
||||
}
|
||||
if (blockRecord.type !== "text") {
|
||||
nextContent.push(blockRecord);
|
||||
continue;
|
||||
}
|
||||
const text = typeof blockRecord.text === "string" ? blockRecord.text : "";
|
||||
if (!text.trim()) {
|
||||
continue;
|
||||
}
|
||||
const parsed = parseLmstudioPlainTextToolCalls(text, toolNames);
|
||||
if (!parsed) {
|
||||
return undefined;
|
||||
}
|
||||
nextContent.push(...parsed.map(createLmstudioToolCallBlock));
|
||||
promoted = true;
|
||||
}
|
||||
|
||||
if (!promoted) {
|
||||
return undefined;
|
||||
}
|
||||
return {
|
||||
...messageRecord,
|
||||
content: nextContent,
|
||||
stopReason: "toolUse",
|
||||
};
|
||||
}
|
||||
|
||||
function emitPromotedToolCallEvents(
|
||||
stream: { push(event: unknown): void },
|
||||
message: Record<string, unknown>,
|
||||
): void {
|
||||
const content = Array.isArray(message.content) ? message.content : [];
|
||||
content.forEach((block, contentIndex) => {
|
||||
const record = toRecord(block);
|
||||
if (record?.type !== "toolCall") {
|
||||
return;
|
||||
}
|
||||
stream.push({ type: "toolcall_start", contentIndex, partial: message });
|
||||
stream.push({
|
||||
type: "toolcall_delta",
|
||||
contentIndex,
|
||||
delta: typeof record.partialArgs === "string" ? record.partialArgs : "{}",
|
||||
partial: message,
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
function wrapLmstudioPlainTextToolCalls(
|
||||
source: ReturnType<StreamFn>,
|
||||
context: StreamContext,
|
||||
): ReturnType<StreamFn> {
|
||||
const toolNames = resolveContextToolNames(context);
|
||||
if (toolNames.size === 0) {
|
||||
return source;
|
||||
}
|
||||
const output = createAssistantMessageEventStream();
|
||||
const stream = output as unknown as { push(event: unknown): void; end(): void };
|
||||
|
||||
void (async () => {
|
||||
const bufferedTextEvents: unknown[] = [];
|
||||
let bufferedText = "";
|
||||
let ended = false;
|
||||
const endStream = () => {
|
||||
if (!ended) {
|
||||
ended = true;
|
||||
stream.end();
|
||||
}
|
||||
};
|
||||
const flushBufferedTextEvents = () => {
|
||||
for (const event of bufferedTextEvents.splice(0)) {
|
||||
stream.push(event);
|
||||
}
|
||||
bufferedText = "";
|
||||
};
|
||||
|
||||
try {
|
||||
for await (const event of source as AsyncIterable<unknown>) {
|
||||
const record = toRecord(event);
|
||||
const type = typeof record?.type === "string" ? record.type : "";
|
||||
|
||||
if (type === "text_start" || type === "text_delta" || type === "text_end") {
|
||||
bufferedTextEvents.push(event);
|
||||
if (typeof record?.delta === "string") {
|
||||
bufferedText += record.delta;
|
||||
} else if (typeof record?.content === "string" && !bufferedText) {
|
||||
bufferedText = record.content;
|
||||
}
|
||||
if (!couldStillBePlainTextToolCall(bufferedText)) {
|
||||
flushBufferedTextEvents();
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (type === "done") {
|
||||
const promotedMessage = promoteLmstudioPlainTextToolCalls(record?.message, toolNames);
|
||||
if (promotedMessage) {
|
||||
bufferedTextEvents.splice(0);
|
||||
bufferedText = "";
|
||||
emitPromotedToolCallEvents(stream, promotedMessage);
|
||||
stream.push({ ...record, reason: "toolUse", message: promotedMessage });
|
||||
} else {
|
||||
flushBufferedTextEvents();
|
||||
stream.push(event);
|
||||
}
|
||||
endStream();
|
||||
return;
|
||||
}
|
||||
|
||||
flushBufferedTextEvents();
|
||||
stream.push(event);
|
||||
if (type === "error") {
|
||||
endStream();
|
||||
return;
|
||||
}
|
||||
}
|
||||
flushBufferedTextEvents();
|
||||
} catch (error) {
|
||||
stream.push({
|
||||
type: "error",
|
||||
reason: "error",
|
||||
error: {
|
||||
role: "assistant",
|
||||
content: [],
|
||||
stopReason: "error",
|
||||
errorMessage: error instanceof Error ? error.message : String(error),
|
||||
},
|
||||
});
|
||||
} finally {
|
||||
endStream();
|
||||
}
|
||||
})();
|
||||
|
||||
return output as ReturnType<StreamFn>;
|
||||
}
|
||||
|
||||
function createPreloadKey(params: {
|
||||
baseUrl: string;
|
||||
modelKey: string;
|
||||
@@ -248,7 +462,8 @@ export function wrapLmstudioInferencePreload(ctx: ProviderWrapStreamFnContext):
|
||||
},
|
||||
};
|
||||
const stream = underlying(modelWithUsageCompat, context, options);
|
||||
return stream instanceof Promise ? await stream : stream;
|
||||
const resolvedStream = stream instanceof Promise ? await stream : stream;
|
||||
return wrapLmstudioPlainTextToolCalls(resolvedStream, context);
|
||||
})();
|
||||
};
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user