import logging
import os
from os.path import join
from typing import Any, Callable, Optional, Tuple

import numpy as np
from PIL import Image
import torch.utils.data as data
from torchvision.datasets import VisionDataset
from torchvision.datasets.utils import check_integrity, download_and_extract_archive

log = logging.getLogger(__name__)


class ImageDatasetBase(VisionDataset):
    """
    Base Class for Downloading Image related Datasets

    Code inspired by : https://pytorch.org/vision/0.8/_modules/torchvision/datasets/cifar.html#CIFAR10
    """

    base_folder = None
    url = None
    filename = None
    md5hash = None

    def __init__(
        self,
        root: str,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        download: bool = False,
    ) -> None:
        super(ImageDatasetBase, self).__init__(
            root, transform=transform, target_transform=target_transform
        )

        if download:
            self.download()

        if not self._check_integrity():
            raise RuntimeError(
                "Dataset not found or corrupted." + " You can use download=True to download it"
            )

        self.basedir = os.path.join(self.root, self.base_folder)
        self.files = os.listdir(self.basedir)

    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        file, target = self.files[index], -1

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        path = os.path.join(self.root, self.base_folder, file)
        img = Image.open(path)

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

    def __len__(self) -> int:
        return len(self.files)

    def _check_integrity(self) -> bool:
        fpath = os.path.join(self.root, self.filename)
        print(fpath)
        return check_integrity(fpath, self.md5hash)

    def download(self) -> None:
        if self._check_integrity():
            log.debug("Files already downloaded and verified")
            return
        download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.md5hash)



class CIFAR10C(ImageDatasetBase):
    """
    Corrupted version of the CIFAR10 from the paper *Benchmarking Neural
    Network Robustness to Common Corruptions and Perturbations.*

    :see Website: `Zenodo <https://zenodo.org/record/2535967>`__
    :see Paper: `ArXiv <https://arxiv.org/abs/1903.12261>`__
    """

    subsets = [
        "contrast",
        "defocus_blur",
        "elastic_transform",
        "fog",
        "frost",
        "gaussian_blur",
        "gaussian_noise",
        "glass_blur",
        "impulse_noise",
        "jpeg_compression",
        "motion_blur",
        "pixelate",
        "saturate",
        "shot_noise",
        "snow",
        "spatter",
        "speckle_noise",
        "zoom_blur",
    ]

    base_folder = "CIFAR-10-C"
    url = "https://zenodo.org/record/2535967/files/CIFAR-10-C.tar"
    filename = "CIFAR-10-C.tar"
    md5hash = "56bf5dcef84df0e2308c6dcbcbbd8499"

    def __init__(
        self,
        root: str,
        subset: str,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        download: bool = False,
    ):
        super(CIFAR10C, self).__init__(root, transform, target_transform, download)

        self.subset = subset
        print(root, self.base_folder)

        if subset not in self.subsets and subset != "all":
            raise ValueError(f"Unknown Subset: {subset}")

        if subset == "all":
            self.data = np.concatenate(
                [np.load(join(root, self.base_folder, f"{s}.npy")) for s in self.subsets]
            )
            self.targets = np.concatenate(
                [np.load(join(root, self.base_folder, "labels.npy")) for s in self.subsets]
            )
        else:
            self.data = np.load(join(root, self.base_folder, f"{subset}.npy"))
            self.targets = np.load(join(root, self.base_folder, "labels.npy"))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        img = self.data[index]
        target = self.targets[index]

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img)

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target


class CIFAR100C(CIFAR10C):
    """
    Corrupted version of the CIFAR100 from the paper *Benchmarking Neural Network
    Robustness to Common Corruptions and Perturbations.*

    :see Website: `Zenodo <https://zenodo.org/record/3555552>`__
    :see Paper: `ArXiv <https://arxiv.org/abs/1903.12261>`__
    """

    base_folder = "CIFAR-100-C/"
    url = "https://zenodo.org/record/3555552/files/CIFAR-100-C.tar"
    filename = "CIFAR-100-C.tar"
    md5hash = "11f0ed0f1191edbf9fa23466ae6021d3" #"11f0ed0f1191edbf9fa23466ae6021d3 "
