import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor, GaussianBlur
from torchvision.transforms.functional import adjust_sharpness
import pandas as pd
import numpy as np
import os
import pickle as pkl
import cv2
from sklearn.model_selection import train_test_split
import nibabel as nib
import json
from pycox.preprocessing.label_transforms import LabTransDiscreteTime
from scipy.ndimage import zoom
from PIL import Image
from scipy.ndimage import gaussian_filter


class MIMICCXRDatasetShift(Dataset):
    def __init__(self, hparams, data_split, transform=None, shift=None, group=None):
        super(MIMICCXRDatasetShift, self).__init__()
        self.hparams = hparams
        self.data_split = data_split
        self.transform = transform
        self.shift = shift
        self.blur = GaussianBlur(kernel_size=(11), sigma=(100))
        self.group = group
    
        self.mtdt = pd.read_csv(os.path.join(self.hparams['data_dir'], 'mimiccxr_metadata.csv'))
        icustays = np.unique(self.mtdt['subject_id'])
        train_idx, test_val_idx = train_test_split(icustays, test_size=0.4, random_state=42)
        val_idx, test_idx = train_test_split(test_val_idx, test_size=0.5, random_state=42)
        if self.data_split == 'train':
            self.mtdt = self.mtdt[self.mtdt['subject_id'].isin(train_idx)].copy()
        elif self.data_split == 'val':
            self.mtdt = self.mtdt[self.mtdt['subject_id'].isin(val_idx)].copy()
        elif self.data_split == 'test':
            self.mtdt = self.mtdt[self.mtdt['subject_id'].isin(test_idx)].copy()
        
        with open(os.path.join(self.hparams['data_dir'], 'var_encoding.json'), 'r') as f:
            var_encoding = json.load(f)
        self.var_encoding = var_encoding
        for field in var_encoding:
            self.mtdt[field] = self.mtdt[field].map(var_encoding[field])
        self.mtdt['age'] = (self.mtdt['age'] > 60).astype(int)
        self.mtdt['race'] = (self.mtdt['race'] == 0).astype(int)

        # create shift
        age_d_idx = self.mtdt[(self.mtdt['age'] == self.group) & (self.mtdt['indicator'] == 1)].index.to_list()
        age_d_idx = age_d_idx[:(int(len(age_d_idx) * 0.9))]
        sex_d_idx = self.mtdt[(self.mtdt['sex'] == self.group) & (self.mtdt['indicator'] == 1)].index.to_list()
        sex_d_idx = sex_d_idx[:(int(len(sex_d_idx) * 0.9))]
        race_d_idx = self.mtdt[(self.mtdt['race'] == self.group) & (self.mtdt['indicator'] == 1)].index.to_list()
        race_d_idx = race_d_idx[:(int(len(race_d_idx) * 0.9))]
        self.d_idx = {'age': age_d_idx, 'sex': sex_d_idx, 'race': race_d_idx}

        age_y_idx = self.mtdt[(self.mtdt['age'] == self.group)].index.to_list()
        sex_y_idx = self.mtdt[(self.mtdt['sex'] == self.group)].index.to_list()
        race_y_idx = self.mtdt[(self.mtdt['race'] == self.group)].index.to_list()
        self.y_idx = {'age': age_y_idx, 'sex': sex_y_idx, 'race': race_y_idx}
        self.y_noise = {'age': np.random.randint(-30, 31, size=len(age_y_idx)), 
                        'sex': np.random.randint(-30, 31, size=len(sex_y_idx)), 
                        'race': np.random.randint(-30, 31, size=len(race_y_idx))}


    def __len__(self):
        return len(self.mtdt)
    
    def __getitem__(self, idx):
        item = self.mtdt.iloc[idx]
        sensitive_attribute = item[self.hparams['sensitive_attribute']]
        if self.data_split == 'train':
            indicator = item['indicator_discrete']
            time_to_event = item['time_to_event_discrete']
        else:
            indicator = item['indicator']
            time_to_event = item['time_to_event']

        img = cv2.imread(os.path.join(self.hparams['data_dir'], 'images', item['dicom_id'] + '.jpg'))
        if self.shift == 'x' and self.data_split == 'train' and sensitive_attribute == self.group:
            img = np.array(self.blur(Image.fromarray(img)))
        img = cv2.resize(img, (224, 224), interpolation=cv2.INTER_LINEAR)
        if self.shift == 'd' and self.data_split == 'train' and idx in self.d_idx[self.hparams['sensitive_attribute']]:
            indicator = 0
        if self.shift == 'y' and self.data_split == 'train' and idx in self.y_idx[self.hparams['sensitive_attribute']]:
            time_to_event += self.y_noise[self.hparams['sensitive_attribute']][self.y_idx[self.hparams['sensitive_attribute']].index(idx)]
            time_to_event = max(0, time_to_event)
            time_to_event = min(126, time_to_event)
        if self.transform:
            img = self.transform(self.hparams, img)
        return img, indicator, time_to_event, sensitive_attribute
    
    def discretize_label(self, num_time_steps=128, label_transform=None):
        if self.data_split == 'train':
            Y_train_np = self.mtdt['time_to_event'].to_numpy()
            D_train_np = self.mtdt['indicator'].to_numpy()
            if num_time_steps == 0:
                mask = (D_train_np == 1)  # boolean mask specifying which training patients experienced death
                label_transform = LabTransDiscreteTime(np.unique(Y_train_np[mask]))
            else:
                # use a quantile based discretization, which could possibly end up using fewer than the
                # number of time steps requested (if it turns out that in the dataset, there are many
                # duplicate observed times)
                label_transform = LabTransDiscreteTime(num_time_steps, scheme='quantiles')

            Y_train_discrete_np, D_train_discrete_np = label_transform.fit_transform(Y_train_np, D_train_np)
            self.mtdt['indicator_discrete'] = D_train_discrete_np
            self.mtdt['time_to_event_discrete'] = Y_train_discrete_np
            return label_transform


class AREDSDatasetShift(Dataset):
    def __init__(self, hparams, data_split, transform=None, shift=None, group=None):
        super(AREDSDatasetShift, self).__init__()
        self.hparams = hparams
        self.data_split = data_split
        self.transform = transform
        self.shift = shift
        self.blur = GaussianBlur(kernel_size=(7), sigma=(100))
        self.group = group
    
        self.mtdt = pd.read_csv(os.path.join(self.hparams['data_dir'], 'areds_metadata.csv'), dtype={'pid': str})
        patients = self.mtdt['pid'].unique()
        if hparams['sensitive_attribute'] == 'race':
            train_idx, test_val_idx = train_test_split(patients, test_size=0.4, random_state=41)
            val_idx, test_idx = train_test_split(test_val_idx, test_size=0.5, random_state=41)
        else:
            train_idx, test_val_idx = train_test_split(patients, test_size=0.4, random_state=42)
            val_idx, test_idx = train_test_split(test_val_idx, test_size=0.5, random_state=42)
        if self.data_split == 'train':
            self.mtdt = self.mtdt[self.mtdt['pid'].isin(train_idx)].copy()
        elif self.data_split == 'val':
            self.mtdt = self.mtdt[self.mtdt['pid'].isin(val_idx)].copy()
        elif self.data_split == 'test':
            self.mtdt = self.mtdt[self.mtdt['pid'].isin(test_idx)].copy()
        
        with open(os.path.join(self.hparams['data_dir'], 'var_encoding.json'), 'r') as f:
            var_encoding = json.load(f)
        self.var_encoding = var_encoding
        for field in var_encoding:
            self.mtdt[field] = self.mtdt[field].map(var_encoding[field])
        self.mtdt['age'] = (self.mtdt['age'] > 70).astype(int)
        self.mtdt['race'] = (self.mtdt['race'] == 1).astype(int)

        # create shift
        age_d_idx = self.mtdt[(self.mtdt['age'] == self.group) & (self.mtdt['indicator'] == 1)].index.to_list()
        age_d_idx = age_d_idx[:(int(len(age_d_idx) * 0.9))]
        sex_d_idx = self.mtdt[(self.mtdt['sex'] == self.group) & (self.mtdt['indicator'] == 1)].index.to_list()
        sex_d_idx = sex_d_idx[:(int(len(sex_d_idx) * 0.9))]
        race_d_idx = self.mtdt[(self.mtdt['race'] == self.group) & (self.mtdt['indicator'] == 1)].index.to_list()
        race_d_idx = race_d_idx[:(int(len(race_d_idx) * 0.9))]
        self.d_idx = {'age': age_d_idx, 'sex': sex_d_idx, 'race': race_d_idx}

        age_y_idx = self.mtdt[(self.mtdt['age'] == self.group)].index.to_list()
        sex_y_idx = self.mtdt[(self.mtdt['sex'] == self.group)].index.to_list()
        race_y_idx = self.mtdt[(self.mtdt['race'] == self.group)].index.to_list()
        self.y_idx = {'age': age_y_idx, 'sex': sex_y_idx, 'race': race_y_idx}
        self.y_noise = {'age': np.random.randint(-3, 4, size=len(age_y_idx)), 
                        'sex': np.random.randint(-3, 4, size=len(sex_y_idx)), 
                        'race': np.random.randint(-3, 4, size=len(race_y_idx))}

    def __len__(self):
        return len(self.mtdt)
    
    def __getitem__(self, idx):
        item = self.mtdt.iloc[idx]
        sensitive_attribute = item[self.hparams['sensitive_attribute']]
        if self.data_split == 'train':
            indicator = item['indicator_discrete']
            time_to_event = item['time_to_event_discrete']
        else:
            indicator = item['indicator']
            time_to_event = item['time_to_event']

        img = cv2.imread(item['img_file'])
        assert img is not None, f'Image loading failed: {item["img_file"]}'
        if self.shift == 'x' and self.data_split == 'train' and sensitive_attribute == self.group:
            img = np.array(self.blur(Image.fromarray(img)))
        img = cv2.resize(img, (224, 224), interpolation=cv2.INTER_LINEAR)
        if self.shift == 'd' and self.data_split == 'train' and idx in self.d_idx[self.hparams['sensitive_attribute']]:
            indicator = 0
        if self.shift == 'y' and self.data_split == 'train' and idx in self.y_idx[self.hparams['sensitive_attribute']]:
            time_to_event += self.y_noise[self.hparams['sensitive_attribute']][self.y_idx[self.hparams['sensitive_attribute']].index(idx)]
            time_to_event = max(0, time_to_event)
            time_to_event = min(12, time_to_event)
        if self.transform:
            img = self.transform(self.hparams, img)
        return img, indicator, time_to_event, sensitive_attribute

    def discretize_label(self, num_time_steps=128, label_transform=None):
        if self.data_split == 'train':
            Y_train_np = self.mtdt['time_to_event'].to_numpy()
            D_train_np = self.mtdt['indicator'].to_numpy()
            if num_time_steps == 0:
                mask = (D_train_np == 1)  # boolean mask specifying which training patients experienced death
                label_transform = LabTransDiscreteTime(np.unique(Y_train_np[mask]))
            else:
                # use a quantile based discretization, which could possibly end up using fewer than the
                # number of time steps requested (if it turns out that in the dataset, there are many
                # duplicate observed times)
                label_transform = LabTransDiscreteTime(num_time_steps, scheme='quantiles')

            Y_train_discrete_np, D_train_discrete_np = label_transform.fit_transform(Y_train_np, D_train_np)
            self.mtdt['indicator_discrete'] = D_train_discrete_np
            self.mtdt['time_to_event_discrete'] = Y_train_discrete_np
            return label_transform


class ADNIDatasetShift(Dataset):
    def __init__(self, hparams, data_split, transform=None, shift=None, group=None):
        super(ADNIDatasetShift, self).__init__()
        self.hparams = hparams
        self.data_split = data_split
        self.transform = transform
        self.shift = shift
        self.group = group
    
        self.mtdt = pd.read_csv(os.path.join(self.hparams['data_dir'], 'adni_metadata.csv'))
        subjects = np.unique(self.mtdt['subject_id'])
        train_idx, test_val_idx = train_test_split(subjects, test_size=0.4, random_state=42)
        test_idx, val_idx = train_test_split(test_val_idx, test_size=0.5, random_state=42)
        if self.data_split == 'train':
            self.mtdt = self.mtdt[self.mtdt['subject_id'].isin(train_idx)].copy()
        elif self.data_split == 'val':
            self.mtdt = self.mtdt[self.mtdt['subject_id'].isin(val_idx)].copy()
        elif self.data_split == 'test':
            self.mtdt = self.mtdt[self.mtdt['subject_id'].isin(test_idx)].copy()
        
        with open(os.path.join(self.hparams['data_dir'], 'var_encoding.json'), 'r') as f:
            var_encoding = json.load(f)
        self.var_encoding = var_encoding
        for field in var_encoding:
            self.mtdt[field] = self.mtdt[field].map(var_encoding[field])
        self.mtdt['age'] = (self.mtdt['age'] > 80).astype(int)
        self.mtdt['race'] = (self.mtdt['race'] == 0).astype(int)

        if self.group == 1:
            age_d_idx = self.mtdt[(self.mtdt['age'] == self.group) & (self.mtdt['indicator'] == 1)].index.to_list()
            age_d_idx = age_d_idx[:(int(len(age_d_idx) * 0.9))]
            sex_d_idx = self.mtdt[(self.mtdt['sex'] == self.group) & (self.mtdt['indicator'] == 1)].index.to_list()
            sex_d_idx = sex_d_idx[:(int(len(sex_d_idx) * 0.9))]
            race_d_idx = self.mtdt[(self.mtdt['race'] == self.group) & (self.mtdt['indicator'] == 1)].index.to_list()
            race_d_idx = race_d_idx[:(int(len(race_d_idx) * 0.9))]
            self.d_idx = {'age': age_d_idx, 'sex': sex_d_idx, 'race': race_d_idx}

            age_y_idx = self.mtdt[(self.mtdt['age'] == self.group)].index.to_list()
            sex_y_idx = self.mtdt[(self.mtdt['sex'] == self.group)].index.to_list()
            race_y_idx = self.mtdt[(self.mtdt['race'] == self.group)].index.to_list()
            self.y_idx = {'age': age_y_idx, 'sex': sex_y_idx, 'race': race_y_idx}
            self.y_noise = {'age': np.random.randint(-7, 8, size=len(age_y_idx)), 
                            'sex': np.random.randint(-7, 8, size=len(sex_y_idx)), 
                            'race': np.random.randint(-7, 8, size=len(race_y_idx))}
    
    def __len__(self):
        return len(self.mtdt)
    
    def __getitem__(self, idx):
        item = self.mtdt.iloc[idx]
        sensitive_attribute = item[self.hparams['sensitive_attribute']]
        if self.data_split == 'train':
            indicator = item['indicator_discrete']
            time_to_event = item['time_to_event_discrete']
        else:
            indicator = item['indicator']
            time_to_event = item['time_to_event']

        img = nib.load(os.path.join(self.hparams['data_dir'], item['img_path'])).get_fdata()
        img = np.array(img)
        img = zoom(img, (96/img.shape[0], 96/img.shape[1], 64/img.shape[2]))

        if self.shift == 'x' and self.data_split == 'train' and sensitive_attribute == self.group:
            img = gaussian_filter(img, sigma=100, radius=2)
        if self.shift == 'd' and self.data_split == 'train' and idx in self.d_idx[self.hparams['sensitive_attribute']]:
            indicator = 0
        if self.shift == 'y' and self.data_split == 'train' and idx in self.y_idx[self.hparams['sensitive_attribute']]:
            time_to_event += self.y_noise[self.hparams['sensitive_attribute']][self.y_idx[self.hparams['sensitive_attribute']].index(idx)]
            time_to_event = max(0, time_to_event)
            time_to_event = min(27, time_to_event)

        if self.transform:
            img = self.transform(self.hparams, img)
        return img, indicator, time_to_event, sensitive_attribute
    
    def discretize_label(self, num_time_steps=128, label_transform=None):
        if self.data_split == 'train':
            Y_train_np = self.mtdt['time_to_event'].to_numpy()
            D_train_np = self.mtdt['indicator'].to_numpy()
            if num_time_steps == 0:
                mask = (D_train_np == 1)  # boolean mask specifying which training patients experienced death
                label_transform = LabTransDiscreteTime(np.unique(Y_train_np[mask]))
            else:
                # use a quantile based discretization, which could possibly end up using fewer than the
                # number of time steps requested (if it turns out that in the dataset, there are many
                # duplicate observed times)
                label_transform = LabTransDiscreteTime(num_time_steps, scheme='quantiles')

            Y_train_discrete_np, D_train_discrete_np = label_transform.fit_transform(Y_train_np, D_train_np)
            self.mtdt['indicator_discrete'] = D_train_discrete_np
            self.mtdt['time_to_event_discrete'] = Y_train_discrete_np
            return label_transform