import json
import argparse
from collections import defaultdict
import os
from tqdm import tqdm
import numpy as np
import pandas as pd

def parse_args():
    parser = argparse.ArgumentParser(description='计算每个ID的pass@k指标')
    parser.add_argument('--input', type=str, required=True,
                       help='输入的jsonl文件路径')
    parser.add_argument('--output', type=str,
                       help='输出的CSV文件路径，如果不指定则在输入文件名后加上_metrics.csv')
    return parser.parse_args()

def calculate_pass_at_k(results, k):
    """计算pass@k
    
    对于每个问题，如果在k次尝试中至少有一次成功，就算通过
    """
    if len(results) < k:
        return 0.0
    return 1.0 if any(results[:k]) else 0.0

def process_jsonl(input_file: str, output_file: str):
    # 用于存储每个ID的结果
    records = defaultdict(list)  # {id: [result1, result2, ...]}
    
    print(f"正在读取文件: {input_file}")
    
    # 首先计算总行数用于进度条
    with open(input_file, 'r') as f:
        total_lines = sum(1 for _ in f)
    
    # 读取并处理每一行
    with open(input_file, 'r') as f:
        for line in tqdm(f, total=total_lines, desc="处理数据"):
            try:
                # 尝试解析JSON
                record = json.loads(line.strip())
                
                # 确保必要的字段存在
                if 'id' not in record or 'result' not in record or 'sample_index' not in record:
                    print(f"警告: 记录缺少必要字段")
                    continue
                
                record_id = record['id']
                result = record['result']
                sample_index = record['sample_index']
                
                # 将结果存储为布尔值
                if isinstance(result, str):
                    result = (result.lower() == 'true')
                
                # 将结果添加到对应ID的列表中
                records[record_id].append((sample_index, result))
                
            except json.JSONDecodeError as e:
                print(f"错误: JSON解析失败 - {str(e)}")
                continue
    
    # 计算每个ID的指标
    metrics = []
    for record_id, results in records.items():
        # 按sample_index排序
        sorted_results = [r[1] for r in sorted(results, key=lambda x: x[0])]
        
        # 计算不同k值的pass@k
        pass_at_1 = calculate_pass_at_k(sorted_results, 1)
        pass_at_10 = calculate_pass_at_k(sorted_results, 10)
        pass_at_100 = calculate_pass_at_k(sorted_results, 100)
        
        # 计算总的成功率
        success_rate = sum(sorted_results) / len(sorted_results) if sorted_results else 0
        
        metrics.append({
            'id': record_id,
            'pass@1': pass_at_1,
            'pass@10': pass_at_10,
            'pass@100': pass_at_100,
            'success_rate': success_rate,
            'total_samples': len(sorted_results)
        })
    
    # 转换为DataFrame并计算平均值
    df = pd.DataFrame(metrics)
    
    # 计算总体指标
    overall_metrics = {
        'id': 'AVERAGE',
        'pass@1': df['pass@1'].mean(),
        'pass@10': df['pass@10'].mean(),
        'pass@100': df['pass@100'].mean(),
        'success_rate': df['success_rate'].mean(),
        'total_samples': df['total_samples'].mean()
    }
    
    # 添加总体指标到DataFrame
    df = pd.concat([df, pd.DataFrame([overall_metrics])], ignore_index=True)
    
    # 保存结果
    print(f"\n正在保存结果到: {output_file}")
    df.to_csv(output_file, index=False, float_format='%.4f')
    
    # 打印总体指标
    print("\n总体指标:")
    print(f"Pass@1: {overall_metrics['pass@1']:.4f}")
    print(f"Pass@10: {overall_metrics['pass@10']:.4f}")
    print(f"Pass@100: {overall_metrics['pass@100']:.4f}")
    print(f"平均成功率: {overall_metrics['success_rate']:.4f}")
    print(f"平均样本数: {overall_metrics['total_samples']:.1f}")
    
    # 检查样本数不足的ID
    incomplete_ids = df[df['total_samples'] < 100]['id'].tolist()
    if incomplete_ids and 'AVERAGE' in incomplete_ids:
        incomplete_ids.remove('AVERAGE')
    
    if incomplete_ids:
        print(f"\n警告：以下ID的样本数量不足100个 (共{len(incomplete_ids)}个):")
        for record_id in incomplete_ids:
            count = df[df['id'] == record_id]['total_samples'].iloc[0]
            print(f"ID: {record_id}, 数量: {int(count)}")

def main():
    args = parse_args()
    
    # 如果没有指定输出文件，则在输入文件名后加上_metrics.csv
    if args.output is None:
        base, _ = os.path.splitext(args.input)
        args.output = f"{base}_metrics.csv"
    
    # 确保输出目录存在
    os.makedirs(os.path.dirname(args.output) if os.path.dirname(args.output) else '.', exist_ok=True)
    
    # 处理文件
    process_jsonl(args.input, args.output)

if __name__ == "__main__":
    main()
    
    

# 指定输出文件
# python calculate_pass_rate.py --input /home/superbench/xinzhang3/haoling/epicoder2/data/std_format/train_sample_1k_results_greedy.jsonl --output pass_rate_7Bins_1k_greedy.csv