import json
import numpy as np
import random
import sklearn.preprocessing as sp
import sys
import torch
from scipy.stats import wasserstein_distance
from sklearn.model_selection import KFold, train_test_split

import loss_function
import read_data


class TorchDataset(torch.utils.data.Dataset):
    def __init__(self, x, y):
        self.x = torch.tensor(x, dtype = torch.float)
        self.y = torch.tensor(y, dtype = torch.float)

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]

class TorchDataset_Predict(torch.utils.data.Dataset):
    def __init__(self, x):
        self.x = torch.tensor(x, dtype = torch.float)

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        return self.x[idx]

class DataModule(object):
    def __init__(self, args):
        self.args = args
        self.dataset_name = args.dataset_name

        # read dataset
        if args.dataset_name == 'flchain':
            df_x, df_y = read_data.flchain()
            self.original_x = df_x
            self.original_y = df_y
        elif args.dataset_name == 'prostateSurvival':
            df_x, df_y = read_data.prostateSurvival()
            self.original_x = df_x
            self.original_y = df_y
        elif args.dataset_name == 'support':
            df_x, df_y = read_data.support()
            self.original_x = df_x
            self.original_y = df_y
        else:
            print('Unknown dataset name: %s' % args.dataset_name)
            sys.exit()

        # transform dataframes into numpy arrays
        self.__normalize_x()
        self.__set_y()

        # data split for five-fold cross validation
        self.train_indices = []
        self.val_indices = []
        self.test_indices = []
        self.fold = -1
        kf = KFold(n_splits = 5, shuffle = True)
        split_result = kf.split(self.original_x, self.original_y)
        for train_index, test_index in split_result:
            it, iv, _, _ = train_test_split(train_index, train_index,
                                            test_size = 0.25)
            self.train_indices.append(it)
            self.val_indices.append(iv)
            self.test_indices.append(test_index)
        self.set_fold(0)

    def __normalize_x(self):
        x = self.original_x.values
        mmscaler = sp.MinMaxScaler()
        mmscaler.fit(x)
        self.x = mmscaler.transform(x)

    def __set_y(self):
        # self.y_max must be slightly larger than actual max
        self.y_max = self.original_y.iloc[:,0].max() * 1.001
        self.y = self.original_y.values

    def set_fold(self, fold):
        if self.fold == fold:
            return
        self.fold = fold

        # set self.index_list
        index = self.train_indices[fold]
        self.index_list = [ index ]

        if self.args.loss_function=='ProperRPS_teacher_pwl':
            zeros = np.zeros( (self.y.shape[0], self.args.num_bin+1) )
            y = np.concatenate([self.original_y.values, zeros], 1)
            filename = 'prediction/%s' % self.dataset_name
            filename += '_logarithmic_%d.json' % self.fold
            print('Reading '+filename)
            with open(filename) as f:
                pred = json.load(f)
            for key, value in pred.items():
                cumsum = np.cumsum(np.array(value))
                y[int(key),3:] = cumsum
            self.y = y

    def train_dataloader(self):
        ret = []
        for index in self.index_list:
            td = TorchDataset(self.x[index], self.y[index])
            dl = torch.utils.data.DataLoader(td,
                                             batch_size=self.args.batch_size,
                                             shuffle=True)
            ret.append(dl)
        return ret

    def val_dataloader(self):
        index = self.val_indices[self.fold]
        td = TorchDataset(self.x[index],self.y[index])
        return [ torch.utils.data.DataLoader(td,
                                            batch_size=self.args.batch_size,
                                            shuffle=False) ]

    def test_dataloader(self):
        index = self.test_indices[self.fold]
        td = TorchDataset(self.x[index],self.y[index])
        return [ torch.utils.data.DataLoader(td,batch_size=99999999,
                                            shuffle=False) ]

    def predict_dataloader(self, index=None):
        if index is None:
            index = self.test_indices[self.fold]

        td = TorchDataset_Predict(self.x[index])
        return [ torch.utils.data.DataLoader(td,batch_size=99999999,
                                            shuffle=False) ]
