#!/usr/bin/env python3
"""
Script to visualize evaluation results from all models and create leaderboards.
Creates separate leaderboards for each dataset (cisco and nsl_kdd).
"""

import json
import os
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import numpy as np
from typing import Dict, List, Any, Optional


def load_results_from_outputs(outputs_dir: str = "outputs") -> Dict[str, Dict[str, Any]]:
    """Load all results.json files from the outputs directory."""
    results = {}
    outputs_path = Path(outputs_dir)
    
    if not outputs_path.exists():
        print(f"Outputs directory {outputs_dir} does not exist!")
        return results
    
    for model_dir in outputs_path.iterdir():
        if model_dir.is_dir():
            results_file = model_dir / "results.json"
            if results_file.exists():
                try:
                    with open(results_file, 'r') as f:
                        data = json.load(f)
                        results[model_dir.name] = data
                        print(f"Loaded results for {model_dir.name}")
                except Exception as e:
                    print(f"Error loading {results_file}: {e}")
            else:
                print(f"No results.json found in {model_dir}")
    
    return results


def extract_main_metrics(results: Dict[str, Dict[str, Any]]) -> pd.DataFrame:
    """Extract main test metrics from all models."""
    rows = []
    
    for model_name, data in results.items():
        if 'test_metrics' not in data:
            print(f"No test_metrics found for {model_name}")
            continue
            
        test_metrics = data['test_metrics']
        dataset = data.get('dataset', 'unknown')
        
        # Try to infer dataset from model_name if not specified
        if dataset == 'unknown':
            if 'nsl_kdd' in model_name.lower():
                dataset = 'nsl_kdd'
            elif 'cisco' in model_name.lower():
                dataset = 'cisco'
            elif 'config' in data and 'dataset' in data['config']:
                dataset = data['config']['dataset'].get('name', 'unknown')
        
        model_type = data.get('model', model_name.split('_')[0])
        
        row = {
            'Model': model_type.upper(),
            'Dataset': dataset.upper(),
            'Accuracy': test_metrics.get('accuracy', 0),
            'Precision': test_metrics.get('precision', 0),
            'Recall': test_metrics.get('recall', 0),
            'F1': test_metrics.get('f1', 0),
            'AUC': test_metrics.get('auc', 0),
            'Training Time (s)': data.get('total_time', 0)
        }
        
        # Add validation metrics if available
        if 'train_results' in data and 'val_metrics' in data['train_results']:
            val_metrics = data['train_results']['val_metrics']
            row['Val_Accuracy'] = val_metrics.get('accuracy', 0)
            row['Val_F1'] = val_metrics.get('f1', 0)
        
        rows.append(row)
    
    return pd.DataFrame(rows)


def extract_class_metrics(results: Dict[str, Dict[str, Any]]) -> Dict[str, pd.DataFrame]:
    """Extract per-class metrics for multi-class datasets."""
    class_metrics_by_dataset = {}
    
    for model_name, data in results.items():
        if 'test_metrics' not in data:
            continue
            
        test_metrics = data['test_metrics']
        dataset = data.get('dataset', 'unknown')
        
        # Try to infer dataset from model_name if not specified
        if dataset == 'unknown':
            if 'nsl_kdd' in model_name.lower():
                dataset = 'nsl_kdd'
            elif 'cisco' in model_name.lower():
                dataset = 'cisco'
            elif 'config' in data and 'dataset' in data['config']:
                dataset = data['config']['dataset'].get('name', 'unknown')
        
        model_type = data.get('model', model_name.split('_')[0])
        
        # Check if class-level metrics exist
        if 'class_accuracies' in test_metrics:
            if dataset not in class_metrics_by_dataset:
                class_metrics_by_dataset[dataset] = []
            
            class_accuracies = test_metrics['class_accuracies']
            class_precision = test_metrics.get('class_precision', {})
            class_recall = test_metrics.get('class_recall', {})
            class_f1 = test_metrics.get('class_f1', {})
            
            for class_name in class_accuracies.keys():
                row = {
                    'Model': model_type.upper(),
                    'Class': class_name.upper(),
                    'Accuracy': class_accuracies.get(class_name, 0),
                    'Precision': class_precision.get(class_name, 0),
                    'Recall': class_recall.get(class_name, 0),
                    'F1': class_f1.get(class_name, 0)
                }
                class_metrics_by_dataset[dataset].append(row)
    
    # Convert to DataFrames
    for dataset in class_metrics_by_dataset:
        class_metrics_by_dataset[dataset] = pd.DataFrame(class_metrics_by_dataset[dataset])
    
    return class_metrics_by_dataset


def create_main_leaderboard(df: pd.DataFrame, dataset: str, save_dir: str = "leaderboards"):
    """Create main leaderboard visualization for a dataset."""
    dataset_df = df[df['Dataset'] == dataset.upper()].copy()
    
    if dataset_df.empty:
        print(f"No data found for dataset: {dataset}")
        return
    
    # Sort by F1 score (primary metric)
    dataset_df = dataset_df.sort_values('F1', ascending=False)
    
    # Create separate focused visualizations
    create_performance_ranking(dataset_df, dataset, save_dir)
    create_metrics_comparison(dataset_df, dataset, save_dir)
    create_performance_scatter(dataset_df, dataset, save_dir)
    if 'Training Time (s)' in dataset_df.columns and dataset_df['Training Time (s)'].sum() > 0:
        create_training_time_analysis(dataset_df, dataset, save_dir)


def create_performance_ranking(df: pd.DataFrame, dataset: str, save_dir: str = "leaderboards"):
    """Create a clean performance ranking chart."""
    fig, ax = plt.subplots(figsize=(12, 8))
    
    # Create horizontal bar chart for better readability
    y_pos = np.arange(len(df))
    colors = plt.cm.RdYlGn(df['F1'] / df['F1'].max())  # Color by performance
    
    bars = ax.barh(y_pos, df['F1'], color=colors, alpha=0.8, edgecolor='black')
    
    # Customize the chart
    ax.set_yticks(y_pos)
    ax.set_yticklabels(df['Model'], fontsize=12, fontweight='bold')
    ax.set_xlabel('F1 Score', fontsize=12, fontweight='bold')
    ax.set_title(f'{dataset.upper()} Dataset - Model Performance Ranking', 
                fontsize=16, fontweight='bold', pad=20)
    
    # Add value labels on bars
    for i, (bar, value) in enumerate(zip(bars, df['F1'])):
        ax.text(bar.get_width() + 0.01, bar.get_y() + bar.get_height()/2,
                f'{value:.3f}', va='center', fontweight='bold', fontsize=11)
    
    # Add grid for better readability
    ax.grid(True, axis='x', alpha=0.3)
    ax.set_xlim(0, 1)
    
    # Add rank numbers
    for i, rank in enumerate(range(1, len(df) + 1)):
        ax.text(-0.05, i, f'#{rank}', ha='center', va='center', 
                fontweight='bold', fontsize=12, color='darkblue')
    
    plt.tight_layout()
    save_path = Path(save_dir) / f"{dataset}_performance_ranking.png"
    plt.savefig(save_path, dpi=300, bbox_inches='tight', facecolor='white')
    print(f"Saved performance ranking for {dataset} to {save_path}")
    plt.close()


def create_metrics_comparison(df: pd.DataFrame, dataset: str, save_dir: str = "leaderboards"):
    """Create a metrics comparison chart."""
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(16, 12))
    fig.suptitle(f'{dataset.upper()} Dataset - Detailed Metrics Comparison', 
                 fontsize=16, fontweight='bold', y=0.95)
    
    metrics = ['Accuracy', 'Precision', 'Recall', 'F1']
    axes = [ax1, ax2, ax3, ax4]
    colors = sns.color_palette("husl", len(df))
    
    for ax, metric in zip(axes, metrics):
        bars = ax.bar(df['Model'], df[metric], color=colors, alpha=0.8, edgecolor='black')
        ax.set_title(f'{metric} Comparison', fontsize=14, fontweight='bold')
        ax.set_ylabel(metric, fontsize=12)
        ax.set_ylim(0, 1)
        plt.setp(ax.get_xticklabels(), rotation=45, ha='right')
        
        # Add value labels
        for bar, value in zip(bars, df[metric]):
            ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
                    f'{value:.3f}', ha='center', va='bottom', fontweight='bold', fontsize=10)
        
        ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    save_path = Path(save_dir) / f"{dataset}_metrics_comparison.png"
    plt.savefig(save_path, dpi=300, bbox_inches='tight', facecolor='white')
    print(f"Saved metrics comparison for {dataset} to {save_path}")
    plt.close()


def create_performance_scatter(df: pd.DataFrame, dataset: str, save_dir: str = "leaderboards"):
    """Create performance correlation scatter plots."""
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(16, 12))
    fig.suptitle(f'{dataset.upper()} Dataset - Performance Correlations', 
                 fontsize=16, fontweight='bold', y=0.95)
    
    colors = sns.color_palette("husl", len(df))
    
    # Accuracy vs F1
    ax1.scatter(df['Accuracy'], df['F1'], c=colors, s=150, alpha=0.7, edgecolors='black')
    ax1.set_xlabel('Accuracy')
    ax1.set_ylabel('F1 Score')
    ax1.set_title('Accuracy vs F1 Score')
    ax1.grid(True, alpha=0.3)
    for i, model in enumerate(df['Model']):
        ax1.annotate(model, (df.iloc[i]['Accuracy'], df.iloc[i]['F1']),
                    xytext=(5, 5), textcoords='offset points', fontweight='bold')
    
    # Precision vs Recall
    ax2.scatter(df['Precision'], df['Recall'], c=colors, s=150, alpha=0.7, edgecolors='black')
    ax2.set_xlabel('Precision')
    ax2.set_ylabel('Recall')
    ax2.set_title('Precision vs Recall')
    ax2.grid(True, alpha=0.3)
    for i, model in enumerate(df['Model']):
        ax2.annotate(model, (df.iloc[i]['Precision'], df.iloc[i]['Recall']),
                    xytext=(5, 5), textcoords='offset points', fontweight='bold')
    
    # F1 vs AUC
    ax3.scatter(df['F1'], df['AUC'], c=colors, s=150, alpha=0.7, edgecolors='black')
    ax3.set_xlabel('F1 Score')
    ax3.set_ylabel('AUC')
    ax3.set_title('F1 Score vs AUC')
    ax3.grid(True, alpha=0.3)
    for i, model in enumerate(df['Model']):
        ax3.annotate(model, (df.iloc[i]['F1'], df.iloc[i]['AUC']),
                    xytext=(5, 5), textcoords='offset points', fontweight='bold')
    
    # Overall performance radar-like visualization
    ax4.axis('off')
    
    # Create a simple performance summary table
    summary_data = df[['Model', 'F1']].copy()
    summary_data['Rank'] = range(1, len(summary_data) + 1)
    summary_data = summary_data[['Rank', 'Model', 'F1']].round(3)
    
    table = ax4.table(cellText=summary_data.values,
                     colLabels=['Rank', 'Model', 'F1 Score'],
                     cellLoc='center',
                     loc='center',
                     colColours=['#E8E8E8'] * 3)
    
    table.auto_set_font_size(False)
    table.set_fontsize(12)
    table.scale(1.2, 2)
    
    # Style the table
    for j in range(3):
        table[(0, j)].set_text_props(weight='bold', color='white')
        table[(0, j)].set_facecolor('#4472C4')
    
    ax4.set_title('Performance Summary', fontsize=14, fontweight='bold')
    
    plt.tight_layout()
    save_path = Path(save_dir) / f"{dataset}_performance_correlations.png"
    plt.savefig(save_path, dpi=300, bbox_inches='tight', facecolor='white')
    print(f"Saved performance correlations for {dataset} to {save_path}")
    plt.close()


def create_training_time_analysis(df: pd.DataFrame, dataset: str, save_dir: str = "leaderboards"):
    """Create training time analysis."""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
    fig.suptitle(f'{dataset.upper()} Dataset - Training Time Analysis', 
                 fontsize=16, fontweight='bold', y=0.95)
    
    colors = sns.color_palette("husl", len(df))
    
    # Training time comparison
    bars = ax1.bar(df['Model'], df['Training Time (s)'], color=colors, alpha=0.8, edgecolor='black')
    ax1.set_title('Training Time Comparison', fontsize=14, fontweight='bold')
    ax1.set_ylabel('Training Time (seconds)')
    plt.setp(ax1.get_xticklabels(), rotation=45, ha='right')
    ax1.grid(True, alpha=0.3)
    
    # Add value labels
    for bar, value in zip(bars, df['Training Time (s)']):
        ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + max(df['Training Time (s)']) * 0.01,
                f'{value:.2f}s', ha='center', va='bottom', fontweight='bold')
    
    # Performance vs Training Time
    scatter = ax2.scatter(df['Training Time (s)'], df['F1'], c=colors, s=150, alpha=0.7, edgecolors='black')
    ax2.set_xlabel('Training Time (seconds)')
    ax2.set_ylabel('F1 Score')
    ax2.set_title('Performance vs Training Time')
    ax2.grid(True, alpha=0.3)
    
    # Add model labels
    for i, model in enumerate(df['Model']):
        ax2.annotate(model, (df.iloc[i]['Training Time (s)'], df.iloc[i]['F1']),
                    xytext=(5, 5), textcoords='offset points', fontweight='bold')
    
    plt.tight_layout()
    save_path = Path(save_dir) / f"{dataset}_training_time_analysis.png"
    plt.savefig(save_path, dpi=300, bbox_inches='tight', facecolor='white')
    print(f"Saved training time analysis for {dataset} to {save_path}")
    plt.close()


def create_class_leaderboard(df: pd.DataFrame, dataset: str, save_dir: str = "leaderboards"):
    """Create per-class performance leaderboard for multi-class datasets."""
    if df.empty:
        print(f"No class-level data found for dataset: {dataset}")
        return
    
    # Create separate focused visualizations for class-level analysis
    create_class_heatmaps(df, dataset, save_dir)
    create_class_rankings(df, dataset, save_dir)
    create_class_model_comparison(df, dataset, save_dir)


def create_class_heatmaps(df: pd.DataFrame, dataset: str, save_dir: str = "leaderboards"):
    """Create individual heatmaps for each metric."""
    metrics = ['Accuracy', 'Precision', 'Recall', 'F1']
    
    for metric in metrics:
        fig, ax = plt.subplots(figsize=(12, 8))
        
        # Create pivot table for heatmap
        pivot_df = df.pivot(index='Class', columns='Model', values=metric)
        
        # Create heatmap
        sns.heatmap(pivot_df, annot=True, fmt='.3f', cmap='RdYlGn', 
                   ax=ax, cbar_kws={'label': metric}, 
                   square=True, linewidths=0.5)
        
        ax.set_title(f'{dataset.upper()} Dataset - Per-Class {metric}', 
                    fontsize=16, fontweight='bold', pad=20)
        ax.set_xlabel('Model', fontsize=12, fontweight='bold')
        ax.set_ylabel('Attack Class', fontsize=12, fontweight='bold')
        
        # Rotate labels for better readability
        plt.setp(ax.get_xticklabels(), rotation=45, ha='right')
        plt.setp(ax.get_yticklabels(), rotation=0)
        
        plt.tight_layout()
        save_path = Path(save_dir) / f"{dataset}_class_{metric.lower()}_heatmap.png"
        plt.savefig(save_path, dpi=300, bbox_inches='tight', facecolor='white')
        print(f"Saved class {metric} heatmap for {dataset} to {save_path}")
        plt.close()


def create_class_rankings(df: pd.DataFrame, dataset: str, save_dir: str = "leaderboards"):
    """Create ranking charts for each class."""
    classes = sorted(df['Class'].unique())
    
    for class_name in classes:
        class_df = df[df['Class'] == class_name].copy()
        class_df = class_df.sort_values('F1', ascending=False)
        
        fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(16, 12))
        fig.suptitle(f'{dataset.upper()} Dataset - {class_name.upper()} Class Performance', 
                     fontsize=16, fontweight='bold', y=0.95)
        
        metrics = ['Accuracy', 'Precision', 'Recall', 'F1']
        axes = [ax1, ax2, ax3, ax4]
        colors = sns.color_palette("husl", len(class_df))
        
        for ax, metric in zip(axes, metrics):
            # Horizontal bar chart for better readability
            y_pos = np.arange(len(class_df))
            bars = ax.barh(y_pos, class_df[metric], color=colors, alpha=0.8, edgecolor='black')
            
            ax.set_yticks(y_pos)
            ax.set_yticklabels(class_df['Model'], fontsize=11, fontweight='bold')
            ax.set_xlabel(metric, fontsize=12, fontweight='bold')
            ax.set_title(f'{metric} for {class_name.upper()}', fontsize=14, fontweight='bold')
            ax.set_xlim(0, 1)
            ax.grid(True, axis='x', alpha=0.3)
            
            # Add value labels
            for i, (bar, value) in enumerate(zip(bars, class_df[metric])):
                ax.text(bar.get_width() + 0.01, bar.get_y() + bar.get_height()/2,
                        f'{value:.3f}', va='center', fontweight='bold', fontsize=10)
        
        plt.tight_layout()
        save_path = Path(save_dir) / f"{dataset}_class_{class_name.lower()}_ranking.png"
        plt.savefig(save_path, dpi=300, bbox_inches='tight', facecolor='white')
        print(f"Saved {class_name} class ranking for {dataset} to {save_path}")
        plt.close()


def create_class_model_comparison(df: pd.DataFrame, dataset: str, save_dir: str = "leaderboards"):
    """Create model comparison across all classes."""
    models = sorted(df['Model'].unique())
    
    for model in models:
        model_df = df[df['Model'] == model].copy()
        model_df = model_df.sort_values('F1', ascending=False)
        
        fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(16, 12))
        fig.suptitle(f'{dataset.upper()} Dataset - {model} Model Performance Across Classes', 
                     fontsize=16, fontweight='bold', y=0.95)
        
        metrics = ['Accuracy', 'Precision', 'Recall', 'F1']
        axes = [ax1, ax2, ax3, ax4]
        colors = sns.color_palette("Set2", len(model_df))
        
        for ax, metric in zip(axes, metrics):
            bars = ax.bar(model_df['Class'], model_df[metric], color=colors, alpha=0.8, edgecolor='black')
            ax.set_title(f'{metric} Across Classes', fontsize=14, fontweight='bold')
            ax.set_ylabel(metric, fontsize=12)
            ax.set_ylim(0, 1)
            plt.setp(ax.get_xticklabels(), rotation=45, ha='right')
            ax.grid(True, alpha=0.3)
            
            # Add value labels
            for bar, value in zip(bars, model_df[metric]):
                ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
                        f'{value:.3f}', ha='center', va='bottom', fontweight='bold', fontsize=10)
        
        plt.tight_layout()
        save_path = Path(save_dir) / f"{dataset}_model_{model.lower()}_class_comparison.png"
        plt.savefig(save_path, dpi=300, bbox_inches='tight', facecolor='white')
        print(f"Saved {model} model class comparison for {dataset} to {save_path}")
        plt.close()




def create_summary_comparison(df: pd.DataFrame, save_dir: str = "leaderboards"):
    """Create a summary comparison across all datasets."""
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(20, 16))
    fig.suptitle('Cross-Dataset Model Performance Summary', 
                 fontsize=20, fontweight='bold', y=0.95)
    
    datasets = df['Dataset'].unique()
    models = df['Model'].unique()
    
    # 1. F1 Score comparison across datasets
    for dataset in datasets:
        dataset_df = df[df['Dataset'] == dataset]
        ax1.plot(dataset_df['Model'], dataset_df['F1'], 
                marker='o', linewidth=2, markersize=8, label=dataset)
    
    ax1.set_title('F1 Score Across Datasets', fontsize=14, fontweight='bold')
    ax1.set_ylabel('F1 Score')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    plt.setp(ax1.get_xticklabels(), rotation=45, ha='right')
    
    # 2. Accuracy comparison
    for dataset in datasets:
        dataset_df = df[df['Dataset'] == dataset]
        ax2.plot(dataset_df['Model'], dataset_df['Accuracy'], 
                marker='s', linewidth=2, markersize=8, label=dataset)
    
    ax2.set_title('Accuracy Across Datasets', fontsize=14, fontweight='bold')
    ax2.set_ylabel('Accuracy')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    plt.setp(ax2.get_xticklabels(), rotation=45, ha='right')
    
    # 3. Model ranking heatmap (by F1 score)
    ranking_data = []
    for dataset in datasets:
        dataset_df = df[df['Dataset'] == dataset].sort_values('F1', ascending=False)
        for rank, (_, row) in enumerate(dataset_df.iterrows(), 1):
            ranking_data.append({
                'Dataset': dataset,
                'Model': row['Model'],
                'Rank': rank,
                'F1': row['F1']
            })
    
    ranking_df = pd.DataFrame(ranking_data)
    pivot_ranking = ranking_df.pivot(index='Model', columns='Dataset', values='Rank')
    
    sns.heatmap(pivot_ranking, annot=True, fmt='.0f', cmap='RdYlGn_r', 
               ax=ax3, cbar_kws={'label': 'Rank (1=Best)'})
    ax3.set_title('Model Rankings by F1 Score', fontsize=14, fontweight='bold')
    
    # 4. Performance vs Training Time scatter
    if 'Training Time (s)' in df.columns and df['Training Time (s)'].sum() > 0:
        for dataset in datasets:
            dataset_df = df[df['Dataset'] == dataset]
            scatter = ax4.scatter(dataset_df['Training Time (s)'], dataset_df['F1'], 
                                 label=dataset, s=100, alpha=0.7)
            
            # Add model labels
            for _, row in dataset_df.iterrows():
                ax4.annotate(row['Model'], (row['Training Time (s)'], row['F1']),
                           xytext=(5, 5), textcoords='offset points', fontsize=8)
        
        ax4.set_xlabel('Training Time (seconds)')
        ax4.set_ylabel('F1 Score')
        ax4.set_title('Performance vs Training Time', fontsize=14, fontweight='bold')
        ax4.legend()
        ax4.grid(True, alpha=0.3)
    else:
        ax4.text(0.5, 0.5, 'Training time data\nnot available', 
                ha='center', va='center', transform=ax4.transAxes, fontsize=12)
        ax4.set_title('Performance vs Training Time', fontsize=14, fontweight='bold')
    
    plt.tight_layout()
    
    # Save the plot
    save_path = Path(save_dir) / "summary_comparison.png"
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    print(f"Saved summary comparison to {save_path}")
    plt.close()


def main():
    """Main function to generate all leaderboards."""
    print("Starting leaderboard generation...")
    
    # Create leaderboards directory
    save_dir = "leaderboards"
    Path(save_dir).mkdir(exist_ok=True)
    
    # Load all results
    results = load_results_from_outputs()
    
    if not results:
        print("No results found! Make sure the outputs directory contains results.json files.")
        return
    
    print(f"Found results for {len(results)} models")
    
    # Extract main metrics
    main_df = extract_main_metrics(results)
    
    if main_df.empty:
        print("No valid metrics found in results!")
        return
    
    print(f"Extracted metrics for {len(main_df)} model-dataset combinations")
    print("Datasets found:", main_df['Dataset'].unique())
    print("Models found:", main_df['Model'].unique())
    
    # Create main leaderboards for each dataset
    for dataset in main_df['Dataset'].unique():
        print(f"\nCreating leaderboard for {dataset}...")
        create_main_leaderboard(main_df, dataset.lower(), save_dir)
    
    # Extract and visualize class-level metrics
    class_metrics = extract_class_metrics(results)
    
    for dataset, class_df in class_metrics.items():
        print(f"\nCreating class-level analysis for {dataset}...")
        create_class_leaderboard(class_df, dataset, save_dir)
    
    # Create summary comparison
    if len(main_df['Dataset'].unique()) > 1:
        print("\nCreating cross-dataset summary...")
        create_summary_comparison(main_df, save_dir)
    
    # Save raw data to CSV for reference
    main_df.to_csv(Path(save_dir) / "all_results.csv", index=False)
    print(f"\nSaved raw results to {save_dir}/all_results.csv")
    
    print(f"\nLeaderboard generation complete! Check the '{save_dir}' directory for all visualizations.")


if __name__ == "__main__":
    main()
