mirror of
https://github.com/openclaw/openclaw.git
synced 2026-05-06 06:20:43 +00:00
refactor: share provider polling helper
This commit is contained in:
@@ -4,10 +4,10 @@ import {
|
||||
assertOkOrThrowHttpError,
|
||||
createProviderOperationDeadline,
|
||||
fetchWithTimeout,
|
||||
pollProviderOperationJson,
|
||||
postJsonRequest,
|
||||
resolveProviderOperationTimeoutMs,
|
||||
resolveProviderHttpRequestConfig,
|
||||
waitProviderOperationPollInterval,
|
||||
} from "openclaw/plugin-sdk/provider-http";
|
||||
import { normalizeOptionalString } from "openclaw/plugin-sdk/text-runtime";
|
||||
import type {
|
||||
@@ -128,29 +128,22 @@ async function pollOpenAIVideo(params: {
|
||||
timeoutMs: params.timeoutMs,
|
||||
label: `OpenAI video generation task ${params.videoId}`,
|
||||
});
|
||||
for (let attempt = 0; attempt < MAX_POLL_ATTEMPTS; attempt += 1) {
|
||||
const response = await fetchWithTimeout(
|
||||
`${params.baseUrl}/videos/${params.videoId}`,
|
||||
{
|
||||
method: "GET",
|
||||
headers: params.headers,
|
||||
},
|
||||
resolveProviderOperationTimeoutMs({ deadline, defaultTimeoutMs: DEFAULT_TIMEOUT_MS }),
|
||||
params.fetchFn,
|
||||
);
|
||||
await assertOkOrThrowHttpError(response, "OpenAI video status request failed");
|
||||
const payload = (await response.json()) as OpenAIVideoResponse;
|
||||
if (payload.status === "completed") {
|
||||
return payload;
|
||||
}
|
||||
if (payload.status === "failed") {
|
||||
throw new Error(
|
||||
normalizeOptionalString(payload.error?.message) || "OpenAI video generation failed",
|
||||
);
|
||||
}
|
||||
await waitProviderOperationPollInterval({ deadline, pollIntervalMs: POLL_INTERVAL_MS });
|
||||
}
|
||||
throw new Error(`OpenAI video generation task ${params.videoId} did not finish in time`);
|
||||
return await pollProviderOperationJson<OpenAIVideoResponse>({
|
||||
url: `${params.baseUrl}/videos/${params.videoId}`,
|
||||
headers: params.headers,
|
||||
deadline,
|
||||
defaultTimeoutMs: DEFAULT_TIMEOUT_MS,
|
||||
fetchFn: params.fetchFn,
|
||||
maxAttempts: MAX_POLL_ATTEMPTS,
|
||||
pollIntervalMs: POLL_INTERVAL_MS,
|
||||
requestFailedMessage: "OpenAI video status request failed",
|
||||
timeoutMessage: `OpenAI video generation task ${params.videoId} did not finish in time`,
|
||||
isComplete: (payload) => payload.status === "completed",
|
||||
getFailureMessage: (payload) =>
|
||||
payload.status === "failed"
|
||||
? normalizeOptionalString(payload.error?.message) || "OpenAI video generation failed"
|
||||
: undefined,
|
||||
});
|
||||
}
|
||||
|
||||
async function downloadOpenAIVideo(params: {
|
||||
|
||||
@@ -4,10 +4,10 @@ import {
|
||||
assertOkOrThrowHttpError,
|
||||
createProviderOperationDeadline,
|
||||
fetchWithTimeout,
|
||||
pollProviderOperationJson,
|
||||
postJsonRequest,
|
||||
resolveProviderOperationTimeoutMs,
|
||||
resolveProviderHttpRequestConfig,
|
||||
waitProviderOperationPollInterval,
|
||||
} from "openclaw/plugin-sdk/provider-http";
|
||||
import { normalizeOptionalString } from "openclaw/plugin-sdk/text-runtime";
|
||||
import type {
|
||||
@@ -78,29 +78,22 @@ async function pollTogetherVideo(params: {
|
||||
timeoutMs: params.timeoutMs,
|
||||
label: `Together video generation task ${params.videoId}`,
|
||||
});
|
||||
for (let attempt = 0; attempt < MAX_POLL_ATTEMPTS; attempt += 1) {
|
||||
const response = await fetchWithTimeout(
|
||||
`${params.baseUrl}/videos/${params.videoId}`,
|
||||
{
|
||||
method: "GET",
|
||||
headers: params.headers,
|
||||
},
|
||||
resolveProviderOperationTimeoutMs({ deadline, defaultTimeoutMs: DEFAULT_TIMEOUT_MS }),
|
||||
params.fetchFn,
|
||||
);
|
||||
await assertOkOrThrowHttpError(response, "Together video status request failed");
|
||||
const payload = (await response.json()) as TogetherVideoResponse;
|
||||
if (payload.status === "completed") {
|
||||
return payload;
|
||||
}
|
||||
if (payload.status === "failed") {
|
||||
throw new Error(
|
||||
normalizeOptionalString(payload.error?.message) ?? "Together video generation failed",
|
||||
);
|
||||
}
|
||||
await waitProviderOperationPollInterval({ deadline, pollIntervalMs: POLL_INTERVAL_MS });
|
||||
}
|
||||
throw new Error(`Together video generation task ${params.videoId} did not finish in time`);
|
||||
return await pollProviderOperationJson<TogetherVideoResponse>({
|
||||
url: `${params.baseUrl}/videos/${params.videoId}`,
|
||||
headers: params.headers,
|
||||
deadline,
|
||||
defaultTimeoutMs: DEFAULT_TIMEOUT_MS,
|
||||
fetchFn: params.fetchFn,
|
||||
maxAttempts: MAX_POLL_ATTEMPTS,
|
||||
pollIntervalMs: POLL_INTERVAL_MS,
|
||||
requestFailedMessage: "Together video status request failed",
|
||||
timeoutMessage: `Together video generation task ${params.videoId} did not finish in time`,
|
||||
isComplete: (payload) => payload.status === "completed",
|
||||
getFailureMessage: (payload) =>
|
||||
payload.status === "failed"
|
||||
? (normalizeOptionalString(payload.error?.message) ?? "Together video generation failed")
|
||||
: undefined,
|
||||
});
|
||||
}
|
||||
|
||||
async function downloadTogetherVideo(params: {
|
||||
|
||||
@@ -32,6 +32,7 @@ vi.mock("../infra/net/proxy-env.js", async () => {
|
||||
import {
|
||||
createProviderOperationDeadline,
|
||||
fetchWithTimeoutGuarded,
|
||||
pollProviderOperationJson,
|
||||
postJsonRequest,
|
||||
postTranscriptionRequest,
|
||||
readErrorResponse,
|
||||
@@ -113,6 +114,63 @@ describe("provider operation deadlines", () => {
|
||||
await vi.advanceTimersByTimeAsync(1);
|
||||
await expect(wait).resolves.toBeUndefined();
|
||||
});
|
||||
|
||||
it("polls provider status JSON until a payload is complete", async () => {
|
||||
vi.useFakeTimers();
|
||||
vi.setSystemTime(1_000);
|
||||
const fetchFn = vi
|
||||
.fn<typeof fetch>()
|
||||
.mockResolvedValueOnce(new Response(JSON.stringify({ status: "in_progress" })))
|
||||
.mockResolvedValueOnce(new Response(JSON.stringify({ status: "completed" })));
|
||||
|
||||
const result = pollProviderOperationJson<{ status?: string }>({
|
||||
url: "https://api.example.com/v1/videos/task-1",
|
||||
headers: new Headers({ authorization: "Bearer test" }),
|
||||
deadline: createProviderOperationDeadline({
|
||||
label: "video generation task task-1",
|
||||
timeoutMs: 10_000,
|
||||
}),
|
||||
defaultTimeoutMs: 5_000,
|
||||
fetchFn,
|
||||
maxAttempts: 3,
|
||||
pollIntervalMs: 1_000,
|
||||
requestFailedMessage: "status failed",
|
||||
timeoutMessage: "task timed out",
|
||||
isComplete: (payload) => payload.status === "completed",
|
||||
});
|
||||
|
||||
await vi.advanceTimersByTimeAsync(1_000);
|
||||
|
||||
await expect(result).resolves.toEqual({ status: "completed" });
|
||||
expect(fetchFn).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
|
||||
it("throws provider failure messages while polling status JSON", async () => {
|
||||
const fetchFn = vi
|
||||
.fn<typeof fetch>()
|
||||
.mockResolvedValueOnce(
|
||||
new Response(JSON.stringify({ status: "failed", error: { message: "model rejected" } })),
|
||||
);
|
||||
|
||||
await expect(
|
||||
pollProviderOperationJson<{ status?: string; error?: { message?: string } }>({
|
||||
url: "https://api.example.com/v1/videos/task-1",
|
||||
headers: new Headers(),
|
||||
deadline: createProviderOperationDeadline({
|
||||
label: "video generation task task-1",
|
||||
}),
|
||||
defaultTimeoutMs: 5_000,
|
||||
fetchFn,
|
||||
maxAttempts: 3,
|
||||
pollIntervalMs: 1_000,
|
||||
requestFailedMessage: "status failed",
|
||||
timeoutMessage: "task timed out",
|
||||
isComplete: (payload) => payload.status === "completed",
|
||||
getFailureMessage: (payload) =>
|
||||
payload.status === "failed" ? payload.error?.message : undefined,
|
||||
}),
|
||||
).rejects.toThrow("model rejected");
|
||||
});
|
||||
});
|
||||
|
||||
describe("resolveProviderHttpRequestConfig", () => {
|
||||
|
||||
@@ -13,7 +13,8 @@ import type { GuardedFetchMode, GuardedFetchResult } from "../infra/net/fetch-gu
|
||||
import { fetchWithSsrFGuard, GUARDED_FETCH_MODE } from "../infra/net/fetch-guard.js";
|
||||
import { hasEnvHttpProxyConfigured, matchesNoProxy } from "../infra/net/proxy-env.js";
|
||||
import type { LookupFn, PinnedDispatcherPolicy, SsrFPolicy } from "../infra/net/ssrf.js";
|
||||
export { fetchWithTimeout } from "../utils/fetch-timeout.js";
|
||||
import { fetchWithTimeout } from "../utils/fetch-timeout.js";
|
||||
export { fetchWithTimeout };
|
||||
export { normalizeBaseUrl } from "../agents/provider-request-config.js";
|
||||
|
||||
const MAX_ERROR_CHARS = 300;
|
||||
@@ -77,6 +78,49 @@ export async function waitProviderOperationPollInterval(params: {
|
||||
await new Promise((resolve) => setTimeout(resolve, Math.min(params.pollIntervalMs, remainingMs)));
|
||||
}
|
||||
|
||||
export async function pollProviderOperationJson<TPayload>(params: {
|
||||
url: string;
|
||||
headers: Headers;
|
||||
deadline: ProviderOperationDeadline;
|
||||
defaultTimeoutMs: number;
|
||||
fetchFn: typeof fetch;
|
||||
maxAttempts: number;
|
||||
pollIntervalMs: number;
|
||||
requestFailedMessage: string;
|
||||
timeoutMessage: string;
|
||||
isComplete: (payload: TPayload) => boolean;
|
||||
getFailureMessage?: (payload: TPayload) => string | undefined;
|
||||
}): Promise<TPayload> {
|
||||
for (let attempt = 0; attempt < params.maxAttempts; attempt += 1) {
|
||||
const response = await fetchWithTimeout(
|
||||
params.url,
|
||||
{
|
||||
method: "GET",
|
||||
headers: params.headers,
|
||||
},
|
||||
resolveProviderOperationTimeoutMs({
|
||||
deadline: params.deadline,
|
||||
defaultTimeoutMs: params.defaultTimeoutMs,
|
||||
}),
|
||||
params.fetchFn,
|
||||
);
|
||||
await assertOkOrThrowHttpError(response, params.requestFailedMessage);
|
||||
const payload = (await response.json()) as TPayload;
|
||||
if (params.isComplete(payload)) {
|
||||
return payload;
|
||||
}
|
||||
const failureMessage = params.getFailureMessage?.(payload);
|
||||
if (failureMessage) {
|
||||
throw new Error(failureMessage);
|
||||
}
|
||||
await waitProviderOperationPollInterval({
|
||||
deadline: params.deadline,
|
||||
pollIntervalMs: params.pollIntervalMs,
|
||||
});
|
||||
}
|
||||
throw new Error(params.timeoutMessage);
|
||||
}
|
||||
|
||||
function resolveGuardedHttpTimeoutMs(timeoutMs: number | undefined): number {
|
||||
if (typeof timeoutMs !== "number" || !Number.isFinite(timeoutMs) || timeoutMs <= 0) {
|
||||
return DEFAULT_GUARDED_HTTP_TIMEOUT_MS;
|
||||
|
||||
@@ -7,6 +7,7 @@ export {
|
||||
fetchWithTimeout,
|
||||
fetchWithTimeoutGuarded,
|
||||
normalizeBaseUrl,
|
||||
pollProviderOperationJson,
|
||||
postJsonRequest,
|
||||
postTranscriptionRequest,
|
||||
resolveProviderOperationTimeoutMs,
|
||||
|
||||
@@ -1,15 +1,20 @@
|
||||
import type { resolveProviderHttpRequestConfig } from "openclaw/plugin-sdk/provider-http";
|
||||
import type {
|
||||
pollProviderOperationJson,
|
||||
resolveProviderHttpRequestConfig,
|
||||
} from "openclaw/plugin-sdk/provider-http";
|
||||
import { afterEach, vi } from "vitest";
|
||||
|
||||
type ResolveProviderHttpRequestConfigParams = Parameters<
|
||||
typeof resolveProviderHttpRequestConfig
|
||||
>[0];
|
||||
type PollProviderOperationJsonParams = Parameters<typeof pollProviderOperationJson>[0];
|
||||
|
||||
const providerHttpMocks = vi.hoisted(() => ({
|
||||
resolveApiKeyForProviderMock: vi.fn(async () => ({ apiKey: "provider-key" })),
|
||||
postJsonRequestMock: vi.fn(),
|
||||
fetchWithTimeoutMock: vi.fn(),
|
||||
assertOkOrThrowHttpErrorMock: vi.fn(async () => {}),
|
||||
pollProviderOperationJsonMock: vi.fn(),
|
||||
assertOkOrThrowHttpErrorMock: vi.fn(async (_response: Response, _label: string) => {}),
|
||||
resolveProviderHttpRequestConfigMock: vi.fn((params: ResolveProviderHttpRequestConfigParams) => ({
|
||||
baseUrl: params.baseUrl ?? params.defaultBaseUrl,
|
||||
allowPrivateNetwork: false,
|
||||
@@ -18,6 +23,32 @@ const providerHttpMocks = vi.hoisted(() => ({
|
||||
})),
|
||||
}));
|
||||
|
||||
providerHttpMocks.pollProviderOperationJsonMock.mockImplementation(
|
||||
async (params: PollProviderOperationJsonParams) => {
|
||||
for (let attempt = 0; attempt < params.maxAttempts; attempt += 1) {
|
||||
const response = await providerHttpMocks.fetchWithTimeoutMock(
|
||||
params.url,
|
||||
{
|
||||
method: "GET",
|
||||
headers: params.headers,
|
||||
},
|
||||
params.defaultTimeoutMs,
|
||||
params.fetchFn,
|
||||
);
|
||||
await providerHttpMocks.assertOkOrThrowHttpErrorMock(response, params.requestFailedMessage);
|
||||
const payload = await response.json();
|
||||
if (params.isComplete(payload)) {
|
||||
return payload;
|
||||
}
|
||||
const failureMessage = params.getFailureMessage?.(payload);
|
||||
if (failureMessage) {
|
||||
throw new Error(failureMessage);
|
||||
}
|
||||
}
|
||||
throw new Error(params.timeoutMessage);
|
||||
},
|
||||
);
|
||||
|
||||
vi.mock("openclaw/plugin-sdk/provider-auth-runtime", () => ({
|
||||
resolveApiKeyForProvider: providerHttpMocks.resolveApiKeyForProviderMock,
|
||||
}));
|
||||
@@ -35,6 +66,7 @@ vi.mock("openclaw/plugin-sdk/provider-http", () => ({
|
||||
timeoutMs,
|
||||
}),
|
||||
fetchWithTimeout: providerHttpMocks.fetchWithTimeoutMock,
|
||||
pollProviderOperationJson: providerHttpMocks.pollProviderOperationJsonMock,
|
||||
postJsonRequest: providerHttpMocks.postJsonRequestMock,
|
||||
resolveProviderOperationTimeoutMs: ({ defaultTimeoutMs }: { defaultTimeoutMs: number }) =>
|
||||
defaultTimeoutMs,
|
||||
@@ -51,6 +83,7 @@ export function installProviderHttpMockCleanup(): void {
|
||||
providerHttpMocks.resolveApiKeyForProviderMock.mockClear();
|
||||
providerHttpMocks.postJsonRequestMock.mockReset();
|
||||
providerHttpMocks.fetchWithTimeoutMock.mockReset();
|
||||
providerHttpMocks.pollProviderOperationJsonMock.mockClear();
|
||||
providerHttpMocks.assertOkOrThrowHttpErrorMock.mockClear();
|
||||
providerHttpMocks.resolveProviderHttpRequestConfigMock.mockClear();
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user