from __future__ import print_function
import os

import numpy as np
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
# from .image_transforms import Low_pass_filter
from PIL import Image
import pdb




_CIFAR_IMAGE_SIZE = (32, 32, 3)
_CIFAR_CLASSES = 10
_CORRUPTIONS_TO_FILENAMES = {
    'gaussian_noise': 'gaussian_noise.npy',
    'shot_noise': 'shot_noise.npy',
    'impulse_noise': 'impulse_noise.npy',
    'defocus_blur': 'defocus_blur.npy',
    'frosted_glass_blur': 'glass_blur.npy',
    'motion_blur': 'motion_blur.npy',
    'zoom_blur': 'zoom_blur.npy',
    'snow': 'snow.npy',
    'frost': 'frost.npy',
    'fog': 'fog.npy',
    'brightness': 'brightness.npy',
    'contrast': 'contrast.npy',
    'elastic': 'elastic_transform.npy',
    'pixelate': 'pixelate.npy',
    'jpeg_compression': 'jpeg_compression.npy',
    'gaussian_blur': 'gaussian_blur.npy',
    'saturate': 'saturate.npy',
    'spatter': 'spatter.npy',
    'speckle_noise': 'speckle_noise.npy',
}
_CORRUPTIONS, _FILENAMES = zip(*sorted(_CORRUPTIONS_TO_FILENAMES.items()))
_DIRNAME = 'CIFAR-10-C'
_LABELS_FILENAME = 'labels.npy'

BENCHMARK_CORRUPTIONS = [
    'gaussian_noise',
    'shot_noise',
    'impulse_noise',
    'defocus_blur',
    'frosted_glass_blur',
    'motion_blur',
    'zoom_blur',
    'snow',
    'frost',
    'fog',
    'brightness',
    'contrast',
    'elastic',
    'pixelate',
    'jpeg_compression',
]

EXTRA_CORRUPTIONS = [
    'gaussian_blur',
    'saturate',
    'spatter',
    'speckle_noise',
]

data_path = "~/dataset/CIFAR-10-C"

class CIFAR10_C(datasets.CIFAR10):

    def __init__(self,root=data_path, transform=None, target_transform=None,corrupted_name="clean",
                      severity=1,train=False,**kwargs):
        super(CIFAR10_C, self).__init__(root="~/dataset/CIFAR10", transform=transform,
                                        train = train,
                                        target_transform=target_transform,
                                       )

        self.transform = transform
        self.target_transform = target_transform

        assert corrupted_name in list(_CORRUPTIONS_TO_FILENAMES.keys())+['clean','gaussian_noise2','adversarial_noise'], \
           "The noise type is not existed!"

        if corrupted_name in ["clean","gaussian_noise2"]:
            self.test_data = self.data
            self.test_labels = self.targets
        elif corrupted_name in ['adversarial_noise']:
            noise_strength = [0.0,0.005,0.01,0.015,0.02,0.025,0.03,0.035,0.04]
            test_clean_data = self.data
            self.test_labels = self.targets
            work_dir = kwargs.get("work_dir").split('/')[-1]
            adversarial_noise = np.load("./noise_dataC1/{}/adversarial_noise.npy".format(work_dir))
            adversarial_noise = adversarial_noise.transpose(0,2,3,1)

            # pdb.set_trace()
            self.test_data = test_clean_data + np.array(adversarial_noise*noise_strength[severity-1],np.uint8)
            self.test_data[self.test_data>255] = 255
            self.test_data[self.test_data<0] = 0
        else:
            assert severity in [1,2,3,4,5], "There are 5 severities, [1,2,3,4,5]."
            test_labels = np.load(os.path.join(root,"labels.npy"))
            test_data = np.load(os.path.join(root,_CORRUPTIONS_TO_FILENAMES[corrupted_name]))
            self.test_data = test_data.reshape(5,10000,32,32,3)[int(severity-1)]
            self.test_labels = np.array(test_labels,np.int32).reshape(5,10000)[int(severity-1)]

    def __getitem__(self, index):
        img, target = self.test_data[index], self.test_labels[index]

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img)

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

    def __len__(self):
        return len(self.test_labels)


class CIFAR10_C_twopass(CIFAR10_C):
    def __init__(self, coarse_transform=None,fine_transform=None, target_transform=None,corrupted_name="clean",
                      severity=1,train=False,**kwargs):
        super(CIFAR10_C_twopass,self).__init__(root=data_path,train=train,
                           severity=severity,corrupted_name=corrupted_name,**kwargs)
        self.coarse_transform = coarse_transform
        self.fine_transform = fine_transform

    def __getitem__(self,index):
        img, target = self.test_data[index], self.test_labels[index]
        img = Image.fromarray(img)

        if self.fine_transform is not None:
            fine_img = self.fine_transform(img)
        else:
            fine_img = img

        if self.coarse_transform is not None:
            coarse_img = self.coarse_transform(img)
        else:
            coarse_img = img

        if self.target_transform is not None:
            target = self.target_transform(target)

        return fine_img, coarse_img, target

if __name__ == "__main__":
    for i in range(1,6):
        dd = CIFAR10_C_twopass(corrupted_name="clean",severity=i,train=True)
        dd[0][0].save("test-{}.png".format(i))













##
