import pandas as pd
import numpy as np
import os

def summarize_benchmark_results(input_files, output_file):
    """
    Process benchmark results, organizing results from three repeated experiments into "mean ± standard deviation" format
    
    Parameters:
    input_files: Dictionary containing paths to three benchmark result files
    output_file: Output file path
    """
    # Create output directory
    os.makedirs(os.path.dirname(output_file), exist_ok=True)
    
    # Store summarized results for all benchmarks
    all_summary_results = []
    
    # Process each benchmark
    for benchmark_name, file_path in input_files.items():
        print(f"Processing results for {benchmark_name}...")
        
        # Read data
        df = pd.read_csv(file_path)
        
        # Filter out IRT-2PL (no data available)
        df = df[df['prediction_method'] != 'irt_2pl_pred']
        
        # Group by training ratio and prediction method, calculate mean and standard deviation
        grouped = df.groupby(['train_ratio', 'prediction_method']).agg({
            'auc': ['mean', 'std', 'count'],
            'accuracy': ['mean', 'std', 'count']
        }).reset_index()
        
        # Flatten column names
        grouped.columns = ['train_ratio', 'prediction_method', 'auc_mean', 'auc_std', 'auc_count', 'accuracy_mean', 'accuracy_std', 'accuracy_count']
        
        # Add benchmark name
        grouped['benchmark'] = benchmark_name
        
        # Format as "mean ± standard deviation"
        grouped['auc_formatted'] = grouped.apply(
            lambda row: f"{row['auc_mean']:.6f} ± {row['auc_std']:.6f}" if not np.isnan(row['auc_std']) else f"{row['auc_mean']:.6f}", 
            axis=1
        )
        grouped['accuracy_formatted'] = grouped.apply(
            lambda row: f"{row['accuracy_mean']:.6f} ± {row['accuracy_std']:.6f}" if not np.isnan(row['accuracy_std']) else f"{row['accuracy_mean']:.6f}", 
            axis=1
        )
        
        # Add to results list
        all_summary_results.append(grouped)
    
    # Merge results from all benchmarks
    final_summary = pd.concat(all_summary_results, ignore_index=True)
    
    # Save results
    final_summary.to_csv(output_file, index=False)
    print(f"Results saved to: {output_file}")
    
    # Print some statistics
    print("\nSummary results:")
    print(f"Total processed {len(final_summary)} rows of data")
    print(f"Including {final_summary['benchmark'].nunique()} benchmarks")
    print(f"Including {final_summary['prediction_method'].nunique()} prediction methods")
    
    return final_summary

def main():
    # Define input file paths
    input_files = {
        "CEVAL": "data/prediction_metrics_ceval.csv",
        "CSQA": "data/prediction_metrics_csqa.csv",
        "MMLU": "data/prediction_metrics_mmlu.csv"
    }
    
    # Define output file path
    output_file = "results/benchmark_results_summary.csv"
    
    # Summarize results
    summary = summarize_benchmark_results(input_files, output_file)
    
    # Display partial results
    print("\nFirst 20 rows of results:")
    print(summary.head(20)[['benchmark', 'train_ratio', 'prediction_method', 'auc_formatted', 'accuracy_formatted']])

if __name__ == "__main__":
    main()