#from .base import *
import scipy.io
import torchvision.transforms as transforms
import torch.utils.data as data
import errno
import contextlib
import numpy as np
import os
from PIL import Image

@contextlib.contextmanager
def temp_seed(seed):
    state = np.random.get_state()
    np.random.seed(seed)
    try:
        yield
    finally:
        np.random.set_state(state)


def pil_loader(path):
    with open(path, 'rb') as f:
        with Image.open(f) as img:
            return img.convert('RGB')


def accimage_loader(path):
    import torchvision.datasets.accimage as accimage
    try:
        return accimage.Image(path)
    except IOError:
        return pil_loader(path)

def default_loader(path):
    from torchvision import get_image_backend
    if get_image_backend() == 'accimage':
        return accimage_loader(path)
    else:
        return pil_loader(path)

def encode_onehot(labels, num_classes=196):
    """
    one-hot labels
    Args:
        labels (numpy.ndarray): labels.
        num_classes (int): Number of classes.
    Returns:
        onehot_labels (numpy.ndarray): one-hot labels.
    """
    onehot_labels = np.zeros((len(labels), num_classes))

    for i in range(len(labels)):
        onehot_labels[i, labels[i]] = 1

    return onehot_labels

class Cars(data.Dataset):

    raw_folder = 'raw'

    def __init__(self, root, split='train',  transform = None, seen_rate=0.75, download=False, loader=default_loader,
                 nb_fold=0):
        #self.root = root + '/car196/raw/devkit'
        self.root = os.path.expanduser(root)
        self.split = split
        self.transform = transform
        self.loader = loader
        self.label_mat = None
        self.urls = ['http://imagenet.stanford.edu/internal/car196/car_ims.tgz',
                     'https://ai.stanford.edu/~jkrause/cars/car_devkit.tgz',
                     'http://imagenet.stanford.edu/internal/car196/cars_annos.mat']
        if self.split == 'train':
            self.classes = range(0,int(196*seen_rate))
        elif self.split == 'eval':
            self.classes = range(int(196*seen_rate),196)

        annos_fn = 'cars_annos.mat'
        img_file_path = os.path.join(root, self.raw_folder)
        cars = scipy.io.loadmat(os.path.join(self.root, self.raw_folder, 'devkit', annos_fn))
        ys = [int(a[5][0] - 1) for a in cars['annotations'][0]]
        im_paths = [os.path.join(img_file_path, a[0][0]) for a in cars['annotations'][0]]

        if download:
            self.download()

        if not self._check_exists():
            raise RuntimeError('Dataset not found. You can use download=True to download it')

        total_y = np.array(ys)
        total_path = np.array(im_paths)

        unseen_imgs = []
        unseen_cls_idxs = []
        switch_cls_idxs = []
        total_y_onehot = encode_onehot(total_y, num_classes=196)

        with temp_seed(nb_fold):
            # Choice Unseen class according to fold idx
            unseen_cls_idxs = np.random.choice(196, 49, replace=False)

            unseen_img_labels = []
            unseen_img_paths = []
            for cls_idx in unseen_cls_idxs:
                idxs = np.where(total_y_onehot[:, cls_idx] == 1)[0]
                sub_unseen_labels = total_y_onehot[idxs]
                sub_unseen_img_paths = total_path[idxs]

                total_y_onehot = np.delete(total_y_onehot, idxs, axis=0)
                total_path = np.delete(total_path, idxs, axis=0)

                unseen_img_labels.append(sub_unseen_labels)
                unseen_img_paths.append(sub_unseen_img_paths)

            unseen_img_labels = np.concatenate(unseen_img_labels, axis=0)
            unseen_img_paths = np.concatenate(unseen_img_paths, axis=0)

            # Split unseen data into gallery and query set
            self.query_x = unseen_img_paths[::2]
            self.query_y = unseen_img_labels[::2]

            unseen_gallery_x = np.delete(unseen_img_paths, np.arange(0, np.shape(unseen_img_paths)[0], 2), axis=0)
            unseen_gallery_y = np.delete(unseen_img_labels, np.arange(0, np.shape(unseen_img_labels)[0], 2), axis=0)

        # Delete one-hot gt for unseen class
        extracted_img_labels = np.delete(total_y_onehot, unseen_cls_idxs, axis=1)

        # Split seen data into source, gallery set
        self.source_x = total_path[::2]
        self.source_y = extracted_img_labels[::2]
        self.source_y = np.where(self.source_y != 0)[1]

        seen_gallery_x = np.delete(total_path, np.arange(0, np.shape(total_path)[0], 2), axis=0)
        seen_gallery_y = np.delete(total_y_onehot, np.arange(0, np.shape(total_path)[0], 2), axis=0)

        self.gallery_x = np.concatenate((seen_gallery_x, unseen_gallery_x), axis=0)
        self.gallery_y = np.concatenate((seen_gallery_y, unseen_gallery_y), axis=0)

        seen_eval_y = np.delete(seen_gallery_y, unseen_cls_idxs, axis=1)

        if split == 'train':
            self.imgs = list(zip(self.source_x, self.source_y))
        elif split == 'gallery':
            self.gallery_y = np.where(self.gallery_y != 0)[1]
            self.imgs = list(zip(self.gallery_x, self.gallery_y))
        elif split == 'query':
            self.label_mat = (np.matmul(self.query_y, np.transpose(self.gallery_y)) > 0).astype(np.float32)
            self.query_y = np.where(self.query_y != 0)[1]
            self.imgs = list(zip(self.query_x, self.query_y))

            # Construct label mat between gallary and query

        elif split == 'eval':
            self.imgs = list(zip(seen_gallery_x, seen_eval_y))

            self.label_mat = (np.matmul(seen_eval_y, np.transpose(seen_eval_y)) > 0).astype(np.float32)


    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        path, target = self.imgs[index]
        img = self.loader(path)

        if self.transform is not None:
            img = self.transform(img)

        return img, target

    def download(self):
        from six.moves import urllib
        import tarfile

        if self._check_exists():
            return

        try:
            os.makedirs(os.path.join(self.root, self.raw_folder))
        except OSError as e:
            if e.errno == errno.EEXIST:
                pass
            else:
                raise

        for url in self.urls:
            print('Downloading ' + url)
            data = urllib.request.urlopen(url)
            filename = url.rpartition('/')[2]
            file_path = os.path.join(self.root, self.raw_folder, filename)
            with open(file_path, 'wb') as f:
                f.write(data.read())
            tar = tarfile.open(file_path, 'r')
            for item in tar:
                tar.extract(item, file_path.replace(filename, ''))
            os.unlink(file_path)

        print('Done!')

    def _check_exists(self):
        pth = os.path.join(self.root, self.raw_folder)
        return os.path.exists(os.path.join(pth, 'devkit/'))

    def __len__(self):
        return len(self.imgs)

if __name__ == '__main__':

    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
    ])

    exs = Cars('../data/car196', split='train', transform=transform, download=True, seen_rate=1.0)
    exs_test = Cars('../data/car196', split='query', transform=transform, download=True, seen_rate=1.0)

    print(len(exs.classes))

    mean = 0
    sq_mean = 0
    for ex in exs:
        mean += ex[0].sum(1).sum(1) / (224 * 224)
        sq_mean += ex[0].pow(2).sum(1).sum(1) / (224 * 224)

    mean /= len(exs)
    sq_mean /= len(exs)

    std = (sq_mean - mean.pow(2)).pow(0.5)

    print(mean)
    print(std)