﻿#!/usr/bin/env python3
"""
Geometry answer extraction and scoring helpers for reward computation.

Derived from logic in evaluate_instantiation.py and latex_parser.py.
"""

import re
from typing import Any, Dict, Optional

import sympy as sp
from sympy.parsing.sympy_parser import (
    implicit_multiplication_application,
    parse_expr,
    standard_transformations,
)

# =========================
# LaTeX parsing helpers (adapted from latex_parser.py)
# =========================

# SymPy parsing: allow implicit multiplication (e.g., 3sqrt(2) -> 3*sqrt(2)).
TRANS = standard_transformations + (implicit_multiplication_application,)

# Regex patterns
INNER_BRACE_SQRT = re.compile(r"\\sqrt\{([^{}]+)\}")
SIMPLE_SQRT = re.compile(r"\\sqrt\s*([0-9A-Za-z\.]+)")
INNER_FRAC = re.compile(r"\\[dt]?frac\{([^{}]+)\}\{([^{}]+)\}")
INNER_ARCTAN = re.compile(r"\\arctan\{([^{}]+)\}")
INNER_ARCCOS = re.compile(r"\\arccos\{([^{}]+)\}")
INNER_ARCSIN = re.compile(r"\\arcsin\{([^{}]+)\}")
NTH_ROOT = re.compile(r"\\sqrt\[(\d+)\]\{([^{}]+)\}")

SIN_DEG = re.compile(r"\\sin\s*([0-9\.]+)\^\{?\\circ\}?")
COS_DEG = re.compile(r"\\cos\s*([0-9\.]+)\^\{?\\circ\}?")
TAN_DEG = re.compile(r"\\tan\s*([0-9\.]+)\^\{?\\circ\}?")

SIN_PAREN = re.compile(r"\\sin\(([^()]+)\)")
COS_PAREN = re.compile(r"\\cos\(([^()]+)\)")
TAN_PAREN = re.compile(r"\\tan\(([^()]+)\)")

SIN_SIMPLE = re.compile(r"\\sin\s*([0-9A-Za-z\.]+)")
COS_SIMPLE = re.compile(r"\\cos\s*([0-9A-Za-z\.]+)")
TAN_SIMPLE = re.compile(r"\\tan\s*([0-9A-Za-z\.]+)")


def strip_latex_noise(s: str) -> str:
    """Remove LaTeX formatting commands like \left, \right, \bigl, \bigr, \, etc."""
    for junk in [
        r"\left",
        r"\right",
        r"\bigl",
        r"\bigr",
        r"\Bigl",
        r"\Bigr",
        r"\,",
        r"\!",
        r"\;",
        r"\:",
        r"\quad",
        r"\qquad",
    ]:
        s = s.replace(junk, "")
    return s


def handle_trig_degrees(s: str) -> str:
    r"""Convert \sin 10^\circ into sin(pi*10/180)."""
    s = SIN_DEG.sub(r"sin(pi*\1/180)", s)
    s = COS_DEG.sub(r"cos(pi*\1/180)", s)
    s = TAN_DEG.sub(r"tan(pi*\1/180)", s)
    return s


def handle_basic_trig(s: str) -> str:
    r"""Normalize \sin(x) or \sin x into sin(x)."""
    s = SIN_PAREN.sub(r"sin(\1)", s)
    s = COS_PAREN.sub(r"cos(\1)", s)
    s = TAN_PAREN.sub(r"tan(\1)", s)
    s = SIN_SIMPLE.sub(r"sin(\1)", s)
    s = COS_SIMPLE.sub(r"cos(\1)", s)
    s = TAN_SIMPLE.sub(r"tan(\1)", s)
    return s


def handle_arc_trig_frac(s: str) -> str:
    """Handle \\arccos\\frac{a}{b}, \\arcsin\\frac{a}{b}, \\arctan\\frac{a}{b}."""
    for name, func in [("arccos", "acos"), ("arcsin", "asin"), ("arctan", "atan")]:
        needle = "\\" + name + "\\frac{"
        while True:
            start = s.find(needle)
            if start == -1:
                break

            i = start + len(needle)
            depth = 1
            num_start = i
            while i < len(s) and depth > 0:
                if s[i] == "{":
                    depth += 1
                elif s[i] == "}":
                    depth -= 1
                i += 1
            if depth != 0:
                break
            num_end = i - 1
            numerator = s[num_start:num_end]

            if i >= len(s) or s[i] != "{":
                break
            i += 1
            depth = 1
            den_start = i
            while i < len(s) and depth > 0:
                if s[i] == "{":
                    depth += 1
                elif s[i] == "}":
                    depth -= 1
                i += 1
            if depth != 0:
                break
            den_end = i - 1
            denominator = s[den_start:den_end]

            full_end = i
            replacement = f"{func}(({numerator})/({denominator}))"
            s = s[:start] + replacement + s[full_end:]
    return s


def handle_arc_trig_paren(s: str) -> str:
    """Handle \\arccos(...), \\arcsin(...), \\arctan(...)."""
    for name, func in [("arccos", "acos"), ("arcsin", "asin"), ("arctan", "atan")]:
        needle = "\\" + name + "("
        while True:
            start = s.find(needle)
            if start == -1:
                break

            i = start + len(needle)
            depth = 1
            arg_start = i
            while i < len(s) and depth > 0:
                if s[i] == "(":
                    depth += 1
                elif s[i] == ")":
                    depth -= 1
                i += 1
            if depth != 0:
                break
            arg_end = i - 1
            arg = s[arg_start:arg_end]
            full_end = i

            replacement = f"{func}({arg})"
            s = s[:start] + replacement + s[full_end:]
    return s


def _normalize_latex_expr(s: str) -> str:
    """Normalize a LaTeX math expression into a SymPy-parsable string."""
    s = s.replace(" ", "")

    s = handle_trig_degrees(s)
    s = re.sub(r"\^\{?\\circ\}?", "", s)

    s = s.replace(r"\\pi", "pi")
    s = s.replace(r"\\cdot", "*")

    s = strip_latex_noise(s)
    s = handle_basic_trig(s)
    s = handle_arc_trig_frac(s)

    while True:
        old = s
        s = NTH_ROOT.sub(r"(\2)**(1/(\1))", s)
        s = INNER_BRACE_SQRT.sub(r"sqrt(\1)", s)
        s = SIMPLE_SQRT.sub(r"sqrt(\1)", s)
        s = INNER_FRAC.sub(r"(\1)/(\2)", s)
        s = INNER_ARCTAN.sub(r"atan(\1)", s)
        s = INNER_ARCCOS.sub(r"acos(\1)", s)
        s = INNER_ARCSIN.sub(r"asin(\1)", s)

        if s == old:
            break

    s = handle_arc_trig_paren(s)
    s = s.replace("^", "**")

    return s


def latex_to_float(answer: Optional[str]) -> Optional[float]:
    """
    Convert a simple/medium-complexity LaTeX or text expression into a float.

    Returns None if parsing fails.
    """
    if not answer:
        return None

    s = answer.strip()

    if len(s) >= 2 and s[0] == "$" and s[-1] == "$":
        s = s[1:-1].strip()

    s = _normalize_latex_expr(s)

    try:
        expr = parse_expr(
            s,
            transformations=TRANS,
            local_dict={
                "pi": sp.pi,
                "sqrt": sp.sqrt,
                "e": sp.E,
                "E": sp.E,
                "atan": sp.atan,
                "acos": sp.acos,
                "asin": sp.asin,
                "sin": sp.sin,
                "cos": sp.cos,
                "tan": sp.tan,
            },
        )

        if expr.free_symbols:
            return None

        return float(expr.evalf())
    except Exception:
        try:
            return float(s)
        except Exception:
            return None


# =========================
# Answer extraction and scoring logic (adapted from evaluate_instantiation.py)
# =========================


def extract_answer_from_reasoning(text: str) -> Optional[str]:
    """Extract the final answer from model output across common formats."""
    # 1) Prefer <answer>...</answer> tag content.
    pattern_answer = r"<answer>(.*?)</answer>"
    matches = re.findall(pattern_answer, text, re.DOTALL | re.IGNORECASE)
    if matches:
        return matches[-1].strip()

    # 2) Fall back to \boxed{...}.
    pattern_boxed = r"\\boxed\{([^}]+)\}"
    matches = re.findall(pattern_boxed, text)
    if matches:
        return matches[-1].strip()

    return None


def normalize_answer(answer: str) -> str:
    """Normalize an answer string for robust comparison."""
    answer = answer.strip()
    # Remove common prefixes like "Answer:".
    answer = re.sub(r"^(Answer|answer):?\s*", "", answer, flags=re.IGNORECASE)

    # Handle "Choice: content" format, e.g., "C: 145 deg".
    match = re.match(r"^[A-J]:?\s*(.+)$", answer, re.IGNORECASE)
    if match:
        answer = match.group(1).strip()

    return answer.strip()


def compare_answers(
    pred: str,
    gold: str,
    tolerance: float = 1e-6,
) -> bool:
    """
    Compare predicted and gold answers.

    Supports numeric comparison and string comparison.
    """
    pred = normalize_answer(pred)
    gold = normalize_answer(gold)

    # Exact string match first (answers may include LaTeX).
    if pred.lower().strip() == gold.lower().strip():
        return True

    # Compare without spaces.
    pred_clean = pred.replace(" ", "").lower().strip()
    gold_clean = gold.replace(" ", "").lower().strip()
    if pred_clean == gold_clean:
        return True

    # Try numeric comparison after removing common unit symbols.
    pred_for_num = (
        pred.replace("\u00b0", "")
        .replace("deg", "")
        .replace("cm", "")
        .replace("m", "")
        .strip()
    )
    gold_for_num = (
        gold.replace("\u00b0", "")
        .replace("deg", "")
        .replace("cm", "")
        .replace("m", "")
        .strip()
    )

    pred_float = latex_to_float(pred_for_num)
    gold_float = latex_to_float(gold_for_num)

    if pred_float is not None and gold_float is not None:
        return abs(pred_float - gold_float) < tolerance

    return False


def format_reward(solution_str: str) -> float:
    """
    Check whether the required format is present:
    <code>...</code><think>...</think><answer>...</answer>

    Args:
        solution_str: Full model output string.
    Returns:
        0.1 if the format is correct, otherwise 0.0.
    """
    # Check required tags.
    has_code = "<code>" in solution_str and "</code>" in solution_str
    has_reasoning = "<think>" in solution_str and "</think>" in solution_str
    has_answer = "<answer>" in solution_str and "</answer>" in solution_str

    # If all three tags exist, format reward is 0.1.
    if has_code and has_reasoning and has_answer:
        return 0.1

    return 0.0


def compute_score(
    solution_str: str,
    ground_truth: str,
    extra_info: Optional[Dict[str, Any]] = None,
    **kwargs,
) -> float:
    """
    Compute score for a geometry problem.

    - Format correct: +0.1
    - Answer correct: +1.0
    - Final score = format_reward + answer_reward

    Args:
        solution_str: Full model output string.
        ground_truth: Gold answer (may include choice letter or raw content).
        extra_info: Optional metadata (unused, kept for compatibility).
        **kwargs: Extra arguments for interface compatibility.

    Returns:
        Score: 1.1 if both correct, 1.0 if only answer correct,
        0.1 if only format correct, otherwise 0.0.
    """
    # Format reward.
    fmt_reward = format_reward(solution_str)

    # Extract answer and compute answer reward.
    extracted_answer = extract_answer_from_reasoning(solution_str)

    if extracted_answer is None:
        answer_reward = 0.0
    else:
        is_correct = compare_answers(extracted_answer, str(ground_truth))
        answer_reward = 1.0 if is_correct else 0.0

    # Final score = format reward (0.1) + answer reward (1.0).
    final_score = fmt_reward + answer_reward

    return final_score
