import os
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

# Configuration
BASE_PATH = "/home/40430660@eeecs.qub.ac.uk/InterpretGatedNetwork/explanations"
ARCH = "FCN"
OUTPUT_DIR = f"./summary_plots/train_data_confidence_accuracy/{ARCH}"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# UEA_DATASETS = [
#     "UWaveGestureLibrary", "ERing", "RacketSports", "SpokenArabicDigits",
#     "NATOPS", "Heartbeat", "PenDigits", "CharacterTrajectories",
#     "SelfRegulationSCP1", "ArticularyWordRecognition", "Libras",
#     "BasicMotions", "PEMS-SF", "Cricket", "Epilepsy"
# ]

UEA_DATASETS = ["UWaveGestureLibrary", "ERing"]

# METHODS = ["IG", "Saliency", "InputXGradient", "BestEnsemble"]
# Updated list of methods for visualization
METHODS = ['Saliency', 'InputXGradient', 'IG', 'SegIG_8', 'SegIG_16', 'SegIG_32', "BestEnsemble"]

# Professional and attractive color mapping
COLORS = {
    "Saliency": "#FF7F0E",       # Vibrant Orange: High visibility for baseline
    "InputXGradient": "#2CA02C", # Forest Green: Distinctive and calm
    "IG": "#1F77B4",             # Elegant Blue: Standard for Integrated Gradients
    
    # SegIG Gradient: Using shades of purple to show increasing resolution
    "SegIG_8": "#D4B9DA",        # Light Lavender
    "SegIG_16": "#9E9AC8",       # Medium Purple
    "SegIG_32": "#6A51A3",       # Deep Royal Purple
    
    "BestEnsemble": "#D62728"    # Deep Red: Maximum contrast to highlight your proposed winner
}

def plot_dataset_confidence_results(dataset_name):
    # Path to the specific DNN/FCN folder
    folder_path = os.path.join(BASE_PATH, dataset_name, "DNN", ARCH)
    
    if not os.path.exists(folder_path):
        print(f"Skipping {dataset_name}: Folder not found.")
        return

    plt.figure(figsize=(10, 6))
    plt.style.use('seaborn-v0_8-whitegrid')
    
    found_data = False
    for method in METHODS:
        # Match the filename pattern: Dataset-42-DNN-FCN-train_Method_BottomUp_Confidence_Faithfulness.csv
        file_name = f"{dataset_name}-42-DNN-{ARCH}-train_{method}_BottomUp_Confidence_Faithfulness.csv"
        file_path = os.path.join(folder_path, file_name)
        
        if os.path.exists(file_path):
            df = pd.read_csv(file_path)
            
            # Extract ratio, confidence curve, and the broadcasted AUC value
            ratios = df['ratio'].values
            confidence = df['avg_confidence'].values
            auc_val = df['auc_confidence'].iloc[0] # Broadcasted value
            
            plt.plot(ratios, confidence, label=f"{method} (AUC: {auc_val:.4f})", 
                     color=COLORS[method], linewidth=2.5, marker='o', markersize=4)
            found_data = True
        else:
            print(f"  Warning: File not found for {method} in {dataset_name}")

    if found_data:
        plt.title(f"Confidence Decay: {dataset_name} (Bottom-Up)", fontsize=16, fontweight='bold')
        plt.xlabel("Removal Ratio (Least Important -> Most Important)", fontsize=12)
        plt.ylabel("Model Confidence (GT Class Probability)", fontsize=12)
        plt.xticks(np.arange(0, 1.1, 0.1))
        plt.ylim(0, 1.05)
        plt.legend(frameon=True, fontsize=10, loc='lower left')
        
        # Save figure
        save_path = os.path.join(OUTPUT_DIR, f"{dataset_name}_{ARCH}_confidence_summary.png")
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Successfully saved plot for {dataset_name} to {save_path}")
    
    plt.close()

def plot_dataset_accuracy_results(dataset_name):
    # Path to the specific DNN/FCN folder
    folder_path = os.path.join(BASE_PATH, dataset_name, "DNN", ARCH)
    
    if not os.path.exists(folder_path):
        print(f"Skipping {dataset_name}: Folder not found.")
        return

    plt.figure(figsize=(10, 6))
    plt.style.use('seaborn-v0_8-whitegrid')
    
    found_data = False
    for method in METHODS:
        # Match the filename pattern: Dataset-42-DNN-FCN-train_Method_BottomUp_Confidence_Faithfulness.csv
        file_name = f"{dataset_name}-42-DNN-{ARCH}-train_{method}_BottomUp_Mean_Faithfulness.csv"
        file_path = os.path.join(folder_path, file_name)
        
        if os.path.exists(file_path):
            df = pd.read_csv(file_path)
            
            # Extract ratio, confidence curve, and the broadcasted AUC value
            ratios = df['removal_ratio'].values
            confidence = df['accuracy'].values
            auc_val = df['faithfulness_auc'].iloc[0] # Broadcasted value
            
            plt.plot(ratios, confidence, label=f"{method} (AUC: {auc_val:.4f})", 
                     color=COLORS[method], linewidth=2.5, marker='o', markersize=4)
            found_data = True
        else:
            print(f"  Warning: File not found for {method} in {dataset_name}")

    if found_data:
        plt.title(f"Confidence Decay: {dataset_name} (Bottom-Up)", fontsize=16, fontweight='bold')
        plt.xlabel("Removal Ratio (Least Important -> Most Important)", fontsize=12)
        plt.ylabel("Model Confidence (GT Class Probability)", fontsize=12)
        plt.xticks(np.arange(0, 1.1, 0.1))
        plt.ylim(0, 1.05)
        plt.legend(frameon=True, fontsize=10, loc='lower left')
        
        # Save figure
        save_path = os.path.join(OUTPUT_DIR, f"{dataset_name}_{ARCH}_accuracy_summary.png")
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Successfully saved plot for {dataset_name} to {save_path}")
    
    plt.close()

def plot_dataset_accuracy_results_aucroc(dataset_name):
    # Path to the specific DNN/FCN folder
    folder_path = os.path.join(BASE_PATH, dataset_name, "DNN", ARCH)
    
    if not os.path.exists(folder_path):
        print(f"Skipping {dataset_name}: Folder not found.")
        return

    plt.figure(figsize=(10, 6))
    plt.style.use('seaborn-v0_8-whitegrid')
    
    found_data = False
    for method in METHODS:
        # Match the filename pattern: Dataset-42-DNN-FCN-test_Method_BottomUp_Confidence_Faithfulness.csv
        file_name = f"{dataset_name}-42-DNN-{ARCH}-train_{method}_BottomUp_Mean_AUCROC_Faithfulness.csv"
        file_path = os.path.join(folder_path, file_name)
        
        if os.path.exists(file_path):
            df = pd.read_csv(file_path)
            
            # Extract ratio, confidence curve, and the broadcasted AUC value
            ratios = df['removal_ratio'].values
            confidence = df['auc_roc'].values
            auc_val = df['faithfulness_auc_roc'].iloc[0] # Broadcasted value
            
            plt.plot(ratios, confidence, label=f"{method} (AUC: {auc_val:.4f})", 
                     color=COLORS[method], linewidth=2.5, marker='o', markersize=4)
            found_data = True
        else:
            print(f"AUCROC  Warning: File not found for {method} in {dataset_name}: {file_path}")

    if found_data:
        plt.title(f"Confidence Decay: {dataset_name} (Bottom-Up)", fontsize=16, fontweight='bold')
        plt.xlabel("Removal Ratio (Least Important -> Most Important)", fontsize=12)
        plt.ylabel("AUCROC", fontsize=12)
        plt.xticks(np.arange(0, 1.1, 0.1))
        plt.ylim(0, 1.05)
        plt.legend(frameon=True, fontsize=10, loc='lower left')
        
        # Save figure
        save_path = os.path.join(OUTPUT_DIR, f"{dataset_name}_{ARCH}_AUCROC_summary.png")
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Successfully saved plot for {dataset_name} to {save_path}")
    
    plt.close()

if __name__ == "__main__":
    print("Starting result aggregation and visualization...")
    for ds in UEA_DATASETS:
        plot_dataset_accuracy_results(ds)
        plot_dataset_confidence_results(ds)
        plot_dataset_accuracy_results_aucroc(ds)
    print("\nProcessing complete. Check the 'summary_plots' directory.")
