"""
OMEGA benchmark task utilities for lm-evaluation-harness.

Handles answer extraction and comparison for OMEGA-generated math problems.
OMEGA answers are exact numerical values (integers, fractions, decimals, or
occasionally list/matrix answers).

Reuses normalization patterns from math500/utils.py.
"""

import logging
import re
import signal
from typing import Optional

eval_logger = logging.getLogger(__name__)

# Best-effort math_verify (preferred) and sympy for equivalence checking
_HAS_MATH_VERIFY = False
_HAS_SYMPY = False

try:
    from math_verify import parse, verify
    _HAS_MATH_VERIFY = True
except (ModuleNotFoundError, AssertionError):
    eval_logger.info("math_verify not available — using sympy/string comparison only.")

try:
    import sympy
    from sympy.parsing.latex import parse_latex
    _HAS_SYMPY = True
except (ModuleNotFoundError, AssertionError):
    pass


# ---------------------------------------------------------------------------
# Result processing — called by lm-eval after generation
# ---------------------------------------------------------------------------
def process_results(doc: dict, results: list[str]) -> dict[str, int]:
    """Compare model output to ground truth.

    Strict scoring: only credit a response when an answer can be cleanly extracted
    from a ``\\boxed{}`` (the format the prompt explicitly requests). The previous
    fallback chain (``<answer>`` tag, ``the answer is X`` regex, last-number-in-text,
    and math_verify on the full unboxed response) inflated long/truncated responses
    by crediting coincidental digits and mid-thought claims, which biased
    cross-checkpoint comparisons in favour of more verbose models.
    """
    model_output = results[0]
    ground_truth = str(doc.get("answer", ""))

    boxed = _extract_boxed(model_output)
    if boxed is None:
        return {"exact_match": 0}

    pred_norm = _normalize(boxed)
    gt_norm = _normalize(ground_truth)

    if pred_norm == gt_norm:
        return {"exact_match": 1}

    # math_verify on the boxed expression vs gold (e.g. "\\frac{1}{2}" ↔ "0.5")
    if _HAS_MATH_VERIFY:
        try:
            if verify(gold=parse(ground_truth), target=parse(f"\\boxed{{{boxed}}}")):
                return {"exact_match": 1}
        except Exception:
            pass

    # Numerical equivalence on the boxed value (e.g. "68968/61" vs "1130.62...")
    if _numerical_equiv(pred_norm, gt_norm):
        return {"exact_match": 1}

    # Symbolic equivalence on the boxed value (e.g. "(1+1)/2" ↔ "1")
    if _HAS_SYMPY and _sympy_equiv(pred_norm, gt_norm):
        return {"exact_match": 1}

    return {"exact_match": 0}


def _score_single(model_output: str, doc: dict) -> int:
    """Score a single response — for compute_pass_at_k.py integration."""
    return process_results(doc, [model_output])["exact_match"]


# ---------------------------------------------------------------------------
# Answer extraction
# ---------------------------------------------------------------------------
def _extract_answer(text: str) -> Optional[str]:
    """Extract answer from model output, trying multiple strategies."""
    if not text:
        return None

    # 1. Try \boxed{} (primary)
    boxed = _extract_boxed(text)
    if boxed is not None:
        return boxed

    # 2. Try <answer>...</answer> tags (RSFT format fallback)
    m = re.search(r"<answer>\s*(.*?)\s*</answer>", text, re.DOTALL)
    if m:
        return m.group(1).strip()

    # 3. Try "the answer is X" pattern
    m = re.search(r"[Tt]he\s+(?:final\s+)?answer\s+is[:\s]*([^\n.]+)", text)
    if m:
        return m.group(1).strip().rstrip(".")

    # 4. Last number fallback
    return _extract_last_number(text)


def _extract_boxed(text: str) -> Optional[str]:
    """Extract last \\boxed{} content with proper brace matching."""
    idx = text.rfind("\\boxed")
    if idx < 0:
        idx = text.rfind("\\fbox")
        if idx < 0:
            return None

    # Handle \boxed X (space, no braces)
    if "\\boxed " in text[idx:idx+10]:
        m = re.search(r"\\boxed\s+(\S+)", text[idx:])
        if m:
            return m.group(1).rstrip(".,;")

    # Find opening brace
    i = idx
    while i < len(text) and text[i] != "{":
        i += 1
    if i >= len(text):
        return None

    # Match braces
    start = i + 1
    depth = 1
    i += 1
    while i < len(text) and depth > 0:
        if text[i] == "{":
            depth += 1
        elif text[i] == "}":
            depth -= 1
        i += 1

    if depth == 0:
        return text[start:i-1].strip()
    return None


def _extract_last_number(text: str) -> Optional[str]:
    """Extract the last number from text."""
    if not text:
        return None
    numbers = re.findall(r"-?\d+(?:\.\d+)?", text)
    return numbers[-1] if numbers else None


# ---------------------------------------------------------------------------
# Normalization
# ---------------------------------------------------------------------------
def _normalize(answer: str) -> str:
    """Normalize an answer string for comparison."""
    if not answer:
        return ""

    answer = answer.strip()

    # Remove LaTeX math delimiters
    answer = answer.replace("$", "")

    # Remove common LaTeX wrappers
    answer = re.sub(r"\\text\{([^}]*)\}", r"\1", answer)
    answer = re.sub(r"\\textbf\{([^}]*)\}", r"\1", answer)
    answer = re.sub(r"\\mathrm\{([^}]*)\}", r"\1", answer)
    answer = re.sub(r"\\boxed\{(.*)\}", r"\1", answer)

    # Convert \frac{a}{b} to a/b
    answer = re.sub(r"\\frac\{([^}]*)\}\{([^}]*)\}", r"\1/\2", answer)
    # Convert \dfrac{a}{b} to a/b
    answer = re.sub(r"\\dfrac\{([^}]*)\}\{([^}]*)\}", r"\1/\2", answer)

    # Remove commas in numbers
    answer = answer.replace(",", "")

    # Remove trailing periods
    answer = answer.rstrip(".")

    # Strip whitespace
    answer = answer.strip()

    # Convert float-that-is-integer to integer string
    try:
        num = float(answer)
        if num == int(num) and "." not in answer.replace(".0", ""):
            return str(int(num))
        # For explicit .0 answers, also convert
        if num == int(num):
            return str(int(num))
    except (ValueError, OverflowError):
        pass

    return answer


# ---------------------------------------------------------------------------
# Numerical equivalence
# ---------------------------------------------------------------------------
def _parse_number(s: str) -> float:
    """Parse a number string, handling fractions like '68968/61'."""
    s = s.strip()
    if "/" in s:
        parts = s.split("/")
        if len(parts) == 2:
            return float(parts[0]) / float(parts[1])
    return float(s)


def _numerical_equiv(a: str, b: str, tol: float = 1e-4) -> bool:
    """Check if two strings represent the same number."""
    try:
        fa, fb = _parse_number(a), _parse_number(b)
        if fb == 0:
            return abs(fa) < tol
        return abs(fa - fb) / max(abs(fb), 1e-10) < tol
    except (ValueError, OverflowError, ZeroDivisionError):
        return False


# ---------------------------------------------------------------------------
# Sympy equivalence
# ---------------------------------------------------------------------------
class _Timeout:
    def __init__(self, seconds=3):
        self.seconds = seconds

    def handle_timeout(self, signum, frame):
        raise TimeoutError

    def __enter__(self):
        signal.signal(signal.SIGALRM, self.handle_timeout)
        signal.alarm(self.seconds)

    def __exit__(self, *args):
        signal.alarm(0)


def _sympy_equiv(x1: str, x2: str) -> bool:
    """Check symbolic equivalence with timeout."""
    if not _HAS_SYMPY:
        return False
    try:
        with _Timeout(seconds=3):
            p1 = parse_latex(x1)
            p2 = parse_latex(x2)
            return sympy.simplify(p1 - p2) == 0
    except Exception:
        return False
