"""Math answer extraction and equivalence checking.

This module provides utilities for extracting and comparing mathematical answers
from LLM predictions, supporting both numeric and LaTeX boxed formats.

References:
- OLMES evaluation: https://github.com/allenai/olmes/blob/main/oe_eval/configs/tasks.py
- lm-evaluation-harness: https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/minerva_math/utils.py
- Math-Verify: https://github.com/huggingface/Math-Verify
"""

import re
from fractions import Fraction

# =============================================================================
# Answer Extraction Patterns
# =============================================================================

# LaTeX boxed patterns (OLMES style)
# Reference: https://github.com/allenai/olmes/blob/main/oe_eval/configs/tasks.py
BOXED_PATTERN = r"\\boxed\s*\{([^{}]*(?:\{[^{}]*\}[^{}]*)*)\}"
BOXED_DOLLAR_PATTERN = r"\$\\boxed\s*\{([^{}]*(?:\{[^{}]*\}[^{}]*)*)\}\$"
FBOX_PATTERN = r"\\fbox\s*\{([^{}]*(?:\{[^{}]*\}[^{}]*)*)\}"

# "Final answer is" patterns (GSM8K / Minerva style)
# Reference: https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/gsm8k
FINAL_ANSWER_PATTERN = r"(?i)(?:the\s+)?final\s+answer\s+is[:\s]*(.+?)(?:\.|$)"
FINAL_ANSWER_NUMERIC_PATTERN = (
    r"(?i)The final answer is[:\s]*([*]*\s*-?[$0-9.,]*[0-9]+[$0-9.,]*\s*[*]*)"
)

# LaTeX cleanup patterns
LATEX_CLEANUP_PATTERNS = [
    (r"\\text\{([^}]*)\}", r"\1"),
    (r"\\textbf\{([^}]*)\}", r"\1"),
    (r"\\mathbf\{([^}]*)\}", r"\1"),
    (r"\\mathrm\{([^}]*)\}", r"\1"),
    (r"\\left", ""),
    (r"\\right", ""),
    (r"\\,", ""),
    (r"\\;", ""),
    (r"\\!", ""),
    (r"\\ ", " "),
    (r"\\%", "%"),
    (r"\\[$]", "$"),
]

# Numeric cleanup patterns (GSM8K style)
NUMERIC_IGNORE_PATTERNS = [r",", r"\$", r"(?s).*#### ", r"\.$", r"\*"]


# =============================================================================
# Answer Extraction Functions
# =============================================================================


def extract_boxed_answer(text: str) -> str | None:
    """Extract the last boxed answer from text.

    Tries patterns in order: \\boxed{}, $\\boxed{}$, \\fbox{}
    Returns the last match (assumed to be the final answer).
    """
    if not isinstance(text, str):
        return None

    for pattern in [BOXED_PATTERN, BOXED_DOLLAR_PATTERN, FBOX_PATTERN]:
        matches = re.findall(pattern, text)
        if matches:
            return clean_latex(matches[-1].strip())

    return None


def extract_final_answer(text: str) -> str | None:
    """Extract answer from 'the final answer is...' pattern."""
    if not isinstance(text, str):
        return None

    match = re.search(FINAL_ANSWER_PATTERN, text)
    if match:
        return clean_latex(match.group(1).strip())

    return None


def extract_final_answer_numeric(text: str) -> str | None:
    """Extract numeric answer from 'the final answer is...' pattern (GSM8K style)."""
    if not isinstance(text, str):
        return None

    match = re.search(FINAL_ANSWER_NUMERIC_PATTERN, text)
    if match:
        answer = match.group(1).strip()
        for pattern in NUMERIC_IGNORE_PATTERNS:
            answer = re.sub(pattern, "", answer)
        return answer.strip()

    return None


def extract_answer(text: str, prefer_boxed: bool = True) -> str | None:
    """Extract answer using multiple strategies.

    Args:
        text: The text to extract from
        prefer_boxed: If True, try boxed patterns first, then fallback to final answer
    """
    if not isinstance(text, str):
        return None

    if prefer_boxed:
        answer = extract_boxed_answer(text)
        if answer is not None:
            return answer
        return extract_final_answer(text)
    else:
        answer = extract_final_answer_numeric(text)
        if answer is not None:
            return answer
        return extract_boxed_answer(text)


def clean_latex(text: str) -> str:
    """Remove common LaTeX formatting from text."""
    for pattern, replacement in LATEX_CLEANUP_PATTERNS:
        text = re.sub(pattern, replacement, text)
    return text.strip()


# =============================================================================
# Math Equivalence Checking
# =============================================================================


def is_equiv(pred: str | None, gold: str | None) -> bool:
    """Check if two answers are mathematically equivalent.

    Tries comparison in order:
    1. Direct string match
    2. Numeric comparison (floats)
    3. Fraction comparison
    4. Normalized string comparison
    """
    if pred is None or gold is None:
        return False

    pred_str = str(pred).strip()
    gold_str = str(gold).strip()

    if pred_str == gold_str:
        return True

    if is_numeric_equiv(pred_str, gold_str):
        return True

    if is_fraction_equiv(pred_str, gold_str):
        return True

    return normalize_answer(pred_str) == normalize_answer(gold_str)


def is_numeric_equiv(pred: str, gold: str, rel_tol: float = 1e-6) -> bool:
    """Check if two strings represent equivalent numbers."""
    pred_num = parse_number(pred)
    gold_num = parse_number(gold)

    if pred_num is None or gold_num is None:
        return False

    if gold_num == 0:
        return abs(pred_num) < rel_tol

    return abs(pred_num - gold_num) < rel_tol or abs((pred_num - gold_num) / gold_num) < rel_tol


def is_fraction_equiv(pred: str, gold: str) -> bool:
    """Check if two strings represent equivalent fractions."""
    pred_frac = parse_fraction(pred)
    gold_frac = parse_fraction(gold)

    if pred_frac is None or gold_frac is None:
        return False

    return pred_frac == gold_frac


def parse_number(s: str) -> float | None:
    """Parse a string as a number."""
    s = s.strip().replace(",", "").replace("$", "")

    if s.endswith("%"):
        try:
            return float(s[:-1]) / 100
        except ValueError:
            return None

    try:
        return float(s)
    except ValueError:
        return None


def parse_fraction(s: str) -> Fraction | None:
    """Parse a string as a fraction."""
    s = s.strip()

    # Simple fraction: a/b
    match = re.match(r"^(-?\d+)\s*/\s*(-?\d+)$", s)
    if match:
        try:
            num, den = int(match.group(1)), int(match.group(2))
            if den != 0:
                return Fraction(num, den)
        except ValueError:
            pass

    # LaTeX fraction: \frac{a}{b}
    match = re.match(r"^\\frac\{(-?\d+)\}\{(-?\d+)\}$", s)
    if match:
        try:
            num, den = int(match.group(1)), int(match.group(2))
            if den != 0:
                return Fraction(num, den)
        except ValueError:
            pass

    # Try as decimal
    parsed = parse_number(s)
    if parsed is not None:
        try:
            return Fraction(parsed).limit_denominator(10000)
        except (ValueError, OverflowError):
            pass

    return None


def normalize_answer(s: str) -> str:
    """Normalize an answer for string comparison."""
    s = s.lower()
    s = re.sub(r"\s+", "", s)
    s = s.rstrip(".,;:")
    s = s.replace("$", "")
    return s
