"""
MATH-500 R1 task utilities for lm-evaluation-harness.

Provides answer extraction and comparison for MATH-500 benchmark problems.
Uses \boxed{} extraction from model outputs (R1-style reasoning).
Reuses normalization logic from the minerva_math task.

Key difference from minerva_math: MATH-500 dataset has a pre-extracted `answer`
field, so we use it directly instead of extracting from `solution`.
"""

import logging
import re
import signal
from typing import Optional

import datasets


eval_logger = logging.getLogger(__name__)


# ---------------------------------------------------------------------------
# Sympy / math_verify imports (best-effort, fallback to string comparison)
# ---------------------------------------------------------------------------
_HAS_SYMPY = False
_HAS_MATH_VERIFY = False

try:
    import sympy
    from sympy.parsing.latex import parse_latex
    _HAS_SYMPY = True
except (ModuleNotFoundError, AssertionError):
    eval_logger.warning("sympy not available — sympy equivalence checking disabled.")

try:
    from math_verify import parse, verify
    _HAS_MATH_VERIFY = True
except (ModuleNotFoundError, AssertionError):
    eval_logger.info("math_verify not available — using sympy/string comparison only.")


# ---------------------------------------------------------------------------
# Dataset processing
# ---------------------------------------------------------------------------
def process_docs(dataset: datasets.Dataset) -> datasets.Dataset:
    """Use the pre-extracted answer field from MATH-500 dataset."""

    def _process_doc(doc: dict) -> dict:
        out_doc = {
            "problem": doc["problem"],
            "solution": doc["solution"],
            "answer": normalize_final_answer(doc["answer"]),
        }
        return out_doc

    return dataset.map(_process_doc)


# ---------------------------------------------------------------------------
# Result processing — called by lm-eval after generation
# ---------------------------------------------------------------------------
def process_results(doc: dict, results: list[str]) -> dict[str, int]:
    """Compare model output to ground truth using boxed extraction."""
    model_output = results[0]

    # Extract predicted answer from model output
    predicted_boxed = last_boxed_only_string(model_output)
    if predicted_boxed is not None:
        predicted = normalize_final_answer(remove_boxed(predicted_boxed))
    else:
        # Fallback: try to find last number
        predicted = _extract_last_number(model_output) or "[invalidanswer]"

    ground_truth = doc["answer"]

    # Primary: string comparison after normalization
    exact = 1 if predicted == ground_truth else 0

    # Secondary: sympy equivalence check
    if exact == 0 and predicted != "[invalidanswer]":
        if _is_equiv(predicted, ground_truth):
            exact = 1

    # Tertiary: math_verify
    if exact == 0 and _HAS_MATH_VERIFY and predicted != "[invalidanswer]":
        try:
            if verify(gold=parse(doc["solution"]), target=parse(model_output)):
                exact = 1
        except Exception:
            pass

    return {"exact_match": exact}


# ---------------------------------------------------------------------------
# Answer extraction helpers (from minerva_math built-in)
# ---------------------------------------------------------------------------
def last_boxed_only_string(string: str) -> Optional[str]:
    idx = string.rfind("\\boxed")
    if "\\boxed " in string:
        return "\\boxed " + string.split("\\boxed ")[-1].split("$")[0]
    if idx < 0:
        idx = string.rfind("\\fbox")
        if idx < 0:
            return None

    i = idx
    right_brace_idx = None
    num_left_braces_open = 0
    while i < len(string):
        if string[i] == "{":
            num_left_braces_open += 1
        if string[i] == "}":
            num_left_braces_open -= 1
            if num_left_braces_open == 0:
                right_brace_idx = i
                break
        i += 1

    if right_brace_idx is None:
        retval = None
    else:
        retval = string[idx : right_brace_idx + 1]

    return retval


def remove_boxed(s: str) -> str:
    if s is None:
        return ""
    if "\\boxed " in s:
        left = "\\boxed "
        assert s[: len(left)] == left
        return s[len(left) :]

    left = "\\boxed{"
    assert s[: len(left)] == left
    assert s[-1] == "}"
    return s[len(left) : -1]


def _extract_last_number(text: str) -> Optional[str]:
    if not text:
        return None
    numbers = re.findall(r"-?\d+(?:\.\d+)?", text)
    return numbers[-1] if numbers else None


# ---------------------------------------------------------------------------
# Normalization (from Lewkowycz et al. 2022, Appendix D)
# ---------------------------------------------------------------------------
SUBSTITUTIONS = [
    ("an ", ""),
    ("a ", ""),
    (".$", "$"),
    ("\\$", ""),
    (r"\ ", ""),
    (" ", ""),
    ("mbox", "text"),
    (",\\text{and}", ","),
    ("\\text{and}", ","),
    ("\\text{m}", "\\text{}"),
]
REMOVED_EXPRESSIONS = [
    "square", "ways", "integers", "dollars", "mph", "inches", "ft",
    "hours", "km", "units", "\\ldots", "sue", "points", "feet",
    "minutes", "digits", "cents", "degrees", "cm", "gm", "pounds",
    "meters", "meals", "edges", "students", "childrentickets",
    "multiples", "\\text{s}", "\\text{.}", "\\text{\ns}", "\\text{}^2",
    "\\text{}^3", "\\text{\n}", "\\text{}", r"\mathrm{th}", r"^\circ",
    r"^{\circ}", r"\;", r",\!", "{,}", '"', "\\dots",
]


def normalize_final_answer(final_answer: str) -> str:
    final_answer = final_answer.split("=")[-1]

    for before, after in SUBSTITUTIONS:
        final_answer = final_answer.replace(before, after)
    for expr in REMOVED_EXPRESSIONS:
        final_answer = final_answer.replace(expr, "")

    final_answer = re.sub(r"(.*?)(\$)(.*?)(\$)(.*)", "$\\3$", final_answer)
    final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer)
    final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer)
    final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer)
    final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer)

    final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer)
    final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer)
    final_answer = final_answer.replace("$", "")

    if final_answer.replace(",", "").isdigit():
        final_answer = final_answer.replace(",", "")

    return final_answer


# ---------------------------------------------------------------------------
# Sympy equivalence check
# ---------------------------------------------------------------------------
class _Timeout:
    def __init__(self, seconds=1, error_message="Timeout"):
        self.seconds = seconds
        self.error_message = error_message

    def handle_timeout(self, signum, frame):
        raise TimeoutError(self.error_message)

    def __enter__(self):
        signal.signal(signal.SIGALRM, self.handle_timeout)
        signal.alarm(self.seconds)

    def __exit__(self, type, value, traceback):
        signal.alarm(0)


def _is_equiv(x1: str, x2: str) -> bool:
    if not _HAS_SYMPY:
        return False
    try:
        with _Timeout(seconds=5):
            try:
                parsed_x1 = parse_latex(x1)
                parsed_x2 = parse_latex(x2)
            except Exception:
                return False
            try:
                diff = parsed_x1 - parsed_x2
            except TypeError:
                return False
            try:
                return sympy.simplify(diff) == 0
            except ValueError:
                return False
    except TimeoutError:
        return False
    except Exception:
        return False
