import numpy as np
from torch.utils.data import Dataset
from typing import Optional, Tuple
import logging
from scipy.ndimage import zoom as scizoom
from PIL import Image, ImageFilter
from io import BytesIO

import torch
import cv2
from torchvision.transforms import v2
from skimage.filters import gaussian
import skimage as sk
from pkg_resources import resource_filename
from scipy.ndimage.interpolation import map_coordinates


class FrozenNoiseDataset(Dataset):
    def __init__(
        self,
        dataset,
        noise_transform=None,
        transforms=None,
        seed_base=0,
    ):
        self.dataset = dataset
        self.noise_transform = noise_transform
        self.transforms = transforms
        self.seed_base = seed_base

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        x, y = self.dataset[idx]
        if self.noise_transform is not None:
            np_random_state = np.random.get_state()
            np.random.seed(self.seed_base + idx)
            x = self.noise_transform(x)
            np.random.set_state(np_random_state)
        if self.transforms is not None:
            x = self.transforms(x)
        return x, y


def get_named_version(size, name, strength):
    if name == "gaussiannoise":
        return GaussianNoise(severity=strength)
    if name == "shotnoise":
        return ShotNoise(size=size, severity=strength)
    if name == "impulsenoise":
        return ImpulseNoise(severity=strength)
    if name == "defocusblur":
        return DefocusBlur(severity=strength)
    if name == "glassblur":
        return GlassBlur(size=size, severity=strength)

    if name == "motionblur":
        return MotionBlur(severity=strength)
    if name == "zoomblur":
        return ZoomBlur(size=size, severity=strength)
    if name == "snow":
        return Snow(size=size, severity=strength)
    if name == "frost":
        return Frost(size=size, severity=strength)
    if name == "fog":
        return Fog(size=size, severity=strength)

    if name == "brightness":
        return Brightness(severity=strength)
    if name == "contrast":
        return Contrast(severity=strength)
    if name == "elastic":
        return Elastic(severity=strength, size=size)
    if name == "pixelate":
        return Pixelate(size=size, severity=strength)
    if name == "jpegcompression":
        return JPEGCompression(severity=strength)

    if name == "specklenoise":
        return SpeckleNoise(size=size, severity=strength)
    if name == "gaussianblur":
        return GaussianBlur(severity=strength)
    if name == "spatter":
        return Spatter(severity=strength)
    if name == "saturate":
        return Saturate(severity=strength)


def plasma_fractal(mapsize=256, wibbledecay=3):
    """Generate a heightmap using diamond-square algorithm.

    Returns a square 2D array, side length 'mapsize', of floats in range 0-255.
    'mapsize' must be a power of two.

    Parameters
    ----------
    mapsize : int, optional
        Size of the map (must be a power of two). Default is 256.
    wibbledecay : int, optional
        The rate of decay for the randomness. Default is 3.
    """
    assert mapsize & (mapsize - 1) == 0
    maparray = np.empty((mapsize, mapsize), dtype=np.float_)
    maparray[0, 0] = 0
    stepsize = mapsize
    wibble = 100

    def wibbledmean(array):
        return array / 4 + wibble * np.random.uniform(-wibble, wibble, array.shape)

    def fillsquares():
        """Calculate middle value of squares as mean of points plus wibble."""
        cornerref = maparray[0:mapsize:stepsize, 0:mapsize:stepsize]
        squareaccum = cornerref + np.roll(cornerref, shift=-1, axis=0)
        squareaccum += np.roll(squareaccum, shift=-1, axis=1)
        maparray[
            stepsize // 2 : mapsize : stepsize, stepsize // 2 : mapsize : stepsize
        ] = wibbledmean(squareaccum)

    def filldiamonds():
        """Calculate middle value of diamonds as mean of points plus wibble."""
        mapsize = maparray.shape[0]
        drgrid = maparray[
            stepsize // 2 : mapsize : stepsize, stepsize // 2 : mapsize : stepsize
        ]
        ulgrid = maparray[0:mapsize:stepsize, 0:mapsize:stepsize]
        ldrsum = drgrid + np.roll(drgrid, 1, axis=0)
        lulsum = ulgrid + np.roll(ulgrid, -1, axis=1)
        ltsum = ldrsum + lulsum
        maparray[0:mapsize:stepsize, stepsize // 2 : mapsize : stepsize] = wibbledmean(
            ltsum
        )
        tdrsum = drgrid + np.roll(drgrid, 1, axis=1)
        tulsum = ulgrid + np.roll(ulgrid, -1, axis=0)
        ttsum = tdrsum + tulsum
        maparray[stepsize // 2 : mapsize : stepsize, 0:mapsize:stepsize] = wibbledmean(
            ttsum
        )

    while stepsize >= 2:
        fillsquares()
        filldiamonds()
        stepsize //= 2
        wibble /= wibbledecay

    maparray -= maparray.min()
    return maparray / maparray.max()


def clipped_zoom(img, zoom_factor):
    """Apply clipped zoom to an image.

    Parameters
    ----------
    img : numpy.ndarray
        Input image array.
    zoom_factor : float
        Factor by which to zoom the image.

    Returns
    -------
    numpy.ndarray
        The zoomed and clipped image.
    """
    h = img.shape[0]
    # ceil crop height(= crop width)
    ch = int(np.ceil(h / float(zoom_factor)))

    top = (h - ch) // 2
    img = scizoom(
        img[top : top + ch, top : top + ch], (zoom_factor, zoom_factor, 1), order=1
    )
    # trim off any extra pixels
    trim_top = (img.shape[0] - h) // 2

    return img[trim_top : trim_top + h, trim_top : trim_top + h]


def disk(radius, alias_blur=0.1, dtype=np.float32):
    if radius <= 8:
        L = np.arange(-8, 8 + 1)
        ksize = (3, 3)
    else:
        L = np.arange(-radius, radius + 1)
        ksize = (5, 5)
    X, Y = np.meshgrid(L, L)
    aliased_disk = np.array((X**2 + Y**2) <= radius**2, dtype=dtype)
    aliased_disk /= np.sum(aliased_disk)

    # supersample disk to antialias
    return cv2.GaussianBlur(aliased_disk, ksize=ksize, sigmaX=alias_blur)


# /////////////// End Corruption Helpers ///////////////


# /////////////// Corruptions ///////////////


class CustomGaussianNoise(torch.nn.Module):
    """Apply Gaussian noise to an image.

    Parameters
    ----------
    severity : int, optional
        Severity level of the noise. Default is 1.
    """

    def __init__(self, severity=1):
        super().__init__()
        self.severity = int(severity)

    def forward(self, x):
        # """
        # x: PIL.Image
        #     Needs to be a PIL image in the range (0-255)
        # """
        if self.severity == 0:
            return x
        c = [0.08, 0.12, 0.18, 0.26, 0.38][self.severity - 1]

        x = np.array(x) / 255.0
        x = np.clip(x + np.random.normal(size=x.shape, scale=c), 0, 1) * 255
        x = Image.fromarray(np.uint8(x))
        return x


class ShotNoise(torch.nn.Module):
    """Apply Shot noise to an image.

    Parameters
    ----------
    severity : int, optional
        Severity level of the noise. Default is 1.
    """

    def __init__(self, size, severity=1):
        super().__init__()
        self.severity = int(severity)

    def forward(self, x):
        if self.severity == 0:
            return x
        c = [60, 25, 12, 5, 3][self.severity - 1]
        x = np.array(x) / 255.0
        x = np.clip(np.random.poisson(x * c) / float(c), 0, 1) * 255
        x = Image.fromarray(np.uint8(x))
        return x


class ImpulseNoise(torch.nn.Module):
    def __init__(self, severity=1):
        super().__init__()
        self.severity = int(severity)

    def forward(self, x):
        if self.severity == 0:
            return x
        c = [0.03, 0.06, 0.09, 0.17, 0.27][self.severity - 1]
        x = sk.util.random_noise(np.array(x) / 255.0, mode="s&p", amount=c)
        x = np.clip(x, 0, 1) * 255
        x = Image.fromarray(np.uint8(x))
        return x


class SpeckleNoise(torch.nn.Module):
    """Apply Speckle noise to an image.

    Parameters
    ----------
    severity : int, optional
        Severity level of the noise. Default is 1.
    """

    def __init__(self, size, severity=1):
        super().__init__()
        self.severity = int(severity)

    def forward(self, x):
        if self.severity == 0:
            return x
        c = [0.15, 0.2, 0.35, 0.45, 0.6][self.severity - 1]
        x = np.array(x) / 255.0
        x = np.clip(x + x * np.random.normal(size=x.shape, scale=c), 0, 1) * 255
        x = Image.fromarray(np.uint8(x))
        return x


class ZoomBlur(torch.nn.Module):
    """Apply zoom blur to an image.

    Parameters
    ----------
    severity : int, optional
        Severity level of the blur. Default is 1.
    """

    def __init__(self, size, severity=1, **kwargs):
        super().__init__()
        self.size = size
        self.severity = int(severity)

    def forward(self, x):
        if self.severity == 0:
            return x
        c = [
            np.arange(1, 1.11, 0.01),
            np.arange(1, 1.16, 0.01),
            np.arange(1, 1.21, 0.02),
            np.arange(1, 1.26, 0.02),
            np.arange(1, 1.31, 0.03),
        ][self.severity - 1]

        if x.size[0] != 32:
            # imagenet needs resize & center-crop before corruption
            x = x.resize((256, 256))
            x = v2.CenterCrop(self.size)(x)

        x = (np.array(x) / 255.0).astype(np.float32)
        out = np.zeros_like(x)
        for zoom_factor in c:
            out += clipped_zoom(x, zoom_factor)

        x = (x + out) / (len(c) + 1)
        x = np.clip(x, 0, 1) * 255
        x = Image.fromarray(np.uint8(x))
        return x


class Fog(torch.nn.Module):
    """Apply fog effect to an image.

    Parameters
    ----------
    severity : int, optional
        Severity level of the fog. Default is 1.
    """

    def __init__(self, size, severity=1):
        super().__init__()
        self.size = size
        self.severity = int(severity)

    def forward(self, x):
        if self.severity == 0:
            return x
        c = [(1.5, 2), (2.0, 2), (2.5, 1.7), (2.5, 1.5), (3.0, 1.4)][self.severity - 1]

        if x.size[0] == 32:
            idx = 32
            mapsize = 32
        else:
            idx = 224
            mapsize = 256
            x = x.resize((256, 256))
            x = v2.CenterCrop(224)(x)
        x = np.array(x) / 255.0
        max_val = x.max()
        x += (
            c[0]
            * plasma_fractal(mapsize=mapsize, wibbledecay=c[1])[:idx, :idx][
                ..., np.newaxis
            ]
        )
        x = np.clip(x * max_val / (max_val + c[0]), 0, 1) * 255
        x = Image.fromarray(np.uint8(x))
        return x


class Contrast(torch.nn.Module):
    """Apply contrast adjustment to an image.

    Parameters
    ----------
    severity : int, optional
        Severity level of the contrast adjustment. Default is 1.
    """

    def __init__(self, severity=1):
        super().__init__()
        self.severity = int(severity)

    def forward(self, x):
        if self.severity == 0:
            return x
        c = [0.4, 0.3, 0.2, 0.1, 0.05][self.severity - 1]

        x = np.array(x) / 255.0
        means = np.mean(x, axis=(0, 1), keepdims=True)
        x = np.clip((x - means) * c + means, 0, 1) * 255
        x = Image.fromarray(np.uint8(x))
        return x


class JPEGCompression(torch.nn.Module):
    """Apply JPEG compression to an image.

    Parameters
    ----------
    severity : int, optional
        Severity level of the compression. Default is 1.
    """

    def __init__(self, severity=1):
        super().__init__()
        self.severity = int(severity)

    def forward(self, x):
        if self.severity == 0:
            return x
        c = [25, 18, 15, 10, 7][self.severity - 1]

        output = BytesIO()
        x.save(output, "JPEG", quality=c)
        x = Image.open(output)

        return x


class Pixelate(torch.nn.Module):
    """Apply pixelation effect to an image.

    Parameters
    ----------
    size: int
        The size of the final pixelated image.
    severity : int, optional
        Severity level of the pixelation. Default is 1.
    """

    def __init__(self, size=32, severity=1):
        super().__init__()
        self.severity = int(severity)
        self.size = size

    def forward(self, x):
        if self.severity == 0:
            return x
        c = [0.6, 0.5, 0.4, 0.3, 0.25][self.severity - 1]
        size = x.size[0] if x.size[0] == 32 else 224  # cifar or imagenet
        x = x.resize((int(size * c), int(size * c)), Image.BOX)
        x = x.resize((size, size), Image.BOX)
        return x


class GaussianBlur(torch.nn.Module):
    """Apply Gaussian blur to an image."""

    def __init__(self, severity=1):
        super().__init__()
        self.severity = int(severity)

    def forward(self, x):
        if self.severity == 0:
            return x

        c = [1, 2, 3, 4, 6][self.severity - 1]

        x = gaussian(np.array(x) / 255.0, sigma=c)
        x = np.clip(x, 0, 1) * 255
        x = Image.fromarray(np.uint8(x))
        return x


class Spatter(torch.nn.Module):
    def __init__(self, severity=1):
        super().__init__()
        self.severity = int(severity)

    def forward(self, x):
        if self.severity == 0:
            return x
        c = [
            (0.65, 0.3, 4, 0.69, 0.6, 0),
            (0.65, 0.3, 3, 0.68, 0.6, 0),
            (0.65, 0.3, 2, 0.68, 0.5, 0),
            (0.65, 0.3, 1, 0.65, 1.5, 1),
            (0.67, 0.4, 1, 0.65, 1.5, 1),
        ][self.severity - 1]
        x = np.array(x, dtype=np.float32) / 255.0

        liquid_layer = np.random.normal(size=x.shape[:2], loc=c[0], scale=c[1])

        liquid_layer = gaussian(liquid_layer, sigma=c[2])
        liquid_layer[liquid_layer < c[3]] = 0
        if c[5] == 0:
            liquid_layer = (liquid_layer * 255).astype(np.uint8)
            dist = 255 - cv2.Canny(liquid_layer, 50, 150)
            dist = cv2.distanceTransform(dist, cv2.DIST_L2, 5)
            _, dist = cv2.threshold(dist, 20, 20, cv2.THRESH_TRUNC)
            dist = cv2.blur(dist, (3, 3)).astype(np.uint8)
            dist = cv2.equalizeHist(dist)
            ker = np.array([[-2, -1, 0], [-1, 1, 1], [0, 1, 2]])
            dist = cv2.filter2D(dist, cv2.CV_8U, ker)
            dist = cv2.blur(dist, (3, 3)).astype(np.float32)

            m = cv2.cvtColor(liquid_layer * dist, cv2.COLOR_GRAY2BGRA)
            m /= np.max(m, axis=(0, 1))
            m *= c[4]

            # water is pale turqouise
            color = np.concatenate(
                (
                    175 / 255.0 * np.ones_like(m[..., :1]),
                    238 / 255.0 * np.ones_like(m[..., :1]),
                    238 / 255.0 * np.ones_like(m[..., :1]),
                ),
                axis=2,
            )

            color = cv2.cvtColor(color, cv2.COLOR_BGR2BGRA)
            x = cv2.cvtColor(x, cv2.COLOR_BGR2BGRA)

            x = cv2.cvtColor(np.clip(x + m * color, 0, 1), cv2.COLOR_BGRA2BGR) * 255
            x = Image.fromarray(np.uint8(x))
            return x
        else:
            m = np.where(liquid_layer > c[3], 1, 0)
            m = gaussian(m.astype(np.float32), sigma=c[4])
            m[m < 0.8] = 0

            # mud brown
            color = np.concatenate(
                (
                    63 / 255.0 * np.ones_like(x[..., :1]),
                    42 / 255.0 * np.ones_like(x[..., :1]),
                    20 / 255.0 * np.ones_like(x[..., :1]),
                ),
                axis=2,
            )

            color *= m[..., np.newaxis]
            x *= 1 - m[..., np.newaxis]

            x = np.clip(x + color, 0, 1) * 255
            x = Image.fromarray(np.uint8(x))
            return x


class Brightness(torch.nn.Module):

    def __init__(self, severity=1):
        super().__init__()
        self.severity = int(severity)

    def forward(self, x):
        if self.severity == 0:
            return x
        c = [0.1, 0.2, 0.3, 0.4, 0.5][self.severity - 1]

        x = np.array(x) / 255.0
        x = sk.color.rgb2hsv(x)
        x[:, :, 2] = np.clip(x[:, :, 2] + c, 0, 1)
        x = sk.color.hsv2rgb(x)
        x = np.clip(x, 0, 1) * 255
        x = Image.fromarray(np.uint8(x))
        return x


class Saturate(torch.nn.Module):

    def __init__(self, severity=1):
        super().__init__()
        self.severity = int(severity)

    def forward(self, x):
        if self.severity == 0:
            return x
        c = [(0.3, 0), (0.1, 0), (2, 0), (5, 0.1), (20, 0.2)][self.severity - 1]

        x = np.array(x) / 255.0
        x = sk.color.rgb2hsv(x)
        x[:, :, 1] = np.clip(x[:, :, 1] * c[0] + c[1], 0, 1)
        x = sk.color.hsv2rgb(x)
        x = np.clip(x, 0, 1) * 255
        x = Image.fromarray(np.uint8(x))

        return x


class Elastic(torch.nn.Module):

    def __init__(self, severity=1, size=224):
        super().__init__()
        self.severity = int(severity)
        self.size = size

    def forward(self, x):
        if self.severity == 0:
            return x
        c = [
            (
                self.size * 2,
                self.size * 0.7,
                self.size * 0.1,
            ),
            (self.size * 2, self.size * 0.08, self.size * 0.2),
            (self.size * 0.05, self.size * 0.01, self.size * 0.02),
            (self.size * 0.07, self.size * 0.01, self.size * 0.02),
            (self.size * 0.12, self.size * 0.01, self.size * 0.02),
        ][self.severity - 1]

        image = np.array(x, dtype=np.float32) / 255.0
        shape = image.shape
        shape_size = shape[:2]

        # random affine
        center_square = np.float32(shape_size) // 2
        square_size = min(shape_size) // 3
        pts1 = np.float32(
            [
                center_square + square_size,
                [center_square[0] + square_size, center_square[1] - square_size],
                center_square - square_size,
            ]
        )
        pts2 = pts1 + np.random.uniform(-c[2], c[2], size=pts1.shape).astype(np.float32)
        M = cv2.getAffineTransform(pts1, pts2)
        image = cv2.warpAffine(
            image, M, shape_size[::-1], borderMode=cv2.BORDER_REFLECT_101
        )

        dx = (
            gaussian(
                np.random.uniform(-1, 1, size=shape[:2]),
                c[1],
                mode="reflect",
                truncate=3,
            )
            * c[0]
        ).astype(np.float32)
        dy = (
            gaussian(
                np.random.uniform(-1, 1, size=shape[:2]),
                c[1],
                mode="reflect",
                truncate=3,
            )
            * c[0]
        ).astype(np.float32)
        dx, dy = dx[..., np.newaxis], dy[..., np.newaxis]

        x, y, z = np.meshgrid(
            np.arange(shape[1]), np.arange(shape[0]), np.arange(shape[2])
        )
        indices = (
            np.reshape(y + dy, (-1, 1)),
            np.reshape(x + dx, (-1, 1)),
            np.reshape(z, (-1, 1)),
        )
        x = (
            np.clip(
                map_coordinates(image, indices, order=1, mode="reflect").reshape(shape),
                0,
                1,
            )
            * 255
        )
        x = Image.fromarray(np.uint8(x))
        return x


class DefocusBlur(torch.nn.Module):

    def __init__(self, severity=1):
        super().__init__()
        self.severity = int(severity)

    def forward(self, x):
        if self.severity == 0:
            return x
        c = [(3, 0.1), (4, 0.5), (6, 0.5), (8, 0.5), (10, 0.5)][self.severity - 1]

        x = np.array(x) / 255.0
        kernel = disk(radius=c[0], alias_blur=c[1])

        channels = []
        for d in range(3):
            channels.append(cv2.filter2D(x[:, :, d], -1, kernel))
        channels = np.array(channels).transpose((1, 2, 0))  # 3x224x224 -> 224x224x3

        x = np.clip(channels, 0, 1) * 255
        x = Image.fromarray(np.uint8(x))
        return x


class Frost(torch.nn.Module):

    def __init__(self, size, severity=1):
        super().__init__()
        self.size = size
        self.severity = int(severity)

    def forward(self, x):
        if self.severity == 0:
            return x
        c = [(1, 0.4), (0.8, 0.6), (0.7, 0.7), (0.65, 0.7), (0.6, 0.75)][
            self.severity - 1
        ]
        idx = np.random.randint(5)
        filename = [
            resource_filename(__name__, "frost/frost1.png"),
            resource_filename(__name__, "frost/frost2.png"),
            resource_filename(__name__, "frost/frost3.png"),
            resource_filename(__name__, "frost/frost4.jpg"),
            resource_filename(__name__, "frost/frost5.jpg"),
            resource_filename(__name__, "frost/frost6.jpg"),
        ][idx]
        frost = cv2.imread(filename)
        # randomly crop and convert to rgb
        x_start, y_start = np.random.randint(
            0, frost.shape[0] - self.size
        ), np.random.randint(0, frost.shape[1] - self.size)
        frost = frost[x_start : x_start + self.size, y_start : y_start + self.size][
            ..., [2, 1, 0]
        ]
        x = np.clip(c[0] * np.array(x) + c[1] * frost, 0, 255)
        x = Image.fromarray(np.uint8(x))
        return x


class GaussianNoise(torch.nn.Module):

    def __init__(self, severity=1):
        super().__init__()
        self.severity = int(severity)

    def forward(self, x):
        if self.severity == 0:
            return x
        c = [0.08, 0.12, 0.18, 0.26, 0.38][self.severity - 1]

        x = np.array(x) / 255.0
        x = np.clip(x + np.random.normal(size=x.shape, scale=c), 0, 1) * 255
        x = Image.fromarray(np.uint8(x))
        return x


class GlassBlur(torch.nn.Module):

    def __init__(self, size, severity=1):
        super().__init__()
        self.size = size
        self.severity = int(severity)

    def forward(self, x):
        if self.severity == 0:
            return x
        # sigma, max_delta, iterations
        c = [(0.7, 1, 2), (0.9, 2, 1), (1, 2, 3), (1.1, 3, 2), (1.5, 4, 2)][
            self.severity - 1
        ]

        x = np.uint8(gaussian(np.array(x) / 255.0, sigma=c[0]) * 255)

        # locally shuffle pixels
        for i in range(c[2]):
            for h in range(self.size - c[1], c[1], -1):
                for w in range(self.size - c[1], c[1], -1):
                    dx, dy = np.random.randint(-c[1], c[1], size=(2,))
                    h_prime, w_prime = h + dy, w + dx
                    # swap
                    x[h, w], x[h_prime, w_prime] = x[h_prime, w_prime], x[h, w]

        x = np.clip(gaussian(x / 255.0, sigma=c[0]), 0, 1) * 255
        x = Image.fromarray(np.uint8(x))
        return x


# /////////////// Corruptions using Wand ///////////////

from wand.api import library as wandlibrary
from wand.image import Image as WandImage


class MotionImage(WandImage):
    def motion_blur(self, radius=0.0, sigma=0.0, angle=0.0):
        wandlibrary.MagickMotionBlurImage(self.wand, radius, sigma, angle)


class Snow(torch.nn.Module):
    """Apply Snow noise to an image.

    Parameters
    ----------
    severity : int, optional
        Severity level of the noise. Default is 1.
    """

    def __init__(self, size, severity=1):
        super().__init__()
        self.size = size
        self.severity = int(severity)

    def forward(self, x):
        if self.severity == 0:
            return x
        c = [
            (0.1, 0.3, 3, 0.5, 10, 4, 0.8),
            (0.2, 0.3, 2, 0.5, 12, 4, 0.7),
            (0.55, 0.3, 4, 0.9, 12, 8, 0.7),
            (0.55, 0.3, 4.5, 0.85, 12, 8, 0.65),
            (0.55, 0.3, 2.5, 0.85, 12, 12, 0.55),
        ][self.severity - 1]

        x = np.array(x, dtype=np.float32) / 255.0
        snow_layer = np.random.normal(
            size=x.shape[:2], loc=c[0], scale=c[1]
        )  # [:2] for monochrome

        snow_layer = clipped_zoom(snow_layer[..., np.newaxis], c[2])
        snow_layer[snow_layer < c[3]] = 0

        snow_layer = Image.fromarray(
            (np.clip(snow_layer.squeeze(), 0, 1) * 255).astype(np.uint8), mode="L"
        )
        output = BytesIO()
        snow_layer.save(output, format="PNG")
        snow_layer = MotionImage(blob=output.getvalue())

        snow_layer.motion_blur(
            radius=c[4], sigma=c[5], angle=np.random.uniform(-135, -45)
        )

        snow_layer = (
            cv2.imdecode(
                np.fromstring(snow_layer.make_blob(), np.uint8), cv2.IMREAD_UNCHANGED
            )
            / 255.0
        )
        snow_layer = snow_layer[..., np.newaxis]

        x = c[6] * x + (1 - c[6]) * np.maximum(
            x,
            cv2.cvtColor(x, cv2.COLOR_RGB2GRAY).reshape(self.size, self.size, 1) * 1.5
            + 0.5,
        )
        x = np.clip(x + snow_layer + np.rot90(snow_layer, k=2), 0, 1) * 255
        x = Image.fromarray(np.uint8(x))
        return x


class MotionBlur(torch.nn.Module):

    def __init__(self, severity=1):
        super().__init__()
        self.severity = int(severity)

    def forward(self, x):
        if self.severity == 0:
            return x
        c = [(10, 3), (15, 5), (15, 8), (15, 12), (20, 15)][self.severity - 1]

        output = BytesIO()
        x.save(output, format="PNG")
        x = MotionImage(blob=output.getvalue())

        x.motion_blur(radius=c[0], sigma=c[1], angle=np.random.uniform(-45, 45))

        x = cv2.imdecode(np.fromstring(x.make_blob(), np.uint8), cv2.IMREAD_UNCHANGED)

        if x.shape != (224, 224):
            x = np.clip(x[..., [2, 1, 0]], 0, 255)  # BGR to RGB
        else:  # greyscale to RGB
            x = np.clip(np.array([x, x, x]).transpose((1, 2, 0)), 0, 255)

        x = Image.fromarray(np.uint8(x))

        return x
