from __future__ import print_function
import os
import sys
import errno
import numpy as np
from PIL import Image
import torch.utils.data as data
import contextlib
import pickle
from .base import *
import copy

@contextlib.contextmanager
def temp_seed(seed):
    state = np.random.get_state()
    np.random.seed(seed)
    try:
        yield
    finally:
        np.random.set_state(state)


def pil_loader(path):
    with open(path, 'rb') as f:
        with Image.open(f) as img:
            return img.convert('RGB')


def accimage_loader(path):
    import torchvision.datasets.accimage as accimage
    try:
        return accimage.Image(path)
    except IOError:
        return pil_loader(path)


def default_loader(path):
    from torchvision import get_image_backend
    if get_image_backend() == 'accimage':
        return accimage_loader(path)
    else:
        return pil_loader(path)


def build_set(root, split, imgs, noise_type='pairflip', noise_rate=0.5):
    """
       Function to return the lists of paths with the corresponding labels for the images
    Args:
        root (string): Root directory of dataset
        split (str): ['train', 'gallery', 'query'] returns the list pertaining to training images and labels, else otherwise
    Returns:
        return_list: list of 236_comb_fromZeroNoise-tuples with 1st location specifying path and 2nd location specifying the class
    """

    tmp_imgs = imgs

    argidx = np.argsort(tmp_imgs)





class cifar(BaseDataset2):
    """`cifar10 <https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz>`_ Dataset.
    Args:
        root (string): Root directory of dataset the images and corresponding lists exist
            inside raw folder
        train (bool, optional): If True, creates dataset from ``training.pt``,
            otherwise from ``test.pt``.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
    """
    urls = []
    raw_folder = 'raw'
    tgz_md5 = 'c58f30108f718f92721af3b95e74349a'
    train_list = [
        ['data_batch_1', 'c99cafc152244af753f735de768cd75f'],
        ['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'],
        ['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'],
        ['data_batch_4', '634d18415352ddfa80567beed471001a'],
        ['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'],
    ]

    test_list = [
        ['test_batch', '40351d587109b95175f43aff81a1287e'],
    ]
    meta = {
        'filename': 'batches.meta',
        'key': 'label_names',
        'md5': '5ff9c542aee3614f3951f8cda6e48888',
    }

    fold_lst_dict = {
        'fold_0': [7, 8, 9],
        'fold_1': [4, 5, 6],
        'fold_2': [2, 3, 4],
        'fold_3': [0, 1, 2],
    }

    def __init__(self, root,  split='train', transform=None, target_transform=None, download=True, loader=default_loader,
                 nb_fold=0, seen_portion=0.5, seen_rate=0.7):

        #self.root = os.path.expanduser('../' + root)
        self.root = os.path.expanduser(root)
        self.transform = transform
        self.target_transform = target_transform
        self.split = split

        self.loader = loader
        self.label_mat = None
        self.nb_fold = nb_fold
        self.classes=7
        self.seen_rate = seen_rate

        BaseDataset2.__init__(self, self.root, self.split, self.transform)


        # if self.train:
        #     downloaded_list = self.train_list
        # else:
        #     downloaded_list = self.test_list
        downloaded_list = self.train_list + self.test_list

        self.urls = ['https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz']

        if download:
            self.download()

        if not self._check_exists():
            raise RuntimeError('Dataset not found. You can use download=True to download it')

        self.imgs = []
        targets = []
        for file_name, checksum in downloaded_list:
            file_path = os.path.join(os.path.join(self.root, self.raw_folder), 'cifar-10-batches-py/', file_name)
            with open(file_path, 'rb') as f:
                entry = pickle.load(f, encoding='latin1')
                self.imgs.append(entry['data'])
                if 'labels' in entry:
                    targets.extend(entry['labels'])
                else:
                    targets.extend(entry['fine_labels'])
        self.imgs = np.vstack(self.imgs).reshape(-1, 3, 32, 32)
        self.imgs = self.imgs.transpose((0, 2, 3, 1))  # convert to HWC
        self.load_meta()

        seen_classes = int(10 * self.seen_rate)

        argidx = np.argsort(targets)
        self.imgs = self.imgs[argidx]
        targets = np.array(targets)[argidx]

        targets = self._relocate_labels(targets, nb_fold, self.fold_lst_dict)

        labeled_masks = targets < seen_classes
        unlabeled_masks = targets >= seen_classes

        seen_imgs = self.imgs[labeled_masks]
        seen_targets = targets[labeled_masks]

        unseen_imgs = self.imgs[unlabeled_masks]
        unseen_targets = targets[unlabeled_masks]

        self.query_x = unseen_imgs[::2]
        self.query_y = unseen_targets[::2]

        unseen_gallery_x = np.delete(unseen_imgs, np.arange(0, np.shape(unseen_imgs)[0], 2), axis=0)
        unseen_gallery_y = np.delete(unseen_targets, np.arange(0, np.shape(unseen_targets)[0], 2), axis=0)

        seen_gallery_x_split = np.split(seen_imgs, seen_classes, 0)
        seen_gallery_y_split = np.split(seen_targets, seen_classes, 0)

        portion = int(6000 * seen_portion)
        for i in range(seen_classes):
            Source_tmp_x = seen_gallery_x_split[i][:portion]
            Source_tmp_y = seen_gallery_y_split[i][:portion]
            Target_tmp_x = seen_gallery_x_split[i][portion:]
            Target_tmp_y = seen_gallery_y_split[i][portion:]

            if i == 0:
                self.Source_x = Source_tmp_x
                self.Source_y = Source_tmp_y
                Target_x = Target_tmp_x
                Target_y = Target_tmp_y
            else:
                self.Source_x = np.concatenate((self.Source_x, Source_tmp_x), axis=0)
                self.Source_y = np.concatenate((self.Source_y, Source_tmp_y), axis=0)
                Target_x = np.concatenate((Target_x, Target_tmp_x), axis=0)
                Target_y = np.concatenate((Target_y, Target_tmp_y), axis=0)

        self.Gallery_x = np.concatenate((Target_x, unseen_gallery_x), axis=0)
        self.Gallery_y = np.concatenate((Target_y, unseen_gallery_y), axis=0)

        if split == 'train':
            self.imgs = list(zip(self.Source_x, self.Source_y))
        elif split == 'gallery':
            self.imgs = list(zip(self.Gallery_x, self.Gallery_y))
        elif split == 'query':
            self.imgs = list(zip(self.query_x, self.query_y))

            # Construct label mat between gallary and query
            query_one_hot = np.eye(10)[self.query_y]
            gallery_one_hot = np.eye(10)[self.Gallery_y]
            self.label_mat = np.matmul(query_one_hot, np.transpose(gallery_one_hot))

        index = 0
        for i in self.imgs:
            # i[1]: label, i[0]: root
            y = i[1]

            # if y in self.classes and fn[:2] != '._':

            self.ys += [y]
            self.I += [index]
            self.im_paths.append(i[0])
            index += 1
    def rebuild_imgs(self, pseudo_labels):
        imgs, labels = zip(*list(self.imgs))
        self.imgs = list(zip(imgs, pseudo_labels))

    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        # path, target = self.imgs[index]
        # img = self.loader(path)
        #
        # if self.transform is not None:
        #     img = self.transform(img)
        #
        # if self.target_transform is not None:
        #     img = self.target_transform(img)
        #
        # return img, target

        img = self.imgs[index][0]
        #img = self.loader(path)
        img = Image.fromarray(img)

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            img = self.target_transform(img)

        return (img, *self.imgs[index][1:])


    def download(self):
        from six.moves import urllib
        import tarfile

        if self._check_exists():
            return

        try:
            os.makedirs(os.path.join(self.root, self.raw_folder))
        except OSError as e:
            if e.errno == errno.EEXIST:
                pass
            else:
                raise

        for url in self.urls:
            print('Downloading ' + url)
            data = urllib.request.urlopen(url)
            filename = url.rpartition('/')[2]
            file_path = os.path.join(self.root, self.raw_folder, filename)
            with open(file_path, 'wb') as f:
                f.write(data.read())
            tar = tarfile.open(file_path, 'r')
            for item in tar:
                tar.extract(item, file_path.replace(filename, ''))
            os.unlink(file_path)

        print('Done!')

    def load_meta(self):
        path = os.path.join(self.root, self.raw_folder,'cifar-10-batches-py/',self.meta['filename'])
        with open(path, 'rb') as infile:
            if sys.version_info[0] == 2:
                data = pickle.load(infile)
            else:
                data = pickle.load(infile, encoding='latin1')
            self.class_names = data[self.meta['key']]
        self.class_to_idx = {_class: i for i, _class in enumerate(self.class_names)}.values()

    def _check_exists(self):
        pth = os.path.join(self.root, self.raw_folder)
        return os.path.exists(os.path.join(pth, 'cifar-10-batches-py'))

    def _relocate_labels(self, targets, nb_fold, fold_lst_dict):
        fold_lst = fold_lst_dict['fold_%d' % nb_fold]
        tmp_targets = copy.deepcopy(targets)

        for i, cls_idx in enumerate(fold_lst):
            idxs = np.where(targets == cls_idx)
            tmp_targets[idxs] = 7 + i

            idxs = np.where(targets == 7 + i)
            tmp_targets[idxs] = cls_idx

        return tmp_targets

@contextlib.contextmanager
def temp_seed(seed):
    state = np.random.get_state()
    np.random.seed(seed)
    try:
        yield
    finally:
        np.random.set_state(state)

class cifar_randomsample(BaseDataset2):
    """`cifar10 <https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz>`_ Dataset.
    Args:
        root (string): Root directory of dataset the images and corresponding lists exist
            inside raw folder
        train (bool, optional): If True, creates dataset from ``training.pt``,
            otherwise from ``test.pt``.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
    """
    urls = []
    raw_folder = 'raw'
    tgz_md5 = 'c58f30108f718f92721af3b95e74349a'
    train_list = [
        ['data_batch_1', 'c99cafc152244af753f735de768cd75f'],
        ['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'],
        ['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'],
        ['data_batch_4', '634d18415352ddfa80567beed471001a'],
        ['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'],
    ]

    test_list = [
        ['test_batch', '40351d587109b95175f43aff81a1287e'],
    ]
    meta = {
        'filename': 'batches.meta',
        'key': 'label_names',
        'md5': '5ff9c542aee3614f3951f8cda6e48888',
    }

    fold_lst_dict = {
        'fold_0': [7, 8, 9],
        'fold_1': [4, 5, 6],
        'fold_2': [2, 3, 4],
        'fold_3': [0, 1, 2],
    }

    def __init__(self, root,  split='train', transform=None, target_transform=None, download=True, loader=default_loader,
                 nb_fold=0, seen_portion=0.5, seen_rate=0.7):

        #self.root = os.path.expanduser('../' + root)
        self.root = os.path.expanduser(root)
        self.transform = transform
        self.target_transform = target_transform
        self.split = split

        self.loader = loader
        self.label_mat = None
        self.nb_fold = nb_fold
        self.classes=7
        self.seen_rate = seen_rate

        BaseDataset2.__init__(self, self.root, self.split, self.transform)


        # if self.train:
        #     downloaded_list = self.train_list
        # else:
        #     downloaded_list = self.test_list
        downloaded_list = self.train_list + self.test_list

        self.urls = ['https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz']

        if download:
            self.download()

        if not self._check_exists():
            raise RuntimeError('Dataset not found. You can use download=True to download it')

        self.imgs = []
        targets = []
        for file_name, checksum in downloaded_list:
            file_path = os.path.join(os.path.join(self.root, self.raw_folder), 'cifar-10-batches-py/', file_name)
            with open(file_path, 'rb') as f:
                entry = pickle.load(f, encoding='latin1')
                self.imgs.append(entry['data'])
                if 'labels' in entry:
                    targets.extend(entry['labels'])
                else:
                    targets.extend(entry['fine_labels'])
        self.imgs = np.vstack(self.imgs).reshape(-1, 3, 32, 32)
        self.imgs = self.imgs.transpose((0, 2, 3, 1))  # convert to HWC
        self.load_meta()

        seen_classes = int(10 * self.seen_rate)

        targets = np.array(targets)

        imgs =  []
        relocated_targets = []
        with temp_seed(nb_fold):
            idx_permute = np.random.permutation(10)[:]
            unseen_cls_idx = idx_permute[:int(10*(1-seen_rate))]
            seen_cls_idx = idx_permute[int(10 * (1 - seen_rate)):]
            for i, cls_idx in enumerate(unseen_cls_idx):
                idxs = np.where(targets == cls_idx)
                tmp_targets = copy.deepcopy(targets[idxs])
                tmp_targets[:] = seen_classes + i
                tmp_imgs = copy.deepcopy(self.imgs[idxs])

                imgs.append(tmp_imgs)
                relocated_targets.append(tmp_targets)
            for i, cls_idx in enumerate(seen_cls_idx):

                idxs = np.where(targets == cls_idx)
                tmp_targets = copy.deepcopy(targets[idxs])
                tmp_targets[:] = i
                tmp_imgs = copy.deepcopy(self.imgs[idxs])

                imgs.append(tmp_imgs)
                relocated_targets.append(tmp_targets)


        #targets = self._relocate_labels(targets, nb_fold, self.fold_lst_dict)
        self.imgs = np.vstack(imgs)
        targets = np.hstack(relocated_targets)

        argidx = np.argsort(targets)
        self.imgs = self.imgs[argidx]
        targets = np.array(targets)[argidx]

        labeled_masks = targets < seen_classes
        unlabeled_masks = targets >= seen_classes

        seen_imgs = self.imgs[labeled_masks]
        seen_targets = targets[labeled_masks]

        unseen_imgs = self.imgs[unlabeled_masks]
        unseen_targets = targets[unlabeled_masks]

        self.query_x = unseen_imgs[::2]
        self.query_y = unseen_targets[::2]

        unseen_gallery_x = np.delete(unseen_imgs, np.arange(0, np.shape(unseen_imgs)[0], 2), axis=0)
        unseen_gallery_y = np.delete(unseen_targets, np.arange(0, np.shape(unseen_targets)[0], 2), axis=0)

        seen_gallery_x_split = np.split(seen_imgs, seen_classes, 0)
        seen_gallery_y_split = np.split(seen_targets, seen_classes, 0)

        portion = int(6000 * seen_portion)
        for i in range(seen_classes):
            Source_tmp_x = seen_gallery_x_split[i][:portion]
            Source_tmp_y = seen_gallery_y_split[i][:portion]
            Target_tmp_x = seen_gallery_x_split[i][portion:]
            Target_tmp_y = seen_gallery_y_split[i][portion:]

            if i == 0:
                self.Source_x = Source_tmp_x
                self.Source_y = Source_tmp_y
                Target_x = Target_tmp_x
                Target_y = Target_tmp_y
            else:
                self.Source_x = np.concatenate((self.Source_x, Source_tmp_x), axis=0)
                self.Source_y = np.concatenate((self.Source_y, Source_tmp_y), axis=0)
                Target_x = np.concatenate((Target_x, Target_tmp_x), axis=0)
                Target_y = np.concatenate((Target_y, Target_tmp_y), axis=0)

        self.Gallery_x = np.concatenate((Target_x, unseen_gallery_x), axis=0)
        self.Gallery_y = np.concatenate((Target_y, unseen_gallery_y), axis=0)

        if split == 'train':
            self.imgs = list(zip(self.Source_x, self.Source_y))
        elif split == 'gallery':
            self.imgs = list(zip(self.Gallery_x, self.Gallery_y))
        elif split == 'query':
            self.imgs = list(zip(self.query_x, self.query_y))

            # Construct label mat between gallary and query
            query_one_hot = np.eye(10)[self.query_y]
            gallery_one_hot = np.eye(10)[self.Gallery_y]
            self.label_mat = np.matmul(query_one_hot, np.transpose(gallery_one_hot))

        index = 0
        for i in self.imgs:
            # i[1]: label, i[0]: root
            y = i[1]

            # if y in self.classes and fn[:2] != '._':

            self.ys += [y]
            self.I += [index]
            self.im_paths.append(i[0])
            index += 1
    def rebuild_imgs(self, pseudo_labels):
        imgs, labels = zip(*list(self.imgs))
        self.imgs = list(zip(imgs, pseudo_labels))

    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        # path, target = self.imgs[index]
        # img = self.loader(path)
        #
        # if self.transform is not None:
        #     img = self.transform(img)
        #
        # if self.target_transform is not None:
        #     img = self.target_transform(img)
        #
        # return img, target

        img = self.imgs[index][0]
        #img = self.loader(path)
        img = Image.fromarray(img)

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            img = self.target_transform(img)

        return (img, *self.imgs[index][1:])


    def download(self):
        from six.moves import urllib
        import tarfile

        if self._check_exists():
            return

        try:
            os.makedirs(os.path.join(self.root, self.raw_folder))
        except OSError as e:
            if e.errno == errno.EEXIST:
                pass
            else:
                raise

        for url in self.urls:
            print('Downloading ' + url)
            data = urllib.request.urlopen(url)
            filename = url.rpartition('/')[2]
            file_path = os.path.join(self.root, self.raw_folder, filename)
            with open(file_path, 'wb') as f:
                f.write(data.read())
            tar = tarfile.open(file_path, 'r')
            for item in tar:
                tar.extract(item, file_path.replace(filename, ''))
            os.unlink(file_path)

        print('Done!')

    def load_meta(self):
        path = os.path.join(self.root, self.raw_folder,'cifar-10-batches-py/',self.meta['filename'])
        with open(path, 'rb') as infile:
            if sys.version_info[0] == 2:
                data = pickle.load(infile)
            else:
                data = pickle.load(infile, encoding='latin1')
            self.class_names = data[self.meta['key']]
        self.class_to_idx = {_class: i for i, _class in enumerate(self.class_names)}.values()

    def _check_exists(self):
        pth = os.path.join(self.root, self.raw_folder)
        return os.path.exists(os.path.join(pth, 'cifar-10-batches-py'))

    def _relocate_labels(self, targets, nb_fold, fold_lst_dict):
        fold_lst = fold_lst_dict['fold_%d' % nb_fold]
        tmp_targets = copy.deepcopy(targets)

        for i, cls_idx in enumerate(fold_lst):
            idxs = np.where(targets == cls_idx)
            tmp_targets[idxs] = 7 + i

            idxs = np.where(targets == 7 + i)
            tmp_targets[idxs] = cls_idx

        return tmp_targets



"""Changed"""
class cifar_randomsample2(BaseDataset2):
    """`cifar10 <https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz>`_ Dataset.
    Args:
        root (string): Root directory of dataset the images and corresponding lists exist
            inside raw folder
        train (bool, optional): If True, creates dataset from ``training.pt``,
            otherwise from ``test.pt``.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
    """
    urls = []
    raw_folder = 'raw'
    tgz_md5 = 'c58f30108f718f92721af3b95e74349a'
    train_list = [
        ['data_batch_1', 'c99cafc152244af753f735de768cd75f'],
        ['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'],
        ['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'],
        ['data_batch_4', '634d18415352ddfa80567beed471001a'],
        ['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'],
    ]

    test_list = [
        ['test_batch', '40351d587109b95175f43aff81a1287e'],
    ]
    meta = {
        'filename': 'batches.meta',
        'key': 'label_names',
        'md5': '5ff9c542aee3614f3951f8cda6e48888',
    }

    fold_lst_dict = {
        'fold_0': [7, 8, 9],
        'fold_1': [4, 5, 6],
        'fold_2': [2, 3, 4],
        'fold_3': [0, 1, 2],
    }

    def __init__(self, root,  split='train', transform=None, target_transform=None, download=True, loader=default_loader,
                 nb_fold=0, seen_portion=0.5, seen_rate=0.7):

        #self.root = os.path.expanduser('../' + root)
        self.root = os.path.expanduser(root)
        self.transform = transform
        self.target_transform = target_transform
        self.split = split

        self.loader = loader
        self.label_mat = None
        self.nb_fold = nb_fold
        self.classes=7
        self.seen_rate = seen_rate

        BaseDataset2.__init__(self, self.root, self.split, self.transform)


        # if self.train:
        #     downloaded_list = self.train_list
        # else:
        #     downloaded_list = self.test_list
        downloaded_list = self.train_list + self.test_list

        self.urls = ['https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz']

        if download:
            self.download()

        if not self._check_exists():
            raise RuntimeError('Dataset not found. You can use download=True to download it')

        self.imgs = []
        targets = []
        for file_name, checksum in downloaded_list:
            file_path = os.path.join(os.path.join(self.root, self.raw_folder), 'cifar-10-batches-py/', file_name)
            with open(file_path, 'rb') as f:
                entry = pickle.load(f, encoding='latin1')
                self.imgs.append(entry['data'])
                if 'labels' in entry:
                    targets.extend(entry['labels'])
                else:
                    targets.extend(entry['fine_labels'])
        self.imgs = np.vstack(self.imgs).reshape(-1, 3, 32, 32)
        self.imgs = self.imgs.transpose((0, 2, 3, 1))  # convert to HWC
        self.load_meta()

        seen_classes = int(10 * self.seen_rate)

        targets = np.array(targets)

        imgs =  []
        relocated_targets = []
        with temp_seed(nb_fold):
            idx_permute = np.random.permutation(10)[:]
            unseen_cls_idx = idx_permute[:int(10*(1-seen_rate))]
            seen_cls_idx = idx_permute[int(10 * (1 - seen_rate)):]
            for i, cls_idx in enumerate(unseen_cls_idx):
                idxs = np.where(targets == cls_idx)
                tmp_targets = copy.deepcopy(targets[idxs])
                tmp_targets[:] = seen_classes + i
                tmp_imgs = copy.deepcopy(self.imgs[idxs])

                imgs.append(tmp_imgs)
                relocated_targets.append(tmp_targets)
            for i, cls_idx in enumerate(seen_cls_idx):

                idxs = np.where(targets == cls_idx)
                tmp_targets = copy.deepcopy(targets[idxs])
                tmp_targets[:] = i
                tmp_imgs = copy.deepcopy(self.imgs[idxs])

                imgs.append(tmp_imgs)
                relocated_targets.append(tmp_targets)


        #targets = self._relocate_labels(targets, nb_fold, self.fold_lst_dict)
        self.imgs = np.vstack(imgs)
        targets = np.hstack(relocated_targets)

        argidx = np.argsort(targets)
        self.imgs = self.imgs[argidx]
        targets = np.array(targets)[argidx]

        labeled_masks = targets < seen_classes
        unlabeled_masks = targets >= seen_classes

        seen_imgs = self.imgs[labeled_masks]
        seen_targets = targets[labeled_masks]

        unseen_imgs = self.imgs[unlabeled_masks]
        unseen_targets = targets[unlabeled_masks]

        self.query_x = unseen_imgs[::2]
        self.query_y = unseen_targets[::2]

        unseen_gallery_x = np.delete(unseen_imgs, np.arange(0, np.shape(unseen_imgs)[0], 2), axis=0)
        unseen_gallery_y = np.delete(unseen_targets, np.arange(0, np.shape(unseen_targets)[0], 2), axis=0)

        seen_gallery_x_split = np.split(seen_imgs, seen_classes, 0)
        seen_gallery_y_split = np.split(seen_targets, seen_classes, 0)

        portion = int(6000 * seen_portion)
        for i in range(seen_classes):
            Source_tmp_x = seen_gallery_x_split[i][:portion]
            Source_tmp_y = seen_gallery_y_split[i][:portion]
            Target_tmp_x = seen_gallery_x_split[i][portion:]
            Target_tmp_y = seen_gallery_y_split[i][portion:]

            if i == 0:
                self.Source_x = Source_tmp_x
                self.Source_y = Source_tmp_y
                Target_x = Target_tmp_x
                Target_y = Target_tmp_y
            else:
                self.Source_x = np.concatenate((self.Source_x, Source_tmp_x), axis=0)
                self.Source_y = np.concatenate((self.Source_y, Source_tmp_y), axis=0)
                Target_x = np.concatenate((Target_x, Target_tmp_x), axis=0)
                Target_y = np.concatenate((Target_y, Target_tmp_y), axis=0)

        self.Gallery_x = np.concatenate((Target_x, unseen_gallery_x), axis=0)
        self.Gallery_y = np.concatenate((Target_y, unseen_gallery_y), axis=0)

        if split == 'train':
            self.imgs = list(zip(self.Source_x, self.Source_y))
        elif split == 'gallery':
            self.imgs = list(zip(self.Gallery_x, self.Gallery_y))
        elif split == 'query':
            self.imgs = list(zip(self.query_x, self.query_y))

            # Construct label mat between gallary and query
            query_one_hot = np.eye(10)[self.query_y]
            gallery_one_hot = np.eye(10)[self.Gallery_y]
            self.label_mat = np.matmul(query_one_hot, np.transpose(gallery_one_hot))

        index = 0
        for i in self.imgs:
            # i[1]: label, i[0]: root
            y = i[1]

            # if y in self.classes and fn[:2] != '._':

            self.ys += [y]
            self.I += [index]
            self.im_paths.append(i[0])
            index += 1
    def rebuild_imgs(self, pseudo_labels):
        imgs, labels = zip(*list(self.imgs))
        self.imgs = list(zip(imgs, pseudo_labels))

    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        # path, target = self.imgs[index]
        # img = self.loader(path)
        #
        # if self.transform is not None:
        #     img = self.transform(img)
        #
        # if self.target_transform is not None:
        #     img = self.target_transform(img)
        #
        # return img, target

        img = self.imgs[index][0]
        #img = self.loader(path)
        img = Image.fromarray(img)
        """Changed"""
        if self.transform is not None:
            img1 = self.transform(img)
            img2 = self.transform(img)

        if self.target_transform is not None:
            img = self.target_transform(img)


        if self.split == 'split':
            return (img1, *self.imgs[index][1:])
        return (img1, img2, *self.imgs[index][1:])


    def download(self):
        from six.moves import urllib
        import tarfile

        if self._check_exists():
            return

        try:
            os.makedirs(os.path.join(self.root, self.raw_folder))
        except OSError as e:
            if e.errno == errno.EEXIST:
                pass
            else:
                raise

        for url in self.urls:
            print('Downloading ' + url)
            data = urllib.request.urlopen(url)
            filename = url.rpartition('/')[2]
            file_path = os.path.join(self.root, self.raw_folder, filename)
            with open(file_path, 'wb') as f:
                f.write(data.read())
            tar = tarfile.open(file_path, 'r')
            for item in tar:
                tar.extract(item, file_path.replace(filename, ''))
            os.unlink(file_path)

        print('Done!')

    def load_meta(self):
        path = os.path.join(self.root, self.raw_folder,'cifar-10-batches-py/',self.meta['filename'])
        with open(path, 'rb') as infile:
            if sys.version_info[0] == 2:
                data = pickle.load(infile)
            else:
                data = pickle.load(infile, encoding='latin1')
            self.class_names = data[self.meta['key']]
        self.class_to_idx = {_class: i for i, _class in enumerate(self.class_names)}.values()

    def _check_exists(self):
        pth = os.path.join(self.root, self.raw_folder)
        return os.path.exists(os.path.join(pth, 'cifar-10-batches-py'))

    def _relocate_labels(self, targets, nb_fold, fold_lst_dict):
        fold_lst = fold_lst_dict['fold_%d' % nb_fold]
        tmp_targets = copy.deepcopy(targets)

        for i, cls_idx in enumerate(fold_lst):
            idxs = np.where(targets == cls_idx)
            tmp_targets[idxs] = 7 + i

            idxs = np.where(targets == 7 + i)
            tmp_targets[idxs] = cls_idx

        return tmp_targets


def encode_onehot(labels, num_classes=75):
    """
    one-hot labels
    Args:
        labels (numpy.ndarray): labels.
        num_classes (int): Number of classes.
    Returns:
        onehot_labels (numpy.ndarray): one-hot labels.
    """
    onehot_labels = np.zeros((len(labels), num_classes))

    for i in range(len(labels)):
        onehot_labels[i, labels[i]] = 1

    return onehot_labels


# class CIFAR100(BaseDataset2):
#     """`CIFAR100 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.
#
#     This is a subclass of the `CIFAR10` Dataset.
#     """
#     base_folder = 'cifar-100-python'
#     raw_folder='raw'
#     urls = ["https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"]
#     filename = "cifar-100-python.tar.gz"
#     tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85'
#     train_list = [
#         ['train', '16019d7e3df5f24257cddd939b257f8d'],
#     ]
#
#     test_list = [
#         ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'],
#     ]
#     meta = {
#         'filename': 'meta',
#         'key': 'fine_label_names',
#         'md5': '7973b15100ade9c7d40fb424638fde48',
#     }
#
#     def __init__(self, root,  split='train', transform=None, target_transform=None, download=True, loader=default_loader,
#                  nb_fold=0):
#
#         #self.root = os.path.expanduser('../' + root)
#         self.root = os.path.expanduser(root)
#         self.transform = transform
#         self.target_transform = target_transform
#         self.split = split
#
#         self.loader = loader
#         self.label_mat = None
#         self.nb_fold = nb_fold
#         self.classes = 75
#
#         BaseDataset2.__init__(self, self.root, self.split, self.transform)
#
#
#         # if self.train:
#         #     downloaded_list = self.train_list
#         # else:
#         #     downloaded_list = self.test_list
#         downloaded_list = self.train_list + self.test_list
#
#         if download:
#             self.download()
#
#         if not self._check_exists():
#             raise RuntimeError('Dataset not found. You can use download=True to download it')
#
#         self.imgs = []
#         targets = []
#         for file_name, checksum in downloaded_list:
#             file_path = os.path.join(os.path.join(self.root, self.raw_folder), self.base_folder, file_name)
#             with open(file_path, 'rb') as f:
#                 entry = pickle.load(f, encoding='latin1')
#                 self.imgs.append(entry['data'])
#                 if 'labels' in entry:
#                     targets.extend(entry['labels'])
#                 else:
#                     targets.extend(entry['fine_labels'])
#         self.imgs = np.vstack(self.imgs).reshape(-1, 3, 32, 32)
#         self.imgs = self.imgs.transpose((0, 2, 3, 1))  # convert to HWC
#         self.load_meta()
#
#         unseen_imgs = []
#         unseen_cls_idxs = []
#         switch_cls_idxs = []
#         total_y_onehot = encode_onehot(np.array(targets), num_classes=100)
#
#         with temp_seed(nb_fold):
#             # Choice Unseen class according to fold idx
#             unseen_cls_idxs = np.random.choice(100, 25, replace=False)
#
#             unseen_img_labels = []
#             unseen_img_paths = []
#             for cls_idx in unseen_cls_idxs:
#                 idxs = np.where(total_y_onehot[:, cls_idx] == 1)[0]
#                 sub_unseen_labels = total_y_onehot[idxs]
#                 sub_unseen_img_paths = self.imgs[idxs]
#
#                 total_y_onehot = np.delete(total_y_onehot, idxs, axis=0)
#                 self.imgs = np.delete(self.imgs, idxs, axis=0)
#
#                 unseen_img_labels.append(sub_unseen_labels)
#                 unseen_img_paths.append(sub_unseen_img_paths)
#
#             unseen_img_labels = np.concatenate(unseen_img_labels, axis=0)
#             unseen_img_paths = np.concatenate(unseen_img_paths, axis=0)
#
#             # Split unseen data into gallery and query set
#             self.query_x = unseen_img_paths[::2]
#             self.query_y = unseen_img_labels[::2]
#
#
#             unseen_gallery_x = np.delete(unseen_img_paths, np.arange(0, np.shape(unseen_img_paths)[0], 2), axis=0)
#             unseen_gallery_y = np.delete(unseen_img_labels, np.arange(0, np.shape(unseen_img_labels)[0], 2), axis=0)
#
#         # Delete one-hot gt for unseen class
#         extracted_img_labels = np.delete(total_y_onehot, unseen_cls_idxs, axis=1)
#
#         # Split seen data into source, gallery set
#         self.source_x = self.imgs[::2]
#         self.source_y = extracted_img_labels[::2]
#         self.source_y = np.where(self.source_y != 0)[1]
#
#         seen_gallery_x = np.delete(self.imgs, np.arange(0, np.shape(self.imgs)[0], 2), axis=0)
#         seen_gallery_y = np.delete(total_y_onehot, np.arange(0, np.shape(self.imgs)[0], 2), axis=0)
#
#         self.gallery_x = np.concatenate((seen_gallery_x, unseen_gallery_x), axis=0)
#         self.gallery_y = np.concatenate((seen_gallery_y, unseen_gallery_y), axis=0)
#
#         seen_eval_y = np.delete(seen_gallery_y, unseen_cls_idxs, axis=1)
#
#         if split == 'train':
#             self.imgs = list(zip(self.source_x, self.source_y))
#         elif split == 'gallery':
#             self.gallery_y = np.where(self.gallery_y != 0)[1]
#             self.imgs = list(zip(self.gallery_x, self.gallery_y))
#         elif split == 'query':
#             # Construct label mat between gallary and query
#             self.label_mat = (np.matmul(self.query_y, np.transpose(self.gallery_y)) > 0).astype(np.float32)
#             self.query_y = np.where(self.query_y != 0)[1]
#             self.imgs = list(zip(self.query_x, self.query_y))
#
#
#         elif split == 'eval':
#             seen_eval_y = np.where(seen_eval_y != 0)[1]
#             self.imgs = list(zip(seen_gallery_x, seen_eval_y))
#
#             self.label_mat = (np.matmul(seen_eval_y, np.transpose(seen_eval_y)) > 0).astype(np.float32)
#
#         # index = 0
#         # for i in self.imgs:
#         #     # i[1]: label, i[0]: root
#         #     y = np.where(i[1] != 0)[1]
#         #
#         #     # if y in self.classes and fn[:2] != '._':
#         #
#         #     self.ys += [y]
#         #     self.I += [index]
#         #     self.im_paths.append(i[0])
#         #     index += 1
#
#     def __len__(self):
#         return len(self.imgs)
#
#     def __getitem__(self, index):
#         """
#         Args:
#             index (int): Index
#         Returns:
#             tuple: (image, target) where target is index of the target class.
#         """
#         # path, target = self.imgs[index]
#         # img = self.loader(path)
#         #
#         # if self.transform is not None:
#         #     img = self.transform(img)
#         #
#         # if self.target_transform is not None:
#         #     img = self.target_transform(img)
#         #
#         # return img, target
#
#         img = self.imgs[index][0]
#         #img = self.loader(path)
#         img = Image.fromarray(img)
#
#         if self.transform is not None:
#             img = self.transform(img)
#
#         if self.target_transform is not None:
#             img = self.target_transform(img)
#
#         return (img, *self.imgs[index][1:])
#
#
#     def download(self):
#         from six.moves import urllib
#         import tarfile
#
#         if self._check_exists():
#             return
#
#         try:
#             os.makedirs(os.path.join(self.root, self.raw_folder))
#         except OSError as e:
#             if e.errno == errno.EEXIST:
#                 pass
#             else:
#                 raise
#
#         for url in self.urls:
#             print('Downloading ' + url)
#             data = urllib.request.urlopen(url)
#             filename = url.rpartition('/')[2]
#             file_path = os.path.join(self.root, self.raw_folder, filename)
#             with open(file_path, 'wb') as f:
#                 f.write(data.read())
#             tar = tarfile.open(file_path, 'r')
#             for item in tar:
#                 tar.extract(item, file_path.replace(filename, ''))
#             os.unlink(file_path)
#
#         print('Done!')
#
#     def load_meta(self):
#         path = os.path.join(self.root, self.raw_folder, self.base_folder, self.meta['filename'])
#         with open(path, 'rb') as infile:
#             if sys.version_info[0] == 2:
#                 data = pickle.load(infile)
#             else:
#                 data = pickle.load(infile, encoding='latin1')
#             self.class_names = data[self.meta['key']]
#         self.class_to_idx = {_class: i for i, _class in enumerate(self.class_names)}.values()
#
#     def _check_exists(self):
#         pth = os.path.join(self.root, self.raw_folder)
#         return os.path.exists(os.path.join(pth, self.base_folder))
#
#     def _relocate_labels(self, targets, nb_fold, fold_lst_dict):
#         fold_lst = fold_lst_dict['fold_%d' % nb_fold]
#         tmp_targets = copy.deepcopy(targets)
#
#         for i, cls_idx in enumerate(fold_lst):
#             idxs = np.where(targets == cls_idx)
#             tmp_targets[idxs] = 7 + i
#
#             idxs = np.where(targets == 7 + i)
#             tmp_targets[idxs] = cls_idx
#
#         return tmp_targets


class CIFAR100(BaseDataset2):
    """`CIFAR100 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.

    This is a subclass of the `CIFAR10` Dataset.
    """
    base_folder = 'cifar-100-python'
    raw_folder='raw'
    urls = ["https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"]
    filename = "cifar-100-python.tar.gz"
    tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85'
    train_list = [
        ['train', '16019d7e3df5f24257cddd939b257f8d'],
    ]

    test_list = [
        ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'],
    ]
    meta = {
        'filename': 'meta',
        'key': 'fine_label_names',
        'md5': '7973b15100ade9c7d40fb424638fde48',
    }

    fold_lst_dict = {
        'fold_0': range(75, 100),
        'fold_1': range(50, 75),
        'fold_2': range(25, 50),
        'fold_3': range(0, 25),
    }
    def __init__(self, root,  split='train', transform=None, target_transform=None, download=True, loader=default_loader,
                 nb_fold=0):

        #self.root = os.path.expanduser('../' + root)
        self.root = os.path.expanduser(root)
        self.transform = transform
        self.target_transform = target_transform
        self.split = split

        self.loader = loader
        self.label_mat = None
        self.nb_fold = nb_fold
        #self.classes = 75

        BaseDataset2.__init__(self, self.root, self.split, self.transform)


        # if self.train:
        #     downloaded_list = self.train_list
        # else:
        #     downloaded_list = self.test_list
        downloaded_list = self.train_list + self.test_list

        if download:
            self.download()

        if not self._check_exists():
            raise RuntimeError('Dataset not found. You can use download=True to download it')

        self.imgs = []
        targets = []
        for file_name, checksum in downloaded_list:
            file_path = os.path.join(os.path.join(self.root, self.raw_folder), self.base_folder, file_name)
            with open(file_path, 'rb') as f:
                entry = pickle.load(f, encoding='latin1')
                self.imgs.append(entry['data'])
                if 'labels' in entry:
                    targets.extend(entry['labels'])
                else:
                    targets.extend(entry['fine_labels'])
        self.imgs = np.vstack(self.imgs).reshape(-1, 3, 32, 32)
        self.imgs = self.imgs.transpose((0, 2, 3, 1))  # convert to HWC
        self.load_meta()


        argidx = np.argsort(targets)
        self.imgs = self.imgs[argidx]
        targets = np.array(targets)[argidx]

        targets = self._relocate_labels(targets, nb_fold, self.fold_lst_dict)

        unseen_imgs = []
        unseen_cls_idxs = []
        switch_cls_idxs = []
        #total_y_onehot = encode_onehot(np.array(targets), num_classes=100)

        labeled_masks = targets < 75
        unlabeled_masks = targets >= 75

        seen_imgs = self.imgs[labeled_masks]
        seen_targets = targets[labeled_masks]

        unseen_imgs = self.imgs[unlabeled_masks]
        unseen_targets = targets[unlabeled_masks]

        self.query_x = unseen_imgs[::2]
        self.query_y = unseen_targets[::2]

        unseen_gallery_x = np.delete(unseen_imgs, np.arange(0, np.shape(unseen_imgs)[0], 2), axis=0)
        unseen_gallery_y = np.delete(unseen_targets, np.arange(0, np.shape(unseen_targets)[0], 2), axis=0)

        seen_gallery_x_split = np.split(seen_imgs, 75, 0)
        seen_gallery_y_split = np.split(seen_targets, 75, 0)


        for i in range(75):
            Source_tmp_x = seen_gallery_x_split[i][:300]
            Source_tmp_y = seen_gallery_y_split[i][:300]
            Target_tmp_x = seen_gallery_x_split[i][300:]
            Target_tmp_y = seen_gallery_y_split[i][300:]

            if i == 0:
                self.Source_x = Source_tmp_x
                self.Source_y = Source_tmp_y
                Target_x = Target_tmp_x
                Target_y = Target_tmp_y
            else:
                self.Source_x = np.concatenate((self.Source_x, Source_tmp_x), axis=0)
                self.Source_y = np.concatenate((self.Source_y, Source_tmp_y), axis=0)
                Target_x = np.concatenate((Target_x, Target_tmp_x), axis=0)
                Target_y = np.concatenate((Target_y, Target_tmp_y), axis=0)

        self.Gallery_x = np.concatenate((Target_x, unseen_gallery_x), axis=0)
        self.Gallery_y = np.concatenate((Target_y, unseen_gallery_y), axis=0)

        if split == 'train':
            self.imgs = list(zip(self.Source_x, self.Source_y))
        elif split == 'gallery':
            self.imgs = list(zip(self.Gallery_x, self.Gallery_y))
        elif split == 'query':
            self.imgs = list(zip(self.query_x, self.query_y))

            # Construct label mat between gallary and query
            query_one_hot = np.eye(100)[self.query_y]
            gallery_one_hot = np.eye(100)[self.Gallery_y]
            self.label_mat = np.matmul(query_one_hot, np.transpose(gallery_one_hot))

        index = 0
        for i in self.imgs:
            # i[1]: label, i[0]: root
            y = i[1]

            # if y in self.classes and fn[:2] != '._':

            self.ys += [y]
            self.I += [index]
            self.im_paths.append(i[0])
            index += 1

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        # path, target = self.imgs[index]
        # img = self.loader(path)
        #
        # if self.transform is not None:
        #     img = self.transform(img)
        #
        # if self.target_transform is not None:
        #     img = self.target_transform(img)
        #
        # return img, target

        img = self.imgs[index][0]
        #img = self.loader(path)
        img = Image.fromarray(img)

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            img = self.target_transform(img)

        return (img, *self.imgs[index][1:])


    def download(self):
        from six.moves import urllib
        import tarfile

        if self._check_exists():
            return

        try:
            os.makedirs(os.path.join(self.root, self.raw_folder))
        except OSError as e:
            if e.errno == errno.EEXIST:
                pass
            else:
                raise

        for url in self.urls:
            print('Downloading ' + url)
            data = urllib.request.urlopen(url)
            filename = url.rpartition('/')[2]
            file_path = os.path.join(self.root, self.raw_folder, filename)
            with open(file_path, 'wb') as f:
                f.write(data.read())
            tar = tarfile.open(file_path, 'r')
            for item in tar:
                tar.extract(item, file_path.replace(filename, ''))
            os.unlink(file_path)

        print('Done!')

    def load_meta(self):
        path = os.path.join(self.root, self.raw_folder, self.base_folder, self.meta['filename'])
        with open(path, 'rb') as infile:
            if sys.version_info[0] == 2:
                data = pickle.load(infile)
            else:
                data = pickle.load(infile, encoding='latin1')
            self.class_names = data[self.meta['key']]
        self.class_to_idx = {_class: i for i, _class in enumerate(self.class_names)}.values()

    def _check_exists(self):
        pth = os.path.join(self.root, self.raw_folder)
        return os.path.exists(os.path.join(pth, self.base_folder))

    def _relocate_labels(self, targets, nb_fold, fold_lst_dict):
        fold_lst = fold_lst_dict['fold_%d' % nb_fold]
        tmp_targets = copy.deepcopy(targets)

        for i, cls_idx in enumerate(fold_lst):
            idxs = np.where(targets == cls_idx)
            tmp_targets[idxs] = 75 + i

            idxs = np.where(targets == 75 + i)
            tmp_targets[idxs] = cls_idx

        return tmp_targets


class cifar100_randomsample(BaseDataset2):
    """`CIFAR100 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.

    This is a subclass of the `CIFAR10` Dataset.
    """
    base_folder = 'cifar-100-python'
    raw_folder='raw'
    urls = ["https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"]
    filename = "cifar-100-python.tar.gz"
    tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85'
    train_list = [
        ['train', '16019d7e3df5f24257cddd939b257f8d'],
    ]

    test_list = [
        ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'],
    ]
    meta = {
        'filename': 'meta',
        'key': 'fine_label_names',
        'md5': '7973b15100ade9c7d40fb424638fde48',
    }

    fold_lst_dict = {
        'fold_0': range(75, 100),
        'fold_1': range(50, 75),
        'fold_2': range(25, 50),
        'fold_3': range(0, 25),
    }
    def __init__(self, root,  split='train', transform=None, target_transform=None, download=True, loader=default_loader,
                 nb_fold=0, seen_rate=0.75, seen_portion=0.5):

        #self.root = os.path.expanduser('../' + root)
        self.root = os.path.expanduser(root)
        self.transform = transform
        self.target_transform = target_transform
        self.split = split

        self.loader = loader
        self.label_mat = None
        self.nb_fold = nb_fold
        #self.classes = 75
        self.seen_rate = seen_rate
        self.seen_portion =  seen_portion


        BaseDataset2.__init__(self, self.root, self.split, self.transform)


        # if self.train:
        #     downloaded_list = self.train_list
        # else:
        #     downloaded_list = self.test_list
        downloaded_list = self.train_list + self.test_list

        if download:
            self.download()

        if not self._check_exists():
            raise RuntimeError('Dataset not found. You can use download=True to download it')

        self.imgs = []
        targets = []
        for file_name, checksum in downloaded_list:
            file_path = os.path.join(os.path.join(self.root, self.raw_folder), self.base_folder, file_name)
            with open(file_path, 'rb') as f:
                entry = pickle.load(f, encoding='latin1')
                self.imgs.append(entry['data'])
                if 'labels' in entry:
                    targets.extend(entry['labels'])
                else:
                    targets.extend(entry['fine_labels'])
        self.imgs = np.vstack(self.imgs).reshape(-1, 3, 32, 32)
        self.imgs = self.imgs.transpose((0, 2, 3, 1))  # convert to HWC
        self.load_meta()

        seen_classes = int(100 * self.seen_rate)

        targets = np.array(targets)

        imgs =  []
        relocated_targets = []
        with temp_seed(nb_fold):
            idx_permute = np.random.permutation(100)[:]
            unseen_cls_idx = idx_permute[:int(100*(1-seen_rate))]
            seen_cls_idx = idx_permute[int(100 * (1 - seen_rate)):]
            for i, cls_idx in enumerate(unseen_cls_idx):
                idxs = np.where(targets == cls_idx)
                tmp_targets = copy.deepcopy(targets[idxs])
                tmp_targets[:] = seen_classes + i
                tmp_imgs = copy.deepcopy(self.imgs[idxs])

                imgs.append(tmp_imgs)
                relocated_targets.append(tmp_targets)
            for i, cls_idx in enumerate(seen_cls_idx):

                idxs = np.where(targets == cls_idx)
                tmp_targets = copy.deepcopy(targets[idxs])
                tmp_targets[:] = i
                tmp_imgs = copy.deepcopy(self.imgs[idxs])

                imgs.append(tmp_imgs)
                relocated_targets.append(tmp_targets)


        #targets = self._relocate_labels(targets, nb_fold, self.fold_lst_dict)
        self.imgs = np.vstack(imgs)
        targets = np.hstack(relocated_targets)

        argidx = np.argsort(targets)
        self.imgs = self.imgs[argidx]
        targets = np.array(targets)[argidx]

        labeled_masks = targets < seen_classes
        unlabeled_masks = targets >= seen_classes

        seen_imgs = self.imgs[labeled_masks]
        seen_targets = targets[labeled_masks]

        unseen_imgs = self.imgs[unlabeled_masks]
        unseen_targets = targets[unlabeled_masks]

        self.query_x = unseen_imgs[::2]
        self.query_y = unseen_targets[::2]

        unseen_gallery_x = np.delete(unseen_imgs, np.arange(0, np.shape(unseen_imgs)[0], 2), axis=0)
        unseen_gallery_y = np.delete(unseen_targets, np.arange(0, np.shape(unseen_targets)[0], 2), axis=0)

        seen_gallery_x_split = np.split(seen_imgs, seen_classes, 0)
        seen_gallery_y_split = np.split(seen_targets, seen_classes, 0)

        portion = int(600 * seen_portion)
        for i in range(seen_classes):
            Source_tmp_x = seen_gallery_x_split[i][:portion]
            Source_tmp_y = seen_gallery_y_split[i][:portion]
            Target_tmp_x = seen_gallery_x_split[i][portion:]
            Target_tmp_y = seen_gallery_y_split[i][portion:]

            if i == 0:
                self.Source_x = Source_tmp_x
                self.Source_y = Source_tmp_y
                Target_x = Target_tmp_x
                Target_y = Target_tmp_y
            else:
                self.Source_x = np.concatenate((self.Source_x, Source_tmp_x), axis=0)
                self.Source_y = np.concatenate((self.Source_y, Source_tmp_y), axis=0)
                Target_x = np.concatenate((Target_x, Target_tmp_x), axis=0)
                Target_y = np.concatenate((Target_y, Target_tmp_y), axis=0)

        self.Gallery_x = np.concatenate((Target_x, unseen_gallery_x), axis=0)
        self.Gallery_y = np.concatenate((Target_y, unseen_gallery_y), axis=0)

        if split == 'train':
            self.imgs = list(zip(self.Source_x, self.Source_y))
        elif split == 'gallery':
            self.imgs = list(zip(self.Gallery_x, self.Gallery_y))
        elif split == 'query':
            self.imgs = list(zip(self.query_x, self.query_y))

            # Construct label mat between gallary and query
            query_one_hot = np.eye(100)[self.query_y]
            gallery_one_hot = np.eye(100)[self.Gallery_y]
            self.label_mat = np.matmul(query_one_hot, np.transpose(gallery_one_hot))

        index = 0
        for i in self.imgs:
            # i[1]: label, i[0]: root
            y = i[1]

            # if y in self.classes and fn[:2] != '._':

            self.ys += [y]
            self.I += [index]
            self.im_paths.append(i[0])
            index += 1

    def rebuild_imgs(self, pseudo_labels):
        imgs, labels = zip(*list(self.imgs))
        self.imgs = list(zip(imgs, pseudo_labels))

    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        # path, target = self.imgs[index]
        # img = self.loader(path)
        #
        # if self.transform is not None:
        #     img = self.transform(img)
        #
        # if self.target_transform is not None:
        #     img = self.target_transform(img)
        #
        # return img, target

        img = self.imgs[index][0]
        #img = self.loader(path)
        img = Image.fromarray(img)

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            img = self.target_transform(img)

        return (img, *self.imgs[index][1:])

    def download(self):
        from six.moves import urllib
        import tarfile

        if self._check_exists():
            return

        try:
            os.makedirs(os.path.join(self.root, self.raw_folder))
        except OSError as e:
            if e.errno == errno.EEXIST:
                pass
            else:
                raise

        for url in self.urls:
            print('Downloading ' + url)
            data = urllib.request.urlopen(url)
            filename = url.rpartition('/')[2]
            file_path = os.path.join(self.root, self.raw_folder, filename)
            with open(file_path, 'wb') as f:
                f.write(data.read())
            tar = tarfile.open(file_path, 'r')
            for item in tar:
                tar.extract(item, file_path.replace(filename, ''))
            os.unlink(file_path)

        print('Done!')

    def load_meta(self):
        path = os.path.join(self.root, self.raw_folder, self.base_folder, self.meta['filename'])
        with open(path, 'rb') as infile:
            if sys.version_info[0] == 2:
                data = pickle.load(infile)
            else:
                data = pickle.load(infile, encoding='latin1')
            self.class_names = data[self.meta['key']]
        self.class_to_idx = {_class: i for i, _class in enumerate(self.class_names)}.values()

    def _check_exists(self):
        pth = os.path.join(self.root, self.raw_folder)
        return os.path.exists(os.path.join(pth, self.base_folder))

    def _relocate_labels(self, targets, nb_fold, fold_lst_dict):
        fold_lst = fold_lst_dict['fold_%d' % nb_fold]
        tmp_targets = copy.deepcopy(targets)

        for i, cls_idx in enumerate(fold_lst):
            idxs = np.where(targets == cls_idx)
            tmp_targets[idxs] = 75 + i

            idxs = np.where(targets == 75 + i)
            tmp_targets[idxs] = cls_idx

        return tmp_targets



"""Changed"""
class cifar100_randomsample2(BaseDataset2):
    """`cifar10 <https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz>`_ Dataset.
    Args:
        root (string): Root directory of dataset the images and corresponding lists exist
            inside raw folder
        train (bool, optional): If True, creates dataset from ``training.pt``,
            otherwise from ``test.pt``.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
    """
    base_folder = 'cifar-100-python'
    raw_folder = 'raw'
    urls = ["https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"]
    filename = "cifar-100-python.tar.gz"
    tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85'
    train_list = [
        ['train', '16019d7e3df5f24257cddd939b257f8d'],
    ]

    test_list = [
        ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'],
    ]
    meta = {
        'filename': 'meta',
        'key': 'fine_label_names',
        'md5': '7973b15100ade9c7d40fb424638fde48',
    }

    fold_lst_dict = {
        'fold_0': range(75, 100),
        'fold_1': range(50, 75),
        'fold_2': range(25, 50),
        'fold_3': range(0, 25),
    }

    def __init__(self, root, split='train', transform=None, target_transform=None, download=True, loader=default_loader,
                 nb_fold=0, seen_rate=0.75, seen_portion=0.5):

        # self.root = os.path.expanduser('../' + root)
        self.root = os.path.expanduser(root)
        self.transform = transform
        self.target_transform = target_transform
        self.split = split

        self.loader = loader
        self.label_mat = None
        self.nb_fold = nb_fold
        # self.classes = 75
        self.seen_rate = seen_rate
        self.seen_portion = seen_portion

        BaseDataset2.__init__(self, self.root, self.split, self.transform)

        # if self.train:
        #     downloaded_list = self.train_list
        # else:
        #     downloaded_list = self.test_list
        downloaded_list = self.train_list + self.test_list

        if download:
            self.download()

        if not self._check_exists():
            raise RuntimeError('Dataset not found. You can use download=True to download it')

        self.imgs = []
        targets = []
        for file_name, checksum in downloaded_list:
            file_path = os.path.join(os.path.join(self.root, self.raw_folder), self.base_folder, file_name)
            with open(file_path, 'rb') as f:
                entry = pickle.load(f, encoding='latin1')
                self.imgs.append(entry['data'])
                if 'labels' in entry:
                    targets.extend(entry['labels'])
                else:
                    targets.extend(entry['fine_labels'])
        self.imgs = np.vstack(self.imgs).reshape(-1, 3, 32, 32)
        self.imgs = self.imgs.transpose((0, 2, 3, 1))  # convert to HWC
        self.load_meta()

        seen_classes = int(100 * self.seen_rate)

        targets = np.array(targets)

        imgs = []
        relocated_targets = []
        with temp_seed(nb_fold):
            idx_permute = np.random.permutation(100)[:]
            unseen_cls_idx = idx_permute[:int(100 * (1 - seen_rate))]
            seen_cls_idx = idx_permute[int(100 * (1 - seen_rate)):]
            for i, cls_idx in enumerate(unseen_cls_idx):
                idxs = np.where(targets == cls_idx)
                tmp_targets = copy.deepcopy(targets[idxs])
                tmp_targets[:] = seen_classes + i
                tmp_imgs = copy.deepcopy(self.imgs[idxs])

                imgs.append(tmp_imgs)
                relocated_targets.append(tmp_targets)
            for i, cls_idx in enumerate(seen_cls_idx):
                idxs = np.where(targets == cls_idx)
                tmp_targets = copy.deepcopy(targets[idxs])
                tmp_targets[:] = i
                tmp_imgs = copy.deepcopy(self.imgs[idxs])

                imgs.append(tmp_imgs)
                relocated_targets.append(tmp_targets)

        # targets = self._relocate_labels(targets, nb_fold, self.fold_lst_dict)
        self.imgs = np.vstack(imgs)
        targets = np.hstack(relocated_targets)

        argidx = np.argsort(targets)
        self.imgs = self.imgs[argidx]
        targets = np.array(targets)[argidx]

        labeled_masks = targets < seen_classes
        unlabeled_masks = targets >= seen_classes

        seen_imgs = self.imgs[labeled_masks]
        seen_targets = targets[labeled_masks]

        unseen_imgs = self.imgs[unlabeled_masks]
        unseen_targets = targets[unlabeled_masks]

        self.query_x = unseen_imgs[::2]
        self.query_y = unseen_targets[::2]

        unseen_gallery_x = np.delete(unseen_imgs, np.arange(0, np.shape(unseen_imgs)[0], 2), axis=0)
        unseen_gallery_y = np.delete(unseen_targets, np.arange(0, np.shape(unseen_targets)[0], 2), axis=0)

        seen_gallery_x_split = np.split(seen_imgs, seen_classes, 0)
        seen_gallery_y_split = np.split(seen_targets, seen_classes, 0)

        portion = int(600 * seen_portion)
        for i in range(seen_classes):
            Source_tmp_x = seen_gallery_x_split[i][:portion]
            Source_tmp_y = seen_gallery_y_split[i][:portion]
            Target_tmp_x = seen_gallery_x_split[i][portion:]
            Target_tmp_y = seen_gallery_y_split[i][portion:]

            if i == 0:
                self.Source_x = Source_tmp_x
                self.Source_y = Source_tmp_y
                Target_x = Target_tmp_x
                Target_y = Target_tmp_y
            else:
                self.Source_x = np.concatenate((self.Source_x, Source_tmp_x), axis=0)
                self.Source_y = np.concatenate((self.Source_y, Source_tmp_y), axis=0)
                Target_x = np.concatenate((Target_x, Target_tmp_x), axis=0)
                Target_y = np.concatenate((Target_y, Target_tmp_y), axis=0)

        self.Gallery_x = np.concatenate((Target_x, unseen_gallery_x), axis=0)
        self.Gallery_y = np.concatenate((Target_y, unseen_gallery_y), axis=0)

        if split == 'train':
            self.imgs = list(zip(self.Source_x, self.Source_y))
        elif split == 'gallery':
            self.imgs = list(zip(self.Gallery_x, self.Gallery_y))
        elif split == 'query':
            self.imgs = list(zip(self.query_x, self.query_y))

            # Construct label mat between gallary and query
            query_one_hot = np.eye(100)[self.query_y]
            gallery_one_hot = np.eye(100)[self.Gallery_y]
            self.label_mat = np.matmul(query_one_hot, np.transpose(gallery_one_hot))

        index = 0
        for i in self.imgs:
            # i[1]: label, i[0]: root
            y = i[1]

            # if y in self.classes and fn[:2] != '._':

            self.ys += [y]
            self.I += [index]
            self.im_paths.append(i[0])
            index += 1

    def rebuild_imgs(self, pseudo_labels):
        imgs, labels = zip(*list(self.imgs))
        self.imgs = list(zip(imgs, pseudo_labels))

    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        # path, target = self.imgs[index]
        # img = self.loader(path)
        #
        # if self.transform is not None:
        #     img = self.transform(img)
        #
        # if self.target_transform is not None:
        #     img = self.target_transform(img)
        #
        # return img, target

        img = self.imgs[index][0]
        #img = self.loader(path)
        img = Image.fromarray(img)
        """Changed"""
        if self.transform is not None:
            img1 = self.transform(img)
            img2 = self.transform(img)

        if self.target_transform is not None:
            img = self.target_transform(img)


        if self.split == 'split':
            return (img1, *self.imgs[index][1:])
        return (img1, img2, *self.imgs[index][1:])

    def download(self):
        from six.moves import urllib
        import tarfile

        if self._check_exists():
            return

        try:
            os.makedirs(os.path.join(self.root, self.raw_folder))
        except OSError as e:
            if e.errno == errno.EEXIST:
                pass
            else:
                raise

        for url in self.urls:
            print('Downloading ' + url)
            data = urllib.request.urlopen(url)
            filename = url.rpartition('/')[2]
            file_path = os.path.join(self.root, self.raw_folder, filename)
            with open(file_path, 'wb') as f:
                f.write(data.read())
            tar = tarfile.open(file_path, 'r')
            for item in tar:
                tar.extract(item, file_path.replace(filename, ''))
            os.unlink(file_path)

        print('Done!')

    def load_meta(self):
        path = os.path.join(self.root, self.raw_folder, self.base_folder, self.meta['filename'])
        with open(path, 'rb') as infile:
            if sys.version_info[0] == 2:
                data = pickle.load(infile)
            else:
                data = pickle.load(infile, encoding='latin1')
            self.class_names = data[self.meta['key']]
        self.class_to_idx = {_class: i for i, _class in enumerate(self.class_names)}.values()

    def _check_exists(self):
        pth = os.path.join(self.root, self.raw_folder)
        return os.path.exists(os.path.join(pth, self.base_folder))

    def _relocate_labels(self, targets, nb_fold, fold_lst_dict):
        fold_lst = fold_lst_dict['fold_%d' % nb_fold]
        tmp_targets = copy.deepcopy(targets)

        for i, cls_idx in enumerate(fold_lst):
            idxs = np.where(targets == cls_idx)
            tmp_targets[idxs] = 75 + i

            idxs = np.where(targets == 75 + i)
            tmp_targets[idxs] = cls_idx

        return tmp_targets
