#!/usr/bin/env python3

import re
import json
import torch
import random
import argparse
import warnings
import time
import os
import glob
from typing import List, Dict, Tuple, Optional
from datetime import datetime

from tqdm import tqdm
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer

# Import existing utilities from the MATH evaluation codebase
from data_loader import load_data
from parser import *
from utils import construct_prompt, set_seed
from trajectory import *
from evaluate import evaluate
from python_executor import PythonExecutor

os.environ["TOKENIZERS_PARALLELISM"] = "false"

warnings.filterwarnings("ignore")

# --------------------------------------------------------------------------- #
#                                 CONSTANTS                                   #
# --------------------------------------------------------------------------- #

XML_SYSTEM_PROMPT = """
Respond in the following format, with only the final answer between the <answer> tags and always put your answer in boxed:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
""".strip()

# --------------------------------------------------------------------------- #
#                               EVALUATOR                                     #
# --------------------------------------------------------------------------- #


class MATH500Evaluator:
    """Evaluator for MATH dataset using vLLM."""

    def __init__(
        self,
        model_name: str,
        tensor_parallel_size: int = 1,
        prompt_type: str = "cot",
        temperature: float = 0.0,
        is_instruct: bool = True,
    ):
        print(f"Loading model with vLLM: {model_name}  (TP={tensor_parallel_size})")
        self.model_name = model_name
        self.prompt_type = prompt_type
        self.temperature = temperature
        self.is_instruct = is_instruct

        # ------ Tokenizer --------------------------------------------------- #
        self.tokenizer = AutoTokenizer.from_pretrained(
            model_name, trust_remote_code=True, padding_side="left"
        )
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
            self.tokenizer.pad_token_id = self.tokenizer.eos_token_id

        # ------ vLLM runtime ------------------------------------------------- #
        self.llm = LLM(model_name, tensor_parallel_size=tensor_parallel_size)

        # ------ Python executor for answer extraction ----------------------- #
        self.executor = PythonExecutor(get_answer_from_stdout=True)

        # ------ System prompt ----------------------------------------------- #
        self.system_prompt = XML_SYSTEM_PROMPT

    def cleanup(self):
        """Clean up resources."""
        # This ensures vLLM properly releases GPU memory
        del self.llm
        torch.cuda.empty_cache()

    # ------------------------------------------------------------------ #
    #                  Helper / extraction utilities                     #
    # ------------------------------------------------------------------ #

    def extract_xml_answer(self, text: str) -> Optional[str]:
        """Extract answer from XML tags if present."""
        # Look for answer in XML tags, handling newlines
        m = re.search(
            r"<answer>[\s\n]*(.*?)[\s\n]*</answer>", text, flags=re.DOTALL
        )
        if m:
            return m.group(1).strip()
        return None

    def prepare_code_for_extraction(self, code: str) -> str:
        """
        If the code contains XML answer tags, extract the answer and
        prepare it for the core extraction logic.
        """
        # First check if there's an XML answer
        xml_answer = self.extract_xml_answer(code)
        
        if xml_answer:
            cleaned_answer = re.sub(r'\\\((.*?)\\\)', r'\1', xml_answer)
            cleaned_answer = re.sub(r'^\\boxed\{(.*)\}$', r'\1', cleaned_answer, flags=re.DOTALL)
            return f"The answer is \\boxed{{{cleaned_answer}}}"
        
        # For non-XML answers, strip inline math delimiters but don't add wrapper
        code = re.sub(r'\\\((.*?)\\\)', r'\1', code, flags=re.DOTALL)
        
        return code

    # ------------------------------------------------------------------ #
    #                    vLLM-based generation                           #
    # ------------------------------------------------------------------ #

    def generate_answer(self, prompt: str, max_new_tokens: int = 1024) -> str:
        """Generate answer using vLLM."""
        # Apply chat template if using instruct model
        if self.is_instruct:
            messages = [
                {"role": "system", "content": self.system_prompt},
                {"role": "user", "content": prompt},
            ]
            prompt = self.tokenizer.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=True
            )
        
        # Set up stop words based on prompt type
        stop_words = ["</s>"]
        if self.prompt_type in ['cot']:
            stop_words.extend(["\n\nQuestion:", "\n\nProblem:"])
        elif self.prompt_type in ['pal', 'tool-integrated', 'tora']:
            stop_words.extend(["\n\n---", "```output"])
        
        params = SamplingParams(
            temperature=self.temperature,
            top_p=1.0 if self.temperature == 0 else 0.95,
            max_tokens=max_new_tokens,
            stop=stop_words,
        )

        outputs = self.llm.generate([prompt], params)
        answer = outputs[0].outputs[0].text.strip()

        # For XML format, ensure we capture up to </answer> if present
        if "</answer>" in answer:
            answer = answer.split("</answer>")[0] + "</answer>"

        return answer

    # ------------------------------------------------------------------ #
    #                       Evaluation logic                             #
    # ------------------------------------------------------------------ #

    def evaluate_sample(
        self, example: dict, data_name: str = "math"
    ) -> dict:
        """Evaluate a single sample."""
        # Parse question and ground truth
        question = parse_question(example, data_name)
        gt_cot, gt_ans = parse_ground_truth(example, data_name)
        
        # For instruct models, we pass the question directly
        # For non-instruct models, we use construct_prompt
        if self.is_instruct:
            prompt_to_use = question
        else:
            # Construct prompt for non-instruct models
            example['question'] = question
            args = argparse.Namespace(
                prompt_type=self.prompt_type,
                temperature=self.temperature
            )
            prompt_to_use = construct_prompt(example, data_name, args)
        
        # Generate answer
        code = self.generate_answer(prompt_to_use)
        
        # Prepare code for extraction (handle XML tags)
        prepared_code = self.prepare_code_for_extraction(code)
        # Extract prediction using the existing logic
        pred, report = run_execute(self.executor, prepared_code, self.prompt_type, data_name)
        
        # Create sample result
        sample = {
            'idx': example['idx'],
            'question': question,
            'gt_cot': gt_cot,
            'gt': gt_ans,
            'code': [prepared_code],
            'pred': [pred],
            'report': [report],
        }
        
        # Add additional fields if present
        for key in ['level', 'type', 'unit', 'solution_type', 'choices', 'solution', 
                    'ques_type', 'ans_type', 'answer_type', 'dataset', 'subfield', 
                    'filed', 'theorem', 'answer']:
            if key in example:
                sample[key] = example[key]
        
        return sample

    def evaluate_dataset(
        self,
        data_dir: str = "./data",
        split: str = "test",
        num_samples: int = 50,
        seed: int = 42,
        start_idx: int = 0,
    ) -> Dict:
        """Evaluate on MATH dataset."""
        set_seed(seed)
        
        # Load MATH dataset
        print("Loading MATH dataset...")
        examples = load_data("math", split, data_dir)
        
        # Sample or slice examples
        if start_idx > 0:
            examples = examples[start_idx:start_idx + num_samples]
        else:
            if num_samples > 0 and num_samples < len(examples):
                random.shuffle(examples)
                examples = examples[:num_samples]
        
        print(f"Evaluating on {len(examples)} samples...")
        
        # Evaluate all samples
        all_samples = []
        start_time = time.time()
        
        for example in tqdm(examples, desc="Evaluating"):
            try:
                sample = self.evaluate_sample(example, "math")
                all_samples.append(sample)
                
                # Print first example
                if example == examples[0]:
                    print("\n" + "="*60)
                    print("First example:")
                    print(f"Question: {sample['question'][:200]}...")
                    print(f"Generated: {sample['code'][0][:200]}...")
                    print(f"Predicted: {sample['pred'][0]}")
                    print(f"Ground Truth: {sample['gt']}")
                    print("="*60 + "\n")
                    
            except Exception as e:
                print(f"\nError processing sample {example['idx']}: {e}")
                # Add failed sample with None prediction
                sample = {
                    'idx': example['idx'],
                    'question': parse_question(example, "math"),
                    'gt_cot': "",
                    'gt': parse_ground_truth(example, "math")[1],
                    'code': [""],
                    'pred': [None],
                    'report': [str(e)],
                }
                all_samples.append(sample)
        
        time_use = time.time() - start_time
        
        # Evaluate using the existing evaluation function
        all_samples, result_json = evaluate(
            samples=all_samples, 
            data_name="math", 
            prompt_type=self.prompt_type, 
            execute=True
        )
        
        # Add timing information
        result_json['time_use_in_second'] = time_use
        result_json['time_use_in_minute'] = f"{int(time_use // 60)}:{int(time_use % 60):02d}"
        result_json['num_samples'] = len(all_samples)
        
        return result_json, all_samples


# --------------------------------------------------------------------------- #
#                           BATCH EVALUATION                                  #
# --------------------------------------------------------------------------- #

def find_checkpoints(base_dir: str) -> List[str]:
    """Find all checkpoint directories in a given base directory."""
    checkpoint_paths = []
    
    # Look for directories starting with 'checkpoint'
    pattern = os.path.join(base_dir, "checkpoint*")
    checkpoints = glob.glob(pattern)
    
    # Filter to ensure we only get directories
    checkpoints = [cp for cp in checkpoints if os.path.isdir(cp)]
    
    # Sort checkpoints naturally (by checkpoint number)
    def extract_number(path):
        match = re.search(r'checkpoint-(\d+)', os.path.basename(path))
        return int(match.group(1)) if match else 0
    
    checkpoints.sort(key=extract_number)
    
    return checkpoints

def evaluate_all_checkpoints(args) -> None:
    """Evaluate all checkpoints in the specified folders."""
    
    # Define the base folders
    base_folders = [
        ""
    ]
    
    # Results storage
    all_results = {
        "evaluation_date": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
        "evaluation_params": {
            "num_samples": args.num_samples,
            "prompt_type": args.prompt_type,
            "temperature": args.temperature,
            "seed": args.seed,
            "split": args.split,
            "instruct": args.instruct
        },
        "results": {}
    }
    
    # Track overall progress
    total_checkpoints = 0
    evaluated_checkpoints = 0
    
    # First, count total checkpoints
    for folder in base_folders:
        if os.path.exists(folder):
            checkpoints = find_checkpoints(folder)
            total_checkpoints += len(checkpoints)
    
    print(f"\nTotal checkpoints to evaluate: {total_checkpoints}")
    
    # Evaluate each folder
    for folder in base_folders:
        if not os.path.exists(folder):
            print(f"\nWarning: Folder not found: {folder}")
            continue
        
        folder_name = os.path.basename(folder)
        print(f"\n{'='*80}")
        print(f"Processing folder: {folder_name}")
        print(f"{'='*80}")
        
        checkpoints = find_checkpoints(folder)
        print(f"Found {len(checkpoints)} checkpoints in {folder_name}")
        
        if folder_name not in all_results["results"]:
            all_results["results"][folder_name] = {}
        
        for checkpoint_path in checkpoints:
            checkpoint_name = os.path.basename(checkpoint_path)
            evaluated_checkpoints += 1
            
            print(f"\n[{evaluated_checkpoints}/{total_checkpoints}] Evaluating: {folder_name}/{checkpoint_name}")
            print("-" * 60)
            
            try:
                # Initialize evaluator for this checkpoint
                evaluator = MATH500Evaluator(
                    model_name=checkpoint_path,
                    tensor_parallel_size=args.tp,
                    prompt_type=args.prompt_type,
                    temperature=args.temperature,
                    is_instruct=args.instruct,
                )
                
                # Run evaluation
                results, samples = evaluator.evaluate_dataset(
                    data_dir=args.data_dir,
                    split=args.split,
                    num_samples=args.num_samples,
                    seed=args.seed,
                    start_idx=args.start_idx,
                )
                
                # Store results
                all_results["results"][folder_name][checkpoint_name] = {
                    "accuracy": results['acc'],
                    "correct_count": results.get('count', 0),
                    "total_samples": results['num_samples'],
                    "time_used": results['time_use_in_minute'],
                    "type_accuracy": results.get('type_acc', {}),
                    "checkpoint_path": checkpoint_path
                }
                
                print(f"✓ Accuracy: {results['acc']:.2f}%")
                
                # Clean up to free GPU memory
                evaluator.cleanup()
                del evaluator
                torch.cuda.empty_cache()
                
                # Save intermediate results after each checkpoint
                intermediate_file = args.output_file.replace(".json", "_intermediate.json")
                with open(intermediate_file, "w") as f:
                    json.dump(all_results, f, indent=2)
                
            except Exception as e:
                print(f"✗ Error evaluating {checkpoint_name}: {str(e)}")
                all_results["results"][folder_name][checkpoint_name] = {
                    "error": str(e),
                    "checkpoint_path": checkpoint_path
                }
    
    # Save final results
    with open(args.output_file, "w") as f:
        json.dump(all_results, f, indent=2)
    
    # Print summary
    print(f"\n{'='*80}")
    print("EVALUATION SUMMARY")
    print(f"{'='*80}")
    
    for folder_name in all_results["results"]:
        print(f"\n{folder_name}:")
        for checkpoint_name, result in all_results["results"][folder_name].items():
            if "accuracy" in result:
                print(f"  {checkpoint_name}: {result['accuracy']:.2f}%")
            else:
                print(f"  {checkpoint_name}: ERROR - {result.get('error', 'Unknown error')}")
    
    print(f"\nAll results saved to: {args.output_file}")
    
    # Clean up intermediate file
    intermediate_file = args.output_file.replace(".json", "_intermediate.json")
    if os.path.exists(intermediate_file):
        os.remove(intermediate_file)


# --------------------------------------------------------------------------- #
#                                    CLI                                      #
# --------------------------------------------------------------------------- #


def main() -> None:
    parser = argparse.ArgumentParser(description="Batch MATH dataset evaluation with vLLM")
    parser.add_argument("--data_dir", default="./data", type=str, help="Data directory")
    parser.add_argument("--tp", type=int, default=1, help="#GPUs for vLLM sharding")
    parser.add_argument("--num_samples", type=int, default=100, help="Number of samples to evaluate per checkpoint")
    parser.add_argument("--seed", type=int, default=42, help="Random seed")
    parser.add_argument("--start_idx", type=int, default=0, help="Starting index")
    parser.add_argument("--split", default="test", type=str, help="Dataset split")
    parser.add_argument("--prompt_type", default="cot", type=str, help="Prompt type")
    parser.add_argument("--temperature", default=0.0, type=float, help="Sampling temperature")
    parser.add_argument(
        "--output_file", 
        default="all_checkpoints_evaluation_results.json", 
        type=str, 
        help="Output file for all results"
    )
    parser.add_argument(
        "--instruct",
        action="store_true",
        default=True,
        help="set if model was trained with chat / system prompt (XML format)",
    )
    
    args = parser.parse_args()

    print("=" * 80)
    print("BATCH MATH DATASET EVALUATION")
    print(f"Prompt type:     {args.prompt_type}")
    print(f"Temperature:     {args.temperature}")
    print(f"Tensor parallel: {args.tp}")
    print(f"Samples per checkpoint: {args.num_samples}")
    print(f"Output file:     {args.output_file}")
    print("=" * 80)

    # Run batch evaluation
    evaluate_all_checkpoints(args)


if __name__ == "__main__":
    main()