"""Pytorch Dataset object that loads MNIST and SVHN. It returns x,y,s where s=0 when x,y is taken from MNIST."""

import os
import numpy as np
import torch
import torch.utils.data as data_utils
from torchvision import datasets, transforms
from itertools import permutations, combinations_with_replacement
import random


# code modified from https://github.com/atuannguyen/DIRT/blob/main/domain_gen_rotatedmnist/mnist_loader.py
class MnistRotated(data_utils.Dataset):
    def __init__(self,
                 list_train_domains,
                 root,
                 train=True,
                 mnist_subset='med',
                 transform=None,
                 download=True,
                 num_supervised=None):

        """
        :param list_train_domains: all domains we observe in the training
        :param root: data directory
        :param train: whether to load MNIST training data
        :param mnist_subset: 'max' - for each domain, use 60000 MNIST samples, 'med' - use 10000 MNIST samples, 'min' - use 1000 MNIST samples
        :param transform: ...
        :param download: ...
        :param list_test_domains: whether to load unseen domains (this might be removed later, but I don't have time to optimize the code at this point)
        :param num_supervised: whether to further subsample
        """

        self.list_train_domains = list_train_domains
        self.mnist_subset = mnist_subset
        self.root = os.path.expanduser(root)
        self.transform = transform
        self.train = train
        self.download = download
        self.num_supervised = num_supervised

        # self.not_eval = not_eval  # load test MNIST dataset

        self.train_data, self.train_labels, self.train_domain, self.train_angles = self._get_data()

    def load_inds(self):
        '''
        If specifyign a subset, load 1000 mnist samples with balanced class (100 samples
        for each class). If not, load 10000 mnist samples.
        :return: indices of mnist samples to be loaded
        '''
        if self.mnist_subset=='med':
            fullidx = np.array([])
            for i in range(10):
                fullidx = np.concatenate(
                    (fullidx, np.load(os.path.join(self.root, 'rotatedfmnist/supervised_inds_' + str(i) + '.npy'))))
            return fullidx
        else:
            return np.load(os.path.join(self.root, 'rotatedfmnist/supervised_inds_' + self.mnist_subset + '.npy'))

    def _get_data(self):
        if self.train:
            bs = 60000
        else:
            bs = 10000
            self.mnist_subset='max' # always use full set for test data as we don't have saved indices for MNIST test set

        train_loader = torch.utils.data.DataLoader(datasets.FashionMNIST(self.root,
                                                   train=self.train,
                                                   download=self.download,
                                                   transform=transforms.ToTensor()),
                                                   batch_size=bs,
                                                   shuffle=False)

        for i, (x, y) in enumerate(train_loader):
            mnist_imgs = x
            mnist_labels = y

        if self.mnist_subset != 'max':
            # Get labeled examples
            print(f'use MNIST subset {self.mnist_subset}!')
            sup_inds = self.load_inds()
            mnist_labels = mnist_labels[sup_inds]
            mnist_imgs = mnist_imgs[sup_inds]
        else:
            print('use all MNIST data!')

        if not self.num_supervised:
            self.num_supervised = int(mnist_imgs.shape[0])

        to_pil = transforms.ToPILImage()
        to_tensor = transforms.ToTensor()

        # Run transforms
        mnist_0_img = torch.zeros((self.num_supervised, 28, 28))
        mnist_15_img = torch.zeros((self.num_supervised, 28, 28))
        mnist_30_img = torch.zeros((self.num_supervised, 28, 28))
        mnist_45_img = torch.zeros((self.num_supervised, 28, 28))
        mnist_60_img = torch.zeros((self.num_supervised, 28, 28))
        mnist_75_img = torch.zeros((self.num_supervised, 28, 28))

        for i in range(len(mnist_imgs)):
            mnist_0_img[i] = to_tensor(to_pil(mnist_imgs[i]))

        for i in range(len(mnist_imgs)):
            mnist_15_img[i] = to_tensor(transforms.functional.rotate(to_pil(mnist_imgs[i]), 15))

        for i in range(len(mnist_imgs)):
            mnist_30_img[i] = to_tensor(transforms.functional.rotate(to_pil(mnist_imgs[i]), 30))

        for i in range(len(mnist_imgs)):
            mnist_45_img[i] = to_tensor(transforms.functional.rotate(to_pil(mnist_imgs[i]), 45))

        for i in range(len(mnist_imgs)):
            mnist_60_img[i] = to_tensor(transforms.functional.rotate(to_pil(mnist_imgs[i]), 60))

        for i in range(len(mnist_imgs)):
            mnist_75_img[i] = to_tensor(transforms.functional.rotate(to_pil(mnist_imgs[i]), 75))

        # Choose subsets that should be included into the training
        training_list_img = []
        training_list_labels = []
        train_angles = []
        for domain in self.list_train_domains:
            if domain == '0':
                training_list_img.append(mnist_0_img)
                training_list_labels.append(mnist_labels)
                train_angles.append(0)
            if domain == '15':
                training_list_img.append(mnist_15_img)
                training_list_labels.append(mnist_labels)
                train_angles.append(15)
            if domain == '30':
                training_list_img.append(mnist_30_img)
                training_list_labels.append(mnist_labels)
                train_angles.append(30)
            if domain == '45':
                training_list_img.append(mnist_45_img)
                training_list_labels.append(mnist_labels)
                train_angles.append(45)
            if domain == '60':
                training_list_img.append(mnist_60_img)
                training_list_labels.append(mnist_labels)
                train_angles.append(60)
            if domain == '75':
                training_list_img.append(mnist_75_img)
                training_list_labels.append(mnist_labels)
                train_angles.append(75)

        # Stack
        train_imgs = torch.cat(training_list_img)
        train_labels = torch.cat(training_list_labels)

        # Create domain labels
        train_domains = torch.zeros(train_labels.size())
        for i in range(len(self.list_train_domains)):
            train_domains[i * self.num_supervised:(i + 1) * self.num_supervised] += i


        # Shuffle everything one more time
        inds = np.arange(train_labels.size()[0])
        np.random.shuffle(inds)
        train_imgs = train_imgs[inds]
        train_labels = train_labels[inds]
        train_domains = train_domains[inds].long()

        return train_imgs.unsqueeze(1), train_labels, train_domains, train_angles



    def __len__(self):
        return len(self.train_labels)


    def __getitem__(self, index):

        x = self.train_data[index]
        y = self.train_labels[index]
        d = self.train_domain[index]

        if self.transform is not None:
            x = self.transform(x)

        return x, y, d

class MnistRotated_CF(data_utils.Dataset):
    def __init__(self,
                 list_train_domains,
                 root,
                 transform=None,
                 download=True,
                 n_samples=None):

        """
        :param list_train_domains: all domains we observe in the training
        :param root: data directory
        :param train: whether to load MNIST training data
        :param mnist_subset: 'max' - for each domain, use 60000 MNIST samples, 'med' - use 10000 MNIST samples, 'min' - use 1000 MNIST samples
        :param transform: ...
        :param download: ...
        :param list_test_domains: whether to load unseen domains (this might be removed later, but I don't have time to optimize the code at this point)
        :param num_supervised: whether to further subsample
        """

        self.list_train_domains = list_train_domains
        self.root = os.path.expanduser(root)
        self.transform = transform
        self.download = download
        self.n_samples = n_samples

        # self.not_eval = not_eval  # load test MNIST dataset


        self.train_data1,self.train_data2, self.train_label, self.train_domain1,self.train_domain2 = self._get_data()


    def _get_data(self):

        train_loader = torch.utils.data.DataLoader(datasets.FashionMNIST(self.root,
                                                                  train=False,
                                                                  download=self.download,
                                                                  transform=transforms.ToTensor()),
                                                   batch_size=10000,
                                                   shuffle=False)

        for i, (x, y) in enumerate(train_loader):
            mnist_imgs = x
            mnist_labels = y

        if not self.n_samples:
            self.n_samples = int(mnist_imgs.shape[0])

        to_pil = transforms.ToPILImage()
        to_tensor = transforms.ToTensor()

        # Run transforms
        mnist_0_img = torch.zeros((self.n_samples, 28, 28))
        mnist_15_img = torch.zeros((self.n_samples, 28, 28))
        mnist_30_img = torch.zeros((self.n_samples, 28, 28))
        mnist_45_img = torch.zeros((self.n_samples, 28, 28))
        mnist_60_img = torch.zeros((self.n_samples, 28, 28))
        mnist_75_img = torch.zeros((self.n_samples, 28, 28))

        for i in range(len(mnist_imgs)):
            mnist_0_img[i] = to_tensor(to_pil(mnist_imgs[i]))

        for i in range(len(mnist_imgs)):
            mnist_15_img[i] = to_tensor(transforms.functional.rotate(to_pil(mnist_imgs[i]), 15))

        for i in range(len(mnist_imgs)):
            mnist_30_img[i] = to_tensor(transforms.functional.rotate(to_pil(mnist_imgs[i]), 30))

        for i in range(len(mnist_imgs)):
            mnist_45_img[i] = to_tensor(transforms.functional.rotate(to_pil(mnist_imgs[i]), 45))

        for i in range(len(mnist_imgs)):
            mnist_60_img[i] = to_tensor(transforms.functional.rotate(to_pil(mnist_imgs[i]), 60))

        for i in range(len(mnist_imgs)):
            mnist_75_img[i] = to_tensor(transforms.functional.rotate(to_pil(mnist_imgs[i]), 75))

        all_img = dict()
        all_img['0'] = mnist_0_img
        all_img['15'] = mnist_15_img
        all_img['30'] = mnist_30_img
        all_img['45'] = mnist_45_img
        all_img['60'] = mnist_60_img
        all_img['75'] = mnist_75_img

        domain_to_domainlabel = dict()
        for i, domain in enumerate(self.list_train_domains):
            domain_to_domainlabel[domain] = i

        # Choose subsets that should be included into the training
        # all_domain_combs = list(permutations(self.list_train_domains, 2))
        all_domain_combs = list(combinations_with_replacement(self.list_train_domains, 2)) + \
                                list(combinations_with_replacement(reversed(self.list_train_domains), 2))
        training_list_img1 = []
        training_list_img2 = []
        training_list_domains1 = []
        training_list_domains2 = []
        for i in range(self.n_samples):
            d1,d2 = random.choice(all_domain_combs)
            training_list_img1.append(all_img[d1][i].unsqueeze(0))
            training_list_img2.append(all_img[d2][i].unsqueeze(0))
            training_list_domains1.append(domain_to_domainlabel[d1])
            training_list_domains2.append(domain_to_domainlabel[d2])

        train_img1 = torch.cat(training_list_img1)
        train_img2 = torch.cat(training_list_img2)
        train_domain1 = torch.Tensor(training_list_domains1)
        train_domain2 = torch.Tensor(training_list_domains2)
        train_label = mnist_labels

        # Shuffle everything one more time
        inds = np.arange(train_label.size()[0])
        np.random.shuffle(inds)
        train_img1 = train_img1[inds]
        train_img2 = train_img2[inds]
        train_label = train_label[inds]
        train_domain1 = train_domain1[inds].long()
        train_domain2 = train_domain2[inds].long()

        return train_img1.unsqueeze(1),train_img2.unsqueeze(1), train_label, train_domain1, train_domain2

    def __len__(self):
        return len(self.train_label)


    def __getitem__(self, index):

        x1 = self.train_data1[index]
        x2 = self.train_data2[index]
        y = self.train_label[index]
        d1 = self.train_domain1[index]
        d2 = self.train_domain2[index]

        if self.transform is not None:
            x1 = self.transform(x1)
            x2 = self.transform(x2)

        return x1, x2, y, d1, d2