import numpy as np
import pytorch_lightning as pl
import random
import torch
import torch.utils.data
from sklearn.model_selection import KFold, train_test_split

import kaplan_meier
import loss_function


class TorchDataset(torch.utils.data.Dataset):
    def __init__(self, x, y):
        self.x = torch.tensor(x, dtype = torch.float)
        self.y = torch.tensor(y, dtype = torch.float)

        self.indices = np.arange(len(self.y))
        random.shuffle(self.indices)

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        idx = self.indices[idx]
        return self.x[idx], self.y[idx]

class TorchDataset_Predict(torch.utils.data.Dataset):
    def __init__(self, x):
        self.x = torch.tensor(x, dtype = torch.float)

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        return self.x[idx]

class SurvivalAnalysisDataModule(pl.LightningDataModule):
    def __init__(self, nn_param, dataset):
        super().__init__()

        # init
        self.nn_param = nn_param
        self.dataset = dataset
        self.fold = 0
        self.test_batch_size = 999999
        self.shuffle_train = False
        self.shuffle_val = False
        self.shuffle_test = False

        # prepare five-fold cross validation
        self.train_indices = []
        self.val_indices = []
        self.test_indices = []
        kf = KFold(n_splits = 5, shuffle = True)
        split_result = kf.split(self.dataset.original_x,
                                self.dataset.original_y)
        for train_index, test_index in split_result:
            it, iv, _, _ = train_test_split(train_index, train_index,
                                            test_size = 0.25)
            self.train_indices.append(it)
            self.val_indices.append(iv)
            self.test_indices.append(test_index)

    def setup(self, stage = None):
        pass

    def set_fold(self, fold):
        self.fold = fold

    def set_test_batch_size(self, test_batch_size):
        self.test_batch_size = test_batch_size

    def train_dataloader(self):
        train_indices = self.train_indices[self.fold]
        x = np.array(self.dataset.original_x[train_indices])
        y = np.array(self.dataset.original_y[train_indices])
        torch_dataset = self.nn_param.get('torch_dataset', 'Normal_dataset')
        if torch_dataset == 'Normal_dataset':
            td = TorchDataset(x,y)
        else:
            print('Unknown torch_dataset: %s' % torch_dataset)
            #td = ClusteredDataset(x,y)
            #td.do_clustering()
            sys.exit()
        return torch.utils.data.DataLoader(td,
                                           num_workers=0,
                                           batch_size=99999999,
                                           shuffle = self.shuffle_train,
                                           drop_last = False)

    def val_dataloader(self):
        x = np.array(self.dataset.original_x[self.val_indices[self.fold]])
        y = np.array(self.dataset.original_y[self.val_indices[self.fold]])
        td = TorchDataset(x,y)
        return torch.utils.data.DataLoader(td,
                                           num_workers=0,
                                           batch_size=99999999,
                                           shuffle = self.shuffle_val,
                                           drop_last = False)

    def test_dataloader(self):
        x = np.array(self.dataset.original_x[self.test_indices[self.fold]])
        y = np.array(self.dataset.original_y[self.test_indices[self.fold]])
        td = TorchDataset(x,y)
        return torch.utils.data.DataLoader(td,
                                           num_workers=0,
                                           batch_size=self.test_batch_size,
                                           shuffle = self.shuffle_test,
                                           drop_last = False)

    def predict_train_dataloader(self):
        x = np.array(self.dataset.original_x[self.train_indices[self.fold]])
        td = TorchDataset_Predict(x)
        return torch.utils.data.DataLoader(td,
                                           num_workers=0,
                                           batch_size=99999999,
                                           shuffle = False,  # no shuffle
                                           drop_last = False)

    def predict_dataloader(self):
        x = np.array(self.dataset.original_x[self.test_indices[self.fold]])
        td = TorchDataset_Predict(x)
        return torch.utils.data.DataLoader(td,
                                           num_workers=0,
                                           batch_size=99999999,
                                           shuffle = False,  # no shuffle
                                           drop_last = False)
