import torch
import torch.nn as nn
from torch.utils.data import DataLoader

import os
import sys
import time
from datetime import datetime
import numpy as np
import pandas as pd
from dataclasses import dataclass, field
from sklearn.metrics import accuracy_score
from sklearn.metrics import roc_auc_score


from models.Shapelet import ShapeBottleneckModel, DistThresholdSBM
from models.InterpGN import InterpGN, FullyConvNetwork, Transformer, TimesNet, PatchTST, ResNet

from utils.tools import EarlyStopping, convert_to_hms, gini_coefficient
from utils.shapelet_util import ClassificationResult
from data_provider.data_factory import data_provider

# Explanation libraries  
from captum.attr import IntegratedGradients
import captum.attr as attr_lib
from einops import rearrange

import copy

def compute_beta(epoch, max_epoch, schedule='cosine'):
    if schedule == 'cosine':
        beta = 1/2 * (1 + np.cos(np.pi*epoch/max_epoch))
    elif schedule == 'linear':
        beta = 1 - epoch/max_epoch
    else:
        beta = 1
    return beta


def compute_shapelet_score(shapelet_distances, cls_weights, y_pred, y_true):
    score = shapelet_distances @ nn.functional.relu(cls_weights.T) / shapelet_distances.shape[-1]
    score_correct = score[y_pred == y_true]
    class_correct = y_true[y_pred == y_true]
    score_class = score_correct.gather(-1, class_correct.unsqueeze(1))
    return score_class.mean().item()


def get_dnn_model(configs):
    dnn_dict = {
        'FCN': FullyConvNetwork,
        'Transformer': Transformer,
        'TimesNet': TimesNet,
        'PatchTST': PatchTST,
        'ResNet': ResNet
    }
    # print(configs)
    # exit()
    clone_configs = copy.deepcopy(configs)
    clone_configs.task_name = 'classification'
    return dnn_dict[configs.dnn_type](clone_configs)


class Explanation_Experiment(object):
    model_dict = {
        'InterpGN': InterpGN,
        'SBM': ShapeBottleneckModel,
        'LTS': DistThresholdSBM,
        'DNN': get_dnn_model
    }
    def __init__(self, args):
        self.train_data, self.train_loader = data_provider(args, flag="TRAIN")
        self.test_data, self.test_loader = data_provider(args, flag="TEST")
        self.val_data, self.val_loader = data_provider(args, flag='TEST')

        # print(self.train_data[0][0].dtype, self.train_data[0][0].shape)
        # exit()


        args.seq_len = max(self.train_data.max_seq_len, self.test_data.max_seq_len)
        args.pred_len = 0
        args.enc_in = self.train_data.feature_df.shape[1]
        args.num_class = len(self.train_data.class_names)
        self.epoch_stop = 0

        # Build Model
        self.args = args
        self.device = torch.device('cuda')
        self.loss_fn = nn.CrossEntropyLoss()
        self.model = self._build_model().to(self.device)
        # Check the dtype of the first available parameter
        # print(next(self.model.parameters()).dtype)
        # exit()
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.args.lr)
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(self.optimizer, T_0=self.args.train_epochs)
        self.checkpoint_dir = "./checkpoints/{}/{}/dnn-{}_seed-{}_k-{}_div-{}_reg-{}_eps-{}_beta-{}_dfunc-{}_cls-{}".format(
            self.args.model,
            self.args.dataset,
            self.args.dnn_type,
            self.args.seed,
            self.args.num_shapelet,
            self.args.lambda_div,
            self.args.lambda_reg,
            self.args.epsilon,
            self.args.beta_schedule,
            self.args.distance_func,
            self.args.sbm_cls
        )
        print(f"Model loaded from {self.checkpoint_dir}/checkpoint.pth")
        # print(self.model)
        # exit()
        print("Loading model checkpoint...")
        self.model.load_state_dict(torch.load(f"{self.checkpoint_dir}/checkpoint.pth"))
        if self.args.multi_gpu:
            self.model = nn.DataParallel(self.model)
        # exit()

    def print_args(self):
        for arg in vars(self.args):
            print(f"{arg}: {getattr(self.args, arg)}")

    def _build_model(self):
        shapelet_lengths = [0.05, 0.1, 0.2, 0.3, 0.5, 0.8]
        num_shapelet = [self.args.num_shapelet] * len(shapelet_lengths)

        model = self.model_dict[self.args.model](
            configs=self.args,
            # num_shapelet = num_shapelet,
            # shapelet_len = shapelet_lengths,
        )

        return model

    def test(self, save_csv=True, result_dir=None):
        if not os.path.isdir(result_dir):
            try:
                os.makedirs(result_dir)
            except:
                pass
        
        @dataclass
        class Buffer:
            x_data: list = field(default_factory=list)
            trues: list = field(default_factory=list)
            preds: list = field(default_factory=list)
            shapelet_preds: list = field(default_factory=list)
            dnn_preds: list = field(default_factory=list)
            p: list = field(default_factory=list)
            d: list = field(default_factory=list)
            eta: list = field(default_factory=list)
            loss: list = field(default_factory=list)

        buffer = Buffer()
        self.model.eval()
        with torch.no_grad():
            for i, (batch_x, label, padding_mask) in enumerate(self.test_loader):
                batch_x = batch_x.float().to(self.device)
                label = label.long().squeeze(-1).to(self.device)
                padding_mask = padding_mask.float().to(self.device)

                with torch.autocast(device_type=self.device.type, dtype=torch.bfloat16, enabled=self.args.amp):
                    if self.args.model == 'DNN':
                        logits = self.model(batch_x, padding_mask, None, None)
                        loss = nn.functional.cross_entropy(logits, label, reduction='none')
                    else:
                        logits, model_info = self.model(batch_x, padding_mask, None, None, gating_value=self.args.gating_value)
                        loss = nn.functional.cross_entropy(logits, label, reduction='none') + model_info.loss.mean()
                
                buffer.loss.append(loss.flatten())
                buffer.x_data.append(batch_x.cpu())
                buffer.trues.append(label.cpu())
                buffer.preds.append(logits.cpu())
                if self.args.model != 'DNN':
                    buffer.p.append(model_info.p.cpu())
                    buffer.d.append(model_info.d.cpu())
                    buffer.shapelet_preds.append(model_info.shapelet_preds.cpu())
                    if self.args.model == 'InterpGN':
                        buffer.eta.append(model_info.eta.cpu())
                        buffer.dnn_preds.append(model_info.dnn_preds.cpu())
                        
        probs = torch.nn.functional.softmax(torch.cat(buffer.preds, dim=0), dim=1)  # (total_samples, num_classes) est. prob. for each class and sample
        predictions = torch.argmax(probs, dim=1)  # (total_samples,) int class index for each sample
        trues = torch.cat(buffer.trues, dim=0).flatten()
        accuracy = accuracy_score(predictions.cpu().numpy(), trues.cpu().numpy())


        cls_result = ClassificationResult(
            x_data=torch.cat(buffer.x_data, dim=0).cpu(),
            trues=trues.cpu(),
            preds=predictions.cpu(),
            loss=torch.cat(buffer.loss, dim=0).mean().item(),
            accuracy=accuracy
        )

        return cls_result

    def get_train_utility(self, save_csv=True, result_dir=None):
        if not os.path.isdir(result_dir):
            try:
                os.makedirs(result_dir)
            except:
                pass
        
        @dataclass
        class Buffer:
            x_data: list = field(default_factory=list)
            trues: list = field(default_factory=list)
            preds: list = field(default_factory=list)
            shapelet_preds: list = field(default_factory=list)
            dnn_preds: list = field(default_factory=list)
            p: list = field(default_factory=list)
            d: list = field(default_factory=list)
            eta: list = field(default_factory=list)
            loss: list = field(default_factory=list)

        buffer = Buffer()
        self.model.eval()
        with torch.no_grad():
            for i, (batch_x, label, padding_mask) in enumerate(self.train_loader):
                batch_x = batch_x.float().to(self.device)
                label = label.long().squeeze(-1).to(self.device)
                padding_mask = padding_mask.float().to(self.device)

                with torch.autocast(device_type=self.device.type, dtype=torch.bfloat16, enabled=self.args.amp):
                    if self.args.model == 'DNN':
                        logits = self.model(batch_x, padding_mask, None, None)
                        loss = nn.functional.cross_entropy(logits, label, reduction='none')
                    else:
                        logits, model_info = self.model(batch_x, padding_mask, None, None, gating_value=self.args.gating_value)
                        loss = nn.functional.cross_entropy(logits, label, reduction='none') + model_info.loss.mean()
                
                buffer.loss.append(loss.flatten())
                buffer.x_data.append(batch_x.cpu())
                buffer.trues.append(label.cpu())
                buffer.preds.append(logits.cpu())
                if self.args.model != 'DNN':
                    buffer.p.append(model_info.p.cpu())
                    buffer.d.append(model_info.d.cpu())
                    buffer.shapelet_preds.append(model_info.shapelet_preds.cpu())
                    if self.args.model == 'InterpGN':
                        buffer.eta.append(model_info.eta.cpu())
                        buffer.dnn_preds.append(model_info.dnn_preds.cpu())
                        
        probs = torch.nn.functional.softmax(torch.cat(buffer.preds, dim=0), dim=1)  # (total_samples, num_classes) est. prob. for each class and sample
        predictions = torch.argmax(probs, dim=1)  # (total_samples,) int class index for each sample
        trues = torch.cat(buffer.trues, dim=0).flatten()
        accuracy = accuracy_score(predictions.cpu().numpy(), trues.cpu().numpy())


        cls_result = ClassificationResult(
            x_data=torch.cat(buffer.x_data, dim=0).cpu(),
            trues=trues.cpu(),
            preds=predictions.cpu(),
            loss=torch.cat(buffer.loss, dim=0).mean().item(),
            accuracy=accuracy
        )

        return cls_result
    
    def explain_IG(self, target_set='test', save_dir=None):
        """
        Generates Integrated Gradients attribution maps for the specified dataset.
        Saves a dictionary containing attributions, labels, and predictions as a .npy file.
        """
        
        # Define saving path
        if save_dir is None:
            save_dir = os.path.join(f"./explanations/{self.args.dataset}", self.args.model, self.args.dnn_type)
        os.makedirs(save_dir, exist_ok=True)

        self.model.eval()
        loader = self.test_loader if target_set == 'test' else self.train_loader

        
        # Wrapper function for Captum
        # InterpGN returns (logits, model_info), but IG needs only logits
        def forward_wrapper(x, padding_mask):
            if self.args.model == 'DNN':
                return self.model(x, padding_mask, None, None)
            else:
                # gating_value is passed if specified in args
                logits, _ = self.model(x, padding_mask, None, None, gating_value=self.args.gating_value)
                return logits

        ig = IntegratedGradients(forward_wrapper)
        
        all_attributions = []
        all_labels = []
        all_preds = []

        print(f"Generating IG explanations for {target_set} set...")
        for i, (batch_x, label, padding_mask) in enumerate(loader):
            batch_x = batch_x.float().to(self.device).requires_grad_()
            padding_mask = padding_mask.float().to(self.device)
            label = label.long().squeeze(-1).to(self.device)

            # Obtain predicted classes to use as attribution targets
            with torch.no_grad():
                logits = forward_wrapper(batch_x, padding_mask)
                preds = torch.argmax(logits, dim=1)

            # Calculate IG attributions
            # n_steps=50 is standard; higher is more accurate but slower
            attributions = ig.attribute(
                inputs=batch_x, 
                target=preds, 
                additional_forward_args=(padding_mask),
                n_steps=50,
                internal_batch_size = 1
            )
            
            all_attributions.append(attributions.detach().cpu().numpy())
            all_labels.append(label.cpu().numpy())
            all_preds.append(preds.cpu().numpy())

        # Concatenate results across all batches
        attr_map = np.concatenate(all_attributions, axis=0)
        labels = np.concatenate(all_labels, axis=0)
        preds = np.concatenate(all_preds, axis=0)

        # Define file name using metadata for consistency
        # current_time = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
        file_name = f"{self.args.dataset}-{self.args.seed}-{self.args.model}-{self.args.dnn_type}-{target_set}_IG.npy"
        save_path = os.path.join(save_dir, file_name)
        
        # Save results
        np.save(save_path, {
            'attributions': attr_map,
            'labels': labels,
            'predictions': preds,
            'metadata': {
                'model': self.args.model,
                'dnn_type': self.args.dnn_type,
                'dataset': self.args.dataset
            }
        })
        
        print(f"Explanation map saved at: {save_path}")
        return attr_map

    def _forward_wrapper(self, x, padding_mask):
        """Wrapper to handle different model output formats."""
        if self.args.model == 'DNN':
            return self.model(x, padding_mask, None, None)
        else:
            logits, _ = self.model(x, padding_mask, None, None, gating_value=self.args.gating_value)
            return logits
    
    # def _PatchTST_forward_wrapper(self, x, padding_mask):
    #     """Wrapper with explicit cloning to prevent in-place modification errors."""
    #     # 1. Force a clone of the input to ensure it has its own memory space
    #     x_input = x.clone() 
        
    #     # 2. Ensure gradients are enabled for this specific pass
    #     with torch.enable_grad():
    #         if self.args.model == 'DNN':
    #             # PatchTST needs specific args
    #             return self.model(x_input, padding_mask, None, None)
    #         else:
    #             logits, _ = self.model(x_input, padding_mask, None, None, 
    #                                   gating_value=self.args.gating_value)
    #             return logits

    def explain_SegIG(self, n, target_set='test'):
        """Generates Segmented Integrated Gradients (SegIG) explanations."""
        ig = IntegratedGradients(self._forward_wrapper)
        loader = self.test_loader if target_set == 'test' else self.train_loader
        self.model.eval()

        all_attr, all_labels, all_preds = [], [], []
        print(f"Generating SegIG_{n} for {target_set} set...")

        for batch_x, label, padding_mask in loader:
            batch_x = batch_x.float().to(self.device).requires_grad_()
            padding_mask = padding_mask.float().to(self.device)
            
            with torch.no_grad():
                logits = self._forward_wrapper(batch_x, padding_mask)
                preds = torch.argmax(logits, dim=1)
            
            # Standard IG: [Batch, Time, Channel]
            attr = ig.attribute(batch_x, target=preds, additional_forward_args=(padding_mask), internal_batch_size=8)
            attr_np = attr.detach().cpu().numpy()
            B, T, C = attr_np.shape

            # Divide Time (T) into n segments
            segmented_attr = np.zeros_like(attr_np)
            indices = np.array_split(np.arange(T), n)
            # print(indices)
            # exit()
            
            for segment_range in indices:
                # Sum IG scores within the segment for each channel
                seg_sum = attr_np[:, segment_range, :].sum(axis=1, keepdims=True)
                # Broadcast the sum to the whole segment
                segmented_attr[:, segment_range, :] = seg_sum

            all_attr.append(segmented_attr)
            all_labels.append(label.cpu().numpy())
            all_preds.append(preds.cpu().numpy())

        # Save logic
        save_dir = os.path.join(f"./explanations/{self.args.dataset}", self.args.model, self.args.dnn_type)
        os.makedirs(save_dir, exist_ok=True)
        file_name = f"{self.args.dataset}-{self.args.seed}-{self.args.model}-{self.args.dnn_type}-{target_set}_SegIG_{n}.npy"
        save_path = os.path.join(save_dir, file_name)
        
        np.save(save_path, {
            'attributions': np.concatenate(all_attr, axis=0),
            'labels': np.concatenate(all_labels, axis=0),
            'predictions': np.concatenate(all_preds, axis=0)
        })
        print(f"SegIG_{n} map saved at: {save_path}")
        return save_path

    def evaluate_faithfulness_bottom_up(self, target_set='test', explanation_method='IG'):
        """
        Evaluates faithfulness by removing the LEAST important features.
        Replaces removed features with the mean values obtained from the training dataset.
        Thresholds range from 0% to 100% with a 5% step.
        """
        # 1. Automatically determine the path based on current experiment settings
        save_dir = os.path.join(f"./explanations/{self.args.dataset}", self.args.model, self.args.dnn_type)
        file_name = f"{self.args.dataset}-{self.args.seed}-{self.args.model}-{self.args.dnn_type}-{target_set}_{explanation_method}.npy"
        map_path = os.path.join(save_dir, file_name)

        if not os.path.exists(map_path):
            print(f"Error: Explanation map not found at {map_path}. Run explain_IG first.")
            return

        # 2. Compute channel-wise means from the training dataset
        # feature_df contains all time steps for all training samples
        train_means_np = self.train_data.feature_df.mean().values 
        # print(self.train_data.feature_df.mean())
        # exit()
        train_means = torch.from_numpy(train_means_np).float().to(self.device)
        # train_means shape: [num_features] (will broadcast to [Batch, Seq_len, num_features])

        # 3. Load the explanation data (saved as a dictionary)
        data = np.load(map_path, allow_pickle=True).item()
        attributions = data['attributions'] # Shape: [N, seq_len, num_features]
        
        self.model.eval()
        loader = self.test_loader if target_set == 'test' else self.train_loader
        
        # 4. Define Thresholds (0% to 100% with 5% increments)
        thresholds = np.linspace(0, 1, 21)
        faith_results = []

        print(f"Starting Bottom-Up Evaluation (Mean-Replacement) on {target_set}...")
        print(f"Loading attributions from: {map_path}")

        for threshold in thresholds:
            all_preds = []
            all_trues = []
            attr_idx = 0
            percentile = threshold * 100
            
            with torch.no_grad():
                for i, (batch_x, label, padding_mask) in enumerate(loader):
                    batch_size = batch_x.shape[0]
                    batch_x = batch_x.float().to(self.device) # [Batch, Seq_len, Vars]
                    padding_mask = padding_mask.float().to(self.device)
                    
                    # Extract attributions for this batch
                    batch_attr = attributions[attr_idx : attr_idx + batch_size]
                    attr_idx += batch_size
                    batch_attr_abs = np.abs(batch_attr)
                    
                    if threshold == 0:
                        mask = torch.ones_like(batch_x)
                    elif threshold == 1:
                        mask = torch.zeros_like(batch_x)
                    else:
                        # Find threshold for each sample individually (least important 5%, 10%, etc.)
                        flat_attr = batch_attr_abs.reshape(batch_size, -1)
                        thresh_vals = np.percentile(flat_attr, percentile, axis=1, keepdims=True)
                        thresh_vals = thresh_vals.reshape(batch_size, 1, 1) 
                        
                        # Binary Mask: 1 for features to KEEP (high importance), 0 for REMOVE (low importance)
                        mask = torch.tensor(batch_attr_abs > thresh_vals).float().to(self.device)

                    # --- Mean Replacement Logic ---
                    # Replace removed features (1-mask) with training means
                    masked_x = (batch_x * mask) + (train_means * (1 - mask))
                    
                    if self.args.model == 'DNN':
                        logits = self.model(masked_x, padding_mask, None, None)
                    else:
                        logits, _ = self.model(masked_x, padding_mask, None, None, gating_value=self.args.gating_value)
                    
                    all_preds.append(torch.argmax(logits, dim=1).cpu())
                    all_trues.append(label.long().squeeze(-1).cpu())

            # 5. Calculate Accuracy
            # print(torch.cat(all_trues).shape, torch.cat(all_preds).shape)
            # exit()
            acc = accuracy_score(torch.cat(all_trues).numpy(), torch.cat(all_preds).numpy())
            faith_results.append({
                'removal_ratio': threshold,
                'accuracy': acc
            })
            print(f"Removed {percentile:3.0f}% least important | Accuracy: {acc:.4f}")

        # 6. Save results to a CSV in the same directory as the map
        summary_df = pd.DataFrame(faith_results)
        csv_save_path = map_path.replace('.npy', '_BottomUp_Mean_Faithfulness.csv')
        print(f"Evaluation complete. Results saved at: {csv_save_path}")
        auc_score = np.trapz(summary_df['accuracy'], summary_df['removal_ratio'])
        print(f"Faithfulness AUC (Accuracy vs. Removal Ratio): {auc_score:.4f}")
        summary_df['faithfulness_auc'] = auc_score
        summary_df.to_csv(csv_save_path, index=False)
        
        return summary_df
    
    def explanation_with_expert_name(self, explanation_name, target_set='test', **kwargs):
        """
        Unified interface for all Captum attribution methods and custom SegIG.
        
        Supported Captum names: 'IG', 'Saliency', 'DeepLift', 'InputXGradient', 
                                'FeatureAblation', 'Occlusion', 'GuidedBackprop'
        Supported Custom: 'SegIG_n' (e.g., 'SegIG_10')
        """
        torch.autograd.set_detect_anomaly(True)
        # 1. Dispatch to custom SegIG if requested
        if explanation_name.startswith('SegIG'):
            n = int(explanation_name.split('_')[1]) if '_' in explanation_name else kwargs.get('n', 10)
            return self.explain_SegIG(n=n, target_set=target_set)

        # 2. Map explanation names to Captum classes
        captum_mapping = {
            'IG': attr_lib.IntegratedGradients,
            'LIME': attr_lib.Lime,
            'KernelSHAP': attr_lib.KernelShap,
            'Saliency': attr_lib.Saliency,
            'InputXGradient': attr_lib.InputXGradient,
        }

        if explanation_name not in captum_mapping:
            raise ValueError(f"Method '{explanation_name}' not supported. Choose from {list(captum_mapping.keys())} or 'SegIG_n'.")

        # 3. Setup paths and model state
        save_dir = os.path.join(f"./explanations/{self.args.dataset}", self.args.model, self.args.dnn_type)
        os.makedirs(save_dir, exist_ok=True)
        self.model.eval()
        if target_set == 'test':
            print("Using test loader for explanations.")
            loader = self.test_loader 
        else:
            print("Using train loader for explanations.")
            loader =self.train_loader
        
        # Initialize Captum method
        # if self.args.dnn_type == 'PatchTST':
            # explainer = captum_mapping[explanation_name](self._PatchTST_forward_wrapper)
        # else:
        explainer = captum_mapping[explanation_name](self._forward_wrapper)
        
        all_attributions, all_labels, all_preds = [], [], []
        print(f"Generating {explanation_name} explanations for {target_set}...")

        for i, (batch_x, label, padding_mask) in enumerate(loader):
            if i % 100 == 0:
                print(f"  Processing batch {i}...")
            batch_x = batch_x.float().to(self.device).requires_grad_()
            # print(batch_x.shape)
            # exit()
            padding_mask = padding_mask.float().to(self.device)
            
            with torch.no_grad():
                logits = self._forward_wrapper(batch_x, padding_mask)
                preds = torch.argmax(logits, dim=1)

            # 4. Generate Attributions
            # Standard params for most methods
            attr_kwargs = {
                'inputs': batch_x,
                'target': preds,
                'additional_forward_args': (padding_mask)
            }
            
            # Method-specific adjustments
            if explanation_name == 'Occlusion':
                # Time-series occlusion usually needs a sliding window (e.g., 5 time steps)
                attr_kwargs['sliding_window_shapes'] = (1, kwargs.get('window', 5), 1) 
            elif explanation_name == 'IG':
                attr_kwargs['n_steps'] = kwargs.get('n_steps', 50)
                attr_kwargs['internal_batch_size'] = kwargs.get('internal_batch_size', 8)

            # print(self.model.task_name)
            # exit()
            attributions = explainer.attribute(**attr_kwargs)
            
            all_attributions.append(attributions.detach().cpu().numpy())
            all_labels.append(label.cpu().numpy())
            all_preds.append(preds.cpu().numpy())

        # 5. Save results
        file_name = f"{self.args.dataset}-{self.args.seed}-{self.args.model}-{self.args.dnn_type}-{target_set}_{explanation_name}.npy"
        save_path = os.path.join(save_dir, file_name)
        
        np.save(save_path, {
            'attributions': np.concatenate(all_attributions, axis=0),
            'labels': np.concatenate(all_labels, axis=0),
            'predictions': np.concatenate(all_preds, axis=0)
        })
        
        print(f"Explanation map ({explanation_name}) saved at: {save_path}")
        return save_path

    def calculate_instance_faithfulness(self, x, label, padding_mask, attribution, verbose=False):
        """
        Calculates faithfulness for a SINGLE instance.
        x: [1, T, C]
        attribution: [1, T, C]
        Returns: AUC score (float)
        """
        self.model.eval()
        # 1. Get Baseline (Training Mean)
        train_means = torch.from_numpy(self.train_data.feature_df.mean().values).float().to(self.device).view(1, 1, -1)
        
        # 2. Setup Thresholds (0% to 100% removal)
        thresholds = np.linspace(0, 1, 21)
        # print(thresholds)
        # exit()
        conf_scores = []

        if verbose:
            print(f"DEBUG: Explaining instance with true label {label.item()}")

        attr_abs = np.abs(attribution)
        # Flatten to find percentiles across all features in this sample
        flat_attr = attr_abs.flatten()

        for threshold in thresholds:
            with torch.no_grad():
                # Calculate threshold value for this specific instance
                percentile_val = np.percentile(flat_attr, threshold * 100)
                
                # Create Mask: 1 to KEEP (high importance), 0 to REMOVE (low importance)
                mask = torch.tensor(attr_abs > percentile_val).float().to(self.device)

                # Bottom-up: Replace least important features with training mean
                # BTC Format: [1, T, C]
                masked_x = (x * mask) + (train_means * (1 - mask))
                
                logits = self._forward_wrapper(masked_x, padding_mask)
                probs = torch.softmax(logits, dim=1)
                
                # Confidence in the GROUND TRUTH label
                conf = probs[0, label.item()].item()
                conf_scores.append(conf)

                if verbose and (threshold * 100) % 30 == 0:
                    print(f"  > DEBUG: Removed {threshold*100:3.0f}% | Confidence: {conf:.4f}")

        # 3. Calculate AUC (Accuracy/Confidence vs. Removal Ratio)
        auc_score = np.trapz(conf_scores, thresholds)
        
        if verbose:
            print(f"DEBUG: Instance Faithfulness AUC: {auc_score:.4f}")
            
        return auc_score

    def select_best_instance_explanation(self, methods, target_set='train'):
        """
        Iterates through all instances in the specified set (defaults to 'train') 
        and selects the best explanation expert using per-instance faithfulness.
        """
        save_dir = os.path.join(f"./explanations/{self.args.dataset}", self.args.model, self.args.dnn_type)
        
        # Determine the correct loader based on the target_set
        loader = self.train_loader if target_set == 'train' else self.test_loader
        num_instances = len(loader.dataset)
        
        # 1. Load all expert maps for the chosen set into memory
        expert_data = {}
        for m in methods:
            # Filename depends on target_set (e.g., ...-train_IG.npy)
            path = os.path.join(save_dir, f"{self.args.dataset}-{self.args.seed}-{self.args.model}-{self.args.dnn_type}-{target_set}_{m}.npy")
            if os.path.exists(path):
                expert_data[m] = np.load(path, allow_pickle=True).item()['attributions']
            else:
                print(f"DEBUG WARNING: File not found for {m} on {target_set} set at {path}")

        best_attr_map = []
        winner_names = []
        best_aucs = []
        
        print(f"\n--- DEBUG: STARTING ENSEMBLE SELECTION ON {target_set.upper()} SET ---")
        print(f"DEBUG: Evaluating {len(expert_data)} experts across {num_instances} samples.")

        # 2. Iterate through the dataset instance by instance
        # We manually track global index to sync with the loaded attribution maps
        global_idx = 0
        for i, (batch_x, label, padding_mask) in enumerate(loader):
            # print(batch_x.shape)
            # if i == 3:
            #     exit()
            # else:
            #     continue
            for b in range(batch_x.shape[0]):
                x_i = batch_x[b:b+1].to(self.device)
                label_i = label[b].to(self.device)
                mask_i = padding_mask[b:b+1].to(self.device)

                best_instance_auc = -1.0
                best_instance_attr = None
                winning_expert = None

                # 3. Evaluate each "Expert" for this specific sample
                for m_name, attr_full in expert_data.items():
                    # Extract the specific attribution for sample global_idx
                    attr_i = attr_full[global_idx:global_idx+1]
                    
                    # Call the per-instance faithfulness function
                    auc = self.calculate_instance_faithfulness(
                        x_i, label_i, mask_i, attr_i, verbose=False
                    )
                    
                    if auc > best_instance_auc:
                        best_instance_auc = auc
                        best_instance_attr = attr_i
                        winning_expert = m_name
                
                # Store results for this instance
                best_attr_map.append(best_instance_attr)
                winner_names.append(winning_expert)
                best_aucs.append(best_instance_auc)

                if global_idx % 20 == 0 or global_idx == num_instances - 1:
                    print(f"  > DEBUG: Instance {global_idx}/{num_instances-1} | Winner: {winning_expert:15} | AUC: {best_instance_auc:.4f}")
                
                global_idx += 1

        # 4. Save the "Best-of-Ensemble" map for the chosen set
        final_map_array = np.concatenate(best_attr_map, axis=0)
        save_path = os.path.join(save_dir, f"{self.args.dataset}-{self.args.seed}-{self.args.model}-{self.args.dnn_type}-{target_set}_BestEnsemble.npy")
        
        np.save(save_path, {
            'attributions': final_map_array,
            'winning_methods': winner_names,
            'auc_scores': best_aucs,
            'metadata': {
                'target_set': target_set,
                'methods_evaluated': list(expert_data.keys()),
                'average_auc': np.mean(best_aucs)
            }
        })
        
        print(f"--- DEBUG: SUCCESS! BestEnsemble for {target_set} saved to {save_path} ---")
        print(f"DEBUG: Final Mean Faithfulness AUC: {np.mean(best_aucs):.4f}")
        return save_path
    
    def evaluate_confidence_bottom_up(self, target_set='test', explanation_method='IG'):
        """
        Evaluates faithfulness by calculating the average CONFIDENCE score 
        as the least important features are removed.
        """
        # 1. Automatic Pathfinding
        save_dir = os.path.join(f"./explanations/{self.args.dataset}", self.args.model, self.args.dnn_type)
        file_name = f"{self.args.dataset}-{self.args.seed}-{self.args.model}-{self.args.dnn_type}-{target_set}_{explanation_method}.npy"
        map_path = os.path.join(save_dir, file_name)

        if not os.path.exists(map_path):
            print(f"Error: map not found at {map_path}")
            return

        # 2. Baseline & Data Setup
        train_means = torch.from_numpy(self.train_data.feature_df.mean().values).float().to(self.device).view(1, 1, -1)
        data = np.load(map_path, allow_pickle=True).item()
        attributions = np.abs(data['attributions'])
        
        loader = self.test_loader if target_set == 'test' else self.train_loader
        thresholds = np.linspace(0, 1, 21) 
        confidence_results = []

        print(f"--- DEBUG: Evaluating {explanation_method} Confidence Decay (Bottom-Up) ---")

        for threshold in thresholds:
            all_confs = []
            attr_idx = 0
            percentile = threshold * 100
            
            with torch.no_grad():
                for batch_x, label, padding_mask in loader:
                    batch_size = batch_x.shape[0]
                    batch_x = batch_x.float().to(self.device)
                    padding_mask = padding_mask.float().to(self.device)
                    label = label.long().squeeze(-1).to(self.device)
                    
                    batch_attr = attributions[attr_idx : attr_idx + batch_size]
                    attr_idx += batch_size
                    
                    # Create Mask
                    flat_attr = batch_attr.reshape(batch_size, -1)
                    thresh_vals = np.percentile(flat_attr, percentile, axis=1, keepdims=True).reshape(batch_size, 1, 1)
                    mask = torch.tensor(batch_attr > thresh_vals).float().to(self.device)

                    # Replace with training mean
                    masked_x = (batch_x * mask) + (train_means * (1 - mask))
                    logits = self._forward_wrapper(masked_x, padding_mask)
                    
                    # Calculate Probability of the Ground Truth Label
                    probs = torch.softmax(logits, dim=1)
                    # print(probs.shape)
                    # print(probs.gather(1, label.unsqueeze(1)).shape)
                    # exit()
                    # true_class_probs = probs.gather(1, label.unsqueeze(1)).squeeze()
                    true_class_probs = probs.gather(1, label.unsqueeze(1))
                    all_confs.append(true_class_probs.cpu().numpy())
            avg_conf = np.mean(np.concatenate(all_confs))
            # print(avg_conf)
            # exit()
            confidence_results.append({'ratio': threshold, 'avg_confidence': avg_conf})
            print(f"  > Removed {percentile:3.0f}% | Avg Confidence: {avg_conf:.4f}")

        # 3. Calculate AUC of Confidence Curve
        summary_df = pd.DataFrame(confidence_results)
        auc_conf = np.trapz(summary_df['avg_confidence'], summary_df['ratio'])
        summary_df['auc_confidence'] = auc_conf
        
        # 4. Save
        csv_path = map_path.replace('.npy', '_BottomUp_Confidence_Faithfulness.csv')
        summary_df.to_csv(csv_path, index=False)
        print(f"DEBUG: Confidence AUC: {auc_conf:.4f} | Saved to: {csv_path}")
        return summary_df



    def evaluate_faithfulness_bottom_up_aucroc(self, target_set='test', explanation_method='IG'):
        """
        Evaluates faithfulness by removing the LEAST important features.
        Replaces removed features with training mean values.
        Calculates AUC-ROC at each threshold (0% to 100% removal).
        """
        # 1. Path and Data Setup
        save_dir = os.path.join(f"./explanations/{self.args.dataset}", self.args.model, self.args.dnn_type)
        file_name = f"{self.args.dataset}-{self.args.seed}-{self.args.model}-{self.args.dnn_type}-{target_set}_{explanation_method}.npy"
        map_path = os.path.join(save_dir, file_name)

        if not os.path.exists(map_path):
            print(f"Error: Explanation map not found at {map_path}.")
            return

        # 2. Baseline: Channel-wise training means
        train_means_np = self.train_data.feature_df.mean().values 
        train_means = torch.from_numpy(train_means_np).float().to(self.device)

        # 3. Load Explanation Data
        data = np.load(map_path, allow_pickle=True).item()
        attributions = data['attributions']
        
        self.model.eval()
        loader = self.test_loader if target_set == 'test' else self.train_loader
        
        # 4. Define Thresholds
        thresholds = np.linspace(0, 1, 21)
        faith_results = []

        print(f"--- Starting Bottom-Up Evaluation (AUC-ROC) on {target_set} ---")

        for threshold in thresholds:
            all_probs = []
            all_trues = []
            attr_idx = 0
            percentile = threshold * 100
            
            with torch.no_grad():
                for i, (batch_x, label, padding_mask) in enumerate(loader):
                    batch_size = batch_x.shape[0]
                    batch_x = batch_x.float().to(self.device)
                    padding_mask = padding_mask.float().to(self.device)
                    
                    batch_attr_abs = np.abs(attributions[attr_idx : attr_idx + batch_size])
                    attr_idx += batch_size
                    
                    if threshold == 0:
                        mask = torch.ones_like(batch_x)
                    elif threshold == 1:
                        mask = torch.zeros_like(batch_x)
                    else:
                        flat_attr = batch_attr_abs.reshape(batch_size, -1)
                        thresh_vals = np.percentile(flat_attr, percentile, axis=1, keepdims=True)
                        thresh_vals = thresh_vals.reshape(batch_size, 1, 1) 
                        mask = torch.tensor(batch_attr_abs > thresh_vals).float().to(self.device)

                    # Mean Replacement
                    masked_x = (batch_x * mask) + (train_means * (1 - mask))
                    
                    logits = self._forward_wrapper(masked_x, padding_mask)
                    probs = torch.softmax(logits, dim=1)
                    
                    all_probs.append(probs.cpu())
                    all_trues.append(label.long().squeeze(-1).cpu())

            # 5. Calculate AUC-ROC for this threshold
            y_true = torch.cat(all_trues).numpy()
            y_scores = torch.cat(all_probs).numpy()
            
            try:
                # Handle Multi-class vs Binary
                if self.args.num_class == 2:
                    # For binary, we use the probability of the positive class (column 1)
                    roc_auc = roc_auc_score(y_true, y_scores[:, 1])
                else:
                    # For multi-class, use One-vs-Rest (OvR)
                    roc_auc = roc_auc_score(y_true, y_scores, multi_class='ovr', average='macro')
            except ValueError:
                # Fallback if a batch doesn't have all classes represented
                roc_auc = np.nan

            faith_results.append({
                'removal_ratio': threshold,
                'auc_roc': roc_auc
            })
            print(f"Removed {percentile:3.0f}% least important | AUC-ROC: {roc_auc:.4f}")

        # 6. Save results
        summary_df = pd.DataFrame(faith_results)
        csv_save_path = map_path.replace('.npy', '_BottomUp_Mean_AUCROC_Faithfulness.csv')
        
        # Calculate the AUC of the Faithfulness Curve (Area under the AUC-ROC points)
        # We drop NaNs to ensure trapz works
        clean_df = summary_df.dropna()
        faithfulness_auc_val = np.trapz(clean_df['auc_roc'], clean_df['removal_ratio'])
        
        summary_df['faithfulness_auc_roc'] = faithfulness_auc_val
        summary_df.to_csv(csv_save_path, index=False)
        
        print(f"Evaluation complete. Saved at: {csv_save_path}")
        print(f"Overall Faithfulness (Area under AUC-ROC curve): {faithfulness_auc_val:.4f}")
        
        return summary_df
