"""
AIME24 task utilities for lm-evaluation-harness.

Provides answer extraction and comparison for AIME math competition problems.
AIME answers are integers from 0-999.
"""

import logging
import re
from typing import Dict, List, Optional


eval_logger = logging.getLogger(__name__)


def process_results(doc: dict, results: List[str]) -> Dict[str, int]:
    """
    Process model output and compare with ground truth.
    
    Args:
        doc: Document dict with 'solution' field containing ground truth
        results: List of model outputs (we use the first one)
    
    Returns:
        Dict with 'exact_match' key (1 if correct, 0 otherwise)
    """
    model_output = results[0]
    
    # Extract answer from model output
    predicted = extract_boxed_answer(model_output)
    if predicted is None:
        # Fallback: try to extract last number
        predicted = extract_last_number(model_output)
    
    # Extract ground truth from answer field (AIME25 uses 'answer', not 'solution')
    # AIME answers are integers 0-999
    ground_truth = doc.get("answer", "")
    
    # Normalize and compare
    pred_normalized = normalize_answer(predicted) if predicted else ""
    gt_normalized = normalize_answer(ground_truth) if ground_truth else ""
    
    exact_match = 1 if pred_normalized == gt_normalized else 0
    
    eval_logger.debug(
        f"Predicted: {pred_normalized}, Ground Truth: {gt_normalized}, Match: {exact_match}"
    )
    
    return {"exact_match": exact_match}


def extract_boxed_answer(text: str) -> Optional[str]:
    """
    Extract the last \\boxed{} content from text.
    Handles nested braces correctly.
    """
    if not text:
        return None
    
    # Find the last \boxed occurrence
    idx = text.rfind("\\boxed")
    if idx < 0:
        # Try \fbox as fallback
        idx = text.rfind("\\fbox")
        if idx < 0:
            return None
    
    # Handle \boxed with space (e.g., \boxed 123)
    if "\\boxed " in text[idx:idx+10]:
        match = re.search(r"\\boxed\s+(\S+)", text[idx:])
        if match:
            return match.group(1).rstrip(".,;")
    
    # Handle \boxed{...} with proper brace matching
    i = idx
    while i < len(text) and text[i] != "{":
        i += 1
    
    if i >= len(text):
        return None
    
    # Match braces
    start = i + 1
    depth = 1
    i += 1
    while i < len(text) and depth > 0:
        if text[i] == "{":
            depth += 1
        elif text[i] == "}":
            depth -= 1
        i += 1
    
    if depth == 0:
        return text[start:i-1].strip()
    
    return None


def extract_last_number(text: str) -> Optional[str]:
    """
    Extract the last number from text as fallback.
    """
    if not text:
        return None
    
    # Find all numbers (including negative and decimals)
    numbers = re.findall(r"-?\d+(?:\.\d+)?", text)
    if numbers:
        return numbers[-1]
    return None


def extract_aime_answer(solution: str) -> Optional[str]:
    """
    Extract the AIME answer from a solution string.
    
    AIME answers are integers 0-999. Solutions may contain:
    - "The answer is X"
    - "\\boxed{X}"
    - Just the integer
    """
    if not solution:
        return None
    
    # Try boxed first
    boxed = extract_boxed_answer(solution)
    if boxed:
        return boxed
    
    # Try "The answer is X" pattern
    match = re.search(r"[Tt]he\s+answer\s+is\s+[:\s]*(\d+)", solution)
    if match:
        return match.group(1)
    
    # Try "answer: X" or "Answer: X"
    match = re.search(r"[Aa]nswer[:\s]+(\d+)", solution)
    if match:
        return match.group(1)
    
    # If solution is just a number, use it
    solution_stripped = solution.strip()
    if solution_stripped.isdigit():
        return solution_stripped
    
    # Last resort: extract last number
    return extract_last_number(solution)


def normalize_answer(answer: str) -> str:
    """
    Normalize an answer for comparison.
    
    - Strip whitespace
    - Remove commas from numbers
    - Convert to integer string if possible
    - Remove LaTeX formatting
    """
    if not answer:
        return ""
    
    answer = answer.strip()
    
    # Remove LaTeX math delimiters
    answer = answer.replace("$", "")
    
    # Remove common LaTeX commands
    answer = re.sub(r"\\text\{([^}]*)\}", r"\1", answer)
    answer = re.sub(r"\\textbf\{([^}]*)\}", r"\1", answer)
    answer = re.sub(r"\\mathrm\{([^}]*)\}", r"\1", answer)
    
    # Remove commas from numbers (e.g., 1,000 -> 1000)
    answer = answer.replace(",", "")
    
    # Remove trailing periods
    answer = answer.rstrip(".")
    
    # Strip again after all substitutions
    answer = answer.strip()
    
    # Try to convert to integer for cleaner comparison
    try:
        # Handle potential float strings
        num = float(answer)
        if num == int(num):
            return str(int(num))
    except (ValueError, OverflowError):
        pass
    
    return answer
