import os
from pathlib import Path
from typing import Any, Callable, Optional, Tuple

from torchvision.datasets import VisionDataset
from torchvision.datasets.utils import (
    check_integrity,
    download_and_extract_archive,
)

import numpy as np

class CIFAR10C(VisionDataset):
    """The corrupted CIFAR-10-C Dataset.

    Args:
        root (str): Root directory of the datasets.
        transform (callable, optional): A function/transform that takes in
            a PIL image and returns a transformed version. E.g,
            ``transforms.RandomCrop``. Defaults to None.
        target_transform (callable, optional): A function/transform that
            takes in the target and transforms it. Defaults to None.
        subset (str): The subset to use, one of ``all`` or the keys in
            ``cifarc_subsets``.
        severity (int): The severity of the corruption, between 1 and 5.
        download (bool, optional): If True, downloads the dataset from the
            internet and puts it in root directory. If dataset is already
            downloaded, it is not downloaded again. Defaults to False.

    References:
        Benchmarking neural network robustness to common corruptions and
        perturbations. Dan Hendrycks and Thomas Dietterich. In ICLR, 2019.
    """

    base_folder = "CIFAR-10-C"
    tgz_md5 = "56bf5dcef84df0e2308c6dcbcbbd8499"
    cifarc_subsets = [
        "brightness",
        "contrast",
        "defocus_blur",
        "elastic_transform",
        "fog",
        "frost",
        "gaussian_noise",
        "glass_blur",
        "impulse_noise",
        "jpeg_compression",
        "motion_blur",
        "pixelate",
        "shot_noise",
        "snow",
        "zoom_blur",
    ]

    ctest_list = [
        ["fog.npy", "7b397314b5670f825465fbcd1f6e9ccd"],
        ["jpeg_compression.npy", "2b9cc4c864e0193bb64db8d7728f8187"],
        ["zoom_blur.npy", "6ea8e63f1c5cdee1517533840641641b"],
        ["speckle_noise.npy", "ef00b87611792b00df09c0b0237a1e30"],
        ["glass_blur.npy", "7361fb4019269e02dbf6925f083e8629"],
        ["spatter.npy", "8a5a3903a7f8f65b59501a6093b4311e"],
        ["shot_noise.npy", "3a7239bb118894f013d9bf1984be7f11"],
        ["defocus_blur.npy", "7d1322666342a0702b1957e92f6254bc"],
        ["elastic_transform.npy", "9421657c6cd452429cf6ce96cc412b5f"],
        ["gaussian_blur.npy", "c33370155bc9b055fb4a89113d3c559d"],
        ["frost.npy", "31f6ab3bce1d9934abfb0cc13656f141"],
        ["saturate.npy", "1cfae0964219c5102abbb883e538cc56"],
        ["brightness.npy", "0a81ef75e0b523c3383219c330a85d48"],
        ["snow.npy", "bb238de8555123da9c282dea23bd6e55"],
        ["gaussian_noise.npy", "ecaf8b9a2399ffeda7680934c33405fd"],
        ["motion_blur.npy", "fffa5f852ff7ad299cfe8a7643f090f4"],
        ["contrast.npy", "3c8262171c51307f916c30a3308235a8"],
        ["impulse_noise.npy", "2090e01c83519ec51427e65116af6b1a"],
        ["labels.npy", "c439b113295ed5254878798ffe28fd54"],
        ["pixelate.npy", "0f14f7e2db14288304e1de10df16832f"],
    ]
    url = "https://zenodo.org/record/2535967/files/CIFAR-10-C.tar"
    filename = "CIFAR-10-C.tar"

    def __init__(
        self,
        root: str,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        subset: str = "all",
        severity: int = 1,
        download: bool = False,
    ):
        if isinstance(root, str):
            root = Path(root)

        self.root = root
        # Download the new targets
        if download:
            self.download()

        if not self._check_integrity():
            raise RuntimeError(
                "Dataset not found. You can use download=True to download it."
            )

        super().__init__(
            root=root / self.base_folder,
            transform=transform,
            target_transform=target_transform,
        )
        if not (subset in ["all"] + self.cifarc_subsets):
            raise ValueError(
                f"The subset '{subset}' does not exist in CIFAR-C."
            )
        self.subset = subset
        self.severity = severity

        if severity not in list(range(1, 6)):
            raise ValueError(
                "Corruptions severity should be chosen between 1 and 5 "
                "included."
            )
        samples, labels = self.make_dataset(
            self.root, self.subset, self.severity
        )

        self.samples = samples
        self.labels = labels.astype(np.int64)

    def make_dataset(
        self, root: Path, subset: str, severity: int
    ) -> Tuple[np.ndarray, np.ndarray]:
        r"""
        Build the corrupted dataset according to the chosen subset and
            severity. If the subset is 'all', gather all corruption types
            in the dataset.
        Args:
            root (Path):The path to the dataset.
            subset (str): The name of the corruption subset to be used. Choose
                `all` for the dataset to contain all subsets.
            severity (int): The severity of the corruption applied to the
                images.
        Returns:
            Tuple[np.ndarray, np.ndarray]: The samples and labels of the chosen
        """
        if subset == "all":
            sample_arrays = []
            labels: np.ndarray = np.load(root / "labels.npy")[
                (severity - 1) * 10000 : severity * 10000
            ]
            for cifar_subset in self.cifarc_subsets:
                sample_arrays.append(
                    np.load(root / (cifar_subset + ".npy"))[
                        (severity - 1) * 10000 : severity * 10000
                    ]
                )
            samples = np.concatenate(sample_arrays, axis=0)
            labels = np.tile(labels, len(self.cifarc_subsets))

        else:
            samples: np.ndarray = np.load(root / (subset + ".npy"))[
                (severity - 1) * 10000 : severity * 10000
            ]
            labels: np.ndarray = np.load(root / "labels.npy")[
                (severity - 1) * 10000 : severity * 10000
            ]
        return samples, labels

    def __len__(self) -> int:
        """The number of samples in the dataset."""
        return self.labels.shape[0]

    def __getitem__(self, index: int) -> Any:
        sample, target = (
            self.samples[index],
            self.labels[index],
        )

        if self.transform is not None:
            sample = self.transform(sample)
        if self.target_transform is not None:
            target = self.target_transform(target)
        # return sample, target, index
        return sample, target

    def _check_integrity(self) -> bool:
        """Check the integrity of the dataset."""
        for filename, md5 in self.ctest_list:
            fpath = os.path.join(self.root, self.base_folder, filename)
            if not check_integrity(fpath, md5):
                return False
        return True

    def download(self) -> None:
        """Download the dataset."""
        if self._check_integrity():
            print("Files already downloaded and verified.")
            return
        download_and_extract_archive(
            self.url, self.root, filename=self.filename, md5=self.tgz_md5
        )