import torch
import numpy as np
import os
from torch.utils.data import Dataset
from data_provider.data_factory import data_provider

class BottomUpMaskDataset(Dataset):
    def __init__(self, args, tm_args, k_percentage=5, flag='train'):
        self.args = args
        self.tm_args = tm_args
        self.k = k_percentage / 100.0
        self.save_dir = os.path.join(f"./explanations/{args.dataset}", args.model, args.dnn_type)
        # self.methods = ['Saliency', 'InputXGradient', 'IG', 'SegIG_8', 'SegIG_16', 'SegIG_32', 'BestEnsemble']
        # self.methods =  ['IG', 'KeystoneIG_0.05', 'KeystoneIG_0.1', 'KeystoneIG_0.15', 'KeystoneIG_0.2', 'KeystoneIG_0.25']
        explanation_methods = ['LIME', 'KernelSHAP', 'IG', 'Saliency', 'InputXGradient', 'KeystoneIG_0.05', 'KeystoneIG_0.1', 'KeystoneIG_0.15', 'KeystoneIG_0.2', 'KeystoneIG_0.25', 'BestEnsemble']
        self.methods = explanation_methods
        
        # 1. Load the Raw Time Series Data
        self.raw_dataset, _ = data_provider(tm_args, flag=flag.upper())
        self.max_seq_len = self.raw_dataset.max_seq_len
        self.num_raw = len(self.raw_dataset)
        
        self.all_inputs = []
        self.all_masks = []
        
        # 2. Load and Pre-normalize Attribution Maps
        for method in self.methods:
            file_path = os.path.join(self.save_dir, f"{args.dataset}-{args.seed}-{args.model}-{args.dnn_type}-{flag.lower()}_{method}.npy")
            
            print(f"Loading {method} for mask generation...")
            data_dict = np.load(file_path, allow_pickle=True).item()
            maps = data_dict['attributions']
            
            for i in range(len(maps)):
                attr_map = maps[i]
                # Normalization to [0, 1]
                m_min, m_max = attr_map.min(), attr_map.max()
                norm_map = (attr_map - m_min) / (m_max - m_min + 1e-8)
                
                # Generate mask (Debug check on first instance of first method)
                is_debug = (i == 0 and method == self.methods[0])
                mask = self._generate_bottom_up_mask(norm_map, debug=is_debug)
                
                self.all_inputs.append(norm_map)
                self.all_masks.append(mask)
        
        self.all_inputs = np.array(self.all_inputs)
        self.all_masks = np.array(self.all_masks)
        print(f"Dataset Initialized: {len(self.all_inputs)} meta-samples generated.")

    def _generate_bottom_up_mask(self, attr_map, debug=False):
        """
        Generates a mask by finding the top-k LEAST important features.
        """
        # STEP 1: Convert to Tensor and Flatten
        attr_tensor = torch.from_numpy(attr_map).float()
        T, C = attr_tensor.shape
        flat_attr = attr_tensor.view(-1)
        num_total = flat_attr.numel()
        
        # STEP 2: Calculate exact number of features to mask out (bottom-up)
        num_to_mask = int(num_total * self.k)
        
        if num_to_mask == 0:
            return np.ones_like(attr_map).astype(np.float32)
        
        # STEP 3: Use topk to find the smallest values (largest=False)
        # This identifies the indices of the 'k' least important features
        _, low_importance_indices = torch.topk(flat_attr, k=num_to_mask, largest=False)
        
        # STEP 4: Create mask (initialize with 1.0, then set low importance to 0.0)
        mask_flat = torch.ones_like(flat_attr)
        mask_flat.scatter_(0, low_importance_indices, 0.0)
        
        # --- DEBUG PART ---
        if debug:
            actual_masked = (mask_flat == 0).sum().item()
            print(f"[DEBUG MASK] Total Features: {num_total}")
            print(f"[DEBUG MASK] Target Mask Count (k={self.k*100}%): {num_to_mask}")
            print(f"[DEBUG MASK] Actual Mask Count: {actual_masked}")
            assert actual_masked == num_to_mask, "Mask count mismatch!"
        # ------------------
        
        # STEP 5: Reshape back to original dimensions [T, C]
        return mask_flat.view(T, C).numpy()

    def __len__(self):
        return len(self.all_inputs)

    def __getitem__(self, index):
        # Sync meta-sample with the circular raw indices
        raw_idx = index % self.num_raw
        raw_x, label = self.raw_dataset[raw_idx]
        
        # Return Normalized Attribution (Input), Binary Mask (Target), and Raw Data
        x_attr = torch.from_numpy(self.all_inputs[index]).float()
        y_mask = torch.from_numpy(self.all_masks[index]).float()

        # print(x_attr.shape, y_mask.shape, raw_x.shape, label.shape)
        # print(x_attr.flatten())
        # print(y_mask.flatten())
        # exit()
        
        return x_attr, y_mask, raw_x.float(), label
