import pandas as pd
import json
import os
from typing import List, Dict, Any

# 文件路径
parquet_file = "/mnt/shared-storage-user/p1-shared/wangfuting/codes/project_tts_extrapolation/data/dapo-math-17k_qwen3_polaris.parquet"
jsonl_file = "/mnt/shared-storage-user/p1-shared/wangfuting/shared/models/verl-041-result/verl-qwen3-8b-new/qwen3-30b-a3b-dapo-math-17k-n16/valid/0_16384.jsonl"

def load_jsonl(file_path: str) -> List[Dict]:
    """加载JSONL文件"""
    data = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            if line:
                data.append(json.loads(line))
    return data

def calculate_accuracy_from_8_rollouts(jsonl_items: List[Dict]) -> float:
    """
    从8条jsonl数据计算准确率
    
    Args:
        jsonl_items: 8条jsonl数据
        
    Returns:
        准确率 (0-1之间)
    """
    if len(jsonl_items) != 8:
        print(f"Warning: 期望8条数据，但得到{len(jsonl_items)}条")
        return 0.0
    
    correct_count = 0
    
    for item in jsonl_items:
        # 尝试多种可能的字段名
        accuracy = item['score']
        if accuracy > 0:
            correct_count += 1
    
    # 准确率 = 正确数 / 8
    accuracy = correct_count / 8.0
    return accuracy

def filter_data_by_difficulty(
    parquet_file: str,
    jsonl_file: str,
    output_dir: str = None,
    lower_threshold: float = 0,
    upper_threshold: float = 1,
):
    """
    根据rollout 8次的准确率将数据按阈值分为三个数据集
    
    Args:
        parquet_file: 原始parquet文件路径
        jsonl_file: jsonl文件路径，每8条对应parquet中的一条
        output_dir: 输出目录，如果为None则使用parquet文件所在目录
        lower_threshold: 低准确率上界，默认0.25
        upper_threshold: 中准确率上界，默认0.75
    """
    # 加载数据
    print("正在加载parquet文件...")
    df_parquet = pd.read_parquet(parquet_file)
    print(f"Parquet文件包含 {len(df_parquet)} 行数据")
    
    print("正在加载jsonl文件...")
    jsonl_data = load_jsonl(jsonl_file)
    print(f"JSONL文件包含 {len(jsonl_data)} 行数据")
    
    # 检查数据量是否匹配
    expected_jsonl_count = len(df_parquet) * 8
    if len(jsonl_data) != expected_jsonl_count:
        print(f"Warning: JSONL数据量 ({len(jsonl_data)}) 与期望值 ({expected_jsonl_count}) 不匹配")
        print(f"将处理前 {min(len(jsonl_data) // 8, len(df_parquet))} 条parquet数据")
    
    # 计算每个样本的准确率
    print("\n正在计算每个样本的准确率...")
    accuracies = []
    processed_count = min(len(jsonl_data) // 8, len(df_parquet))
    
    for i in range(processed_count):
        # 获取对应的8条jsonl数据
        start_idx = i * 8
        end_idx = start_idx + 8
        jsonl_items = jsonl_data[start_idx:end_idx]
        # breakpoint()
        # 计算准确率
        accuracy = calculate_accuracy_from_8_rollouts(jsonl_items)
        accuracies.append(accuracy)
        
        if (i + 1) % 100 == 0:
            print(f"已处理 {i + 1}/{processed_count} 个样本")
    
    # 只处理有准确率数据的行
    df_filtered = df_parquet.iloc[:processed_count].copy()
    df_filtered['rollout_accuracy'] = accuracies
    
    # 根据准确率按阈值分类
    # 低准确率: <= lower_threshold
    # 中等准确率: (lower_threshold, upper_threshold]
    # 高准确率: > upper_threshold
    df_low = df_filtered[df_filtered['rollout_accuracy'] <= lower_threshold].copy()
    df_mid = df_filtered[
        (df_filtered['rollout_accuracy'] > lower_threshold)
        & (df_filtered['rollout_accuracy'] < upper_threshold)
    ].copy()
    df_high = df_filtered[df_filtered['rollout_accuracy'] >= upper_threshold].copy()
    
    # 打印统计信息
    print("\n" + "="*60)
    print("数据分类统计")
    print("="*60)
    print(f"划分阈值: low<={lower_threshold} < mid < {upper_threshold} <= high")
    print(f"总样本数: {len(df_filtered)}")
    print(f"低准确率样本 (<= {lower_threshold}): {len(df_low)} ({len(df_low)/len(df_filtered)*100:.2f}%)")
    print(f"中准确率样本 ({lower_threshold}, {upper_threshold}]: {len(df_mid)} ({len(df_mid)/len(df_filtered)*100:.2f}%)")
    print(f"高准确率样本 (>= {upper_threshold}): {len(df_high)} ({len(df_high)/len(df_filtered)*100:.2f}%)")
    print("\n准确率统计:")
    print(f"  平均准确率: {df_filtered['rollout_accuracy'].mean():.4f}")
    if len(df_low) > 0:
        print(f"  低准确率样本平均准确率: {df_low['rollout_accuracy'].mean():.4f}")
        print(f"  低准确率样本准确率范围: [{df_low['rollout_accuracy'].min():.4f}, {df_low['rollout_accuracy'].max():.4f}]")
    else:
        print("  低准确率样本: 无数据")
    if len(df_high) > 0:
        print(f"  高准确率样本平均准确率: {df_high['rollout_accuracy'].mean():.4f}")
        print(f"  高准确率样本准确率范围: [{df_high['rollout_accuracy'].min():.4f}, {df_high['rollout_accuracy'].max():.4f}]")
    else:
        print("  高准确率样本: 无数据")
    print("="*60)
    
    # 确定输出目录
    if output_dir is None:
        output_dir = os.path.dirname(parquet_file)
    
    os.makedirs(output_dir, exist_ok=True)
    
    # 生成输出文件名
    base_name = os.path.splitext(os.path.basename(parquet_file))[0]
    low_file = os.path.join(output_dir, f"{base_name}_low_leq{lower_threshold}.parquet")
    mid_file = os.path.join(
        output_dir, f"{base_name}_mid_gt{lower_threshold}_leq{upper_threshold}.parquet"
    )
    high_file = os.path.join(output_dir, f"{base_name}_high_gt{upper_threshold}.parquet")
    
    # 移除临时添加的rollout_accuracy列（如果需要保留可以注释掉）
    # df_low = df_low.drop(columns=['rollout_accuracy'])
    # df_high = df_high.drop(columns=['rollout_accuracy'])
    
    # 保存文件
    print(f"\n正在保存文件...")
    if len(df_low) > 0:
        df_low.to_parquet(low_file, index=False)
        print(f"低准确率样本已保存到: {low_file} ({len(df_low)} 条)")
    else:
        print("没有低准确率样本，跳过保存")
    
    if len(df_mid) > 0:
        df_mid.to_parquet(mid_file, index=False)
        print(f"中准确率样本已保存到: {mid_file} ({len(df_mid)} 条)")
    else:
        print("没有中准确率样本，跳过保存")
    
    if len(df_high) > 0:
        df_high.to_parquet(high_file, index=False)
        print(f"高准确率样本已保存到: {high_file} ({len(df_high)} 条)")
    else:
        print("没有高准确率样本，跳过保存")
    
    print("\n处理完成！")

if __name__ == "__main__":
  
    
    # 执行数据过滤
    filter_data_by_difficulty(parquet_file, jsonl_file)
