import re


def extract_solution(solution_str, method="strict", use_box=False):
    """
    从解答字符串中提取最终答案
    
    Args:
        solution_str: 解答字符串
        method: 提取方法，"strict" 或 "flexible"
        use_box: 是否优先使用LaTeX box格式
    
    Returns:
        提取的数值答案，如果提取失败返回None
    """
    assert method in ["strict", "flexible"]
    
    if method == "strict":
        # 严格模式：需要特定格式标记
        if use_box:
            # 优先尝试LaTeX box格式
            solutions = re.findall(r"\\box\{(\-?[0-9\.\,]+)\}", solution_str)
            if solutions:
                return solutions[-1].replace(",", "").replace("$", "")
        
        # 尝试GSM8K格式
        solutions = re.findall(r"#### (\-?[0-9\.\,]+)", solution_str)
        if solutions:
            return solutions[-1].replace(",", "").replace("$", "")
        
        # 如果没有启用box但数据中有box格式，也尝试提取
        if not use_box:
            solutions = re.findall(r"\\box\{(\-?[0-9\.\,]+)\}", solution_str)
            if solutions:
                return solutions[-1].replace(",", "").replace("$", "")
            
        return None
        
    elif method == "flexible":
        # 灵活模式：提取最后出现的数字
        answer = re.findall(r"(\-?[0-9\.\,]+)", solution_str)
        final_answer = None
        if len(answer) == 0:
            return None
        else:
            invalid_str = ["", "."]
            # 找到最后一个不是'.'的数字
            for final_answer in reversed(answer):
                if final_answer not in invalid_str:
                    break
        return final_answer.replace(",", "").replace("$", "") if final_answer else None


def general_math_reward(solution_str, ground_truth, method="strict", use_box=False, format_score=0.0, score=1.0):
    """
    通用数学奖励函数
    
    Args:
        solution_str: 模型生成的解答
        ground_truth: 标准答案
        method: 答案提取方法，"strict" 或 "flexible"
        use_box: 是否优先使用LaTeX box格式匹配
        format_score: 格式分数（答案格式正确但数值错误时给分）
        score: 完全正确时的得分
    
    Returns:
        得分（0到score之间）
    """
    answer = extract_solution(solution_str=solution_str, method=method, use_box=use_box)
    if answer is None:
        return 0
    else:
        # 将答案转换为浮点数进行比较
        try:
            answer_float = float(answer)
            ground_truth_float = float(ground_truth)
            
            if abs(answer_float - ground_truth_float) < 1e-6:  # 考虑浮点数精度
                return score
            else:
                return format_score
        except ValueError:
            # 如果无法转换为浮点数，进行字符串比较
            if answer == ground_truth:
                return score
            else:
                return format_score


def compute_score(data_source, solution_str, ground_truth, extra_info=None):
    """
    兼容VERL框架的奖励函数接口
    根据数据源自动选择合适的参数
    """
    # 根据数据源选择策略
    if "gsm8k" in data_source.lower():
        return general_math_reward(solution_str, ground_truth, method="strict", use_box=False)
    else:
        return general_math_reward(solution_str, ground_truth, method="strict", use_box=True)


# 测试函数
if __name__ == "__main__":
    # 测试用例
    test_cases = [
        # LaTeX box格式测试
        ("解答过程...\n\\box{42}", "42", True, "box"),
        ("解答过程...\n\\box{0.73728}", "0.73728", True, "box"),
        ("解答过程...\n\\box{42}", "43", False, "box"),
        
        # GSM8K格式测试
        ("解答过程...\n#### 42", "42", True, "gsm8k"),
        ("解答过程...\n#### 42.5", "42.5", True, "gsm8k"),
        ("解答过程...\n#### 42", "43", False, "gsm8k"),
        
        # 中文格式测试
        ("解答过程...\n答案：42", "42", True, "math"),
        ("解答过程...\n最终答案：42", "42", True, "math"),
        
        # 混合格式测试
        ("解答过程...\n答案：42\n\\box{42}", "42", True, "box"),
        ("解答过程...\n#### 42\n\\box{42}", "42", True, "box"),
        
        # 浮点数测试
        ("答案：42.0", "42", True, "math"),
        ("\\box{42.000}", "42.0", True, "box"),
        
        # 错误案例
        ("没有数字", "42", False, "math"),
        ("解答过程", "42", False, "gsm8k"),
    ]
    
    print("测试通用数学奖励函数...")
    print("=" * 50)
    
    for i, (solution, ground_truth, expected, data_source) in enumerate(test_cases):
        # 测试通用函数
        if "box" in data_source:
            score = general_math_reward(solution, ground_truth, method="strict", use_box=True)
        elif "gsm8k" in data_source:
            score = general_math_reward(solution, ground_truth, method="strict", use_box=False)
        else:
            score = general_math_reward(solution, ground_truth, method="flexible", use_box=True)
            
        # 测试VERL接口
        verl_score = compute_score(data_source, solution, ground_truth)
        
        result = "✓" if (score > 0) == expected else "✗"
        verl_result = "✓" if (verl_score > 0) == expected else "✗"
        
        print(f"测试 {i+1}: {result} {verl_result}")
        print(f"  数据源: {data_source}")
        print(f"  解答: {solution}")
        print(f"  标准答案: {ground_truth}")
        print(f"  通用函数得分: {score}")
        print(f"  VERL接口得分: {verl_score}")
        print(f"  预期结果: {'正确' if expected else '错误'}")
        print() 