import json
import re
from collections import defaultdict
from math_verify import parse

def extract_boxed_answer(text):
    """
    从文本中提取 \\boxed{} 格式的答案
    如果提取不到则返回 None，表示未完成
    """
    try:
        # 使用 math_verify 的 parse 函数来提取答案
        result = parse(text)
        return result
    except:
        return None

def is_incomplete(text):
    """
    判断模型回答是否未完成
    未完成的标志是模型最后没有类似 $\\n\\boxed{\\n$$ 的格式
    """
    if not text:
        return True
    
    # 检查是否包含 \\boxed{} 格式
    boxed_pattern = r'\\boxed\{[^}]*\}'
    if not re.search(boxed_pattern, text):
        return True
    
    # 尝试提取答案，如果提取失败则认为未完成
    extracted = extract_boxed_answer(text)
    return extracted is None

def analyze_incomplete_ratio(file_path):
    """
    分析 all_wrong_questions.json 文件中按 budget 和 model 分类的未完成比例
    """
    print("正在加载数据...")
    
    # 统计数据结构: {model: {budget: {'total': int, 'incomplete': int}}}
    stats = defaultdict(lambda: defaultdict(lambda: {'total': 0, 'incomplete': 0}))
    
    # 读取 JSON 文件
    with open(file_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    
    print(f"总共加载了 {len(data)} 个问题")
    
    # 遍历每个问题
    for i, question in enumerate(data):
        if i % 1000 == 0:
            print(f"处理进度: {i}/{len(data)}")
        
        model_results = question.get('model_results', {})
        
        # 遍历每个模型的结果
        for model_name, model_data in model_results.items():
            # 遍历每个 budget 的结果
            for budget, budget_data in model_data.items():
                generated_text = budget_data.get('generated_text', '')
                
                # 统计总数
                stats[model_name][budget]['total'] += 1
                
                # 检查是否未完成
                if is_incomplete(generated_text):
                    stats[model_name][budget]['incomplete'] += 1
    
    # 计算并输出结果
    print("\n=== 未完成比例分析结果 ===")
    print("格式: 模型 | Budget | 总数 | 未完成数 | 未完成比例")
    print("-" * 60)
    
    # 按模型和 budget 排序输出
    for model_name in sorted(stats.keys()):
        for budget in sorted(stats[model_name].keys(), key=int):
            total = stats[model_name][budget]['total']
            incomplete = stats[model_name][budget]['incomplete']
            ratio = incomplete / total if total > 0 else 0
            
            print(f"{model_name:15} | {budget:6} | {total:4} | {incomplete:6} | {ratio:.3f}")
    
    # 计算总体统计
    print("\n=== 总体统计 ===")
    total_questions = sum(sum(model_stats.values(), start={'total': 0, 'incomplete': 0})['total'] 
                         for model_stats in stats.values())
    total_incomplete = sum(sum(model_stats.values(), start={'total': 0, 'incomplete': 0})['incomplete'] 
                          for model_stats in stats.values())
    overall_ratio = total_incomplete / total_questions if total_questions > 0 else 0
    
    print(f"总问题数: {total_questions}")
    print(f"总未完成数: {total_incomplete}")
    print(f"总体未完成比例: {overall_ratio:.3f}")
    
    # 按模型统计
    print("\n=== 按模型统计 ===")
    for model_name in sorted(stats.keys()):
        model_total = sum(budget_stats['total'] for budget_stats in stats[model_name].values())
        model_incomplete = sum(budget_stats['incomplete'] for budget_stats in stats[model_name].values())
        model_ratio = model_incomplete / model_total if model_total > 0 else 0
        print(f"{model_name:15} | {model_total:4} | {model_incomplete:6} | {model_ratio:.3f}")
    
    # 按 budget 统计
    print("\n=== 按 Budget 统计 ===")
    budget_stats = defaultdict(lambda: {'total': 0, 'incomplete': 0})
    for model_stats in stats.values():
        for budget, budget_data in model_stats.items():
            budget_stats[budget]['total'] += budget_data['total']
            budget_stats[budget]['incomplete'] += budget_data['incomplete']
    
    for budget in sorted(budget_stats.keys(), key=int):
        total = budget_stats[budget]['total']
        incomplete = budget_stats[budget]['incomplete']
        ratio = incomplete / total if total > 0 else 0
        print(f"{budget:6} | {total:4} | {incomplete:6} | {ratio:.3f}")

if __name__ == "__main__":
    file_path = "/mnt/shared-storage-user/p1-shared/wangfuting/codes/rl/LUFFY/old_version/eval_scripts/all_wrong_questions.json"
    analyze_incomplete_ratio(file_path)
