diff --git a/scripts/docs-i18n/translator.go b/scripts/docs-i18n/translator.go index aac2afc5f80..8f7023c615b 100644 --- a/scripts/docs-i18n/translator.go +++ b/scripts/docs-i18n/translator.go @@ -2,7 +2,6 @@ package main import ( "context" - "encoding/json" "errors" "fmt" "strings" @@ -14,6 +13,7 @@ import ( const ( translateMaxAttempts = 3 translateBaseDelay = 15 * time.Second + translatePromptTimeout = 2 * time.Minute ) var errEmptyTranslation = errors.New("empty translation") @@ -145,96 +145,31 @@ func (t *PiTranslator) Close() { } } -type agentEndPayload struct { - Messages []agentMessage `json:"messages"` +type promptRunner interface { + Run(context.Context, string) (pi.RunResult, error) + Stderr() string } -type agentMessage struct { - Role string `json:"role"` - Content json.RawMessage `json:"content"` - StopReason string `json:"stopReason,omitempty"` - ErrorMessage string `json:"errorMessage,omitempty"` -} - -type contentBlock struct { - Type string `json:"type"` - Text string `json:"text,omitempty"` -} - -func runPrompt(ctx context.Context, client *pi.OneShotClient, message string) (string, error) { - events, cancel := client.Subscribe(256) +func runPrompt(ctx context.Context, client promptRunner, message string) (string, error) { + promptCtx, cancel := context.WithTimeout(ctx, translatePromptTimeout) defer cancel() - if err := client.Prompt(ctx, message); err != nil { - return "", err - } - - for { - select { - case <-ctx.Done(): - return "", ctx.Err() - case event, ok := <-events: - if !ok { - return "", errors.New("event stream closed") - } - if event.Type == "agent_end" { - return extractTranslationResult(event.Raw) - } - } + result, err := client.Run(promptCtx, message) + if err != nil { + return "", decoratePromptError(err, client.Stderr()) } + return result.Text, nil } -func extractTranslationResult(raw json.RawMessage) (string, error) { - var payload agentEndPayload - if err := json.Unmarshal(raw, &payload); err != nil { - return "", err +func decoratePromptError(err error, stderr string) error { + if err == nil { + return nil } - for index := len(payload.Messages) - 1; index >= 0; index-- { - message := payload.Messages[index] - if message.Role != "assistant" { - continue - } - if message.ErrorMessage != "" || strings.EqualFold(message.StopReason, "error") { - msg := strings.TrimSpace(message.ErrorMessage) - if msg == "" { - msg = "unknown error" - } - return "", fmt.Errorf("pi error: %s", msg) - } - text, err := extractContentText(message.Content) - if err != nil { - return "", err - } - return text, nil - } - return "", errors.New("assistant message not found") -} - -func extractContentText(content json.RawMessage) (string, error) { - trimmed := strings.TrimSpace(string(content)) + trimmed := strings.TrimSpace(stderr) if trimmed == "" { - return "", nil + return err } - if strings.HasPrefix(trimmed, "\"") { - var text string - if err := json.Unmarshal(content, &text); err != nil { - return "", err - } - return text, nil - } - - var blocks []contentBlock - if err := json.Unmarshal(content, &blocks); err != nil { - return "", err - } - - var parts []string - for _, block := range blocks { - if block.Type == "text" && block.Text != "" { - parts = append(parts, block.Text) - } - } - return strings.Join(parts, ""), nil + return fmt.Errorf("%w (pi stderr: %s)", err, trimmed) } func normalizeThinking(value string) string { diff --git a/scripts/docs-i18n/translator_test.go b/scripts/docs-i18n/translator_test.go new file mode 100644 index 00000000000..a632e44e96e --- /dev/null +++ b/scripts/docs-i18n/translator_test.go @@ -0,0 +1,92 @@ +package main + +import ( + "context" + "errors" + "strings" + "testing" + "time" + + pi "github.com/joshp123/pi-golang" +) + +type fakePromptRunner struct { + run func(context.Context, string) (pi.RunResult, error) + stderr string +} + +func (runner fakePromptRunner) Run(ctx context.Context, message string) (pi.RunResult, error) { + return runner.run(ctx, message) +} + +func (runner fakePromptRunner) Stderr() string { + return runner.stderr +} + +func TestRunPromptAddsTimeout(t *testing.T) { + t.Parallel() + + var deadline time.Time + client := fakePromptRunner{ + run: func(ctx context.Context, message string) (pi.RunResult, error) { + var ok bool + deadline, ok = ctx.Deadline() + if !ok { + t.Fatal("expected prompt deadline") + } + if message != "Translate me" { + t.Fatalf("unexpected message %q", message) + } + return pi.RunResult{Text: "translated"}, nil + }, + } + + got, err := runPrompt(context.Background(), client, "Translate me") + if err != nil { + t.Fatalf("runPrompt returned error: %v", err) + } + if got != "translated" { + t.Fatalf("unexpected translation %q", got) + } + + remaining := time.Until(deadline) + if remaining <= time.Minute || remaining > translatePromptTimeout { + t.Fatalf("unexpected timeout window %s", remaining) + } +} + +func TestRunPromptIncludesStderr(t *testing.T) { + t.Parallel() + + rootErr := errors.New("context deadline exceeded") + client := fakePromptRunner{ + run: func(context.Context, string) (pi.RunResult, error) { + return pi.RunResult{}, rootErr + }, + stderr: "boom", + } + + _, err := runPrompt(context.Background(), client, "Translate me") + if err == nil { + t.Fatal("expected error") + } + if !errors.Is(err, rootErr) { + t.Fatalf("expected wrapped root error, got %v", err) + } + if !strings.Contains(err.Error(), "pi stderr: boom") { + t.Fatalf("expected stderr in error, got %v", err) + } +} + +func TestDecoratePromptErrorLeavesCleanErrorsAlone(t *testing.T) { + t.Parallel() + + rootErr := errors.New("plain failure") + got := decoratePromptError(rootErr, " ") + if !errors.Is(got, rootErr) { + t.Fatalf("expected original error, got %v", got) + } + if got.Error() != rootErr.Error() { + t.Fatalf("expected unchanged message, got %v", got) + } +}