import os
import h5py
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset
import pdb
import warnings
from tqdm import tqdm
import gc

warnings.filterwarnings("ignore", message="Sparse CSR tensor support is in beta state.*")


class BaseDataset(Dataset):
    def __init__(self, csv_path, split, label_dict, label_col='type', ignore=[]):
        """
        Args:
            csv_path (str): Path to the csv file with annotations.
            split (pd.DataFrame): Train/val/test split. 
            label_dict (dict): Dictionary with key, value pairs for converting str labels to int. 
            label_col (str, optional): Label column. Defaults to 'type'.
            ignore (list, optional): Ignored labels. Defaults to [].
        """        
        slide_data = pd.read_csv(csv_path)
        slide_data.columns = ['slide_id', 'type']
        slide_data = self._df_prep(slide_data, label_dict, ignore, label_col)
        assert len(split) > 0, "Split should not be empty!"
        mask = slide_data['slide_id'].isin(split.tolist())
        self.slide_data = slide_data[mask].reset_index(drop=True)
        self.n_cls = len(set(label_dict.values()))
        self.slide_cls_ids = self._cls_ids_prep()
        self._print_info()

    def __len__(self):
        return len(self.slide_data)

    def __getitem__(self, idx):
        return None

    def _print_info(self):
        print("Number of classes: {}".format(self.n_cls))
        print("Slide-level counts: ", '\n', self.slide_data['label'].value_counts(sort = False))
        print("\n")

    def _cls_ids_prep(self):
        slide_cls_ids = [[] for i in range(self.n_cls)]
        for i in range(self.n_cls):
            slide_cls_ids[i] = np.where(self.slide_data['label'] == i)[0]
        return slide_cls_ids

    def get_label(self, ids):
        return self.slide_data['label'][ids]

    @staticmethod
    def _df_prep(data, label_dict, ignore, label_col):
        if label_col != 'label':
            data['label'] = data[label_col].copy()

        mask = data['label'].isin(ignore)
        data = data[~mask]
        data.reset_index(drop=True, inplace=True)
        for i in data.index:
            key = data.loc[i, 'label']
            data.at[i, 'label'] = label_dict[key]
        return data
    
# PatchFeatureDataset(csv_path=args.csv_path, split=val_split, label_dict=label_dict, 
#                                         label_col=args.label_col, data_path=args.data_path, 
#                                         low_mag=args.low_mag, mid_mag=args.mid_mag, high_mag=args.high_mag,
#                                         n_mag=args.n_mag, use_HF=args.use_HF, use_LF=args.use_LF,
#                                         use_graph=args.use_graph, LapPE_k=args.LapPE_k, RWPE_length=args.RWPE_length,
#                                         sigma_HF=args.sigma_HF, sigma_LF=args.sigma_LF,
#                                         graph_threshold_HF=args.graph_threshold_HF, graph_threshold_LF=args.graph_threshold_LF,
#                                         graph_top_k_HF=args.graph_top_k_HF, graph_top_k_LF=args.graph_top_k_LF
#                                         # )

class PatchFeatureDataset(BaseDataset):
    def __init__(self, data_name, data_path, low_mag, mid_mag, high_mag, n_mag, use_HF, use_LF, use_graph, LapPE_k, use_SignLapPE, RWPE_length, sigma_HF, sigma_LF, graph_threshold_HF, graph_threshold_LF, graph_top_k_HF, graph_top_k_LF, **kwargs):
        """
        Args:
            data_path (str): Path to the data. 
            low_mag (str): Low magnifications. 
            mid_mag (str): Middle magnifications.
            high_mag (str): High magnifications.
        """        
        super(PatchFeatureDataset, self).__init__(**kwargs)
        self.data_name = data_name
        self.data_path = data_path
        self.low_mag = low_mag
        self.mid_mag = mid_mag
        self.high_mag = high_mag
        self.n_mag = n_mag
        self.use_HF = use_HF
        self.use_LF = use_LF
        self.use_graph = use_graph
        self.LapPE_k = LapPE_k
        self.RWPE_length = RWPE_length
        self.sigma_HF = sigma_HF
        self.sigma_LF = sigma_LF
        self.graph_threshold_HF = graph_threshold_HF
        self.graph_threshold_LF = graph_threshold_LF
        self.graph_top_k_HF = graph_top_k_HF
        self.graph_top_k_LF = graph_top_k_LF
        self.use_SignLapPE = use_SignLapPE
        
        if data_name == 'BRIGHT':
            
            for split in ['Train', 'Validation']:
                for cls in ['Cancerous', 'Non-cancerous',  'Pre-cancerous']:
                    for file_name in os.listdir(os.path.join(data_path, split, cls)):
                        if file_name.endswith('.h5'):
                            slide_id = file_name.split('.')[0]
                            self.slide_data['slide_id'] = self.slide_data['slide_id'].replace(slide_id, os.path.join(split,cls,slide_id))
    
        
    def __len__(self):
        return len(self.slide_data)

    def __getitem__(self, idx):
        # return self.data_list[idx]
        slide_id = self.slide_data['slide_id'][idx]
        label = self.slide_data['label'][idx]
        data = {'label': label}

        with h5py.File(os.path.join(self.data_path, '{}.h5'.format(slide_id)),'r') as hdf5_file:
            data['low_mag_feats'] = torch.from_numpy(hdf5_file[f'{self.low_mag}_patches'][:])
            data['low_mag_coords'] = torch.from_numpy(hdf5_file[f'{self.low_mag}_coords'][:])

            if self.n_mag > 1:
                data['mid_mag_feats'] = torch.from_numpy(hdf5_file[f'{self.mid_mag}_patches'][:])
                data['mid_mag_coords'] = torch.from_numpy(hdf5_file[f'{self.mid_mag}_coords'][:])
                
            if self.n_mag == 3:
                data['high_mag_feats']  = torch.from_numpy(hdf5_file[f'{self.high_mag}_patches'][:])
                data['high_mag_coords'] = torch.from_numpy(hdf5_file[f'{self.high_mag}_coords'][:])

            if self.use_HF or self.use_LF:
                low_mag_feats_filtered = hdf5_file[f'{self.low_mag}_filtered_patches'][:]
                if self.use_HF:
                    sigma_idx = int(self.sigma_HF / 10)-1
                    data['low_mag_feats_HF']  = torch.from_numpy(low_mag_feats_filtered[:,sigma_idx, 1,:])
                if self.use_LF:
                    sigma_idx = int(self.sigma_LF / 10)-1
                    data['low_mag_feats_LF']  = torch.from_numpy(low_mag_feats_filtered[:,sigma_idx, 0,:])
                if self.n_mag>1:
                    mid_mag_feats_filtered = hdf5_file[f'{self.mid_mag}_filtered_patches'][:]
                    if self.use_HF:
                        sigma_idx = int(self.sigma_HF / 10)-1
                        data['mid_mag_feats_HF'] = torch.from_numpy(mid_mag_feats_filtered[:,sigma_idx, 1,:])
                    if self.use_LF:
                        sigma_idx = int(self.sigma_LF / 10)-1
                        data['mid_mag_feats_LF'] = torch.from_numpy(mid_mag_feats_filtered[:,sigma_idx, 0,:])
                if self.n_mag==3:
                    high_mag_feats_filtered = hdf5_file[f'{self.high_mag}_filtered_patches'][:]
                    if self.use_HF:
                        sigma_idx = int(self.sigma_HF / 10)-1
                        data['high_mag_feats_HF'] = torch.from_numpy(high_mag_feats_filtered[:,sigma_idx, 1,:])
                    if self.use_LF:
                        sigma_idx = int(self.sigma_LF / 10)-1
                        data['high_mag_feats_LF'] = torch.from_numpy(high_mag_feats_filtered[:,sigma_idx, 0,:])

            if self.use_graph:
                data['base_edge_index'] = torch.from_numpy(hdf5_file[f'{self.low_mag}_base_adj_edge_index_{self.graph_threshold_HF}'][:])
                data['base_edge_attr'] = torch.from_numpy(hdf5_file[f'{self.low_mag}_base_adj_edge_attr_{self.graph_threshold_HF}'][:])

                data['feat_edge_index'] = torch.from_numpy(hdf5_file[f'{self.low_mag}_feat_adj_edge_index_{self.graph_threshold_HF}_{self.graph_top_k_HF}'][:])
                data['feat_edge_attr'] = torch.from_numpy(hdf5_file[f'{self.low_mag}_feat_adj_edge_attr_{self.graph_threshold_HF}_{self.graph_top_k_HF}'][:])

                if self.use_HF:
                    data['HF_edge_index'] = torch.from_numpy(hdf5_file[f'{self.low_mag}_feat_filtered_adj_edge_index_{self.sigma_HF}_HPF_{self.graph_threshold_HF}_{self.graph_top_k_HF}'][:])
                    data['HF_edge_attr'] = torch.from_numpy(hdf5_file[f'{self.low_mag}_feat_filtered_adj_edge_attr_{self.sigma_HF}_HPF_{self.graph_threshold_HF}_{self.graph_top_k_HF}'][:])
                    
                if self.use_LF:
                    data['LF_edge_index'] = torch.from_numpy(hdf5_file[f'{self.low_mag}_feat_filtered_adj_edge_index_{self.sigma_LF}_LPF_{self.graph_threshold_LF}_{self.graph_top_k_LF}'][:])
                    data['LF_edge_attr'] = torch.from_numpy(hdf5_file[f'{self.low_mag}_feat_filtered_adj_edge_attr_{self.sigma_LF}_LPF_{self.graph_threshold_LF}_{self.graph_top_k_LF}'][:])

            if self.RWPE_length>0:
                # data['RWPE'] = torch.from_numpy(hdf5_file[f'{self.low_mag}_base_RWPE_{self.graph_threshold_HF}'][:][:,:self.RWPE_length]).float()
                data['RWPE'] = torch.from_numpy(hdf5_file[f'{self.low_mag}_base_RWPE_{self.graph_threshold_HF}'][:][:,:self.RWPE_length]).float()

            if self.LapPE_k>0:
                # LapPE = torch.from_numpy(hdf5_file[f'{self.low_mag}_base_LapPE_{self.graph_threshold_HF}'][:][:,:self.LapPE_k]).float()
                LapPE = torch.from_numpy(hdf5_file[f'{self.low_mag}_base_LapPE_{self.graph_threshold_HF}'][:][:,:self.LapPE_k]).float()
                if self.use_SignLapPE:
                    # for sign ambiguity
                    sign = -1 + 2 * torch.randint(0, 2, (self.LapPE_k, ))
                    LapPE *= sign
                data['LapPE'] = LapPE
        return data
