#!/usr/bin/env python3
"""
Evaluate original + rewrites with `Decoy-for-the-Judge/unified_judge.py` via Batch APIs.

This script is modeled after:
- test_trained_model.py (Stage 2: score original + rewritten)
- data_generation/evaluate_rewrites_gptjudge_batch.py (OpenAI Batch submission + metadata + resume)
- unified_judge.py (prompt builders + JSON schema + OpenAI parameter conventions)

Supported input JSONL formats:
1) Nested format (recommended):
   {
     "prompt": "...",
     "original_response": "...",
     "rewrites": [{"rewritten_response": "...", ...}, ...]
   }
2) Flat format (one rewrite per line):
   {
     "prompt": "...",
     "original_response": "...",
     "rewritten_response": "..."
   }

Output fields (default: score only):
- original_cka_score, original_xteaming_score, original_generative_score
- cka_score, xteaming_score, generative_score (on each rewrite; or on the flat record)

Optional (enable via flags):
- *_reason, *_confidence, *_raw_output
"""

from __future__ import annotations

import argparse
import json
import os
import sys
import time
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

import openai
from google import genai

# Ensure repo root is on sys.path so we can import unified_judge.py
SCRIPT_DIR = Path(__file__).resolve().parent
REPO_ROOT = SCRIPT_DIR.parent
if str(REPO_ROOT) not in sys.path:
    sys.path.insert(0, str(REPO_ROOT))

# Import prompt builders from unified_judge.py
try:
    from unified_judge import (
        JudgeEvaluator,
        build_cka_evaluation_prompt,
        build_xteaming_evaluation_prompt,
        build_pair_evaluation_prompt,
    )
except Exception as e:  # pragma: no cover
    raise ImportError(
        "Failed to import prompt builders from unified_judge.py. "
        "Please ensure it exists and is importable."
    ) from e


JUDGE_TYPES_ALL = ("cka", "xteaming", "generative", "pair")


def _canonical_judge_type(judge_type: str) -> str:
    """Map legacy judge names to unified_judge.py canonical judge types."""
    jt = (judge_type or "").strip().lower()
    if jt == "generative":
        return "pair"
    return jt

def _detect_provider(model_name: str) -> str:
    name = (model_name or "").lower()
    if name.startswith("gemini"):
        return "gemini"
    if name.startswith("claude"):
        return "claude"
    return "openai"


def _parse_judges_arg(judges: str) -> List[str]:
    items = [x.strip().lower() for x in (judges or "").split(",") if x.strip()]
    if not items:
        # Backward compatible default
        return ["cka", "xteaming", "generative"]
    invalid = [x for x in items if x not in JUDGE_TYPES_ALL]
    if invalid:
        raise ValueError(f"Invalid --judges entries: {invalid}. Allowed: {list(JUDGE_TYPES_ALL)}")
    # Preserve order, de-dup
    seen = set()
    out: List[str] = []
    for x in items:
        if x not in seen:
            seen.add(x)
            out.append(x)
    return out


def _build_prompt(
    judge_type: str,
    query: str,
    response: str,
    target_model_holder: str,
    include_reason: bool,
    include_confidence: bool,
) -> str:
    """
    Build the *full* prompt exactly like unified_judge.JudgeEvaluator.infer_single():
      prompt_body + format_query_response(query, response)
    """
    canon = _canonical_judge_type(judge_type)
    if canon == "cka":
        prompt_body = build_cka_evaluation_prompt(include_reason=include_reason, include_confidence=include_confidence)
    elif canon == "xteaming":
        prompt_body = build_xteaming_evaluation_prompt(
            include_reason=include_reason,
            include_confidence=include_confidence,
            target_model_holder=target_model_holder,
        )
    elif canon == "pair":
        prompt_body = build_pair_evaluation_prompt(include_reason=include_reason, include_confidence=include_confidence)
    else:
        raise ValueError(f"Unknown judge_type: {judge_type}")
    return prompt_body + JudgeEvaluator.format_query_response(query, response)


def _score_field(judge_type: str, is_original: bool) -> str:
    return f"{'original_' if is_original else ''}{judge_type}_score"


def _reason_field(judge_type: str, is_original: bool) -> str:
    return f"{'original_' if is_original else ''}{judge_type}_reason"


def _confidence_field(judge_type: str, is_original: bool) -> str:
    return f"{'original_' if is_original else ''}{judge_type}_confidence"


def _raw_field(judge_type: str, is_original: bool) -> str:
    return f"{'original_' if is_original else ''}{judge_type}_raw_output"


def _extract_json_obj(text: str) -> Optional[Dict[str, Any]]:
    """
    Best-effort extraction of a JSON object from the model output.
    The OpenAI ChatCompletions "json_object" format should already return JSON,
    but we keep this robust in case of minor formatting glitches.
    """
    if not text:
        return None
    s = "".join(text.splitlines()).strip()
    if not s:
        return None
    if "{" in s and "}" in s:
        try:
            start = s.index("{")
            end = s.rindex("}") + 1
            s = s[start:end]
        except Exception:
            pass
    try:
        data = json.loads(s)
        if isinstance(data, dict):
            return data
        return None
    except Exception:
        return None


def _parse_judge_output(content: str) -> Dict[str, Any]:
    """
    Normalize judge output into a unified dict:
    {score:int|None, reason:str|None, confidence:float|None, raw_output:str}
    """
    out: Dict[str, Any] = {"score": None, "reason": None, "confidence": None, "raw_output": content or ""}
    data = _extract_json_obj(content)
    if not data:
        return out
    if "score" in data:
        try:
            out["score"] = int(data["score"])
        except Exception:
            out["score"] = None
    if "reason" in data:
        try:
            out["reason"] = str(data["reason"]).strip()
        except Exception:
            out["reason"] = None
    if "confidence" in data:
        try:
            out["confidence"] = float(data["confidence"])
        except Exception:
            out["confidence"] = None
    return out


def _gemini_extract_text_from_result_line(obj: Dict[str, Any]) -> str:
    """
    Best-effort extract generated text from a Gemini batch result JSON line.
    Gemini Batch API returns OpenAI-compatible format: {"response": {"body": {"choices": [...]}}}
    """
    # Gemini Batch API returns OpenAI-compatible format
    # Shape: {"response": {"body": {"choices": [{"message": {"content": "..."}}]}}}
    resp = obj.get("response")
    if isinstance(resp, dict):
        body = resp.get("body")
        if isinstance(body, dict):
            choices = body.get("choices")
            if isinstance(choices, list) and choices:
                choice = choices[0]
                if isinstance(choice, dict):
                    message = choice.get("message")
                    if isinstance(message, dict):
                        content = message.get("content")
                        if isinstance(content, str):
                            return content.strip()
    
    # Fallback: Try Gemini native format
    # Shape: {"response": {"candidates": [{"content": {"parts":[{"text":"..."}]}}]}}
    resp = obj.get("response") if isinstance(obj.get("response"), dict) else obj
    candidates = resp.get("candidates")
    if isinstance(candidates, list) and candidates:
        c0 = candidates[0]
        if isinstance(c0, dict):
            content = c0.get("content")
            if isinstance(content, dict):
                parts = content.get("parts")
                if isinstance(parts, list):
                    texts: List[str] = []
                    for p in parts:
                        if isinstance(p, dict) and isinstance(p.get("text"), str):
                            texts.append(p["text"])
                    if texts:
                        return "".join(texts).strip()
            # Fallback: some APIs may return "text" directly on candidate
            if isinstance(c0.get("text"), str):
                return c0["text"].strip()
    # Fallback: sometimes there may be a top-level "text"
    if isinstance(resp, dict) and isinstance(resp.get("text"), str):
        return resp["text"].strip()
    return ""


def _gemini_get_client() -> Any:
    api_key = os.environ.get("GEMINI_API_KEY") or os.environ.get("GOOGLE_API_KEY")
    if not api_key:
        raise ValueError("GEMINI_API_KEY or GOOGLE_API_KEY environment variable is required for Gemini Batch API")
    # Prefer explicit api_key; fallback to default constructor if SDK signature differs
    try:
        return genai.Client(api_key=api_key)
    except TypeError:
        return genai.Client()


def _gemini_normalize_model(model_name: str) -> str:
    # google-genai batch examples typically use "models/<name>"
    if not model_name:
        return model_name
    if model_name.startswith("models/"):
        return model_name
    return f"models/{model_name}"


def _wait_for_gemini_batch_completion(client: Any, batch_name: str, check_interval: int = 60) -> Any:
    start_time = time.time()
    print(f"Waiting for Gemini batch {batch_name} to complete...")
    while True:
        elapsed_time = time.time() - start_time
        hours = int(elapsed_time // 3600)
        minutes = int((elapsed_time % 3600) // 60)
        seconds = int(elapsed_time % 60)
        time_str = f"{hours:02d}:{minutes:02d}:{seconds:02d}"

        # API name in google-genai: client.batches.get(name=...) (best guess)
        try:
            batch = client.batches.get(name=batch_name)
        except TypeError:
            batch = client.batches.get(batch_name)

        state = getattr(batch, "state", None)
        state_name = getattr(state, "name", None) if state is not None else None
        if not state_name and isinstance(state, str):
            state_name = state
        state_name = str(state_name or "UNKNOWN")
        print(f"[{time_str}] Batch state: {state_name}")

        if "SUCCEEDED" in state_name or "COMPLETED" in state_name:
            return batch
        if "FAILED" in state_name or "CANCEL" in state_name:
            return batch
        time.sleep(int(check_interval))


def _process_gemini_batch_results(
    output_jsonl: str,
    samples: List[Dict[str, Any]],
    metadata: List[Dict[str, Any]],
    request_custom_ids: List[str],
    overwrite: bool,
    save_reason: bool,
    save_confidence: bool,
    save_raw_output: bool,
) -> List[Dict[str, Any]]:
    """
    Convert Gemini batch output JSONL to our internal mapping, then apply parsed fields.
    If Gemini output does not contain custom_id, we map by line order using request_custom_ids.
    """
    lines = [ln for ln in output_jsonl.splitlines() if ln.strip()]
    content_map: Dict[str, str] = {}

    for i, ln in enumerate(lines):
        try:
            obj = json.loads(ln)
        except Exception:
            obj = {}
        custom_id = None
        if isinstance(obj, dict):
            if isinstance(obj.get("custom_id"), str):
                custom_id = obj["custom_id"]
            elif isinstance(obj.get("id"), str):
                custom_id = obj["id"]
        if not custom_id and i < len(request_custom_ids):
            custom_id = request_custom_ids[i]
        if not custom_id:
            continue
        content_map[custom_id] = _gemini_extract_text_from_result_line(obj if isinstance(obj, dict) else {}) or ""

    # Apply back to samples (reuse existing metadata mapping + parser)
    out_samples: List[Dict[str, Any]] = []
    for sample in samples:
        out_samples.append(sample.copy())

    for meta in metadata:
        local_idx = meta.get("local_idx")
        global_idx = meta.get("global_idx")
        judge_type = meta.get("judge")
        part = meta.get("part")
        rewrite_idx = meta.get("rewrite_idx")
        is_flat = bool(meta.get("is_flat", False))
        if local_idx is None or judge_type is None or part not in ("original", "rewrite"):
            continue
        if local_idx < 0 or local_idx >= len(out_samples):
            continue

        is_original = part == "original"
        rid = rewrite_idx if rewrite_idx is not None else "flat"
        custom_id = (
            f"sample_{global_idx}_original_{judge_type}"
            if is_original
            else f"sample_{global_idx}_rewrite_{rid}_{judge_type}"
        )
        content = content_map.get(custom_id, "")
        parsed = _parse_judge_output(content)

        # Locate target object
        sample_obj = out_samples[local_idx]
        if is_original:
            target_obj = sample_obj
        else:
            if is_flat:
                target_obj = sample_obj
            else:
                rewrites = sample_obj.get("rewrites")
                if not isinstance(rewrites, list) or rewrite_idx is None or rewrite_idx >= len(rewrites):
                    continue
                if not isinstance(rewrites[rewrite_idx], dict):
                    continue
                target_obj = rewrites[rewrite_idx]

        score_field = _score_field(judge_type, is_original=is_original)
        if not overwrite and target_obj.get(score_field) is not None:
            continue

        _apply_parsed_to_sample(
            sample=sample_obj,
            parsed=parsed,
            judge_type=judge_type,
            is_original=is_original,
            rewrite_idx=rewrite_idx,
            is_flat=is_flat,
            save_reason=save_reason,
            save_confidence=save_confidence,
            save_raw_output=save_raw_output,
        )

    return out_samples


def _get_sample_fields(sample: Dict[str, Any]) -> Tuple[str, str]:
    """
    Return (prompt, original_response) with compatibility for:
    - prompt or behavior or query
    - original_response or response
    """
    prompt = sample.get("prompt") or sample.get("behavior") or sample.get("query") or ""
    original_response = sample.get("original_response")
    if original_response is None:
        original_response = sample.get("response")
    original_response = original_response or ""
    return str(prompt), str(original_response)


def _is_nested(sample: Dict[str, Any]) -> bool:
    return isinstance(sample.get("rewrites"), list)


def _get_rewrite_targets(sample: Dict[str, Any]) -> List[Tuple[Optional[int], Dict[str, Any], bool]]:
    """
    Returns list of (rewrite_idx, rewrite_obj, is_flat).
    - nested: returns each rewrite dict with its index
    - flat: returns [(None, sample, True)] and expects 'rewritten_response' on sample
    """
    if _is_nested(sample):
        rewrites = sample.get("rewrites") or []
        out: List[Tuple[Optional[int], Dict[str, Any], bool]] = []
        for i, r in enumerate(rewrites):
            if isinstance(r, dict):
                out.append((i, r, False))
        return out
    # flat
    return [(None, sample, True)]


def build_batch_request(
    model_name: str,
    prompt: str,
    temperature: float,
    seed: Optional[int],
    max_completion_tokens: int,
    provider: str = "openai",  # "openai" or "gemini"
) -> Dict[str, Any]:
    body: Dict[str, Any] = {
        "model": model_name,
        "messages": [{"role": "user", "content": prompt}],
        "response_format": {"type": "json_object"},
        "temperature": temperature,
        "top_p": 1,
        "n": 1,
    }

    # Output-length parameter:
    # - OpenAI: some newer models reject `max_tokens` and require `max_completion_tokens`
    # - Gemini Batch: validates against OpenAI schema; `max_tokens` is accepted; `seed` is rejected
    max_out = max(50, int(max_completion_tokens))
    if provider == "openai":
        if JudgeEvaluator._openai_uses_max_completion_tokens(model_name):
            body["max_completion_tokens"] = max_out
        else:
            body["max_tokens"] = max_out
    else:
        body["max_tokens"] = max_out

    # Only add seed for OpenAI.
    #
    # Gemini Batch (google-genai) currently validates input JSONL lines against the
    # OpenAI-compatible /v1/chat/completions schema and rejects an OpenAI-style "seed"
    # field with: 400 INVALID_ARGUMENT ... no such field: 'seed'.
    #
    # Also, Gemini Batch only supports URLs: /v1/chat/completions and /v1/embeddings
    # (generateContent URLs are rejected), so generationConfig.seed isn't usable here.
    if provider == "openai" and seed is not None:
        body["seed"] = seed

    return body


def prepare_batch_requests(
    samples: List[Dict[str, Any]],
    judge_model_name: str,
    target_model_holder: str,
    judges: List[str],
    temperature: float,
    seed: Optional[int],
    max_completion_tokens: int,
    overwrite: bool,
    start_idx: int,
    save_reason: bool,
    save_confidence: bool,
    provider: str = "openai",  # "openai" or "gemini"
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
    """
    Returns (batch_requests, metadata_items)
    metadata_items contain the mapping needed to place results back into samples.
    """
    batch_requests: List[Dict[str, Any]] = []
    metadata: List[Dict[str, Any]] = []

    def _prompt_fields_for(jt: str) -> Tuple[bool, bool]:
        canon = _canonical_judge_type(jt)
        # Match unified_judge.py prompt-builder defaults.
        if canon == "cka":
            default_reason, default_conf = True, True
        elif canon == "xteaming":
            default_reason, default_conf = True, False
        else:  # pair
            default_reason, default_conf = False, False
        return (default_reason or bool(save_reason), default_conf or bool(save_confidence))

    for local_idx, sample in enumerate(samples):
        prompt, original_response = _get_sample_fields(sample)
        if not original_response:
            continue
        global_idx = start_idx + local_idx

        # Score original response with each judge
        for judge_type in judges:
            field = _score_field(judge_type, is_original=True)
            if not overwrite and sample.get(field) is not None:
                continue
            include_reason, include_conf = _prompt_fields_for(judge_type)
            judge_prompt = _build_prompt(
                judge_type=judge_type,
                query=prompt,
                response=original_response,
                target_model_holder=target_model_holder,
                include_reason=include_reason,
                include_confidence=include_conf,
            )
            body = build_batch_request(
                model_name=judge_model_name,
                prompt=judge_prompt,
                temperature=temperature,
                seed=seed,
                max_completion_tokens=max_completion_tokens,
                provider=provider,
            )
            custom_id = f"sample_{global_idx}_original_{judge_type}"
            batch_requests.append({"custom_id": custom_id, "method": "POST", "url": "/v1/chat/completions", "body": body})
            metadata.append(
                {
                    "local_idx": local_idx,
                    "global_idx": global_idx,
                    "part": "original",
                    "rewrite_idx": None,
                    "judge": judge_type,
                    "is_flat": False,
                }
            )

        # Score rewrite(s)
        for rewrite_idx, rewrite_obj, is_flat in _get_rewrite_targets(sample):
            rewritten_response = rewrite_obj.get("rewritten_response") if not is_flat else sample.get("rewritten_response")
            rewritten_response = rewritten_response or ""
            if not rewritten_response:
                continue

            for judge_type in judges:
                field = _score_field(judge_type, is_original=False)
                # For nested, score is inside rewrite dict; for flat, on sample.
                target_obj = rewrite_obj if not is_flat else sample
                if not overwrite and target_obj.get(field) is not None:
                    continue

                include_reason, include_conf = _prompt_fields_for(judge_type)
                judge_prompt = _build_prompt(
                    judge_type=judge_type,
                    query=prompt,
                    response=rewritten_response,
                    target_model_holder=target_model_holder,
                    include_reason=include_reason,
                    include_confidence=include_conf,
                )
                body = build_batch_request(
                    model_name=judge_model_name,
                    prompt=judge_prompt,
                    temperature=temperature,
                    seed=seed,
                    max_completion_tokens=max_completion_tokens,
                    provider=provider,
                )
                rid = rewrite_idx if rewrite_idx is not None else "flat"
                custom_id = f"sample_{global_idx}_rewrite_{rid}_{judge_type}"
                batch_requests.append({"custom_id": custom_id, "method": "POST", "url": "/v1/chat/completions", "body": body})
                metadata.append(
                    {
                        "local_idx": local_idx,
                        "global_idx": global_idx,
                        "part": "rewrite",
                        "rewrite_idx": rewrite_idx,
                        "judge": judge_type,
                        "is_flat": bool(is_flat),
                    }
                )

    return batch_requests, metadata


def wait_for_batch_completion(client: openai.OpenAI, batch_id: str, check_interval: int = 60):
    start_time = time.time()
    print(f"Waiting for batch {batch_id} to complete...")
    print("This may take minutes to hours depending on the queue.")
    while True:
        elapsed_time = time.time() - start_time
        hours = int(elapsed_time // 3600)
        minutes = int((elapsed_time % 3600) // 60)
        seconds = int(elapsed_time % 60)
        time_str = f"{hours:02d}:{minutes:02d}:{seconds:02d}"
        
        batch_status = client.batches.retrieve(batch_id)
        status = batch_status.status
        print(f"[{time_str}] Batch status: {status}")
        if status == "completed":
            print("Batch completed successfully!")
            return batch_status
        if status in ["failed", "expired", "cancelled"]:
            print(f"Batch ended with status: {status}")
            return batch_status
        if status == "in_progress" and hasattr(batch_status, "request_counts"):
            counts = batch_status.request_counts
            total = getattr(counts, "total", 0)
            completed = getattr(counts, "completed", 0)
            failed = getattr(counts, "failed", 0)
            print(f"[{time_str}] Progress: {total} total, {completed} completed, {failed} failed")
        time.sleep(int(check_interval))


def _apply_parsed_to_sample(
    sample: Dict[str, Any],
    parsed: Dict[str, Any],
    judge_type: str,
    is_original: bool,
    rewrite_idx: Optional[int],
    is_flat: bool,
    save_reason: bool,
    save_confidence: bool,
    save_raw_output: bool,
) -> None:
    score = parsed.get("score")
    if score is None:
        return

    # Target object:
    # - original always on sample
    # - rewrite on sample (flat) or on sample["rewrites"][idx] (nested)
    if is_original:
        target = sample
    else:
        if is_flat:
            target = sample
        else:
            rewrites = sample.get("rewrites")
            if not isinstance(rewrites, list) or rewrite_idx is None or rewrite_idx >= len(rewrites):
                return
            if not isinstance(rewrites[rewrite_idx], dict):
                return
            target = rewrites[rewrite_idx]

    target[_score_field(judge_type, is_original=is_original)] = int(score)
    if save_reason:
        target[_reason_field(judge_type, is_original=is_original)] = parsed.get("reason") or ""
    if save_confidence:
        target[_confidence_field(judge_type, is_original=is_original)] = parsed.get("confidence")
    if save_raw_output:
        target[_raw_field(judge_type, is_original=is_original)] = parsed.get("raw_output") or ""


def process_batch_results(
    batch_results: List[Dict[str, Any]],
    samples: List[Dict[str, Any]],
    metadata: List[Dict[str, Any]],
    overwrite: bool,
    save_reason: bool,
    save_confidence: bool,
    save_raw_output: bool,
) -> List[Dict[str, Any]]:
    # Map custom_id -> raw text content
    content_map: Dict[str, str] = {}
    for result in batch_results:
        custom_id = result.get("custom_id")
        if not custom_id:
            continue
        if "response" in result and "body" in result["response"]:
            body = result["response"]["body"]
            try:
                content = body["choices"][0]["message"]["content"]
            except Exception:
                content = ""
            content_map[custom_id] = content or ""
        else:
            content_map[custom_id] = ""

    # Work on copies
    out_samples: List[Dict[str, Any]] = []
    for s in samples:
        s2 = s.copy()
        if isinstance(s.get("rewrites"), list):
            s2["rewrites"] = [r.copy() if isinstance(r, dict) else r for r in s.get("rewrites", [])]
        out_samples.append(s2)

    for m in metadata:
        local_idx = int(m["local_idx"])
        if local_idx < 0 or local_idx >= len(out_samples):
            continue
        sample = out_samples[local_idx]
        judge_type = str(m["judge"])
        part = str(m["part"])
        is_original = part == "original"
        rewrite_idx = m.get("rewrite_idx")
        rewrite_idx = int(rewrite_idx) if rewrite_idx is not None else None
        is_flat = bool(m.get("is_flat", False))

        global_idx = int(m["global_idx"])
        if is_original:
            custom_id = f"sample_{global_idx}_original_{judge_type}"
        else:
            rid = rewrite_idx if rewrite_idx is not None else "flat"
            custom_id = f"sample_{global_idx}_rewrite_{rid}_{judge_type}"

        content = content_map.get(custom_id, "")
        parsed = _parse_judge_output(content)
        if parsed.get("score") is None:
            continue

        # If not overwrite, keep existing values.
        target_obj: Dict[str, Any]
        if is_original:
            target_obj = sample
        else:
            if is_flat:
                target_obj = sample
            else:
                rewrites = sample.get("rewrites")
                if not isinstance(rewrites, list) or rewrite_idx is None or rewrite_idx >= len(rewrites):
                    continue
                if not isinstance(rewrites[rewrite_idx], dict):
                    continue
                target_obj = rewrites[rewrite_idx]
        score_field = _score_field(judge_type, is_original=is_original)
        if not overwrite and target_obj.get(score_field) is not None:
            continue

        _apply_parsed_to_sample(
            sample=sample,
            parsed=parsed,
            judge_type=judge_type,
            is_original=is_original,
            rewrite_idx=rewrite_idx,
            is_flat=is_flat,
            save_reason=save_reason,
            save_confidence=save_confidence,
            save_raw_output=save_raw_output,
        )

    return out_samples


def _read_jsonl(path: str) -> List[Dict[str, Any]]:
    items: List[Dict[str, Any]] = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            obj = json.loads(line)
            if isinstance(obj, dict):
                items.append(obj)
    return items


def _write_jsonl(path: str, items: List[Dict[str, Any]]) -> None:
    with open(path, "w", encoding="utf-8") as f:
        for obj in items:
            f.write(json.dumps(obj, ensure_ascii=False) + "\n")


def main() -> None:
    p = argparse.ArgumentParser(
        description="Evaluate original+rewrites with unified_judge.py judges via Batch APIs (OpenAI Batch / Gemini Batch)"
    )
    p.add_argument("--input_file", type=str, required=True, help="Input JSONL file")
    p.add_argument("--output_file", type=str, required=True, help="Output JSONL file")

    p.add_argument("--judge_model_name", type=str, default="gpt-4o-2024-11-20", help="Judge model (OpenAI or Gemini)")
    p.add_argument("--judges", type=str, default="cka,xteaming,generative", help="Comma-separated judges")
    p.add_argument("--target_model_holder", type=str, default="OpenAI", help="Used by xteaming judge prompt")

    p.add_argument("--temperature", type=float, default=0.0, help="Judge temperature (default: 0)")
    p.add_argument("--seed", type=int, default=123, help="Seed (set -1 to disable)")
    p.add_argument("--max_completion_tokens", type=int, default=1000, help="Max completion tokens per judge call")

    p.add_argument("--overwrite", action="store_true", help="Overwrite existing scores")
    p.add_argument("--start_idx", type=int, default=0, help="Start index in input file")
    p.add_argument("--end_idx", type=int, default=None, help="End index (exclusive)")

    p.add_argument("--batch_id", type=str, default=None, help="Existing batch ID (skip creation)")
    p.add_argument("--metadata_file", type=str, default=None, help="Metadata JSON file (required for --batch_id)")
    p.add_argument("--batch_requests_file", type=str, default=None, help="Path to save/load batch requests JSONL")
    p.add_argument("--completion_window", type=str, default="24h", help="Batch completion window")
    p.add_argument("--check_interval", type=int, default=60, help="Batch status poll interval (sec)")
    p.add_argument("--skip_wait", action="store_true", help="Submit batch and exit without waiting")

    p.add_argument("--save_reason", action="store_true", help="Save judge reason (if provided)")
    p.add_argument("--save_confidence", action="store_true", help="Save judge confidence (if provided)")
    p.add_argument("--save_raw_output", action="store_true", help="Save raw judge output (large!)")

    args = p.parse_args()

    provider = _detect_provider(args.judge_model_name)
    if provider == "claude":
        print("⚠️  Error: Claude models are not supported by this Batch script.")
        print("Use evaluate_rewrites_unifiedjudge.py (non-batch) instead.")
        sys.exit(1)
    if provider == "openai":
        if not os.environ.get("OPENAI_API_KEY"):
            print("⚠️  Error: OPENAI_API_KEY not found in environment variables.")
            print("Please set it with: export OPENAI_API_KEY='your-api-key'")
            sys.exit(1)
        _ = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
    else:
        # Gemini Batch API
        try:
            _ = _gemini_get_client()
        except Exception as e:
            print(f"⚠️  Error: {e}")
            sys.exit(1)

    judges = _parse_judges_arg(args.judges)
    seed_value: Optional[int] = None if args.seed is None or int(args.seed) < 0 else int(args.seed)

    # ---------------------------------------------------------------------
    # Mode B: retrieve results for existing batch
    # ---------------------------------------------------------------------
    if args.batch_id:
        if not args.metadata_file:
            raise ValueError("--metadata_file is required when using --batch_id")

        metadata_path = Path(args.metadata_file)
        metadata_info = json.loads(metadata_path.read_text(encoding="utf-8"))
        metadata = metadata_info["metadata"]
        provider_meta = metadata_info.get("provider") or _detect_provider(str(metadata_info.get("judge_model_name", "")))

        # Load input samples (keep original fields)
        input_file = args.input_file or metadata_info.get("input_file")
        if not input_file:
            raise ValueError("Input file not specified. Provide --input_file or include input_file in metadata.")
        all_samples = _read_jsonl(input_file)
        start_idx = int(metadata_info.get("start_idx", 0))
        end_idx = int(metadata_info.get("end_idx", len(all_samples)))
        samples = all_samples[start_idx:end_idx]
        print(f"Processing samples {start_idx} to {end_idx - 1} ({len(samples)} samples)")

        if provider_meta == "gemini":
            # Retrieve Gemini batch by name
            batch_name = args.batch_id
            print(f"Retrieving Gemini batch: {batch_name}")
            # (Re)create client in case Mode B is run standalone
            client_g = _gemini_get_client()
            batch = _wait_for_gemini_batch_completion(client_g, batch_name, args.check_interval) if not args.skip_wait else client_g.batches.get(name=batch_name)
            state = getattr(getattr(batch, "state", None), "name", None)
            if state and ("FAILED" in str(state) or "CANCEL" in str(state)):
                print(f"Cannot proceed: batch state is {state}")
                return
            # Download result file
            dest = getattr(batch, "dest", None)
            file_name = getattr(dest, "file_name", None) if dest is not None else None
            if not file_name:
                file_name = metadata_info.get("dest_file_name")
            if not file_name:
                raise ValueError("Cannot find Gemini batch result file name (dest.file_name).")
            print("Downloading batch results...")
            try:
                content_bytes = client_g.files.download(file=file_name)
            except TypeError:
                content_bytes = client_g.files.download(file_name)
            output_text = content_bytes.decode("utf-8") if isinstance(content_bytes, (bytes, bytearray)) else str(content_bytes)
            processed = _process_gemini_batch_results(
                output_jsonl=output_text,
                samples=samples,
                metadata=metadata,
                request_custom_ids=list(metadata_info.get("request_custom_ids") or []),
                overwrite=bool(metadata_info.get("overwrite", False)) or args.overwrite,
                save_reason=bool(args.save_reason),
                save_confidence=bool(args.save_confidence),
                save_raw_output=bool(args.save_raw_output),
            )
        else:
            # OpenAI Batch retrieve
            client_o = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
            print(f"Retrieving OpenAI batch: {args.batch_id}")
            batch_status = client_o.batches.retrieve(args.batch_id)
            if batch_status.status != "completed":
                print(f"Batch status is {batch_status.status}, not completed yet.")
                if args.skip_wait:
                    print("Skipping wait. Re-run later when the batch is completed.")
                    return
                batch_status = wait_for_batch_completion(client_o, args.batch_id, args.check_interval)
            if batch_status.status != "completed":
                print(f"Cannot proceed: batch status is {batch_status.status}")
                return

            print("Downloading batch results...")
            output_file_id = batch_status.output_file_id
            output_content = client_o.files.content(output_file_id).read()
            batch_results: List[Dict[str, Any]] = []
            for line in output_content.decode("utf-8").strip().split("\n"):
                if line.strip():
                    batch_results.append(json.loads(line))
            print(f"Retrieved {len(batch_results)} results")

            processed = process_batch_results(
                batch_results=batch_results,
                samples=samples,
                metadata=metadata,
                overwrite=bool(metadata_info.get("overwrite", False)) or args.overwrite,
                save_reason=bool(args.save_reason),
                save_confidence=bool(args.save_confidence),
                save_raw_output=bool(args.save_raw_output),
            )

        print(f"Saving results to {args.output_file} ...")
        _write_jsonl(args.output_file, processed)
        print("✅ Done.")
        return

    # ---------------------------------------------------------------------
    # Mode A: create and submit a new batch
    # ---------------------------------------------------------------------
    print(f"Loading samples from {args.input_file} ...")
    all_samples = _read_jsonl(args.input_file)
    start_idx = int(args.start_idx)
    end_idx = int(args.end_idx) if args.end_idx is not None else len(all_samples)
    samples = all_samples[start_idx:end_idx]
    print(f"Loaded {len(all_samples)} total samples; processing {start_idx}..{end_idx - 1} ({len(samples)} samples)")

    print(f"Preparing batch requests (judges={judges}, model={args.judge_model_name}) ...")
    batch_requests, metadata = prepare_batch_requests(
        samples=samples,
        judge_model_name=args.judge_model_name,
        target_model_holder=args.target_model_holder,
        judges=judges,
        temperature=float(args.temperature),
        seed=seed_value,
        max_completion_tokens=int(args.max_completion_tokens),
        overwrite=bool(args.overwrite),
        start_idx=start_idx,
        save_reason=bool(args.save_reason),
        save_confidence=bool(args.save_confidence),
        provider=provider,
    )
    print(f"Prepared {len(batch_requests)} requests")

    if len(batch_requests) == 0:
        print("No requests to submit (everything may already be scored; use --overwrite to force).")
        print(f"Writing passthrough output to {args.output_file} ...")
        _write_jsonl(args.output_file, samples)
        return

    # Save metadata
    if args.batch_requests_file:
        metadata_file = args.batch_requests_file.replace(".jsonl", "_metadata.json")
        batch_requests_path = args.batch_requests_file
    else:
        metadata_file = f"batch_unifiedjudge_metadata_{start_idx}_{end_idx}.json"
        batch_requests_path = f"batch_unifiedjudge_requests_{start_idx}_{end_idx}.jsonl"

    metadata_info = {
        "metadata": metadata,
        "judge_model_name": args.judge_model_name,
        "judges": judges,
        "target_model_holder": args.target_model_holder,
        "temperature": float(args.temperature),
        "seed": seed_value,
        "max_completion_tokens": int(args.max_completion_tokens),
        "input_file": args.input_file,
        "overwrite": bool(args.overwrite),
        "start_idx": start_idx,
        "end_idx": end_idx,
        "save_reason": bool(args.save_reason),
        "save_confidence": bool(args.save_confidence),
        "save_raw_output": bool(args.save_raw_output),
    }
    Path(metadata_file).write_text(json.dumps(metadata_info, ensure_ascii=False, indent=2), encoding="utf-8")
    print(f"Saved metadata to {metadata_file}")

    # Write batch requests JSONL
    with open(batch_requests_path, "w", encoding="utf-8") as f:
        for req in batch_requests:
            f.write(json.dumps(req) + "\n")
    print(f"Created batch requests file: {batch_requests_path}")

    # Upload + create batch
    print("Uploading batch requests file...")
    if provider == "gemini":
        # Write request_custom_ids so we can map results by order if needed
        metadata_info["provider"] = "gemini"
        metadata_info["request_custom_ids"] = [r.get("custom_id") for r in batch_requests if isinstance(r, dict) and r.get("custom_id")]
        Path(metadata_file).write_text(json.dumps(metadata_info, ensure_ascii=False, indent=2), encoding="utf-8")

        client_g = _gemini_get_client()
        # Upload JSONL to Gemini batch files
        upload_config = {"mime_type": "application/x-ndjson"}
        try:
            uploaded = client_g.files.upload(file=batch_requests_path, config=upload_config)
        except TypeError:
            # Some SDK versions accept file handles
            with open(batch_requests_path, "rb") as f:
                uploaded = client_g.files.upload(file=f, config=upload_config)
        src_name = getattr(uploaded, "name", None) or getattr(uploaded, "file", None) or getattr(uploaded, "id", None)
        if not src_name:
            raise ValueError("Gemini upload did not return a file identifier (uploaded.name).")
        print(f"Uploaded batch file: {src_name}")

        model_norm = _gemini_normalize_model(args.judge_model_name)
        display_name = f"unifiedjudge_{int(time.time())}"
        print("Creating Gemini batch...")
        
        # Retry with exponential backoff for quota errors
        max_retries = 5
        retry_delay = 60  # Start with 60 seconds
        for attempt in range(max_retries):
            try:
                try:
                    batch = client_g.batches.create(model=model_norm, src=src_name, config={"display_name": display_name})
                except TypeError:
                    batch = client_g.batches.create(model=model_norm, src=src_name)
                break  # Success
            except Exception as e:
                error_str = str(e)
                if "429" in error_str or "RESOURCE_EXHAUSTED" in error_str or "quota" in error_str.lower():
                    if attempt < max_retries - 1:
                        wait_time = retry_delay * (2 ** attempt)  # Exponential backoff
                        print(f"⚠️  Quota limit hit (attempt {attempt + 1}/{max_retries}). Waiting {wait_time} seconds before retry...")
                        time.sleep(wait_time)
                        continue
                    else:
                        print(f"❌ Quota limit exceeded after {max_retries} attempts.")
                        print(f"   Error: {error_str}")
                        print(f"   Suggestion: Try submitting smaller batches or wait longer.")
                        raise
                else:
                    # Non-quota error, re-raise immediately
                    raise
        batch_name = getattr(batch, "name", None) or getattr(batch, "id", None)
        if not batch_name:
            raise ValueError("Gemini batch create did not return a batch name/id.")
        print(f"Gemini batch created: {batch_name}")

        metadata_info["batch_id"] = batch_name
        metadata_info["provider"] = "gemini"
        metadata_info["src_file_name"] = src_name
        Path(metadata_file).write_text(json.dumps(metadata_info, ensure_ascii=False, indent=2), encoding="utf-8")
        print(f"Updated metadata file with batch_id: {batch_name}")

        print("\nTo retrieve results later, run:")
        print(f"  python {__file__} --batch_id {batch_name} --metadata_file {metadata_file} --input_file {args.input_file} --output_file {args.output_file}")

        if args.skip_wait:
            print("\nSkipping wait. Batch is processing in the background.")
            return

        batch_done = _wait_for_gemini_batch_completion(client_g, batch_name, args.check_interval)
        state = getattr(getattr(batch_done, "state", None), "name", None)
        if state and ("FAILED" in str(state) or "CANCEL" in str(state)):
            print(f"Cannot proceed: batch state is {state}")
            return

        dest = getattr(batch_done, "dest", None)
        dest_file_name = getattr(dest, "file_name", None) if dest is not None else None
        if not dest_file_name:
            raise ValueError("Gemini batch did not expose dest.file_name; cannot download results.")
        metadata_info["dest_file_name"] = dest_file_name
        Path(metadata_file).write_text(json.dumps(metadata_info, ensure_ascii=False, indent=2), encoding="utf-8")

        print("Downloading batch results...")
        try:
            content_bytes = client_g.files.download(file=dest_file_name)
        except TypeError:
            content_bytes = client_g.files.download(dest_file_name)
        output_text = content_bytes.decode("utf-8") if isinstance(content_bytes, (bytes, bytearray)) else str(content_bytes)

        processed = _process_gemini_batch_results(
            output_jsonl=output_text,
            samples=samples,
            metadata=metadata,
            request_custom_ids=list(metadata_info.get("request_custom_ids") or []),
            overwrite=bool(args.overwrite),
            save_reason=bool(args.save_reason),
            save_confidence=bool(args.save_confidence),
            save_raw_output=bool(args.save_raw_output),
        )
        print(f"Saving results to {args.output_file} ...")
        _write_jsonl(args.output_file, processed)
        print("✅ Done.")
        return

    # OpenAI Batch path (existing)
    metadata_info["provider"] = "openai"
    Path(metadata_file).write_text(json.dumps(metadata_info, ensure_ascii=False, indent=2), encoding="utf-8")
    client_o = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
    with open(batch_requests_path, "rb") as f:
        uploaded = client_o.files.create(file=f, purpose="batch")
    file_id = uploaded.id
    print(f"Uploaded file ID: {file_id}")

    print("Creating batch...")
    batch = client_o.batches.create(input_file_id=file_id, endpoint="/v1/chat/completions", completion_window=args.completion_window)
    batch_id = batch.id
    print(f"Batch created with ID: {batch_id}")

    # Update metadata with batch_id
    metadata_info["batch_id"] = batch_id
    Path(metadata_file).write_text(json.dumps(metadata_info, ensure_ascii=False, indent=2), encoding="utf-8")
    print(f"Updated metadata file with batch_id: {batch_id}")

    print("\nTo retrieve results later, run:")
    print(f"  python {__file__} --batch_id {batch_id} --metadata_file {metadata_file} --input_file {args.input_file} --output_file {args.output_file}")

    if args.skip_wait:
        print("\nSkipping wait. Batch is processing in the background.")
        return

    batch_status = wait_for_batch_completion(client_o, batch_id, args.check_interval)
    if batch_status.status != "completed":
        print(f"Cannot proceed: batch status is {batch_status.status}")
        return

    # Download results and write output immediately
    print("Downloading batch results...")
    output_file_id = batch_status.output_file_id
    output_content = client_o.files.content(output_file_id).read()
    batch_results: List[Dict[str, Any]] = []
    for line in output_content.decode("utf-8").strip().split("\n"):
        if line.strip():
            batch_results.append(json.loads(line))

    processed = process_batch_results(
        batch_results=batch_results,
        samples=samples,
        metadata=metadata,
        overwrite=bool(args.overwrite),
        save_reason=bool(args.save_reason),
        save_confidence=bool(args.save_confidence),
        save_raw_output=bool(args.save_raw_output),
    )
    print(f"Saving results to {args.output_file} ...")
    _write_jsonl(args.output_file, processed)
    print("✅ Done.")


if __name__ == "__main__":
    main()


