import numpy as np
import scipy.io as sio
from sklearn.preprocessing import MinMaxScaler
from torch.utils.data import Dataset


class MultiViewDataset(Dataset):
    def __init__(self, data_name, data_X, data_Y,names=None):
        super(MultiViewDataset, self).__init__()
        self.names = names
        self.data_name = data_name

        self.X = dict()
        self.num_views = len(data_X)
        for v in range(self.num_views):
            self.X[v] = self.normalize(data_X[v])

        self.Y = data_Y
        self.Y = np.squeeze(self.Y)
        if np.min(self.Y) == 1:
            self.Y = self.Y - 1
        self.Y = self.Y.astype(dtype=np.int64)
        self.num_classes = len(np.unique(self.Y))
        self.dims = self.get_dims()
        self.nums = self.get_nums()

    def __getitem__(self, index):
        data = dict()
        for v_num in range(len(self.X)):
            data[v_num] = (self.X[v_num][index]).astype(np.float32)
        target = self.Y[index]
        return data, target, index

    def __len__(self):
        return len(self.X[0])

    def get_dims(self):
        dims = []
        for view in range(self.num_views):
            dims.append([self.X[view].shape[1]])
        return np.array(dims)

    def get_nums(self):
        nums = []
        for c in range(self.num_classes):
            nums.append(np.sum(self.Y == c))

        return np.array(nums)

    @staticmethod
    def normalize(x, min=0):
        if min == 0:
            scaler = MinMaxScaler((0, 1))
        else:
            scaler = MinMaxScaler((-1, 1))
        norm_x = scaler.fit_transform(x)
        return norm_x


def Hand_train():
    data_path = "Tailed Data/LT_train_handwritten0.mat"
    data = sio.loadmat(data_path)
    data_X = [None, None, None, None, None, None]
    data_X[0] = data['X'][0][0]
    data_X[1] = data['X'][0][1]
    data_X[2] = data['X'][0][2]
    data_X[3] = data['X'][0][3]
    data_X[4] = data['X'][0][4]
    data_X[5] = data['X'][0][5]

    for v in range(len(data_X)):
        data_X[v] = data_X[v].T

    data_Y = data['gt']
    return MultiViewDataset("Hand_train", data_X, data_Y)

def Hand_test():
    data_path = "Tailed Data/LT_test_handwritten0.mat"
    data = sio.loadmat(data_path)
    data_X = [None, None, None, None, None, None]
    data_X[0] = data['X'][0][0]
    data_X[1] = data['X'][0][1]
    data_X[2] = data['X'][0][2]
    data_X[3] = data['X'][0][3]
    data_X[4] = data['X'][0][4]
    data_X[5] = data['X'][0][5]

    for v in range(len(data_X)):
        data_X[v] = data_X[v].T

    data_Y = data['gt']
    return MultiViewDataset("Hand_test", data_X, data_Y)