import re
import os
from typing import Union, Optional
from math_verify import parse, verify


def extract_answer_from_think_answer_format(text: str) -> Optional[str]:
    """
    Extract answer from the unified format: <think>...</think> followed by Answer:
    This format is used for all question types.
    
    Args:
        text: Response text containing <think></think> and Answer:
        
    Returns:
        Extracted answer string or None if not found
    """
    # First try to find content after <think></think>
    think_pattern = r'<think>(.*?)</think>'
    think_match = re.search(think_pattern, text, re.DOTALL | re.IGNORECASE)
    
    # Then try to find content after Answer:
    answer_pattern = r'Answer:\s*(.+?)(?:\n|$)'
    answer_match = re.search(answer_pattern, text, re.DOTALL | re.IGNORECASE)
    
    # Prioritize Answer: content if both exist
    if answer_match:
        return answer_match.group(1).strip()
    
    return None


def extract_math_answer(text: str) -> Optional[str]:
    """Extract numerical answer from various formats"""
    # First try the unified format
    unified_answer = extract_answer_from_think_answer_format(text)
    if unified_answer:
        text = unified_answer
    
    # Try LaTeX box format
    box_pattern = r'\\box\{([^}]+)\}'
    box_match = re.search(box_pattern, text)
    if box_match:
        return box_match.group(1).strip()
    
    return None


def extract_question_answer(text: str) -> Optional[str]:
    """Extract answer for question-type problems"""
    # First try the unified format
    unified_answer = extract_answer_from_think_answer_format(text)
    if unified_answer:
        return unified_answer
    
    # If no unified format, return the whole text for comparison
    return text.strip()


def normalize_for_comparison(text: str) -> str:
    """Normalize text for comparison by removing punctuation and converting to lowercase"""
    # Remove common punctuation
    text = re.sub(r'[，。！？；：""''（）\(\)\[\]{}.,!?;:"\'\(\)\[\]{}]', '', text)
    # Convert to lowercase
    text = text.lower()
    # Remove extra whitespace
    text = re.sub(r'\s+', ' ', text)
    return text.strip()


def compare_math_answers(pred: str, target: str) -> bool:
    """Compare numerical answers with tolerance"""
    try:
        pred_num = parse(pred)
        target_num = parse(target)
        return verify(pred_num, target_num)
    except (ValueError, TypeError):
        return pred.strip().lower() == target.strip().lower()


def compare_choice_answers(pred: str, target: str) -> bool:
    """Compare choice answers (case-insensitive)"""
    if not pred or not target:
        return False
    return pred.upper() == target.upper()


def compare_question_answers(pred: str, target: str) -> bool:
    """Compare question answers using bidirectional containment"""
    if not pred or not target:
        return False
    
    pred_norm = normalize_for_comparison(pred)
    target_norm = normalize_for_comparison(target)
    
    # Bidirectional containment check
    return (pred_norm in target_norm) or (target_norm in pred_norm)


def general_math_reward(
    response: str,
    prompt: str,
    data_source: str,
    answer: str,
    use_box: bool = False
) -> Union[int, float]:
    """
    General reward function that handles multiple question types with unified format.
    All questions use <think></think> followed by Answer: format.
    
    Args:
        response: Model response text
        prompt: Input prompt (not used in current implementation)
        data_source: Data source identifier to determine question type
        answer: Ground truth answer
        use_box: Whether to prioritize LaTeX box format (for backward compatibility)
        
    Returns:
        1.0 if correct, 0.0 if incorrect
    """
    # Extract data source basename for matching
    data_source_name = os.path.basename(data_source).lower()
    
    # Determine question type based on data source
    if 'math' in data_source_name:
        question_type = 'math'
    elif 'choice' in data_source_name:
        question_type = 'choice'
    elif 'question' in data_source_name:
        question_type = 'question'
    else:
        # Default to math for backward compatibility
        question_type = 'math'
    
    # Extract answer based on question type
    if question_type == 'math':
        extracted_answer = extract_math_answer(response)
        if extracted_answer is None:
            return 0.0
        return 1.0 if compare_math_answers(extracted_answer, answer) else 0.0
        
    elif question_type == 'choice':
        extracted_answer = extract_answer_from_think_answer_format(response)
        if extracted_answer is None:
            return 0.0
        return 1.0 if compare_choice_answers(extracted_answer, answer) else 0.0
        
    elif question_type == 'question':
        extracted_answer = extract_answer_from_think_answer_format(response)
        if extracted_answer is None:
            return 0.0
        return 1.0 if compare_question_answers(extracted_answer, answer) else 0.0
    
    return 0.0



# VERL framework compatible interface - 统一的总函数调用入口
def compute_score(data_source: str, solution_str: str, ground_truth: str, extra_info=None) -> Union[int, float]:
    """
    VERL框架兼容的奖励函数接口 - 统一的总函数调用入口
    
    这个函数可以处理所有类型的奖励计算：
    - 数学题（Math）: 数值答案提取和比较
    - 选择题（Choice）: A/B/C/D 选项匹配  
    - 问答题（Question）: 文本内容匹配
    
    所有题目都使用统一格式：<think></think> Answer:
    
    Args:
        data_source: 数据源标识符，用于判断题目类型
        solution_str: 模型生成的回答（包含<think></think>和Answer:）
        ground_truth: 标准答案
        extra_info: 额外信息（可选，暂未使用）
        
    Returns:
        奖励分数：1.0表示正确，0.0表示错误
    """
    # 使用通用数学奖励函数处理所有类型
    return general_math_reward(
        response=solution_str,
        prompt="",  # 当前实现不使用prompt
        data_source=data_source,
        answer=ground_truth,
        use_box=True  # 支持LaTeX box格式
    )


# Example usage
if __name__ == "__main__":
    print("测试统一格式的奖励函数...")
    print("=" * 60)
    
    # Test cases for the unified format
    test_cases = [
        {
            "response": "<think>这是一个数学问题，需要计算。根据题目，答案是0.73728</think>\nAnswer: 0.73728",
            "answer": "0.73728",
            "data_source": "math_data.parquet",
            "expected": 1.0,
            "type": "数学题"
        },
        {
            "response": "<think>这是一个选择题，分析各个选项后</think>\nAnswer: A",
            "answer": "A",
            "data_source": "choice_data.parquet",
            "expected": 1.0,
            "type": "选择题"
        },
        {
            "response": "<think>这个问题需要详细分析</think>\nAnswer: 这是正确答案",
            "answer": "这是正确答案",
            "data_source": "question_data.parquet",
            "expected": 1.0,
            "type": "问答题"
        },
        {
            "response": "<think>计算过程：2+3=5</think>\nAnswer: \\box{5}",
            "answer": "5",
            "data_source": "math_data.parquet",
            "expected": 1.0,
            "type": "数学题(LaTeX)"
        },
        {
            "response": "<think>选择题分析：B是正确答案</think>\nAnswer: 答案是B",
            "answer": "B",
            "data_source": "choice_data.parquet",
            "expected": 1.0,
            "type": "选择题(中文)"
        },
        {
            "response": "<think>错误的计算</think>\nAnswer: 42",
            "answer": "24",
            "data_source": "math_data.parquet",
            "expected": 0.0,
            "type": "数学题(错误)"
        }
    ]
    
    print("测试 general_math_reward 函数：")
    print("-" * 40)
    
    for i, test_case in enumerate(test_cases):
        result = general_math_reward(
            test_case["response"],
            "",
            test_case["data_source"],
            test_case["answer"]
        )
        status = "✓ PASS" if result == test_case["expected"] else "✗ FAIL"
        print(f"测试 {i+1}: {status} | 类型: {test_case['type']}")
        print(f"  回答: {test_case['response'][:50]}...")
        print(f"  标准答案: {test_case['answer']}")
        print(f"  预期: {test_case['expected']}, 实际: {result}")
        print()
    
    print("=" * 60)
    print("测试统一的 compute_score 函数：")
    print("-" * 40)
    
    # Test the unified compute_score function
    for i, test_case in enumerate(test_cases):
        result = compute_score(
            test_case["data_source"],
            test_case["response"],
            test_case["answer"]
        )
        status = "✓ PASS" if result == test_case["expected"] else "✗ FAIL"
        print(f"测试 {i+1}: {status} | 类型: {test_case['type']}")
        print(f"  数据源: {test_case['data_source']}")
        print(f"  VERL接口得分: {result}")
        print()
    
    print("=" * 60)
    print("测试总结：")
    print("✓ 所有题目类型都可以通过 compute_score 统一调用")
    print("✓ 支持 <think></think> + Answer: 统一格式")
    print("✓ 自动识别数学题、选择题、问答题类型")
    print("✓ 兼容多种答案格式（LaTeX box、GSM8K、中文等）")
    print("✓ 适配VERL框架的训练脚本调用") 