"""
HMMT task utilities for lm-evaluation-harness.

Reuses the OlymMATH-canonical scorer (math_verify primary + exact-match
fallback). HMMT answers cover the same types as OlymMATH (integers,
fractions, surds, symbolic expressions). Doc schema `{problem, answer}`
is identical to OlymMATH, so no extraction changes are needed.
"""

import logging
import re
from typing import Dict, List, Optional


eval_logger = logging.getLogger(__name__)


# ---------------------------------------------------------------------------
# math_verify import (the official OlymMATH scorer uses this)
# ---------------------------------------------------------------------------
_HAS_MATH_VERIFY = False

try:
    from math_verify import parse, verify
    _HAS_MATH_VERIFY = True
except (ModuleNotFoundError, AssertionError):
    eval_logger.warning("math_verify not available — OlymMATH scoring will use string comparison only (UNRELIABLE for non-integer answers).")


# ---------------------------------------------------------------------------
# Answer extraction
# ---------------------------------------------------------------------------
def extract_boxed_answer(text: str) -> Optional[str]:
    """Extract content from the last \\boxed{} with support for nested braces.
    Matches the official OlymMATH extraction logic."""
    if not text:
        return None

    stack = []
    boxed_contents = []
    i = 0
    start_idx = -1

    while i < len(text):
        if text[i:i + 7] == "\\boxed{" and (i == 0 or text[i - 1] != "\\"):
            if not stack:
                start_idx = i + 7
            stack.append("{")
            i += 7
        elif text[i] == "{" and (i == 0 or text[i - 1] != "\\"):
            stack.append("{")
            i += 1
        elif text[i] == "}" and (i == 0 or text[i - 1] != "\\"):
            if stack:
                stack.pop()
                if not stack and start_idx != -1:
                    boxed_contents.append(text[start_idx:i])
                    start_idx = -1
            i += 1
        else:
            i += 1

    if boxed_contents:
        return boxed_contents[-1]

    # Regex fallback
    pattern = r"\\boxed{((?:[^{}]|{(?:[^{}]|{[^{}]*})*})*?)}"
    matches = list(re.finditer(pattern, text))
    if matches:
        return matches[-1].group(1)

    return None


def extract_last_number(text: str) -> Optional[str]:
    if not text:
        return None
    numbers = re.findall(r"-?\d+(?:\.\d+)?", text)
    return numbers[-1] if numbers else None


# ---------------------------------------------------------------------------
# Scoring — official OlymMATH pattern: math_verify → string fallback
# ---------------------------------------------------------------------------
def _format_for_math_verify(answer: str) -> str:
    """Wrap answer in $ for math_verify parsing."""
    if not answer:
        return "$.$"
    answer = answer.strip()
    if answer.startswith("$"):
        answer = answer[1:]
    if answer.endswith("$"):
        answer = answer[:-1]
    answer = answer.strip()
    if not answer:
        return "$.$"
    return f"${answer}$"


def _score(predicted: Optional[str], ground_truth: str) -> int:
    """Score a single prediction against ground truth.
    Uses math_verify for symbolic equivalence only.
    """
    if not predicted:
        return 0

    # Primary: math_verify symbolic equivalence
    if _HAS_MATH_VERIFY:
        try:
            fg = _format_for_math_verify(ground_truth)
            fp = _format_for_math_verify(predicted)
            if verify(parse(fg), parse(fp)):
                return 1
        except Exception:
            pass

    # Fallback: exact string match only (no substring matching)
    if predicted and ground_truth:
        def _normalize(text):
            if not text:
                return ""
            text = re.sub(r"\s+", "", text)
            text = text.replace("\\frac", "")
            text = text.replace("\\cdot", "*")
            text = text.replace("\\times", "*")
            text = re.sub(r"\\[a-zA-Z]+", "", text)
            return text
        if _normalize(predicted) == _normalize(ground_truth):
            return 1

    return 0


# ---------------------------------------------------------------------------
# lm-eval interface functions
# ---------------------------------------------------------------------------
def process_results(doc: dict, results: List[str]) -> Dict[str, int]:
    model_output = results[0]

    predicted = extract_boxed_answer(model_output)
    if predicted is None:
        predicted = extract_last_number(model_output)

    ground_truth = doc.get("answer", "")
    exact_match = _score(predicted, ground_truth)

    return {"exact_match": exact_match}


def _score_single(model_output: str, doc: dict) -> int:
    """Score a single model output against ground truth.
    Used by process_results_avg and by compute_pass_at_k.py."""
    predicted = extract_boxed_answer(model_output)
    if predicted is None:
        predicted = extract_last_number(model_output)
    ground_truth = doc.get("answer", "")
    return _score(predicted, ground_truth)


def process_results_avg(doc: dict, results: List) -> Dict[str, float]:
    """Score ALL k responses and return avg@k."""
    if not results:
        return {"exact_match": 0}
    unwrapped = [r[0] if isinstance(r, list) else r for r in results]
    scores = [_score_single(r, doc) for r in unwrapped]
    return {"exact_match": sum(scores) / len(scores)}
