from PIL import Image
import os
import numpy as np
from torchvision.datasets import VisionDataset


class CorruptedCIFAR10(VisionDataset):
    def __init__(self, root, severity, transform=None, target_transform=None,
                 corruptions=("brightness", "contrast", "defocus_blur", "elastic_transform", "fog",
                              "frost", "gaussian_blur", "gaussian_noise", "glass_blur", "impulse_noise",
                              "jpeg_compression", "motion_blur", "pixelate", "saturate", "shot_noise",
                              "snow", "spatter", "speckle_noise", "zoom_blur")):
        super().__init__(root, transform=transform, target_transform=target_transform)
        if severity < 1 or severity > 5:
            raise Exception("Expected corruption corruption_severity to be 1 - 5 but got \'{}\' instead.".format(severity))

        self.data = []
        self.targets = []

        labels = np.load(os.path.join(root, "CIFAR-10-C", "labels.npy"))
        for corruption in corruptions:
            images = np.load(os.path.join(root, "CIFAR-10-C", corruption + ".npy"))
            idx = (severity - 1) * 10000
            self.data.extend(images[idx:idx + 10000])
            self.targets.extend(labels[idx:idx + 10000].astype(np.int64))

    def __getitem__(self, index):
        img, target = self.data[index], self.targets[index]
        img = Image.fromarray(img)

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

    def __len__(self):
        return len(self.data)


class CorruptedCIFAR100(VisionDataset):
    def __init__(self, root, severity, transform=None, target_transform=None,
                 corruptions=("brightness", "contrast", "defocus_blur", "elastic_transform", "fog",
                              "frost", "gaussian_blur", "gaussian_noise", "glass_blur", "impulse_noise",
                              "jpeg_compression", "motion_blur", "pixelate", "saturate", "shot_noise",
                              "snow", "spatter", "speckle_noise", "zoom_blur")):
        super().__init__(root, transform=transform, target_transform=target_transform)
        if severity < 1 or severity > 5:
            raise Exception("Expected corruption corruption_severity to be 1 - 5 but got \'{}\' instead.".format(severity))

        self.data = []
        self.targets = []

        labels = np.load(os.path.join(root, "CIFAR-100-C", "labels.npy"))
        for corruption in corruptions:
            images = np.load(os.path.join(root, "CIFAR-100-C", corruption + ".npy"))
            idx = (severity - 1) * 10000
            self.data.extend(images[idx:idx + 10000])
            self.targets.extend(labels[idx:idx + 10000].astype(np.int64))

    def __getitem__(self, index):
        img, target = self.data[index], self.targets[index]
        img = Image.fromarray(img)

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

    def __len__(self):
        return len(self.data)
