import pandas as pd
import numpy as np
import os
import glob
from pathlib import Path

def parse_filename(filename):
    """Extract dataset and strategy from filename."""
    parts = Path(filename).stem.split('_')
    dataset = parts[1]
    strategy = '_'.join(parts[2:-1])
    return dataset, strategy

def aggregate_summary_logs(summary_logs_dir):
    """Aggregate summary logs across seeds and compute statistics."""
    summary_files = glob.glob(os.path.join(summary_logs_dir, "summary_*.csv"))
    
    grouped_files = {}
    for file in summary_files:
        dataset, strategy = parse_filename(file)
        key = (dataset, strategy)
        if key not in grouped_files:
            grouped_files[key] = []
        grouped_files[key].append(file)
    
    results = []
    for (dataset, strategy), files in grouped_files.items():
        dfs = [pd.read_csv(f) for f in files]
        combined_df = pd.concat(dfs, keys=range(len(dfs)), names=['seed'])
        
        stats = combined_df.groupby('config_name').agg({
            'avg_exp_miscov': ['mean', 'std', 'max'],
            'avg_exp_len': ['mean', 'std', 'max']
        })
        
        stats.columns = ['_'.join(col).strip() for col in stats.columns.values]
        stats['dataset'] = dataset
        stats['strategy'] = strategy
        
        for metric in ['avg_exp_miscov', 'avg_exp_len']:
            stats[f'{metric}_sem'] = stats[f'{metric}_std'] / np.sqrt(len(files))
        
        results.append(stats)
    
    final_results = pd.concat(results)
    output_file = os.path.join(summary_logs_dir, 'aggregated_summary_stats.csv')
    final_results.to_csv(output_file)
    print(f"Results saved to {output_file}")
    
    return final_results

if __name__ == "__main__":
    summary_logs_dir = "experiment_results/summary_logs"
    results = aggregate_summary_logs(summary_logs_dir)
    print("\nSummary of aggregated results:")
    print(results) 