import re
from typing import Optional

_SOLUTION_CLIP_CHARS = 500


def normalize_answer(answer: str) -> str:
    """Normalize a math answer for comparison."""
    if answer is None:
        return ""
    # Remove spaces
    answer = answer.strip()
    # Remove leading/trailing whitespace inside
    answer = re.sub(r'\s+', ' ', answer)
    return answer


def extract_solution(solution_str: str, method: str = "strict") -> Optional[str]:
    """Extract the answer from a model response.
    
    Args:
        solution_str: The model's response string
        method: 'strict' looks for \\boxed{}, 'flexible' looks for last boxed or answer pattern
    
    Returns:
        Extracted answer string or None
    """
    assert method in ["strict", "flexible"]
    
    # For long strings, only look at the end (optimization)
    if len(solution_str) > _SOLUTION_CLIP_CHARS:
        solution_str = solution_str[-_SOLUTION_CLIP_CHARS:]
    
    if method == "strict":
        # Look for \boxed{...} pattern - handles nested braces
        # Find all \boxed{ occurrences and extract content
        boxed_pattern = r'\\boxed\{([^{}]*(?:\{[^{}]*\}[^{}]*)*)\}'
        matches = re.findall(boxed_pattern, solution_str)
        if matches:
            return matches[-1]  # Return last match
        
        # Try simpler pattern for single-level braces
        simple_pattern = r'\\boxed\{([^}]+)\}'
        matches = re.findall(simple_pattern, solution_str)
        if matches:
            return matches[-1]
        
        return None
        
    elif method == "flexible":
        # First try strict method
        answer = extract_solution(solution_str, method="strict")
        if answer is not None:
            return answer
        
        # Try other patterns
        # Look for "answer is X" or "= X" at the end
        patterns = [
            r'answer\s+is\s*[:\s]*([^\n.]+)',
            r'(?:therefore|thus|hence|so)\s*,?\s*(?:the\s+)?(?:answer\s+is\s*)?[:\s]*([^\n.]+)',
            r'=\s*([^\n=]+)$',
        ]
        
        for pattern in patterns:
            matches = re.findall(pattern, solution_str, re.IGNORECASE)
            if matches:
                return matches[-1].strip()
        
        return None


def compute_score(solution_str: str, ground_truth: str, method: str = "strict", 
                  format_score: float = 0.0, score: float = 1.0) -> float:
    """Compute the score for a MATH problem.
    
    Args:
        solution_str: The model's response
        ground_truth: The ground truth answer
        method: 'strict' or 'flexible' extraction
        format_score: Score when format is correct but answer is wrong
        score: Score for correct answer
    
    Returns:
        Score value
    """
    answer = extract_solution(solution_str, method=method)
    if answer is None:
        return 0.0
    
    # Normalize both answers for comparison
    norm_answer = normalize_answer(answer)
    norm_truth = normalize_answer(ground_truth)
    
    # Direct string match
    if norm_answer == norm_truth:
        return score
    
    # Try without spaces
    if norm_answer.replace(" ", "") == norm_truth.replace(" ", ""):
        return score
    
    # Try with common LaTeX simplifications
    def simplify_latex(s):
        s = s.replace("\\left", "").replace("\\right", "")
        s = s.replace("\\,", " ")
        s = s.replace("\\!", "")
        s = s.replace("\\quad", " ")
        s = s.replace("\\qquad", " ")
        s = re.sub(r'\s+', ' ', s)
        s = s.strip()
        return s
    
    simp_answer = simplify_latex(norm_answer)
    simp_truth = simplify_latex(norm_truth)
    
    if simp_answer == simp_truth:
        return score
    
    if simp_answer.replace(" ", "") == simp_truth.replace(" ", ""):
        return score
    
    return format_score

