import os
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

# --- Configuration ---
BASE_PATH = "/home/40430660@eeecs.qub.ac.uk/InterpretGatedNetwork/explanations"
ARCHITECTURES = ["FCN", "ResNet", "PatchTST"]
OUTPUT_DIR = "./summary_plots/consolidated_results"
os.makedirs(OUTPUT_DIR, exist_ok=True)

UEA_DATASETS = [
    "UWaveGestureLibrary", "ERing", "RacketSports", "NATOPS",
    "CharacterTrajectories", "SelfRegulationSCP1", "ArticularyWordRecognition",
    "Libras", "BasicMotions", "Cricket", "Epilepsy"
]

# Consistent labels for methods
METHODS = ['LIME', 'KernelSHAP', 'InputXGradient', 'IG', 'Input_only_amortized_attr_v2', 'amortized_attr_v2']
METHOD_LABELS = ['LIME', 'KernelSHAP', 'InputxGrad', 'IG', 'XMA (Ours)']

# --- COLOR SCHEME (Professional Science Style) ---
COLORS = [
    "#D4AC0D", # LIME: Muted Gold
    "#CA6F1E", # KernelSHAP: Bronze
    "#A93226", # InputxGrad: Deep Brick Red
    "#2471A3", # IG: Deep Cobalt Blue
    # "#ED19CE",  # Input Only XMA: Vibrant Teal (Spotlight)
    "#1ABC9C"  # Amortized: Vibrant Teal (Spotlight)
]
FONT_SIZE = 25

# Suffix for ROC-AUC Faithfulness files
FAITH_SUFFIX = "BottomUp_Roc_auc_Faithfulness.csv"
AUC_COLUMN = "faithfulness_auc"

def get_faithfulness_val(dataset, arch, method):
    """Extracts ROC-AUC Faithfulness from CSV files."""
    folder_path = os.path.join(BASE_PATH, dataset, "DNN", arch)
    file_name = f"{dataset}-42-DNN-{arch}-test_{method}_{FAITH_SUFFIX}"
    file_path = os.path.join(folder_path, file_name)
    
    if os.path.exists(file_path):
        try:
            df = pd.read_csv(file_path)
            return df[AUC_COLUMN].iloc[0]
        except: return None
    return None

def generate_faithfulness_plot():
    all_data = []

    # 1. Aggregate ROC-AUC Data
    for ds in UEA_DATASETS:
        for method in METHODS:
            aucs = []
            for arch in ARCHITECTURES:
                val = get_faithfulness_val(ds, arch, method)
                if val is not None: aucs.append(val)
            
            if aucs:
                if method == 'Input_only_amortized_attr_v2':
                    method_label = 'Input Only XMA'  # unify label for amortized methods
                elif method == 'amortized_attr_v2':
                    method_label = 'XMA (Ours)'
                else:
                    method_label = method
                all_data.append({
                    "Dataset": ds,
                    "Method": method_label,
                    "Mean": np.mean(aucs),
                    "Std": np.std(aucs)
                })

    df_faith = pd.DataFrame(all_data)
    
    # 2. Plotting Configuration
    plt.rcParams['font.family'] = 'sans-serif'
    plt.rcParams['font.sans-serif'] = ['Arial']
    
    fig, ax = plt.subplots(figsize=(18, 9))
    ax.set_facecolor('#fdfdfd') 
    
    n_datasets = len(UEA_DATASETS)
    n_methods = len(METHOD_LABELS)
    index = np.arange(n_datasets)
    bar_width = 0.15

    basic_motions = {}

    for i, method in enumerate(['LIME', 'KernelSHAP', 'InputXGradient', 'IG', 'XMA (Ours)']):
        method_df = df_faith[df_faith['Method'] == method]
        method_df = method_df.set_index('Dataset').reindex(UEA_DATASETS).reset_index()
        # print(f"Plotting {method} with means: {method_df['Mean'].values}, max: {method_df['Mean'].max()}")
        # print(method_df)
        mean_value = method_df.loc[method_df['Dataset'] == 'Epilepsy', 'Mean'].values[0]
        print('Epilepsy mean for', method, ':', mean_value)
        # exit()
        
        ax.bar(index + (i * bar_width), method_df['Mean'], bar_width,
               yerr=method_df['Std'], label=method, color=COLORS[i],
               capsize=4, error_kw={'elinewidth': 1.2, 'ecolor': '#333333'},
               edgecolor='white', linewidth=0.5, alpha=0.95)
    # exit()

    # --- Aesthetics ---
    # Short title as requested
    # ax.set_title("Averaged Faithfulness Across 11 UEA Archive\n",
                #  fontsize=FONT_SIZE, fontweight='bold', pad=1, color='#2c3e50')
    
    ax.set_ylabel("ROC-AUC Faithfulness", fontsize=FONT_SIZE, labelpad=15, fontweight='600')
    
    # Precise alignment for Dataset labels
    ax.set_xticks(index + bar_width * (n_methods - 1) / 2)
    ax.set_xticklabels(UEA_DATASETS, rotation=35, ha='right', fontsize=FONT_SIZE)
    
    ax.set_yticks(np.arange(0, 1.1, 0.1))
    ax.set_ylim(0.70, 1.1)
    
    # Styling grids and spines
    ax.yaxis.grid(True, linestyle='--', color='grey', alpha=0.3)
    ax.set_axisbelow(True)
    for spine in ['top', 'right']: ax.spines[spine].set_visible(False)
    for spine in ['left', 'bottom']: ax.spines[spine].set_color('#bdc3c7')

    # --- Legend Fix (Lowered and Compact) ---
    legend = ax.legend(loc='upper center', bbox_to_anchor=(0.5, 1.0),
                       ncol=6, frameon=False, fontsize=FONT_SIZE*0.9)
    
    # Spotlight: Bold "Amortized" text
    plt.setp(legend.get_texts()[-1], fontweight='bold')

    plt.tight_layout()
    
    # Save results
    save_path = os.path.join(OUTPUT_DIR, "Faithfulness_Main_Multi.png")
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    print(f"Faithfulness Plot Saved: {save_path}")
    plt.show()

if __name__ == "__main__":
    generate_faithfulness_plot()
