from __future__ import print_function
import os
import errno
import numpy as np
from PIL import Image
import torch.utils.data as data


import contextlib

@contextlib.contextmanager
def temp_seed(seed):
    state = np.random.get_state()
    np.random.seed(seed)
    try:
        yield
    finally:
        np.random.set_state(state)

def pil_loader(path):
    with open(path, 'rb') as f:
        with Image.open(f) as img:
            return img.convert('RGB')


def accimage_loader(path):
    import torchvision.datasets.accimage as accimage
    try:
        return accimage.Image(path)
    except IOError:
        return pil_loader(path)


def default_loader(path):
    from torchvision import get_image_backend
    if get_image_backend() == 'accimage':
        return accimage_loader(path)
    else:
        return pil_loader(path)



def encode_onehot(labels, num_classes=200):
    """
    one-hot labels
    Args:
        labels (numpy.ndarray): labels.
        num_classes (int): Number of classes.
    Returns:
        onehot_labels (numpy.ndarray): one-hot labels.
    """
    onehot_labels = np.zeros((len(labels), num_classes))

    for i in range(len(labels)):
        onehot_labels[i, labels[i]] = 1

    return onehot_labels

class CUB200(data.Dataset):
    """`CUB200 <http://www.vision.caltech.edu/visipedia/CUB-200.html>`_ Dataset.
       `CUB200 <http://www.vision.caltech.edu/visipedia/CUB-200-2011.html>`_ Dataset.
    Args:
        root (string): Root directory of dataset the images and corresponding lists exist
            inside raw folder
        train (bool, optional): If True, creates dataset from ``training.pt``,
            otherwise from ``test.pt``.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        year (int): Year/version of the dataset. Available options are 2010 and 2011
    """
    urls = []
    raw_folder = 'raw'

    def __init__(self, root, year, split='train', nb_fold=0, transform=None, target_transform=None, download=False,
                 loader=default_loader, split_by=1, subsample=1, return_unsup=False, return_flags=False):
        self.root = os.path.expanduser(root)
        self.transform = transform
        self.target_transform = target_transform
        self.year = year
        self.loader = loader
        self.split_by = split_by
        self.subsample = subsample
        self.split = split
        self.nb_classes =200
        self.nb_fold = nb_fold
        assert year == 2010 or year == 2011, "Invalid version of CUB200 dataset"
        if year == 2010:
            self.urls = ['http://www.vision.caltech.edu/visipedia-data/CUB-200/images.tgz',
                         'http://www.vision.caltech.edu/visipedia-data/CUB-200/lists.tgz']

        elif year == 2011:
            self.urls = ['http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz']

        if download:
            self.download()

        if not self._check_exists():
            raise RuntimeError('Dataset not found. You can use download=True to download it')

        self.train_imgs, self.test_imgs, self.classes, self.class_to_idx = self.build_set(os.path.join(self.root, self.raw_folder),
                                                               self.year, split_by, subsample,
                                                               return_unsup, return_flags)
        print('코딩 시작')
        self.label_mat = None
        total_imgs = self.train_imgs + self.test_imgs
        total_path, total_y = list(zip(*total_imgs))
        total_y = np.array(total_y)
        total_path = np.array(total_path)

        with temp_seed(0):
            unseen_classes = np.random.permutation(range(200))[:50]

        unseen_imgs = []
        unseen_cls_idxs = []
        switch_cls_idxs = []
        total_y_onehot = encode_onehot(total_y)

        with temp_seed(nb_fold):
            # Choice Unseen class according to fold idx
            unseen_cls_idxs = np.random.choice(200, 50, replace=False)

            unseen_img_labels = []
            unseen_img_paths = []
            for cls_idx in unseen_cls_idxs:
                idxs = np.where(total_y_onehot[:, cls_idx] == 1)[0]
                sub_unseen_labels = total_y_onehot[idxs]
                sub_unseen_img_paths = total_path[idxs]

                total_y_onehot = np.delete(total_y_onehot, idxs, axis=0)
                total_path = np.delete(total_path, idxs, axis=0)

                unseen_img_labels.append(sub_unseen_labels)
                unseen_img_paths.append(sub_unseen_img_paths)

            unseen_img_labels = np.concatenate(unseen_img_labels, axis=0)
            unseen_img_paths = np.concatenate(unseen_img_paths, axis=0)

            # Split unseen data into gallery and query set
            self.query_x = unseen_img_paths[::2]
            self.query_y = unseen_img_labels[::2]

            unseen_gallery_x = np.delete(unseen_img_paths, np.arange(0, np.shape(unseen_img_paths)[0], 2), axis=0)
            unseen_gallery_y = np.delete(unseen_img_labels, np.arange(0, np.shape(unseen_img_labels)[0], 2), axis=0)

        # Delete one-hot gt for unseen class
        extracted_img_labels = np.delete(total_y_onehot, unseen_cls_idxs, axis=1)

        # Split seen data into source, gallery set
        self.source_x = total_path[::2]
        self.source_y = extracted_img_labels[::2]
        self.source_y = np.where(self.source_y != 0)[1]

        seen_gallery_x = np.delete(total_path, np.arange(0, np.shape(total_path)[0], 2), axis=0)
        seen_gallery_y = np.delete(total_y_onehot, np.arange(0, np.shape(total_path)[0], 2), axis=0)

        self.gallery_x = np.concatenate((seen_gallery_x, unseen_gallery_x), axis=0)
        self.gallery_y = np.concatenate((seen_gallery_y, unseen_gallery_y), axis=0)

        seen_eval_y = np.delete(seen_gallery_y, unseen_cls_idxs, axis=1)

        if split == 'train':
            self.imgs = list(zip(self.source_x, self.source_y))
        elif split == 'gallery':
            self.gallery_y = np.where(self.gallery_y != 0)[1]
            self.imgs = list(zip(self.gallery_x, self.gallery_y))
        elif split == 'query':
            self.label_mat = (np.matmul(self.query_y, np.transpose(self.gallery_y)) > 0).astype(np.float32)
            self.query_y = np.where(self.query_y != 0)[1]
            self.imgs = list(zip(self.query_x, self.query_y))

            # Construct label mat between gallary and query

        elif split == 'eval':
            self.imgs = list(zip(seen_gallery_x, seen_eval_y))

            self.label_mat = (np.matmul(seen_eval_y, np.transpose(seen_eval_y)) > 0).astype(np.float32)


    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        # path, target = self.imgs[index]
        # img = self.loader(path)
        #
        # if self.transform is not None:
        #     img = self.transform(img)
        #
        # if self.target_transform is not None:
        #     img = self.target_transform(img)
        #
        # return img, target

        path = self.imgs[index][0]
        img = self.loader(path)

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            img = self.target_transform(img)

        return (img, *self.imgs[index][1:])


    def _check_exists(self):
        pth = os.path.join(self.root, self.raw_folder)
        if self.year == 2010:
            return os.path.exists(os.path.join(pth, 'images/')) and os.path.exists(os.path.join(pth, 'lists/'))
        elif self.year == 2011:
            return os.path.exists(os.path.join(pth, 'CUB_200_2011/'))

    def __len__(self):
        return len(self.imgs)

    def download(self):
        from six.moves import urllib
        import tarfile

        if self._check_exists():
            return

        try:
            os.makedirs(os.path.join(self.root, self.raw_folder))
        except OSError as e:
            if e.errno == errno.EEXIST:
                pass
            else:
                raise

        for url in self.urls:
            print('Downloading ' + url)
            data = urllib.request.urlopen(url)
            filename = url.rpartition('/')[2]
            file_path = os.path.join(self.root, self.raw_folder, filename)
            with open(file_path, 'wb') as f:
                f.write(data.read())
            tar = tarfile.open(file_path, 'r')
            for item in tar:
                tar.extract(item, file_path.replace(filename, ''))
            os.unlink(file_path)

        print('Done!')

    def build_set(self, root, year, split_by, subsample, return_unsup=False, return_flags=False):
        """
           Function to return the lists of paths with the corresponding labels for the images
        Args:
            root (string): Root directory of dataset
            year (int): Year/version of the dataset. Available options are 2010 and 2011
            train (bool, optional): If true, returns the list pertaining to training images and labels, else otherwise
        Returns:
            return_list: list of 236_comb_fromZeroNoise-tuples with 1st location specifying path and 2nd location specifying the class
        """
        if year == 2010:
            assert False

        elif year == 2011:
            images_file_path = os.path.join(root, 'CUB_200_2011/images/')

            all_images_list_path = os.path.join(root, 'CUB_200_2011/images.txt')
            all_images_list = np.genfromtxt(all_images_list_path, dtype=str)
            train_test_list_path = os.path.join(root, 'CUB_200_2011/train_test_split.txt')
            train_test_list = np.genfromtxt(train_test_list_path, dtype=int)

            train_imgs = []
            test_imgs = []
            classes = []
            class_to_idx = []

            for i in range(0, len(all_images_list)):
                fname = all_images_list[i, 1]
                full_path = os.path.join(images_file_path, fname)
                if train_test_list[i, 1] == 1:
                    train_imgs.append((full_path, int(fname[0:3]) - 1))
                elif train_test_list[i, 1] == 0:
                    test_imgs.append((full_path, int(fname[0:3]) - 1))
                if os.path.split(fname)[0][4:] not in classes:
                    classes.append(os.path.split(fname)[0][4:])
                    class_to_idx.append(int(fname[0:3]) - 1)

            return train_imgs, test_imgs, classes, class_to_idx

    def rebuild_imgs(self, pseudo_labels):
        imgs, labels = zip(*list(self.imgs))
        self.imgs = list(zip(imgs, pseudo_labels))


class CUB200_2(data.Dataset):
    """`CUB200 <http://www.vision.caltech.edu/visipedia/CUB-200.html>`_ Dataset.
       `CUB200 <http://www.vision.caltech.edu/visipedia/CUB-200-2011.html>`_ Dataset.
    Args:
        root (string): Root directory of dataset the images and corresponding lists exist
            inside raw folder
        train (bool, optional): If True, creates dataset from ``training.pt``,
            otherwise from ``test.pt``.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        year (int): Year/version of the dataset. Available options are 2010 and 2011
    """
    urls = []
    raw_folder = 'raw'

    def __init__(self, root, year, split='train', nb_fold=0, transform=None, target_transform=None, download=False,
                 loader=default_loader, split_by=1, subsample=1, return_unsup=False, return_flags=False):
        self.root = os.path.expanduser(root)
        self.transform = transform
        self.target_transform = target_transform
        self.year = year
        self.loader = loader
        self.split_by = split_by
        self.subsample = subsample
        self.split = split
        self.nb_classes =200
        self.nb_fold = nb_fold
        assert year == 2010 or year == 2011, "Invalid version of CUB200 dataset"
        if year == 2010:
            self.urls = ['http://www.vision.caltech.edu/visipedia-data/CUB-200/images.tgz',
                         'http://www.vision.caltech.edu/visipedia-data/CUB-200/lists.tgz']

        elif year == 2011:
            self.urls = ['http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz']

        if download:
            self.download()

        if not self._check_exists():
            raise RuntimeError('Dataset not found. You can use download=True to download it')

        self.train_imgs, self.test_imgs, self.classes, self.class_to_idx = self.build_set(os.path.join(self.root, self.raw_folder),
                                                               self.year, split_by, subsample,
                                                               return_unsup, return_flags)
        print('코딩 시작')
        self.label_mat = None
        total_imgs = self.train_imgs + self.test_imgs
        total_path, total_y = list(zip(*total_imgs))
        total_y = np.array(total_y)
        total_path = np.array(total_path)

        with temp_seed(0):
            unseen_classes = np.random.permutation(range(200))[:50]

        unseen_imgs = []
        unseen_cls_idxs = []
        switch_cls_idxs = []
        total_y_onehot = encode_onehot(total_y)

        with temp_seed(nb_fold):
            # Choice Unseen class according to fold idx
            unseen_cls_idxs = np.random.choice(200, 50, replace=False)

            unseen_img_labels = []
            unseen_img_paths = []
            for cls_idx in unseen_cls_idxs:
                idxs = np.where(total_y_onehot[:, cls_idx] == 1)[0]
                sub_unseen_labels = total_y_onehot[idxs]
                sub_unseen_img_paths = total_path[idxs]

                total_y_onehot = np.delete(total_y_onehot, idxs, axis=0)
                total_path = np.delete(total_path, idxs, axis=0)

                unseen_img_labels.append(sub_unseen_labels)
                unseen_img_paths.append(sub_unseen_img_paths)

            unseen_img_labels = np.concatenate(unseen_img_labels, axis=0)
            unseen_img_paths = np.concatenate(unseen_img_paths, axis=0)

            # Split unseen data into gallery and query set
            self.query_x = unseen_img_paths[::2]
            self.query_y = unseen_img_labels[::2]

            unseen_gallery_x = np.delete(unseen_img_paths, np.arange(0, np.shape(unseen_img_paths)[0], 2), axis=0)
            unseen_gallery_y = np.delete(unseen_img_labels, np.arange(0, np.shape(unseen_img_labels)[0], 2), axis=0)

        # Delete one-hot gt for unseen class
        extracted_img_labels = np.delete(total_y_onehot, unseen_cls_idxs, axis=1)

        # Split seen data into source, gallery set
        self.source_x = total_path[::2]
        self.source_y = extracted_img_labels[::2]
        self.source_y = np.where(self.source_y != 0)[1]

        seen_gallery_x = np.delete(total_path, np.arange(0, np.shape(total_path)[0], 2), axis=0)
        seen_gallery_y = np.delete(total_y_onehot, np.arange(0, np.shape(total_path)[0], 2), axis=0)

        self.gallery_x = np.concatenate((seen_gallery_x, unseen_gallery_x), axis=0)
        self.gallery_y = np.concatenate((seen_gallery_y, unseen_gallery_y), axis=0)

        seen_eval_y = np.delete(seen_gallery_y, unseen_cls_idxs, axis=1)

        if split == 'train':
            self.imgs = list(zip(self.source_x, self.source_y))
        elif split == 'gallery':
            self.gallery_y = np.where(self.gallery_y != 0)[1]
            self.imgs = list(zip(self.gallery_x, self.gallery_y))
        elif split == 'query':
            self.label_mat = (np.matmul(self.query_y, np.transpose(self.gallery_y)) > 0).astype(np.float32)
            self.query_y = np.where(self.query_y != 0)[1]
            self.imgs = list(zip(self.query_x, self.query_y))

            # Construct label mat between gallary and query

        elif split == 'eval':
            self.imgs = list(zip(seen_gallery_x, seen_eval_y))

            self.label_mat = (np.matmul(seen_eval_y, np.transpose(seen_eval_y)) > 0).astype(np.float32)


    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        # path, target = self.imgs[index]
        # img = self.loader(path)
        #
        # if self.transform is not None:
        #     img = self.transform(img)
        #
        # if self.target_transform is not None:
        #     img = self.target_transform(img)
        #
        # return img, target

        path = self.imgs[index][0]
        img = self.loader(path)

        """Changed"""
        if self.transform is not None:
            img1 = self.transform(img)
            img2 = self.transform(img)

        if self.target_transform is not None:
            img = self.target_transform(img)

        if self.split == 'split':
            return (img1, *self.imgs[index][1:])
        return (img1, img2, *self.imgs[index][1:])


    def _check_exists(self):
        pth = os.path.join(self.root, self.raw_folder)
        if self.year == 2010:
            return os.path.exists(os.path.join(pth, 'images/')) and os.path.exists(os.path.join(pth, 'lists/'))
        elif self.year == 2011:
            return os.path.exists(os.path.join(pth, 'CUB_200_2011/'))

    def __len__(self):
        return len(self.imgs)

    def download(self):
        from six.moves import urllib
        import tarfile

        if self._check_exists():
            return

        try:
            os.makedirs(os.path.join(self.root, self.raw_folder))
        except OSError as e:
            if e.errno == errno.EEXIST:
                pass
            else:
                raise

        for url in self.urls:
            print('Downloading ' + url)
            data = urllib.request.urlopen(url)
            filename = url.rpartition('/')[2]
            file_path = os.path.join(self.root, self.raw_folder, filename)
            with open(file_path, 'wb') as f:
                f.write(data.read())
            tar = tarfile.open(file_path, 'r')
            for item in tar:
                tar.extract(item, file_path.replace(filename, ''))
            os.unlink(file_path)

        print('Done!')

    def build_set(self, root, year, split_by, subsample, return_unsup=False, return_flags=False):
        """
           Function to return the lists of paths with the corresponding labels for the images
        Args:
            root (string): Root directory of dataset
            year (int): Year/version of the dataset. Available options are 2010 and 2011
            train (bool, optional): If true, returns the list pertaining to training images and labels, else otherwise
        Returns:
            return_list: list of 236_comb_fromZeroNoise-tuples with 1st location specifying path and 2nd location specifying the class
        """
        if year == 2010:
            assert False

        elif year == 2011:
            images_file_path = os.path.join(root, 'CUB_200_2011/images/')

            all_images_list_path = os.path.join(root, 'CUB_200_2011/images.txt')
            all_images_list = np.genfromtxt(all_images_list_path, dtype=str)
            train_test_list_path = os.path.join(root, 'CUB_200_2011/train_test_split.txt')
            train_test_list = np.genfromtxt(train_test_list_path, dtype=int)

            train_imgs = []
            test_imgs = []
            classes = []
            class_to_idx = []

            for i in range(0, len(all_images_list)):
                fname = all_images_list[i, 1]
                full_path = os.path.join(images_file_path, fname)
                if train_test_list[i, 1] == 1:
                    train_imgs.append((full_path, int(fname[0:3]) - 1))
                elif train_test_list[i, 1] == 0:
                    test_imgs.append((full_path, int(fname[0:3]) - 1))
                if os.path.split(fname)[0][4:] not in classes:
                    classes.append(os.path.split(fname)[0][4:])
                    class_to_idx.append(int(fname[0:3]) - 1)

            return train_imgs, test_imgs, classes, class_to_idx

    def rebuild_imgs(self, pseudo_labels):
        imgs, labels = zip(*list(self.imgs))
        self.imgs = list(zip(imgs, pseudo_labels))