import os
import argparse
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

def load_benchmark_data(dataset, results_dir):
    """Load benchmark CSVs for the base models and their P^2 variants."""
    models_with_p2 = ["CAM", "SCORE", "NoGAM", "DiffAN", "DAS"]
    models_without_p2 = ["CaPS", "OURS"]

    
    all_data = []
    
    # Load both the base and P^2 variants
    for model in models_with_p2:
        filename = f"{results_dir}/benchmark_{dataset}_{model}.csv"
        if os.path.exists(filename):
            df = pd.read_csv(filename)
            if not df.empty:
                df["model"] = model  # Keep original model label
                all_data.append(df)
                print(f"Loaded: {filename}")
        else:
            print(f"Missing: {filename}")
        
        # Load the P^2 variant
        p2_filename = f"{results_dir}/benchmark_{dataset}_{model}_P2.csv"
        if os.path.exists(p2_filename):
            df_p2 = pd.read_csv(p2_filename)
            if not df_p2.empty:
                df_p2["model"] = f"{model} w/ PEP"
                all_data.append(df_p2)
                print(f"Loaded: {p2_filename}")
        else:
            print(f"Missing: {p2_filename}")
    
    # Load only the base version for models without P^2
    for model in models_without_p2:
        filename = f"{results_dir}/benchmark_{dataset}_{model}.csv"
        if os.path.exists(filename):
            df = pd.read_csv(filename)
            if not df.empty:
                if model == "OURS":
                    display_model_name = r'CaPS w/ PEP'
                else:
                    display_model_name = model
                df["model"] = display_model_name
                all_data.append(df)
                print(f"Loaded: {filename}")
        else:
            print(f"Missing: {filename}")
    
    if not all_data:
        raise ValueError(f"No benchmark data found for {dataset}")
    
    combined_df = pd.concat(all_data, ignore_index=True)
    return combined_df

def plot_results(args):
    results_dir = args.results_dir
    output_dir = args.output_dir
    os.makedirs(output_dir, exist_ok=True)

    try:
        all_data_df = load_benchmark_data(args.dataset, results_dir)
    except ValueError as e:
        print(f"Error: {e}")
        return


    metrics_to_plot = ["shd", "sid", "F1"]
    metric_arrows = {'shd': r'$\downarrow$', 'sid': r'$\downarrow$', 'precision': r'$\uparrow$', 'recall': r'$\uparrow$', 'F1': r'$\uparrow$'}
    metric_display_names = {
        'shd': 'SHD',
        'sid': 'SID',
        'F1': 'F1'
    }
    
    model_order = [
        "CAM", "CAM w/ PEP",
        "SCORE",
        "SCORE w/ PEP",
        "DAS", 
        "DAS w/ PEP", 
        "NoGAM",
        "NoGAM w/ PEP",
        "DiffAN",
        "DiffAN w/ PEP",
        "CaPS",
        'CaPS w/ PEP'
    ]


    colors_2 = sns.color_palette("tab20c", 20)
    colors = sns.color_palette("Paired", 12)
    
    color_map = {
        "CAM": colors_2[19],
        "CAM w/ PEP": colors_2[17],  
        "SCORE": colors[6], 
        "SCORE w/ PEP": colors[7],  
        "DAS": colors[0],
        "DAS w/ PEP": colors[1],  
        "NoGAM": colors[2],
        "NoGAM w/ PEP": colors[3], 
        "DiffAN": colors[8],
        "DiffAN w/ PEP": colors[9],  
        "CaPS": colors[4],
        'CaPS w/ PEP': colors[5]  
    }

    custom_palette = [color_map[model] for model in model_order]

    fig, axes = plt.subplots(1, len(metrics_to_plot), figsize=(20,6)) # Wider and taller for readability

    if args.dataset.startswith("Syn"):
        suptitle = f"Metric Comparison for {args.dataset}"
    else:
        suptitle = f"Metric Comparison for {args.dataset}"

    for i, metric in enumerate(metrics_to_plot):
        ax = axes[i]
        sns.boxplot(data=all_data_df, x="model", y=metric, palette=custom_palette, ax=ax, order=model_order)
        sns.pointplot(
            data=all_data_df,
            x="model",
            y=metric,
            color="black",
            markers=".",
            linestyles="",
            errorbar=None,
            ax=ax,
            order=model_order,
        )

        arrow = metric_arrows.get(metric, "")
        
        # Get the desired display name for the metric
        display_name = metric_display_names.get(metric, metric.upper())
        ylabel = f'{display_name}{arrow}'
        ax.set_ylabel(ylabel, fontsize=16, fontweight='bold')

        ax.set_xlabel(None) # Remove xlabel
        ax.tick_params(axis="x", rotation=30, labelsize=16) # Rotate x-axis labels
        ax.tick_params(axis="y", labelsize=14) # Adjust y-axis label size
        
        # Improve x-axis label alignment
        for label in ax.get_xticklabels():
            label.set_horizontalalignment('right')
            if "PEP" in label.get_text():
                label.set_fontweight("bold")
                
    plt.tight_layout(rect=[0, 0.1, 1, 1])  # Reserve extra space for x-axis labels

    # Save the figure
    fig_filename = f"{args.dataset}_metrics_comparison.png"
    plt.savefig(os.path.join(output_dir, fig_filename), dpi=300, bbox_inches='tight')

    print(f"Plot saved to {os.path.join(output_dir, fig_filename)}")
    plt.close()


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Generate box plots for model comparison")
    parser.add_argument(
        "--results_dir",
        type=str,
        default="experiments/plot_data/benchmark",
        help="Directory containing result CSV files",
    )
    parser.add_argument(
        "--dataset", type=str, required=True, help="Dataset name (e.g., sachs, SynER1)"
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="fig/benchmark",
        help="Directory to save the plot",
    )

    args = parser.parse_args()
    plot_results(args)
