import os
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

datasets = ['retinamnist', 'organmnist3d', 'mnist','tabular-framingham', 'tabular-metabric', 'tabular-flchain']
for dataset in datasets:
    print(dataset)
    files = [f for f in os.listdir('output') if f.endswith('.csv') and f.startswith(f'{dataset}+')]

    stack = []
    for f in files:
        stack.append(pd.read_csv(f'output/{f}', index_col=0, header=None).T)

    if len(stack) == 0:
        print(f'no results found for {dataset}')
        continue

    stack = pd.concat(stack, axis=0)
    for col in ['dataset', 'backbone']:
        if col in stack.columns:
            stack = stack.drop([col], axis=1)

    stack = stack[['model', 'mixup_strategy', 'mixup_alpha', 'test_td_cindex_avg', 'test_ibs', 'test_ece_avg']]
    stack.columns = ['Model', 'Mixup', 'Alpha', 'Ctd', 'IBS', 'ECE']
    stack.iloc[:,2:] = stack.iloc[:,2:].astype(float)

    mix = stack['Mixup'].astype(str).str.strip()
    stack['mixup'] = np.where(mix.str.lower().eq('none'), 'ERM', mix)
    stack.loc[stack['mixup'].eq('ERM'), 'Alpha'] = 0.0

    df = stack

    # --- Plotting Logic ---
    metrics = ['Ctd', 'IBS', 'ECE']
    models = ['DeepCox', 'DeepIBS', 'DeepHit', 'DeepMTLR']
    n_rows, n_cols = len(metrics), len(models)

    # Define the order for the x-axis and hue
    mixup_order = ['ERM', 'smix', 'chmix', 'hmix', 'omix']
    alpha_order = sorted(df['Alpha'].unique())

    # Define a color palette to use consistently
    palette = sns.color_palette('muted', n_colors=len(alpha_order))
    alpha_color_map = dict(zip(alpha_order, palette))

    fig, axes = plt.subplots(n_rows, n_cols, figsize=(8 * n_cols, 6 * n_rows), sharex=True, constrained_layout=True)
    fig.suptitle('Model Performance Metrics', fontsize=24)
    for i, metric in enumerate(metrics):
        for j, model in enumerate(models):
            ax = axes[i, j]

            model_df = df[df['Model'] == model]

            # MODIFIED: Split data for separate plotting
            erm_df = model_df[model_df['mixup'] == 'ERM']
            other_df = model_df[model_df['mixup'] != 'ERM']
            other_mixup_order = [m for m in mixup_order if m != 'ERM']

            # Plot the single, centered bar for ERM
            sns.boxplot(
                data=erm_df,
                x='mixup',
                y=metric,
                ax=ax,
                width=0.6, # Control the width of the single bar
                color=alpha_color_map[0.0], # Use the color for alpha=0.1
                order=mixup_order # Ensure ERM is placed at the correct position
            )

            # Plot the grouped bars for other mixup methods first
            sns.boxplot(
                data=other_df,
                x='mixup',
                y=metric,
                hue='Alpha',
                ax=ax,
                palette=alpha_color_map,
                order=other_mixup_order,
                hue_order=alpha_order
            )

            if i == 0:
                ax.set_title(model, fontsize=18)
            ax.set_ylabel(metric if j == 0 else '', fontsize=14)
            ax.set_xlabel('')
            
            # Keep x-tick labels, but remove the tick marks themselves
            ax.tick_params(axis='x', length=0, labelsize=14) # length=0 removes the tick line
            
            ax.tick_params(axis='y', labelsize=12)

            if ax.get_legend() is not None:
                ax.get_legend().remove()

    # Add vertical lines to create visual separation between mixup strategies
    for row_axes in axes:
        for ax in row_axes:
            # Draw a line after each mixup method group, except the last one
            for k in range(len(mixup_order) - 1):
                ax.axvline(k + 0.5, color='grey', linestyle='--', linewidth=1.5)


    handles, labels = axes[0, 0].get_legend_handles_labels()
    #fig.legend(handles, labels, title='Alpha', bbox_to_anchor=(0.9, 0.9), loc='upper left', fontsize=12, title_fontsize=14)
    fig.legend(handles, labels, title='Alpha', bbox_to_anchor=(0.05, 0.9), loc='upper left', fontsize=12, title_fontsize=14)

    fig.supxlabel('Mixup Method', fontsize=18)

    plt.savefig(f'output/_fig_{dataset}.png', bbox_inches='tight', dpi=300)
    plt.close()


