from __future__ import print_function
import os
import sys
import errno
import numpy as np
from PIL import Image
import torch.utils.data as data
import contextlib
import pickle
from .base import *

@contextlib.contextmanager
def temp_seed(seed):
    state = np.random.get_state()
    np.random.seed(seed)
    try:
        yield
    finally:
        np.random.set_state(state)


def pil_loader(path):
    with open(path, 'rb') as f:
        with Image.open(f) as img:
            return img.convert('RGB')


def accimage_loader(path):
    import torchvision.datasets.accimage as accimage
    try:
        return accimage.Image(path)
    except IOError:
        return pil_loader(path)


def default_loader(path):
    from torchvision import get_image_backend
    if get_image_backend() == 'accimage':
        return accimage_loader(path)
    else:
        return pil_loader(path)


def build_set(root, split, imgs, noise_type='pairflip', noise_rate=0.5):
    """
       Function to return the lists of paths with the corresponding labels for the images
    Args:
        root (string): Root directory of dataset
        split (str): ['train', 'gallery', 'query'] returns the list pertaining to training images and labels, else otherwise
    Returns:
        return_list: list of 236_comb_fromZeroNoise-tuples with 1st location specifying path and 2nd location specifying the class
    """

    tmp_imgs = imgs

    argidx = np.argsort(tmp_imgs)





class cifar(BaseDataset2):
    """`cifar10 <https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz>`_ Dataset.
    Args:
        root (string): Root directory of dataset the images and corresponding lists exist
            inside raw folder
        train (bool, optional): If True, creates dataset from ``training.pt``,
            otherwise from ``test.pt``.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
    """
    urls = []
    raw_folder = 'raw'
    tgz_md5 = 'c58f30108f718f92721af3b95e74349a'
    train_list = [
        ['data_batch_1', 'c99cafc152244af753f735de768cd75f'],
        ['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'],
        ['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'],
        ['data_batch_4', '634d18415352ddfa80567beed471001a'],
        ['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'],
    ]

    test_list = [
        ['test_batch', '40351d587109b95175f43aff81a1287e'],
    ]
    meta = {
        'filename': 'batches.meta',
        'key': 'label_names',
        'md5': '5ff9c542aee3614f3951f8cda6e48888',
    }

    def __init__(self, root,  split='train', transform=None, target_transform=None, download=True, loader=default_loader):

        #self.root = os.path.expanduser('../' + root)
        self.root = os.path.expanduser(root)
        self.transform = transform
        self.target_transform = target_transform
        self.split = split

        self.loader = loader
        self.label_mat = None
        self.classes = range(10)
        self.meta_labels = None
        BaseDataset2.__init__(self, self.root, self.split, self.transform)



        # if self.train:
        #     downloaded_list = self.train_list
        # else:
        #     downloaded_list = self.test_list
        downloaded_list = self.train_list + self.test_list

        self.urls = ['https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz']

        if download:
            self.download()

        if not self._check_exists():
            raise RuntimeError('Dataset not found. You can use download=True to download it')

        self.imgs = []
        targets = []
        for file_name, checksum in downloaded_list:
            file_path = os.path.join(os.path.join(self.root, self.raw_folder), 'cifar-10-batches-py/', file_name)
            with open(file_path, 'rb') as f:
                entry = pickle.load(f, encoding='latin1')
                self.imgs.append(entry['data'])
                if 'labels' in entry:
                    targets.extend(entry['labels'])
                else:
                    targets.extend(entry['fine_labels'])
        self.imgs = np.vstack(self.imgs).reshape(-1, 3, 32, 32)
        self.imgs = self.imgs.transpose((0, 2, 3, 1))  # convert to HWC
        self.load_meta()

        argidx = np.argsort(targets)
        self.imgs = self.imgs[argidx]
        targets = np.array(targets)[argidx]

        self.query_x = self.imgs[::60]
        self.query_y = targets[::60]

        gallery_x = np.delete(self.imgs, np.arange(0, np.shape(self.imgs)[0], 60), axis=0)
        gallery_y = np.delete(targets, np.arange(0, np.shape(targets)[0], 60), axis=0)

        gallery_x_split = np.split(gallery_x, len(self.classes), 0)
        gallery_y_split = np.split(gallery_y, len(self.classes), 0)

        for i in range(len(self.classes)):
            Source_tmp_x = gallery_x_split[i][:500]
            Source_tmp_y = gallery_y_split[i][:500]
            Target_tmp_x = gallery_x_split[i][500:]
            Target_tmp_y = gallery_y_split[i][500:]

            if i == 0:
                self.Source_x = Source_tmp_x
                self.Source_y = Source_tmp_y
                Target_x = Target_tmp_x
                Target_y = Target_tmp_y
            else:
                self.Source_x = np.concatenate((self.Source_x, Source_tmp_x), axis=0)
                self.Source_y = np.concatenate((self.Source_y, Source_tmp_y), axis=0)
                Target_x = np.concatenate((Target_x, Target_tmp_x), axis=0)
                Target_y = np.concatenate((Target_y, Target_tmp_y), axis=0)

        self.Gallery_x = Target_x
        self.Gallery_y = Target_y

        if split == 'train':
            self.imgs = list(zip(self.Source_x, self.Source_y))
        elif split == 'gallery':
            self.imgs = list(zip(self.Gallery_x, self.Gallery_y))
        elif split == 'db':
            db_x = np.concatenate((self.Source_x, self.Gallery_x), axis=0)
            db_y = np.concatenate((self.Source_y, self.Gallery_y), axis=0)
            self.imgs = list(zip(db_x, db_y))
        elif split == 'query':
            self.imgs = list(zip(self.query_x, self.query_y))

            #db_y = np.concatenate((self.Source_y, self.Gallery_y), axis=0)
            db_y = self.Gallery_y

            # Construct label mat between gallary and query
            query_one_hot = np.eye(len(self.classes))[self.query_y]
            gallery_one_hot = np.eye(len(self.classes))[db_y]
            self.label_mat = np.matmul(query_one_hot, np.transpose(gallery_one_hot))

        index = 0
        for i in self.imgs:
            # i[1]: label, i[0]: root
            y = i[1]
            # if y in self.classes and fn[:2] != '._':

            self.ys += [y]
            self.I += [index]
            self.im_paths.append(i[0])
            index += 1

    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        # path, target = self.imgs[index]
        # img = self.loader(path)
        #
        # if self.transform is not None:
        #     img = self.transform(img)
        #
        # if self.target_transform is not None:
        #     img = self.target_transform(img)
        #
        # return img, target

        img = self.imgs[index][0]
        #img = self.loader(path)
        img = Image.fromarray(img)

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            img = self.target_transform(img)

        if not self.split == 'train':
            return (img, *self.imgs[index][1:])
        else:
            if self.meta_labels is None:
                return (img, *self.imgs[index][1:])
            else:
                return (img, *self.imgs[index][1:], self.meta_labels[index])


    def download(self):
        from six.moves import urllib
        import tarfile

        if self._check_exists():
            return

        try:
            os.makedirs(os.path.join(self.root, self.raw_folder))
        except OSError as e:
            if e.errno == errno.EEXIST:
                pass
            else:
                raise

        for url in self.urls:
            print('Downloading ' + url)
            data = urllib.request.urlopen(url)
            filename = url.rpartition('/')[2]
            file_path = os.path.join(self.root, self.raw_folder, filename)
            with open(file_path, 'wb') as f:
                f.write(data.read())
            tar = tarfile.open(file_path, 'r')
            for item in tar:
                tar.extract(item, file_path.replace(filename, ''))
            os.unlink(file_path)

        print('Done!')

    def load_meta(self):
        path = os.path.join(self.root, self.raw_folder,'cifar-10-batches-py/',self.meta['filename'])
        with open(path, 'rb') as infile:
            if sys.version_info[0] == 2:
                data = pickle.load(infile)
            else:
                data = pickle.load(infile, encoding='latin1')
            self.class_names = data[self.meta['key']]
        self.class_to_idx = {_class: i for i, _class in enumerate(self.class_names)}.values()


    def _check_exists(self):
        pth = os.path.join(self.root, self.raw_folder)
        return os.path.exists(os.path.join(pth, 'cifar-10-batches-py'))


class cifar2(BaseDataset2):
    """`cifar10 <https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz>`_ Dataset.
    Args:
        root (string): Root directory of dataset the images and corresponding lists exist
            inside raw folder
        train (bool, optional): If True, creates dataset from ``training.pt``,
            otherwise from ``test.pt``.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
    """
    urls = []
    raw_folder = 'raw'
    tgz_md5 = 'c58f30108f718f92721af3b95e74349a'
    train_list = [
        ['data_batch_1', 'c99cafc152244af753f735de768cd75f'],
        ['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'],
        ['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'],
        ['data_batch_4', '634d18415352ddfa80567beed471001a'],
        ['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'],
    ]

    test_list = [
        ['test_batch', '40351d587109b95175f43aff81a1287e'],
    ]
    meta = {
        'filename': 'batches.meta',
        'key': 'label_names',
        'md5': '5ff9c542aee3614f3951f8cda6e48888',
    }

    def __init__(self, root,  split='train', transform=None, target_transform=None, download=True, loader=default_loader):

        #self.root = os.path.expanduser('../' + root)
        self.root = os.path.expanduser(root)
        self.transform = transform
        self.target_transform = target_transform
        self.split = split

        self.loader = loader
        self.label_mat = None
        self.classes = range(10)
        self.meta_labels = None
        BaseDataset2.__init__(self, self.root, self.split, self.transform)



        # if self.train:
        #     downloaded_list = self.train_list
        # else:
        #     downloaded_list = self.test_list
        downloaded_list = self.train_list + self.test_list

        self.urls = ['https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz']

        if download:
            self.download()

        if not self._check_exists():
            raise RuntimeError('Dataset not found. You can use download=True to download it')

        self.imgs = []
        targets = []
        for file_name, checksum in downloaded_list:
            file_path = os.path.join(os.path.join(self.root, self.raw_folder), 'cifar-10-batches-py/', file_name)
            with open(file_path, 'rb') as f:
                entry = pickle.load(f, encoding='latin1')
                self.imgs.append(entry['data'])
                if 'labels' in entry:
                    targets.extend(entry['labels'])
                else:
                    targets.extend(entry['fine_labels'])
        self.imgs = np.vstack(self.imgs).reshape(-1, 3, 32, 32)
        self.imgs = self.imgs.transpose((0, 2, 3, 1))  # convert to HWC
        self.load_meta()

        argidx = np.argsort(targets)
        self.imgs = self.imgs[argidx]
        targets = np.array(targets)[argidx]

        self.query_x = self.imgs[::60]
        self.query_y = targets[::60]

        gallery_x = np.delete(self.imgs, np.arange(0, np.shape(self.imgs)[0], 60), axis=0)
        gallery_y = np.delete(targets, np.arange(0, np.shape(targets)[0], 60), axis=0)

        gallery_x_split = np.split(gallery_x, len(self.classes), 0)
        gallery_y_split = np.split(gallery_y, len(self.classes), 0)

        for i in range(len(self.classes)):
            Source_tmp_x = gallery_x_split[i][:500]
            Source_tmp_y = gallery_y_split[i][:500]
            Target_tmp_x = gallery_x_split[i][500:]
            Target_tmp_y = gallery_y_split[i][500:]

            if i == 0:
                self.Source_x = Source_tmp_x
                self.Source_y = Source_tmp_y
                Target_x = Target_tmp_x
                Target_y = Target_tmp_y
            else:
                self.Source_x = np.concatenate((self.Source_x, Source_tmp_x), axis=0)
                self.Source_y = np.concatenate((self.Source_y, Source_tmp_y), axis=0)
                Target_x = np.concatenate((Target_x, Target_tmp_x), axis=0)
                Target_y = np.concatenate((Target_y, Target_tmp_y), axis=0)

        self.Gallery_x = Target_x
        self.Gallery_y = Target_y

        if split == 'train':
            self.imgs = list(zip(self.Source_x, self.Source_y))
        elif split == 'gallery':
            self.imgs = list(zip(self.Gallery_x, self.Gallery_y))
        elif split == 'db':
            db_x = np.concatenate((self.Source_x, self.Gallery_x), axis=0)
            db_y = np.concatenate((self.Source_y, self.Gallery_y), axis=0)
            self.imgs = list(zip(db_x, db_y))
        elif split == 'query':
            self.imgs = list(zip(self.query_x, self.query_y))

            #db_y = np.concatenate((self.Source_y, self.Gallery_y), axis=0)
            db_y = self.Gallery_y

            # Construct label mat between gallary and query
            query_one_hot = np.eye(len(self.classes))[self.query_y]
            gallery_one_hot = np.eye(len(self.classes))[db_y]
            self.label_mat = np.matmul(query_one_hot, np.transpose(gallery_one_hot))

        index = 0
        for i in self.imgs:
            # i[1]: label, i[0]: root
            y = i[1]
            # if y in self.classes and fn[:2] != '._':

            self.ys += [y]
            self.I += [index]
            self.im_paths.append(i[0])
            index += 1

    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        # path, target = self.imgs[index]
        # img = self.loader(path)
        #
        # if self.transform is not None:
        #     img = self.transform(img)
        #
        # if self.target_transform is not None:
        #     img = self.target_transform(img)
        #
        # return img, target

        img = self.imgs[index][0]
        #img = self.loader(path)
        img = Image.fromarray(img)

        if self.transform is not None:
            img1 = self.transform(img)
            img2 = self.transform(img)

        if self.target_transform is not None:
            img = self.target_transform(img)

        if self.split == 'split':
            return (img, *self.imgs[index][1:])
        else:
            if self.meta_labels is None:
                return (img1, img2, *self.imgs[index][1:])
            else:
                return (img1, img2, *self.imgs[index][1:], self.meta_labels[index])


    def download(self):
        from six.moves import urllib
        import tarfile

        if self._check_exists():
            return

        try:
            os.makedirs(os.path.join(self.root, self.raw_folder))
        except OSError as e:
            if e.errno == errno.EEXIST:
                pass
            else:
                raise

        for url in self.urls:
            print('Downloading ' + url)
            data = urllib.request.urlopen(url)
            filename = url.rpartition('/')[2]
            file_path = os.path.join(self.root, self.raw_folder, filename)
            with open(file_path, 'wb') as f:
                f.write(data.read())
            tar = tarfile.open(file_path, 'r')
            for item in tar:
                tar.extract(item, file_path.replace(filename, ''))
            os.unlink(file_path)

        print('Done!')

    def load_meta(self):
        path = os.path.join(self.root, self.raw_folder,'cifar-10-batches-py/',self.meta['filename'])
        with open(path, 'rb') as infile:
            if sys.version_info[0] == 2:
                data = pickle.load(infile)
            else:
                data = pickle.load(infile, encoding='latin1')
            self.class_names = data[self.meta['key']]
        self.class_to_idx = {_class: i for i, _class in enumerate(self.class_names)}.values()


    def _check_exists(self):
        pth = os.path.join(self.root, self.raw_folder)
        return os.path.exists(os.path.join(pth, 'cifar-10-batches-py'))