import numpy as np
import torch
from torch.utils.data.dataset import Dataset
import numpy as np
import matplotlib.pyplot as plt
import scipy.io as sio
import glob
import PIL


# Multiview Dateset
class ViewDataset(Dataset):
    def __init__(self, v1, v2):
        self.v1 = torch.tensor(v1).unsqueeze(1)
        self.v2 = torch.tensor(v2).unsqueeze(1)
        self.data_len = v1.shape[0]

    def __getitem__(self, index):
        return self.v1[index], self.v2[index], index

    def __len__(self):
        return self.data_len


# Get a dataloader
def get_dataloader(view1, view2, batchsize, shuffle):
    dataset = ViewDataset(view1, view2)

    # Dataloader
    data_loader = torch.utils.data.DataLoader(dataset=dataset,
                                            batch_size=batchsize,
                                            shuffle=shuffle)

    return data_loader


# Get the Car3D dataset
def get_mnist(filedir='./'):
    data_file = filedir + 'MNIST.mat'
    data = sio.loadmat(data_file)

    # get number of samples
    n_sample = data['X1'].shape[0]
    n_valid = data['XV1'].shape[0]
    n_test = data['XTe1'].shape[0]

    view1 = data['X1']
    view2 = data['X2']
    view1_valid = data['XV1']
    view2_valid = data['XV2']
    view1_test = data['XTe1']
    view2_test = data['XTe2']

    labels = {'train': data['trainLabel'].flatten(), 'valid': data['tuneLabel'].flatten(),
            'test': data['testLabel'].flatten()}

    view1 = np.reshape(view1, (n_sample,28,28))
    view1 = np.transpose(view1, (0,2,1))
    view2 = np.reshape(view2, (n_sample,28,28))
    view2 = np.transpose(view2, (0,2,1))

    view1_valid = np.reshape(view1_valid, (n_valid,28,28))
    view1_valid = np.transpose(view1_valid, (0,2,1))
    view2_valid = np.reshape(view2_valid, (n_valid,28,28))
    view2_valid = np.transpose(view2_valid, (0,2,1))

    view1_test = np.reshape(view1_test, (n_test,28,28))
    view1_test = np.transpose(view1_test, (0,2,1))
    view2_test = np.reshape(view2_test, (n_test,28,28))
    view2_test = np.transpose(view2_test, (0,2,1))

    return view1, view2, view1_valid, view2_valid, view1_test, view2_test, labels
