#!/usr/bin/env python3

import logging
import re
import signal
import json
import torch
import random
import argparse
import warnings
from importlib.metadata import version
from typing import Dict, List, Optional, Tuple
from time import sleep
from tqdm import tqdm
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
import datasets
import gc
import time
import json
from types import SimpleNamespace

from math500_grader import grade_answer

warnings.filterwarnings("ignore")

eval_logger = logging.getLogger(__name__)

try:
    import antlr4
    import sympy
    from math_verify import parse, verify
    from sympy.parsing.latex import parse_latex

    assert version("antlr4-python3-runtime").startswith("4.11")
except (ModuleNotFoundError, AssertionError) as e:
    raise type(e)(
        "`sympy`, `math_verify` and `antlr4-python3-runtime==4.11` are required for generating translation task prompt templates. "
        "Please install the required packages via pip install lm-eval[math] or pip install -e .[math]"
    ) from e

# --------------------------------------------------------------------------- #
#                                 CONSTANTS                                   #
# --------------------------------------------------------------------------- #

XML_SYSTEM_PROMPT = """
Respond in the following format, with only the numerical answer or formula between the <answer> tags:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
""".strip()

XML_SYSTEM_PROMPT_BOXED = """
Always respond in the following format, with only the final answer between the <answer> tags and always put your answer in boxed:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
""".strip()

XML_SYSTEM_PROMPT_LLAMA = """
You are a helpful assistant that solves math problems step by step. Always show your work clearly and provide your final numerical answer at the end. Always respond in the following format, with only the final answer between the <answer> tags:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
""".strip()

XML_SYSTEM_PROMPT_LLAMA_BOXED = """
You are a helpful assistant that solves math problems step by step. Always show your work clearly and provide your final numerical answer at the end. Always respond in the following format, with only the final answer between the <answer> tags and always put your answer in boxed:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
""".strip()

# --------------------------------------------------------------------------- #
#                               EVALUATOR                                     #
# --------------------------------------------------------------------------- #

class MATH500Evaluator:
    """Evaluator for MATH-500 using vLLM."""

    def __init__(
        self,
        model_name: str = "",
        model = None,
        tokenizer = None,
        tensor_parallel_size: int = 1,
        is_instruct: bool = False,
        temperatures: List[float] = [0.0],
        num_generations: int = 1,
        seed: int = 42,
        boxed_system_prompt: bool = False,
        llama_system_prompt: bool = False,
    ):

        if model_name == "" and model is None:
            self.is_instruct = is_instruct
            self.system_prompt = ""
            return

        print(f"Loading model with vLLM: {model_name}  (TP={tensor_parallel_size})")
        self.model_name = model_name
        self.is_instruct = is_instruct

        # ------ Tokenizer --------------------------------------------------- #
        if tokenizer is None:
            self.tokenizer = AutoTokenizer.from_pretrained(
                model_name, trust_remote_code=True, padding_side="left"
            )
        else:
            self.tokenizer = tokenizer
            
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
            self.tokenizer.pad_token_id = self.tokenizer.eos_token_id

        # ------ vLLM runtime ------------------------------------------------- #
        if model is None:
            self.llm = LLM(model_name, tensor_parallel_size=tensor_parallel_size, gpu_memory_utilization=0.6, seed=seed, enforce_eager=True)
        else:
            self.llm = model

        # ------ Prompt template --------------------------------------------- #
        if is_instruct and llama_system_prompt and boxed_system_prompt:
            self.system_prompt = XML_SYSTEM_PROMPT_LLAMA_BOXED
        elif is_instruct and llama_system_prompt:
            self.system_prompt = XML_SYSTEM_PROMPT_LLAMA
        elif is_instruct and boxed_system_prompt:
            self.system_prompt = XML_SYSTEM_PROMPT_BOXED
        else:
            self.system_prompt = XML_SYSTEM_PROMPT

        self.temperatures = temperatures
        self.num_generations = num_generations

        self.num_message_examples = 1

    # ------------------------------------------------------------------ #
    #                  Helper / extraction utilities                     #
    # ------------------------------------------------------------------ #

    # 1) Primary path: XML tags
    @staticmethod
    def _extract_xml_answer(text: str) -> Optional[str]:
        m = re.search(
            r"<answer>\s*([^<]+)\s*</answer>", text, flags=re.DOTALL
        )
        if m:
            try:
                return m.group(1).strip()
            except ValueError:
                return None
        return None

    # 2) Fallbacks – keep the old heuristics around just in case
    def _legacy_extract_answer(self, text: str) -> Optional[str]:
        text = text.replace(",", "")
        answer_pattern = r"[Tt]he answer is:?\s*([+-]?\d+(?:\.\d+)?)"
        if m := re.search(answer_pattern, text):
            try:
                return str(float(m.group(1)))
            except ValueError:
                pass

        if self.is_instruct and "####" in text:
            tail = text.split("####")[-1].strip()
            if m := re.search(r"([+-]?\d+(?:\.\d+)?)", tail):
                try:
                    return str(float(m.group(1)))
                except ValueError:
                    pass

        # last-chunk heuristics
        parts = re.split(r"answer", text, flags=re.IGNORECASE)
        if len(parts) > 1:
            numbers = re.findall(r"([+-]?\d+(?:\.\d+)?)", parts[-1])
            if numbers:
                try:
                    return str(float(numbers[0]))
                except ValueError:
                    pass

            cleaned = re.sub(r'^[a-zA-Z\s.:,!?]+', '', parts[-1])
            cleaned = re.sub(r'[a-zA-Z\s.:,!?]+$', '', cleaned)
            cleaned = cleaned.strip()
            
            if cleaned:
                return cleaned

        lines = text.strip().splitlines()
        if lines:
            numbers = re.findall(r"([+-]?\d+(?:\.\d+)?)", lines[-1])
            if numbers:
                try:
                    return str(float(numbers[-1]))
                except ValueError:
                    pass

            cleaned = re.sub(r'^[a-zA-Z\s.:,!?]+', '', lines[-1])
            cleaned = re.sub(r'[a-zA-Z\s.:,!?]+$', '', cleaned)
            cleaned = cleaned.strip()
            
            if cleaned:
                return cleaned

        return None

    def extract_answer_from_text(self, text: str) -> Optional[str]:
        """Try XML first, fall back to legacy."""
        answer = self._extract_xml_answer(text)
        if answer is not None:
            try:
                answer_boxed = remove_boxed(last_boxed_only_string(answer))
                if answer_boxed is not None:
                    return answer_boxed
            except:
                pass
            return answer
        try:
            answer = remove_boxed(last_boxed_only_string(text))
            if answer is not None:
                return answer
        except:
            pass
        return self._legacy_extract_answer(text)

    # ------------------------------------------------------------------ #
    #                       Prompt construction                          #
    # ------------------------------------------------------------------ #

    def create_prompt(self, problem: str, add_intruction_prefix: bool = True) -> str:
        """Return the string we feed to vLLM."""
        assert self.is_instruct, "Only instruct mode is supported"
        if self.is_instruct:
            messages = [
                {"role": "system", "content": self.system_prompt},
                {"role": "user", "content": f"Solve the following math problem step by step:\n\n{problem}" if add_intruction_prefix else problem},
            ]

            if self.num_message_examples > 0:
                print(f"Messages: {messages}")
                self.num_message_examples -= 1

            return self.tokenizer.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=True
            )
        # non-chat baseline
        return f"Solve the following math problem step by step:\n{problem}"

    # ------------------------------------------------------------------ #
    #                    vLLM-based generation                           #
    # ------------------------------------------------------------------ #

    def generate_answers(self, problems: List[str], add_intruction_prefix: bool = True, max_new_tokens: int = 2048, seed: int = 42, top_p: float = 0.95, top_k: int = 40) -> str:
        prompts = [self.create_prompt(problem, add_intruction_prefix=add_intruction_prefix) for problem in problems]
        answers = [[] for _ in range(len(problems))]

        for temperature in self.temperatures:

            # print(f"Generating answers with temperature {temperature}")
            # if temperature != 0.0:
            #     raise NotImplementedError("Non-zero temperatures are not supported")

            params = SamplingParams(
                temperature=temperature,
                max_tokens=max_new_tokens,
                n=(1 if abs(temperature) < 1e-6 else self.num_generations),
                seed=seed,
                top_p=top_p,
                top_k=top_k,
            )

            outputs = self.llm.generate(prompts, params)

            for i in range(len(outputs)):
                for output in outputs[i].outputs:
                    answer = output.text

                    # Stop once we hit </answer> to avoid trailing noise.
                    if self.is_instruct and "</answer>" in answer:
                        answer = answer.split("</answer>")[0] + "</answer>"

                    # For non-instruct fall back to previous trimming logic
                    elif not self.is_instruct:
                        for stop in ["\n\n", "\nProblem:", "Problem:", "Solve the following math problem step by step:"]:
                            if stop in answer:
                                answer = answer.split(stop)[0]
                                break

                    answers[i].append(answer.strip())
                    if abs(temperature) < 1e-6:
                        for _ in range(self.num_generations - 1):
                            answers[i].append(answer.strip())
        
        return answers

    # ------------------------------------------------------------------ #
    #                       Evaluation logic                             #
    # ------------------------------------------------------------------ #

    def evaluate_sample(
        self, problem: str, gt_answer: str, gen_texts: List[str]
    ) -> Tuple[List[bool], str, List[Optional[str]], List[str]]:
        preds = [self.extract_answer_from_text(gen_text) for gen_text in gen_texts]
        # for gen_text, pred in zip(gen_texts, preds):
        #     print(f"Gen text: {gen_text[-100:]} pred: {pred}")
        
        # Check exact match and math verification
        exact_matches = []
        math_verifications = []
        
        for pred in preds:
            if pred == "[invalidanswer]" or pred is None:
                exact_matches.append(False)
                math_verifications.append(False)
            else:
                # For XML format, we need to normalize the extracted answer
                normalized_pred = normalize_final_answer(pred)
                exact_match = is_equiv(normalized_pred, gt_answer)
                exact_matches.append(exact_match)
                
                # math_verify
                try:
                    res = verify(parse(gt_answer), parse(normalized_pred))
                    math_verifications.append(res)
                except:
                    math_verifications.append(False)
            
            # print(f"Pred: {pred} gt: {gt_answer} exact_match: {exact_match} math_verification: {res}")
            exact_matches[-1] = (exact_matches[-1] or math_verifications[-1])
            try:
                grader_result = grade_answer(pred, gt_answer)
                # print(f"Grader result: {grader_result}")
                exact_matches[-1] = (exact_matches[-1] or grader_result)
            except Exception as e:
                pass
                # print(f"Grader error: {e}")
        
        return exact_matches, gt_answer, preds, gen_texts, math_verifications

    def evaluate_dataset(
        self,
        dataset,
        num_samples: int = None,
        seed: int = 42,
        start_idx: int = 0,
        top_p: float = 0.95,
        top_k: int = 40,
        max_new_tokens: int = 2048,
        add_intruction_prefix: bool = True,
    ) -> Dict:
        random.seed(seed)
        torch.manual_seed(seed)

        if num_samples is None:
            num_samples = len(dataset)

        if start_idx > 0:
            indices = list(
                range(start_idx, min(start_idx + num_samples, len(dataset)))
            )
        else:
            indices = random.sample(range(len(dataset)), min(num_samples, len(dataset)))

        results = {
            "exact_match_correct": 0,
            "math_verify_correct": 0,
            "total": 0,
            "no_answer": 0,
            "samples": [],
            "num_samples": 0,
            "num_any_exact_match": 0,
            "num_any_math_verify": 0,
        }

        answers = self.generate_answers([dataset[idx]["problem"] for idx in indices], add_intruction_prefix=add_intruction_prefix, max_new_tokens=max_new_tokens, seed=seed, top_p=top_p, top_k=top_k)

        filename = f"generated_answers_{time.time()}.json"
        with open(filename, "w") as f:
            json.dump(answers, f, indent=4)
            print(f"Generated answers saved to {filename}")

        for idx, answer in tqdm(zip(indices, answers), desc="Evaluating"):
            sample = dataset[idx]
            try:
                exact_matches, gt, preds, gen_texts, math_verifications = self.evaluate_sample(
                    sample["problem"], sample["answer"], answer
                )

                results["num_samples"] += 1
                if any(exact_matches):
                    results["num_any_exact_match"] += 1
                if any(math_verifications):
                    results["num_any_math_verify"] += 1

                for i in range(len(preds)):
                    results["total"] += 1
                    if preds[i] == "[invalidanswer]" or preds[i] is None:
                        results["no_answer"] += 1
                    elif exact_matches[i]:
                        results["exact_match_correct"] += 1
                    if math_verifications[i]:
                        results["math_verify_correct"] += 1

                    results["samples"].append(
                        {
                            "index": idx,
                            "problem": sample["problem"],
                            "ground_truth": gt,
                            "predicted": preds[i],
                            "exact_match_correct": exact_matches[i],
                            "math_verify_correct": math_verifications[i],
                            "generated_text": gen_texts[i],
                        }
                    )
                    
            except Exception as e:
                print(f"\nError processing sample {idx}: {e}")

        if results["total"] > 0:
            results["exact_match_accuracy"] = results["exact_match_correct"] / results["total"]
            results["math_verify_accuracy"] = results["math_verify_correct"] / results["total"]
            results["answer_rate"] = (
                (results["total"] - results["no_answer"]) / results["total"]
            )
            results["exact_match_pass_at_n"] = results["num_any_exact_match"] / results["num_samples"]
            results["math_verify_pass_at_n"] = results["num_any_math_verify"] / results["num_samples"]
        else:
            results["exact_match_accuracy"] = results["math_verify_accuracy"] = 0.0
            results["answer_rate"] = 0.0
            results["exact_match_pass_at_n"] = results["math_verify_pass_at_n"] = 0.0

        return results


def evaluate_model_on_datasets(model_name, datasets, tp, num_samples_by_dataset, seed, start_idx, instruct, temperatures, num_generations, top_p, top_k, boxed_system_prompt, llama_system_prompt, max_new_tokens):
    evaluator = MATH500Evaluator(
        model_name=model_name,
        tensor_parallel_size=tp,
        is_instruct=instruct,
        temperatures=temperatures,
        num_generations=num_generations,
        seed=seed,
        boxed_system_prompt=boxed_system_prompt,
        llama_system_prompt=llama_system_prompt,
    )
    results = {}
    for dataset_name, dataset in datasets.items():
        print(f"\nEvaluating {model_name} on {dataset_name} with {num_samples_by_dataset[dataset_name]} samples …")
        add_intruction_prefix = (not (dataset_name.endswith(".json") or dataset_name.endswith(".jsonl")))
        print(f"Adding instruction prefix: {add_intruction_prefix}")
        results[dataset_name] = evaluator.evaluate_dataset(dataset, num_samples=num_samples_by_dataset[dataset_name], seed=seed, start_idx=start_idx, top_p=top_p, top_k=top_k, max_new_tokens=max_new_tokens, add_intruction_prefix=add_intruction_prefix)
        print(f"Results for {model_name} on {dataset_name}:")
        print(f"Total samples:   {results[dataset_name]['total']}")
        print(f"Exact match correct: {results[dataset_name]['exact_match_correct']}")
        print(f"Math verify correct: {results[dataset_name]['math_verify_correct']}")
        print(f"No answer:       {results[dataset_name]['no_answer']}")
        print(f"Exact match accuracy: {results[dataset_name]['exact_match_accuracy']*100:.2f}%")
        print(f"Answer rate:     {results[dataset_name]['answer_rate']*100:.2f}%")
        print(f"Exact match pass @ {num_generations * len(temperatures)}: {results[dataset_name]['exact_match_pass_at_n']*100:.2f}%")

    del evaluator
    gc.collect()
    torch.cuda.empty_cache()

    return results

# --------------------------------------------------------------------------- #
#                               DATASET LOADING                              #
# --------------------------------------------------------------------------- #

def load_math500_dataset():
    """Load MATH-500 dataset from Hugging Face for zero-shot evaluation."""
    print("Loading MATH-500 dataset from Hugging Face...")
    dataset = datasets.load_dataset("HuggingFaceH4/MATH-500", split="test")
    
    # Process the dataset to match the expected format
    processed_dataset = []
    for i, sample in enumerate(dataset):
        processed_sample = {
            "problem": sample["problem"],
            "solution": sample["solution"],
            "answer": normalize_final_answer(
                remove_boxed(last_boxed_only_string(sample["solution"]))
            ),
            "subject": sample["subject"],
            "level": sample["level"],
            "unique_id": sample["unique_id"],
        }
        processed_dataset.append(processed_sample)
    
    print(f"Loaded {len(processed_dataset)} samples from MATH-500 for zero-shot evaluation")
    return processed_dataset

def load_amc_dataset():
    """Load AMC-12 2024 dataset from Hugging Face for zero-shot evaluation."""
    print("Loading AMC-12 2024 dataset from Hugging Face...")
    dataset = datasets.load_dataset("rawsh/2024_AMC12", split="train")
    
    # Process the dataset to match the expected format
    processed_dataset = []
    for i, sample in enumerate(dataset):
        processed_sample = {
            "problem": sample["problem"],
            "answer": sample["answer"],
            "exam": sample["exam"],
            "problem_number": sample["problem_number"],
            "unique_id": f"{sample['exam']}_{sample['problem_number']}",
        }
        processed_dataset.append(processed_sample)
    
    print(f"Loaded {len(processed_dataset)} samples from AMC-12 2024 for zero-shot evaluation")
    return processed_dataset

def load_math_train_dataset():
    with open("datasets/hendrycks_math_train_all_with_answers.jsonl", "r") as f:
        dataset = [json.loads(line) for line in f]
    return dataset

def load_math_test_dataset():
    with open("datasets/hendrycks_math_test_all_with_answers.jsonl", "r") as f:
        dataset = [json.loads(line) for line in f]
    return dataset

def load_general_dataset(dataset_name: str):
    with open(f"{dataset_name}", "r") as f:
        dataset = [json.loads(line) for line in f]

    processed_dataset = []
    for i, sample in enumerate(dataset):
        processed_sample = {
            "problem": sample["prompt"],
            "answer": str(sample["final_answer"])
        }
        processed_dataset.append(processed_sample)
    return processed_dataset

def load_dataset(dataset_name: str):
    """Load the specified dataset."""
    if dataset_name.lower() == "math500" or dataset_name.lower() == "math-500":
        return load_math500_dataset()
    elif dataset_name.lower() == "amc" or dataset_name.lower() == "amc12":
        return load_amc_dataset()
    elif dataset_name.lower() == "math_train":
        return load_math_train_dataset()
    elif dataset_name.lower() == "math_test":
        return load_math_test_dataset()
    else:
        return load_general_dataset(dataset_name)

# --------------------------------------------------------------------------- #
#                               UTILITY FUNCTIONS                            #
# --------------------------------------------------------------------------- #

# Zero-shot text formatting
def doc_to_text(doc: dict) -> str:
    return "Solve the following math problem step by step:\n" + doc["problem"]


def process_docs(dataset: datasets.Dataset) -> datasets.Dataset:
    def _process_doc(doc: dict) -> dict:
        out_doc = {
            "problem": doc["problem"],
            "solution": doc["solution"],
            "answer": normalize_final_answer(
                remove_boxed(last_boxed_only_string(doc["solution"]))
            ),
        }
        return out_doc

    return dataset.map(_process_doc)


def last_boxed_only_string(string: str) -> Optional[str]:
    idx = string.rfind("\\boxed")
    if "\\boxed " in string:
        return "\\boxed " + string.split("\\boxed ")[-1].split("$")[0]
    if idx < 0:
        idx = string.rfind("\\fbox")
        if idx < 0:
            return None

    i = idx
    right_brace_idx = None
    num_left_braces_open = 0
    while i < len(string):
        if string[i] == "{":
            num_left_braces_open += 1
        if string[i] == "}":
            num_left_braces_open -= 1
            if num_left_braces_open == 0:
                right_brace_idx = i
                break
        i += 1

    if right_brace_idx is None:
        retval = None
    else:
        retval = string[idx : right_brace_idx + 1]

    return retval


def remove_boxed(s: str) -> str:
    if "\\boxed " in s:
        left = "\\boxed "
        assert s[: len(left)] == left
        return s[len(left) :]

    left = "\\boxed{"

    assert s[: len(left)] == left
    assert s[-1] == "}"

    return s[len(left) : -1]


class timeout:
    def __init__(self, seconds=1, error_message="Timeout"):
        self.seconds = seconds
        self.error_message = error_message

    def handle_timeout(self, signum, frame):
        raise TimeoutError(self.error_message)

    def __enter__(self):
        signal.signal(signal.SIGALRM, self.handle_timeout)
        signal.alarm(self.seconds)

    def __exit__(self, type, value, traceback):
        signal.alarm(0)


def is_equiv(x1: str, x2: str) -> bool:
    """
    x1 and x2 are normalized latex string
    """
    if x1.strip() == x2.strip():
        return True
    try:
        with timeout(seconds=5):
            try:
                parsed_x1 = parse_latex(x1)
                parsed_x2 = parse_latex(x2)
            except (
                sympy.parsing.latex.errors.LaTeXParsingError,
                sympy.SympifyError,
                TypeError,
            ):
                eval_logger.debug(f"couldn't parse one of {x1} or {x2}")
                return False

            try:
                diff = parsed_x1 - parsed_x2
            except TypeError:
                eval_logger.debug(f"couldn't subtract {x1} and {x2}")
                return False

            try:
                if sympy.simplify(diff) == 0:
                    return True
                else:
                    return False
            except ValueError:
                eval_logger.debug(
                    f"Had some trouble simplifying when comparing {x1} and {x2}"
                )
    except TimeoutError:
        eval_logger.debug(f"Timed out comparing {x1} and {x2}")
        raise TimeoutError(f"Timed out comparing {x1} and {x2}")
        return False
    except ImportError as e:
        eval_logger.error(e)
        raise
    except Exception as e:
        eval_logger.debug(f"Failed comparing {x1} and {x2} with {e}")
        return False


SUBSTITUTIONS = [
    ("an ", ""),
    ("a ", ""),
    (".$", "$"),
    ("\\$", ""),
    (r"\ ", ""),
    (" ", ""),
    ("mbox", "text"),
    (",\\text{and}", ","),
    ("\\text{and}", ","),
    ("\\text{m}", "\\text{}"),
]
REMOVED_EXPRESSIONS = [
    "square",
    "ways",
    "integers",
    "dollars",
    "mph",
    "inches",
    "ft",
    "hours",
    "km",
    "units",
    "\\ldots",
    "sue",
    "points",
    "feet",
    "minutes",
    "digits",
    "cents",
    "degrees",
    "cm",
    "gm",
    "pounds",
    "meters",
    "meals",
    "edges",
    "students",
    "childrentickets",
    "multiples",
    "\\text{s}",
    "\\text{.}",
    "\\text{\ns}",
    "\\text{}^2",
    "\\text{}^3",
    "\\text{\n}",
    "\\text{}",
    r"\mathrm{th}",
    r"^\circ",
    r"^{\circ}",
    r"\;",
    r",\!",
    "{,}",
    '"',
    "\\dots",
]


def normalize_final_answer(final_answer: str) -> str:
    """
    Normalize a final answer to a quantitative reasoning question.

    Copied character for character from appendix D of Lewkowycz et al. (2022)
    """

    if final_answer.startswith(r"\(") and final_answer.endswith(r"\)"):
        final_answer = final_answer[2:-2]
    
    final_answer = final_answer.split("=")[-1]

    final_answer = final_answer.split(r"\in")[-1]
    final_answer = final_answer.split(r"\approx")[-1]

    for before, after in SUBSTITUTIONS:
        final_answer = final_answer.replace(before, after)
    for expr in REMOVED_EXPRESSIONS:
        final_answer = final_answer.replace(expr, "")

    # Extract answer that is in LaTeX math, is bold,
    # is surrounded by a box, etc.
    final_answer = re.sub(r"(.*?)(\$)(.*?)(\$)(.*)", "$\\3$", final_answer)
    final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer)
    final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer)
    final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer)
    final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer)

    # Normalize shorthand TeX:
    #  \fracab -> \frac{a}{b}
    #  \frac{abc}{bef} -> \frac{abc}{bef}
    #  \fracabc -> \frac{a}{b}c
    #  \sqrta -> \sqrt{a}
    #  \sqrtab -> sqrt{a}b
    final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer)
    final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer)
    final_answer = final_answer.replace("$", "")

    # Normalize 100,000 -> 100000
    if final_answer.replace(",", "").isdigit():
        final_answer = final_answer.replace(",", "")

    return final_answer

# --------------------------------------------------------------------------- #
#                                    CLI                                      #
# --------------------------------------------------------------------------- #

def main() -> None:
    parser = argparse.ArgumentParser(description="Math evaluation with vLLM")
    parser.add_argument("--models", help="HF repo or local paths", nargs="+")
    parser.add_argument("--datasets", help="Dataset to evaluate on: math500 or amc or math_train or math_test or path to dataset", nargs="+")
    parser.add_argument("--out_file", default=None, help="Path to output file")
    parser.add_argument("--tp", type=int, default=1, help="#GPUs for vLLM sharding")
    parser.add_argument("--num_samples", type=int, default=None)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--start_idx", type=int, default=0)
    parser.add_argument("--temperatures", type=float, default=[0.0], nargs="+")
    parser.add_argument("--num_generations", type=int, default=1)
    parser.add_argument(
        "--instruct",
        action="store_true",
        default=True,
        help="set if model was trained with chat / system prompt",
    )
    parser.add_argument("--max_new_tokens", type=int, default=8192, help="Max new tokens for vLLM")
    parser.add_argument("--top_p", type=float, default=1.0, help="Top-p for vLLM")
    parser.add_argument("--top_k", type=int, default=0, help="Top-k for vLLM")
    parser.add_argument("--boxed_system_prompt", action="store_true", default=False, help="Use boxed system prompt")
    parser.add_argument("--llama_system_prompt", action="store_true", default=False, help="Use llama system prompt")
    args = parser.parse_args()

    if ("7b" in args.models[0] or "7B" in args.models[0]) and not args.boxed_system_prompt:
        print("Error: Use --boxed_system_prompt for 7B models")
        return

    if "llama" in args.models[0].lower() and not args.llama_system_prompt:
        print("Error: Use --llama_system_prompt for llama models")
        return

    # Load MATH-500 dataset
    dataset_by_name = {}
    for dataset_name in args.datasets:
        print(f"\nLoading {dataset_name} dataset …")
        try:
            test_set = load_dataset(dataset_name)
        except Exception as e:
            print(f"Error loading dataset: {e}")
            return
        print(f"Total samples: {len(test_set)}")
        dataset_by_name[dataset_name] = test_set
    print(f"Loaded {len(dataset_by_name)} datasets")
    
    num_samples_by_dataset = {}
    for dataset_name in args.datasets:
        num_samples_by_dataset[dataset_name] = args.num_samples if args.num_samples is not None else len(dataset_by_name[dataset_name])
    print(f"Number of samples per dataset: {num_samples_by_dataset}")

    results = {}
    for model_name in args.models:
        print("=" * 60)
        print(f"Model:           {model_name}")
        print(f"Instruct-tuned:  {args.instruct}")
        print(f"Tensor parallel: {args.tp}")
        print("=" * 60)
        
        for _ in range(10):
            try:
                current_results = evaluate_model_on_datasets(
                    model_name, dataset_by_name, args.tp, num_samples_by_dataset, 
                    args.seed, args.start_idx, args.instruct, args.temperatures, args.num_generations, args.top_p, args.top_k, args.boxed_system_prompt, args.llama_system_prompt, args.max_new_tokens
                )
                results[model_name] = current_results
                for dataset_name, dataset in current_results.items():
                    print(f"Results for {model_name} on {dataset_name}:")
                    print(f"Total samples:   {dataset['total']}")
                    print(f"Exact match correct: {dataset['exact_match_correct']}")
                    print(f"Math verify correct: {dataset['math_verify_correct']}")
                    print(f"No answer:       {dataset['no_answer']}")
                    print(f"Exact match accuracy: {dataset['exact_match_accuracy']*100:.2f}%")
                    # print(f"Math verify accuracy: {current_results['math_verify_accuracy']*100:.2f}%")
                    print(f"Answer rate:     {dataset['answer_rate']*100:.2f}%")
                    print(f"Exact match pass @ {args.num_generations * len(args.temperatures)}: {dataset['exact_match_pass_at_n']*100:.2f}%")
                    # print(f"Math verify pass @ {args.num_generations * len(args.temperatures)}: {current_results['math_verify_pass_at_n']*100:.2f}%")
                break
            except Exception as e:
                print(f"Error evaluating {model_name} on {dataset_name}: {e}")
                sleep(30)

    # ---------------------- summary ----------------------------------- #
    print("\n" + "=" * 60)
    print("EVALUATION RESULTS")
    for model_name in results.keys():
        for dataset_name in results[model_name].keys():
            print("=" * 60)
            print(f"Results for {model_name} on {dataset_name}:")
            print(f"Total samples:   {results[model_name][dataset_name]['total']}")
            print(f"Exact match correct: {results[model_name][dataset_name]['exact_match_correct']}")
            print(f"Math verify correct: {results[model_name][dataset_name]['math_verify_correct']}")
            print(f"No answer:       {results[model_name][dataset_name]['no_answer']}")
            print(f"Exact match accuracy: {results[model_name][dataset_name]['exact_match_accuracy']*100:.2f}%")
            # print(f"Math verify accuracy: {results[model_name][dataset_name]['math_verify_accuracy']*100:.2f}%")
            print(f"Answer rate:     {results[model_name][dataset_name]['answer_rate']*100:.2f}%")
            print(f"Exact match pass @ {args.num_generations * len(args.temperatures)}: {results[model_name][dataset_name]['exact_match_pass_at_n']*100:.2f}%")
            # print(f"Math verify pass @ {args.num_generations * len(args.temperatures)}: {results[model_name][dataset_name]['math_verify_pass_at_n']*100:.2f}%")

    # ---------------------- dump JSON --------------------------------- #
    if args.out_file is not None:
        print(f"Dumping results to: {args.out_file}")
        try:
            for model_name in results.keys():
                for dataset_name in results[model_name].keys():
                    results[model_name][dataset_name] = f"{results[model_name][dataset_name]["exact_match_accuracy"]*100:.2f}"
            with open(args.out_file, "w") as f:
                json.dump(results, f, indent=2)
            print(f"\nDetailed results saved to: {args.out_file}")
        except Exception as e:
            print(f"Error dumping results: {e}")


if __name__ == "__main__":
    main()