import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import pandas as pd
import numpy as np
import os
import pickle as pkl
import cv2
from sklearn.model_selection import train_test_split
import nibabel as nib
import json
from pycox.preprocessing.label_transforms import LabTransDiscreteTime
from scipy.ndimage import zoom


class MIMICCXRDataset(Dataset):
    def __init__(self, hparams, data_split, transform=None):
        self.hparams = hparams
        self.data_split = data_split
        self.transform = transform
    
        self.mtdt = pd.read_csv(os.path.join(self.hparams['data_dir'], 'mimiccxr_metadata.csv'))
        icustays = np.unique(self.mtdt['subject_id'])
        train_idx, test_val_idx = train_test_split(icustays, test_size=0.4, random_state=42)
        val_idx, test_idx = train_test_split(test_val_idx, test_size=0.5, random_state=42)
        if self.data_split == 'train':
            self.mtdt = self.mtdt[self.mtdt['subject_id'].isin(train_idx)].copy()
        elif self.data_split == 'val':
            self.mtdt = self.mtdt[self.mtdt['subject_id'].isin(val_idx)].copy()
        elif self.data_split == 'test':
            self.mtdt = self.mtdt[self.mtdt['subject_id'].isin(test_idx)].copy()
        
        with open(os.path.join(self.hparams['data_dir'], 'var_encoding.json'), 'r') as f:
            var_encoding = json.load(f)
        self.var_encoding = var_encoding
        for field in var_encoding:
            self.mtdt[field] = self.mtdt[field].map(var_encoding[field])
        self.mtdt['age'] = (self.mtdt['age'] > 60).astype(int)
        self.mtdt['race'] = (self.mtdt['race'] == 0).astype(int)

    def __len__(self):
        return len(self.mtdt)
    
    def __getitem__(self, idx):
        item = self.mtdt.iloc[idx]
        sensitive_attribute = item[self.hparams['sensitive_attribute']]
        if self.data_split == 'train':
            indicator = item['indicator_discrete']
            time_to_event = item['time_to_event_discrete']
        else:
            indicator = item['indicator']
            time_to_event = item['time_to_event']

        img = cv2.imread(os.path.join(self.hparams['data_dir'], 'images', item['dicom_id'] + '.jpg'))
        img = cv2.resize(img, (224, 224), interpolation=cv2.INTER_LINEAR)
        if self.transform:
            img = self.transform(self.hparams, img)
        return img, indicator, time_to_event, sensitive_attribute
    
    def discretize_label(self, num_time_steps=128, label_transform=None):
        if self.data_split == 'train':
            Y_train_np = self.mtdt['time_to_event'].to_numpy()
            D_train_np = self.mtdt['indicator'].to_numpy()
            if num_time_steps == 0:
                mask = (D_train_np == 1)  # boolean mask specifying which training patients experienced death
                label_transform = LabTransDiscreteTime(np.unique(Y_train_np[mask]))
            else:
                # use a quantile based discretization, which could possibly end up using fewer than the
                # number of time steps requested (if it turns out that in the dataset, there are many
                # duplicate observed times)
                label_transform = LabTransDiscreteTime(num_time_steps, scheme='quantiles')

            Y_train_discrete_np, D_train_discrete_np = label_transform.fit_transform(Y_train_np, D_train_np)
            self.mtdt['indicator_discrete'] = D_train_discrete_np
            self.mtdt['time_to_event_discrete'] = Y_train_discrete_np
            return label_transform


class ADNIDataset(Dataset):
    def __init__(self, hparams, data_split, transform=None):
        self.hparams = hparams
        self.data_split = data_split
        self.transform = transform
    
        self.mtdt = pd.read_csv(os.path.join(self.hparams['data_dir'], 'adni_metadata.csv'))
        subjects = np.unique(self.mtdt['subject_id'])
        train_idx, test_val_idx = train_test_split(subjects, test_size=0.4, random_state=42)
        test_idx, val_idx = train_test_split(test_val_idx, test_size=0.5, random_state=42)
        if self.data_split == 'train':
            self.mtdt = self.mtdt[self.mtdt['subject_id'].isin(train_idx)].copy()
        elif self.data_split == 'val':
            self.mtdt = self.mtdt[self.mtdt['subject_id'].isin(val_idx)].copy()
        elif self.data_split == 'test':
            self.mtdt = self.mtdt[self.mtdt['subject_id'].isin(test_idx)].copy()
        
        with open(os.path.join(self.hparams['data_dir'], 'var_encoding.json'), 'r') as f:
            var_encoding = json.load(f)
        self.var_encoding = var_encoding
        for field in var_encoding:
            self.mtdt[field] = self.mtdt[field].map(var_encoding[field])
        self.mtdt['age'] = (self.mtdt['age'] > 80).astype(int)
        self.mtdt['race'] = (self.mtdt['race'] == 0).astype(int)
    
    def __len__(self):
        return len(self.mtdt)
    
    def __getitem__(self, idx):
        item = self.mtdt.iloc[idx]
        sensitive_attribute = item[self.hparams['sensitive_attribute']]
        if self.data_split == 'train':
            indicator = item['indicator_discrete']
            time_to_event = item['time_to_event_discrete']
        else:
            indicator = item['indicator']
            time_to_event = item['time_to_event']

        img = nib.load(os.path.join(self.hparams['data_dir'], item['img_path'])).get_fdata()
        img = np.array(img)
        img = zoom(img, (96/img.shape[0], 96/img.shape[1], 64/img.shape[2]))
        if self.transform:
            img = self.transform(self.hparams, img)
        return img, indicator, time_to_event, sensitive_attribute
    
    def discretize_label(self, num_time_steps=128, label_transform=None):
        if self.data_split == 'train':
            Y_train_np = self.mtdt['time_to_event'].to_numpy()
            D_train_np = self.mtdt['indicator'].to_numpy()
            if num_time_steps == 0:
                mask = (D_train_np == 1)  # boolean mask specifying which training patients experienced death
                label_transform = LabTransDiscreteTime(np.unique(Y_train_np[mask]))
            else:
                # use a quantile based discretization, which could possibly end up using fewer than the
                # number of time steps requested (if it turns out that in the dataset, there are many
                # duplicate observed times)
                label_transform = LabTransDiscreteTime(num_time_steps, scheme='quantiles')

            Y_train_discrete_np, D_train_discrete_np = label_transform.fit_transform(Y_train_np, D_train_np)
            self.mtdt['indicator_discrete'] = D_train_discrete_np
            self.mtdt['time_to_event_discrete'] = Y_train_discrete_np
            return label_transform


class AREDSDataset(Dataset):
    def __init__(self, hparams, data_split, transform=None):
        self.hparams = hparams
        self.data_split = data_split
        self.transform = transform
    
        self.mtdt = pd.read_csv(os.path.join(self.hparams['data_dir'], 'areds_metadata_new.csv'), dtype={'pid': str})
        patients = self.mtdt['pid'].unique()
        # if hparams['sensitive_attribute'] == 'race':
        #     train_idx, test_val_idx = train_test_split(patients, test_size=0.4, random_state=41)
        #     val_idx, test_idx = train_test_split(test_val_idx, test_size=0.5, random_state=41)
        # else:
        #     train_idx, test_val_idx = train_test_split(patients, test_size=0.4, random_state=42)
        #     val_idx, test_idx = train_test_split(test_val_idx, test_size=0.5, random_state=42)
        train_idx, test_val_idx = train_test_split(patients, test_size=0.4, random_state=41)
        val_idx, test_idx = train_test_split(test_val_idx, test_size=0.5, random_state=41)
        if self.data_split == 'train':
            self.mtdt = self.mtdt[self.mtdt['pid'].isin(train_idx)].copy()
        elif self.data_split == 'val':
            self.mtdt = self.mtdt[self.mtdt['pid'].isin(val_idx)].copy()
        elif self.data_split == 'test':
            self.mtdt = self.mtdt[self.mtdt['pid'].isin(test_idx)].copy()
        
        with open(os.path.join(self.hparams['data_dir'], 'var_encoding.json'), 'r') as f:
            var_encoding = json.load(f)
        self.var_encoding = var_encoding
        for field in var_encoding:
            self.mtdt[field] = self.mtdt[field].map(var_encoding[field])
        self.mtdt['age'] = (self.mtdt['age'] > 70).astype(int)
        self.mtdt['race'] = (self.mtdt['race'] == 1).astype(int)

    def __len__(self):
        return len(self.mtdt)
    
    def __getitem__(self, idx):
        item = self.mtdt.iloc[idx]
        sensitive_attribute = item[self.hparams['sensitive_attribute']]
        if self.data_split == 'train':
            indicator = item['indicator_discrete']
            time_to_event = item['time_to_event_discrete']
        else:
            indicator = item['indicator']
            time_to_event = item['time_to_event']

        img = cv2.imread(item['img_file'])
        assert img is not None, f'Image loading failed: {item["img_file"]}.'
        img = cv2.resize(img, (224, 224), interpolation=cv2.INTER_LINEAR)
        if self.transform:
            img = self.transform(self.hparams, img)
        return img, indicator, time_to_event, sensitive_attribute

    def discretize_label(self, num_time_steps=128, label_transform=None):
        if self.data_split == 'train':
            Y_train_np = self.mtdt['time_to_event'].to_numpy()
            D_train_np = self.mtdt['indicator'].to_numpy()
            if num_time_steps == 0:
                mask = (D_train_np == 1)  # boolean mask specifying which training patients experienced death
                label_transform = LabTransDiscreteTime(np.unique(Y_train_np[mask]))
            else:
                # use a quantile based discretization, which could possibly end up using fewer than the
                # number of time steps requested (if it turns out that in the dataset, there are many
                # duplicate observed times)
                label_transform = LabTransDiscreteTime(num_time_steps, scheme='quantiles')

            Y_train_discrete_np, D_train_discrete_np = label_transform.fit_transform(Y_train_np, D_train_np)
            self.mtdt['indicator_discrete'] = D_train_discrete_np
            self.mtdt['time_to_event_discrete'] = Y_train_discrete_np
            return label_transform


class UKBiobankDataset(Dataset):
    def __init__(self, hparams, data_split, meta_filename, transform=None):
        self.hparams = hparams
        self.data_split = data_split
        self.transform = transform

        self.mtdt = pd.read_csv(os.path.join(self.hparams['data_dir'], meta_filename))
        eids = np.unique(self.mtdt['eid'])
        train_idx, test_val_idx = train_test_split(eids, test_size=0.4, random_state=42)
        val_idx, test_idx = train_test_split(test_val_idx, test_size=0.5, random_state=42)
        if self.data_split == 'train':
            self.mtdt = self.mtdt[self.mtdt['eid'].isin(train_idx)].copy()
        elif self.data_split == 'val':
            self.mtdt = self.mtdt[self.mtdt['eid'].isin(val_idx)].copy()
        elif self.data_split == 'test':
            self.mtdt = self.mtdt[self.mtdt['eid'].isin(test_idx)].copy()
        self.mtdt['age'] = (self.mtdt['age'] > 60).astype(int)
        self.mtdt['race'] = (self.mtdt['race'] == 0).astype(int)
    
    def __len__(self):
        return len(self.mtdt)
    
    def __getitem__(self, idx):
        item = self.mtdt.iloc[idx]
        sensitive_attribute = item[self.hparams['sensitive_attribute']]
        if self.data_split == 'train':
            indicator = item['indicator_discrete']
            time_to_event = item['time_to_event_discrete']
        else:
            indicator = item['indicator']
            time_to_event = item['time_to_event']

        img = cv2.imread(os.path.join(self.hparams['data_dir'], item['file']))
        img = cv2.resize(img, (224, 224), interpolation=cv2.INTER_LINEAR)
        if self.transform:
            img = self.transform(self.hparams, img)
        return img, indicator, time_to_event, sensitive_attribute
    
    def discretize_label(self, num_time_steps=128, label_transform=None):
        if self.data_split == 'train':
            Y_train_np = self.mtdt['time_to_event'].to_numpy()
            D_train_np = self.mtdt['indicator'].to_numpy()
            if num_time_steps == 0:
                mask = (D_train_np == 1)  # boolean mask specifying which training patients experienced death
                label_transform = LabTransDiscreteTime(np.unique(Y_train_np[mask]))
            else:
                # use a quantile based discretization, which could possibly end up using fewer than the
                # number of time steps requested (if it turns out that in the dataset, there are many
                # duplicate observed times)
                label_transform = LabTransDiscreteTime(num_time_steps, scheme='quantiles')

            Y_train_discrete_np, D_train_discrete_np = label_transform.fit_transform(Y_train_np, D_train_np)
            self.mtdt['indicator_discrete'] = D_train_discrete_np
            self.mtdt['time_to_event_discrete'] = Y_train_discrete_np
            return label_transform
    
class UKBiobankMIDataset(UKBiobankDataset):
    def __init__(self, hparams, data_split, transform=None):
        super().__init__(hparams, data_split, 'mi_metadata.csv', transform)

class UKBiobankPDDataset(UKBiobankDataset):
    def __init__(self, hparams, data_split, transform=None):
        super().__init__(hparams, data_split, 'pd_metadata.csv', transform)

class UKBiobankStrokeDataset(UKBiobankDataset):
    def __init__(self, hparams, data_split, transform=None):
        super().__init__(hparams, data_split, 'stroke_metadata.csv', transform)


if __name__ == '__main__':
    from tqdm import tqdm
    
    dataset = MIMICCXRDataset(hparams={'dataset':'mimiccxr', 'data_dir': 'datasets/mimiccxr', 'sensitive_attribute': 'sex'}, data_split='val')
    for img, indicator, time_to_event, sensitive_attribute in tqdm(dataset):
        print(img.shape, indicator, time_to_event, sensitive_attribute)
        break