"""
OlympiadBench specific reward function.
Handles complex answer types: Numerical, Expression, Tuple, Interval.
Uses math-verify for robust mathematical equivalence checking when available.
"""

import re
from typing import Optional


def last_boxed_only_string(string: str) -> Optional[str]:
    """Extract the last LaTeX boxed expression from a string.
    
    Handles multiple formats:
    - \\boxed{content}
    - \\fbox{content}
    """
    # First try \\boxed{...}
    idx = string.rfind("\\boxed{")
    if idx >= 0:
        i = idx
        right_brace_idx = None
        num_left_braces_open = 0
        while i < len(string):
            if string[i] == "{":
                num_left_braces_open += 1
            if string[i] == "}":
                num_left_braces_open -= 1
                if num_left_braces_open == 0:
                    right_brace_idx = i
                    break
            i += 1
        if right_brace_idx is not None:
            return string[idx : right_brace_idx + 1]
    
    # Try \\fbox{...}
    idx = string.rfind("\\fbox")
    if idx >= 0:
        i = idx
        right_brace_idx = None
        num_left_braces_open = 0
        while i < len(string):
            if string[i] == "{":
                num_left_braces_open += 1
            if string[i] == "}":
                num_left_braces_open -= 1
                if num_left_braces_open == 0:
                    right_brace_idx = i
                    break
            i += 1
        if right_brace_idx is not None:
            return string[idx : right_brace_idx + 1]
    
    return None


def remove_boxed(s: str) -> str:
    """Remove the LaTeX boxed command from a string."""
    if s.startswith("\\boxed{") and s.endswith("}"):
        return s[len("\\boxed{"):-1]
    
    if s.startswith("\\fbox{") and s.endswith("}"):
        return s[len("\\fbox{"):-1]
    
    return s


def normalize_latex(s: str) -> str:
    """Normalize LaTeX string for comparison."""
    if s is None:
        return ""
    
    s = str(s)
    
    # Remove common LaTeX formatting
    s = s.replace("\\!", "")
    s = s.replace("\\,", "")
    s = s.replace("\\;", "")
    s = s.replace("\\:", "")
    s = s.replace("\\ ", "")
    s = s.replace("\\quad", "")
    s = s.replace("\\qquad", "")
    
    # Normalize fractions
    s = s.replace("\\dfrac", "\\frac")
    s = s.replace("\\tfrac", "\\frac")
    
    # Remove \left and \right
    s = s.replace("\\left", "")
    s = s.replace("\\right", "")
    
    # Remove \text{} wrapper but keep content
    s = re.sub(r"\\text\{([^}]*)\}", r"\1", s)
    s = re.sub(r"\\textbf\{([^}]*)\}", r"\1", s)
    s = re.sub(r"\\mathrm\{([^}]*)\}", r"\1", s)
    
    # Normalize degrees
    s = s.replace("^{\\circ}", "°")
    s = s.replace("^\\circ", "°")
    s = s.replace("\\circ", "°")
    
    # Remove dollar signs
    s = s.replace("$", "")
    
    # Normalize whitespace
    s = re.sub(r'\s+', '', s)
    
    return s.strip()


def normalize_number(s: str) -> str:
    """Normalize numbers: convert 42.0 to 42, 1.50 to 1.5"""
    try:
        # Remove commas for numbers like 1,000
        cleaned = s.replace(",", "").replace(" ", "")
        val = float(cleaned)
        if val == int(val):
            return str(int(val))
        return str(val)
    except (ValueError, OverflowError):
        return s


def extract_answer(solution_str: str) -> Optional[str]:
    """Extract the answer from a solution string.
    
    Tries multiple extraction methods in order of priority:
    1. Last \\boxed{} expression
    2. "The answer is" pattern
    3. "Answer:" pattern
    4. Final expression after "=" or "Therefore"
    """
    # Method 1: Try to find boxed answer (most reliable)
    boxed = last_boxed_only_string(solution_str)
    if boxed is not None:
        return remove_boxed(boxed)
    
    # Method 2: Try various answer patterns
    answer_patterns = [
        # "the answer is X" or "the final answer is X"
        r"(?i)(?:the\s+)?(?:final\s+)?answer\s+is[:\s]+(.+?)(?:\.|$|\n)",
        # "Answer: X"
        r"(?i)answer\s*[:=]\s*(.+?)(?:\.|$|\n)",
        # "Therefore, X" or "Thus, X" or "Hence, X"
        r"(?i)(?:therefore|thus|hence)[,\s]+(?:the\s+answer\s+is\s+)?(.+?)(?:\.|$|\n)",
        # "= X" at the end
        r"=\s*([^\n=]+?)(?:\.|$|\n)",
        # "So X" at the end
        r"(?i)so\s+(?:the\s+answer\s+is\s+)?(.+?)(?:\.|$|\n)",
    ]
    
    for pattern in answer_patterns:
        matches = re.findall(pattern, solution_str)
        if matches:
            # Take the last match
            answer = matches[-1].strip()
            # Clean up the answer
            answer = answer.strip(".")
            answer = answer.strip()
            if answer and len(answer) < 500:  # Sanity check
                return answer
    
    return None


def compare_tuples(pred: str, gt: str) -> bool:
    """Compare tuple answers, handling multiple valid tuples and order."""
    # Normalize both strings
    pred_norm = normalize_latex(pred)
    gt_norm = normalize_latex(gt)
    
    # Direct match
    if pred_norm == gt_norm:
        return True
    
    # Extract all tuples from both strings
    tuple_pattern = r'\(([^()]+)\)'
    pred_tuples = set(re.findall(tuple_pattern, pred_norm))
    gt_tuples = set(re.findall(tuple_pattern, gt_norm))
    
    # If we found tuples in both, compare the sets
    if pred_tuples and gt_tuples:
        # Normalize each tuple's content
        pred_tuples_norm = set()
        for t in pred_tuples:
            parts = [normalize_number(p.strip()) for p in t.split(',')]
            pred_tuples_norm.add(tuple(parts))
        
        gt_tuples_norm = set()
        for t in gt_tuples:
            parts = [normalize_number(p.strip()) for p in t.split(',')]
            gt_tuples_norm.add(tuple(parts))
        
        # Check if prediction contains all ground truth tuples
        # (allowing for extra tuples in prediction is too lenient, 
        #  but missing tuples should be penalized)
        return pred_tuples_norm == gt_tuples_norm
    
    return False


def compare_expressions(pred: str, gt: str) -> bool:
    """Compare mathematical expressions."""
    pred_norm = normalize_latex(pred)
    gt_norm = normalize_latex(gt)
    
    # Direct match after normalization
    if pred_norm == gt_norm:
        return True
    
    # Try numeric comparison
    pred_num = normalize_number(pred_norm)
    gt_num = normalize_number(gt_norm)
    if pred_num == gt_num:
        return True
    
    # Handle expressions like "f(x)=2x" vs "f(x) = 2x"
    # Remove all spaces and compare
    if pred_norm.replace(" ", "") == gt_norm.replace(" ", ""):
        return True
    
    return False


def compare_intervals(pred: str, gt: str) -> bool:
    """Compare interval answers."""
    pred_norm = normalize_latex(pred)
    gt_norm = normalize_latex(gt)
    
    # Remove trailing periods
    pred_norm = pred_norm.rstrip(".")
    gt_norm = gt_norm.rstrip(".")
    
    # Direct match
    if pred_norm == gt_norm:
        return True
    
    # Normalize interval notation
    for old, new in [("\\cup", "∪"), ("\\cap", "∩"), ("\\infty", "∞"), 
                     ("-\\infty", "-∞"), ("+\\infty", "+∞")]:
        pred_norm = pred_norm.replace(old, new)
        gt_norm = gt_norm.replace(old, new)
    
    return pred_norm == gt_norm


def is_equiv(pred: Optional[str], gt: Optional[str], answer_type: str = "Numerical") -> bool:
    """Check if prediction is equivalent to ground truth.
    
    Args:
        pred: Predicted answer
        gt: Ground truth answer
        answer_type: Type of answer (Numerical, Expression, Tuple, Interval)
    """
    if pred is None:
        return False
    if gt is None:
        return False
    
    # Normalize both
    pred_norm = normalize_latex(pred)
    gt_norm = normalize_latex(gt)
    
    # Direct match
    if pred_norm == gt_norm:
        return True
    
    # Type-specific comparison
    if answer_type == "Tuple":
        return compare_tuples(pred, gt)
    elif answer_type == "Interval":
        return compare_intervals(pred, gt)
    elif answer_type == "Expression":
        return compare_expressions(pred, gt)
    else:  # Numerical
        # Try numeric comparison
        pred_num = normalize_number(pred_norm)
        gt_num = normalize_number(gt_norm)
        if pred_num == gt_num:
            return True
        
        # Handle cases like "k=1" where gt is "k=1" or "1"
        if "=" in gt_norm:
            gt_value = gt_norm.split("=")[-1].strip()
            if normalize_number(pred_norm) == normalize_number(gt_value):
                return True
        if "=" in pred_norm:
            pred_value = pred_norm.split("=")[-1].strip()
            if normalize_number(pred_value) == normalize_number(gt_norm):
                return True
        
        # Handle multiple numerical answers like "69,84"
        if "," in gt_norm:
            gt_parts = set(normalize_number(p.strip()) for p in gt_norm.split(","))
            pred_parts = set(normalize_number(p.strip()) for p in pred_norm.split(","))
            if gt_parts == pred_parts:
                return True
    
    return False


# Try to use math-verify for more robust comparison
_math_verify_available = False
_math_verify_warning_printed = False
try:
    from math_verify import verify as math_verify_check
    _math_verify_available = True
except ImportError:
    pass


def _print_math_verify_warning():
    """Print warning about math-verify once."""
    global _math_verify_warning_printed
    if not _math_verify_warning_printed and not _math_verify_available:
        print("To use Math-Verify, please install it first by running `pip install math-verify`.")
        _math_verify_warning_printed = True


def compute_score(solution_str: str, ground_truth: str, extra_info: dict = None) -> dict:
    """Compute the reward score for an OlympiadBench solution.
    
    Args:
        solution_str: The model's solution string
        ground_truth: The ground truth answer
        extra_info: Optional dict containing 'answer_type' 
                   (Numerical, Expression, Tuple, Interval)
        
    Returns:
        dict with:
            - score: 1.0 for correct, -1.0 for incorrect
            - acc: True/False for accuracy
            - pred: The extracted prediction
    """
    # Print warning about math-verify (only once)
    _print_math_verify_warning()
    
    # Get answer type from extra_info
    answer_type = "Numerical"
    if extra_info and isinstance(extra_info, dict):
        answer_type = extra_info.get("answer_type", "Numerical")
    
    # Extract answer from solution (use more context for complex answers)
    pred = extract_answer(solution_str)
    
    # If no answer found, try extracting from the last part of the solution
    if pred is None:
        # Try to find any mathematical expression at the end
        last_part = solution_str[-500:] if len(solution_str) > 500 else solution_str
        # Look for patterns like "= expression" at the end
        match = re.search(r'=\s*([^\n=]+?)(?:\s*$|\s*\.?\s*$)', last_part)
        if match:
            pred = match.group(1).strip()
    
    # Check correctness
    correct = False
    if pred is not None:
        # First try our custom comparison
        correct = is_equiv(pred, ground_truth, answer_type)
        
        # If not correct and math-verify is available, try that too
        if not correct and _math_verify_available:
            try:
                correct = math_verify_check(pred, ground_truth)
            except Exception:
                pass
    
    # Return [0, 1] reward range - avoid negative rewards for unextracted answers
    # This prevents punishing the model when it hasn't learned the format yet
    reward = 1.0 if correct else 0.0
    
    # Only return numeric fields that can be safely averaged in metrics
    # pred and answer_type are non-numeric and would cause issues in np.mean()
    return {
        "score": reward,
        "acc": 1.0 if correct else 0.0,  # Use numeric values for metrics
    }
