import os
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

# --- Configuration ---
BASE_PATH = "/home/40430660@eeecs.qub.ac.uk/InterpretGatedNetwork/explanations"
ARCH = "PatchTST"
OUTPUT_DIR = f"./summary_plots/robustness_comparison/{ARCH}"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# UEA_DATASETS = ["UWaveGestureLibrary"] # Updated to match your example
UEA_DATASETS = [
   "UWaveGestureLibrary",
   "ERing",
   "RacketSports",
#    "SpokenArabicDigits",
   "NATOPS",
#    "Heartbeat"
#    "PenDigits",
   "CharacterTrajectories",
   "SelfRegulationSCP1",
   "ArticularyWordRecognition",
   "Libras",
   "BasicMotions",
#    "PEMS-SF"
   "Cricket",
   "Epilepsy"
]

# Methods to compare in the bar chart
METHODS = [
    'LIME', 'KernelSHAP', 'Saliency', 'InputXGradient', 'IG', 
    'KeystoneIG_0.05', 'KeystoneIG_0.1', 'KeystoneIG_0.15', 
    'KeystoneIG_0.2', 'KeystoneIG_0.25', 'BestEnsemble',
    'Input_only_amortized_attr_v2', 'amortized_attr_v2',  # Your new amortized version
]

# Consistent color mapping
COLORS = {
    "LIME": "#9E9E9E", "KernelSHAP": "#757575", "Saliency": "#607D8B", 
    "InputXGradient": "#27AE60", "IG": "#2980B9",
    "KeystoneIG_0.05": "#E1BEE7", "KeystoneIG_0.1": "#CE93D8",
    "KeystoneIG_0.15": "#BA68C8", "KeystoneIG_0.2": "#9C27B0",
    "KeystoneIG_0.25": "#7B1FA2",
    "BestEnsemble": "#C62828",
    "Input_only_amortized_attr_v2": "#725FDE",  # Amber for input-only amortized model
    "amortized_attr_v2": "#FF8F00" # Amber for amortized model
}

def plot_robustness_bar_chart(dataset_name):
    """
    Creates a bar chart comparing the Jaccard Index (Robustness) across methods.
    """
    folder_path = os.path.join(BASE_PATH, dataset_name, "DNN", ARCH)
    
    results = []
    
    for method in METHODS:
        # File format: Dataset-42-DNN-ARCH-test_METHOD_Robustness_Jaccard.csv
        file_name = f"{dataset_name}-42-DNN-{ARCH}-test_{method}_Robustness_Jaccard.csv"
        file_path = os.path.join(folder_path, file_name)
        
        if os.path.exists(file_path):
            try:
                df = pd.read_csv(file_path)
                # Extract the average IoU score (usually in 'avg_jaccard_iou' or 'iou' column)
                iou_val = df['avg_jaccard_iou'].iloc[0]
                results.append({'Method': method, 'IoU': iou_val})
            except Exception as e:
                print(f"Error reading {method} in {dataset_name}: {e}")
        else:
            print(f"Missing robustness file for {method}: {file_path}")

    if not results:
        print(f"No robustness data found for {dataset_name}.")
        return

    # Convert to DataFrame for plotting
    plot_df = pd.DataFrame(results)
    
    # Sort by IoU for better visual comparison (optional)
    # plot_df = plot_df.sort_values(by='IoU', ascending=False)

    plt.figure(figsize=(14, 7))
    plt.style.use('seaborn-v0_8-whitegrid')
    
    # Create the bar chart
    bars = plt.bar(plot_df['Method'], plot_df['IoU'], 
                   color=[COLORS.get(m, "#333333") for m in plot_df['Method']],
                   edgecolor='black', alpha=0.85)

    # Add text labels on top of each bar
    for bar in bars:
        yval = bar.get_height()
        plt.text(bar.get_x() + bar.get_width()/2, yval + 0.01, f'{yval:.3f}', 
                 ha='center', va='bottom', fontsize=10, fontweight='bold')

    plt.title(f"Robustness Comparison: {dataset_name} (Jaccard Index @ Top 20%)", fontsize=16, fontweight='bold')
    plt.ylabel("Mean Jaccard IoU (Higher is More Robust)", fontsize=12)
    plt.xlabel("Explanation Method", fontsize=12)
    plt.xticks(rotation=45, ha='right')
    plt.ylim(0, 1.1) # Max IoU is 1.0
    
    # Grid formatting
    plt.gca().yaxis.grid(True, linestyle='--', alpha=0.7)
    
    # Save the plot
    save_path = os.path.join(OUTPUT_DIR, f"{dataset_name}_{ARCH}_robustness_comparison_Input_only.png")
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    print(f"Successfully saved robustness bar chart to {save_path}")
    plt.close()

if __name__ == "__main__":
    print("Starting robustness result aggregation...")
    for ds in UEA_DATASETS:
        plot_robustness_bar_chart(ds)
    print("\nProcessing complete.")
