import numpy as np
from PIL import Image
import os
import sys
from six.moves import cPickle
import pickle
from torchvision.datasets.utils import download_url, check_integrity
from torch.utils.data import Dataset


def load_cifar10h_labels(type="aggregate", path="/data/cifar10h/"):

    """
    Type is 'aggregate' (probabilities) or 'sparse'
    (individual choices).

    """
    if type == "raw":
        f_path = os.path.join(path, "raw")

        with open(f_path, "rb") as f:  # straight from keras
            if sys.version_info < (3,):
                d = cPickle.load(f)
            else:
                d = cPickle.load(f, encoding="bytes")
                # decode utf8
                d_decoded = {}
                for k, v in d.items():
                    d_decoded[str(k)] = {}
                    for j, l in v.items():
                        d_decoded[str(k)][str(j)] = int(l[0])
                d = d_decoded
        return d


class CIFAR10H(Dataset):
    base_folder = "cifar-10-batches-py"
    url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
    filename = "cifar-10-python.tar.gz"
    tgz_md5 = "c58f30108f718f92721af3b95e74349a"
    c10_train_list = [
        ["data_batch_1", "c99cafc152244af753f735de768cd75f"],
        ["data_batch_2", "d4bba439e000b95fd0a9bffe97cbabec"],
        ["data_batch_3", "54ebc095f3ab1f0389bbae665268c751"],
        ["data_batch_4", "634d18415352ddfa80567beed471001a"],
        ["data_batch_5", "482c414d41f54cd18b22e5b47cb7c3cb"],
    ]
    c10_test_list = [["test_batch", "40351d587109b95175f43aff81a1287e"]]
    meta = {
        "filename": "batches.meta",
        "key": "label_names",
        "md5": "5ff9c542aee3614f3951f8cda6e48888",
    }

    def __init__(
        self,
        root,
        which_set="train",
        train=True,
        transform=None,
        target_transform=None,
        download=False,
    ):
        self.root = os.path.expanduser(root)
        self.transform = transform
        self.target_transform = target_transform
        self.set = which_set  # training set or test set

        if download:
            self.download()

        if train and self.set != "val":
            self.set = "train"
        if not self._check_integrity():
            raise RuntimeError(
                "Dataset not found or corrupted."
                + " You can use download=True to download it"
            )

        if self.set in ["train", "val", "test"]:
            # handles both the cifar10h (human) training and test sets
            self.c10h_data = []
            self.c10h_c10_targets = []
            if self.set in ["train", "val"]:
                downloaded_list = self.c10_test_list
            else:
                downloaded_list = self.c10_train_list
            # now load the picked numpy arrays
            for file_name, checksum in downloaded_list:
                file_path = os.path.join(
                    self.root, self.base_folder, file_name
                )
                with open(file_path, "rb") as f:
                    if sys.version_info[0] == 2:
                        entry = pickle.load(f)
                    else:
                        entry = pickle.load(f, encoding="latin1")
                    self.c10h_data.append(entry["data"])
                    if "labels" in entry:
                        self.c10h_c10_targets.extend(entry["labels"])

            self.c10h_data = np.vstack(self.c10h_data).reshape(-1, 3, 32, 32)
            self.c10h_data = self.c10h_data.transpose(
                (0, 2, 3, 1)
            )  # convert to HWC
        if self.set == "train":
            self.c10h_data = self.c10h_data[:-500, :, :, :]
            self.c10h_c10_targets = self.c10h_c10_targets[:-500]
        if self.set == "val":
            self.c10h_data = self.c10h_data[-500:, :, :, :]
            self.c10h_c10_targets = self.c10h_c10_targets[-500:]

    def get_raw(self):
        labs = load_cifar10h_labels(
            type="raw",
            path=os.path.join(os.path.dirname(__file__), "data", "cifar10h"),
        )
        self.raw_crowd = labs

    def __getitem__(self, index):

        if self.set in ["train", "test"]:
            img, c10h_c10_targets = (
                self.c10h_data[index],
                self.c10h_c10_targets[index],
            )

            # doing this so that it is consistent with all other datasets
            # to return a PIL Image
            img = Image.fromarray(img)

            if self.transform is not None:
                img = self.transform(img)

            if self.target_transform is not None:
                c10h_c10_targets = self.target_transform(c10h_c10_targets)

            return img, c10h_c10_targets

    def __len__(self):
        if self.set in ["train", "val", "test"]:
            return len(self.c10h_data)

    def _check_integrity(self):
        root = self.root
        for fentry in self.c10_train_list + self.c10_test_list:
            filename, md5 = fentry[0], fentry[1]
            fpath = os.path.join(root, self.base_folder, filename)
            if not check_integrity(fpath, md5):
                return False
        return True

    def download(self):
        import tarfile

        if self._check_integrity():
            print("Files already downloaded and verified")
            return

        download_url(self.url, self.root, self.filename, self.tgz_md5)

        # extract file
        with tarfile.open(
            os.path.join(self.root, self.filename), "r:gz"
        ) as tar:
            tar.extractall(path=self.root)

    def __repr__(self):
        fmt_str = "Dataset " + self.__class__.__name__ + "\n"
        fmt_str += "    Number of datapoints: {}\n".format(self.__len__())
        tmp = "train" if self.train is True else "test"
        fmt_str += "    Root Location: {}\n".format(self.root)
        tmp = "    Transforms (if any): "
        fmt_str += "{0}{1}\n".format(
            tmp, self.transform.__repr__().replace("\n", "\n" + " " * len(tmp))
        )
        tmp = "    Target Transforms (if any): "
        fmt_str += "{0}{1}".format(
            tmp,
            self.target_transform.__repr__().replace(
                "\n", "\n" + " " * len(tmp)
            ),
        )
        return fmt_str


class CrowdCIFAR(Dataset):
    def __init__(
        self,
        set,
        targets,
        transform=None,
    ):
        self.targets = targets
        self.set = set
        self.transform = transform

    def __getitem__(self, index):
        img, target, true_target, index = (
            self.set.c10h_data[index],
            self.targets[index],
            self.set.c10h_c10_targets[index],
            index,
        )

        img = Image.fromarray(img, "RGB")
        if self.transform is not None:
            img = self.transform(img)

        return img, target, true_target, index

    def __len__(self):
        return len(self.set.c10h_data)
