import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import os
import argparse
import numpy as np 


def plot_pruning_comparison(df, dataset_name, output_dir):
    """
    Generates a 1x5 grid of grouped box plots to compare pruning methods across models.
    """
    
    metrics = ['shd', 'sid', 'F1']
    metric_display_names = {
        'shd': 'SHD',
        'sid': 'SID',
        'F1': 'F1'
    }
    metric_arrows = {'shd': r'$\downarrow$', 'sid': r'$\downarrow$', 'F1': r'$\uparrow$'}
    model_order = ['CAM', 'SCORE', 'DAS', 'NoGAM', 'DiffAN', 'CaPS']  # Include NoGAM

    hue_order = ['cam', 'rf', 'xgb', 'tabpfn']

    colors_2 = sns.color_palette("tab20c", 20)
    colors = sns.color_palette("Reds", 6)
    color_map = {
        'cam': colors_2[19],
        'rf': colors[0],
        'xgb': colors[2],
        'tabpfn': colors[4],
        }
    custom_palette = [color_map[method] for method in hue_order]

    # --- Plotting ---
    sns.set_style("whitegrid")

    fig, axes = plt.subplots(1, len(metrics), figsize=(20, 6))
    axes = axes.flatten()

    for i, metric in enumerate(metrics):
        ax = axes[i]
        
        # Draw the box plot
        box_plot = sns.boxplot(
            data=df, x='model', y=metric, hue='pruning_method',
            order=model_order, hue_order=hue_order,
            ax=ax, palette=custom_palette
        )
        
        # Compute the x-position when seaborn hue is used
        num_models = len(model_order)
        num_methods = len(hue_order)
        
        # Compute and display the mean for each combination
        for model_idx, model in enumerate(model_order):
            for method_idx, method in enumerate(hue_order):
                subset = df[(df['model'] == model) & (df['pruning_method'] == method)]
                if not subset.empty and not subset[metric].isna().all():
                    mean_val = subset[metric].mean()
                    
                    # Apply the spacing seaborn uses for evenly distributed hue boxes
                    group_width = 0.8  # Default seaborn group width for boxes
                    method_width = group_width / num_methods
                    
                    # Offset from the model center for each method
                    method_offset = (method_idx - (num_methods - 1) / 2) * method_width
                    
                    # Final x coordinate
                    x_pos = model_idx + method_offset
                    
                    # Plot the mean marker
                    ax.scatter(x_pos, mean_val, color='black', s=30, zorder=10, alpha=0.8)
        
        display_name = metric_display_names.get(metric, metric.upper())
        arrow = metric_arrows.get(metric, '')
        
        ax.set_title('') # Remove individual subplot titles
        ax.set_xlabel('')
        ax.set_ylabel(f'{display_name} {arrow}', fontsize=16, fontweight='bold')
        ax.tick_params(axis='x', labelsize=16)
        
        # Configure x-axis tick label weight
        for label in ax.get_xticklabels():
            label.set_fontweight('normal')
            
        ax.tick_params(axis='y', labelsize=14)
        
        # Hide individual legends from subplots
        ax.get_legend().remove()

    # --- Final Touches ---
    handles, labels = axes[0].get_legend_handles_labels()
    
    # Map the original data labels to the desired display labels for the legend
    label_map = {
        'cam': 'CAM pruning (Original)',
        'rf': 'PEP (w/ Random Forest) ',
        'xgb': 'PEP (w/ XGBoost) ',
        'tabpfn': 'PEP (w/ TabPFN, Ours)'
    }
    new_labels = [label_map.get(label, label) for label in labels]
    
    # fig.legend(handles, new_labels, loc='upper center', bbox_to_anchor=(0.5, 1.05), ncol=4, fontsize=16)
    
    # fig.suptitle(f'Pruning Method Comparison on {dataset_name}', fontsize=24, fontweight='bold', y=1.06)
    plt.tight_layout(rect=[0, 0, 1, 0.9]) # Adjust layout

    # --- Save Figure ---
    os.makedirs(output_dir, exist_ok=True)
    plot_filename = os.path.join(output_dir, f"{dataset_name}_box_pruning_comparison.png")
    plt.savefig(plot_filename, dpi=300, bbox_inches='tight')
    plt.close(fig)
    print(f"Generated plot: {plot_filename}")


def load_pruning_data(dataset, results_dir):
    """Load pruning metrics from the new CSV layout."""
    models = ["CAM", "SCORE", "DAS", "NoGAM", "DiffAN", "CaPS"]  # Include NoGAM
    all_data = []
    
    for model in models:
        filename = f"{results_dir}/pruning_{dataset}_{model}.csv"
        
        if os.path.exists(filename):
            df = pd.read_csv(filename)
            if not df.empty:
                df['model'] = model
                all_data.append(df)
                print(f"Loaded: {filename}")
        else:
            print(f"Missing: {filename}")
    
    if not all_data:
        raise ValueError(f"No pruning data found for {dataset}")
    
    combined_df = pd.concat(all_data, ignore_index=True)
    return combined_df


def main(args):
    """Main function to generate pruning comparison plots"""
    try:
        full_df = load_pruning_data(args.dataset, args.results_dir)
        plot_pruning_comparison(full_df, args.dataset, args.output_dir)
    except ValueError as e:
        print(f"Error: {e}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Generate box plots for pruning method comparison")
    parser.add_argument(
        "--results_dir",
        type=str,
        default="experiments/plot_data/pruning",
        help="Directory containing CSV files",
    )
    parser.add_argument(
        "--dataset", type=str, required=True, help="Dataset name (e.g., sachs, SynER1)"
    )
    parser.add_argument(
        "--output_dir", type=str, default="fig/pruning", help="Output directory for plots"
    )

    args = parser.parse_args()
    main(args)
