"""
Countdown task reward verification utilities.
Based on TinyZero: https://github.com/Jiayi-Pan/TinyZero
"""

import re
import random
from typing import Dict, Any


def extract_solution(solution_text: str) -> str:
    """
    Extract the equation from the solution text.

    Looks for content within <answer> tags or after "Assistant:" marker.

    Args:
        solution_text: The full solution text from the model

    Returns:
        The extracted equation string, or empty string if not found
    """
    # Try to extract from <answer> tags first
    answer_match = re.search(r'<answer>\s*(.*?)\s*</answer>', solution_text, re.IGNORECASE | re.DOTALL)
    if answer_match:
        return answer_match.group(1).strip()

    # Fall back to looking after "Assistant:" marker
    if "Assistant:" in solution_text:
        parts = solution_text.split("Assistant:")
        return parts[-1].strip()

    return solution_text.strip()


def validate_equation(equation: str, available_numbers: list) -> bool:
    """
    Validate that equation only uses available numbers and each number once.

    Args:
        equation: The equation string to validate
        available_numbers: List of numbers that can be used

    Returns:
        True if equation is valid (uses only available numbers, each once), False otherwise
    """
    # Extract all numbers from the equation
    number_pattern = r'\b\d+(?:\.\d+)?\b'
    used_numbers = re.findall(number_pattern, equation)

    # Convert to integers/floats for comparison
    try:
        used_nums = []
        for num_str in used_numbers:
            if '.' in num_str:
                used_nums.append(float(num_str))
            else:
                used_nums.append(int(num_str))
    except ValueError:
        return False

    # Check if the used numbers match the available numbers (order doesn't matter)
    available_sorted = sorted(available_numbers)
    used_sorted = sorted(used_nums)

    return available_sorted == used_sorted


def evaluate_equation(equation: str) -> float:
    """
    Safely evaluate an arithmetic equation.

    Only allows numbers, basic operators (+, -, *, /), parentheses, and whitespace.

    Args:
        equation: The equation string to evaluate

    Returns:
        The result of the equation, or None if evaluation fails
    """
    # Security check: only allow safe characters
    allowed_pattern = r'^[\d\s\+\-\*\/\(\)\.]+$'
    if not re.match(allowed_pattern, equation):
        return None

    try:
        # Use eval with restricted namespace for safety
        result = eval(equation, {"__builtins__": {}}, {})
        return float(result)
    except (SyntaxError, ZeroDivisionError, TypeError, NameError, ValueError, OverflowError):
        return None


def compute_score(
    solution_text: str,
    ground_truth: Dict[str, Any],
    correct_score: float = 1.0,
    format_score: float = 0.1,
    debug: bool = False
) -> float:
    """
    Compute the reward score for a countdown task solution.

    Args:
        solution_text: The model's generated solution text
        ground_truth: Dict containing 'target' (int) and 'numbers' (list of ints)
        correct_score: Score to return for correct answer (default: 1.0)
        format_score: Score to return for valid format but wrong answer (default: 0.1)
        debug: If True, print debug information

    Returns:
        The reward score: correct_score if correct, format_score if valid format but wrong,
        0.0 if invalid or missing equation
    """
    target = ground_truth['target']
    available_numbers = ground_truth['numbers']

    # Random debug printing (1 in 64 chance)
    should_debug = debug or (random.random() < 1.0 / 64.0)

    # Extract the equation from the solution
    equation = extract_solution(solution_text)

    if should_debug:
        print(f"\n=== Countdown Verification Debug ===")
        print(f"Target: {target}")
        print(f"Available numbers: {available_numbers}")
        print(f"Extracted equation: '{equation}'")

    if not equation:
        if should_debug:
            print(f"No equation found!")
        return 0.0

    # Validate that the equation uses the correct numbers
    is_valid = validate_equation(equation, available_numbers)

    if should_debug:
        print(f"Equation valid (uses correct numbers): {is_valid}")

    if not is_valid:
        # Invalid format, but we might still give partial credit if answer tag exists
        if '<answer>' in solution_text.lower():
            return format_score * 0.5  # Partial format credit
        return 0.0

    # Evaluate the equation
    result = evaluate_equation(equation)

    if should_debug:
        print(f"Evaluated result: {result}")

    if result is None:
        if should_debug:
            print(f"Equation evaluation failed!")
        return format_score * 0.5  # Partial format credit for valid structure

    # Check if the result matches the target (with floating point tolerance)
    is_correct = abs(result - target) < 1e-6

    if should_debug:
        print(f"Matches target: {is_correct}")
        print(f"Score: {correct_score if is_correct else format_score}")
        print(f"=====================================\n")

    return correct_score if is_correct else format_score
