# Adapted from: https://github.com/intel-isl/MultiObjectiveOptimization/blob/master/multi_task/loaders/multi_mnist_loader.py

import torch.utils.data as data
from PIL import Image
import os
import os.path
import errno
import numpy as np
import torch
import codecs


def to_one_hot(x, size):
    """
    Converts a tensor `x` into a one-hot representation of size `size` where x_one_hot[i] is a one hot with ones at the
    positions x[i].
    :param x: indexes tensor.
    :param size: size of the resulting tensor.
    :return: the one-hot tensor.
    """
    x_one_hot = x.new_zeros(x.size(0), size)
    x_one_hot.scatter_(1, x.unsqueeze(-1).long(), 1).float()

    return x_one_hot


from math import sqrt
from itertools import count, islice


def is_prime(x):
     prime = (x >= 2)
     for i in range(2, 11):
         prime &= (i >= x) | (x % i != 0)
     return prime


def foo(x):
     x = x.int()
     result = is_prime(x)
     if x.max() > 9:
         y = torch.zeros_like(x)
         while x.max() > 0:
             y += x % 10
             x = x // 10
         result |= foo(y)
     return result



class MNIST(data.Dataset):
    tasks = {
        'left': [0, ['nll', 'acc']],
        'right': [1, ['nll', 'acc']],
        'left2c': [2, ['nll', 'acc']],
        'right2c': [3, ['nll', 'acc']],
        'sum': [4, 'mse'],
        'multiply': [5, 'mse'],
        'density': [6, 'mse'],
        'product_prime': [7, ['bce', 'f1']],
        'number': [8, ['mse', 'mse']],
        'divide': [9, ['mse', 'mse']],
        'bigger_than': [10, ['bce', 'f1']],
        'binary_and': [11, ['nll', 'acc']],
        'odd': [12, ['bce', 'f1']]
    }

    urls = [
        'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz',
        'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz',
        'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz',
        'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz',
    ]
    raw_folder = 'raw'
    processed_folder = 'processed'
    training_file = 'training.pt'
    test_file = 'test.pt'
    multi_training_file = 'multi_training.pt'
    multi_test_file = 'multi_test.pt'

    def __init__(self, root, tag, transform=None):
        self.root = str(root)
        self.transform = transform
        self.train = tag != 'test'

        self.download()

        if self.train:
            self.data, self.labels_l, self.labels_r = torch.load(
                os.path.join(self.root, self.processed_folder, self.multi_training_file))
        else:
            self.data, self.labels_l, self.labels_r = torch.load(
                os.path.join(self.root, self.processed_folder, self.multi_test_file))

        split = 50000

        if tag == 'train':
            self.data, self.labels_l, self.labels_r = self.data[:split], self.labels_l[:split], self.labels_r[:split]
        elif tag == 'val':
            self.data, self.labels_l, self.labels_r = self.data[split:], self.labels_l[split:], self.labels_r[split:]

        self.input_size = self.data.flatten(start_dim=1).size(-1)
        self.target = (
            self.labels_l.float(),
            self.labels_r.float(),
            (torch.abs(10 - self.labels_l) % 10).float(),
            (torch.abs(10 - self.labels_r) % 10).float(),
            torch.unsqueeze(self.labels_l + self.labels_r, dim=-1).float(),
            torch.unsqueeze(self.labels_l * self.labels_r, dim=-1).float(),
            (self.data > 0.5).flatten(start_dim=1).float().mean(dim=1, keepdim=True),
            foo(torch.torch.unsqueeze(self.labels_l * self.labels_r, dim=-1)).float(),
            torch.unsqueeze(10 * self.labels_l + self.labels_r, dim=-1).float(),
            torch.unsqueeze(self.labels_l.float() / (1 + self.labels_r.float()), dim=-1),
            torch.unsqueeze(self.labels_l * self.labels_r >= 25, dim=-1).float(),
            (self.labels_l & self.labels_r).float(),
            torch.torch.unsqueeze((self.labels_l * self.labels_r) % 2 != 0, dim=-1).float(),
        )

        self.data = self.data.flatten(start_dim=1).float()

    def __getitem__(self, index):
        # return self.data[index % 1025], self.target[index % 1025]
        return self.data[index], [t[index] for t in self.target]

    def __len__(self):
        return self.data.size(0)

    def _check_multi_exists(self):
        return os.path.exists(os.path.join(self.root, self.processed_folder, self.multi_training_file)) and \
               os.path.exists(os.path.join(self.root, self.processed_folder, self.multi_test_file))

    def download(self):
        """Download the MNIST data if it doesn't exist in processed_folder already."""
        from six.moves import urllib
        import gzip

        if self._check_multi_exists():
            return

        # download files
        try:
            os.makedirs(os.path.join(self.root, self.raw_folder))
            os.makedirs(os.path.join(self.root, self.processed_folder))
        except OSError as e:
            if e.errno == errno.EEXIST:
                pass
            else:
                raise

        for url in self.urls:
            print('Downloading ' + url)
            data = urllib.request.urlopen(url)
            filename = url.rpartition('/')[2]
            file_path = os.path.join(self.root, self.raw_folder, filename)
            with open(file_path, 'wb') as f:
                f.write(data.read())
            with open(file_path.replace('.gz', ''), 'wb') as out_f, \
                    gzip.GzipFile(file_path) as zip_f:
                out_f.write(zip_f.read())
            os.unlink(file_path)

        # process and save as torch files
        print('Processing...')
        mnist_ims, multi_mnist_ims, extension = read_image_file(
            os.path.join(self.root, self.raw_folder, 'train-images-idx3-ubyte'))
        mnist_labels, multi_mnist_labels_l, multi_mnist_labels_r = read_label_file(
            os.path.join(self.root, self.raw_folder, 'train-labels-idx1-ubyte'), extension)

        tmnist_ims, tmulti_mnist_ims, textension = read_image_file(
            os.path.join(self.root, self.raw_folder, 't10k-images-idx3-ubyte'))
        tmnist_labels, tmulti_mnist_labels_l, tmulti_mnist_labels_r = read_label_file(
            os.path.join(self.root, self.raw_folder, 't10k-labels-idx1-ubyte'), textension)

        mnist_training_set = (mnist_ims, mnist_labels)
        multi_mnist_training_set = (multi_mnist_ims, multi_mnist_labels_l, multi_mnist_labels_r)

        mnist_test_set = (tmnist_ims, tmnist_labels)
        multi_mnist_test_set = (tmulti_mnist_ims, tmulti_mnist_labels_l, tmulti_mnist_labels_r)

        with open(os.path.join(self.root, self.processed_folder, self.training_file), 'wb') as f:
            torch.save(mnist_training_set, f)
        with open(os.path.join(self.root, self.processed_folder, self.test_file), 'wb') as f:
            torch.save(mnist_test_set, f)
        with open(os.path.join(self.root, self.processed_folder, self.multi_training_file), 'wb') as f:
            torch.save(multi_mnist_training_set, f)
        with open(os.path.join(self.root, self.processed_folder, self.multi_test_file), 'wb') as f:
            torch.save(multi_mnist_test_set, f)

        print('Done!')


def get_int(b):
    return int(codecs.encode(b, 'hex'), 16)


def read_label_file(path, extension):
    with open(path, 'rb') as f:
        data = f.read()
        assert get_int(data[:4]) == 2049
        length = get_int(data[4:8])
        parsed = np.frombuffer(data, dtype=np.uint8, offset=8)
        multi_labels_l = np.zeros((1 * length), dtype=np.long)
        multi_labels_r = np.zeros((1 * length), dtype=np.long)
        for im_id in range(length):
            for rim in range(1):
                multi_labels_l[1 * im_id + rim] = parsed[im_id]
                multi_labels_r[1 * im_id + rim] = parsed[extension[1 * im_id + rim]]
        return torch.from_numpy(parsed).view(length).long(), torch.from_numpy(multi_labels_l).view(
            length * 1).long(), torch.from_numpy(multi_labels_r).view(length * 1).long()


def read_image_file(path):
    with open(path, 'rb') as f:
        data = f.read()
        assert get_int(data[:4]) == 2051
        length = get_int(data[4:8])
        num_rows = get_int(data[8:12])
        num_cols = get_int(data[12:16])
        images = []
        parsed = np.frombuffer(data, dtype=np.uint8, offset=16)
        pv = parsed.reshape(length, num_rows, num_cols)
        multi_length = length * 1
        multi_data = np.zeros((1 * length, num_rows, num_cols))
        extension = np.zeros(1 * length, dtype=np.int32)
        for left in range(length):
            chosen_ones = np.random.permutation(length)[:1]
            extension[left * 1:(left + 1) * 1] = chosen_ones
            for j, right in enumerate(chosen_ones):
                lim = pv[left, :, :]
                rim = pv[right, :, :]
                new_im = np.zeros((36, 36))
                new_im[0:28, 0:28] = lim
                new_im[6:34, 6:34] = rim
                new_im[6:28, 6:28] = np.maximum(lim[6:28, 6:28], rim[0:22, 0:22])
                multi_data_im = np.array(Image.fromarray(new_im).resize((28, 28), Image.NEAREST))
                multi_data[left * 1 + j, :, :] = multi_data_im
        return torch.from_numpy(parsed).view(length, num_rows, num_cols), torch.from_numpy(multi_data).view(length,
                                                                                                            num_rows,
                                                                                                            num_cols), extension

