#!/usr/bin/env python3
"""
retry_failures.py

Identify failed/incomplete/missing batch items and optionally submit a retry batch.

Features
--------
- Detect failures from results JSONL (missing response/body, non-200, surface error messages).
- Detect items present in original batch-input but missing from results JSONL.
- Detect incomplete outputs by scanning extracted JSON blocks and flagging ids with < N blocks (including 0).
- Optionally override max_output_tokens / max_tokens for retries via --new-max-output-tokens.
- Sanity-check the first few retry lines to ensure the override is applied, else abort.
- Submit retry batch, poll for completion, download results, and optionally extract fenced JSON blocks.

Usage (examples)
----------------
# Print only (no submission), with blocks scan and reasons:
python retry_failures.py \
  --original-batch-input-path outputs/batch_input.jsonl \
  --batch-results-jsonl-path outputs/batch_output.jsonl \
  --json-blocks-dir outputs/json_blocks \
  --min-json-blocks 2 \
  --mode print

# Submit a retry with larger token limit, then extract json blocks:
python retry_failures.py \
  --original-batch-input-path outputs/batch_input.jsonl \
  --batch-results-jsonl-path outputs/batch_output.jsonl \
  --json-blocks-dir outputs/json_blocks \
  --mode submit \
  --retry-tag r1 \
  --endpoint /v1/responses \
  --new-max-output-tokens 4000 \
  --extract-json-blocks
"""

import argparse
import json
import os
import re
import time
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Tuple

from openai import OpenAI

# ---- Regexes ----
CID_ID_RE = re.compile(r"^(?P<prefix>.+?-)(?P<idx>\d+)(?P<suffix>(?:-r\d+)*)$")
JSON_BLOCK_RE = re.compile(r"```json\s*(\{.*?\})\s*```", re.DOTALL)
BLOCK_FILE_RE = re.compile(r"^(?P<id>\d+)_v(?P<v>\d+)\.json$")  # <id>_v<k>.json


# =============================================================================
# Utilities: robust parsing
# =============================================================================

def _load_jsonl(path: Path) -> List[Any]:
    items: List[Any] = []
    with path.open("r", encoding="utf-8") as f:
        for line in f:
            s = line.strip()
            if not s:
                continue
            try:
                items.append(json.loads(s))
            except Exception:
                # Keep going even if a line is malformed; add raw line as str
                items.append(s)
    return items


def _ensure_dict(x: Any) -> Dict[str, Any]:
    """
    Coerce value to dict if possible (parse JSON strings). Otherwise return {}.
    """
    if isinstance(x, dict):
        return x
    if isinstance(x, str):
        try:
            y = json.loads(x)
            return y if isinstance(y, dict) else {}
        except Exception:
            return {}
    return {}


def _ensure_body_dict(line_obj: Dict[str, Any]) -> Dict[str, Any]:
    """Coerce the 'body' of a batch line to dict, copying it."""
    body_raw = line_obj.get("body", {})
    body_dict = _ensure_dict(body_raw)
    return dict(body_dict)


def _parse_id_from_custom_id(custom_id: str) -> Optional[int]:
    if not isinstance(custom_id, str):
        return None
    m = CID_ID_RE.match(custom_id)
    if not m:
        return None
    try:
        return int(m.group("idx"))
    except Exception:
        return None


# =============================================================================
# Load original batch-input and build indexes
# =============================================================================

def _index_original_requests(original_input_path: Path) -> Tuple[Dict[str, Dict[str, Any]], List[str]]:
    """
    Return:
      - index: custom_id -> line object (dict)
      - ordered_cids: custom_ids in file order
    """
    index: Dict[str, Dict[str, Any]] = {}
    ordered_cids: List[str] = []
    raw_lines = _load_jsonl(original_input_path)
    for raw in raw_lines:
        obj = _ensure_dict(raw)
        if not obj:
            continue
        cid = obj.get("custom_id")
        if isinstance(cid, str):
            index[cid] = obj
            ordered_cids.append(cid)
    return index, ordered_cids


def _build_id_to_cid(ordered_cids: List[str]) -> Dict[int, str]:
    """
    Map numeric id -> first-seen custom_id (in file order).
    """
    id_to_cid: Dict[int, str] = {}
    for cid in ordered_cids:
        i = _parse_id_from_custom_id(cid)
        if i is None:
            continue
        if i not in id_to_cid:
            id_to_cid[i] = cid
    return id_to_cid


# =============================================================================
# Results JSONL analysis
# =============================================================================

def _collect_cids_from_results(results_path: Path) -> Set[str]:
    seen: Set[str] = set()
    try:
        with results_path.open("r", encoding="utf-8") as f:
            for line in f:
                s = line.strip()
                if not s:
                    continue
                try:
                    obj = json.loads(s)
                except Exception:
                    continue
                cid = obj.get("custom_id")
                if isinstance(cid, str):
                    seen.add(cid)
    except FileNotFoundError:
        print(f"Warning: results file not found: {results_path}")
    return seen


def _find_failures_from_results(results_path: Path) -> Dict[str, str]:
    """
    Return mapping {custom_id: reason} for failed lines:
      - no response
      - invalid/missing body
      - non-200 status (include error message if present)
    """
    reasons: Dict[str, str] = {}
    try:
        with results_path.open("r", encoding="utf-8") as f:
            for line in f:
                s = line.strip()
                if not s:
                    continue
                try:
                    obj = json.loads(s)
                except Exception:
                    continue
                cid = obj.get("custom_id")
                if not isinstance(cid, str):
                    continue

                resp = obj.get("response")
                if not isinstance(resp, dict):
                    reasons[cid] = "no response object"
                    continue

                status = resp.get("status_code")
                body = resp.get("body")
                if not isinstance(body, dict):
                    reasons[cid] = "invalid/missing body"
                    continue

                if status != 200:
                    msg = None
                    err = body.get("error") if isinstance(body.get("error"), dict) else None
                    if err:
                        msg = err.get("message")
                    reason = f"status {status}"
                    if msg:
                        reason += f" — {msg}"
                    reasons[cid] = reason
    except FileNotFoundError:
        print(f"Warning: results file not found: {results_path}")
    return reasons


# =============================================================================
# Incomplete JSON blocks detection
# =============================================================================

def _count_blocks(json_blocks_dir: Path) -> Dict[int, int]:
    """
    Count blocks per id based on filenames <id>_v*.json.
    Returns {id: count}.
    """
    counts: Dict[int, int] = {}
    if not json_blocks_dir.exists():
        return counts
    for p in json_blocks_dir.glob("*.json"):
        m = BLOCK_FILE_RE.match(p.name)
        if not m:
            continue
        pid = int(m.group("id"))
        counts[pid] = counts.get(pid, 0) + 1
    return counts


def _incomplete_from_blocks(expected_ids: Set[int], json_blocks_dir: Optional[Path], min_json_blocks: int) -> Dict[int, str]:
    """
    For all expected ids, report those with < min_json_blocks files.
    Includes zero-file cases.
    """
    if not json_blocks_dir:
        return {}
    counts = _count_blocks(json_blocks_dir)
    reasons: Dict[int, str] = {}
    for i in sorted(expected_ids):
        c = counts.get(i, 0)
        if c < min_json_blocks:
            reasons[i] = f"only {c} JSON block(s) < {min_json_blocks}"
    return reasons


# =============================================================================
# Build retry lines
# =============================================================================

def _suffix_custom_id(cid: str, retry_tag: str) -> str:
    return f"{cid}-{retry_tag}" if retry_tag else cid


def _token_key_for_endpoint(endpoint_path: str) -> str:
    """
    Decide which token field to use based on endpoint:
      - /v1/chat/completions -> 'max_tokens'
      - otherwise -> 'max_output_tokens'
    """
    if "/chat/completions" in endpoint_path:
        return "max_tokens"
    return "max_output_tokens"


def _build_retry_lines(
    *,
    failed_cids: List[str],
    original_index: Dict[str, Dict[str, Any]],
    retry_tag: str,
    endpoint_path: Optional[str],
    new_max_output_tokens: Optional[int],
) -> Tuple[List[Dict[str, Any]], Optional[str]]:
    """
    Build retry JSONL lines from a list of failed custom_ids using the original input index.
    Returns (retry_lines, inferred_endpoint).
    """
    retry_lines: List[Dict[str, Any]] = []
    inferred_endpoint: Optional[str] = None

    for cid in failed_cids:
        raw = original_index.get(cid)
        if raw is None:
            print(f"Warning: custom_id not found in original input: {cid}")
            continue

        line = _ensure_dict(raw)
        if not line:
            print(f"Warning: original line for {cid} is not a dict and could not be parsed; skipping.")
            continue

        # Determine endpoint used for this line (if not provided)
        if inferred_endpoint is None:
            if isinstance(endpoint_path, str) and endpoint_path.strip():
                inferred_endpoint = endpoint_path
            elif isinstance(line.get("url"), str):
                inferred_endpoint = line["url"]

        # Prepare body and override max tokens if requested
        body = _ensure_body_dict(line)
        if new_max_output_tokens is not None:
            # Decide which key to use based on the endpoint we will submit to.
            ep = endpoint_path or inferred_endpoint or "/v1/responses"
            key = _token_key_for_endpoint(ep)
            body[key] = new_max_output_tokens
            # Also mirror to the sibling key if present, to be extra safe:
            if key == "max_output_tokens" and "max_tokens" in body:
                body["max_tokens"] = new_max_output_tokens
            if key == "max_tokens" and "max_output_tokens" in body:
                body["max_output_tokens"] = new_max_output_tokens

        new_line = {
            "custom_id": _suffix_custom_id(cid, retry_tag),
            "method": line.get("method", "POST"),
            "url": line.get("url"),
            "body": body,
        }
        retry_lines.append(new_line)

    return retry_lines, inferred_endpoint


def _sanity_check_retry_lines(
    retry_lines: List[Dict[str, Any]],
    endpoint_path: str,
    new_max_output_tokens: Optional[int],
    check_first_n: int,
):
    if new_max_output_tokens is None:
        return
    key = _token_key_for_endpoint(endpoint_path)
    for i, line in enumerate(retry_lines[:check_first_n]):
        body = line.get("body", {})
        if not isinstance(body, dict):
            raise ValueError(f"Sanity check failed: body is not dict in retry line {i}")
        got = body.get(key)
        if got != new_max_output_tokens:
            raise ValueError(
                f"Sanity check failed on line {i}: expected {key}={new_max_output_tokens}, got {got}"
            )


# =============================================================================
# Results text extraction (for block parsing)
# =============================================================================

def _extract_text_from_responses_body(body: Dict[str, Any]) -> str:
    if isinstance(body, dict) and isinstance(body.get("output_text"), str):
        return body["output_text"].strip()
    out_chunks: List[str] = []
    items = body.get("output", []) if isinstance(body, dict) else []
    if isinstance(items, list):
        for item in items:
            if not isinstance(item, dict):
                continue
            t = item.get("text")
            if isinstance(t, str) and t.strip():
                out_chunks.append(t.strip())
            content = item.get("content")
            if isinstance(content, list):
                for part in content:
                    if isinstance(part, dict):
                        pt = part.get("text")
                        if isinstance(pt, str) and pt.strip():
                            out_chunks.append(pt.strip())
                        else:
                            alt = part.get("content")
                            if isinstance(alt, str) and alt.strip():
                                out_chunks.append(alt.strip())
    return "\n".join(out_chunks).strip()


def _extract_text_from_chat_body(body: Dict[str, Any]) -> str:
    try:
        return body["choices"][0]["message"]["content"].strip()
    except Exception:
        return json.dumps(body)


def _extract_text_from_result_body(body: Dict[str, Any]) -> str:
    if isinstance(body, dict) and "choices" in body:
        return _extract_text_from_chat_body(body)
    return _extract_text_from_responses_body(body)


# =============================================================================
# Main flow
# =============================================================================

def create_retry_batch_from_failures(
    *,
    client: OpenAI,
    original_batch_input_path: str,
    batch_results_jsonl_path: str,
    workdir: str = "outputs",
    retry_tag: str = "r1",
    completion_window: str = "24h",
    endpoint: Optional[str] = None,
    max_retry_items: Optional[int] = None,
    json_blocks_dir: Optional[str] = None,
    min_json_blocks: int = 2,
    mode: str = "print",  # "print" or "submit"
    id_prefix: str = "prob-",
    new_max_output_tokens: Optional[int] = None,
    check_first_n: int = 5,
    poll_interval: float = 10.0,
    extract_json_blocks: bool = False,
) -> Optional[object]:
    """
    Identify and optionally submit a follow-up Batch for failed/incomplete/missing items.
    """
    original_path = Path(original_batch_input_path)
    results_path = Path(batch_results_jsonl_path)
    out_dir = Path(workdir)
    out_dir.mkdir(parents=True, exist_ok=True)
    blocks_dir = Path(json_blocks_dir) if json_blocks_dir else None

    # Load original batch input
    original_index, ordered_cids = _index_original_requests(original_path)
    if not original_index:
        print(f"Error: no valid lines found in original batch input: {original_path}")
        return None

    # Failures & missing from results
    failed_cid_reasons = _find_failures_from_results(results_path)
    results_cids = _collect_cids_from_results(results_path)
    original_cids = set(original_index.keys())
    missing_cids = {cid for cid in original_cids if cid not in results_cids}

    # Incomplete from blocks (including zero-file cases)
    expected_ids: Set[int] = {_parse_id_from_custom_id(cid) for cid in original_cids}
    expected_ids = {i for i in expected_ids if i is not None}
    incomplete_id_reasons = _incomplete_from_blocks(expected_ids, blocks_dir, min_json_blocks)

    # Map ids->cids to attach block reasons to the right cids
    id_to_cid = _build_id_to_cid(ordered_cids)
    incomplete_cid_reasons: Dict[str, str] = {}
    for i, reason in incomplete_id_reasons.items():
        cid = id_to_cid.get(i)
        if cid:
            incomplete_cid_reasons[cid] = reason
        else:
            print(f"Warning: could not map incomplete id {i} to a custom_id via {original_path}")

    # Build unified reason map
    all_cid_reasons: Dict[str, str] = dict(failed_cid_reasons)
    for cid in missing_cids:
        all_cid_reasons.setdefault(cid, "missing in results JSONL")
    for cid, reason in incomplete_cid_reasons.items():
        all_cid_reasons.setdefault(cid, reason)

    # Final retry list (sorted by numeric id when possible, else by cid)
    to_retry_cids = list(all_cid_reasons.keys())
    def _sort_key(c: str):
        i = _parse_id_from_custom_id(c)
        return (0, i) if i is not None else (1, c)
    to_retry_cids.sort(key=_sort_key)

    if max_retry_items is not None and len(to_retry_cids) > max_retry_items:
        to_retry_cids = to_retry_cids[:max_retry_items]

    if not to_retry_cids:
        print("No failed, incomplete, or missing items detected. Nothing to retry.")
        return None

    # Print mode
    if mode == "print":
        print(f"Total to retry: {len(to_retry_cids)}")
        print("custom_id\treason")
        for cid in to_retry_cids:
            print(f"{cid}\t{all_cid_reasons.get(cid, 'unknown reason')}")
        print("\n[id]\treason")
        for cid in to_retry_cids:
            iid = _parse_id_from_custom_id(cid)
            if iid is not None:
                print(f"{iid}\t{all_cid_reasons.get(cid, 'unknown reason')}")
        return None

    if mode != "submit":
        raise ValueError("mode must be either 'print' or 'submit'")

    # Build retry JSONL lines
    retry_lines, inferred_endpoint = _build_retry_lines(
        failed_cids=to_retry_cids,
        original_index=original_index,
        retry_tag=retry_tag,
        endpoint_path=endpoint,
        new_max_output_tokens=new_max_output_tokens,
    )
    if not retry_lines:
        print("No retryable lines found (could not map items back to originals).")
        return None

    endpoint_path = endpoint or inferred_endpoint or "/v1/responses"

    # Sanity check on token override
    _sanity_check_retry_lines(retry_lines, endpoint_path, new_max_output_tokens, check_first_n)

    # Write retry input file
    retry_input_path = out_dir / f"batch_retry_{retry_tag}_{int(time.time())}.jsonl"
    with retry_input_path.open("w", encoding="utf-8") as f:
        for obj in retry_lines:
            f.write(json.dumps(obj, ensure_ascii=False) + "\n")

    # Submit
    retry_file = client.files.create(file=open(retry_input_path, "rb"), purpose="batch")
    job = client.batches.create(
        input_file_id=retry_file.id,
        endpoint=endpoint_path,
        completion_window=completion_window,
    )

    print(f"[{datetime.utcnow().isoformat()}Z] Created retry batch: {job.id} | status={job.status}")
    print(f"  - Retry lines: {len(retry_lines)}")
    print(f"  - Endpoint: {endpoint_path}")
    print(f"  - Input file: {retry_input_path}")

    # Poll
    while True:
        job = client.batches.retrieve(job.id)
        print(f"[{datetime.utcnow().isoformat()}Z] status={job.status}")
        if job.status in ("completed", "failed", "cancelled", "expired"):
            break
        time.sleep(poll_interval)

    if job.status != "completed":
        print(f"Retry batch finished with status '{job.status}'. Partial or no results.")
        return job

    if not job.output_file_id:
        print("Retry batch completed but has no output_file_id.")
        return job

    # Download results
    result_bytes = client.files.content(job.output_file_id).content
    retry_result_path = out_dir / f"batch_output_{job.id}.jsonl"
    with retry_result_path.open("wb") as f:
        f.write(result_bytes)
    print(f"Wrote retry results: {retry_result_path}")

    # Optional: extract fenced JSON blocks
    if extract_json_blocks:
        blocks_dir = out_dir / "json_blocks"
        blocks_dir.mkdir(parents=True, exist_ok=True)

        written = 0
        with retry_result_path.open("r", encoding="utf-8") as f:
            for line in f:
                s = line.strip()
                if not s:
                    continue
                try:
                    obj = json.loads(s)
                except Exception:
                    continue
                cid = obj.get("custom_id")
                iid = _parse_id_from_custom_id(cid) if cid else None
                resp = obj.get("response") or {}
                body = resp.get("body") if isinstance(resp, dict) else None
                if iid is None or not isinstance(body, dict):
                    continue

                text = _extract_text_from_result_body(body)
                if not text:
                    continue

                blocks = JSON_BLOCK_RE.findall(text)
                for vi, block in enumerate(blocks, start=1):
                    try:
                        parsed = json.loads(block)
                        with (blocks_dir / f"{iid}_v{vi}.json").open("w", encoding="utf-8") as outf:
                            json.dump(parsed, outf, ensure_ascii=False, indent=2)
                            written += 1
                    except Exception:
                        with (blocks_dir / f"{iid}_v{vi}.raw.txt").open("w", encoding="utf-8") as outf:
                            outf.write(block)
                            written += 1
        print(f"Extracted {written} JSON block file(s) to: {blocks_dir}")

    # Show reason breakdown
    reason_counts: Dict[str, int] = {}
    for cid in to_retry_cids:
        r = all_cid_reasons.get(cid, "unknown reason")
        reason_counts[r] = reason_counts.get(r, 0) + 1
    print("Retry breakdown by reason:")
    for r, c in sorted(reason_counts.items(), key=lambda x: (-x[1], x[0])):
        print(f"  {c} × {r}")

    return job


def main():
    p = argparse.ArgumentParser(description="Identify failed/incomplete/missing items and optionally submit a retry batch.")
    p.add_argument("--original-batch-input-path", type=str, required=True, help="Path to original batch input JSONL")
    p.add_argument("--batch-results-jsonl-path", type=str, required=True, help="Path to batch results JSONL")
    p.add_argument("--workdir", type=str, default="outputs", help="Directory to write retry JSONL and outputs")
    p.add_argument("--retry-tag", type=str, default="r1", help="Suffix for custom_id in retry job")
    p.add_argument("--completion-window", type=str, default="24h", help="Batch completion window")
    p.add_argument("--endpoint", type=str, default=None, help="Endpoint for retry (e.g., /v1/responses or /v1/chat/completions)")
    p.add_argument("--max-retry-items", type=int, default=None, help="Max number of items to retry")

    # Blocks scanning + behavior
    p.add_argument("--json-blocks-dir", type=str, default=None, help="Directory of extracted JSON blocks (<id>_v*.json)")
    p.add_argument("--min-json-blocks", type=int, default=2, help="Minimum number of blocks required to consider complete")

    # Mode & id
    p.add_argument("--mode", choices=["print", "submit"], default="print", help="Only print IDs or also submit a retry batch")
    p.add_argument("--id-prefix", type=str, default="prob-", help="Prefix used in custom_id to encode numeric id (unused in logic but kept for compatibility)")

    # Token override + sanity check + polling
    p.add_argument("--new-max-output-tokens", type=int, default=None, help="Override max_output_tokens/max_tokens for retry requests")
    p.add_argument("--check-first-n", type=int, default=5, help="Sanity-check first N retry lines for token override")
    p.add_argument("--poll-interval", type=float, default=10.0, help="Seconds between status checks when submitting")

    # Extraction
    p.add_argument("--extract-json-blocks", action="store_true", help="If set, parse fenced JSON blocks from retry results to workdir/json_blocks/")

    args = p.parse_args()

    api_key = os.getenv("OPENAI_API_KEY")
    if not api_key:
        raise ValueError("OPENAI_API_KEY environment variable is not set.")
    client = OpenAI(api_key=api_key)

    create_retry_batch_from_failures(
        client=client,
        original_batch_input_path=args.original_batch_input_path,
        batch_results_jsonl_path=args.batch_results_jsonl_path,
        workdir=args.workdir,
        retry_tag=args.retry_tag,
        completion_window=args.completion_window,
        endpoint=args.endpoint,
        max_retry_items=args.max_retry_items,
        json_blocks_dir=args.json_blocks_dir,
        min_json_blocks=args.min_json_blocks,
        mode=args.mode,
        id_prefix=args.id_prefix,
        new_max_output_tokens=args.new_max_output_tokens,
        check_first_n=args.check_first_n,
        poll_interval=args.poll_interval,
        extract_json_blocks=args.extract_json_blocks,
    )


if __name__ == "__main__":
    try:
        main()
    except KeyboardInterrupt:
        print("\nInterrupted by user, exiting…")
