from sklearn.preprocessing import MinMaxScaler
import numpy as np
from torch.utils.data import Dataset
import scipy.io
import torch

class BDGP_Caltech7(Dataset):
    def __init__(self, path):
        data = scipy.io.loadmat(path)
        scaler = MinMaxScaler()

        self.view1 = scaler.fit_transform(data['X1'].T.astype(np.float32).transpose())
        self.view2 = scaler.fit_transform(data['X2'].T.astype(np.float32).transpose())
        self.view3 = scaler.fit_transform(data['X3'].T.astype(np.float32).transpose())
        self.view4 = scaler.fit_transform(data['X4'].T.astype(np.float32).transpose())
        self.view5 = scaler.fit_transform(data['X5'].T.astype(np.float32).transpose())

        self.views = []
        self.views.append(self.view1)
        self.views.append(self.view2)
        self.views.append(self.view3)
        self.views.append(self.view4)
        self.views.append(self.view5)

        self.labels = scipy.io.loadmat(path)['Y'].T

    def __len__(self):
        return 1400

    def __getitem__(self, idx):
        return [torch.from_numpy(self.view1[idx]), torch.from_numpy(
            self.view2[idx]), torch.from_numpy(self.view3[idx])
                ,torch.from_numpy(self.view4[idx]), torch.from_numpy(
            self.view5[idx])],\
               torch.from_numpy(self.labels[idx]), torch.from_numpy(
            np.array(idx)).long()

    def get_view(self, idx):
        if idx == -1:
            return self.labels
        else:
            return self.views[idx]

def load_data(dataset):
    if dataset == "Caltech7":
        dataset = BDGP_Caltech7('./data/Caltech-5V-7.mat')
        dims = [40, 254, 1984, 512, 928]
        view = 5
        data_size = 1400
        class_num = 7
    else:
        raise NotImplementedError
    dataset, ts = Form_Unaligned_Data(dataset, view)
    return dataset, dims, view, data_size, class_num, ts

import random
def Form_Unaligned_Data(dataset, view):
    X = []
    Y = []
    for i in range(view):
        X.append(dataset.get_view(i))
        Y.append(dataset.get_view(-1))

    size = len(Y[0])
    view_num = len(X)
    t = np.linspace(0, size - 1, size, dtype=int)

    random.shuffle(t)
    Xtmp = []
    Ytmp = []
    for i in range(view_num):
        xtmp = np.copy(X[i])
        Xtmp.append(xtmp)
        ytmp = np.copy(Y[i])
        Ytmp.append(ytmp)

    ts = []
    for v in range(view_num):
        random.shuffle(t)
        ts.append(t)
        Xtmp[v][:] = X[v][t]
        Ytmp[v][:] = Y[v][t]
    X = Xtmp
    Y = Ytmp

    result = GN_Dataset(X, Y, view)
    return result, ts


class GN_Dataset(Dataset):
    def __init__(self, dataset, y, view_num):
        self.dataset = dataset
        self.labels = y
        self.view_num = view_num

    def __len__(self):
        return len(self.labels[0])

    def __getitem__(self, index):
        X = []
        Y = []
        for i in range(self.view_num):
            data_tensor = torch.from_numpy(self.dataset[i][index].astype(np.float32).transpose())
            label = self.labels[i][index]
            label_tensor = torch.tensor(label)
            X.append(data_tensor)
            Y.append(label_tensor)

        return X, Y, index