import json
from collections import defaultdict
from pathlib import Path
import sys
# 文件路径
file_path = sys.argv[1]

# 读取 JSONL 文件
def read_jsonl(file_path):
    """读取 JSONL 文件"""
    data = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            if line:
                data.append(json.loads(line))
    return data

# 按照 data source 分组计算平均 correctness
def calculate_avg_correctness_by_source(data):
    """按照 data source 计算平均 correctness"""
    source_stats = defaultdict(lambda: {'correct': 0, 'total': 0, 'correctness_sum': 0})
    
    for item in data:
        # 尝试不同的可能的字段名
        data_source = item.get('data_source') or item.get('source') or item.get('dataset') or 'unknown'
        correctness = item.get('correctness') or item.get('correct') or 0
        
        source_stats[data_source]['total'] += 1
        source_stats[data_source]['correctness_sum'] += correctness
        if correctness > 0:
            source_stats[data_source]['correct'] += 1
    
    # 计算平均值
    results = {}
    for source, stats in source_stats.items():
        avg_correctness = stats['correctness_sum'] / stats['total'] if stats['total'] > 0 else 0
        results[source] = {
            'avg_correctness': avg_correctness,
            'total_samples': stats['total'],
            'correct_samples': stats['correct'],
            'accuracy': stats['correct'] / stats['total'] if stats['total'] > 0 else 0
        }
    
    return results

if __name__ == "__main__":
    try:
        # 检查文件是否存在
        if not Path(file_path).exists():
            print(f"错误: 文件不存在: {file_path}")
            exit(1)
        
        # 读取数据
        print(f"正在读取文件: {file_path}")
        data = read_jsonl(file_path)
        print(f"共读取 {len(data)} 条记录\n")
        
        # 打印第一条记录的结构（用于调试）
        if data:
            print("第一条记录的字段:")
            print(json.dumps(data[0], indent=2, ensure_ascii=False))
            print("\n" + "="*80 + "\n")
        
        # 计算每个数据源的平均 correctness
        results = calculate_avg_correctness_by_source(data)
        
        # 按照平均 correctness 排序并打印结果
        print("各数据集的统计结果:")
        print("-" * 80)
        print(f"{'数据集':<30} {'平均Correctness':<20} {'准确率':<15} {'样本数':<10}")
        print("-" * 80)
        
        for source in sorted(results.keys(), key=lambda x: results[x]['avg_correctness'], reverse=True):
            stats = results[source]
            print(f"{source:<30} {stats['avg_correctness']:<20.4f} {stats['accuracy']:<15.2%} {stats['total_samples']:<10}")
        
        print("-" * 80)
        
        # 计算总体统计
        total_samples = sum(stats['total_samples'] for stats in results.values())
        total_correctness = sum(stats['avg_correctness'] * stats['total_samples'] for stats in results.values())
        overall_avg = total_correctness / total_samples if total_samples > 0 else 0
        
        print(f"\n总体统计:")
        print(f"  总样本数: {total_samples}")
        print(f"  总体平均 Correctness: {overall_avg:.4f}")
        print(f"  数据集数量: {len(results)}")
        
    except Exception as e:
        print(f"错误: {e}")
        import traceback
        traceback.print_exc()
