import torch
import torch.nn as nn
from torch.utils.data import DataLoader

import os
import sys
import time
from datetime import datetime
import numpy as np
import pandas as pd
from dataclasses import dataclass, field
from sklearn.metrics import accuracy_score
from sklearn.metrics import roc_auc_score


from models.Shapelet import ShapeBottleneckModel, DistThresholdSBM
from models.InterpGN import InterpGN, FullyConvNetwork, Transformer, TimesNet, PatchTST, ResNet

from utils.tools import EarlyStopping, convert_to_hms, gini_coefficient
from utils.shapelet_util import ClassificationResult
from data_provider.data_factory import data_provider, noisy_data_provider

# Explanation libraries  
from captum.attr import IntegratedGradients
import captum.attr as attr_lib
from einops import rearrange

import copy

def compute_beta(epoch, max_epoch, schedule='cosine'):
    if schedule == 'cosine':
        beta = 1/2 * (1 + np.cos(np.pi*epoch/max_epoch))
    elif schedule == 'linear':
        beta = 1 - epoch/max_epoch
    else:
        beta = 1
    return beta


def compute_shapelet_score(shapelet_distances, cls_weights, y_pred, y_true):
    score = shapelet_distances @ nn.functional.relu(cls_weights.T) / shapelet_distances.shape[-1]
    score_correct = score[y_pred == y_true]
    class_correct = y_true[y_pred == y_true]
    score_class = score_correct.gather(-1, class_correct.unsqueeze(1))
    return score_class.mean().item()


def get_dnn_model(configs):
    dnn_dict = {
        'FCN': FullyConvNetwork,
        'Transformer': Transformer,
        'TimesNet': TimesNet,
        'PatchTST': PatchTST,
        'ResNet': ResNet
    }
    # print(configs)
    # exit()
    clone_configs = copy.deepcopy(configs)
    clone_configs.task_name = 'classification'
    return dnn_dict[configs.dnn_type](clone_configs)


class Explanation_Experiment(object):
    model_dict = {
        'InterpGN': InterpGN,
        'SBM': ShapeBottleneckModel,
        'LTS': DistThresholdSBM,
        'DNN': get_dnn_model
    }
    def __init__(self, args):
        self.train_data, self.train_loader = data_provider(args, flag="TRAIN")
        self.test_data, self.test_loader = data_provider(args, flag="TEST")
        self.val_data, self.val_loader = data_provider(args, flag='TEST')
        self.noisy_train_data, self.noisy_train_loader = noisy_data_provider(args, flag="TRAIN")
        self.noisy_test_data, self.noisy_test_loader = noisy_data_provider(args, flag="TEST")
        self.noisy_val_data, self.noisy_val_loader = noisy_data_provider(args, flag='TEST')

        # print(self.train_data[0][0].dtype, self.train_data[0][0].shape)
        # exit()


        args.seq_len = max(self.train_data.max_seq_len, self.test_data.max_seq_len)
        args.pred_len = 0
        args.enc_in = self.train_data.feature_df.shape[1]
        args.num_class = len(self.train_data.class_names)
        self.epoch_stop = 0

        # Build Model
        self.args = args
        self.device = torch.device('cuda')
        self.loss_fn = nn.CrossEntropyLoss()
        self.model = self._build_model().to(self.device)
        # Check the dtype of the first available parameter
        # print(next(self.model.parameters()).dtype)
        # exit()
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.args.lr)
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(self.optimizer, T_0=self.args.train_epochs)
        self.checkpoint_dir = "./checkpoints/{}/{}/dnn-{}_seed-{}_k-{}_div-{}_reg-{}_eps-{}_beta-{}_dfunc-{}_cls-{}".format(
            self.args.model,
            self.args.dataset,
            self.args.dnn_type,
            self.args.seed,
            self.args.num_shapelet,
            self.args.lambda_div,
            self.args.lambda_reg,
            self.args.epsilon,
            self.args.beta_schedule,
            self.args.distance_func,
            self.args.sbm_cls
        )
        print(f"Model loaded from {self.checkpoint_dir}/checkpoint.pth")
        # print(self.model)
        # exit()
        print("Loading model checkpoint...")
        self.model.load_state_dict(torch.load(f"{self.checkpoint_dir}/checkpoint.pth"))
        if self.args.multi_gpu:
            self.model = nn.DataParallel(self.model)
        # exit()

    def print_args(self):
        for arg in vars(self.args):
            print(f"{arg}: {getattr(self.args, arg)}")

    def _build_model(self):
        shapelet_lengths = [0.05, 0.1, 0.2, 0.3, 0.5, 0.8]
        num_shapelet = [self.args.num_shapelet] * len(shapelet_lengths)

        model = self.model_dict[self.args.model](
            configs=self.args,
            # num_shapelet = num_shapelet,
            # shapelet_len = shapelet_lengths,
        )

        return model

    def test(self, save_csv=True, result_dir=None):
        if not os.path.isdir(result_dir):
            try:
                os.makedirs(result_dir)
            except:
                pass
        
        @dataclass
        class Buffer:
            x_data: list = field(default_factory=list)
            trues: list = field(default_factory=list)
            preds: list = field(default_factory=list)
            shapelet_preds: list = field(default_factory=list)
            dnn_preds: list = field(default_factory=list)
            p: list = field(default_factory=list)
            d: list = field(default_factory=list)
            eta: list = field(default_factory=list)
            loss: list = field(default_factory=list)

        buffer = Buffer()
        self.model.eval()
        with torch.no_grad():
            for i, (batch_x, label, padding_mask) in enumerate(self.test_loader):
                batch_x = batch_x.float().to(self.device)
                label = label.long().squeeze(-1).to(self.device)
                padding_mask = padding_mask.float().to(self.device)

                with torch.autocast(device_type=self.device.type, dtype=torch.bfloat16, enabled=self.args.amp):
                    if self.args.model == 'DNN':
                        logits = self.model(batch_x, padding_mask, None, None)
                        loss = nn.functional.cross_entropy(logits, label, reduction='none')
                    else:
                        logits, model_info = self.model(batch_x, padding_mask, None, None, gating_value=self.args.gating_value)
                        loss = nn.functional.cross_entropy(logits, label, reduction='none') + model_info.loss.mean()
                
                buffer.loss.append(loss.flatten())
                buffer.x_data.append(batch_x.cpu())
                buffer.trues.append(label.cpu())
                buffer.preds.append(logits.cpu())
                if self.args.model != 'DNN':
                    buffer.p.append(model_info.p.cpu())
                    buffer.d.append(model_info.d.cpu())
                    buffer.shapelet_preds.append(model_info.shapelet_preds.cpu())
                    if self.args.model == 'InterpGN':
                        buffer.eta.append(model_info.eta.cpu())
                        buffer.dnn_preds.append(model_info.dnn_preds.cpu())
                        
        probs = torch.nn.functional.softmax(torch.cat(buffer.preds, dim=0), dim=1)  # (total_samples, num_classes) est. prob. for each class and sample
        predictions = torch.argmax(probs, dim=1)  # (total_samples,) int class index for each sample
        trues = torch.cat(buffer.trues, dim=0).flatten()
        accuracy = accuracy_score(predictions.cpu().numpy(), trues.cpu().numpy())


        cls_result = ClassificationResult(
            x_data=torch.cat(buffer.x_data, dim=0).cpu(),
            trues=trues.cpu(),
            preds=predictions.cpu(),
            loss=torch.cat(buffer.loss, dim=0).mean().item(),
            accuracy=accuracy
        )

        return cls_result

    def get_train_utility(self, save_csv=True, result_dir=None):
        if not os.path.isdir(result_dir):
            try:
                os.makedirs(result_dir)
            except:
                pass
        
        @dataclass
        class Buffer:
            x_data: list = field(default_factory=list)
            trues: list = field(default_factory=list)
            preds: list = field(default_factory=list)
            shapelet_preds: list = field(default_factory=list)
            dnn_preds: list = field(default_factory=list)
            p: list = field(default_factory=list)
            d: list = field(default_factory=list)
            eta: list = field(default_factory=list)
            loss: list = field(default_factory=list)

        buffer = Buffer()
        self.model.eval()
        with torch.no_grad():
            for i, (batch_x, label, padding_mask) in enumerate(self.train_loader):
                batch_x = batch_x.float().to(self.device)
                label = label.long().squeeze(-1).to(self.device)
                padding_mask = padding_mask.float().to(self.device)

                with torch.autocast(device_type=self.device.type, dtype=torch.bfloat16, enabled=self.args.amp):
                    if self.args.model == 'DNN':
                        logits = self.model(batch_x, padding_mask, None, None)
                        loss = nn.functional.cross_entropy(logits, label, reduction='none')
                    else:
                        logits, model_info = self.model(batch_x, padding_mask, None, None, gating_value=self.args.gating_value)
                        loss = nn.functional.cross_entropy(logits, label, reduction='none') + model_info.loss.mean()
                
                buffer.loss.append(loss.flatten())
                buffer.x_data.append(batch_x.cpu())
                buffer.trues.append(label.cpu())
                buffer.preds.append(logits.cpu())
                if self.args.model != 'DNN':
                    buffer.p.append(model_info.p.cpu())
                    buffer.d.append(model_info.d.cpu())
                    buffer.shapelet_preds.append(model_info.shapelet_preds.cpu())
                    if self.args.model == 'InterpGN':
                        buffer.eta.append(model_info.eta.cpu())
                        buffer.dnn_preds.append(model_info.dnn_preds.cpu())
                        
        probs = torch.nn.functional.softmax(torch.cat(buffer.preds, dim=0), dim=1)  # (total_samples, num_classes) est. prob. for each class and sample
        predictions = torch.argmax(probs, dim=1)  # (total_samples,) int class index for each sample
        trues = torch.cat(buffer.trues, dim=0).flatten()
        accuracy = accuracy_score(predictions.cpu().numpy(), trues.cpu().numpy())


        cls_result = ClassificationResult(
            x_data=torch.cat(buffer.x_data, dim=0).cpu(),
            trues=trues.cpu(),
            preds=predictions.cpu(),
            loss=torch.cat(buffer.loss, dim=0).mean().item(),
            accuracy=accuracy
        )

        return cls_result
    
    def explain_IG(self, target_set='test', save_dir=None):
        """
        Generates Integrated Gradients attribution maps for the specified dataset.
        Saves a dictionary containing attributions, labels, and predictions as a .npy file.
        """
        
        # Define saving path
        if save_dir is None:
            save_dir = os.path.join(f"./explanations/{self.args.dataset}", self.args.model, self.args.dnn_type)
        os.makedirs(save_dir, exist_ok=True)

        self.model.eval()
        loader = self.test_loader if target_set == 'test' else self.train_loader

        
        # Wrapper function for Captum
        # InterpGN returns (logits, model_info), but IG needs only logits
        def forward_wrapper(x, padding_mask):
            if self.args.model == 'DNN':
                return self.model(x, padding_mask, None, None)
            else:
                # gating_value is passed if specified in args
                logits, _ = self.model(x, padding_mask, None, None, gating_value=self.args.gating_value)
                return logits

        ig = IntegratedGradients(forward_wrapper)
        
        all_attributions = []
        all_labels = []
        all_preds = []

        print(f"Generating IG explanations for {target_set} set...")
        for i, (batch_x, label, padding_mask) in enumerate(loader):
            batch_x = batch_x.float().to(self.device).requires_grad_()
            padding_mask = padding_mask.float().to(self.device)
            label = label.long().squeeze(-1).to(self.device)

            # Obtain predicted classes to use as attribution targets
            with torch.no_grad():
                logits = forward_wrapper(batch_x, padding_mask)
                preds = torch.argmax(logits, dim=1)

            # Calculate IG attributions
            # n_steps=50 is standard; higher is more accurate but slower
            attributions = ig.attribute(
                inputs=batch_x, 
                target=preds, 
                additional_forward_args=(padding_mask),
                n_steps=50,
                internal_batch_size = 1
            )
            
            all_attributions.append(attributions.detach().cpu().numpy())
            all_labels.append(label.cpu().numpy())
            all_preds.append(preds.cpu().numpy())

        # Concatenate results across all batches
        attr_map = np.concatenate(all_attributions, axis=0)
        labels = np.concatenate(all_labels, axis=0)
        preds = np.concatenate(all_preds, axis=0)

        # Define file name using metadata for consistency
        # current_time = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
        file_name = f"{self.args.dataset}-{self.args.seed}-{self.args.model}-{self.args.dnn_type}-{target_set}_IG.npy"
        save_path = os.path.join(save_dir, file_name)
        
        # Save results
        np.save(save_path, {
            'attributions': attr_map,
            'labels': labels,
            'predictions': preds,
            'metadata': {
                'model': self.args.model,
                'dnn_type': self.args.dnn_type,
                'dataset': self.args.dataset
            }
        })
        
        print(f"Explanation map saved at: {save_path}")
        return attr_map

    def _forward_wrapper(self, x, padding_mask):
        """Wrapper to handle different model output formats."""
        if self.args.model == 'DNN':
            return self.model(x, padding_mask, None, None)
        else:
            logits, _ = self.model(x, padding_mask, None, None, gating_value=self.args.gating_value)
            return logits
    
    # def _PatchTST_forward_wrapper(self, x, padding_mask):
    #     """Wrapper with explicit cloning to prevent in-place modification errors."""
    #     # 1. Force a clone of the input to ensure it has its own memory space
    #     x_input = x.clone() 
        
    #     # 2. Ensure gradients are enabled for this specific pass
    #     with torch.enable_grad():
    #         if self.args.model == 'DNN':
    #             # PatchTST needs specific args
    #             return self.model(x_input, padding_mask, None, None)
    #         else:
    #             logits, _ = self.model(x_input, padding_mask, None, None, 
    #                                   gating_value=self.args.gating_value)
    #             return logits

    def explain_SegIG(self, n, target_set='test', noise=None):
        """Generates Segmented Integrated Gradients (SegIG) explanations."""
        ig = IntegratedGradients(self._forward_wrapper)
        if noise:
            loader = self.noisy_test_loader if target_set == 'test' else self.noisy_train_loader
        else:
            loader = self.test_loader if target_set == 'test' else self.train_loader

        self.model.eval()

        all_attr, all_labels, all_preds = [], [], []
        print(f"Generating SegIG_{n} for {target_set} set...")

        for batch_x, label, padding_mask in loader:
            batch_x = batch_x.float().to(self.device).requires_grad_()
            padding_mask = padding_mask.float().to(self.device)
            
            with torch.no_grad():
                logits = self._forward_wrapper(batch_x, padding_mask)
                preds = torch.argmax(logits, dim=1)
            
            # Standard IG: [Batch, Time, Channel]
            attr = ig.attribute(batch_x, target=preds, additional_forward_args=(padding_mask), internal_batch_size=8)
            attr_np = attr.detach().cpu().numpy()
            B, T, C = attr_np.shape

            # Divide Time (T) into n segments
            segmented_attr = np.zeros_like(attr_np)
            indices = np.array_split(np.arange(T), n)
            # print(indices)
            # exit()
            
            for segment_range in indices:
                # Sum IG scores within the segment for each channel
                seg_sum = attr_np[:, segment_range, :].sum(axis=1, keepdims=True)
                # Broadcast the sum to the whole segment
                segmented_attr[:, segment_range, :] = seg_sum

            all_attr.append(segmented_attr)
            all_labels.append(label.cpu().numpy())
            all_preds.append(preds.cpu().numpy())

        # Save logic
        save_dir = os.path.join(f"./explanations/{self.args.dataset}", self.args.model, self.args.dnn_type)
        os.makedirs(save_dir, exist_ok=True)
        if noise:
            file_name = f"{self.args.dataset}-{self.args.seed}-{self.args.model}-{self.args.dnn_type}-{target_set}_SegIG_{n}_noisy.npy"
        else:
            file_name = f"{self.args.dataset}-{self.args.seed}-{self.args.model}-{self.args.dnn_type}-{target_set}_SegIG_{n}.npy"
        save_path = os.path.join(save_dir, file_name)
        
        np.save(save_path, {
            'attributions': np.concatenate(all_attr, axis=0),
            'labels': np.concatenate(all_labels, axis=0),
            'predictions': np.concatenate(all_preds, axis=0)
        })
        print(f"SegIG_{n} map saved at: {save_path}")
        return save_path

    def explain_keystoneIG(self, keystone_ratio, target_set='test', noise=None):
        """
        Generates Keystone Integrated Gradients explanations.
        1. Fast pass to identify top salient features (Keystones).
        2. Refined pass (attr1) with Keystone-masked baseline to isolate keystone signal.
        3. Refined pass (attr2) to explain the remaining background features.
        """
        ig = IntegratedGradients(self._forward_wrapper)
        
        # Select appropriate loader based on noise flag
        if noise:
            loader = self.noisy_test_loader if target_set == 'test' else self.noisy_train_loader
        else:
            loader = self.test_loader if target_set == 'test' else self.train_loader

        # Load baseline means from training data
        train_means = torch.from_numpy(self.train_data.feature_df.mean().values).float().to(self.device).view(1, 1, -1)

        self.model.eval()
        all_attr, all_labels, all_preds = [], [], []
        print(f"Generating Keystone_IG_{keystone_ratio} for {target_set} set...")

        for i, (batch_x, label, padding_mask) in enumerate(loader):
            batch_x = batch_x.float().to(self.device).requires_grad_()
            padding_mask = padding_mask.float().to(self.device)
            label = label.long().squeeze(-1).to(self.device)

            # Step 1: Initial fast pass to find Keystone locations
            with torch.no_grad():
                logits = self._forward_wrapper(batch_x, padding_mask)
                preds = torch.argmax(logits, dim=1)
            
            # Use fewer steps for the initial search to save time
            attr_search = ig.attribute(batch_x, baselines=train_means, target=preds, additional_forward_args=(padding_mask), n_steps=50, internal_batch_size=8)
            
            # Step 2: Identify exact Keystone indices (top %)
            B, T, C = attr_search.shape
            num_top = int(T * C * keystone_ratio)
            flat_search = attr_search.abs().view(B, -1)
            _, top_indices = torch.topk(flat_search, k=num_top, dim=1)
            
            keystone_mask_flat = torch.zeros_like(flat_search)
            keystone_mask_flat.scatter_(1, top_indices, 1.0)
            keystone_mask = keystone_mask_flat.view(B, T, C)

            # # --- DEBUG BLOCK ---
            # if i == 0: # Check the first batch of the loader
            #     # Sum the mask for the first sample to see how many 1s exist
            #     actual_masked_count = keystone_mask[0].sum().item()
            #     expected_masked_count = num_top
            #     print(f"[DEBUG] Batch 0 | Keystone Ratio: {keystone_ratio}")
            #     print(f"[DEBUG] Expected Masked: {expected_masked_count} | Actual Masked: {actual_masked_count} | Total Features: {T*C}={T}x{C}")
                
            #     # Validation check
            #     assert actual_masked_count == expected_masked_count, \
            #         f"Mask Error: Expected {expected_masked_count} points, but got {actual_masked_count}"
                
            #     # Check if the mask shape matches the input
            #     assert keystone_mask.shape == batch_x.shape, \
            #         f"Shape Error: Mask {keystone_mask.shape} != Input {batch_x.shape}"
            #     exit()
            # --------------------

            # Step 3: Create Keystone-only Data (Top k from batch_x, rest from means)
            batch_x_keystone = (batch_x * keystone_mask) + (train_means * (1 - keystone_mask))

            # Step 4: Attribution Passes
            # attr1: Transition from Neutral -> Keystone (Exposing the primary signal)
            attr1 = ig.attribute(batch_x_keystone, baselines=train_means, target=preds, 
                                 additional_forward_args=(padding_mask), n_steps=50, internal_batch_size=8)
            
            # attr2: Transition from Keystone -> Full Data (Explaining the Background rest)
            attr2 = ig.attribute(batch_x, baselines=batch_x_keystone, target=preds, 
                                 additional_forward_args=(padding_mask), n_steps=50, internal_batch_size=8)

            # Step 5: Total Attribution = attr1 (Background) + attr2 (Keystone)
            combined_attr = attr1 + attr2

            all_attr.append(combined_attr.detach().cpu().numpy())
            all_labels.append(label.cpu().numpy())
            all_preds.append(preds.cpu().numpy())

        # Save Logic
        save_dir = os.path.join(f"./explanations/{self.args.dataset}", self.args.model, self.args.dnn_type)
        os.makedirs(save_dir, exist_ok=True)
        
        suffix = "_noisy" if noise else ""
        file_name = f"{self.args.dataset}-{self.args.seed}-{self.args.model}-{self.args.dnn_type}-{target_set}_KeystoneIG_{keystone_ratio}{suffix}.npy"
        save_path = os.path.join(save_dir, file_name)
        
        np.save(save_path, {
            'attributions': np.concatenate(all_attr, axis=0),
            'labels': np.concatenate(all_labels, axis=0),
            'predictions': np.concatenate(all_preds, axis=0)
        })
        
        print(f"Keystone_IG_{keystone_ratio} map saved at: {save_path}")
        return save_path

    def evaluate_faithfulness_bottom_up(self, target_set='test', explanation_method='IG', metric=None):
        """
        Evaluates faithfulness by removing the LEAST important features.
        Replaces removed features with the mean values obtained from the training dataset.
        Thresholds range from 0% to 100% with a 5% step.
        """
        # 1. Automatically determine the path based on current experiment settings
        save_dir = os.path.join(f"./explanations/{self.args.dataset}", self.args.model, self.args.dnn_type)
        file_name = f"{self.args.dataset}-{self.args.seed}-{self.args.model}-{self.args.dnn_type}-{target_set}_{explanation_method}.npy"
        map_path = os.path.join(save_dir, file_name)

        if not os.path.exists(map_path):
            print(f"Error: Explanation map not found at {map_path}. Run explain_IG first.")
            return

        # 2. Compute channel-wise means from the training dataset
        # feature_df contains all time steps for all training samples
        train_means_np = self.train_data.feature_df.mean().values 
        # print(self.train_data.feature_df.mean())
        # exit()
        train_means = torch.from_numpy(train_means_np).float().to(self.device)
        # train_means shape: [num_features] (will broadcast to [Batch, Seq_len, num_features])

        # 3. Load the explanation data (saved as a dictionary)
        data = np.load(map_path, allow_pickle=True).item()
        attributions = data['attributions'] # Shape: [N, seq_len, num_features]
        
        self.model.eval()
        loader = self.test_loader if target_set == 'test' else self.train_loader
        
        # 4. Define Thresholds (0% to 100% with 5% increments)
        thresholds = np.linspace(0, 1, 21)
        faith_results = []

        print(f"Starting Bottom-Up Evaluation (Mean-Replacement) on {target_set}...")
        print(f"Loading attributions from: {map_path}")


        for threshold in thresholds:
            all_probs = []
            all_preds = []
            all_trues = []
            attr_idx = 0
            
            with torch.no_grad():
                for i, (batch_x, label, padding_mask) in enumerate(loader):
                    batch_size = batch_x.shape[0]
                    batch_x = batch_x.float().to(self.device) # [Batch, Seq_len, Vars]
                    padding_mask = padding_mask.float().to(self.device)
                    
                    # Extract attributions for this batch
                    batch_attr = attributions[attr_idx : attr_idx + batch_size]
                    attr_idx += batch_size
                    batch_attr_abs = np.abs(batch_attr)
                    
                    if threshold == 0:
                        mask = torch.ones_like(batch_x)
                    elif threshold == 1:
                        mask = torch.zeros_like(batch_x)
                    else:
                        # Convert absolute attributions to torch
                        batch_attr_torch = torch.from_numpy(batch_attr_abs).to(self.device)
                        B, T, C = batch_attr_torch.shape
                        flat_attr = batch_attr_torch.reshape(B, -1)
                        
                        # Calculate exact number of features to KEEP
                        num_total = T * C
                        num_keep = int(num_total * (1 - threshold))
                        
                        # Get indices of the top-k most important features per sample
                        _, top_indices = torch.topk(flat_attr, k=num_keep, dim=1)
                        
                        # Create binary mask
                        mask_flat = torch.zeros_like(flat_attr).to(self.device)
                        mask_flat.scatter_(1, top_indices, 1.0)
                        mask = mask_flat.reshape(B, T, C)

                        # --- DEBUG LINES ---
                        # if i == 0 and threshold == 0.2: # Only print for the first batch to avoid log spam
                        #     actual_keep = mask[0].sum().item()
                        #     print(f"[DEBUG] Threshold: {threshold:.2f} | Target Keep: {num_keep} | Actual Keep: {actual_keep} | Num Features: {num_total}, {T}x{C}")
                        #     assert actual_keep == num_keep, f"Mask mismatch! Expected {num_keep}, got {actual_keep}"
                            # exit()
                        # --------------------

                    # --- Mean Replacement Logic ---
                    # Replace removed features (1-mask) with training means
                    masked_x = (batch_x * mask) + (train_means * (1 - mask))
                    
                    if self.args.model == 'DNN':
                        logits = self.model(masked_x, padding_mask, None, None)
                    
                    probs = torch.softmax(logits, dim=1)
                    
                    all_probs.append(probs.cpu())
                    all_preds.append(torch.argmax(logits, dim=1).cpu())
                    all_trues.append(label.long().squeeze(-1).cpu())

            # 5. Calculate Accuracy
            # print(torch.cat(all_trues).shape, torch.cat(all_preds).shape)
            # exit()
            if metric == 'accuracy':
                acc = accuracy_score(torch.cat(all_trues).numpy(), torch.cat(all_preds).numpy())
                faith_results.append({
                    'removal_ratio': threshold,
                    'accuracy': acc
                })
                # print(f"Removed {threshold:.2f} least important | Accuracy: {acc:.4f}")
            elif metric == 'roc_auc':
                if self.args.num_class == 2:
                    roc_auc = roc_auc_score(torch.cat(all_trues).numpy(), torch.cat(all_probs)[:,1].numpy())
                else:
                    roc_auc = roc_auc_score(torch.cat(all_trues).numpy(), torch.cat(all_probs).numpy(), multi_class='ovr', average='macro')
                faith_results.append({
                    'removal_ratio': threshold,
                    'roc_auc': roc_auc
                })
                # print(f"Removed {threshold:.2f} least important | ROC AUC: {roc_auc:.4f}")
            else:
                print("No valid metric specified for faithfulness evaluation.")
                exit()                

        # 6. Save results to a CSV in the same directory as the map
        summary_df = pd.DataFrame(faith_results)
        csv_save_path = map_path.replace('.npy', f'_BottomUp_{metric.capitalize()}_Faithfulness.csv')
        print(f"Evaluation complete. Results saved at: {csv_save_path}")
        auc_score = np.trapz(summary_df[metric], summary_df['removal_ratio'])
        print(f"Faithfulness AUC ({metric.capitalize()} vs. Removal Ratio): {auc_score:.4f}")
        summary_df['faithfulness_auc'] = auc_score
        summary_df.to_csv(csv_save_path, index=False)
        
        return summary_df

    def robustness_val(self, target_set='test', explanation_method='IG', metric=None):
        # 1. Automatically determine the path based on current experiment settings
        save_dir = os.path.join(f"./explanations/{self.args.dataset}", self.args.model, self.args.dnn_type)
        file_name = f"{self.args.dataset}-{self.args.seed}-{self.args.model}-{self.args.dnn_type}-{target_set}_{explanation_method}.npy"
        noisy_file_name = f"{self.args.dataset}-{self.args.seed}-{self.args.model}-{self.args.dnn_type}-{target_set}_{explanation_method}_noisy.npy"
        map_path = os.path.join(save_dir, file_name)
        noisy_map_path = os.path.join(save_dir, noisy_file_name)
        if not os.path.exists(map_path):
            print(f"Error: Explanation map not found at {map_path}. Run explain_IG first.")
            return
        # 3. Load the explanation data (saved as a dictionary)
        data = np.load(map_path, allow_pickle=True).item()
        attributions = data['attributions'] # Shape: [N, seq_len, num_features]
        noisy_data = np.load(noisy_map_path, allow_pickle=True).item()
        noisy_attributions = noisy_data['attributions'] # Shape: [N, seq_len, num_features]
        # print(attributions.shape, noisy_attributions.shape)
        # exit()
        # 4. Define Threshold (Top 20%)
        ratio = 0.2
        N, T, C = attributions.shape
        num_top = int(T * C * ratio)
        
        jaccard_scores = []
        print(f"Starting Robustness Evaluation (Jaccard Index @ Top 20%) on {target_set}...")

        # Convert to torch for consistent top-k selection as per previous fixes
        attr1_torch = torch.from_numpy(attributions).to(self.device)
        attr2_torch = torch.from_numpy(noisy_attributions).to(self.device)
        
        # Flatten across temporal and feature dimensions
        flat1 = attr1_torch.view(N, -1)
        flat2 = attr2_torch.view(N, -1)
        
        # Get top-k indices for original and noisy attributions
        _, indices1 = torch.topk(flat1, k=num_top, dim=1)
        _, indices2 = torch.topk(flat2, k=num_top, dim=1)

        for i in range(N):
            # Convert indices to sets for Jaccard calculation
            set1 = set(indices1[i].tolist())
            set2 = set(indices2[i].tolist())
            
            intersection = len(set1.intersection(set2))
            union = len(set1.union(set2))
            
            iou = intersection / union if union > 0 else 0
            jaccard_scores.append(iou)

        # 5. Report and Save results
        avg_iou = np.mean(jaccard_scores)
        print(f"Robustness Result | Jaccard Index (IoU) @ Top 20%: {avg_iou:.4f}")
        
        summary_df = pd.DataFrame({
            'explanation_method': [explanation_method],
            'top_k_ratio': [ratio],
            'avg_jaccard_iou': [avg_iou]
        })
        
        csv_save_path = map_path.replace('.npy', '_Robustness_Jaccard.csv')
        summary_df.to_csv(csv_save_path, index=False)
        print(f"Robustness evaluation complete. Results saved at: {csv_save_path}")
        
        return avg_iou
        
    
    def explanation_with_expert_name(self, explanation_name, target_set='test', **kwargs):
        """
        Unified interface for all Captum attribution methods and custom SegIG.
        
        Supported Captum names: 'IG', 'Saliency', 'DeepLift', 'InputXGradient', 
                                'FeatureAblation', 'Occlusion', 'GuidedBackprop'
        Supported Custom: 'SegIG_n' (e.g., 'SegIG_10')
        """
        torch.autograd.set_detect_anomaly(True)
        # 1. Dispatch to custom SegIG if requested
        if explanation_name.startswith('SegIG'):
            n = int(explanation_name.split('_')[1]) if '_' in explanation_name else kwargs.get('n', 10)
            return self.explain_SegIG(n=n, target_set=target_set)
        
        if explanation_name.startswith('KeystoneIG'):
            keystone_ratio = float(explanation_name.split('_')[1]) if '_' in explanation_name else kwargs.get('keystone_ratio', 0.2)
            return self.explain_keystoneIG(keystone_ratio=keystone_ratio, target_set=target_set)

        # 2. Map explanation names to Captum classes
        captum_mapping = {
            'IG': attr_lib.IntegratedGradients,
            'LIME': attr_lib.Lime,
            'KernelSHAP': attr_lib.KernelShap,
            'Saliency': attr_lib.Saliency,
            'InputXGradient': attr_lib.InputXGradient,
        }

        if explanation_name not in captum_mapping:
            raise ValueError(f"Method '{explanation_name}' not supported. Choose from {list(captum_mapping.keys())} or 'SegIG_n'.")

        # 3. Setup paths and model state
        save_dir = os.path.join(f"./explanations/{self.args.dataset}", self.args.model, self.args.dnn_type)
        os.makedirs(save_dir, exist_ok=True)
        self.model.eval()
        if target_set == 'test':
            print("Using test loader for explanations.")
            loader = self.test_loader 
        else:
            print("Using train loader for explanations.")
            loader =self.train_loader
        
        # Initialize Captum method
        # if self.args.dnn_type == 'PatchTST':
            # explainer = captum_mapping[explanation_name](self._PatchTST_forward_wrapper)
        # else:
        explainer = captum_mapping[explanation_name](self._forward_wrapper)
        
        all_attributions, all_labels, all_preds = [], [], []
        print(f"Generating {explanation_name} explanations for {target_set}...")

        for i, (batch_x, label, padding_mask) in enumerate(loader):
            if i % 100 == 0:
                print(f"  Processing batch {i}...")
            batch_x = batch_x.float().to(self.device).requires_grad_()
            # print(batch_x.shape)
            # exit()
            padding_mask = padding_mask.float().to(self.device)
            
            with torch.no_grad():
                logits = self._forward_wrapper(batch_x, padding_mask)
                preds = torch.argmax(logits, dim=1)

            # 4. Generate Attributions
            # Standard params for most methods
            attr_kwargs = {
                'inputs': batch_x,
                'target': preds,
                'additional_forward_args': (padding_mask)
            }
            
            # Method-specific adjustments
            if explanation_name == 'Occlusion':
                # Time-series occlusion usually needs a sliding window (e.g., 5 time steps)
                attr_kwargs['sliding_window_shapes'] = (1, kwargs.get('window', 5), 1) 
            elif explanation_name == 'IG':
                attr_kwargs['n_steps'] = kwargs.get('n_steps', 50)
                attr_kwargs['internal_batch_size'] = kwargs.get('internal_batch_size', 8)

            # print(self.model.task_name)
            # exit()
            attributions = explainer.attribute(**attr_kwargs)
            # print(attributions.shape)
            # exit()
            
            all_attributions.append(attributions.detach().cpu().numpy())
            all_labels.append(label.cpu().numpy())
            all_preds.append(preds.cpu().numpy())

        # 5. Save results
        file_name = f"{self.args.dataset}-{self.args.seed}-{self.args.model}-{self.args.dnn_type}-{target_set}_{explanation_name}.npy"
        save_path = os.path.join(save_dir, file_name)
        
        np.save(save_path, {
            'attributions': np.concatenate(all_attributions, axis=0),
            'labels': np.concatenate(all_labels, axis=0),
            'predictions': np.concatenate(all_preds, axis=0)
        })
        
        print(f"Explanation map ({explanation_name}) saved at: {save_path}, np.concatenate(all_attributions, axis=0).shape: {np.concatenate(all_attributions, axis=0).shape}")
        return save_path
    
    def time_explanation_with_expert_name(self, explanation_name, target_set='test', **kwargs):
        """
        Unified interface for all Captum attribution methods and custom SegIG.
        
        Supported Captum names: 'IG', 'Saliency', 'DeepLift', 'InputXGradient', 
                                'FeatureAblation', 'Occlusion', 'GuidedBackprop'
        Supported Custom: 'SegIG_n' (e.g., 'SegIG_10')
        """
        torch.autograd.set_detect_anomaly(True)
        # 1. Dispatch to custom SegIG if requested
        if explanation_name.startswith('SegIG'):
            n = int(explanation_name.split('_')[1]) if '_' in explanation_name else kwargs.get('n', 10)
            return self.explain_SegIG(n=n, target_set=target_set)
        
        if explanation_name.startswith('KeystoneIG'):
            keystone_ratio = float(explanation_name.split('_')[1]) if '_' in explanation_name else kwargs.get('keystone_ratio', 0.2)
            return self.explain_keystoneIG(keystone_ratio=keystone_ratio, target_set=target_set)

        # 2. Map explanation names to Captum classes
        captum_mapping = {
            'IG': attr_lib.IntegratedGradients,
            'LIME': attr_lib.Lime,
            'KernelSHAP': attr_lib.KernelShap,
            'Saliency': attr_lib.Saliency,
            'InputXGradient': attr_lib.InputXGradient,
        }

        if explanation_name not in captum_mapping:
            raise ValueError(f"Method '{explanation_name}' not supported. Choose from {list(captum_mapping.keys())} or 'SegIG_n'.")

        # 3. Setup paths and model state
        save_dir = os.path.join(f"./explanations/{self.args.dataset}", self.args.model, self.args.dnn_type)
        os.makedirs(save_dir, exist_ok=True)
        self.model.eval()
        if target_set == 'test':
            print("Using test loader for explanations.")
            loader = self.test_loader 
        else:
            print("Using train loader for explanations.")
            loader =self.train_loader
        
        # Initialize Captum method
        # if self.args.dnn_type == 'PatchTST':
            # explainer = captum_mapping[explanation_name](self._PatchTST_forward_wrapper)
        # else:
        explainer = captum_mapping[explanation_name](self._forward_wrapper)
        
        all_attributions, all_labels, all_preds = [], [], []
        print(f"Generating {explanation_name} explanations for {target_set}...")
        
        run_times = []
        for time_count in range(5):
            # total_run_time = 0.0
            start_time = time.time()
            for i, (batch_x, label, padding_mask) in enumerate(loader):
                # if i % 100 == 0:
                    # print(f"  Processing batch {i}...")
                batch_x = batch_x.float().to(self.device).requires_grad_()
                # print(batch_x.shape)
                # exit()
                padding_mask = padding_mask.float().to(self.device)
                
                with torch.no_grad():
                    logits = self._forward_wrapper(batch_x, padding_mask)
                    preds = torch.argmax(logits, dim=1)

                # 4. Generate Attributions
                # Standard params for most methods
                attr_kwargs = {
                    'inputs': batch_x,
                    'target': preds,
                    'additional_forward_args': (padding_mask)
                }
                
                # Method-specific adjustments
                if explanation_name == 'Occlusion':
                    # Time-series occlusion usually needs a sliding window (e.g., 5 time steps)
                    attr_kwargs['sliding_window_shapes'] = (1, kwargs.get('window', 5), 1) 
                elif explanation_name == 'IG':
                    attr_kwargs['n_steps'] = kwargs.get('n_steps', 50)
                    attr_kwargs['internal_batch_size'] = kwargs.get('internal_batch_size', 8)

                # print(self.model.task_name)
                # exit()
                attributions = explainer.attribute(**attr_kwargs)
                
                all_attributions.append(attributions.detach().cpu().numpy())
                all_labels.append(label.cpu().numpy())
                all_preds.append(preds.cpu().numpy())
            end_time = time.time()
            total_run_time = end_time - start_time
            average_time = total_run_time / len(loader.dataset)
            run_times.append(average_time)
            print(f"Run {time_count+1}/5: Total time = {total_run_time:.2f} seconds | Average time per sample = {average_time:.6f} seconds")
        mean_time = np.mean(run_times)
        std_time = np.std(run_times)
        print(f"Average time per sample over 5 runs: {mean_time:.6f} ± {std_time:.6f} seconds")
        

        # 5. Save results
        file_name = f"{self.args.dataset}-{self.args.seed}-{self.args.model}-{self.args.dnn_type}-{target_set}_{explanation_name}.npy"
        save_path = os.path.join(save_dir, file_name)
        
        np.save(save_path, {
            'attributions': np.concatenate(all_attributions, axis=0),
            'labels': np.concatenate(all_labels, axis=0),
            'predictions': np.concatenate(all_preds, axis=0)
        })
        
        print(f"Explanation map ({explanation_name}) saved at: {save_path}")
        return save_path

    def explanation_with_expert_name_withnoise(self, explanation_name, target_set='test', **kwargs):
        """
        Unified interface for all Captum attribution methods and custom SegIG.
        
        Supported Captum names: 'IG', 'Saliency', 'DeepLift', 'InputXGradient', 
                                'FeatureAblation', 'Occlusion', 'GuidedBackprop'
        Supported Custom: 'SegIG_n' (e.g., 'SegIG_10')
        """
        torch.autograd.set_detect_anomaly(True)
        # 1. Dispatch to custom SegIG if requested
        if explanation_name.startswith('SegIG'):
            n = int(explanation_name.split('_')[1]) if '_' in explanation_name else kwargs.get('n', 10)
            return self.explain_SegIG(n=n, target_set=target_set, noise=True)
        if explanation_name.startswith('KeystoneIG'):
            keystone_ratio = float(explanation_name.split('_')[1]) if '_' in explanation_name else kwargs.get('keystone_ratio', 0.2)
            return self.explain_keystoneIG(keystone_ratio=keystone_ratio, target_set=target_set, noise=True)

        # 2. Map explanation names to Captum classes
        captum_mapping = {
            'IG': attr_lib.IntegratedGradients,
            'LIME': attr_lib.Lime,
            'KernelSHAP': attr_lib.KernelShap,
            'Saliency': attr_lib.Saliency,
            'InputXGradient': attr_lib.InputXGradient,
        }

        if explanation_name not in captum_mapping:
            raise ValueError(f"Method '{explanation_name}' not supported. Choose from {list(captum_mapping.keys())} or 'SegIG_n'.")

        # 3. Setup paths and model state
        save_dir = os.path.join(f"./explanations/{self.args.dataset}", self.args.model, self.args.dnn_type)
        os.makedirs(save_dir, exist_ok=True)
        self.model.eval()
        if target_set == 'test':
            print("Using test loader for explanations.")
            loader = self.noisy_test_loader
        else:
            print("Using train loader for explanations.")
            loader =self.noisy_train_loader
        
        # Initialize Captum method
        # if self.args.dnn_type == 'PatchTST':
            # explainer = captum_mapping[explanation_name](self._PatchTST_forward_wrapper)
        # else:
        explainer = captum_mapping[explanation_name](self._forward_wrapper)
        
        all_attributions, all_labels, all_preds = [], [], []
        print(f"Generating {explanation_name} explanations for {target_set}...")

        for i, (batch_x, label, padding_mask) in enumerate(loader):
            if i % 100 == 0:
                print(f"  Processing batch {i}...")
            batch_x = batch_x.float().to(self.device).requires_grad_()
            # print(batch_x.shape)
            # exit()
            padding_mask = padding_mask.float().to(self.device)
            
            with torch.no_grad():
                logits = self._forward_wrapper(batch_x, padding_mask)
                preds = torch.argmax(logits, dim=1)

            # 4. Generate Attributions
            # Standard params for most methods
            attr_kwargs = {
                'inputs': batch_x,
                'target': preds,
                'additional_forward_args': (padding_mask)
            }
            
            # Method-specific adjustments
            if explanation_name == 'Occlusion':
                # Time-series occlusion usually needs a sliding window (e.g., 5 time steps)
                attr_kwargs['sliding_window_shapes'] = (1, kwargs.get('window', 5), 1) 
            elif explanation_name == 'IG':
                attr_kwargs['n_steps'] = kwargs.get('n_steps', 50)
                attr_kwargs['internal_batch_size'] = kwargs.get('internal_batch_size', 8)

            # print(self.model.task_name)
            # exit()
            attributions = explainer.attribute(**attr_kwargs)
            
            all_attributions.append(attributions.detach().cpu().numpy())
            all_labels.append(label.cpu().numpy())
            all_preds.append(preds.cpu().numpy())

        # 5. Save results
        file_name = f"{self.args.dataset}-{self.args.seed}-{self.args.model}-{self.args.dnn_type}-{target_set}_{explanation_name}_noisy.npy"
        save_path = os.path.join(save_dir, file_name)
        
        np.save(save_path, {
            'attributions': np.concatenate(all_attributions, axis=0),
            'labels': np.concatenate(all_labels, axis=0),
            'predictions': np.concatenate(all_preds, axis=0)
        })
        
        print(f"Explanation map ({explanation_name}) saved at: {save_path}")
        return save_path

    def calculate_instance_faithfulness(self, x, label, padding_mask, attribution, verbose=False):
        """
        Calculates faithfulness for a SINGLE instance using top-k indexing.
        x: [1, T, C]
        attribution: [1, T, C]
        Returns: AUC score (float)
        """
        # --- STEP 1: INITIALIZATION ---
        # Ensure the model is in evaluation mode and move baseline means to the current device
        self.model.eval()
        train_means = torch.from_numpy(self.train_data.feature_df.mean().values).float().to(self.device).view(1, 1, -1)
        
        # --- STEP 2: SETUP PARAMETERS ---
        # Define thresholds for removal (0% to 100% with 5% increments)
        thresholds = np.linspace(0, 1, 21)
        # print(thresholds)
        # exit()
        conf_scores = []
        
        # Determine total features (T * C) to calculate exact top-k counts
        _, T, C = x.shape
        num_total = T * C

        if verbose:
            print(f"DEBUG: Explaining instance with true label {label.item()}")

        # Work with absolute attributions for importance ranking
        attr_abs_torch = torch.from_numpy(np.abs(attribution)).to(self.device)
        flat_attr = attr_abs_torch.view(1, -1)  # Flatten for top-k selection

        # --- STEP 3: ITERATIVE FEATURE REMOVAL ---
        for threshold in thresholds:
            with torch.no_grad():
                # --- STEP 3A: MASK GENERATION (TOP-K) ---
                # Calculate how many features to KEEP (1.0 - threshold)
                num_keep = int(num_total * (1.0 - threshold))
                
                if threshold == 0:
                    mask = torch.ones_like(x)
                elif threshold == 1:
                    mask = torch.zeros_like(x)
                else:
                    # Identify indices of the most important features
                    _, top_indices = torch.topk(flat_attr, k=num_keep, dim=1)
                    
                    # Create binary mask: 1 to KEEP (high importance), 0 to REMOVE (low importance)
                    mask_flat = torch.zeros_like(flat_attr)
                    mask_flat.scatter_(1, top_indices, 1.0)
                    mask = mask_flat.view(1, T, C)

                # --- STEP 3B: DEBUG MASK INTEGRITY ---
                # Check if the mask contains exactly the expected number of features
                if verbose:
                    actual_keep = mask.sum().item()
                    # Comment: This verifies if our top-k logic perfectly matches the target ratio
                    if actual_keep != num_keep:
                        print(f"  [DEBUG MASK ERROR] Threshold {threshold:.2f}: Expected {num_keep}, got {actual_keep}")
                    else:
                        print(f"  [DEBUG MASK OK] Threshold {threshold:.2f}: Exact {num_keep} features kept.")

                # --- STEP 3C: FEATURE REPLACEMENT ---
                # Bottom-up logic: Replace LEAST important features with training means
                # Features with mask=0 are replaced by train_means
                masked_x = (x * mask) + (train_means * (1 - mask))
                
                # --- STEP 3D: MODEL INFERENCE ---
                # Get model prediction and calculate confidence in the ground truth label
                logits = self._forward_wrapper(masked_x, padding_mask)
                probs = torch.softmax(logits, dim=1)
                
                # Confidence in the actual label
                conf = probs[0, label.item()].item()
                conf_scores.append(conf)

                if verbose and (threshold * 100) % 20 == 0:
                    print(f"    > Step Info: Removed {threshold*100:3.0f}% | Confidence: {conf:.4f}")
        # exit()
        # --- STEP 4: CALCULATE FINAL AUC ---
        # Compute the Area Under the Confidence Curve relative to the removal ratio
        auc_score = np.trapz(conf_scores, thresholds)
        
        if verbose:
            print(f"DEBUG: Final Instance Faithfulness AUC: {auc_score:.4f}")
            
        return auc_score

    def select_best_instance_explanation(self, methods, target_set='train'):
        """
        Iterates through all instances in the specified set (defaults to 'train') 
        and selects the best explanation expert using per-instance faithfulness.
        """
        # --- STEP 1: INITIALIZE DIRECTORIES AND LOADERS ---
        # Define the base directory where pre-calculated explanation maps are stored
        save_dir = os.path.join(f"./explanations/{self.args.dataset}", self.args.model, self.args.dnn_type)
        
        # Select the appropriate data loader based on the requested target_set (e.g., 'train' or 'test')
        loader = self.train_loader if target_set == 'train' else self.test_loader
        num_instances = len(loader.dataset)
        
        # --- STEP 2: PRE-LOAD EXPERT ATTRIBUTION MAPS ---
        # We load attributions for all specified methods into a dictionary for quick access
        expert_data = {}
        for m in methods:
            # Construct the path for each method's .npy file following project naming conventions
            path = os.path.join(save_dir, f"{self.args.dataset}-{self.args.seed}-{self.args.model}-{self.args.dnn_type}-{target_set}_{m}.npy")
            noisy_path = os.path.join(save_dir, f"{self.args.dataset}-{self.args.seed}-{self.args.model}-{self.args.dnn_type}-{target_set}_{m}_noisy.npy")
            
            # Load the expert's attribution data into memory if the file exists
            # if os.path.exists(path):
            expert_data[m] = np.load(path, allow_pickle=True).item()['attributions']
            noisy_expert_data = np.load(noisy_path, allow_pickle=True).item()['attributions']
            expert_data[m+'_noisy'] = noisy_expert_data
            # else:
                # Log a warning if a file is missing to help debug experiment configuration issues
                # print(f"DEBUG WARNING: File not found for {m} on {target_set} set at {path}")

        # --- STEP 3: INITIALIZE CONTAINERS FOR ENSEMBLE RESULTS ---
        best_attr_map = []  # Stores the attribution map of the "winning" expert per instance
        winner_names = []   # Stores the name of the method that won for each instance
        best_scores = []      # Stores the maximum faithfulness (AUC) score found per instance
        
        print(f"\n--- DEBUG: STARTING ENSEMBLE SELECTION ON {target_set.upper()} SET ---")
        print(f"DEBUG: Evaluating {len(expert_data)} experts across {num_instances} samples.")

        # --- STEP 4: PER-INSTANCE EVALUATION LOOP ---
        # global_idx tracks the absolute sample index to synchronize with pre-loaded expert_data arrays
        global_idx = 0
        for i, (batch_x, label, padding_mask) in enumerate(loader):
            # Iterate through each individual sample within the current batch
            for b in range(batch_x.shape[0]):
                # Extract individual sample components and move to the computation device
                x_i = batch_x[b:b+1].to(self.device)
                label_i = label[b].to(self.device)
                mask_i = padding_mask[b:b+1].to(self.device)

                # Initialize variables to track the best performing method for this specific instance
                # best_instance_auc = -1.0
                best_instance_attr = None
                winning_expert = None
                best_instance_score = -1.0  # Combined score for faithfulness + robustness

                # --- STEP 5: COMPARE EXPERTS FOR THE CURRENT SAMPLE ---
                for m_name, attr_full in expert_data.items():
                    # Slice the full method array to retrieve the attribution map for the current global index
                    # attr_i = attr_full[global_idx:global_idx+1]
                    if "noisy" in m_name:
                        continue
                    attr_orig_i = expert_data[m_name][global_idx:global_idx+1]
                    attr_noisy_i = expert_data[m_name + '_noisy'][global_idx:global_idx+1]
                    
                    # Calculate the faithfulness score (AUC) for this specific attribution map
                    # This measures how well the explanation aligns with the model's actual decision process
                    # A) Calculate Faithfulness (AUC)
                    faith_auc = self.calculate_instance_faithfulness(
                        x_i, label_i, mask_i, attr_orig_i, verbose=False
                    )
                    
                    # B) Calculate Robustness (IoU)
                    robust_iou = self.calculate_instance_robustness(
                        attr_orig_i, attr_noisy_i, ratio=0.2, verbose=False
                    )
                    
                    # C) Combine Scores (Summation)
                    total_score = faith_auc + robust_iou
                    
                    # Update winner if this method has a higher combined score
                    if total_score > best_instance_score:
                        best_instance_score = total_score
                        best_instance_attr = attr_orig_i
                        winning_expert = m_name
                
                # --- STEP 6: STORE THE RESULTS FOR THIS INSTANCE ---
                best_attr_map.append(best_instance_attr)
                winner_names.append(winning_expert)
                best_scores.append(best_instance_score)

                # Periodically log progress for debugging long-running evaluations
                if global_idx % 20 == 0 or global_idx == num_instances - 1:
                    print(f"  > DEBUG: Instance {global_idx}/{num_instances-1} | Winner: {winning_expert:15} | AUC: {best_instance_score:.4f}")
                
                # Move to the next instance
                global_idx += 1

        # --- STEP 7: SAVE THE ENSEMBLE RESULTS ---
        # Concatenate all individual winning maps into a single global attribution array
        final_map_array = np.concatenate(best_attr_map, axis=0)
        
        # Define the save path for the "BestEnsemble" result file
        save_path = os.path.join(save_dir, f"{self.args.dataset}-{self.args.seed}-{self.args.model}-{self.args.dnn_type}-{target_set}_BestEnsemble.npy")
        
        # Save a comprehensive results dictionary containing attributions and metadata
        np.save(save_path, {
            'attributions': final_map_array,  # The final ensemble attribution map
            'winning_methods': winner_names,  # List of which method was chosen for each instance
            'best_scores': best_scores,           # The faithfulness scores for the ensemble selections
            'metadata': {
                'target_set': target_set,
                'methods_evaluated': list(expert_data.keys()),
                'average_score': np.mean(best_scores)  # The average faithfulness of the final ensemble
            }
        })
        
        # Final success confirmation
        print(f"--- DEBUG: SUCCESS! BestEnsemble for {target_set} saved to {save_path} ---")
        print(f"DEBUG: Final Mean Faithfulness AUC: {np.mean(best_scores):.4f}")
        
        return save_path
    
    def calculate_instance_robustness(self, attr_original, attr_noisy, ratio=0.2, verbose=False):
        """
        Calculates the Jaccard Index (IoU) for a single instance between original and noisy attributions.
        attr_original: [1, T, C]
        attr_noisy: [1, T, C]
        Returns: IoU score (float)
        """
        # --- STEP 1: PREPARE DATA ---
        # Work with absolute values to focus on magnitude of importance
        attr1 = torch.from_numpy(np.abs(attr_original)).to(self.device)
        attr2 = torch.from_numpy(np.abs(attr_noisy)).to(self.device)
        
        # Flatten to [1, T*C] for top-k selection
        flat1 = attr1.view(1, -1)
        flat2 = attr2.view(1, -1)
        
        # --- STEP 2: IDENTIFY TOP-K SALIENT FEATURES ---
        # Calculate exact number of features to compare (top 20%)
        num_total = flat1.shape[1]
        num_top = int(num_total * ratio)
        
        # Get indices of the top-k most important features
        _, indices1 = torch.topk(flat1, k=num_top, dim=1)
        _, indices2 = torch.topk(flat2, k=num_top, dim=1)

        # --- STEP 3: DEBUG MASK CHECK (Commented for later check) ---
        # if verbose:
        #     print(f"[DEBUG] Robustness Mask | Total: {num_total} | Target K: {num_top}")
        #     print(f"[DEBUG] Indices1 count: {len(indices1[0])} | Indices2 count: {len(indices2[0])}")

        # --- STEP 4: CALCULATE JACCARD INDEX (IoU) ---
        # Convert indices to sets to calculate intersection and union
        set1 = set(indices1[0].tolist())
        set2 = set(indices2[0].tolist())
        
        intersection = len(set1.intersection(set2))
        union = len(set1.union(set2))
        
        # IoU = Intersection / Union
        iou = intersection / union if union > 0 else 0.0
        
        if verbose:
            print(f"  > Robustness IoU: {iou:.4f} (Int: {intersection}, Union: {union})")
            
        return iou
