import torch
import math
import random
from PIL import Image
try:
    import accimage
except ImportError:
    accimage = None
import numpy as np
import numbers
import types
from collections.abc import Sequence, Iterable
import warnings
import torchvision.transforms.functional as F
from PIL import ImageFilter

class Low_pass_filter(object):
    """low pass filter.
    Args:
        std (sequence): Sequence of standard deviations for each channel.
    """

    def __init__(self, std, n_channels=3):
        self.std = std
        self.bg_grey = 0.4423
        self.n_channels = n_channels

    def __call__(self, img):
        """
        Args:
            tensor (Tensor): Tensor image of size (C, H, W) to be normalized.

        Returns:
            low-pass-filtered images.
        """
        # return F.normalize(tensor, self.mean, self.std, self.inplace)

        # covert image to greyscale and define variable prepare new image
        # image = rgb2grey(image)

        image = F.to_grayscale(img,self.n_channels)
        return image.filter(ImageFilter.GaussianBlur(self.std))

        # # aplly Gaussian low-pass filter
        # new_image = gaussian_filter(image, std, mode ='constant', cval=bg_grey)
        #
        # # crop too small and too large values
        # new_image[new_image < 0] = 0
        # new_image[new_image > 1] = 1

        # # return stacked (RGB) grey image
        # return np.dstack((new_image,new_image,new_image))


    def __repr__(self):
        return self.__class__.__name__ + '(std={0})'.format(self.std)



class GaussianNoise(object):
    """low pass filter.
    Args:
    - width: a scalar indicating width of additive uniform noise
                 -> then noise will be in range [-width, width]
    - contrast_level: a scalar in [0, 1]; with 1 -> full contrast
    """

    def __init__(self, width, contrast_level=1):
        assert(contrast_level >= 0.0), "contrast_level too low."
        assert(contrast_level <= 1.0), "contrast_level too high."
        assert(width >= 0.0), "width too low."
        assert(width <= 1.0), "width too high."

        self.width = width
        self.contrast_level = contrast_level

    def __call__(self, img):
        # img = (1-contrast_level)/2.0 + image.dot(contrast_level)
        img = np.asarray(img,np.uint8)
        noise = np.random.uniform(low=-self.width*125, high=self.width*125,size=img.shape).astype(np.uint8)
        # print(noise.std(),noise.mean())
        img = img + noise
        #clip values
        img = np.where(img < 0, 0, img)
        img = np.where(img > 255, 255, img)

        return Image.fromarray(img,mode="RGB")

def low_pass_filter(x):
    image = F.to_grayscale(img,self.n_channels)
    return image.filter(ImageFilter.GaussianBlur(self.std))


if __name__ == "__main__":
    from cifar10_corrupted import CIFAR10_C
    from torchvision import datasets, transforms

    width = 1.0
    for width in np.arange(0,1,0.1):
        transform = transforms.Compose([
        GaussianNoise(width=width, contrast_level=1),
        # Low_pass_filter(args.coarse_low_pass_value,n_channels=1),
        # transforms.ToTensor(),
        # transforms.Normalize(mean=[0.5,], std=[0.2,]),
        ])
        dd = CIFAR10_C(corrupted_name="gaussian_noise2",transform=transform)
        dd[0][0].save("test-{}.png".format(width))
