#!/usr/bin/env python3

import re
import os
import torch
from typing import List, Optional, Tuple

# Import existing utilities from the MATH evaluation codebase
from grader import math_equal_process
from parser import run_execute
from python_executor import PythonExecutor

os.environ["TOKENIZERS_PARALLELISM"] = "false"


class AnswerComparator:
    """Clean comparator for predicted and ground truth answers."""
    
    def __init__(self, prompt_type: str = "cot"):
        self.prompt_type = prompt_type
        self.executor = PythonExecutor(get_answer_from_stdout=True)
        
        # Check if we're in a CUDA environment
        self.cuda_initialized = torch.cuda.is_initialized() if torch.cuda.is_available() else False
    
    def extract_xml_answer(self, text: str) -> Optional[str]:
        """Extract answer from XML tags if present."""
        # Look for answer in XML tags, handling newlines
        m = re.search(
            r"<answer>[\s\n]*(.*?)[\s\n]*</answer>", text, flags=re.DOTALL
        )
        if m:
            return m.group(1).strip()
        return None
    
    def prepare_code_for_extraction(self, code: str) -> str:
        """
        If the code contains XML answer tags, extract the answer and
        prepare it for the core extraction logic.
        """
        # First check if there's an XML answer
        xml_answer = self.extract_xml_answer(code)
        
        if xml_answer:
            cleaned_answer = re.sub(r'\\\((.*?)\\\)', r'\1', xml_answer)
            cleaned_answer = re.sub(r'^\\boxed\{(.*)\}$', r'\1', cleaned_answer, flags=re.DOTALL)
            return f"The answer is \\boxed{{{cleaned_answer}}}"
        
        # For non-XML answers, strip inline math delimiters but don't add wrapper
        code = re.sub(r'\\\((.*?)\\\)', r'\1', code, flags=re.DOTALL)
        
        return code
    
    def process_prediction(self, predicted_answer: str, data_name: str = "math") -> Tuple[str, str]:
        """Process the predicted answer to extract the final answer."""
        # Prepare code for extraction (handle XML tags)
        prepared_code = self.prepare_code_for_extraction(predicted_answer)
        
        # Extract prediction using the existing logic
        pred, report = run_execute(self.executor, prepared_code, self.prompt_type, data_name)
        
        return pred, report


def compare_answers_simple_batch(
    predicted_answers: List[str], 
    ground_truths: List[str], 
    prompt_type: str = "cot", 
    data_name: str = "math"
) -> List[bool]:
    """
    Batch version of compare_answers_simple that reuses the same comparator.
    
    Args:
        predicted_answers: List of predicted answers (can include XML tags, boxed notation, etc.)
        ground_truths: List of ground truth answers
        prompt_type: Type of prompt used (default: "cot")
        data_name: Dataset name (default: "math")
        
    Returns:
        List[bool]: List of comparison results (True if equal, False otherwise)
    """
    # Create comparator once for all comparisons
    comparator = AnswerComparator(prompt_type=prompt_type)
    
    results = []
    
    # Process all predictions and compare
    for predicted_answer, ground_truth in zip(predicted_answers, ground_truths):
        try:
            # Process prediction
            pred, _ = comparator.process_prediction(predicted_answer, data_name)
            
            # Compare directly without multiprocessing
            result = math_equal_process((0, pred, ground_truth))
            results.append(result)
        except Exception as e:
            print(f"Error during comparison: {e}")
            results.append(False)
    
    return results


# Keep the original simple function for backward compatibility
def compare_answers_simple(predicted_answer: str, ground_truth: str, prompt_type: str = "cot", data_name: str = "math") -> bool:
    """
    Simple single-process version for use in CUDA training environments.
    
    Args:
        predicted_answer: The predicted answer (can include XML tags, boxed notation, etc.)
        ground_truth: The ground truth answer
        prompt_type: Type of prompt used (default: "cot")
        data_name: Dataset name (default: "math")
        
    Returns:
        bool: True if answers are mathematically equal, False otherwise
    """
    comparator = AnswerComparator(prompt_type=prompt_type)
    pred, _ = comparator.process_prediction(predicted_answer, data_name)
    
    try:
        # Call math_equal_process directly without multiprocessing
        result = math_equal_process((0, pred, ground_truth))
        return result
    except Exception as e:
        print(f"Error during comparison: {e}")
        return False


# Example usage
if __name__ == "__main__":
    # Test batch comparison
    predictions = [
        """<reasoning>
        Let me solve this step by step...
        </reasoning>
        <answer>
        \\boxed{42}
        </answer>""",
        "<answer>\n\\(\\frac{48}{95}\\)\n</answer>",
        """<answer>
        \[
        \ begin{pmatrix} 12 \\ -21 \\ 3 \end{pmatrix}
        \]
        </answer> """
    ]
    
    ground_truths = ["42", "\\frac{48}{95}", "\ begin{pmatrix}12\\-21\\3\end{pmatrix} "]
    
    results = compare_answers_simple_batch(predictions, ground_truths)
    print(f"Batch results: {results}")  # Should print [True, True, False]
    
    # Show timing difference
    import time
    
    # Individual calls (inefficient)
    start = time.time()
    individual_results = []
    for pred, gt in zip(predictions, ground_truths):
        result = compare_answers_simple(pred, gt)
        individual_results.append(result)
    individual_time = time.time() - start
    
    # Batch call (efficient)
    start = time.time()
    batch_results = compare_answers_simple_batch(predictions, ground_truths)
    batch_time = time.time() - start
    
    print(f"\nIndividual calls time: {individual_time:.4f}s")
    print(f"Batch call time: {batch_time:.4f}s")
    print(f"Speedup: {individual_time/batch_time:.2f}x")