"""Pytorch Dataset object that loads MNIST and SVHN. It returns x,y,s where s=0 when x,y is taken from MNIST."""

import os
import numpy as np
import torch
import torch.utils.data as data_utils
from torchvision import datasets, transforms
from itertools import permutations, combinations_with_replacement
import random
import h5py

from sklearn import preprocessing


class Shape3D(data_utils.Dataset):
    def __init__(self,
                 list_train_domains,
                 root,
                 train=True,
                 transform=None,
                 download=True):
        """
        :param list_train_domains: all domains we observe in the training
        :param root: data directory
        :param train: whether to load MNIST training data
        :param mnist_subset: 'max' - for each domain, use 60000 MNIST samples, 'med' - use 10000 MNIST samples, 'min' - use 1000 MNIST samples
        :param transform: ...
        :param download: ...
        :param list_test_domains: whether to load unseen domains (this might be removed later, but I don't have time to optimize the code at this point)
        :param num_supervised: whether to further subsample
        """

        self.list_train_domains = list_train_domains
        self.root = os.path.expanduser(root)
        self.transform = transform
        self.train = train
        self.download = download
        # self.not_eval = not_eval  # load test MNIST dataset

        self.data, self.label, self.domain = self._get_data()

    def _get_data(self):

        dataset = h5py.File(f'{self.root}/3dshape/3dshapes.h5', 'r')
        images = np.load(f'{self.root}/3dshape/images_3.npy')
        #images = dataset['images']
        all_labels = dataset['labels']  # array shape [480000,6], float64

        # subsampling
        # images = images[:]
        # all_labels = all_labels[:]
        # images = images[all_labels[:, 3] == 1.25]
        # all_labels = all_labels[all_labels[:, 3] == 1.25]
        #
        # images /= 255
        # images = torch.Tensor(images).to(torch.float32).permute(0, 3, 1, 2)

        if self.train:
            train_idx = np.load(f'{self.root}/3dshape/train_idx3.npy')
            images = images[train_idx]
            all_labels = all_labels[train_idx]
        else:
            test_idx = np.load(f'{self.root}/3dshape/test_idx3.npy')
            images = images[test_idx]
            all_labels = all_labels[test_idx]

        # use shape as domains
        domain = torch.Tensor(all_labels[:, 4]).long()

        # use object hue as labels
        le = preprocessing.LabelEncoder()
        label = le.fit_transform(all_labels[:, 2])

        return torch.Tensor(images), torch.Tensor(label).long(), domain

    def __len__(self):

        return len(self.domain)

    def __getitem__(self, index):
        x = self.data[index]
        y = self.label[index]
        d = self.domain[index]

        if self.transform is not None:
            x = self.transform(x)
        return x, y, d


class Shape3D_CF(data_utils.Dataset):
    def __init__(self,
                 list_train_domains,
                 root,
                 transform=None,
                 download=True):
        """
        :param list_train_domains: all domains we observe in the training
        :param root: data directory
        :param train: whether to load MNIST training data
        :param mnist_subset: 'max' - for each domain, use 60000 MNIST samples, 'med' - use 10000 MNIST samples, 'min' - use 1000 MNIST samples
        :param transform: ...
        :param download: ...
        :param list_test_domains: whether to load unseen domains (this might be removed later, but I don't have time to optimize the code at this point)
        :param num_supervised: whether to further subsample
        """

        self.list_train_domains = list_train_domains
        self.root = os.path.expanduser(root)
        self.transform = transform
        self.download = download
        self.train = False

        # self.not_eval = not_eval  # load test MNIST dataset

        self.data1, self.data2, self.label, self.domain1, self.domain2 = self._get_data()

    def _get_data(self):

        dataset = h5py.File(f'{self.root}/3dshape/3dshapes.h5', 'r')
        #images = dataset['images']
        images = torch.Tensor(np.load(f'{self.root}/3dshape/images_3.npy'))
        all_labels = dataset['labels']  # array shape [480000,6], float64

        # subsampling
        # images = images[:]
        # all_labels = all_labels[:]
        # images = images[all_labels[:, 3] == 1.25]
        # all_labels = all_labels[all_labels[:, 3] == 1.25]
        #
        # images /= 255
        # images = torch.Tensor(images).to(torch.float32).permute(0, 3, 1, 2)

        # get test set
        if self.train:
            train_idx = np.load(f'{self.root}/3dshape/train_idx3.npy')
            images = images[train_idx]
            all_labels = all_labels[train_idx]
        else:
            test_idx = np.load(f'{self.root}/3dshape/test_idx3.npy')
            images = images[test_idx]
            all_labels = all_labels[test_idx]

        # use shape as domains
        domain_og = torch.Tensor(all_labels[:, 4]).long()

        # use object hue as labels
        le = preprocessing.LabelEncoder()
        label_og = np.array(le.fit_transform(all_labels[:, 2]))

        all_img = {}
        for idx in self.list_train_domains:
            all_img[idx] = images[domain_og == idx]
            self.n_samples = len(images[domain_og == idx])

        domain_to_domainlabel = dict()
        for i, domain in enumerate(self.list_train_domains):
            domain_to_domainlabel[domain] = i

        # Choose subsets that should be included into the training
        # all_domain_combs = list(permutations(self.list_train_domains, 2))
        all_domain_combs = list(combinations_with_replacement(self.list_train_domains, 2)) + \
                           list(combinations_with_replacement(reversed(self.list_train_domains), 2))
        training_list_img1 = []
        training_list_img2 = []
        training_list_domains1 = []
        training_list_domains2 = []
        for i in range(self.n_samples):
            d1, d2 = random.choice(all_domain_combs)
            training_list_img1.append(all_img[d1][i].unsqueeze(0))
            training_list_img2.append(all_img[d2][i].unsqueeze(0))
            training_list_domains1.append(domain_to_domainlabel[d1])
            training_list_domains2.append(domain_to_domainlabel[d2])

        train_img1 = torch.cat(training_list_img1)
        train_img2 = torch.cat(training_list_img2)
        train_domain1 = torch.Tensor(training_list_domains1)
        train_domain2 = torch.Tensor(training_list_domains2)
        train_label = label_og

        # Shuffle everything one more time
        inds = np.arange(self.n_samples)
        np.random.shuffle(inds)
        train_img1 = train_img1[inds]
        train_img2 = train_img2[inds]
        train_label = train_label[inds]
        train_domain1 = train_domain1[inds].long()
        train_domain2 = train_domain2[inds].long()

        return train_img1, train_img2, train_label, train_domain1, train_domain2

    def __len__(self):
        return len(self.label)

    def __getitem__(self, index):

        x1 = self.data1[index]
        x2 = self.data2[index]
        y = self.label[index]
        d1 = self.domain1[index]
        d2 = self.domain2[index]

        if self.transform is not None:
            x1 = self.transform(x1)
            x2 = self.transform(x2)

        return x1, x2, y, d1, d2