#from __future__ import print_function, division

import torch
import numpy as np
from sklearn.preprocessing import StandardScaler
import random
from PIL import Image
import torch.utils.data as data
import os
import os.path
import accimage
import contextlib


@contextlib.contextmanager
def temp_seed(seed):
    state = np.random.get_state()
    np.random.seed(seed)
    try:
        yield
    finally:
        np.random.set_state(state)

def make_dataset(image_list, labels, root):
    if labels:  # labels=None for imagenet
      len_ = len(image_list)
      images = [(image_list[i].strip(), labels[i, :]) for i in range(len_)]
    else:      # split and get the labels
      if len(image_list[0].split()) > 2:
        #images = [(os.path.join(root, 'images', val.split()[0]), np.array([int(la) for la in val.split()[1:]])) for val in image_list]
        images = [(os.path.join(root, val.split()[0]), np.array([int(la) for la in val.split()[1:]])) for val
                  in image_list]
      else:
        images = [(val.split()[0], int(val.split()[1])) for val in image_list]
    return images


def encode_onehot(labels, num_classes=15):
    """
    one-hot labels
    Args:
        labels (numpy.ndarray): labels.
        num_classes (int): Number of classes.
    Returns:
        onehot_labels (numpy.ndarray): one-hot labels.
    """
    onehot_labels = np.zeros((len(labels), num_classes))

    for i in range(len(labels)):
        onehot_labels[i, labels[i]] = 1

    return onehot_labels


def pil_loader(path):
    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
    with open(path, 'rb') as f:
        with Image.open(f) as img:
            return img.convert('RGB')


def accimage_loader(path):
    import accimage
    try:
        return accimage.Image(path)
    except IOError:
        # Potentially a decoding problem, fall back to PIL.Image
        return pil_loader(path)


def default_loader(path):
    #from torchvision import get_image_backend
    #if get_image_backend() == 'accimage':
    #    return accimage_loader(path)
    #else:
        return pil_loader(path)


class ImageList(object):
    """A generic data loader where the images are arranged in this way: ::
        root/dog/xxx.png
        root/dog/xxy.png
        root/dog/xxz.png
        root/cat/123.png
        root/cat/nsdf3.png
        root/cat/asd932_.png
    Args:
        root (string): Root directory path.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        loader (callable, optional): A function to load an image given its path.
     Attributes:
        classes (list): List of the class names.
        class_to_idx (dict): Dict with items (class_name, class_index).
        imgs (list): List of (image path, class_index) tuples
    """
    #train_list = 'nuswide_Train.txt'
    #db_list = 'nuswide_DB.txt'
    #query_list = 'nuswide_Query.txt'
    train_list = 'train.txt'
    db_list = 'database.txt'
    query_list = 'test.txt'

    def __init__(self, root, split='train', labels=None, transform=None, target_transform=None,
                 loader=default_loader, nb_fold=0):  # ImageList(image_list = '../data/imagenet/train.txt')
        self.root = os.path.expanduser(root)
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader
        self.label_mat = None
        if split == 'train':
            self.nb_classes = 15
            image_list = open(os.path.join(self.root, 'fold_%d' % (nb_fold), self.train_list)).readlines()
        elif split == 'query':
            self.nb_classes = 6
            image_list = open(os.path.join(self.root, 'fold_%d' % (nb_fold), self.query_list)).readlines()

            label_mat = torch.load(os.path.join(self.root, 'fold_%d' % (nb_fold), 'Label_mat.tar'))
            self.label_mat = label_mat.numpy()
        elif split == 'gallery':
            self.nb_classes = 21
            image_list = open(os.path.join(self.root, 'fold_%d' % (nb_fold), self.db_list)).readlines()
        else:
            assert False

        imgs = make_dataset(image_list, labels, self.root)

        if len(imgs) == 0:
            raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"
                               ))

        self.imgs = imgs

    def rebuild_imgs(self, pseudo_labels):
        imgs, labels = zip(*list(self.imgs))
        pseudo_labels_one_hot = encode_onehot(pseudo_labels, num_classes=15)
        self.imgs = list(zip(imgs, pseudo_labels_one_hot))

    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (image, target) where target is class_index of the target class.
        """
        path, target = self.imgs[index]
        img = self.loader(path)
        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            target = self.target_transform(target)

        return (img, target)

    def __len__(self):
        return len(self.imgs)


class ImageList_2(object):
    """A generic data loader where the images are arranged in this way: ::
        root/dog/xxx.png
        root/dog/xxy.png
        root/dog/xxz.png
        root/cat/123.png
        root/cat/nsdf3.png
        root/cat/asd932_.png
    Args:
        root (string): Root directory path.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        loader (callable, optional): A function to load an image given its path.
     Attributes:
        classes (list): List of the class names.
        class_to_idx (dict): Dict with items (class_name, class_index).
        imgs (list): List of (image path, class_index) tuples
    """
    #train_list = 'nuswide_Train.txt'
    #db_list = 'nuswide_DB.txt'
    #query_list = 'nuswide_Query.txt'
    train_list = 'train.txt'
    db_list = 'database.txt'
    query_list = 'test.txt'

    def __init__(self, root, split='train', labels=None, transform=None, target_transform=None,
                 loader=default_loader, nb_fold=0):  # ImageList(image_list = '../data/imagenet/train.txt')
        self.split = split
        self.root = os.path.expanduser(root)
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader
        self.label_mat = None
        if split == 'train':
            self.nb_classes = 15
            image_list = open(os.path.join(self.root, 'fold_%d' % (nb_fold), self.train_list)).readlines()
        elif split == 'query':
            self.nb_classes = 6
            image_list = open(os.path.join(self.root, 'fold_%d' % (nb_fold), self.query_list)).readlines()

            label_mat = torch.load(os.path.join(self.root, 'fold_%d' % (nb_fold), 'Label_mat.tar'))
            self.label_mat = label_mat.numpy()
        elif split == 'gallery':
            self.nb_classes = 21
            image_list = open(os.path.join(self.root, 'fold_%d' % (nb_fold), self.db_list)).readlines()
        else:
            assert False

        imgs = make_dataset(image_list, labels, self.root)

        if len(imgs) == 0:
            raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"
                               ))

        self.imgs = imgs

    def rebuild_imgs(self, pseudo_labels):
        imgs, labels = zip(*list(self.imgs))
        pseudo_labels_one_hot = encode_onehot(pseudo_labels, num_classes=21)
        self.imgs = list(zip(imgs, pseudo_labels_one_hot))

    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (image, target) where target is class_index of the target class.
        """
        path, target = self.imgs[index]
        img = self.loader(path)
        if self.transform is not None:
            img1 = self.transform(img)
            img2 = self.transform(img)
        if self.target_transform is not None:
            target = self.target_transform(target)

        if self.split == 'split':
            return (img, target)
        return (img1, img2, target)

    def __len__(self):
        return len(self.imgs)


if __name__ == '__main__':
    root = '../data/nus_wide'
    nb_fold = 3
    split_dataset = ImageList(root, split='query', nb_fold=nb_fold)
    print('done')



