"""
MATH-500 specific reward function.
Uses [-1, 1] reward range (correct=1, incorrect=-1) for better GRPO training.
"""

import re
from typing import Optional


def last_boxed_only_string(string: str) -> Optional[str]:
    """Extract the last LaTeX boxed expression from a string.
    
    Handles multiple formats:
    - \\boxed{content}
    - \\boxed content (space separated)
    - \\fbox{content}
    """
    # First try \\boxed{...}
    idx = string.rfind("\\boxed{")
    if idx >= 0:
        i = idx
        right_brace_idx = None
        num_left_braces_open = 0
        while i < len(string):
            if string[i] == "{":
                num_left_braces_open += 1
            if string[i] == "}":
                num_left_braces_open -= 1
                if num_left_braces_open == 0:
                    right_brace_idx = i
                    break
            i += 1
        if right_brace_idx is not None:
            return string[idx : right_brace_idx + 1]
    
    # Try \\boxed content (space format)
    if "\\boxed " in string:
        return "\\boxed " + string.split("\\boxed ")[-1].split("$")[0]
    
    # Try \\fbox{...}
    idx = string.rfind("\\fbox")
    if idx >= 0:
        i = idx
        right_brace_idx = None
        num_left_braces_open = 0
        while i < len(string):
            if string[i] == "{":
                num_left_braces_open += 1
            if string[i] == "}":
                num_left_braces_open -= 1
                if num_left_braces_open == 0:
                    right_brace_idx = i
                    break
            i += 1
        if right_brace_idx is not None:
            return string[idx : right_brace_idx + 1]
    
    return None


def remove_boxed(s: str) -> str:
    """Remove the LaTeX boxed command from a string."""
    if s.startswith("\\boxed "):
        return s[len("\\boxed "):]
    
    if s.startswith("\\boxed{") and s.endswith("}"):
        return s[len("\\boxed{"):-1]
    
    if s.startswith("\\fbox{") and s.endswith("}"):
        return s[len("\\fbox{"):-1]
    
    return s


def normalize_number(s: str) -> str:
    """Normalize numbers: convert 42.0 to 42, 1.50 to 1.5"""
    try:
        val = float(s.replace(",", ""))
        if val == int(val):
            return str(int(val))
        return str(val)
    except (ValueError, OverflowError):
        return s


def fix_fracs(string: str) -> str:
    """Fix LaTeX fraction formatting."""
    substrs = string.split("\\frac")
    new_str = substrs[0]
    if len(substrs) > 1:
        substrs = substrs[1:]
        for substr in substrs:
            new_str += "\\frac"
            if len(substr) == 0:
                continue
            if substr[0] == "{":
                new_str += substr
            else:
                if len(substr) < 2:
                    new_str += substr
                    continue
                a = substr[0]
                b = substr[1]
                if b != "{":
                    if len(substr) > 2:
                        post_substr = substr[2:]
                        new_str += "{" + a + "}{" + b + "}" + post_substr
                    else:
                        new_str += "{" + a + "}{" + b + "}"
                else:
                    if len(substr) > 2:
                        post_substr = substr[2:]
                        new_str += "{" + a + "}" + b + post_substr
                    else:
                        new_str += "{" + a + "}" + b
    return new_str


def fix_a_slash_b(string: str) -> str:
    """Convert a/b to \\frac{a}{b}."""
    if len(string.split("/")) != 2:
        return string
    a = string.split("/")[0]
    b = string.split("/")[1]
    try:
        a_int = int(a)
        b_int = int(b)
        if string == "{}/{}".format(a_int, b_int):
            return "\\frac{" + str(a_int) + "}{" + str(b_int) + "}"
    except Exception:
        pass
    return string


def fix_sqrt(string: str) -> str:
    """Fix LaTeX sqrt formatting."""
    if "\\sqrt" not in string:
        return string
    splits = string.split("\\sqrt")
    new_string = splits[0]
    for split in splits[1:]:
        if len(split) == 0:
            new_string += "\\sqrt"
            continue
        if split[0] != "{":
            a = split[0]
            new_substr = "\\sqrt{" + a + "}" + split[1:]
        else:
            new_substr = "\\sqrt" + split
        new_string += new_substr
    return new_string


def strip_string(string: str) -> str:
    """Normalize a string for comparison."""
    # Remove linebreaks
    string = string.replace("\n", "")
    
    # Remove inverse spaces
    string = string.replace("\\!", "")
    
    # Replace \\ with \
    string = string.replace("\\\\", "\\")
    
    # Replace tfrac and dfrac with frac
    string = string.replace("tfrac", "frac")
    string = string.replace("dfrac", "frac")
    
    # Remove \left and \right
    string = string.replace("\\left", "")
    string = string.replace("\\right", "")
    
    # Remove circ (degrees)
    string = string.replace("^{\\circ}", "")
    string = string.replace("^\\circ", "")
    
    # Remove dollar signs and LaTeX inline math delimiters
    string = string.replace("\\$", "")
    string = string.replace("$", "")
    string = string.replace("\\(", "")
    string = string.replace("\\)", "")
    
    # Remove units (on the right)
    if "\\text{ " in string:
        splits = string.split("\\text{ ")
        if len(splits) == 2:
            string = splits[0]
    
    # Remove percentage
    string = string.replace("\\\\%", "")
    string = string.replace("\\%", "")
    string = string.replace("%", "")
    
    # Remove Markdown formatting
    string = string.replace("**", "")
    string = string.replace("*", "")
    
    # Handle decimal points
    string = string.replace(" .", " 0.")
    string = string.replace("{.", "{0.")
    
    if len(string) == 0:
        return string
    if string[0] == ".":
        string = "0" + string
    
    # Remove variable assignments like "k = " or "q = "
    if len(string.split("=")) == 2 and len(string.split("=")[0]) <= 2:
        string = string.split("=")[1]
    
    # Fix sqrt
    string = fix_sqrt(string)
    
    # Remove spaces
    string = string.replace(" ", "")
    
    # Fix fractions
    string = fix_fracs(string)
    
    # Manually change 0.5 --> \frac{1}{2}
    if string == "0.5":
        string = "\\frac{1}{2}"
    
    # Fix a/b to frac
    string = fix_a_slash_b(string)
    
    return string


def is_equiv(str1: Optional[str], str2: Optional[str]) -> bool:
    """Check if two strings are mathematically equivalent."""
    if str1 is None and str2 is None:
        return True
    if str1 is None or str2 is None:
        return False
    
    try:
        ss1 = strip_string(str1)
        ss2 = strip_string(str2)
        # Normalize numbers (42.0 -> 42)
        ss1 = normalize_number(ss1)
        ss2 = normalize_number(ss2)
        return ss1 == ss2
    except Exception:
        return str1 == str2


def extract_answer(solution_str: str) -> Optional[str]:
    """Extract the answer from a solution string.
    
    Tries multiple extraction methods:
    1. Last \\boxed{} expression
    2. Answer: pattern
    3. Final number/expression
    """
    # Method 1: Try to find boxed answer
    boxed = last_boxed_only_string(solution_str)
    if boxed is not None:
        return remove_boxed(boxed)
    
    # Method 2: Try "Answer:" pattern
    answer_patterns = [
        r"(?i)(?:final\s+)?answer\s*(?:is|:)\s*\**([^\n\.,*]+)\**",  # Handle **answer**
        r"(?i)therefore[,\s]+(?:the\s+)?(?:answer\s+is\s+)?\**([^\n\.,*]+)\**",
        r"(?i)thus[,\s]+(?:the\s+)?(?:answer\s+is\s+)?\**([^\n\.,*]+)\**",
        r"=\s*\**([^\n\.,=*]+)\**$",
    ]
    
    for pattern in answer_patterns:
        match = re.search(pattern, solution_str)
        if match:
            answer = match.group(1).strip()
            # Clean up the answer - remove Markdown formatting
            answer = answer.strip(".")
            answer = answer.strip("*")
            answer = answer.strip()
            if answer:
                return answer
    
    return None


def extract_answer_flexible(solution_str: str) -> Optional[str]:
    """Flexible answer extraction - tries multiple patterns.
    
    Used as fallback when strict boxed extraction fails.
    """
    # Clip long strings for efficiency
    if len(solution_str) > 500:
        solution_str = solution_str[-500:]
    
    # Try other patterns
    patterns = [
        r'answer\s+is\s*[:\s]*([^\n.]+)',
        r'(?:therefore|thus|hence|so)\s*,?\s*(?:the\s+)?(?:answer\s+is\s*)?[:\s]*([^\n.]+)',
        r'=\s*([^\n=]+)$',
    ]
    
    for pattern in patterns:
        matches = re.findall(pattern, solution_str, re.IGNORECASE)
        if matches:
            return matches[-1].strip()
    
    return None


def compute_score(solution_str: str, ground_truth: str) -> dict:
    """Compute the reward score for a MATH-500 solution.
    
    Args:
        solution_str: The model's solution string
        ground_truth: The ground truth answer
        
    Returns:
        dict with:
            - score: 1.0 for correct, 0.0 for incorrect (format-aware)
            - acc: 1.0/0.0 for accuracy (numeric for metrics computation)
            - pred: The extracted prediction (string, never None)
    """
    # Extract answer from solution - try strict first, then flexible
    pred = extract_answer(solution_str)
    
    # Check correctness with strict extraction
    correct = False
    if pred is not None:
        correct = is_equiv(pred, ground_truth)
    
    # If strict failed, try flexible extraction
    if not correct:
        pred_flex = extract_answer_flexible(solution_str)
        if pred_flex is not None:
            if is_equiv(pred_flex, ground_truth):
                correct = True
                pred = pred_flex
    
    # Return [0, 1] reward range - avoid negative rewards for unextracted answers
    # This prevents punishing the model when it hasn't learned the format yet
    reward = 1.0 if correct else 0.0
    
    return {
        "score": reward,
        "acc": 1.0 if correct else 0.0,  # Use numeric values for metrics
        "pred": str(pred) if pred is not None else "",  # Never return None
    }
