#!/usr/bin/env python3
"""
Evaluate original + rewrites with `Decoy-for-the-Judge/unified_judge.py` judges.

This script is similar to evaluate_rewrites_unifiedjudge_batch.py but uses concurrent
processing instead of OpenAI Batch API.

Modes:
- Local SGLang (OpenAI-compatible): pass --server_url (default) or --use_local
- Remote provider (auto-detect): pass --provider auto/openai/gemini/claude and omit --use_local
"""

from __future__ import annotations

import argparse
import json
import sys
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm

# Ensure repo root is on sys.path
SCRIPT_DIR = Path(__file__).resolve().parent
REPO_ROOT = SCRIPT_DIR.parent
if str(REPO_ROOT) not in sys.path:
    sys.path.insert(0, str(REPO_ROOT))

from unified_judge import create_judge
from evaluate_rewrites_unifiedjudge_batch import (
    _parse_judges_arg,
    _get_sample_fields,
    _get_rewrite_targets,
    _score_field,
    _apply_parsed_to_sample,
    _read_jsonl,
    _write_jsonl,
)


def _canonical_judge_type(judge_type: str) -> str:
    jt = (judge_type or "").strip().lower()
    if jt == "generative":
        return "pair"
    return jt


def evaluate_sample(
    sample: Dict[str, Any],
    judges: List[str],
    judge_instances: Dict[str, Any],
    target_model_holder: str,
    save_reason: bool,
    save_confidence: bool,
    save_raw_output: bool,
    overwrite: bool,
) -> Dict[str, Any]:
    """Evaluate a single sample with all judges."""
    # Work on a copy
    result_sample = sample.copy()
    if isinstance(sample.get("rewrites"), list):
        result_sample["rewrites"] = [r.copy() if isinstance(r, dict) else r for r in sample.get("rewrites", [])]
    
    prompt, original_response = _get_sample_fields(sample)
    
    # Evaluate original response
    for judge_type in judges:
        score_field = _score_field(judge_type, is_original=True)
        if not overwrite and result_sample.get(score_field) is not None:
            continue
        
        judge = judge_instances[judge_type]
        try:
            result = judge.infer_single(prompt, original_response)
            score = result[0] if isinstance(result, (list, tuple)) and len(result) > 0 else result
            reason = result[1] if isinstance(result, (list, tuple)) and len(result) > 1 else ""
            confidence = result[2] if isinstance(result, (list, tuple)) and len(result) > 2 else None
            raw_output = result[3] if isinstance(result, (list, tuple)) and len(result) > 3 else ""
            parsed_dict = {"score": score, "reason": reason, "confidence": confidence, "raw_output": raw_output}
            
            _apply_parsed_to_sample(
                sample=result_sample,
                parsed=parsed_dict,
                judge_type=judge_type,
                is_original=True,
                rewrite_idx=None,
                is_flat=False,
                save_reason=save_reason,
                save_confidence=save_confidence,
                save_raw_output=save_raw_output,
            )
        except Exception as e:
            print(f"Error evaluating original with {judge_type}: {e}")
            continue
    
    # Evaluate rewrites
    for rewrite_idx, rewrite_obj, is_flat in _get_rewrite_targets(sample):
        rewritten_response = rewrite_obj.get("rewritten_response") if not is_flat else sample.get("rewritten_response")
        rewritten_response = rewritten_response or ""
        if not rewritten_response:
            continue
        
        for judge_type in judges:
            score_field = _score_field(judge_type, is_original=False)
            target_obj = rewrite_obj if not is_flat else result_sample
            if not overwrite and target_obj.get(score_field) is not None:
                continue
            
            judge = judge_instances[judge_type]
            try:
                result = judge.infer_single(prompt, rewritten_response)
                score = result[0] if isinstance(result, (list, tuple)) and len(result) > 0 else result
                reason = result[1] if isinstance(result, (list, tuple)) and len(result) > 1 else ""
                confidence = result[2] if isinstance(result, (list, tuple)) and len(result) > 2 else None
                raw_output = result[3] if isinstance(result, (list, tuple)) and len(result) > 3 else ""
                parsed_dict = {"score": score, "reason": reason, "confidence": confidence, "raw_output": raw_output}
                
                _apply_parsed_to_sample(
                    sample=result_sample,
                    parsed=parsed_dict,
                    judge_type=judge_type,
                    is_original=False,
                    rewrite_idx=rewrite_idx,
                    is_flat=is_flat,
                    save_reason=save_reason,
                    save_confidence=save_confidence,
                    save_raw_output=save_raw_output,
                )
            except Exception as e:
                print(f"Error evaluating rewrite {rewrite_idx} with {judge_type}: {e}")
                continue
    
    return result_sample


def main() -> None:
    p = argparse.ArgumentParser(
        description="Evaluate original+rewrites with unified_judge.py judges (local SGLang or remote providers)"
    )
    p.add_argument("--input_file", type=str, required=True, help="Input JSONL file")
    p.add_argument("--output_file", type=str, required=True, help="Output JSONL file")
    
    p.add_argument("--judge_model_name", type=str, default="default", help="Judge model name")
    p.add_argument("--judges", type=str, default="cka,xteaming,generative", help="Comma-separated judges")
    p.add_argument("--target_model_holder", type=str, default="OpenAI", help="Used by xteaming judge prompt")
    
    p.add_argument("--provider", type=str, default="auto", choices=["auto", "openai", "gemini", "claude"], help="API provider (default: auto-detect from model name)")
    p.add_argument("--use_local", action="store_true", help="Use local OpenAI-compatible server (SGLang). If set, --server_url is used.")
    p.add_argument("--server_url", type=str, default="http://localhost:30000", help="Local SGLang server URL (only used with --use_local)")
    p.add_argument("--temperature", type=float, default=0.0, help="Judge temperature (default: 0)")
    p.add_argument("--seed", type=int, default=123, help="Seed (set -1 to disable)")
    p.add_argument("--max_completion_tokens", type=int, default=400, help="Max completion tokens per judge call")
    
    p.add_argument("--overwrite", action="store_true", help="Overwrite existing scores")
    p.add_argument("--start_idx", type=int, default=0, help="Start index in input file")
    p.add_argument("--end_idx", type=int, default=None, help="End index (exclusive)")
    p.add_argument("--batch_size", type=int, default=32, help="Batch size for concurrent processing (default: 32)")
    
    p.add_argument("--save_reason", action="store_true", help="Save judge reason (if provided)")
    p.add_argument("--save_confidence", action="store_true", help="Save judge confidence (if provided)")
    p.add_argument("--save_raw_output", action="store_true", help="Save raw judge output (large!)")
    
    args = p.parse_args()
    
    judges = _parse_judges_arg(args.judges)
    seed_value: Optional[int] = None if args.seed < 0 else int(args.seed)

    use_local = bool(args.use_local)
    local_base_url = None
    if use_local:
        local_base_url = f"{args.server_url}/v1" if not args.server_url.endswith("/v1") else args.server_url
        print(f"Using local SGLang server: {args.server_url}")
    else:
        print(f"Using remote provider: {args.provider} (model: {args.judge_model_name})")
    print(f"Judges: {judges}")
    print(f"Batch size: {args.batch_size}")
    
    # Create judge instances
    judge_instances = {}
    for judge_type in judges:
        canon = _canonical_judge_type(judge_type)
        extra: Dict[str, Any] = {}
        # Only override prompt defaults if user explicitly wants to save the fields.
        if args.save_reason:
            extra["include_reason"] = True
        if args.save_confidence:
            extra["include_confidence"] = True
        if canon == "xteaming":
            extra["target_model_holder"] = args.target_model_holder
        judge_instances[judge_type] = create_judge(
            judge_type=canon,
            model_name=args.judge_model_name,
            temperature=args.temperature,
            seed=seed_value,
            max_completion_tokens=args.max_completion_tokens,
            use_local=use_local,
            local_base_url=local_base_url or "http://localhost:30000/v1",
            provider=None if args.provider == "auto" else args.provider,
            **extra,
        )
    
    # Load samples
    print(f"Loading samples from {args.input_file} ...")
    all_samples = _read_jsonl(args.input_file)
    start_idx = int(args.start_idx)
    end_idx = int(args.end_idx) if args.end_idx is not None else len(all_samples)
    samples = all_samples[start_idx:end_idx]
    print(f"Loaded {len(all_samples)} total samples; processing {start_idx}..{end_idx - 1} ({len(samples)} samples)")
    
    # Process samples with concurrent execution
    results = []
    
    def process_single_sample(idx: int) -> Tuple[int, Dict[str, Any]]:
        sample = samples[idx]
        try:
            evaluated = evaluate_sample(
                sample=sample,
                judges=judges,
                judge_instances=judge_instances,
                target_model_holder=args.target_model_holder,
                save_reason=args.save_reason,
                save_confidence=args.save_confidence,
                save_raw_output=args.save_raw_output,
                overwrite=args.overwrite,
            )
            return (idx, evaluated)
        except Exception as e:
            print(f"\nError processing sample {idx}: {e}")
            import traceback
            traceback.print_exc()
            return (idx, sample)  # Return original sample on error
    
    with ThreadPoolExecutor(max_workers=args.batch_size) as executor:
        future_to_idx = {executor.submit(process_single_sample, idx): idx for idx in range(len(samples))}
        
        for future in tqdm(as_completed(future_to_idx), total=len(samples), desc="Evaluating samples"):
            idx, evaluated_sample = future.result()
            results.append((idx, evaluated_sample))
    
    # Sort by index to maintain order
    results.sort(key=lambda x: x[0])
    
    # Write output
    evaluated_samples = [sample for _, sample in results]
    print(f"Saving results to {args.output_file} ...")
    _write_jsonl(args.output_file, evaluated_samples)
    print("✅ Done.")


if __name__ == "__main__":
    main()

