import torch
import numpy as np

from copy import deepcopy
from scipy.sparse import csgraph

class MILDataset(torch.utils.data.Dataset):
    def __init__(self, 
                 **kwargs
                #  data_path : str, 
                #  csv_path : str,
                 ):
        super(MILDataset, self).__init__()

        if not hasattr(self, 'dataset_name'):
            self.dataset_name = 'MILDataset'

        for k, v in kwargs.items():
            setattr(self, k, v)
        
        # self.data_path = data_path
        # self.csv_path = csv_path

        self.data_dict = self._init_data_dict()
        self.bag_names = list(self.data_dict.keys())
        # Data dict: { bag_name: { 'bag_label': int, 'inst_paths' : [str, str, ...], 'inst_labels': array, 'L_mat' : array} }

        self.data_shape = self._compute_data_shape()
    
    def _compute_data_shape(self):
        raise NotImplementedError

    def _loader(self, *args, **kwargs):
        raise NotImplementedError

    def _load_bag_feat(self, bag_name, *args, **kwargs):

        if 'inst_paths' not in self.data_dict[bag_name]:
            raise ValueError(f'[{self.dataset_name}] Instance paths not found for bag {bag_name}')

        feat_list = []
        for inst_path in self.data_dict[bag_name]['inst_paths']:
            try:
                data = self._loader(inst_path)
            except Exception as e:
                print(f'[{self.dataset_name}] Error loading instance {inst_path}: {e}')
                continue
            feat_list.append(data)

        bag_feat = np.array(feat_list)
        
        return bag_feat

    def _init_data_dict(self):
        raise NotImplementedError

    def _build_edge_index(self, *args, **kwargs):
        raise NotImplementedError

    def __getitem__(self, index):
        
        bag_name = self.bag_names[index]
        
        bag_data = self._load_bag_feat(bag_name)
        bag_label = self.data_dict[bag_name]['bag_label']
        inst_labels = self.data_dict[bag_name]['inst_labels']
        edge_index = self.data_dict[bag_name]['edge_index']
        # mask = self.data_dict[bag_name]['mask']

        return torch.from_numpy(bag_data), torch.as_tensor(bag_label), torch.from_numpy(inst_labels), torch.from_numpy(edge_index)

    def __len__(self):
        return len(self.bag_names)

    def get_bag_labels(self):
        return [ self.data_dict[bag_name]['bag_label'] for bag_name in self.bag_names ]
    
    def subset(self, idx):

        new_dataset = deepcopy(self)
        new_dataset.bag_names = [self.bag_names[i] for i in idx]
        new_dataset.data_dict = { bag_name: self.data_dict[bag_name] for bag_name in new_dataset.bag_names }

        return new_dataset