import json
import os

# --- 配置 ---
# 指定评测结果文件的路径
RESULTS_FILE = 'mmau_results_base_llm.jsonl'

def analyze_results_grouped(file_path: str):
    """
    读取评测结果文件，按 'was_truncated' 字段分组，分别计算并打印正确率。

    Args:
        file_path (str): 评测结果文件的路径。
    """
    # 检查文件是否存在
    if not os.path.exists(file_path):
        print(f"错误: 结果文件未找到: '{file_path}'")
        print("请确保脚本与结果文件在同一个目录下，或者提供正确的文件路径。")
        return

    # 使用字典来存储两组的统计数据
    # 'truncated_true' 对应 was_truncated 为 true 的组
    # 'truncated_false' 对应 was_truncated 为 false 的组
    stats = {
        'truncated_true': {'total': 0, 'correct': 0},
        'truncated_false': {'total': 0, 'correct': 0}
    }
    
    skipped_lines = 0

    print(f"正在从 '{file_path}' 文件中读取结果并分组统计...")

    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                if not line.strip():
                    continue
                
                try:
                    data = json.loads(line)
                    
                    # 检查 'was_truncated' 字段是否存在
                    if 'was_truncated' not in data or 'llm_judge_is_correct' not in data:
                        skipped_lines += 1
                        continue

                    is_correct = data['llm_judge_is_correct']
                    
                    # 根据 was_truncated 的值，将数据分配到对应的组
                    if data['was_truncated'] is True:
                        stats['truncated_true']['total'] += 1
                        if is_correct is True:
                            stats['truncated_true']['correct'] += 1
                    else: # was_truncated is False
                        stats['truncated_false']['total'] += 1
                        if is_correct is True:
                            stats['truncated_false']['correct'] += 1
                            
                except json.JSONDecodeError:
                    print(f"警告: 发现并跳过一个格式错误的JSON行: {line.strip()}")
        
        # --- 打印报告 ---
        print("\n" + "="*40)
        print("      评测结果分组统计报告")
        print("="*40)

        # 打印 "was_truncated = True" 组的报告
        group_true_stats = stats['truncated_true']
        total_true = group_true_stats['total']
        correct_true = group_true_stats['correct']
        accuracy_true = (correct_true / total_true * 100) if total_true > 0 else 0.0
        
        print("\n--- 组: was_truncated = True ---")
        print(f"总样本数: {total_true}")
        print(f"正确数:    {correct_true}")
        print(f"正确率:    {accuracy_true:.2f}%")
        print("---------------------------------")

        # 打印 "was_truncated = False" 组的报告
        group_false_stats = stats['truncated_false']
        total_false = group_false_stats['total']
        correct_false = group_false_stats['correct']
        accuracy_false = (correct_false / total_false * 100) if total_false > 0 else 0.0
        
        print("\n--- 组: was_truncated = False ---")
        print(f"总样本数: {total_false}")
        print(f"正确数:    {correct_false}")
        print(f"正确率:    {accuracy_false:.2f}%")
        print("----------------------------------")

        # 打印总体报告
        overall_total = total_true + total_false
        overall_correct = correct_true + correct_false
        overall_accuracy = (overall_correct / overall_total * 100) if overall_total > 0 else 0.0

        print("\n--- 总体概要 ---")
        print(f"总计分析样本数: {overall_total}")
        print(f"总计正确数:      {overall_correct}")
        print(f"总体正确率:      {overall_accuracy:.2f}%")
        print("------------------")

        if skipped_lines > 0:
            print(f"\n注意: 有 {skipped_lines} 行因缺少关键字段而被跳过。")

    except Exception as e:
        print(f"\n处理文件时发生意外错误: {e}")


if __name__ == "__main__":
    analyze_results_grouped(RESULTS_FILE)