from __future__ import print_function
import os
import sys
import errno
import numpy as np
from PIL import Image
import torch.utils.data as data
import contextlib
import pickle
from .base import *
import copy

@contextlib.contextmanager
def temp_seed(seed):
    state = np.random.get_state()
    np.random.seed(seed)
    try:
        yield
    finally:
        np.random.set_state(state)


def pil_loader(path):
    with open(path, 'rb') as f:
        with Image.open(f) as img:
            return img.convert('RGB')


def accimage_loader(path):
    import torchvision.datasets.accimage as accimage
    try:
        return accimage.Image(path)
    except IOError:
        return pil_loader(path)


def default_loader(path):
    from torchvision import get_image_backend
    if get_image_backend() == 'accimage':
        return accimage_loader(path)
    else:
        return pil_loader(path)


def build_set(root, split, imgs, noise_type='pairflip', noise_rate=0.5):
    """
       Function to return the lists of paths with the corresponding labels for the images
    Args:
        root (string): Root directory of dataset
        split (str): ['train', 'gallery', 'query'] returns the list pertaining to training images and labels, else otherwise
    Returns:
        return_list: list of 236_comb_fromZeroNoise-tuples with 1st location specifying path and 2nd location specifying the class
    """

    tmp_imgs = imgs

    argidx = np.argsort(tmp_imgs)





class cifar(data.Dataset):
    """`cifar10 <https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz>`_ Dataset.
    Args:
        root (string): Root directory of dataset the images and corresponding lists exist
            inside raw folder
        train (bool, optional): If True, creates dataset from ``training.pt``,
            otherwise from ``test.pt``.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
    """
    urls = []
    raw_folder = 'raw'
    tgz_md5 = 'c58f30108f718f92721af3b95e74349a'
    train_list = [
        ['data_batch_1', 'c99cafc152244af753f735de768cd75f'],
        ['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'],
        ['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'],
        ['data_batch_4', '634d18415352ddfa80567beed471001a'],
        ['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'],
    ]

    test_list = [
        ['test_batch', '40351d587109b95175f43aff81a1287e'],
    ]
    meta = {
        'filename': 'batches.meta',
        'key': 'label_names',
        'md5': '5ff9c542aee3614f3951f8cda6e48888',
    }

    def __init__(self, root,  split='train', transform=None, target_transform=None, download=True, loader=default_loader,
                 target_list=range(5), seen_list=None):

        #self.root = os.path.expanduser('../' + root)
        self.root = os.path.expanduser(root)
        self.transform = transform
        self.target_transform = target_transform
        self.split = split

        self.loader = loader
        self.label_mat = None

        #self.classes = int(10 * self.seen_rate)

        downloaded_list = []
        if split == 'train':
            downloaded_list = self.train_list
        elif split == 'test':
            downloaded_list = self.test_list
        elif split == 'train+test':
            downloaded_list.extend(self.train_list)
            downloaded_list.extend(self.test_list)

        self.urls = ['https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz']

        if download:
            self.download()

        if not self._check_exists():
            raise RuntimeError('Dataset not found. You can use download=True to download it')

        self.imgs = []
        targets = []
        for file_name, checksum in downloaded_list:
            file_path = os.path.join(os.path.join(self.root, self.raw_folder), 'cifar-10-batches-py/', file_name)
            with open(file_path, 'rb') as f:
                entry = pickle.load(f, encoding='latin1')
                self.imgs.append(entry['data'])
                if 'labels' in entry:
                    targets.extend(entry['labels'])
                else:
                    targets.extend(entry['fine_labels'])
        self.imgs = np.vstack(self.imgs).reshape(-1, 3, 32, 32)
        self.imgs = self.imgs.transpose((0, 2, 3, 1))  # convert to HWC
        self.load_meta()
        self.targets = targets
        self.data = self.imgs

        if seen_list is None:
            ind = [i for i in range(len(self.targets)) if self.targets[i] in target_list]

            self.data = self.data[ind]
            self.targets = np.array(self.targets)
            self.targets = self.targets[ind].tolist()

            if split=='train':
                self.data = self.data[::2]
                self.targets = self.targets[::2]
        else:
            ind = [i for i in range(len(self.targets)) if self.targets[i] in seen_list]
            self.seen_data = self.data[ind][1::2]
            self.targets = np.array(self.targets)
            self.seen_targets = self.targets[ind].tolist()[1::2]

            ind = [i for i in range(len(self.targets)) if self.targets[i] not in seen_list]
            self.unseen_data = self.data[ind]
            self.unseen_targets = self.targets[ind].tolist()

            self.data = np.concatenate((self.seen_data, self.unseen_data), 0)
            self.targets = self.seen_targets + self.unseen_targets
            self.targets = np.array(self.targets)
            self.targets = np.ones_like(self.targets) * -1

        #self.imgs = list(zip(self.data, self.targets))

    def rebuild_imgs(self, pseudo_labels):
        imgs, labels = zip(*list(self.imgs))
        self.imgs = list(zip(imgs, pseudo_labels))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        img, target = self.data[index], self.targets[index]

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img)

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target


    def download(self):
        from six.moves import urllib
        import tarfile

        if self._check_exists():
            return

        try:
            os.makedirs(os.path.join(self.root, self.raw_folder))
        except OSError as e:
            if e.errno == errno.EEXIST:
                pass
            else:
                raise

        for url in self.urls:
            print('Downloading ' + url)
            data = urllib.request.urlopen(url)
            filename = url.rpartition('/')[2]
            file_path = os.path.join(self.root, self.raw_folder, filename)
            with open(file_path, 'wb') as f:
                f.write(data.read())
            tar = tarfile.open(file_path, 'r')
            for item in tar:
                tar.extract(item, file_path.replace(filename, ''))
            os.unlink(file_path)

        print('Done!')

    def load_meta(self):
        path = os.path.join(self.root, self.raw_folder,'cifar-10-batches-py/',self.meta['filename'])
        with open(path, 'rb') as infile:
            if sys.version_info[0] == 2:
                data = pickle.load(infile)
            else:
                data = pickle.load(infile, encoding='latin1')
            self.class_names = data[self.meta['key']]
        self.class_to_idx = {_class: i for i, _class in enumerate(self.class_names)}.values()

    def _check_exists(self):
        pth = os.path.join(self.root, self.raw_folder)
        return os.path.exists(os.path.join(pth, 'cifar-10-batches-py'))

    def _relocate_labels(self, targets, nb_fold, fold_lst_dict):
        fold_lst = fold_lst_dict['fold_%d' % nb_fold]
        tmp_targets = copy.deepcopy(targets)

        for i, cls_idx in enumerate(fold_lst):
            idxs = np.where(targets == cls_idx)
            tmp_targets[idxs] = 7 + i

            idxs = np.where(targets == 7 + i)
            tmp_targets[idxs] = cls_idx

        return tmp_targets


class cifar2(data.Dataset):
    """`cifar10 <https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz>`_ Dataset.
    Args:
        root (string): Root directory of dataset the images and corresponding lists exist
            inside raw folder
        train (bool, optional): If True, creates dataset from ``training.pt``,
            otherwise from ``test.pt``.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
    """
    urls = []
    raw_folder = 'raw'
    tgz_md5 = 'c58f30108f718f92721af3b95e74349a'
    train_list = [
        ['data_batch_1', 'c99cafc152244af753f735de768cd75f'],
        ['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'],
        ['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'],
        ['data_batch_4', '634d18415352ddfa80567beed471001a'],
        ['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'],
    ]

    test_list = [
        ['test_batch', '40351d587109b95175f43aff81a1287e'],
    ]
    meta = {
        'filename': 'batches.meta',
        'key': 'label_names',
        'md5': '5ff9c542aee3614f3951f8cda6e48888',
    }

    def __init__(self, root,  split='train', transform=None, target_transform=None, download=True, loader=default_loader,
                 target_list=range(5), seen_list=None):

        #self.root = os.path.expanduser('../' + root)
        self.root = os.path.expanduser(root)
        self.transform = transform
        self.target_transform = target_transform
        self.split = split

        self.loader = loader
        self.label_mat = None

        #self.classes = int(10 * self.seen_rate)

        downloaded_list = []
        if split == 'train':
            downloaded_list = self.train_list
        elif split == 'test':
            downloaded_list = self.test_list
        elif split == 'train+test':
            downloaded_list.extend(self.train_list)
            downloaded_list.extend(self.test_list)

        self.urls = ['https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz']

        if download:
            self.download()

        if not self._check_exists():
            raise RuntimeError('Dataset not found. You can use download=True to download it')

        self.imgs = []
        targets = []
        for file_name, checksum in downloaded_list:
            file_path = os.path.join(os.path.join(self.root, self.raw_folder), 'cifar-10-batches-py/', file_name)
            with open(file_path, 'rb') as f:
                entry = pickle.load(f, encoding='latin1')
                self.imgs.append(entry['data'])
                if 'labels' in entry:
                    targets.extend(entry['labels'])
                else:
                    targets.extend(entry['fine_labels'])
        self.imgs = np.vstack(self.imgs).reshape(-1, 3, 32, 32)
        self.imgs = self.imgs.transpose((0, 2, 3, 1))  # convert to HWC
        self.load_meta()
        self.targets = targets
        self.data = self.imgs

        if seen_list is None:
            ind = [i for i in range(len(self.targets)) if self.targets[i] in target_list]

            self.data = self.data[ind]
            self.targets = np.array(self.targets)
            self.targets = self.targets[ind].tolist()

            if split=='train':
                self.data = self.data[::2]
                self.targets = self.targets[::2]
        else:
            ind = [i for i in range(len(self.targets)) if self.targets[i] in seen_list]
            self.seen_data = self.data[ind][1::2]
            self.targets = np.array(self.targets)
            self.seen_targets = self.targets[ind].tolist()[1::2]

            ind = [i for i in range(len(self.targets)) if self.targets[i] not in seen_list]
            self.unseen_data = self.data[ind]
            self.unseen_targets = self.targets[ind].tolist()

            self.data = np.concatenate((self.seen_data, self.unseen_data), 0)
            self.targets = self.seen_targets + self.unseen_targets
            self.targets = np.array(self.targets)
            self.targets = np.ones_like(self.targets) * -1

        #self.imgs = list(zip(self.data, self.targets))

    def rebuild_imgs(self, pseudo_labels):
        imgs, labels = zip(*list(self.imgs))
        self.imgs = list(zip(imgs, pseudo_labels))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        img, target = self.data[index], self.targets[index]

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img)

        if self.transform is not None:
            img1 = self.transform(img)
            img2 = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img1, img2, target


    def download(self):
        from six.moves import urllib
        import tarfile

        if self._check_exists():
            return

        try:
            os.makedirs(os.path.join(self.root, self.raw_folder))
        except OSError as e:
            if e.errno == errno.EEXIST:
                pass
            else:
                raise

        for url in self.urls:
            print('Downloading ' + url)
            data = urllib.request.urlopen(url)
            filename = url.rpartition('/')[2]
            file_path = os.path.join(self.root, self.raw_folder, filename)
            with open(file_path, 'wb') as f:
                f.write(data.read())
            tar = tarfile.open(file_path, 'r')
            for item in tar:
                tar.extract(item, file_path.replace(filename, ''))
            os.unlink(file_path)

        print('Done!')

    def load_meta(self):
        path = os.path.join(self.root, self.raw_folder,'cifar-10-batches-py/',self.meta['filename'])
        with open(path, 'rb') as infile:
            if sys.version_info[0] == 2:
                data = pickle.load(infile)
            else:
                data = pickle.load(infile, encoding='latin1')
            self.class_names = data[self.meta['key']]
        self.class_to_idx = {_class: i for i, _class in enumerate(self.class_names)}.values()

    def _check_exists(self):
        pth = os.path.join(self.root, self.raw_folder)
        return os.path.exists(os.path.join(pth, 'cifar-10-batches-py'))

    def _relocate_labels(self, targets, nb_fold, fold_lst_dict):
        fold_lst = fold_lst_dict['fold_%d' % nb_fold]
        tmp_targets = copy.deepcopy(targets)

        for i, cls_idx in enumerate(fold_lst):
            idxs = np.where(targets == cls_idx)
            tmp_targets[idxs] = 7 + i

            idxs = np.where(targets == 7 + i)
            tmp_targets[idxs] = cls_idx

        return tmp_targets



def encode_onehot(labels, num_classes=75):
    """
    one-hot labels
    Args:
        labels (numpy.ndarray): labels.
        num_classes (int): Number of classes.
    Returns:
        onehot_labels (numpy.ndarray): one-hot labels.
    """
    onehot_labels = np.zeros((len(labels), num_classes))

    for i in range(len(labels)):
        onehot_labels[i, labels[i]] = 1

    return onehot_labels


class CIFAR100(BaseDataset2):
    """`CIFAR100 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.

    This is a subclass of the `CIFAR10` Dataset.
    """
    base_folder = 'cifar-100-python'
    raw_folder='raw'
    urls = ["https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"]
    filename = "cifar-100-python.tar.gz"
    tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85'
    train_list = [
        ['train', '16019d7e3df5f24257cddd939b257f8d'],
    ]

    test_list = [
        ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'],
    ]
    meta = {
        'filename': 'meta',
        'key': 'fine_label_names',
        'md5': '7973b15100ade9c7d40fb424638fde48',
    }

    def __init__(self, root,  split='train', transform=None, target_transform=None, download=True, loader=default_loader,
                 nb_fold=0, target_list=range(75), seen_list=None):

        #self.root = os.path.expanduser('../' + root)
        self.root = os.path.expanduser(root)
        self.transform = transform
        self.target_transform = target_transform
        self.split = split

        self.loader = loader
        self.label_mat = None
        self.nb_fold = nb_fold
        self.classes = 75

        BaseDataset2.__init__(self, self.root, self.split, self.transform)


        if self.split== 'train':
            downloaded_list = self.train_list
        else:
            downloaded_list = self.test_list

        if download:
            self.download()

        if not self._check_exists():
            raise RuntimeError('Dataset not found. You can use download=True to download it')

        self.imgs = []
        targets = []
        for file_name, checksum in downloaded_list:
            file_path = os.path.join(os.path.join(self.root, self.raw_folder), self.base_folder, file_name)
            with open(file_path, 'rb') as f:
                entry = pickle.load(f, encoding='latin1')
                self.imgs.append(entry['data'])
                if 'labels' in entry:
                    targets.extend(entry['labels'])
                else:
                    targets.extend(entry['fine_labels'])
        self.imgs = np.vstack(self.imgs).reshape(-1, 3, 32, 32)
        self.imgs = self.imgs.transpose((0, 2, 3, 1))  # convert to HWC
        self.load_meta()
        self.targets = targets
        self.data = self.imgs

        if seen_list is None:
            ind = [i for i in range(len(self.targets)) if self.targets[i] in target_list]

            self.data = self.data[ind]
            self.targets = np.array(self.targets)
            self.targets = self.targets[ind].tolist()

            if split == 'train':
                self.data = self.data[::2]
                self.targets = self.targets[::2]
        else:
            ind = [i for i in range(len(self.targets)) if self.targets[i] in seen_list]
            self.seen_data = self.data[ind][1::2]
            self.targets = np.array(self.targets)
            self.seen_targets = self.targets[ind].tolist()[1::2]

            ind = [i for i in range(len(self.targets)) if self.targets[i] not in seen_list]
            self.unseen_data = self.data[ind]
            self.unseen_targets = self.targets[ind].tolist()

            self.data = np.concatenate((self.seen_data, self.unseen_data), 0)
            self.targets = self.seen_targets + self.unseen_targets
            self.targets = np.array(self.targets)
            self.targets = np.ones_like(self.targets) * -1

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        img, target = self.data[index], self.targets[index]

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img)

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target


    def download(self):
        from six.moves import urllib
        import tarfile

        if self._check_exists():
            return

        try:
            os.makedirs(os.path.join(self.root, self.raw_folder))
        except OSError as e:
            if e.errno == errno.EEXIST:
                pass
            else:
                raise

        for url in self.urls:
            print('Downloading ' + url)
            data = urllib.request.urlopen(url)
            filename = url.rpartition('/')[2]
            file_path = os.path.join(self.root, self.raw_folder, filename)
            with open(file_path, 'wb') as f:
                f.write(data.read())
            tar = tarfile.open(file_path, 'r')
            for item in tar:
                tar.extract(item, file_path.replace(filename, ''))
            os.unlink(file_path)

        print('Done!')

    def load_meta(self):
        path = os.path.join(self.root, self.raw_folder, self.base_folder, self.meta['filename'])
        with open(path, 'rb') as infile:
            if sys.version_info[0] == 2:
                data = pickle.load(infile)
            else:
                data = pickle.load(infile, encoding='latin1')
            self.class_names = data[self.meta['key']]
        self.class_to_idx = {_class: i for i, _class in enumerate(self.class_names)}.values()

    def _check_exists(self):
        pth = os.path.join(self.root, self.raw_folder)
        return os.path.exists(os.path.join(pth, self.base_folder))

    def _relocate_labels(self, targets, nb_fold, fold_lst_dict):
        fold_lst = fold_lst_dict['fold_%d' % nb_fold]
        tmp_targets = copy.deepcopy(targets)

        for i, cls_idx in enumerate(fold_lst):
            idxs = np.where(targets == cls_idx)
            tmp_targets[idxs] = 7 + i

            idxs = np.where(targets == 7 + i)
            tmp_targets[idxs] = cls_idx

        return tmp_targets


class CIFAR100_2(BaseDataset2):
    """`CIFAR100 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.

    This is a subclass of the `CIFAR10` Dataset.
    """
    base_folder = 'cifar-100-python'
    raw_folder='raw'
    urls = ["https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"]
    filename = "cifar-100-python.tar.gz"
    tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85'
    train_list = [
        ['train', '16019d7e3df5f24257cddd939b257f8d'],
    ]

    test_list = [
        ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'],
    ]
    meta = {
        'filename': 'meta',
        'key': 'fine_label_names',
        'md5': '7973b15100ade9c7d40fb424638fde48',
    }

    def __init__(self, root,  split='train', transform=None, target_transform=None, download=True, loader=default_loader,
                 nb_fold=0, target_list=range(75), seen_list=None):

        #self.root = os.path.expanduser('../' + root)
        self.root = os.path.expanduser(root)
        self.transform = transform
        self.target_transform = target_transform
        self.split = split

        self.loader = loader
        self.label_mat = None
        self.nb_fold = nb_fold
        self.classes = 75

        BaseDataset2.__init__(self, self.root, self.split, self.transform)


        if self.split== 'train':
            downloaded_list = self.train_list
        else:
            downloaded_list = self.test_list

        if download:
            self.download()

        if not self._check_exists():
            raise RuntimeError('Dataset not found. You can use download=True to download it')

        self.imgs = []
        targets = []
        for file_name, checksum in downloaded_list:
            file_path = os.path.join(os.path.join(self.root, self.raw_folder), self.base_folder, file_name)
            with open(file_path, 'rb') as f:
                entry = pickle.load(f, encoding='latin1')
                self.imgs.append(entry['data'])
                if 'labels' in entry:
                    targets.extend(entry['labels'])
                else:
                    targets.extend(entry['fine_labels'])
        self.imgs = np.vstack(self.imgs).reshape(-1, 3, 32, 32)
        self.imgs = self.imgs.transpose((0, 2, 3, 1))  # convert to HWC
        self.load_meta()
        self.targets = targets
        self.data = self.imgs

        if seen_list is None:
            ind = [i for i in range(len(self.targets)) if self.targets[i] in target_list]

            self.data = self.data[ind]
            self.targets = np.array(self.targets)
            self.targets = self.targets[ind].tolist()

            if split == 'train':
                self.data = self.data[::2]
                self.targets = self.targets[::2]
        else:
            ind = [i for i in range(len(self.targets)) if self.targets[i] in seen_list]
            self.seen_data = self.data[ind][1::2]
            self.targets = np.array(self.targets)
            self.seen_targets = self.targets[ind].tolist()[1::2]

            ind = [i for i in range(len(self.targets)) if self.targets[i] not in seen_list]
            self.unseen_data = self.data[ind]
            self.unseen_targets = self.targets[ind].tolist()

            self.data = np.concatenate((self.seen_data, self.unseen_data), 0)
            self.targets = self.seen_targets + self.unseen_targets
            self.targets = np.array(self.targets)
            self.targets = np.ones_like(self.targets) * -1

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        img, target = self.data[index], self.targets[index]

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img)

        if self.transform is not None:
            img1 = self.transform(img)
            img2 = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img1, img2, target


    def download(self):
        from six.moves import urllib
        import tarfile

        if self._check_exists():
            return

        try:
            os.makedirs(os.path.join(self.root, self.raw_folder))
        except OSError as e:
            if e.errno == errno.EEXIST:
                pass
            else:
                raise

        for url in self.urls:
            print('Downloading ' + url)
            data = urllib.request.urlopen(url)
            filename = url.rpartition('/')[2]
            file_path = os.path.join(self.root, self.raw_folder, filename)
            with open(file_path, 'wb') as f:
                f.write(data.read())
            tar = tarfile.open(file_path, 'r')
            for item in tar:
                tar.extract(item, file_path.replace(filename, ''))
            os.unlink(file_path)

        print('Done!')

    def load_meta(self):
        path = os.path.join(self.root, self.raw_folder, self.base_folder, self.meta['filename'])
        with open(path, 'rb') as infile:
            if sys.version_info[0] == 2:
                data = pickle.load(infile)
            else:
                data = pickle.load(infile, encoding='latin1')
            self.class_names = data[self.meta['key']]
        self.class_to_idx = {_class: i for i, _class in enumerate(self.class_names)}.values()

    def _check_exists(self):
        pth = os.path.join(self.root, self.raw_folder)
        return os.path.exists(os.path.join(pth, self.base_folder))

    def _relocate_labels(self, targets, nb_fold, fold_lst_dict):
        fold_lst = fold_lst_dict['fold_%d' % nb_fold]
        tmp_targets = copy.deepcopy(targets)

        for i, cls_idx in enumerate(fold_lst):
            idxs = np.where(targets == cls_idx)
            tmp_targets[idxs] = 7 + i

            idxs = np.where(targets == 7 + i)
            tmp_targets[idxs] = cls_idx

        return tmp_targets
