"""Evaluation script for math domain using vLLM and math-verify.

Evaluates models on validation.parquet using baseline prompts and math_verify for answer extraction and verification.
Supports GCP model download, multi-run evaluation with confidence intervals.

Usage:
    python -m advisor_models.math.eval_math_direct \
        --model_name "gpt-4o-mini" \
        --dataset_path data/math/validation.parquet \
        --api_model \
        --num_runs 5 \
        --max_workers 100
"""

import argparse
import os
import sys
from typing import Dict, Any, Optional
from tqdm import tqdm
import numpy as np
import litellm
from openai import OpenAI
import pandas as pd

# Add parent directory to path for imports
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../.."))

from utils.eval_utils import (
    download_model_if_needed,
    setup_vllm_server,
    cleanup_vllm_server,
    compute_multi_run_statistics,
    format_ci_string,
    add_common_eval_args,
)

from advisor_models.math.config import (
    BASELINE_SYSTEM_PROMPT,
    BASELINE_INSTRUCTION,
)

try:
    from math_verify import parse, verify
except ImportError:
    raise ImportError("math_verify is required. Install with: pip install math-verify")


def prepare_local_model(
    model_path: str, step_num: Optional[int] = None, is_hf_model: bool = False
) -> tuple[str, Optional[str]]:
    """Prepare local model for serving with vLLM.

    Args:
        model_path: Path to the model (HF format or SkyRL checkpoint) or HF model identifier
        step_num: Optional specific step number for SkyRL checkpoints
        is_hf_model: If True, model_path is a HuggingFace model identifier

    Returns:
        Tuple of (path_to_serve, temp_dir_to_cleanup)
        temp_dir_to_cleanup is None if no cleanup needed
    """
    from utils.upload_model_to_gcp import process_skyrl_model, is_skyrl_checkpoint

    # If it's a HuggingFace model identifier, vLLM will download it automatically
    if is_hf_model:
        print(f"Using HuggingFace model identifier: {model_path}")
        print("vLLM will download the model if not already cached")
        return model_path, None

    # For local paths, expand and check existence
    model_path = os.path.abspath(os.path.expanduser(model_path))

    if not os.path.exists(model_path):
        raise FileNotFoundError(f"Model path does not exist: {model_path}")

    # Check if this is a SkyRL checkpoint
    if is_skyrl_checkpoint(model_path):
        print(f"Detected SkyRL checkpoint at {model_path}")
        if step_num is not None:
            print(
                f"Processing FSDP shards for step {step_num} to HuggingFace format..."
            )
        else:
            print("Processing FSDP shards to HuggingFace format...")
        temp_dir = process_skyrl_model(model_path, step_num=step_num)
        print(f"Processed model saved to temporary directory: {temp_dir}")
        return temp_dir, temp_dir
    else:
        print(f"Using local HuggingFace model at {model_path}")
        return model_path, None


class MathEvaluator:
    """Evaluator for math problems via OpenAI-compatible endpoint."""

    def __init__(
        self,
        model_name: str,
        api_base: str = "http://127.0.0.1:8000/v1",
    ):
        """Initialize evaluator to call a remote model endpoint."""
        self.model_name = model_name
        self.api_base = api_base
        self.openai_client = OpenAI()
        self.system_prompt = BASELINE_SYSTEM_PROMPT

    def extract_answer(self, response_str: str) -> Optional[str]:
        """Extract the final answer from a math response using math_verify.parse."""
        try:
            return parse(response_str, parsing_timeout=None)
        except Exception:
            print(f"Error extracting answer: {response_str}")
            return None

    def compute_score(
        self, extracted_answer: Optional[str], ground_truth: str
    ) -> float:
        """Compute score by comparing extracted answer with ground truth using math_verify.verify."""
        if extracted_answer is None:
            return 0.0

        ground_truth_parsed = parse(ground_truth, parsing_timeout=None)

        try:
            return (
                1.0
                if verify(ground_truth_parsed, extracted_answer, timeout_seconds=None)
                else 0.0
            )
        except Exception:
            print(f"Error verifying answer: {extracted_answer}")
            return 0.0

    def generate_response(self, question: str) -> str:
        """Generate response for a math question using baseline prompt."""
        try:
            # Format question using baseline instruction
            formatted_question = BASELINE_INSTRUCTION.format(problem=question)

            # Build kwargs for litellm.completion with system prompt
            kwargs = {
                "model": self.model_name,
                "messages": [
                    {"role": "system", "content": self.system_prompt},
                    {"role": "user", "content": formatted_question},
                ],
                "temperature": 1.0,
            }

            # Only add api_base if it's not None (for vLLM server)
            if self.api_base is not None:
                kwargs["api_base"] = self.api_base

            response = litellm.completion(**kwargs)
            return (response.choices[0].message.content or "").strip()
        except Exception as e:
            print(f"Error generating response: {e}")
            return ""

    def process_single_example(self, idx_problem_tuple):
        """Process a single math problem."""
        idx, problem = idx_problem_tuple
        try:
            question = problem["original_question"]
            ground_truth = problem["reward_spec"]["ground_truth"]

            response = self.generate_response(question)
            extracted_answer = self.extract_answer(response)
            score = self.compute_score(extracted_answer, ground_truth)

            return {
                "index": idx,
                "question": question,
                "ground_truth": ground_truth,
                "response": response,
                "extracted_answer": extracted_answer,
                "score": score,
            }
        except Exception as e:
            print(f"Error processing problem {idx}: {e}")
            return {
                "index": idx,
                "question": problem.get("original_question", ""),
                "ground_truth": problem.get("ground_truth", ""),
                "response": "",
                "extracted_answer": None,
                "score": 0.0,
                "error": str(e),
            }

    def evaluate_dataset(
        self,
        dataset_path: str,
        max_workers: int = 12,
        max_examples: Optional[int] = None,
    ) -> Dict[str, Any]:
        """Evaluate the dataset."""
        # Load problems from parquet file
        df = pd.read_parquet(dataset_path)
        problems = df.to_dict("records")

        if max_examples is not None and len(problems) > max_examples:
            problems = problems[:max_examples]

        print(f"Evaluating {len(problems)} problems...")

        results = []
        from concurrent.futures import ThreadPoolExecutor, as_completed

        with ThreadPoolExecutor(max_workers=max_workers) as executor:
            future_to_example = {
                executor.submit(self.process_single_example, (idx, problem)): idx
                for idx, problem in enumerate(problems)
            }

            with tqdm(total=len(problems)) as pbar:
                for future in as_completed(future_to_example):
                    results.append(future.result())
                    pbar.update(1)

        # Sort by index to maintain order
        results.sort(key=lambda x: x["index"])

        all_scores = [r["score"] for r in results]
        metrics = {
            "total_examples": len(results),
            "overall_accuracy": float(np.mean(all_scores)) if all_scores else 0.0,
            "extraction_success": sum(
                1 for r in results if r["extracted_answer"] is not None
            ),
            "extraction_rate": sum(
                1 for r in results if r["extracted_answer"] is not None
            )
            / len(results)
            if results
            else 0.0,
        }

        return {
            "metrics": metrics,
            "detailed_results": results,
            "all_scores": all_scores,
        }

    def print_evaluation_report(
        self, evaluation_results: Dict[str, Any], aggregate_stats: Dict[str, Any]
    ):
        """Print a formatted evaluation report."""
        metrics = evaluation_results["metrics"]
        detailed_results = evaluation_results["detailed_results"]

        print("\n" + "=" * 70)
        print("MATH CORRECTNESS EVALUATION REPORT")
        print("=" * 70)

        # Print 2 example evaluations
        print("\n" + "-" * 70)
        print("EXAMPLE EVALUATIONS (2 samples)")
        print("-" * 70)

        # Get one correct and one incorrect example if possible
        correct_examples = [r for r in detailed_results if r["score"] > 0]
        incorrect_examples = [r for r in detailed_results if r["score"] == 0]

        examples_to_show = []
        if correct_examples:
            examples_to_show.append(correct_examples[0])
        if incorrect_examples:
            examples_to_show.append(incorrect_examples[0])

        # If we don't have both, just take first 2
        if len(examples_to_show) < 2:
            examples_to_show = detailed_results[:2]

        for i, example in enumerate(examples_to_show[:2], 1):
            print(f"\nExample {i}:")
            print(f"Question: {example['question']}")
            print(f"Ground Truth: {example['ground_truth']}")
            print(f"Model Response: {example['response']}")
            print(f"Extracted Answer: {example['extracted_answer']}")
            print(f"Score: {example['score']:.1f}")
            print("-" * 70)

        # Print aggregate statistics
        print("\n" + "=" * 70)
        print("AGGREGATE STATISTICS")
        print("=" * 70)
        print(f"\nNumber of runs: {aggregate_stats['n']}")
        print(f"Total examples per run: {metrics['total_examples']}")
        print(f"{format_ci_string(aggregate_stats, 'Accuracy')}")
        print(f"Extraction Rate: {metrics['extraction_rate']:.4f}")

        print("\n" + "=" * 70)


def run_multi_evaluation(
    evaluator: MathEvaluator,
    dataset_path: str,
    num_runs: int = 5,
    max_examples: Optional[int] = None,
    max_workers: int = 12,
) -> Dict[str, Any]:
    """Run evaluation multiple times and compute aggregate statistics."""
    all_run_scores = []  # List of lists: all individual scores from each run
    all_run_results = []

    for run_idx in range(num_runs):
        print(f"\n{'=' * 70}")
        print(f"EVALUATION RUN {run_idx + 1}/{num_runs}")
        print(f"{'=' * 70}")

        results = evaluator.evaluate_dataset(
            dataset_path,
            max_workers=max_workers,
            max_examples=max_examples,
        )

        all_run_results.append(results)
        # Collect all individual scores from this run
        all_run_scores.append(results["all_scores"])
        run_mean = results["metrics"]["overall_accuracy"]
        print(f"Run {run_idx + 1} accuracy: {run_mean:.4f}")

    # Compute aggregate statistics across runs using all individual scores
    aggregate_stats = compute_multi_run_statistics(all_run_scores)

    return {
        "run_results": all_run_results,
        "run_scores": all_run_scores,
        "aggregate_stats": aggregate_stats,
    }


def main():
    parser = argparse.ArgumentParser(
        description="Evaluate math correctness with GCP/vLLM support"
    )

    # Add common evaluation arguments
    add_common_eval_args(parser)

    # Add is_skyrl flag for local SkyRL checkpoint evaluation
    parser.add_argument(
        "--is_skyrl",
        action="store_true",
        help="Flag to indicate model_name is a local SkyRL checkpoint that needs processing",
    )
    parser.add_argument(
        "--step_num",
        type=int,
        default=None,
        help="Specific step number for SkyRL checkpoint (uses latest if not specified)",
    )

    # Add api_model flag for simple API model evaluation
    parser.add_argument(
        "--api_model",
        action="store_true",
        help="Flag to indicate model_name is an API model (e.g., gpt-4o-mini) - no vLLM server needed",
    )

    args = parser.parse_args()

    # Setup model and vLLM server
    vllm_process = None
    temp_dir = None

    try:
        if args.api_model:
            # Use API model directly without vLLM server
            print(f"\nUsing API model: {args.model_name}")
            print("No vLLM server needed - calling API directly via litellm")

            # Initialize evaluator with API model (no vLLM server)
            evaluator = MathEvaluator(
                model_name=args.model_name,
                api_base=None,  # Will use default litellm routing
            )

        elif args.is_skyrl:
            # Process local SkyRL checkpoint
            print(f"\nDetected local SkyRL checkpoint: {args.model_name}")
            print("Preparing model for vLLM serving...")

            model_path_to_serve, temp_dir = prepare_local_model(
                args.model_name, step_num=args.step_num, is_hf_model=False
            )

            # Start vLLM server with processed model
            print(f"\nStarting vLLM server for model: {model_path_to_serve}")
            served_model_name = "math_model"
            vllm_process = setup_vllm_server(
                model_path=model_path_to_serve,
                served_model_name=served_model_name,
                tensor_parallel_size=args.tensor_parallel_size,
                max_model_len=args.max_model_len,
            )
            print(f"vLLM server started. Using model name: {served_model_name}")

            # Initialize evaluator
            evaluator = MathEvaluator(
                model_name="hosted_vllm/" + served_model_name,
                api_base="http://127.0.0.1:8000/v1",
            )

        else:
            # Download model from GCP if needed
            model_path, temp_dir = download_model_if_needed(
                model_name=args.model_name,
                gcp=args.gcp,
                bucket_name=args.bucket_name,
            )

            # Start vLLM server
            served_model_name = "math_model"
            vllm_process = setup_vllm_server(
                model_path=model_path,
                served_model_name=served_model_name,
                tensor_parallel_size=args.tensor_parallel_size,
                max_model_len=args.max_model_len,
            )

            # Initialize evaluator
            evaluator = MathEvaluator(
                model_name="hosted_vllm/" + served_model_name,
                api_base="http://127.0.0.1:8000/v1",
            )

        # Run multi-evaluation
        multi_results = run_multi_evaluation(
            evaluator=evaluator,
            dataset_path=args.dataset_path,
            num_runs=args.num_runs,
            max_examples=args.max_examples,
            max_workers=args.max_workers,
        )

        # Print final report
        evaluator.print_evaluation_report(
            multi_results["run_results"][-1],
            aggregate_stats=multi_results["aggregate_stats"],
        )

    finally:
        # Cleanup vLLM server and temp directory
        cleanup_vllm_server(vllm_process, temp_dir)


if __name__ == "__main__":
    main()
