#!/usr/bin/env python3

import re
import json
import torch
import random
import argparse
import warnings
import time
from typing import List, Dict, Tuple, Optional

from tqdm import tqdm
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer

# Import existing utilities from the MATH evaluation codebase
from data_loader import load_data
from parser import *
from utils import construct_prompt, set_seed
from trajectory import *
from evaluate import evaluate
from python_executor import PythonExecutor
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

warnings.filterwarnings("ignore")

# --------------------------------------------------------------------------- #
#                                 CONSTANTS                                   #
# --------------------------------------------------------------------------- #

XML_SYSTEM_PROMPT = """
Respond in the following format, with only the final answer between the <answer> tags and always put your answer in boxed:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
""".strip()

# --------------------------------------------------------------------------- #
#                               EVALUATOR                                     #
# --------------------------------------------------------------------------- #


class MATH500Evaluator:
    """Evaluator for MATH dataset using vLLM."""

    def __init__(
        self,
        model_name: str,
        tensor_parallel_size: int = 1,
        prompt_type: str = "cot",
        temperature: float = 0.0,
        is_instruct: bool = True,
    ):
        print(f"Loading model with vLLM: {model_name}  (TP={tensor_parallel_size})")
        self.model_name = model_name
        self.prompt_type = prompt_type
        self.temperature = temperature
        self.is_instruct = is_instruct

        # ------ Tokenizer --------------------------------------------------- #
        self.tokenizer = AutoTokenizer.from_pretrained(
            model_name, trust_remote_code=True, padding_side="left"
        )
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
            self.tokenizer.pad_token_id = self.tokenizer.eos_token_id

        # ------ vLLM runtime ------------------------------------------------- #
        self.llm = LLM(model_name, tensor_parallel_size=tensor_parallel_size)

        # ------ Python executor for answer extraction ----------------------- #
        self.executor = PythonExecutor(get_answer_from_stdout=True)

        # ------ System prompt ----------------------------------------------- #
        self.system_prompt = XML_SYSTEM_PROMPT

    # ------------------------------------------------------------------ #
    #                  Helper / extraction utilities                     #
    # ------------------------------------------------------------------ #

    def extract_xml_answer(self, text: str) -> Optional[str]:
        """Extract answer from XML tags if present."""
        # Look for answer in XML tags, handling newlines
        m = re.search(
            r"<answer>[\s\n]*(.*?)[\s\n]*</answer>", text, flags=re.DOTALL
        )
        if m:
            return m.group(1).strip()
        return None

    def prepare_code_for_extraction(self, code: str) -> str:
        """
        If the code contains XML answer tags, extract the answer and
        prepare it for the core extraction logic.
        """
        # First check if there's an XML answer
        xml_answer = self.extract_xml_answer(code)
        
        if xml_answer:
            # cleaned_answer = re.sub(r'^\\\((.*)\\\)$', r'\1', xml_answer.strip(), flags=re.DOTALL) # correct
            # cleaned_answer = cleaned_answer.strip()
            # return f"The answer is \\boxed{{{cleaned_answer}}}"

            cleaned_answer = re.sub(r'\\\((.*?)\\\)', r'\1', xml_answer)
            cleaned_answer = re.sub(r'^\\boxed\{(.*)\}$', r'\1', cleaned_answer, flags=re.DOTALL)
            return f"The answer is \\boxed{{{cleaned_answer}}}"

        
        # For non-XML answers, strip inline math delimiters but don't add wrapper
        # code = re.sub(r'\\\\\\((.*?)\\\\\\)', r'\1', code, flags=re.DOTALL)
        code = re.sub(r'\\\((.*?)\\\)', r'\1', code, flags=re.DOTALL)
        
        return code

    # ------------------------------------------------------------------ #
    #                    vLLM-based generation                           #
    # ------------------------------------------------------------------ #

    def generate_answers(self, prompts: List[str], max_new_tokens: int = 8192) -> List[str]:
        """Generate answer using vLLM."""
        # Apply chat template if using instruct model

        for prompt in prompts:
            if self.is_instruct:
                messages = [
                    {"role": "system", "content": self.system_prompt},
                    {"role": "user", "content": prompt},
                ]
                prompt = self.tokenizer.apply_chat_template(
                    messages, tokenize=False, add_generation_prompt=True
                )
        
        # Set up stop words based on prompt type
        stop_words = ["</s>"]
        if self.prompt_type in ['cot']:
            stop_words.extend(["\n\nQuestion:", "\n\nProblem:"])
        elif self.prompt_type in ['pal', 'tool-integrated', 'tora']:
            stop_words.extend(["\n\n---", "```output"])
        
        params = SamplingParams(
            temperature=self.temperature,
            top_p=1.0 if self.temperature == 0 else 0.95,
            max_tokens=max_new_tokens,
            stop=stop_words,
        )

        outputs = self.llm.generate(prompts, params)
        answers = [output.outputs[0].text.strip() for output in outputs]

        # For XML format, ensure we capture up to </answer> if present
        for answer in answers:
            if "</answer>" in answer:
                answer = answer.split("</answer>")[0] + "</answer>"

        return answers

    # ------------------------------------------------------------------ #
    #                       Evaluation logic                             #
    # ------------------------------------------------------------------ #

    def evaluate_sample(
        self, examples: List[dict], data_name: str = "math"
    ) -> dict:
        """Evaluate a single sample."""
        # Parse question and ground truth
        questions = [parse_question(example, data_name) for example in examples]
        gt_cot, gt_ans = [], []
        for example in examples:
            gt_cot.append(parse_ground_truth(example, data_name)[0])
            gt_ans.append(parse_ground_truth(example, data_name)[1])
        
        # For instruct models, we pass the question directly
        # For non-instruct models, we use construct_prompt
        if self.is_instruct:
            prompts_to_use = questions
        else:
            # Construct prompt for non-instruct models
            assert self.is_instruct, "Non-instruct models should not be used with this function"
            # example['question'] = question
            # args = argparse.Namespace(
            #     prompt_type=self.prompt_type,
            #     temperature=self.temperature
            # )
            # prompt_to_use = construct_prompt(example, data_name, args)
        
        # Generate answer
        codes = self.generate_answers(prompts_to_use)
        
        # Prepare code for extraction (handle XML tags)
        prepared_codes = [self.prepare_code_for_extraction(code) for code in codes]
        # print('HELLO')
        # print(prepared_codes)
        # Extract prediction using the existing logic
        preds, reports = [], []
        for prepared_code in prepared_codes:
            pred, report = run_execute(self.executor, prepared_code, self.prompt_type, data_name)
            preds.append(pred)
            reports.append(report)
        
        # Create sample result
        samples = []
        for i in range(len(examples)):
            sample = {
                'idx': examples[i]['idx'],
                'question': questions[i],
                'gt_cot': gt_cot[i],
                'gt': gt_ans[i],
                'code': [prepared_codes[i]], # changed to prepared_code, og was code.
                'pred': [preds[i]],
                'report': [reports[i]],
            }
            
            # Add additional fields if present
            for key in ['level', 'type', 'unit', 'solution_type', 'choices', 'solution', 
                        'ques_type', 'ans_type', 'answer_type', 'dataset', 'subfield', 
                        'filed', 'theorem', 'answer']:
                if key in examples[i]:
                    sample[key] = examples[i][key]
            samples.append(sample)
        
        return samples

    def evaluate_dataset(
        self,
        data_dir: str = "./data",
        split: str = "test",
        num_samples: int = 50,
        seed: int = 42,
        start_idx: int = 0,
    ) -> Dict:
        """Evaluate on MATH dataset."""
        set_seed(seed)
        
        # Load MATH dataset
        print("Loading MATH dataset...")
        examples = load_data("math", split, data_dir)
        
        # Sample or slice examples
        if start_idx > 0:
            examples = examples[start_idx:start_idx + num_samples]
        else:
            if num_samples > 0 and num_samples < len(examples):
                random.shuffle(examples)
                examples = examples[:num_samples]
        
        print(f"Evaluating on {len(examples)} samples...")
        
        # Evaluate all samples
        all_samples = []
        start_time = time.time()

        all_samples = self.evaluate_sample(examples, "math")
        
        # for example in tqdm(examples, desc="Evaluating"):
        #     try:
        #         sample = self.evaluate_sample(example, "math")
        #         all_samples.append(sample)
                
        #         # Print first example
        #         if example == examples[0]:
        #             print("\n" + "="*60)
        #             print("First example:")
        #             print(f"Question: {sample['question'][:200]}...")
        #             print(f"Generated: {sample['code'][0][:200]}...")
        #             print(f"Predicted: {sample['pred'][0]}")
        #             print(f"Ground Truth: {sample['gt']}")
        #             print("="*60 + "\n")
                    
        #     except Exception as e:
        #         print(f"\nError processing sample {example['idx']}: {e}")
        #         # Add failed sample with None prediction
        #         sample = {
        #             'idx': example['idx'],
        #             'question': parse_question(example, "math"),
        #             'gt_cot': "",
        #             'gt': parse_ground_truth(example, "math")[1],
        #             'code': [""],
        #             'pred': [None],
        #             'report': [str(e)],
        #         }
        #         all_samples.append(sample)
        
        time_use = time.time() - start_time
        
        # Evaluate using the existing evaluation function
        all_samples, result_json = evaluate(
            samples=all_samples, 
            data_name="math", 
            prompt_type=self.prompt_type, 
            execute=True
        )
        
        # Add timing information
        result_json['time_use_in_second'] = time_use
        result_json['time_use_in_minute'] = f"{int(time_use // 60)}:{int(time_use % 60):02d}"
        result_json['num_samples'] = len(all_samples)
        
        return result_json, all_samples


# --------------------------------------------------------------------------- #
#                                    CLI                                      #
# --------------------------------------------------------------------------- #


def main() -> None:
    parser = argparse.ArgumentParser(description="MATH dataset evaluation with vLLM")
    parser.add_argument(
        "--model",
        help="HF repo or local path to model",
    )
    parser.add_argument("--data_dir", default="./data", type=str, help="Data directory")
    parser.add_argument("--tp", type=int, default=1, help="#GPUs for vLLM sharding")
    parser.add_argument("--num_samples", type=int, default=100, help="Number of samples to evaluate")
    parser.add_argument("--seed", type=int, default=42, help="Random seed")
    parser.add_argument("--start_idx", type=int, default=0, help="Starting index")
    parser.add_argument("--split", default="test", type=str, help="Dataset split")
    parser.add_argument("--prompt_type", default="cot", type=str, help="Prompt type")
    parser.add_argument("--temperature", default=0.0, type=float, help="Sampling temperature")
    parser.add_argument("--output_file", type=str, help="Output file")
    parser.add_argument(
        "--instruct",
        action="store_true",
        default=True,
        help="set if model was trained with chat / system prompt (XML format)",
    )
    
    args = parser.parse_args()

    print("=" * 60)
    print("MATH Dataset Evaluation with vLLM")
    print(f"Model:           {args.model}")
    print(f"Instruct-tuned:  {args.instruct}")
    print(f"Prompt type:     {args.prompt_type}")
    print(f"Temperature:     {args.temperature}")
    print(f"Tensor parallel: {args.tp}")
    print(f"Samples:         {args.num_samples}")
    print("=" * 60)

    # Initialize evaluator
    evaluator = MATH500Evaluator(
        model_name=args.model,
        tensor_parallel_size=args.tp,
        prompt_type=args.prompt_type,
        temperature=args.temperature,
        is_instruct=args.instruct,
    )

    # Run evaluation
    results, samples = evaluator.evaluate_dataset(
        data_dir=args.data_dir,
        split=args.split,
        num_samples=args.num_samples,
        seed=args.seed,
        start_idx=args.start_idx,
    )

    # ---------------------- summary ----------------------------------- #
    print("\n" + "=" * 60)
    print("EVALUATION RESULTS")
    print("=" * 60)
    print(f"Total samples:   {results['num_samples']}")
    print(f"Accuracy:        {results['acc']:.2f}%")
    if 'count' in results:
        print(f"Correct:         {results['count']}")
    print(f"Time used:       {results['time_use_in_minute']}")
    
    # Print per-type accuracy if available
    if 'type_acc' in results:
        print("\nPer-type accuracy:")
        for type_name, acc in results['type_acc'].items():
            print(f"  {type_name}: {acc:.2f}%")

    # ---------------------- dump JSON --------------------------------- #
    with open(args.output_file, "w") as f:
        json.dump(results, f, indent=2)
    print(f"\nDetailed results saved to: {args.output_file}")

    # Save sample outputs if needed
    samples_file = args.output_file.replace(".json", "_rlpr_sum75.json")
    with open(samples_file, "w") as f:
        json.dump(samples[:], f, indent=2)
    print(f"Sample outputs saved to: {samples_file}")

    print("\n" + "=" * 60)
    print(f"FINAL ACCURACY: {results['acc']:.2f}%")
    print("=" * 60)


if __name__ == "__main__":
    main()