from __future__ import print_function
from torchvision.datasets.vision import VisionDataset
import warnings
from PIL import Image
import os
import os.path
import numpy as np
import torch
import codecs
import json
from torchvision.datasets.utils import download_url, download_and_extract_archive, extract_archive, verify_str_arg
from .io_utils import matching_file_name, degree_file_name, distance_file_name, global_matching_file_name

def shuffle_along_axis(a, axis):
    idx = np.random.rand(*a.shape).argsort(axis=axis)
    return np.take_along_axis(a,idx,axis=axis)

class MNIST(VisionDataset):
    """`MNIST <http://yann.lecun.com/exdb/mnist/>`_ Dataset.

    Args:
        root (string): Root directory of dataset where ``MNIST/processed/training.pt``
            and  ``MNIST/processed/test.pt`` exist.
        train (bool, optional): If True, creates dataset from ``training.pt``,
            otherwise from ``test.pt``.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
    """
    resources = [
        ("http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz", "f68b3c2dcbeaaa9fbdd348bbdeb94873"),
        ("http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz", "d53e105ee54ea40749a09fcbcd1e9432"),
        ("http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz", "9fb629c4189551a2d022fa330f9573f3"),
        ("http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz", "ec29112dd5afa0611ce80d1b7f02629c")
    ]
    training_file = 'training.pt'
    test_file = 'test.pt'
    classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four',
               '5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']

    @property
    def train_labels(self):
        warnings.warn("train_labels has been renamed targets")
        return self.targets

    @property
    def test_labels(self):
        warnings.warn("test_labels has been renamed targets")
        return self.targets

    @property
    def train_data(self):
        warnings.warn("train_data has been renamed data")
        return self.data

    @property
    def test_data(self):
        warnings.warn("test_data has been renamed data")
        return self.data

    def __init__(self, root, args, train=True, transform=None, target_transform=None,
                 download=False, np_array=False, dropping=False, training_time=False):
        super(MNIST, self).__init__(root, transform=transform,
                                    target_transform=target_transform)
        self.training_time = training_time
        self.train = train  # training set or test set
        self.np_array = np_array
        self.total_classes = 10

        if download:
            self.download()

        if not self._check_exists():
            raise RuntimeError('Dataset not found.' +
                               ' You can use download=True to download it')

        if self.train:
            data_file = self.training_file
        else:
            data_file = self.test_file
        
        self.data, self.targets = torch.load(os.path.join(self.processed_folder, data_file))
        self.data, self.targets, self.num_samples_per_class = self._multi_c_filter(args)


    
    def _multi_c_filter(self, args):
        if args.use_all_classes:
            class_list = list(range(self.total_classes))
        else:
            class_list = args.classes
        num_class = len(class_list)

        vs_all = ('all' in class_list)

        targets_arr = np.array(self.targets)

        X_c = []
        
        for i in range(num_class):
            c = class_list[i]
            if c == 'all':
                continue
            idx = np.where(targets_arr==c)[0]
            X_c.append(self.data[idx])

        if vs_all:
            all_dat = []
            all_classes = list(range(self.total_classes))
            for i in range(self.total_classes):
                if i in class_list:
                    continue
                c = all_classes[i]
                idx = np.where(targets_arr==c)[0]
                all_dat.append(self.data[idx])
            all_dat = np.concatenate(all_dat, axis=0)
            X_c.append(all_dat)

        X_lens = []
        for x in X_c:
            X_lens.append(len(x))

        if args.balanced:
            num_samples = min(*X_lens,args.num_samples)
        else:
            num_samples = args.num_samples

        num_samples_per_class = []
        for i in range(num_class):
            X_c[i] = X_c[i][:num_samples]
            num_samples_per_class.append(len(X_c[i]))

        curr_data = np.concatenate(X_c, axis=0)

        Y_curr = np.zeros(len(curr_data))
        start = 0
        for i in range(num_class):
            Y_curr[start:start+num_samples_per_class[i]] = i
            start += num_samples_per_class[i]

        curr_labels = Y_curr.tolist()
        curr_labels = [int(x) for x in curr_labels]
        return curr_data, curr_labels, num_samples_per_class

    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        img, target = self.data[index], int(self.targets[index])
        
        #easy_indc = self.easy_idx[index]

        #matched_idx_curr = int(self.matched_idx[index])

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        if not self.np_array:
            img = Image.fromarray(img, mode='L')

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target, index, None, None #easy_indc, matched_idx_curr

    def __len__(self):
        return len(self.data)

    @property
    def raw_folder(self):
        return os.path.join(self.root, self.__class__.__name__, 'raw')

    @property
    def processed_folder(self):
        return os.path.join(self.root, self.__class__.__name__, 'processed')

    @property
    def class_to_idx(self):
        return {_class: i for i, _class in enumerate(self.classes)}

    def _check_exists(self):
        return (os.path.exists(os.path.join(self.processed_folder,
                                            self.training_file)) and
                os.path.exists(os.path.join(self.processed_folder,
                                            self.test_file)))

    def download(self):
        """Download the MNIST data if it doesn't exist in processed_folder already."""

        if self._check_exists():
            return

        os.makedirs(self.raw_folder, exist_ok=True)
        os.makedirs(self.processed_folder, exist_ok=True)

        # download files
        for url, md5 in self.resources:
            filename = url.rpartition('/')[2]
            download_and_extract_archive(url, download_root=self.raw_folder, filename=filename, md5=md5)

        # process and save as torch files
        print('Processing...')

        training_set = (
            read_image_file(os.path.join(self.raw_folder, 'train-images-idx3-ubyte')),
            read_label_file(os.path.join(self.raw_folder, 'train-labels-idx1-ubyte'))
        )
        test_set = (
            read_image_file(os.path.join(self.raw_folder, 't10k-images-idx3-ubyte')),
            read_label_file(os.path.join(self.raw_folder, 't10k-labels-idx1-ubyte'))
        )
        with open(os.path.join(self.processed_folder, self.training_file), 'wb') as f:
            torch.save(training_set, f)
        with open(os.path.join(self.processed_folder, self.test_file), 'wb') as f:
            torch.save(test_set, f)

        print('Done!')

    def extra_repr(self):
        return "Split: {}".format("Train" if self.train is True else "Test")




class FashionMNIST(MNIST):
    """`Fashion-MNIST <https://github.com/zalandoresearch/fashion-mnist>`_ Dataset.
    Args:
        root (string): Root directory of dataset where ``Fashion-MNIST/processed/training.pt``
            and  ``Fashion-MNIST/processed/test.pt`` exist.
        train (bool, optional): If True, creates dataset from ``training.pt``,
            otherwise from ``test.pt``.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
    """
    resources = [
        ("http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz",
         "8d4fb7e6c68d591d4c3dfef9ec88bf0d"),
        ("http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz",
         "25c81989df183df01b3e8a0aad5dffbe"),
        ("http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz",
         "bef4ecab320f06d8554ea6380940ec79"),
        ("http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz",
         "bb300cfdad3c16e7a12a480ee83cd310")
    ]
    classes = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal',
               'Shirt', 'Sneaker', 'Bag', 'Ankle boot']



# class KMNIST(MNIST):
#     """`Kuzushiji-MNIST <https://github.com/rois-codh/kmnist>`_ Dataset.

#     Args:
#         root (string): Root directory of dataset where ``KMNIST/processed/training.pt``
#             and  ``KMNIST/processed/test.pt`` exist.
#         train (bool, optional): If True, creates dataset from ``training.pt``,
#             otherwise from ``test.pt``.
#         download (bool, optional): If true, downloads the dataset from the internet and
#             puts it in root directory. If dataset is already downloaded, it is not
#             downloaded again.
#         transform (callable, optional): A function/transform that  takes in an PIL image
#             and returns a transformed version. E.g, ``transforms.RandomCrop``
#         target_transform (callable, optional): A function/transform that takes in the
#             target and transforms it.
#     """
#     urls = [
#         'http://codh.rois.ac.jp/kmnist/dataset/kmnist/train-images-idx3-ubyte.gz',
#         'http://codh.rois.ac.jp/kmnist/dataset/kmnist/train-labels-idx1-ubyte.gz',
#         'http://codh.rois.ac.jp/kmnist/dataset/kmnist/t10k-images-idx3-ubyte.gz',
#         'http://codh.rois.ac.jp/kmnist/dataset/kmnist/t10k-labels-idx1-ubyte.gz',
#     ]
#     classes = ['o', 'ki', 'su', 'tsu', 'na', 'ha', 'ma', 'ya', 're', 'wo']


# class EMNIST(MNIST):
#     """`EMNIST <https://www.westernsydney.edu.au/bens/home/reproducible_research/emnist>`_ Dataset.

#     Args:
#         root (string): Root directory of dataset where ``EMNIST/processed/training.pt``
#             and  ``EMNIST/processed/test.pt`` exist.
#         split (string): The dataset has 6 different splits: ``byclass``, ``bymerge``,
#             ``balanced``, ``letters``, ``digits`` and ``mnist``. This argument specifies
#             which one to use.
#         train (bool, optional): If True, creates dataset from ``training.pt``,
#             otherwise from ``test.pt``.
#         download (bool, optional): If true, downloads the dataset from the internet and
#             puts it in root directory. If dataset is already downloaded, it is not
#             downloaded again.
#         transform (callable, optional): A function/transform that  takes in an PIL image
#             and returns a transformed version. E.g, ``transforms.RandomCrop``
#         target_transform (callable, optional): A function/transform that takes in the
#             target and transforms it.
#     """
#     # Updated URL from https://www.nist.gov/node/1298471/emnist-dataset since the
#     # _official_ download link
#     # https://cloudstor.aarnet.edu.au/plus/s/ZNmuFiuQTqZlu9W/download
#     # is (currently) unavailable
#     url = 'http://www.itl.nist.gov/iaui/vip/cs_links/EMNIST/gzip.zip'
#     splits = ('byclass', 'bymerge', 'balanced', 'letters', 'digits', 'mnist')

#     def __init__(self, root, split, **kwargs):
#         self.split = verify_str_arg(split, "split", self.splits)
#         self.training_file = self._training_file(split)
#         self.test_file = self._test_file(split)
#         super(EMNIST, self).__init__(root, **kwargs)

#     @staticmethod
#     def _training_file(split):
#         return 'training_{}.pt'.format(split)

#     @staticmethod
#     def _test_file(split):
#         return 'test_{}.pt'.format(split)

#     def download(self):
#         """Download the EMNIST data if it doesn't exist in processed_folder already."""
#         import shutil

#         if self._check_exists():
#             return

#         makedir_exist_ok(self.raw_folder)
#         makedir_exist_ok(self.processed_folder)

#         # download files
#         print('Downloading and extracting zip archive')
#         download_and_extract_archive(self.url, download_root=self.raw_folder, filename="emnist.zip",
#                                      remove_finished=True)
#         gzip_folder = os.path.join(self.raw_folder, 'gzip')
#         for gzip_file in os.listdir(gzip_folder):
#             if gzip_file.endswith('.gz'):
#                 extract_archive(os.path.join(gzip_folder, gzip_file), gzip_folder)

#         # process and save as torch files
#         for split in self.splits:
#             print('Processing ' + split)
#             training_set = (
#                 read_image_file(os.path.join(gzip_folder, 'emnist-{}-train-images-idx3-ubyte'.format(split))),
#                 read_label_file(os.path.join(gzip_folder, 'emnist-{}-train-labels-idx1-ubyte'.format(split)))
#             )
#             test_set = (
#                 read_image_file(os.path.join(gzip_folder, 'emnist-{}-test-images-idx3-ubyte'.format(split))),
#                 read_label_file(os.path.join(gzip_folder, 'emnist-{}-test-labels-idx1-ubyte'.format(split)))
#             )
#             with open(os.path.join(self.processed_folder, self._training_file(split)), 'wb') as f:
#                 torch.save(training_set, f)
#             with open(os.path.join(self.processed_folder, self._test_file(split)), 'wb') as f:
#                 torch.save(test_set, f)
#         shutil.rmtree(gzip_folder)

#         print('Done!')


# class QMNIST(MNIST):
#     """`QMNIST <https://github.com/facebookresearch/qmnist>`_ Dataset.

#     Args:
#         root (string): Root directory of dataset whose ``processed''
#             subdir contains torch binary files with the datasets.
#         what (string,optional): Can be 'train', 'test', 'test10k',
#             'test50k', or 'nist' for respectively the mnist compatible
#             training set, the 60k qmnist testing set, the 10k qmnist
#             examples that match the mnist testing set, the 50k
#             remaining qmnist testing examples, or all the nist
#             digits. The default is to select 'train' or 'test'
#             according to the compatibility argument 'train'.
#         compat (bool,optional): A boolean that says whether the target
#             for each example is class number (for compatibility with
#             the MNIST dataloader) or a torch vector containing the
#             full qmnist information. Default=True.
#         download (bool, optional): If true, downloads the dataset from
#             the internet and puts it in root directory. If dataset is
#             already downloaded, it is not downloaded again.
#         transform (callable, optional): A function/transform that
#             takes in an PIL image and returns a transformed
#             version. E.g, ``transforms.RandomCrop``
#         target_transform (callable, optional): A function/transform
#             that takes in the target and transforms it.
#         train (bool,optional,compatibility): When argument 'what' is
#             not specified, this boolean decides whether to load the
#             training set ot the testing set.  Default: True.

#     """

#     subsets = {
#         'train': 'train',
#         'test': 'test', 'test10k': 'test', 'test50k': 'test',
#         'nist': 'nist'
#     }
#     urls = {
#         'train': ['https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-images-idx3-ubyte.gz',
#                   'https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-labels-idx2-int.gz'],
#         'test': ['https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-test-images-idx3-ubyte.gz',
#                  'https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-test-labels-idx2-int.gz'],
#         'nist': ['https://raw.githubusercontent.com/facebookresearch/qmnist/master/xnist-images-idx3-ubyte.xz',
#                  'https://raw.githubusercontent.com/facebookresearch/qmnist/master/xnist-labels-idx2-int.xz']
#     }
#     classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four',
#                '5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']

#     def __init__(self, root, what=None, compat=True, train=True, **kwargs):
#         if what is None:
#             what = 'train' if train else 'test'
#         self.what = verify_str_arg(what, "what", tuple(self.subsets.keys()))
#         self.compat = compat
#         self.data_file = what + '.pt'
#         self.training_file = self.data_file
#         self.test_file = self.data_file
#         super(QMNIST, self).__init__(root, train, **kwargs)

#     def download(self):
#         """Download the QMNIST data if it doesn't exist in processed_folder already.
#            Note that we only download what has been asked for (argument 'what').
#         """
#         if self._check_exists():
#             return
#         makedir_exist_ok(self.raw_folder)
#         makedir_exist_ok(self.processed_folder)
#         urls = self.urls[self.subsets[self.what]]
#         files = []

#         # download data files if not already there
#         for url in urls:
#             filename = url.rpartition('/')[2]
#             file_path = os.path.join(self.raw_folder, filename)
#             if not os.path.isfile(file_path):
#                 download_url(url, root=self.raw_folder, filename=filename, md5=None)
#             files.append(file_path)

#         # process and save as torch files
#         print('Processing...')
#         data = read_sn3_pascalvincent_tensor(files[0])
#         assert(data.dtype == torch.uint8)
#         assert(data.ndimension() == 3)
#         targets = read_sn3_pascalvincent_tensor(files[1]).long()
#         assert(targets.ndimension() == 2)
#         if self.what == 'test10k':
#             data = data[0:10000, :, :].clone()
#             targets = targets[0:10000, :].clone()
#         if self.what == 'test50k':
#             data = data[10000:, :, :].clone()
#             targets = targets[10000:, :].clone()
#         with open(os.path.join(self.processed_folder, self.data_file), 'wb') as f:
#             torch.save((data, targets), f)

#     def __getitem__(self, index):
#         # redefined to handle the compat flag
#         img, target = self.data[index], self.targets[index]
#         img = Image.fromarray(img.numpy(), mode='L')
#         if self.transform is not None:
#             img = self.transform(img)
#         if self.compat:
#             target = int(target[0])
#         if self.target_transform is not None:
#             target = self.target_transform(target)
#         return img, target

#     def extra_repr(self):
#         return "Split: {}".format(self.what)


def get_int(b):
    return int(codecs.encode(b, 'hex'), 16)


def open_maybe_compressed_file(path):
    """Return a file object that possibly decompresses 'path' on the fly.
       Decompression occurs when argument `path` is a string and ends with '.gz' or '.xz'.
    """
    if not isinstance(path, torch._six.string_classes):
        return path
    if path.endswith('.gz'):
        import gzip
        return gzip.open(path, 'rb')
    if path.endswith('.xz'):
        import lzma
        return lzma.open(path, 'rb')
    return open(path, 'rb')


def read_sn3_pascalvincent_tensor(path, strict=True):
    """Read a SN3 file in "Pascal Vincent" format (Lush file 'libidx/idx-io.lsh').
       Argument may be a filename, compressed filename, or file object.
    """
    # typemap
    if not hasattr(read_sn3_pascalvincent_tensor, 'typemap'):
        read_sn3_pascalvincent_tensor.typemap = {
            8: (torch.uint8, np.uint8, np.uint8),
            9: (torch.int8, np.int8, np.int8),
            11: (torch.int16, np.dtype('>i2'), 'i2'),
            12: (torch.int32, np.dtype('>i4'), 'i4'),
            13: (torch.float32, np.dtype('>f4'), 'f4'),
            14: (torch.float64, np.dtype('>f8'), 'f8')}
    # read
    with open_maybe_compressed_file(path) as f:
        data = f.read()
    # parse
    magic = get_int(data[0:4])
    nd = magic % 256
    ty = magic // 256
    assert nd >= 1 and nd <= 3
    assert ty >= 8 and ty <= 14
    m = read_sn3_pascalvincent_tensor.typemap[ty]
    s = [get_int(data[4 * (i + 1): 4 * (i + 2)]) for i in range(nd)]
    parsed = np.frombuffer(data, dtype=m[1], offset=(4 * (nd + 1)))
    assert parsed.shape[0] == np.prod(s) or not strict
    return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


def read_label_file(path):
    with open(path, 'rb') as f:
        x = read_sn3_pascalvincent_tensor(f, strict=False)
    assert(x.dtype == torch.uint8)
    assert(x.ndimension() == 1)
    return x.long()


def read_image_file(path):
    with open(path, 'rb') as f:
        x = read_sn3_pascalvincent_tensor(f, strict=False)
    assert(x.dtype == torch.uint8)
    assert(x.ndimension() == 3)
    return x
