import json
import argparse
from collections import defaultdict
import os
from tqdm import tqdm
import numpy as np
import pandas as pd

def parse_args():
    parser = argparse.ArgumentParser(description='计算每个ID及数据集的pass@k指标')
    parser.add_argument('--input', type=str, required=True,
                       help='输入的jsonl文件路径')
    parser.add_argument('--output', type=str,
                       help='输出的CSV文件路径，如果不指定则在输入文件名后加上_metrics.csv')
    return parser.parse_args()

def calculate_pass_at_k(results, k):
    """计算pass@k
    
    对于每个问题，如果在k次尝试中至少有一次成功，就算通过
    """
    if len(results) < k:
        return 0.0
    return 1.0 if any(results[:k]) else 0.0

def process_jsonl(input_file: str, output_file: str):
    # 存储结构: {dataset: {problem_id: [results]}}
    datasets = defaultdict(lambda: defaultdict(list))
    
    print(f"正在读取文件: {input_file}")
    
    # 计算总行数以用于进度条
    with open(input_file, 'r') as f:
        total_lines = sum(1 for _ in f)
    
    # 读取并处理每一行
    with open(input_file, 'r') as f:
        for line in tqdm(f, total=total_lines, desc="处理数据"):
            try:
                record = json.loads(line.strip())
                
                if 'id' not in record or 'result' not in record or 'sample_index' not in record:
                    print(f"警告: 记录缺少必要字段")
                    continue
                
                record_id = record['id']
                result = record['result']
                sample_index = record['sample_index']
                
                # 分割数据集和问题ID
                if '/' not in record_id:
                    print(f"警告: ID格式错误，缺少'/': {record_id}")
                    continue
                dataset, problem_id = record_id.split('/', 1)
                
                # 转换结果为布尔值
                if isinstance(result, str):
                    result = result.lower() == 'true'
                elif isinstance(result, int):
                    result = bool(result)
                
                datasets[dataset][problem_id].append((sample_index, result))
                
            except json.JSONDecodeError as e:
                print(f"JSON解析错误: {str(e)}")
                continue
    
    # 计算每个问题的指标
    metrics = []
    for dataset in datasets:
        for problem_id in datasets[dataset]:
            results = datasets[dataset][problem_id]
            # 按sample_index排序
            sorted_results = [r[1] for r in sorted(results, key=lambda x: x[0])]
            
            pass_at_1 = calculate_pass_at_k(sorted_results, 1)
            pass_at_10 = calculate_pass_at_k(sorted_results, 10)
            pass_at_100 = calculate_pass_at_k(sorted_results, 100)
            success_rate = sum(sorted_results) / len(sorted_results) if sorted_results else 0
            total_samples = len(sorted_results)
            
            metrics.append({
                'dataset': dataset,
                'problem_id': problem_id,
                'pass@1': pass_at_1,
                'pass@10': pass_at_10,
                'pass@100': pass_at_100,
                'success_rate': success_rate,
                'total_samples': total_samples
            })
    
    # 转换为DataFrame
    if not metrics:
        print("错误: 没有有效数据可处理")
        return
    df = pd.DataFrame(metrics)
    
    # 计算各数据集的平均指标
    dataset_avg = df.groupby('dataset').agg({
        'pass@1': 'mean',
        'pass@10': 'mean',
        'pass@100': 'mean',
        'success_rate': 'mean',
        'total_samples': 'mean'
    }).reset_index()
    dataset_avg['problem_id'] = 'DATASET_AVERAGE'
    
    # 计算总体平均
    overall_avg = pd.DataFrame({
        'dataset': ['OVERALL'],
        'problem_id': ['AVERAGE'],
        'pass@1': [df['pass@1'].mean()],
        'pass@10': [df['pass@10'].mean()],
        'pass@100': [df['pass@100'].mean()],
        'success_rate': [df['success_rate'].mean()],
        'total_samples': [df['total_samples'].mean()]
    })
    
    # 合并所有数据
    final_df = pd.concat([df, dataset_avg, overall_avg], ignore_index=True)
    
    # 保存结果
    print(f"\n保存结果到: {output_file}")
    final_df.to_csv(output_file, index=False, float_format='%.4f')
    
    # 打印关键指标
    print("\n关键指标:")
    print(f"总体Pass@1: {overall_avg['pass@1'].values[0]:.4f}")
    print(f"总体Pass@10: {overall_avg['pass@10'].values[0]:.4f}")
    print(f"总体Pass@100: {overall_avg['pass@100'].values[0]:.4f}")
    print(f"总体平均成功率: {overall_avg['success_rate'].values[0]:.4f}")
    
    # 检查样本数不足的ID
    incomplete = final_df[
        (final_df['total_samples'] < 100) &
        (final_df['problem_id'] != 'DATASET_AVERAGE') &
        (final_df['dataset'] != 'OVERALL')
    ]
    if not incomplete.empty:
        print("\n警告: 以下问题的样本数不足100:")
        for _, row in incomplete.iterrows():
            print(f"数据集: {row['dataset']}, 问题ID: {row['problem_id']}, 样本数: {int(row['total_samples'])}")

def main():
    args = parse_args()
    
    # 设置输出路径
    if args.output is None:
        base, _ = os.path.splitext(args.input)
        args.output = f"{base}_metrics.csv"
    
    # 创建输出目录
    os.makedirs(os.path.dirname(args.output) or '.', exist_ok=True)
    
    process_jsonl(args.input, args.output)

if __name__ == "__main__":
    main()