import os
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

# --- Configuration ---
BASE_PATH = "/home/40430660@eeecs.qub.ac.uk/InterpretGatedNetwork/explanations"
ARCHITECTURES = ["FCN", "ResNet", "PatchTST"]
OUTPUT_DIR = "./summary_plots/consolidated_results"
os.makedirs(OUTPUT_DIR, exist_ok=True)

UEA_DATASETS = [
    "UWaveGestureLibrary", "ERing", "RacketSports", "NATOPS",
    "CharacterTrajectories", "SelfRegulationSCP1", "ArticularyWordRecognition",
    "Libras", "BasicMotions", "Cricket", "Epilepsy"
]

# Consistent labels for methods
METHODS = ['LIME', 'KernelSHAP', 'InputXGradient', 'IG', 'Input_only_amortized_attr_v2', 'amortized_attr_v2']
METHOD_LABELS = ['LIME', 'KernelSHAP', 'InputxGrad', 'IG', 'XMA (Ours)']

# --- COLOR SCHEME REDESIGN (Scientific Style) ---
# Using professional, slightly desaturated tones that look good in publications.
COLORS = [
    "#D4AC0D", # LIME: Muted Gold
    "#CA6F1E", # KernelSHAP: Bronze
    "#A93226", # InputxGrad: Deep Brick Red
    "#2471A3", # IG: Deep Cobalt Blue
    # "#ED19CE",  # Input Only XMA: Vibrant Teal (Spotlight)
    "#1ABC9C"  # Amortized: Vibrant Teal (Spotlight)
]
FONT_SIZE = 25

# Suffix for Jaccard Robustness files
ROBUST_SUFFIX = "Robustness_Jaccard.csv"

def get_robustness_val(dataset, arch, method):
    """Extracts Jaccard IoU from the robustness CSV files."""
    folder_path = os.path.join(BASE_PATH, dataset, "DNN", arch)
    # File naming pattern: Dataset-42-DNN-ARCH-test_METHOD_Robustness_Jaccard.csv
    file_name = f"{dataset}-42-DNN-{arch}-test_{method}_{ROBUST_SUFFIX}"
    file_path = os.path.join(folder_path, file_name)
    
    if os.path.exists(file_path):
        try:
            df = pd.read_csv(file_path)
            return df['avg_jaccard_iou'].iloc[0]
        except: return None
    return None

def generate_robustness_plot():
    all_data = []

    # 1. Aggregate Robustness Data
    for ds in UEA_DATASETS:
        for method in METHODS:
            ious = []
            for arch in ARCHITECTURES:
                val = get_robustness_val(ds, arch, method)
                if val is not None: ious.append(val)
            
            if ious:
                if method == 'Input_only_amortized_attr_v2':
                    method_label = 'Input Only XMA'  # unify label for amortized methods
                elif method == 'amortized_attr_v2':
                    method_label = 'XMA (Ours)'
                else:
                    method_label = method
                all_data.append({
                    "Dataset": ds,
                    "Method": method_label,
                    "Mean": np.mean(ious),
                    "Std": np.std(ious)
                })

    df_robust = pd.DataFrame(all_data)
    
    # 2. Plotting Configuration
    plt.rcParams['font.family'] = 'sans-serif'
    plt.rcParams['font.sans-serif'] = ['Arial']
    
    fig, ax = plt.subplots(figsize=(18, 9))
    ax.set_facecolor('#fdfdfd') 
    
    n_datasets = len(UEA_DATASETS)
    n_methods = len(METHOD_LABELS)
    index = np.arange(n_datasets)
    bar_width = 0.15

    for i, method in enumerate(['LIME', 'KernelSHAP', 'InputXGradient', 'IG', 'XMA (Ours)']):
        method_df = df_robust[df_robust['Method'] == method]
        method_df = method_df.set_index('Dataset').reindex(UEA_DATASETS).reset_index()
        
        # Using white edges to make the colors "pop" against each other
        ax.bar(index + (i * bar_width), method_df['Mean'], bar_width,
               yerr=method_df['Std'], label=method, color=COLORS[i],
               capsize=4, error_kw={'elinewidth': 1.2, 'ecolor': '#333333'},
               edgecolor='white', linewidth=0.5, alpha=0.95)

    # --- Aesthetics ---
    # Lowered the title padding from 30 to 15
    # ax.set_title("Averaged Explanation Robustness Across 11 UEA Datasets\n",
                #  fontsize=FONT_SIZE, fontweight='bold', pad=1, color='#2c3e50')
    
    ax.set_ylabel("Mean Jaccard Index (IoU)", fontsize=FONT_SIZE, labelpad=15, fontweight='600')
    
    # Precise alignment for Dataset labels
    ax.set_xticks(index + bar_width * (n_methods - 1) / 2)
    ax.set_xticklabels(UEA_DATASETS, rotation=35, ha='right', fontsize=FONT_SIZE)
    
    ax.set_yticks(np.arange(0, 1.1, 0.1))
    ax.set_ylim(0.0, 1.1)
    
    # Styling grids and spines
    ax.yaxis.grid(True, linestyle='--', color='grey', alpha=0.3)
    ax.set_axisbelow(True)
    for spine in ['top', 'right']: ax.spines[spine].set_visible(False)
    for spine in ['left', 'bottom']: ax.spines[spine].set_color('#bdc3c7')

    # --- Legend Fix ---
    # Adjusted bbox_to_anchor from 1.12 to 1.05 to bring the legend lower
    legend = ax.legend(loc='upper center', bbox_to_anchor=(0.5, 1.0),
                       ncol=6, frameon=False, fontsize=FONT_SIZE*0.9)
    
    # Optional: Make the "Amortized" legend text bold to match spotlight
    plt.setp(legend.get_texts()[-1], fontweight='bold')

    plt.tight_layout()
    
    # Save results
    save_path = os.path.join(OUTPUT_DIR, "Robustness_Main.png")
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    print(f"Robustness Plot Saved: {save_path}")
    plt.show()

if __name__ == "__main__":
    generate_robustness_plot()
