import os
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

# --- Configuration ---
BASE_PATH = "/home/40430660@eeecs.qub.ac.uk/InterpretGatedNetwork/explanations"
ARCH = "PatchTST"
OUTPUT_DIR = f"./summary_plots/test_data_faithfulness/{ARCH}"
os.makedirs(OUTPUT_DIR, exist_ok=True)

UEA_DATASETS = [
   "UWaveGestureLibrary",
   "ERing",
   "RacketSports",
#    "SpokenArabicDigits",
   "NATOPS",
#    "Heartbeat"
#    "PenDigits",
   "CharacterTrajectories",
   "SelfRegulationSCP1",
   "ArticularyWordRecognition",
   "Libras",
   "BasicMotions",
#    "PEMS-SF"
   "Cricket",
   "Epilepsy"
]

# Updated Methods list as requested
METHODS = [
    'LIME', 'KernelSHAP', 'Saliency', 'InputXGradient', 'IG', 
    'KeystoneIG_0.05', 'KeystoneIG_0.1', 'KeystoneIG_0.15', 
    'KeystoneIG_0.2', 'KeystoneIG_0.25', 'BestEnsemble', 'Input_only_amortized_attr_v2', 'Distill_only_amortized_attr_v2', 'amortized_attr_v2'
]

# Professional, publication-quality color mapping
COLORS = {
    # Traditional Baselines: Neutral/Cool tones
    "LIME": "#9E9E9E",           # Grey
    "KernelSHAP": "#757575",     # Darker Grey
    "Saliency": "#607D8B",       # Slate Blue-Grey
    "InputXGradient": "#27AE60", # Emerald Green
    "IG": "#2980B9",             # Royal Blue
    
    # KeystoneIG Gradient: Purple-shaded transition
    "KeystoneIG_0.05": "#E1BEE7", # Light Lavender
    "KeystoneIG_0.1": "#CE93D8",
    "KeystoneIG_0.15": "#BA68C8",
    "KeystoneIG_0.2": "#9C27B0",
    "KeystoneIG_0.25": "#7B1FA2", # Deep Purple
    
    # Proposed Target: High Contrast Red
    "BestEnsemble": "#C62828",    # Crimson Red

    "Input_only_amortized_attr_v2": "#EB0FB8", #
    "Distill_only_amortized_attr_v2": "#2A19C0", # 
    "amortized_attr_v2": "#11CC20"    # 
}

def plot_metric_results(dataset_name, metric_suffix, ylabel, title_prefix, save_suffix):
    """
    Generalized plotting function for Accuracy and AUC-ROC results.
    """
    # Construct path to the DNN architecture folder
    folder_path = os.path.join(BASE_PATH, dataset_name, "DNN", ARCH)
    
    # if not os.path.exists(folder_path):
        # print(f"Skipping {dataset_name}: Folder not found.")
        # return

    # Setup the figure
    plt.figure(figsize=(12, 7))
    plt.style.use('seaborn-v0_8-whitegrid')
    
    found_data = False
    for method in METHODS:
        # Construct filename based on naming convention in experiment_explanation.py
        file_name = f"{dataset_name}-42-DNN-{ARCH}-test_{method}_{metric_suffix}.csv"
        file_path = os.path.join(folder_path, file_name)
        
        if 1:
            df = pd.read_csv(file_path)
            
            # Step 1: Detect X-axis column (handles naming variations)
            x_col = 'removal_ratio' if 'removal_ratio' in df.columns else 'ratio'
            
            # Step 2: Detect Y-axis and AUC columns
            if 'accuracy' in df.columns:
                y_col, auc_col = 'accuracy', 'faithfulness_auc'
            elif 'roc_auc' in df.columns:
                y_col, auc_col = 'roc_auc', 'faithfulness_auc'
            else:
                continue # Skip if it's a file format we don't recognize here

            ratios = df[x_col].values
            values = df[y_col].values
            auc_val = df[auc_col].iloc[0] # AUC is usually broadcasted across rows
            
            # Step 3: Plot the curve
            plt.plot(ratios, values, label=f"{method} (AUC: {auc_val:.4f})", 
                     color=COLORS.get(method, "#333333"), linewidth=2.0, 
                     marker='o', markersize=3, alpha=0.85)
            found_data = True
        else:
            print(f"File not found for {method} in {dataset_name}: {file_path}")
            print(f"Error encountered, exiting. with {method} and {dataset_name}")
            exit()

    # if found_data:
        # Finalize plot aesthetics
    plt.title(f"{title_prefix}: {dataset_name} (Bottom-Up Removal)", fontsize=16, fontweight='bold')
    plt.xlabel("Removal Ratio (Least Important -> Most Important)", fontsize=12)
    plt.ylabel(ylabel, fontsize=12)
    plt.xticks(np.arange(0, 1.1, 0.1))
    plt.ylim(0, 1.05)
    
    # Legend outside the plot area to avoid overlapping data curves
    plt.legend(frameon=True, fontsize=9, loc='center left', bbox_to_anchor=(1, 0.5))
    
    # Save results
    save_path = os.path.join(OUTPUT_DIR, f"{dataset_name}_{ARCH}_{save_suffix}_Distill_Input.png")
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    print(f"Successfully saved {save_suffix} plot for {dataset_name} to {save_path}")
    # else:
        # print(f"No valid data found for {dataset_name} to plot {save_suffix}.")
    
    plt.close()

if __name__ == "__main__":
    print("Starting visual verification of faithfulness results (Confidence data removed)...")
    
    for ds in UEA_DATASETS:
        # 1. Visualize Accuracy Faithfulness
        # Expected file suffix: _BottomUp_Mean_Faithfulness
        plot_metric_results(
            ds, 
            "BottomUp_Accuracy_Faithfulness", 
            "Model Accuracy", 
            "Accuracy Decay", 
            "accuracy_summary"
        )
        
        # 2. Visualize AUC-ROC Faithfulness
        # Expected file suffix: _BottomUp_Mean_AUCROC_Faithfulness
        plot_metric_results(
            ds, 
            "BottomUp_Roc_auc_Faithfulness", 
            "AUC-ROC Score", 
            "Fidelity (AUC-ROC)", 
            "AUCROC_summary"
        )

    print("\nProcessing complete. Check plots in:", OUTPUT_DIR)
