#!/usr/bin/env python3
"""
batch_fetch_results.py

Given an OpenAI Batch API job id, poll for completion, download the results
JSONL to disk, and optionally extract fenced JSON blocks (```json ... ```).

Works for both /v1/responses and /v1/chat/completions batch outputs.

Usage
-----
python batch_fetch_results.py \
  --job-id batch_abc123 \
  --workdir outputs \
  --poll-interval 10 \
  --extract-json-blocks \
  --json-blocks-dir outputs/json_blocks

Notes
-----
- Requires OPENAI_API_KEY to be set in the environment.
- JSON block files are saved as <id>_v<k>.json, where <id> is parsed from
  the job's custom_id (e.g., "prob-42-r1" -> 42). If no numeric id can be
  parsed, files are saved as <custom_id>_v<k>.json instead.
"""

import argparse
import json
import os
import re
import time
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional

from openai import OpenAI

# Fenced JSON code blocks: ```json ... ```
JSON_BLOCK_RE = re.compile(r"```json\s*(\{.*?\})\s*```", re.DOTALL)

# Parse numeric id from common custom_id formats like "prob-42", "prob-42-r1"
CID_ID_RE = re.compile(r"^(?P<prefix>.+?-)(?P<idx>\d+)(?P<suffix>(?:-r\d+)*)$")


def _parse_id_from_custom_id(custom_id: Optional[str]) -> Optional[int]:
    if not isinstance(custom_id, str):
        return None
    m = CID_ID_RE.match(custom_id)
    if not m:
        return None
    try:
        return int(m.group("idx"))
    except Exception:
        return None


def _extract_text_from_responses_body(body: Dict[str, Any]) -> str:
    """
    Extract text from a /v1/responses-style body.
    """
    if isinstance(body, dict) and isinstance(body.get("output_text"), str):
        return body["output_text"].strip()

    out_chunks: List[str] = []
    items = body.get("output", []) if isinstance(body, dict) else []
    if isinstance(items, list):
        for item in items:
            if not isinstance(item, dict):
                continue
            t = item.get("text")
            if isinstance(t, str) and t.strip():
                out_chunks.append(t.strip())
            content = item.get("content")
            if isinstance(content, list):
                for part in content:
                    if not isinstance(part, dict):
                        continue
                    pt = part.get("text")
                    if isinstance(pt, str) and pt.strip():
                        out_chunks.append(pt.strip())
                    else:
                        alt = part.get("content")
                        if isinstance(alt, str) and alt.strip():
                            out_chunks.append(alt.strip())
    return "\n".join(out_chunks).strip()


def _extract_text_from_chat_body(body: Dict[str, Any]) -> str:
    """
    Extract text from a /v1/chat/completions-style body.
    """
    try:
        return body["choices"][0]["message"]["content"].strip()
    except Exception:
        return json.dumps(body)


def _extract_text_from_result_body(body: Dict[str, Any]) -> str:
    """
    Auto-detect extractor based on presence of 'choices' (chat) vs responses schema.
    """
    if isinstance(body, dict) and "choices" in body:
        return _extract_text_from_chat_body(body)
    return _extract_text_from_responses_body(body)


def poll_and_fetch(
    *,
    client: OpenAI,
    job_id: str,
    workdir: Path,
    poll_interval: float = 10.0,
    extract_json_blocks: bool = False,
    json_blocks_dir: Optional[Path] = None,
) -> Optional[Path]:
    """
    Poll a batch job until terminal state, then download results JSONL.
    Optionally extract fenced JSON blocks into json_blocks_dir.
    Returns the path to the results JSONL (or None if not completed).
    """
    workdir.mkdir(parents=True, exist_ok=True)

    # Poll
    while True:
        job = client.batches.retrieve(job_id)
        print(f"[{datetime.utcnow().isoformat()}Z] status={job.status}")
        if job.status in ("completed", "failed", "cancelled", "expired", "finalizing"):
            break
        time.sleep(poll_interval)

    if job.status != "completed":
        print(f"Batch finished with status '{job.status}'. Partial or no results.")
        return None

    if not job.output_file_id:
        print("Batch completed but has no output_file_id.")
        return None

    # Download results JSONL
    result_bytes = client.files.content(job.output_file_id).content
    results_path = workdir / f"batch_output_{job.id}.jsonl"
    with results_path.open("wb") as f:
        f.write(result_bytes)
    print(f"Wrote results: {results_path}")

    # Extract fenced JSON blocks
    if extract_json_blocks:
        blocks_dir = json_blocks_dir or (workdir / "json_blocks")
        blocks_dir.mkdir(parents=True, exist_ok=True)
        written = 0

        with results_path.open("r", encoding="utf-8") as f:
            for line in f:
                s = line.strip()
                if not s:
                    continue
                try:
                    obj = json.loads(s)
                except Exception:
                    continue

                cid = obj.get("custom_id")
                resp = obj.get("response") or {}
                body = resp.get("body") if isinstance(resp, dict) else None
                if not isinstance(body, dict):
                    continue

                text = _extract_text_from_result_body(body)
                if not text:
                    continue

                # Choose a base name for files: numeric id if possible, else custom_id
                iid = _parse_id_from_custom_id(cid)
                base = str(iid) if iid is not None else (cid if isinstance(cid, str) else "unknown")

                blocks = JSON_BLOCK_RE.findall(text)
                for vi, block in enumerate(blocks, start=1):
                    try:
                        parsed = json.loads(block)
                        outp = blocks_dir / f"{base}_v{vi}.json"
                        with outp.open("w", encoding="utf-8") as outf:
                            json.dump(parsed, outf, ensure_ascii=False, indent=2)
                        written += 1
                    except Exception:
                        outp = blocks_dir / f"{base}_v{vi}.raw.txt"
                        with outp.open("w", encoding="utf-8") as outf:
                            outf.write(block)
                        written += 1

        print(f"Extracted {written} JSON block file(s) to: {blocks_dir}")

    return results_path


def main():
    ap = argparse.ArgumentParser(description="Poll an OpenAI Batch job, download results JSONL, and optionally extract fenced JSON blocks.")
    ap.add_argument("--job-id", type=str, required=True, help="OpenAI Batch job id (e.g., batch_abc123)")
    ap.add_argument("--workdir", type=str, default="outputs", help="Directory to save outputs")
    ap.add_argument("--poll-interval", type=float, default=10.0, help="Seconds between status checks")
    ap.add_argument("--extract-json-blocks", action="store_true", help="If set, extract ```json blocks to --json-blocks-dir")
    ap.add_argument("--json-blocks-dir", type=str, default=None, help="Directory for extracted blocks (defaults to <workdir>/json_blocks)")
    args = ap.parse_args()

    api_key = os.getenv("OPENAI_API_KEY")
    if not api_key:
        raise ValueError("OPENAI_API_KEY environment variable is not set.")

    client = OpenAI(api_key=api_key)

    workdir = Path(args.workdir)
    blocks_dir = Path(args.json_blocks_dir) if args.json_blocks_dir else None

    results_path = poll_and_fetch(
        client=client,
        job_id=args.job_id,
        workdir=workdir,
        poll_interval=args.poll_interval,
        extract_json_blocks=args.extract_json_blocks,
        json_blocks_dir=blocks_dir,
    )

    if results_path is None:
        print("No results file written.")
    else:
        print(f"Done. Results at: {results_path}")


if __name__ == "__main__":
    try:
        main()
    except KeyboardInterrupt:
        print("\nInterrupted by user, exiting…")
