import pandas as pd
import numpy as np
import os
from collections import defaultdict
from typing import List, Dict, Optional

def process_benchmark_group(file_paths: List[str]) -> Optional[Dict[str, float]]:
    
    if not file_paths:
        return None

    list_of_dfs = []
    for path in file_paths:
        try:
            df_part = pd.read_json(path, lines=True)
            list_of_dfs.append(df_part)
        except Exception as e:
            print(f"警告：读取文件失败 {path}。错误: {e}")
            continue
    if not list_of_dfs:
        return None
    
    combined_df = pd.concat(list_of_dfs, ignore_index=True)
    lengths = [combined_df['usage'][ii]['completion_tokens'] for ii in range(len(combined_df))]
    
    return {
        "length_mean": np.mean(lengths), 
        "length_std": np.std(lengths), 
    }

if __name__ == "__main__":

    data_paths = [
        # "ROOT/evalscope-main/True/predictions/baseline-lora",
        # "ROOT/evalscope-main/True/predictions/ours-lora"
        # "ROOT/evalscope-main/outputs/20250824_142001-openRLHF-llama-sft/predictions/openRLHF_llama",
        "ROOT/evalscope-main/True/predictions/Meta-Llama-3.1-8B-Instruct",
        "ROOT/evalscope-main/outputs/v20-0.1-meta-Llama31-8B-it/predictions/v20-0.1-meta-Llama31-8B-it-checkpoint-600",
        "ROOT/evalscope-main/outputs/v18-SK0.0-meta-Llama31-8B-it/predictions/v18-0.0-checkpoint-300",
        # "ROOT/evalscope-main/outputs/v22-0.0-OpenRLHF-Llama3-8B-SFT/predictions/v22-0.0-OpenRLHF-Llama3-8B-SFT-checkpoint-200",
        # "ROOT/evalscope-main/outputs/20250824_142001-openRLHF-llama-sft/predictions/openRLHF_llama",
        # "ROOT/evalscope-main/outputs/v21-ours1.0-OpenRLHF-Llama3-8B-SFT/consolidated_results/predictions/v21-1.0-OpenRLHF-Llama3-8B-SFT-checkpoint-300",
        # "ROOT/evalscope-main/outputs/v22-0.0-OpenRLHF-Llama3-8B-SFT/predictions/v22-0.0-OpenRLHF-Llama3-8B-SFT-checkpoint-200",
    ]

    model_names = [
        # "SK_Baseline",
        # "Vanilla_Baseline",
        # "Ours"
        # "llama31_baseline",
        "openrlhf_llama_baseline",
        "Ours",
        "SK_Baseline",
    ]

    bench_names = [
        "gsm8k", "hellaswag", "ifeval", "mmlu", "process", "race", "bbh", "humaneval", "trivia"
    ]

    all_models_info = {}

    for data_path, model_name in zip(data_paths, model_names):
        print(f"\n--- 正在处理模型: {model_name} ---")
        
        grouped_files = defaultdict(list)
        try:
            all_files = os.listdir(data_path)
        except FileNotFoundError:
            print(f"错误：找不到目录 {data_path}，跳过该模型。")
            continue

        for filename in all_files:
            if not filename.endswith('.jsonl') and not filename.endswith('.json'):
                continue 

            for bench in bench_names:
                if filename.startswith(bench):
                    full_path = os.path.join(data_path, filename)
                    grouped_files[bench].append(full_path)
                    break 

        model_results = {}
        for bench, files in grouped_files.items():
            print(f"  - 正在分析 benchmark: {bench} ({len(files)} 个文件)")
            stats = process_benchmark_group(files)
            if stats: 
                model_results[bench] = stats
        
        all_models_info[model_name] = model_results
    
    print("\n\n" + "="*20 + " 最终统计结果 " + "="*20)
    
    display_data = []
    for model, benches in all_models_info.items():
        for bench, stats in benches.items():
            display_data.append({
                'model': model,
                'benchmark': bench,
                'mean_length': stats['length_mean'],
                'std_length': stats['length_std']
            })
            
    if display_data:
        final_df = pd.DataFrame(display_data)
        pivot_df = final_df.pivot_table(
            index='benchmark', 
            columns='model', 
            values=['mean_length', 'std_length']
        )
        pivot_df.loc['average'] = pivot_df.mean()
        print(pivot_df.round(2)) 
    else:
        print("未能生成任何统计信息。")