import copy
import numbers
import os
import random
import subprocess
import time
from operator import itemgetter

import cv2
from PIL import Image
import numpy as np
import torch
from skimage.transform import resize
from skimage.util.shape import view_as_windows
from torchvision import transforms
from torch import nn

"""
Originally based on the data augmentations from RAD
(https://github.com/pokaxpoka/rad_procgen)
but modified to run with torch tensors and on the gpu.

This was done mostly for reasons that did not end up in
the final paper. After we noticed the benefits of CJ and NR
in DMCR, we experimented with Style Transfer as a way to
add texture. However, the heavy computational burden
kept it out of the final results.
"""


class AugmentationSequence:
    def __init__(self, aug_list):
        self.aug_list = aug_list

    def __call__(self, *batches):
        for aug in self.aug_list:
            aug.change_randomization_params()
        results = []
        for original_batch in batches:
            batch = original_batch.clone()
            for augmentation in self.aug_list:
                with torch.no_grad():
                    batch = augmentation(batch)
            results.append(batch)
        return tuple(results) if len(results) > 1 else results[0]


class _StyleTransfer:
    def get_gpu_memory_map(self):
        result = subprocess.check_output(
            ["nvidia-smi", "--query-gpu=memory.used", "--format=csv,nounits,noheader"],
            encoding="utf-8",
        )
        # Convert lines into a dictionary
        gpu_memory = [int(x) for x in result.strip().split("\n")]
        gpu_memory_map = zip(range(len(gpu_memory)), gpu_memory)
        return list(gpu_memory_map)

    def resize_imgs(self, imgs, size):
        original_device = imgs.device
        result = np.empty((imgs.shape[0],) + size + (imgs.shape[-1],), dtype=np.uint8)
        for i in range(imgs.shape[0]):
            result[i] = cv2.resize(imgs[i].cpu().numpy(), size)
        return torch.from_numpy(result).to(original_device)

    def denorm(self, x):
        _min = x.min(-1).values.min(-1).values.min(-1).values
        _max = x.max(-1).values.max(-1).values.max(-1).values
        lower_bound = torch.ones_like(x) * _min.view(-1, 1, 1, 1)
        upper_bound = torch.ones_like(x) * _max.view(-1, 1, 1, 1)
        x = torch.max(torch.min(x, upper_bound), lower_bound)
        x.add_(-lower_bound).div_(upper_bound - lower_bound + 1e-5)
        return x

    def stylize(self, x):
        raise NotImplementedError()


class AINStyleTransfer(_StyleTransfer):
    def __init__(self, alpha=1.0, resize=False):
        from ain import model as ain_model

        self.alpha = alpha
        self.resize = resize

        self.encoder = ain_model.encoder_load(
            os.path.join(os.path.dirname(__file__), f"ain/models/encoder.pth")
        )
        self.decoder = ain_model.decoder_load(
            os.path.join(os.path.dirname(__file__), f"ain/models/decoder.pth")
        )

        self.gpu = sorted(self.get_gpu_memory_map(), key=itemgetter(1))[0][0]

        self.encoder.eval()
        self.encoder.cuda(self.gpu)

        self.decoder.eval()
        self.decoder.cuda(self.gpu)

    def stylize(self, content, style):
        original_device = content.device

        if self.resize:
            content = self.resize_imgs(content, (256, 256))
            style = self.resize_imgs(style, (256, 256))

        content = content.float() / 255.0
        content = content.cuda(self.gpu).permute(0, 3, 1, 2)

        style = style.float() / 255.0
        style = style.cuda(self.gpu).permute(0, 3, 1, 2)

        with torch.no_grad():
            result = ain_model.style_transfer(
                self.encoder, self.decoder, content, style, self.alpha
            )
        result = self.denorm(result)
        result = (result.permute(0, 2, 3, 1) * 255.0).type(torch.uint8)

        if self.resize:
            result = self.resize_imgs(result, (64, 64))

        return result.to(original_device)


class LinearStyleTransfer(_StyleTransfer):
    def __init__(self, layer="r31", resize=False):
        from lst.libs.Matrix import MulLayer
        from lst.libs.models import (
            decoder3,
            decoder4,
            decoder5,
            encoder3,
            encoder4,
            encoder5,
        )

        self.layer = layer
        self.resize = resize

        if layer == "r31":
            self.vgg = encoder3()
            self.dec = decoder3()
        elif layer == "r41":
            self.vgg = encoder4()
            self.dec = decoder4()
        else:
            raise ValueError(f"Uncrecognized layer: {layer}")
        self.matrix = MulLayer(layer)

        self.vgg.load_state_dict(
            torch.load(
                os.path.join(os.path.dirname(__file__), f"lst/models/vgg_{layer}.pth")
            )
        )
        self.dec.load_state_dict(
            torch.load(
                os.path.join(os.path.dirname(__file__), f"lst/models/dec_{layer}.pth")
            )
        )
        self.matrix.load_state_dict(
            torch.load(
                os.path.join(
                    os.path.dirname(__file__),
                    f"lst/models/custom64x64_{layer}/{layer}.pth",
                )
            )
        )

        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.vgg = self.vgg.cuda(self.device)
        self.dec = self.dec.cuda(self.device)
        self.matrix = self.matrix.cuda(self.device)

    def stylize(self, content, style):
        original_device = content.device
        *_, original_size = content.shape

        if self.resize:
            content = self.resize_imgs(content, (256, 256))
            style = self.resize_imgs(style, (256, 256))

        content = content.to(self.device).float() / 255.0
        style = style.to(self.device).float() / 255.0

        with torch.no_grad():
            sF = self.vgg(style)
            cF = self.vgg(content)
            if self.layer == "r41":
                feature, transmatrix = self.matrix(cF[self.layer], sF[self.layer])
            else:
                feature, transmatrix = self.matrix(cF, sF)
            transfer = torch.clamp(self.dec(feature), 0, 1)
            transfer = self.denorm(transfer)
        if self.resize:
            result = self.resize_imgs(result, original_size)
        result = transfer * 255.0
        return result.to(original_device)


class LightLinearStyleTransfer(_StyleTransfer):
    def __init__(self):
        from lst_light.style import Encoder, Decoder, Matrix, load_data

        self.encoder = Encoder()
        self.decoder = Decoder()
        self.matrix = Matrix()

        self.device = "cuda" if torch.cuda.is_available() else "cpu"

        self.encoder.to(self.device)
        self.decoder.to(self.device)
        self.matrix.to(self.device)

        self.encoder, self.decoder = load_data(self.encoder, self.decoder)
        self.matrix.load_state_dict(
            torch.load("./lst_light/weights/matrix.pth", map_location=self.device)
        )

    def stylize(self, content, style):
        content /= 255.0
        style /= 255.0
        content = content.to(self.device)
        style = style.to(self.device)
        with torch.no_grad():
            _, _, _, _, _, _, sF = self.encoder(style)
            _, _, _, _, _, _, cF = self.encoder(content)
            feature, transmatrix = self.matrix(cF, sF)
            transfer = self.decoder(feature)
        result = torch.clamp(transfer, 0, 1)
        result = result * 255.0
        return result


class IntraBatchStyleTransferAug:
    def __init__(
        self,
        batch_size,
        style_transfer_net=LinearStyleTransfer(),
        st_batch_size=32,
        *_args,
        **_kwargs,
    ):
        self._st = style_transfer_net
        self.st_batch_size = st_batch_size
        self.batch_size = batch_size
        self.change_randomization_params()

    def change_randomization_params(self):
        idxs = set(range(self.batch_size))
        group1 = set(random.sample(list(range(self.batch_size)), self.batch_size // 2))
        group2 = idxs - group1
        self.group1 = list(group1)
        self.group2 = list(group2)

    def __call__(self, imgs):
        original_device = imgs.device

        group1 = imgs[self.group1]
        group2 = imgs[self.group2]

        i = 0
        group1_out = []
        group2_out = []
        while i < self.batch_size / 2:
            group1_batch = group1[i : i + self.st_batch_size]
            group2_batch = group2[i : i + self.st_batch_size]
            # group1 in the style of group2
            group1_batch_out = self._st.stylize(group1_batch, group2_batch).to(
                original_device
            )
            group1_out.append(group1_batch_out)
            # group2 in the style of group1
            group2_batch_out = self._st.stylize(group2_batch, group1_batch).to(
                original_device
            )
            group2_out.append(group2_batch_out)
            i += self.st_batch_size
        imgs[self.group1] = torch.cat(group1_out, axis=0).to(original_device)
        imgs[self.group2] = torch.cat(group2_out, axis=0).to(original_device)

        return imgs.to(original_device)


class RandomStyleTransferAug:
    def __init__(
        self,
        batch_size,
        style_transfer_net=LinearStyleTransfer(),
        st_batch_size=256,
        style_path="~/wikiart",
        num_styles=1000,
        *_args,
        **_kwargs,
    ):
        self.batch_size = batch_size
        self.st_batch_size = st_batch_size
        self._st = style_transfer_net

        def sample_int(n):
            i = n - 1
            order = np.random.permutation(n)
            while True:
                yield order[i]
                i += 1
                if i >= n:
                    np.random.seed()
                    order = np.random.permutation(n)
                    i = 0

        class LoopSampler(torch.utils.data.sampler.Sampler):
            def __init__(self, data_source):
                self.num_samples = len(data_source)

            def __iter__(self):
                return iter(sample_int(self.num_samples))

            def __len__(self):
                return 2 ** 31

        class FolderDataset(torch.utils.data.Dataset):
            def __init__(self, root, transform):
                super(FolderDataset, self).__init__()
                self.root = root
                self.paths = os.listdir(self.root)
                self.transform = transform
                self.default = None

            def __getitem__(self, index):
                path = self.paths[index]
                try:
                    img = Image.open(os.path.join(self.root, path)).convert("RGB")
                except:
                    img = self.default
                else:
                    img = self.transform(img)
                    self.default = img
                return img

            def __len__(self):
                return len(self.paths)

            def name(self):
                return "FolderDataset"

        T = transforms.Compose(
            [
                transforms.Resize(size=(300, 300)),
                transforms.RandomCrop(256),
                transforms.ToTensor(),
            ]
        )
        s_dataset = FolderDataset(style_path, T)

        self.all_styles = (
            next(
                iter(
                    torch.utils.data.DataLoader(
                        s_dataset,
                        batch_size=num_styles,
                        sampler=LoopSampler(s_dataset),
                        num_workers=5,
                    )
                )
            )
            * 255.0
        ).type(torch.uint8)

        self.change_randomization_params()

    def change_randomization_params(self):
        self.style_idxs = list(
            set(random.sample(list(range(len(self.all_styles))), self.batch_size))
        )

    def __call__(self, imgs):
        original_device = imgs.device

        b, c, h, w = imgs.shape
        t = c // 3

        i = 0
        out = []
        while i < self.batch_size:
            style_batch = self.all_styles[self.style_idxs][i : i + self.st_batch_size]
            batch_out = torch.zeros_like(imgs[: self.st_batch_size])
            for j in range(t):
                img_batch = imgs[i : i + self.st_batch_size, j * 3 : (j + 1) * 3, ...]
                batch_out[:, j * 3 : (j + 1) * 3, ...] = self._st.stylize(
                    img_batch, style_batch
                )
            out.append(batch_out)
            i += self.st_batch_size

        imgs = torch.cat(out, dim=0).to(original_device)
        return imgs


class GrayscaleAug:
    def __init__(self, batch_size, p_rand=0.5, *_args, **_kwargs):

        self.p_gray = p_rand
        self.batch_size = batch_size
        self.random_inds = np.random.choice(
            [True, False], batch_size, p=[self.p_gray, 1 - self.p_gray]
        )

    def grayscale(self, imgs):
        b, c, h, w = imgs.shape
        t = c // 3
        # TODO: speed this up
        for i in range(t):
            idx = i * 3
            imgs[:, idx : idx + 3] = (
                imgs[:, idx, ...] * 0.2989
                + imgs[:, idx + 1, ...] * 0.587
                + imgs[:, idx + 2, ...] * 0.114
            )
        return imgs

    def __call__(self, images):
        bs, channels, h, w = images.shape
        if self.random_inds.sum() > 0:
            images[self.random_inds] = self.grayscale(images[self.random_inds])
        return images

    def change_randomization_params(self):
        self.random_inds = np.random.choice(
            [True, False], self.batch_size, p=[self.p_gray, 1 - self.p_gray]
        )


class CutoutAug:
    def __init__(
        self,
        batch_size,
        box_min=7,
        box_max=22,
        pivot_h=12,
        pivot_w=24,
        *_args,
        **_kwargs,
    ):

        self.box_min = box_min
        self.box_max = box_max
        self.pivot_h = pivot_h
        self.pivot_w = pivot_w
        self.batch_size = batch_size
        self.w1 = torch.randint(self.box_min, self.box_max, (batch_size,))
        self.h1 = torch.randint(self.box_min, self.box_max, (batch_size,))

    def __call__(self, imgs):
        n, c, h, w = imgs.shape
        for i, (img, w11, h11) in enumerate(zip(imgs, self.w1, self.h1)):
            img[
                :,
                self.pivot_h + h11 : self.pivot_h + h11 + h11,
                self.pivot_w + w11 : self.pivot_w + w11 + w11,
                ...,
            ] = 0
        return imgs

    def change_randomization_params(self, index_):
        self.w1[index_] = toch.randint(self.box_min, self.box_max)
        self.h1[index_] = torch.randint(self.box_min, self.box_max)

    def change_randomization_params(self):
        self.w1 = torch.randint(self.box_min, self.box_max, (self.batch_size,))
        self.h1 = torch.randint(self.box_min, self.box_max, (self.batch_size,))



class RadAug:
    """
    Emulates the augmentation in Reinforcement Learning with
    Augmented Data by upscaling and randomly cropping.
    """

    def __init__(self, batch_size, crop=16, *_args, **kwargs):
        self.batch_size = batch_size
        self.crop = crop
        self.change_randomization_params()

    def change_randomization_params(self):
        self.h = torch.randint(self.crop, size=(self.batch_size,))
        self.w = torch.randint(self.crop, size=(self.batch_size,))

    def resize_imgs(self, imgs):
        original_device = imgs.device
        b, c, h, w = imgs.shape
        result = np.empty((b, c, h + self.crop, w + self.crop), dtype=np.float32)
        for i in range(imgs.shape[0]):
            result[i] = cv2.resize(
                imgs[i].cpu().numpy().transpose(1, 2, 0), (h + self.crop, w + self.crop)
            ).transpose(2, 0, 1)
        return torch.from_numpy(result).to(original_device)

    def __call__(self, imgs):
        b, c, h, w = imgs.shape
        upscaled_imgs = self.resize_imgs(imgs)
        for i, (crop_h, crop_w) in enumerate(zip(self.h, self.w)):
            imgs[i, ...] = upscaled_imgs[i, :, crop_h : h + crop_h, crop_w : w + crop_w]
        return imgs


class DrqAug:
    """
    Emulates the augmentation in "Image Augmentation is All You Need"
    by replicating boundary pixels and then randomly cropping. We can't use
    the kornia library like the author's repo does, because there doesn't seem
    to be an easy way to preserve the crop across both observation batches. So the
    random crop is implemented manually. One thing I've noticed after looking at
    the official version is that kornia introduces many decimal pixel values during
    it's random crop. In case that's important, I've added a small amount of noise to
    get a similar effect.
    """

    def __init__(self, batch_size, pad=4, noise=True, *_args, **kwargs):
        self.batch_size = batch_size
        self.pad = pad
        self.change_randomization_params()
        self.noise = noise
        self.pad_func = nn.ReflectionPad2d(pad)

    def change_randomization_params(self):
        self.w1 = torch.randint(0, self.pad * 2, (self.batch_size,))
        self.h1 = torch.randint(0, self.pad * 2, (self.batch_size,))

    def random_crop(self, imgs, out):
        b, c, h, w = imgs.shape
        crop_max = h - out + 1
        cropped = torch.zeros((b, c, out, out), dtype=torch.float32)
        for i, (img, w11, h11) in enumerate(zip(imgs, self.w1, self.h1)):
            cropped[i] = img[:, h11 : h11 + out, w11 : w11 + out]
        return cropped

    def __call__(self, imgs):
        b, c, h, w = imgs.shape
        og_device = imgs.device
        padded = self.pad_func(imgs)
        cropped = self.random_crop(padded, out=h).to(og_device)
        if self.noise:
            cropped += torch.randn_like(cropped)
        return cropped.clamp(0, 255.0)



class DrqNoNoiseAug(DrqAug):
    def __init__(self, batch_size, pad=4, noise=False, *_args, **kwargs):
        super().__init__(batch_size, pad, noise)


class LargeDrqNoNoiseAug(DrqAug):
    def __init__(self, batch_size, pad=12, noise=False, *_args, **kwargs):
        super().__init__(batch_size, pad, noise)


class LargeDrqAug(DrqAug):
    def __init__(self, batch_size, pad=12, *_args, **kwargs):
        super().__init__(batch_size, pad)


class TranslateAug:
    def __init__(
        self, batch_size, translate_max=4, *_args, **_kwargs,
    ):
        self.batch_size = batch_size
        self.translate_max = translate_max
        self.translation = torch.zeros((self.batch_size, 2), dtype=torch.float32)
        self.change_randomization_params()

    def change_randomization_params(self):
        self.translation = (
            torch.randint(
                2 * self.translate_max, (self.batch_size, 2), dtype=torch.int32
            )
            - self.translate_max
        )
        self.random_color = torch.randint(255, size=(self.batch_size, 3, 1, 1)).float()

    def __call__(self, imgs):
        b, c, h, w = imgs.shape
        final = torch.ones(
            (b, c, h + (2 * self.translate_max), w + (2 * self.translate_max)),
            dtype=torch.float32,
        )
        final *= self.random_color.repeat(1, c // 3, 1, 1)

        h_min = self.translate_max + self.translation[:, 0]
        h_max = self.translate_max + h + self.translation[:, 0]

        w_min = self.translate_max + self.translation[:, 1]
        w_max = self.translate_max + w + self.translation[:, 1]

        for i in range(self.batch_size):
            # TODO: speed this up??
            final[i, :, h_min[i] : h_max[i], w_min[i] : w_max[i]] = imgs[i]
        final = final[
            :,
            :,
            self.translate_max : -self.translate_max,
            self.translate_max : -self.translate_max,
        ].type(torch.float32)
        return final.to(imgs.device)


class LargeTranslateAug(TranslateAug):
    def __init__(self, batch_size, translate_max=8, *_args, **kwargs):
        super().__init__(batch_size, translate_max)


class CutoutColorAug:
    def __init__(
        self,
        batch_size,
        box_min=7,
        box_max=22,
        pivot_h=12,
        pivot_w=24,
        *_args,
        **_kwargs,
    ):
        self.box_min = box_min
        self.box_max = box_max
        self.pivot_h = pivot_h
        self.pivot_w = pivot_w
        self.batch_size = batch_size
        self.change_randomization_params()

    def __call__(self, imgs):
        b, c, h, w = imgs.shape
        self.rand_box = self.rand_box.to(imgs.device)
        for i, (img, w11, h11) in enumerate(zip(imgs, self.w1, self.h1)):
            for t in range(c // 3):
                img[
                    t * 3 : (t + 1) * 3,
                    self.pivot_h + h11 : self.pivot_h + h11 + h11,
                    self.pivot_w + w11 : self.pivot_w + w11 + w11,
                ] = self.rand_box[i]
        return imgs

    def change_randomization_params(self):
        self.w1 = torch.randint(self.box_min, self.box_max, (self.batch_size,))
        self.h1 = torch.randint(self.box_min, self.box_max, (self.batch_size,))
        self.rand_box = torch.randint(
            0, 255, size=(self.batch_size, 3, 1, 1), dtype=torch.float32,
        )


class GammaAug:
    gamma_mean = 1.0
    gamma_std = 0.45

    def __init__(self, batch_size, *_args, **kwargs):
        self.batch_size = batch_size
        self.change_randomization_params()

    def change_randomization_params(self):
        self.gamma = torch.from_numpy(
            np.random.normal(self.gamma_mean, self.gamma_std, size=(self.batch_size,))
        ).float()
        self.gamma = self.gamma.view(-1, 1, 1, 1)

    def __call__(self, imgs):
        imgs /= 255.0
        imgs = imgs.pow(self.gamma.to(imgs.device))
        imgs *= 255.0
        return imgs.clamp(0.0, 255.0)


class _FlipAug:
    def __init__(self, batch_size, p_rand=0.5, dim=None, *_args, **_kwargs):
        assert dim
        self.p_flip = p_rand
        self.batch_size = batch_size
        self.dim = dim
        self.change_randomization_params()

    def __call__(self, images):
        if self.random_inds.sum() > 0:
            images[self.random_inds] = torch.flip(images[self.random_inds], (self.dim,))
        return images

    def change_randomization_params(self):
        self.random_inds = np.random.choice(
            [True, False], self.batch_size, p=[self.p_flip, 1 - self.p_flip]
        )


class HorizontalFlipAug(_FlipAug):
    def __init__(self, batch_size, p_rand=0.5, *_args, **kwargs):
        super().__init__(batch_size, p_rand, dim=3)


class VerticalFlipAug(_FlipAug):
    def __init__(self, batch_size, p_rand=0.5, *_args, **kwargs):
        super().__init__(batch_size, p_rand, dim=2)


class RotateAug:
    def __init__(self, batch_size, *_args, **_kwargs):

        self.batch_size = batch_size
        self.change_randomization_params()

    def __call__(self, imgs):
        for k in range(1, 4):
            rotate_mask = torch.where((self.random_inds == k))
            imgs[rotate_mask] = torch.rot90(imgs[rotate_mask], k=(k + 1), dims=(2, 3))
        return imgs

    def change_randomization_params(self):
        self.random_inds = torch.randint(
            4, size=(self.batch_size,),
        ) * self.batch_size + np.arange(self.batch_size)


class IdentityAug:
    def __init__(self, batch_size, *_args, **_kwargs):
        self.batch_size = batch_size

    def __call__(self, imgs):
        return imgs

    def change_randomization_params(self):
        return



class WindowAug:
    def __init__(self, batch_size, *_args, **_kwargs):
        self.batch_size = batch_size
        self.crop_size = 64
        self.crop_max = 75 - self.crop_size
        self.change_randomization_params()

    def __call__(self, imgs):
        mask = torch.zeros((imgs.shape), dtype=torch.float32, device=imgs.device)
        for i in range(self.batch_size):
            mask[
                i,
                :,
                self.h1[i] : self.h1[i] + self.crop_size,
                self.w1[i] : self.w1[i] + self.crop_size,
            ] = 1.0
        imgs *= mask
        return imgs

    def change_randomization_params(self):
        self.w1 = torch.randint(0, self.crop_max, (self.batch_size,))
        self.h1 = torch.randint(0, self.crop_max, (self.batch_size,))



class ColorJitterAug(torch.nn.Module):
    def __init__(
        self,
        batch_size,
        brightness=0.4,
        contrast=0.4,
        saturation=0.4,
        hue=0.5,
        p_rand=1.0,
        stack_size=1,
        *_args,
        **_kwargs,
    ):
        super().__init__()
        self.brightness = self._check_input(brightness, "brightness")
        self.contrast = self._check_input(contrast, "contrast")
        self.saturation = self._check_input(saturation, "saturation")
        self.hue = self._check_input(
            hue, "hue", center=0, bound=(-0.5, 0.5), clip_first_on_zero=False
        )

        self.prob = p_rand
        self.batch_size = batch_size
        self.stack_size = stack_size
        self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.change_randomization_params()

    def _check_input(
        self, value, name, center=1, bound=(0, float("inf")), clip_first_on_zero=True
    ):
        if isinstance(value, numbers.Number):
            if value < 0:
                raise ValueError(
                    "If {} is a single number, it must be non negative.".format(name)
                )
            value = [center - value, center + value]
            if clip_first_on_zero:
                value[0] = max(value[0], 0)
        elif isinstance(value, (tuple, list)) and len(value) == 2:
            if not bound[0] <= value[0] <= value[1] <= bound[1]:
                raise ValueError("{} values should be between {}".format(name, bound))
        else:
            raise TypeError(
                "{} should be a single number or a list/tuple with lenght 2.".format(
                    name
                )
            )
        # if value is 0 or (1., 1.) for brightness/contrast/saturation
        # or (0., 0.) for hue, do nothing
        if value[0] == value[1] == center:
            value = None
        return value

    def adjust_contrast(self, x):
        means = torch.mean(x, dim=(2, 3), keepdim=True)
        return torch.clamp(
            (x - means) * self.factor_contrast.view(len(x), 1, 1, 1) + means, 0, 1
        )

    def adjust_hue(self, x):
        h = x[:, 0, :, :]
        h += self.factor_hue.view(len(x), 1, 1) * 255.0 / 360.0
        h = h % 1
        x[:, 0, :, :] = h
        return x

    def adjust_brightness(self, x):
        x[:, 2, :, :] = torch.clamp(
            x[:, 2, :, :] * self.factor_brightness.view(len(x), 1, 1), 0, 1
        )
        return torch.clamp(x, 0, 1)

    def adjust_saturate(self, x):
        x[:, 1, :, :] = torch.clamp(
            x[:, 1, :, :] * self.factor_saturate.view(len(x), 1, 1), 0, 1
        )
        return torch.clamp(x, 0, 1)

    def transform(self, inputs):
        hsv_transform_list = [
            rgb2hsv,
            self.adjust_brightness,
            self.adjust_hue,
            self.adjust_saturate,
            hsv2rgb,
        ]
        rgb_transform_list = [self.adjust_contrast]

        # Shuffle transform
        if random.uniform(0, 1) >= 0.5:
            transform_list = rgb_transform_list + hsv_transform_list
        else:
            transform_list = hsv_transform_list + rgb_transform_list
        for t in transform_list:
            inputs = t(inputs)
        return inputs

    def __call__(self, imgs):
        b, c, h, w = imgs.shape
        original_device = imgs.device
        imgs = imgs.float().to(self._device)
        imgs = imgs / 255.0

        t = c // 3
        for i in range(t):
            idx = i * 3
            imgs[:, idx : idx + 3] = self.forward(imgs[:, idx : idx + 3])
        imgs = imgs.to(original_device) * 255.0
        return imgs

    def change_randomization_params(self):
        factor_contrast = torch.empty(self.batch_size, device=self._device).uniform_(
            *self.contrast
        )
        self.factor_contrast = (
            factor_contrast.reshape(-1, 1).repeat(1, self.stack_size).reshape(-1)
        )

        factor_hue = torch.empty(self.batch_size, device=self._device).uniform_(
            *self.hue
        )
        self.factor_hue = (
            factor_hue.reshape(-1, 1).repeat(1, self.stack_size).reshape(-1)
        )

        factor_brightness = torch.empty(self.batch_size, device=self._device).uniform_(
            *self.brightness
        )
        self.factor_brightness = (
            factor_brightness.reshape(-1, 1).repeat(1, self.stack_size).reshape(-1)
        )

        factor_saturate = torch.empty(self.batch_size, device=self._device).uniform_(
            *self.saturation
        )
        self.factor_saturate = (
            factor_saturate.reshape(-1, 1).repeat(1, self.stack_size).reshape(-1)
        )


    def forward(self, inputs):
        random_inds = np.random.choice(
            [True, False], len(inputs), p=[self.prob, 1 - self.prob]
        )
        inds = torch.tensor(random_inds).to(self._device)
        if random_inds.sum() > 0:
            inputs[inds] = self.transform(inputs[inds])
        return inputs


def rgb2hsv(rgb, eps=1e-8):
    # Reference: https://www.rapidtables.com/convert/color/rgb-to-hsv.html
    # Reference: https://github.com/scikit-image/scikit-image/blob/master/skimage/color/colorconv.py#L287

    _device = rgb.device
    r, g, b = rgb[:, 0, :, :], rgb[:, 1, :, :], rgb[:, 2, :, :]

    Cmax = rgb.max(1)[0]
    Cmin = rgb.min(1)[0]
    delta = Cmax - Cmin

    hue = torch.zeros((rgb.shape[0], rgb.shape[2], rgb.shape[3])).to(_device)
    hue[Cmax == r] = (((g - b) / (delta + eps)) % 6)[Cmax == r]
    hue[Cmax == g] = ((b - r) / (delta + eps) + 2)[Cmax == g]
    hue[Cmax == b] = ((r - g) / (delta + eps) + 4)[Cmax == b]
    hue[Cmax == 0] = 0.0
    hue = hue / 6.0  # making hue range as [0, 1.0)
    hue = hue.unsqueeze(dim=1)

    saturation = (delta) / (Cmax + eps)
    saturation[Cmax == 0.0] = 0.0
    saturation = saturation.to(_device)
    saturation = saturation.unsqueeze(dim=1)

    value = Cmax
    value = value.to(_device)
    value = value.unsqueeze(dim=1)

    return torch.cat(
        (hue, saturation, value), dim=1
    )  # .type(torch.FloatTensor).to(_device)
    # return hue, saturation, value


def hsv2rgb(hsv):
    # Reference: https://www.rapidtables.com/convert/color/hsv-to-rgb.html
    # Reference: https://github.com/scikit-image/scikit-image/blob/master/skimage/color/colorconv.py#L287

    _device = hsv.device

    hsv = torch.clamp(hsv, 0, 1)
    hue = hsv[:, 0, :, :] * 360.0
    saturation = hsv[:, 1, :, :]
    value = hsv[:, 2, :, :]

    c = value * saturation
    x = -c * (torch.abs((hue / 60.0) % 2 - 1) - 1)
    m = (value - c).unsqueeze(dim=1)

    rgb_prime = torch.zeros_like(hsv).to(_device)

    inds = (hue < 60) * (hue >= 0)
    rgb_prime[:, 0, :, :][inds] = c[inds]
    rgb_prime[:, 1, :, :][inds] = x[inds]

    inds = (hue < 120) * (hue >= 60)
    rgb_prime[:, 0, :, :][inds] = x[inds]
    rgb_prime[:, 1, :, :][inds] = c[inds]

    inds = (hue < 180) * (hue >= 120)
    rgb_prime[:, 1, :, :][inds] = c[inds]
    rgb_prime[:, 2, :, :][inds] = x[inds]

    inds = (hue < 240) * (hue >= 180)
    rgb_prime[:, 1, :, :][inds] = x[inds]
    rgb_prime[:, 2, :, :][inds] = c[inds]

    inds = (hue < 300) * (hue >= 240)
    rgb_prime[:, 2, :, :][inds] = c[inds]
    rgb_prime[:, 0, :, :][inds] = x[inds]

    inds = (hue < 360) * (hue >= 300)
    rgb_prime[:, 2, :, :][inds] = x[inds]
    rgb_prime[:, 0, :, :][inds] = c[inds]

    rgb = rgb_prime + torch.cat((m, m, m), dim=1)
    rgb = rgb.to(_device)

    return torch.clamp(rgb, 0, 1)


class NetworkRandomizationAug(torch.nn.Module):
    def __init__(self, batch_size, *_args, **_kwargs):
        super().__init__()
        self._device = "cuda" if torch.cuda.is_available() else "cpu"
        self.change_randomization_params()

    def forward(self, x):
        return self.conv(x)

    def __call__(self, imgs):
        b, c, h, w = imgs.shape
        t = c // 3
        imgs /= 255.0
        for i in range(t):
            idx = i * 3
            with torch.no_grad():
                imgs[:, idx : idx + 3] = self.forward(imgs[:, idx : idx + 3])
        imgs *= 255.0
        return imgs

    def change_randomization_params(self):
        self.conv = torch.nn.Conv2d(3, 3, kernel_size=3, bias=False, padding=1).to(
            self._device
        )
        torch.nn.init.xavier_normal_(self.conv.weight.data)
