"""
    CARs196 has 16185 images in /root/car_ims

    Train data has 8144 images in /root/car_train
    Test data has 8041 images in /root/car_test

    The format of each train data and test data is .jpg

    The annotation of data is located in /root/devkit
    Please check README.txt to handle annotation information.
"""
from __future__ import print_function
import os
import errno
import numpy as np
from PIL import Image
import torch.utils.data as data
import scipy.io as sio
import torchvision.transforms as transforms


def pil_loader(path):
    with open(path, 'rb') as f:
        with Image.open(f) as img:
            return img.convert('RGB')


def accimage_loader(path):
    import torchvision.datasets.accimage as accimage
    try:
        return accimage.Image(path)
    except IOError:
        return pil_loader(path)


def default_loader(path):
    from torchvision import get_image_backend
    if get_image_backend() == 'accimage':
        return accimage_loader(path)
    else:
        return pil_loader(path)


class Car196(data.Dataset):
    """
    Args:
        root (string): Root directory of dataset the images and corresponding lists exist
            inside raw folder
        train (bool, optional): If True, creates dataset from ``training.pt``,
            otherwise from ``test.pt``.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.

    """
    urls = []
    raw_folder = 'raw'

    def __init__(self, root, split='train', transform=None, target_transform=None, download=False,
                 loader=default_loader, seen_rate=0.75):
        self.root = os.path.expanduser(root)
        self.transform = transform
        self.target_transform = target_transform
        self.split = split
        self.loader = loader
        self.urls = ['http://imagenet.stanford.edu/internal/car196/car_ims.tgz',
                     'https://ai.stanford.edu/~jkrause/cars/car_devkit.tgz',
                     'http://imagenet.stanford.edu/internal/car196/cars_annos.mat']
        self.seen_rate = seen_rate

        if download:
            self.download()

        if not self._check_exists():
            raise RuntimeError('Dataset not found. You can use download=True to download it')

        self.imgs, self.classes, self.class_to_idx = self._build_set(os.path.join(self.root, self.raw_folder),
                                                                     self.split, seen_rate=self.seen_rate)

    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        path, target = self.imgs[index]
        img = self.loader(path)

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            img = self.target_transform(img)

        return img, target

    def _check_exists(self):
        pth = os.path.join(self.root, self.raw_folder)
        return os.path.exists(os.path.join(pth, 'devkit/'))

    def __len__(self):
        return len(self.imgs)

    def download(self):
        from six.moves import urllib
        import tarfile

        if self._check_exists():
            return

        try:
            os.makedirs(os.path.join(self.root, self.raw_folder))
        except OSError as e:
            if e.errno == errno.EEXIST:
                pass
            else:
                raise

        for url in self.urls:
            print('Downloading ' + url)
            data = urllib.request.urlopen(url)
            filename = url.rpartition('/')[2]
            file_path = os.path.join(self.root, self.raw_folder, filename)
            with open(file_path, 'wb') as f:
                f.write(data.read())
            tar = tarfile.open(file_path, 'r')
            for item in tar:
                tar.extract(item, file_path.replace(filename, ''))
            os.unlink(file_path)

        print('Done!')

    def _build_set(self, root, train, seen_rate=0.75):
        """
           Function to return the lists of paths with the corresponding labels for the images
        Args:
            root (string): Root directory of dataset

            train (bool, optional): If true, returns the list pertaining to training images and labels, else otherwise
        Returns:
            return_list: list of 2-tuples with 1st location specifying path and 2nd location specifying the class
        """

        images_file_path = os.path.join(root, 'car_ims/')
        devkit_path = os.path.join(root, 'devkit/')

        # load mat files which contain files' name and class informations.
        annotation_file = sio.loadmat(devkit_path + 'cars_annos.mat')
        classes_file = sio.loadmat(devkit_path + 'cars_meta.mat')

        all_images_list = annotation_file['annotations']['relative_im_path'][0]
        classes_idx = annotation_file['annotations']['class'][0]
        classes_name = classes_file['class_names'][0]
        train_test_split = annotation_file['annotations']['test'][0]

        ys = [int(a[5][0] - 1) for a in annotation_file['annotations'][0]]
        im_paths = [a[0][0] for a in annotation_file['annotations'][0]]


        """

            all_images_list: 
                It contains image name e.g. 016185.jpg
            classes_idx:
                 It contains class integer
            classes_name: 
                classes_name has 196 class names
            train_test_split: 
                If the value is 1, the image is in test dataset. otherwise, in train dataset


        """

        imgs = []
        classes = []
        class_to_idx = []

        for i in range(0, len(all_images_list)):
            fname = all_images_list[i][0][8:]
            full_path = os.path.join(images_file_path, fname)

            if int(train_test_split[i]) == 0 and train:
                imgs.append((full_path, int(classes_idx[i]) - 1))
            elif int(train_test_split[i]) == 1 and not train:
                imgs.append((full_path, int(classes_idx[i]) - 1))

        for j in range(0, len(classes_name)):
            classes.append(str(classes_name[j][0]))
            class_to_idx.append(j)

        n_target_classes = int(len(classes) * seen_rate)

        imgs = [(x, y, True if y >= n_target_classes else False) for x, y in imgs]
        imgs = [(y[0], y[1] - n_target_classes) for y in filter(lambda x: x[2], imgs)]

        if self.train_portion < 1.0:
            subset_imgs = []
            for cls in range(len(classes)):
                sub_imgs = [(x, y, True if y == cls else False) for x, y in imgs]
                sub_imgs = [(y[0], y[1]) for y in filter(lambda x: x[2], sub_imgs)]
                idx = np.random.permutation(len(sub_imgs))[:int(len(sub_imgs) * self.train_portion)]
                # sub_imgs = np.array(sub_imgs)[idx].tolist()
                for i in idx:
                    subset_imgs.append(sub_imgs[i])
            imgs = subset_imgs

        return imgs, classes, class_to_idx


if __name__ == '__main__':

    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
    ])

    exs = Car196('../data/car196', split='train', transform=transform, download=True, seen_rate=1.0)
    exs_test = Car196('../data/car196', split='query', transform=transform, download=True, seen_rate=1.0)

    print(len(exs.classes))

    mean = 0
    sq_mean = 0
    for ex in exs:
        mean += ex[0].sum(1).sum(1) / (224 * 224)
        sq_mean += ex[0].pow(2).sum(1).sum(1) / (224 * 224)

    mean /= len(exs)
    sq_mean /= len(exs)

    std = (sq_mean - mean.pow(2)).pow(0.5)

    print(mean)
    print(std)