# this file is taken from https://github.com/autonomousvision/stylegan-t/blob/36ab80ce76237fefe03e65e9b3161c040ae888e3/training/diffaug.py
import math

import torch
import torch.nn.functional as F


def load_png(file_name: str):
    from torchvision.io import read_image
    return read_image(file_name).float().div_(255).mul_(2).sub_(1).unsqueeze(0)  # to [-1, 1]


def show(tensor):  # from [-1, 1]
    from torchvision.utils import make_grid
    from torchvision.transforms.functional import to_pil_image
    if tensor.shape[0] == 1: tensor = tensor[0]
    if tensor.ndim == 3:
        to_pil_image(tensor.add(1).div_(2).clamp_(0, 1).detach().cpu()).convert('RGB').show()
    else:
        to_pil_image(make_grid(tensor.add(1).div_(2).clamp_(0, 1).detach().cpu())).convert('RGB').show()


class DiffAug(object):
    def __init__(self, prob=1.0, cutout=0.2):  # todo: swin ratio = 0.5, T&XL = 0.2
        self.grids = {}
        self.prob = abs(prob)
        self.using_cutout = prob > 0
        self.cutout = cutout
        self.img_channels = -1
        self.last_blur_radius = -1
        self.last_blur_kernel_h = self.last_blur_kernel_w = None

    def get_grids(self, B, x, y, dev):
        if (B, x, y) in self.grids:
            return self.grids[(B, x, y)]

        self.grids[(B, x, y)] = ret = torch.meshgrid(
            torch.arange(B, dtype=torch.long, device=dev),
            torch.arange(x, dtype=torch.long, device=dev),
            torch.arange(y, dtype=torch.long, device=dev),
            indexing='ij'
        )
        return ret

    def aug(self, BCHW: torch.Tensor, warmup_blur_schedule: float = 0) -> torch.Tensor:
        # warmup blurring
        if BCHW.dtype != torch.float32:
            BCHW = BCHW.float()
        if warmup_blur_schedule > 0:
            self.img_channels = BCHW.shape[1]
            sigma0 = (BCHW.shape[-2] * 0.5) ** 0.5
            sigma = sigma0 * warmup_blur_schedule
            blur_radius = math.floor(sigma * 3)  # 3-sigma is enough for Gaussian
            if blur_radius >= 1:
                if self.last_blur_radius != blur_radius:
                    self.last_blur_radius = blur_radius
                    gaussian = torch.arange(-blur_radius, blur_radius + 1, dtype=torch.float32, device=BCHW.device)
                    gaussian = gaussian.mul_(1 / sigma).square_().neg_().exp2_()
                    gaussian.div_(gaussian.sum())  # normalize
                    self.last_blur_kernel_h = gaussian.view(1, 1, 2 * blur_radius + 1, 1).repeat(self.img_channels, 1,
                                                                                                 1, 1).contiguous()
                    self.last_blur_kernel_w = gaussian.view(1, 1, 1, 2 * blur_radius + 1).repeat(self.img_channels, 1,
                                                                                                 1, 1).contiguous()

                BCHW = F.pad(BCHW, [blur_radius, blur_radius, blur_radius, blur_radius], mode='reflect')
                BCHW = F.conv2d(input=BCHW, weight=self.last_blur_kernel_h, bias=None, groups=self.img_channels)
                BCHW = F.conv2d(input=BCHW, weight=self.last_blur_kernel_w, bias=None, groups=self.img_channels)
                # BCHW = filter2d(BCHW, f.div_(f.sum()))  # no need to specify padding (filter2d will add padding in itself based on filter size)

        if self.prob < 1e-6:
            return BCHW
        trans, color, cut = torch.rand(3) <= self.prob
        trans, color, cut = trans.item(), color.item(), cut.item()
        B, dev = BCHW.shape[0], BCHW.device
        rand01 = torch.rand(7, B, 1, 1, device=dev) if (trans or color or cut) else None

        raw_h, raw_w = BCHW.shape[-2:]
        if trans:
            ratio = 0.125
            delta_h = round(raw_h * ratio)
            delta_w = round(raw_w * ratio)
            translation_h = rand01[0].mul(delta_h + delta_h + 1).floor().long() - delta_h
            translation_w = rand01[1].mul(delta_w + delta_w + 1).floor().long() - delta_w
            # translation_h = torch.randint(-delta_h, delta_h+1, size=(B, 1, 1), device=dev)
            # translation_w = torch.randint(-delta_w, delta_w+1, size=(B, 1, 1), device=dev)

            grid_B, grid_h, grid_w = self.get_grids(B, raw_h, raw_w, dev)
            grid_h = (grid_h + translation_h).add_(1).clamp_(0, raw_h + 1)
            grid_w = (grid_w + translation_w).add_(1).clamp_(0, raw_w + 1)
            bchw_pad = F.pad(BCHW, [1, 1, 1, 1, 0, 0, 0, 0])
            BCHW = bchw_pad.permute(0, 2, 3, 1).contiguous()[grid_B, grid_h, grid_w].permute(0, 3, 1, 2).contiguous()

        if color:
            BCHW = BCHW.add(rand01[2].unsqueeze(-1).sub(0.5))
            # BCHW.add_(torch.rand(B, 1, 1, 1, dtype=BCHW.dtype, device=dev).sub_(0.5))
            bchw_mean = BCHW.mean(dim=1, keepdim=True)
            BCHW = BCHW.sub(bchw_mean).mul(rand01[3].unsqueeze(-1).mul(2)).add_(bchw_mean)
            # BCHW.sub_(bchw_mean).mul_(torch.rand(B, 1, 1, 1, dtype=BCHW.dtype, device=dev).mul_(2)).add_(bchw_mean)
            bchw_mean = BCHW.mean(dim=(1, 2, 3), keepdim=True)
            BCHW = BCHW.sub(bchw_mean).mul(rand01[4].unsqueeze(-1).add(0.5)).add_(bchw_mean)
            # BCHW.sub_(bchw_mean).mul_(torch.rand(B, 1, 1, 1, dtype=BCHW.dtype, device=dev).add_(0.5)).add_(bchw_mean)

        if self.using_cutout and cut:
            ratio = self.cutout  # todo: styleswin ratio = 0.5, T&XL = 0.2
            cutout_h = round(raw_h * ratio)
            cutout_w = round(raw_w * ratio)
            offset_h = rand01[5].mul(raw_h + (1 - cutout_h % 2)).floor().long()
            offset_w = rand01[6].mul(raw_w + (1 - cutout_w % 2)).floor().long()
            # offset_h = torch.randint(0, raw_h + (1 - cutout_h % 2), size=(B, 1, 1), device=dev)
            # offset_w = torch.randint(0, raw_w + (1 - cutout_w % 2), size=(B, 1, 1), device=dev)

            grid_B, grid_h, grid_w = self.get_grids(B, cutout_h, cutout_w, dev)
            grid_h = (grid_h + offset_h).sub_(cutout_h // 2).clamp(min=0, max=raw_h - 1)
            grid_w = (grid_w + offset_w).sub_(cutout_w // 2).clamp(min=0, max=raw_w - 1)
            mask = torch.ones(B, raw_h, raw_w, dtype=BCHW.dtype, device=dev)
            mask[grid_B, grid_h, grid_w] = 0
            BCHW = BCHW.mul(mask.unsqueeze(1))

        return BCHW