import numpy as np
from .base import Attack
import torch
import skimage as sk
from skimage.filters import gaussian
import cv2
from io import BytesIO

# from wand.image import Image as WandImage
# from wand.api import library as wandlibrary
import ctypes


class Att_gaussian_noise(Attack):
    def __init__(self, predict, loss_fn, clip_min, clip_max, **kwargs):
        super().__init__(predict, loss_fn, clip_min, clip_max)
        self.severity = kwargs['severity']
        
    
    def perturb(self, x, label=None, **kwargs):
        return torch.as_tensor(gaussian_noise(x, severity=self.severity)).to(torch.device("cuda:3"))

class Att_shot_noise(Attack):
    def __init__(self, predict, loss_fn, clip_min, clip_max, **kwargs):
        super().__init__(predict, loss_fn, clip_min, clip_max)
        self.severity = kwargs['severity']
        
    
    def perturb(self, x, label=None, **kwargs):
        return torch.as_tensor(shot_noise(x, severity=self.severity)).to(torch.device("cuda:3"))

class Att_impulse_noise(Attack):
    def __init__(self, predict, loss_fn, clip_min, clip_max, **kwargs):
        super().__init__(predict, loss_fn, clip_min, clip_max)
        self.severity = kwargs['severity']
        
    
    def perturb(self, x, label=None, **kwargs):
        return torch.as_tensor(impulse_noise(x, severity=self.severity)).to(torch.device("cuda:3"))


class Att_speckle_noise(Attack):
    def __init__(self, predict, loss_fn, clip_min, clip_max, **kwargs):
        super().__init__(predict, loss_fn, clip_min, clip_max)
        self.severity = kwargs['severity']
        
    
    def perturb(self, x, label=None, **kwargs):
        return torch.as_tensor(speckle_noise(x, severity=self.severity)).to(torch.device("cuda:3"))
    
class Att_glass_blur(Attack):
    def __init__(self, predict, loss_fn, clip_min, clip_max, **kwargs):
        super().__init__(predict, loss_fn, clip_min, clip_max)
        self.severity = kwargs['severity']
        
    
    def perturb(self, x, label=None, **kwargs):
        return torch.as_tensor(glass_blur(x, severity=self.severity)).to(torch.device("cuda:3"))

class Att_gaussian_blur(Attack):
    def __init__(self, predict, loss_fn, clip_min, clip_max, **kwargs):
        super().__init__(predict, loss_fn, clip_min, clip_max)
        self.severity = kwargs['severity']
        
    
    def perturb(self, x, label=None, **kwargs):
        return torch.as_tensor(gaussian_blur(x, severity=self.severity)).to(torch.device("cuda:3"))

#  --------------------------------------------------------------------------------

class Att_defocus_blur(Attack):
    def __init__(self, predict, loss_fn, clip_min, clip_max, **kwargs):
        super().__init__(predict, loss_fn, clip_min, clip_max)
        self.severity = kwargs['severity']
        
    
    def perturb(self, x, label=None, **kwargs):
        return torch.as_tensor(defocus_blur(x, severity=self.severity)).to(torch.device("cuda:3"))

# class Att_motion_blur(Attack):
#     def __init__(self, predict, loss_fn, clip_min, clip_max, **kwargs):
#         super().__init__(predict, loss_fn, clip_min, clip_max)
#         self.severity = kwargs['severity']
        
    
#     def perturb(self, x, label=None, **kwargs):
#         return torch.as_tensor(motion_blur(x, severity=self.severity)).to(torch.device("cuda:3"))
#  --------------------------------------------------------------------------------

# def motion_blur(x, severity=1):
#     x = x.cpu()
#     c = [(6,1), (6,1.5), (6,2), (8,2), (9,2.5)][severity - 1]

#     output = BytesIO()
#     x.save(output, format='PNG')
#     x = MotionImage(blob=output.getvalue())

#     x.motion_blur(radius=c[0], sigma=c[1], angle=np.random.uniform(-45, 45))

#     x = cv2.imdecode(np.fromstring(x.make_blob(), np.uint8),
#                      cv2.IMREAD_UNCHANGED)

#     if x.shape != (32, 32):
#         return np.clip(x[..., [2, 1, 0]], 0, 255)  # BGR to RGB
#     else:  # greyscale to RGB
#         return np.clip(np.array([x, x, x]).transpose((1, 2, 0)), 0, 255)

# Tell Python about the C method
# wandlibrary.MagickMotionBlurImage.argtypes = (ctypes.c_void_p,  # wand
#                                               ctypes.c_double,  # radius
#                                               ctypes.c_double,  # sigma
#                                               ctypes.c_double)  # angle
# class MotionImage(WandImage):
#     def motion_blur(self, radius=0.0, sigma=0.0, angle=0.0):
#         wandlibrary.MagickMotionBlurImage(self.wand, radius, sigma, angle)

def defocus_blur(x, severity=1):
    x = x.cpu()
    c = [(0.3, 0.4), (0.4, 0.5), (0.5, 0.6), (1, 0.2), (1.5, 0.1)][severity - 1]

    x = np.array(x) / 255.
    kernel = disk(radius=c[0], alias_blur=c[1])

    channels = []
    for d in range(3):
        channels.append(cv2.filter2D(x[:, :, d], -1, kernel))
    channels = np.array(channels).transpose((1, 2, 0))  # 3x32x32 -> 32x32x3

    return np.clip(channels, 0, 1) * 255

def disk(radius, alias_blur=0.1, dtype=np.float32):
    if radius <= 8:
        L = np.arange(-8, 8 + 1)
        ksize = (3, 3)
    else:
        L = np.arange(-radius, radius + 1)
        ksize = (5, 5)
    X, Y = np.meshgrid(L, L)
    aliased_disk = np.array((X ** 2 + Y ** 2) <= radius ** 2, dtype=dtype)
    aliased_disk /= np.sum(aliased_disk)

    # supersample disk to antialias
    return cv2.GaussianBlur(aliased_disk, ksize=ksize, sigmaX=alias_blur)
#  --------------------------------------------------------------------------------

def gaussian_blur(x, severity=1):
    x = x.cpu()
    c = [.4, .6, 0.7, .8, 1][severity - 1]

    x = gaussian(np.array(x) / 255., sigma=c, multichannel=True)
    return np.clip(x, 0, 1) * 255

def glass_blur(x, severity=1):
    x = x.cpu()
    # sigma, max_delta, iterations
    c = [(0.05,1,1), (0.25,1,1), (0.4,1,1), (0.25,1,2), (0.4,1,2), (0.8,1,2), (1.0,1,2)][severity - 1]

    x = np.uint8(gaussian(np.array(x) / 255., sigma=c[0], multichannel=True) * 255)

    # locally shuffle pixels
    for i in range(c[2]):
        for h in range(32 - c[1], c[1], -1):
            for w in range(32 - c[1], c[1], -1):
                dx, dy = np.random.randint(-c[1], c[1], size=(2,))
                h_prime, w_prime = h + dy, w + dx
                # swap
                x[h, w], x[h_prime, w_prime] = x[h_prime, w_prime], x[h, w]

    return np.clip(gaussian(x / 255., sigma=c[0], multichannel=True), 0, 1) * 255

def speckle_noise(x, severity=1):
    x = x.cpu()
    c = [.06, .1, .12, .16, .2, .3, .4][severity - 1]

    x = np.array(x) / 255.
    return np.clip(x + x * np.random.normal(size=x.shape, scale=c), 0, 1) * 255
def gaussian_noise(x, severity=1):
    c = [0.04, 0.06, .08, .09, .10, .20, .30][severity - 1]
    x = np.array(x.cpu()) / 255.
    return np.clip(x + np.random.normal(size=x.shape, scale=c), 0, 1) * 255

def shot_noise(x, severity=1):
    c = [500, 250, 100, 75, 50][severity - 1]

    x = np.array(x.cpu()) / 255.
    return np.clip(np.random.poisson(x * c) / c, 0, 1) * 255

def impulse_noise(x, severity=1):
    c = [.01, .02, .03, .05, .07, 0.10, 0.20][severity - 1]

    x = sk.util.random_noise(np.array(x.cpu()) / 255., mode='s&p', amount=c)
    return np.clip(x, 0, 1) * 255