diff --git a/extensions/openai/video-generation-provider.ts b/extensions/openai/video-generation-provider.ts index ecc289f4959..c47e4931a13 100644 --- a/extensions/openai/video-generation-provider.ts +++ b/extensions/openai/video-generation-provider.ts @@ -4,10 +4,10 @@ import { assertOkOrThrowHttpError, createProviderOperationDeadline, fetchWithTimeout, + pollProviderOperationJson, postJsonRequest, resolveProviderOperationTimeoutMs, resolveProviderHttpRequestConfig, - waitProviderOperationPollInterval, } from "openclaw/plugin-sdk/provider-http"; import { normalizeOptionalString } from "openclaw/plugin-sdk/text-runtime"; import type { @@ -128,29 +128,22 @@ async function pollOpenAIVideo(params: { timeoutMs: params.timeoutMs, label: `OpenAI video generation task ${params.videoId}`, }); - for (let attempt = 0; attempt < MAX_POLL_ATTEMPTS; attempt += 1) { - const response = await fetchWithTimeout( - `${params.baseUrl}/videos/${params.videoId}`, - { - method: "GET", - headers: params.headers, - }, - resolveProviderOperationTimeoutMs({ deadline, defaultTimeoutMs: DEFAULT_TIMEOUT_MS }), - params.fetchFn, - ); - await assertOkOrThrowHttpError(response, "OpenAI video status request failed"); - const payload = (await response.json()) as OpenAIVideoResponse; - if (payload.status === "completed") { - return payload; - } - if (payload.status === "failed") { - throw new Error( - normalizeOptionalString(payload.error?.message) || "OpenAI video generation failed", - ); - } - await waitProviderOperationPollInterval({ deadline, pollIntervalMs: POLL_INTERVAL_MS }); - } - throw new Error(`OpenAI video generation task ${params.videoId} did not finish in time`); + return await pollProviderOperationJson({ + url: `${params.baseUrl}/videos/${params.videoId}`, + headers: params.headers, + deadline, + defaultTimeoutMs: DEFAULT_TIMEOUT_MS, + fetchFn: params.fetchFn, + maxAttempts: MAX_POLL_ATTEMPTS, + pollIntervalMs: POLL_INTERVAL_MS, + requestFailedMessage: "OpenAI video status request failed", + timeoutMessage: `OpenAI video generation task ${params.videoId} did not finish in time`, + isComplete: (payload) => payload.status === "completed", + getFailureMessage: (payload) => + payload.status === "failed" + ? normalizeOptionalString(payload.error?.message) || "OpenAI video generation failed" + : undefined, + }); } async function downloadOpenAIVideo(params: { diff --git a/extensions/together/video-generation-provider.ts b/extensions/together/video-generation-provider.ts index 9fbeb4a143a..856c39935d6 100644 --- a/extensions/together/video-generation-provider.ts +++ b/extensions/together/video-generation-provider.ts @@ -4,10 +4,10 @@ import { assertOkOrThrowHttpError, createProviderOperationDeadline, fetchWithTimeout, + pollProviderOperationJson, postJsonRequest, resolveProviderOperationTimeoutMs, resolveProviderHttpRequestConfig, - waitProviderOperationPollInterval, } from "openclaw/plugin-sdk/provider-http"; import { normalizeOptionalString } from "openclaw/plugin-sdk/text-runtime"; import type { @@ -78,29 +78,22 @@ async function pollTogetherVideo(params: { timeoutMs: params.timeoutMs, label: `Together video generation task ${params.videoId}`, }); - for (let attempt = 0; attempt < MAX_POLL_ATTEMPTS; attempt += 1) { - const response = await fetchWithTimeout( - `${params.baseUrl}/videos/${params.videoId}`, - { - method: "GET", - headers: params.headers, - }, - resolveProviderOperationTimeoutMs({ deadline, defaultTimeoutMs: DEFAULT_TIMEOUT_MS }), - params.fetchFn, - ); - await assertOkOrThrowHttpError(response, "Together video status request failed"); - const payload = (await response.json()) as TogetherVideoResponse; - if (payload.status === "completed") { - return payload; - } - if (payload.status === "failed") { - throw new Error( - normalizeOptionalString(payload.error?.message) ?? "Together video generation failed", - ); - } - await waitProviderOperationPollInterval({ deadline, pollIntervalMs: POLL_INTERVAL_MS }); - } - throw new Error(`Together video generation task ${params.videoId} did not finish in time`); + return await pollProviderOperationJson({ + url: `${params.baseUrl}/videos/${params.videoId}`, + headers: params.headers, + deadline, + defaultTimeoutMs: DEFAULT_TIMEOUT_MS, + fetchFn: params.fetchFn, + maxAttempts: MAX_POLL_ATTEMPTS, + pollIntervalMs: POLL_INTERVAL_MS, + requestFailedMessage: "Together video status request failed", + timeoutMessage: `Together video generation task ${params.videoId} did not finish in time`, + isComplete: (payload) => payload.status === "completed", + getFailureMessage: (payload) => + payload.status === "failed" + ? (normalizeOptionalString(payload.error?.message) ?? "Together video generation failed") + : undefined, + }); } async function downloadTogetherVideo(params: { diff --git a/src/media-understanding/shared.test.ts b/src/media-understanding/shared.test.ts index a065b3f546c..6b5b214c657 100644 --- a/src/media-understanding/shared.test.ts +++ b/src/media-understanding/shared.test.ts @@ -32,6 +32,7 @@ vi.mock("../infra/net/proxy-env.js", async () => { import { createProviderOperationDeadline, fetchWithTimeoutGuarded, + pollProviderOperationJson, postJsonRequest, postTranscriptionRequest, readErrorResponse, @@ -113,6 +114,63 @@ describe("provider operation deadlines", () => { await vi.advanceTimersByTimeAsync(1); await expect(wait).resolves.toBeUndefined(); }); + + it("polls provider status JSON until a payload is complete", async () => { + vi.useFakeTimers(); + vi.setSystemTime(1_000); + const fetchFn = vi + .fn() + .mockResolvedValueOnce(new Response(JSON.stringify({ status: "in_progress" }))) + .mockResolvedValueOnce(new Response(JSON.stringify({ status: "completed" }))); + + const result = pollProviderOperationJson<{ status?: string }>({ + url: "https://api.example.com/v1/videos/task-1", + headers: new Headers({ authorization: "Bearer test" }), + deadline: createProviderOperationDeadline({ + label: "video generation task task-1", + timeoutMs: 10_000, + }), + defaultTimeoutMs: 5_000, + fetchFn, + maxAttempts: 3, + pollIntervalMs: 1_000, + requestFailedMessage: "status failed", + timeoutMessage: "task timed out", + isComplete: (payload) => payload.status === "completed", + }); + + await vi.advanceTimersByTimeAsync(1_000); + + await expect(result).resolves.toEqual({ status: "completed" }); + expect(fetchFn).toHaveBeenCalledTimes(2); + }); + + it("throws provider failure messages while polling status JSON", async () => { + const fetchFn = vi + .fn() + .mockResolvedValueOnce( + new Response(JSON.stringify({ status: "failed", error: { message: "model rejected" } })), + ); + + await expect( + pollProviderOperationJson<{ status?: string; error?: { message?: string } }>({ + url: "https://api.example.com/v1/videos/task-1", + headers: new Headers(), + deadline: createProviderOperationDeadline({ + label: "video generation task task-1", + }), + defaultTimeoutMs: 5_000, + fetchFn, + maxAttempts: 3, + pollIntervalMs: 1_000, + requestFailedMessage: "status failed", + timeoutMessage: "task timed out", + isComplete: (payload) => payload.status === "completed", + getFailureMessage: (payload) => + payload.status === "failed" ? payload.error?.message : undefined, + }), + ).rejects.toThrow("model rejected"); + }); }); describe("resolveProviderHttpRequestConfig", () => { diff --git a/src/media-understanding/shared.ts b/src/media-understanding/shared.ts index f0c6f2c50b6..6e1df75ad8f 100644 --- a/src/media-understanding/shared.ts +++ b/src/media-understanding/shared.ts @@ -13,7 +13,8 @@ import type { GuardedFetchMode, GuardedFetchResult } from "../infra/net/fetch-gu import { fetchWithSsrFGuard, GUARDED_FETCH_MODE } from "../infra/net/fetch-guard.js"; import { hasEnvHttpProxyConfigured, matchesNoProxy } from "../infra/net/proxy-env.js"; import type { LookupFn, PinnedDispatcherPolicy, SsrFPolicy } from "../infra/net/ssrf.js"; -export { fetchWithTimeout } from "../utils/fetch-timeout.js"; +import { fetchWithTimeout } from "../utils/fetch-timeout.js"; +export { fetchWithTimeout }; export { normalizeBaseUrl } from "../agents/provider-request-config.js"; const MAX_ERROR_CHARS = 300; @@ -77,6 +78,49 @@ export async function waitProviderOperationPollInterval(params: { await new Promise((resolve) => setTimeout(resolve, Math.min(params.pollIntervalMs, remainingMs))); } +export async function pollProviderOperationJson(params: { + url: string; + headers: Headers; + deadline: ProviderOperationDeadline; + defaultTimeoutMs: number; + fetchFn: typeof fetch; + maxAttempts: number; + pollIntervalMs: number; + requestFailedMessage: string; + timeoutMessage: string; + isComplete: (payload: TPayload) => boolean; + getFailureMessage?: (payload: TPayload) => string | undefined; +}): Promise { + for (let attempt = 0; attempt < params.maxAttempts; attempt += 1) { + const response = await fetchWithTimeout( + params.url, + { + method: "GET", + headers: params.headers, + }, + resolveProviderOperationTimeoutMs({ + deadline: params.deadline, + defaultTimeoutMs: params.defaultTimeoutMs, + }), + params.fetchFn, + ); + await assertOkOrThrowHttpError(response, params.requestFailedMessage); + const payload = (await response.json()) as TPayload; + if (params.isComplete(payload)) { + return payload; + } + const failureMessage = params.getFailureMessage?.(payload); + if (failureMessage) { + throw new Error(failureMessage); + } + await waitProviderOperationPollInterval({ + deadline: params.deadline, + pollIntervalMs: params.pollIntervalMs, + }); + } + throw new Error(params.timeoutMessage); +} + function resolveGuardedHttpTimeoutMs(timeoutMs: number | undefined): number { if (typeof timeoutMs !== "number" || !Number.isFinite(timeoutMs) || timeoutMs <= 0) { return DEFAULT_GUARDED_HTTP_TIMEOUT_MS; diff --git a/src/plugin-sdk/provider-http.ts b/src/plugin-sdk/provider-http.ts index 71844827a4e..60c3841f62e 100644 --- a/src/plugin-sdk/provider-http.ts +++ b/src/plugin-sdk/provider-http.ts @@ -7,6 +7,7 @@ export { fetchWithTimeout, fetchWithTimeoutGuarded, normalizeBaseUrl, + pollProviderOperationJson, postJsonRequest, postTranscriptionRequest, resolveProviderOperationTimeoutMs, diff --git a/test/helpers/media-generation/provider-http-mocks.ts b/test/helpers/media-generation/provider-http-mocks.ts index 3ab7fb4d923..271e993f394 100644 --- a/test/helpers/media-generation/provider-http-mocks.ts +++ b/test/helpers/media-generation/provider-http-mocks.ts @@ -1,15 +1,20 @@ -import type { resolveProviderHttpRequestConfig } from "openclaw/plugin-sdk/provider-http"; +import type { + pollProviderOperationJson, + resolveProviderHttpRequestConfig, +} from "openclaw/plugin-sdk/provider-http"; import { afterEach, vi } from "vitest"; type ResolveProviderHttpRequestConfigParams = Parameters< typeof resolveProviderHttpRequestConfig >[0]; +type PollProviderOperationJsonParams = Parameters[0]; const providerHttpMocks = vi.hoisted(() => ({ resolveApiKeyForProviderMock: vi.fn(async () => ({ apiKey: "provider-key" })), postJsonRequestMock: vi.fn(), fetchWithTimeoutMock: vi.fn(), - assertOkOrThrowHttpErrorMock: vi.fn(async () => {}), + pollProviderOperationJsonMock: vi.fn(), + assertOkOrThrowHttpErrorMock: vi.fn(async (_response: Response, _label: string) => {}), resolveProviderHttpRequestConfigMock: vi.fn((params: ResolveProviderHttpRequestConfigParams) => ({ baseUrl: params.baseUrl ?? params.defaultBaseUrl, allowPrivateNetwork: false, @@ -18,6 +23,32 @@ const providerHttpMocks = vi.hoisted(() => ({ })), })); +providerHttpMocks.pollProviderOperationJsonMock.mockImplementation( + async (params: PollProviderOperationJsonParams) => { + for (let attempt = 0; attempt < params.maxAttempts; attempt += 1) { + const response = await providerHttpMocks.fetchWithTimeoutMock( + params.url, + { + method: "GET", + headers: params.headers, + }, + params.defaultTimeoutMs, + params.fetchFn, + ); + await providerHttpMocks.assertOkOrThrowHttpErrorMock(response, params.requestFailedMessage); + const payload = await response.json(); + if (params.isComplete(payload)) { + return payload; + } + const failureMessage = params.getFailureMessage?.(payload); + if (failureMessage) { + throw new Error(failureMessage); + } + } + throw new Error(params.timeoutMessage); + }, +); + vi.mock("openclaw/plugin-sdk/provider-auth-runtime", () => ({ resolveApiKeyForProvider: providerHttpMocks.resolveApiKeyForProviderMock, })); @@ -35,6 +66,7 @@ vi.mock("openclaw/plugin-sdk/provider-http", () => ({ timeoutMs, }), fetchWithTimeout: providerHttpMocks.fetchWithTimeoutMock, + pollProviderOperationJson: providerHttpMocks.pollProviderOperationJsonMock, postJsonRequest: providerHttpMocks.postJsonRequestMock, resolveProviderOperationTimeoutMs: ({ defaultTimeoutMs }: { defaultTimeoutMs: number }) => defaultTimeoutMs, @@ -51,6 +83,7 @@ export function installProviderHttpMockCleanup(): void { providerHttpMocks.resolveApiKeyForProviderMock.mockClear(); providerHttpMocks.postJsonRequestMock.mockReset(); providerHttpMocks.fetchWithTimeoutMock.mockReset(); + providerHttpMocks.pollProviderOperationJsonMock.mockClear(); providerHttpMocks.assertOkOrThrowHttpErrorMock.mockClear(); providerHttpMocks.resolveProviderHttpRequestConfigMock.mockClear(); });