"""
The following comment is from: https://github.com/omegafragger/DDU/blob/main/net/spectral_normalization/spectral_norm_conv_inplace.py
which is where this code was adapted from

From: https://github.com/jhjacobsen/invertible-resnet
Which is based on: https://arxiv.org/abs/1811.00995
Soft Spectral Normalization (not enforced, only <= coeff) for Conv2D layers
Based on: Regularisation of Neural Networks by Enforcing Lipschitz Continuity
    (Gouk et al. 2018)
    https://arxiv.org/abs/1804.04368
"""

from typing import Tuple

import torch
from torch import nn
from torch.nn import functional as F
from torch.nn.functional import conv2d, conv_transpose2d, normalize

T = torch.Tensor


class SpectralNormConv(nn.Module):
    def __init__(
        self,
        base_layer: nn.Module,
        dim: Tuple[int, ...],
        ctype: str = "none",
        c: float = 1,
        weight_name: str = "weight",
        n_power_iterations: int = 1,
        eps: float = 1e-12
    ):
        super().__init__()

        self.base_layer = base_layer
        self.dim = dim
        self.weight_name = weight_name
        self.n_power_iterations = n_power_iterations
        self.eps = eps
        self.ctype = ctype

        if "Conv" not in base_layer.__class__.__name__:
            raise ValueError(f"the base layer needs to be a convolutional layer: got {base_layer.__class__.__name__}")

        if n_power_iterations <= 0:
            raise ValueError("Expected n_power_iterations to be positive, but got n_power_iterations={}".format(n_power_iterations))

        self.c: T
        if ctype == "none":
            self.register_buffer("c", torch.tensor(c, requires_grad=False))
        elif ctype == "scalar":
            self.c = nn.Parameter(torch.tensor(0.0), requires_grad=True)
        elif ctype == "vector" and isinstance(base_layer, nn.Conv2d):
            self.c = nn.Parameter(torch.zeros((self.base_layer.weight.size(0), 1, 1, 1), requires_grad=True))  # type: ignore
        elif ctype == "vector" and isinstance(base_layer, nn.Linear):
            self.c = nn.Parameter(torch.zeros((self.base_layer.weight.size(0), 1), requires_grad=True))  # type: ignore
        else:
            raise NotImplementedError(f"got an unknown combination of ctype: {ctype}")

        weight = self.base_layer._parameters[weight_name]
        if isinstance(weight, torch.nn.parameter.UninitializedParameter):
            raise ValueError(
                'The module passed to `SpectralNorm` can\'t have uninitialized parameters. '
                'Make sure to run the dummy forward before applying spectral normalization')

        with torch.no_grad():
            num_input_dim = dim[0] * dim[1] * dim[2] * dim[3]
            v = normalize(torch.randn(num_input_dim), dim=0, eps=eps)

            # get settings from conv-module (for transposed convolution)
            stride = base_layer.stride
            padding = base_layer.padding
            # forward call to infer the shape
            u = conv2d(v.view(*dim), weight, stride=stride, padding=padding, bias=None)  # type: ignore
            self.out_shape = u.shape
            num_output_dim = self.out_shape[0] * self.out_shape[1] * self.out_shape[2] * self.out_shape[3]
            # overwrite u with random init
            u = normalize(torch.randn(num_output_dim), dim=0, eps=eps)

        delattr(base_layer, weight_name)
        base_layer.register_parameter(weight_name + "_orig", weight)
        setattr(base_layer, weight_name, weight.data)
        base_layer.register_buffer(weight_name + "_u", u)
        base_layer.register_buffer(weight_name + "_v", v)
        base_layer.register_buffer(weight_name + "_sigma", torch.ones(1).to(weight.device))

    def compute_weight(self) -> None:
        # NB: If `do_power_iteration` is set, the `u` and `v` vectors are
        #     updated in power iteration **in-place**. This is very important
        #     because in `DataParallel` forward, the vectors (being buffers) are
        #     broadcast from the parallelized module to each module replica,
        #     which is a new module object created on the fly. And each replica
        #     runs its own spectral norm power iteration. So simply assigning
        #     the updated vectors to the module this function runs on will cause
        #     the update to be lost forever. And the next time the parallelized
        #     module is replicated, the same randomly initialized vectors are
        #     broadcast and used!
        #
        #     Therefore, to make the change propagate back, we rely on two
        #     important bahaviors (also enforced via tests):
        #       1. `DataParallel` doesn't clone storage if the broadcast tensor
        #          is alreay on correct device; and it makes sure that the
        #          parallelized module is already on `device[0]`.
        #       2. If the out tensor in `out=` kwarg has correct shape, it will
        #          just fill in the values.
        #     Therefore, since the same power iteration is performed on all
        #     devices, simply updating the tensors in-place will make sure that
        #     the module replica on `device[0]` will update the _u vector on the
        #     parallized module (by shared storage).
        #
        #    However, after we update `u` and `v` in-place, we need to **clone**
        #    them before using them to normalize the weight. This is to support
        #    backproping through two forward passes, e.g., the common pattern in
        #    GAN training: loss = D(real) - D(fake). Otherwise, engine will
        #    complain that variables needed to do backward for the first forward
        #    (i.e., the `u` and `v` vectors) are changed in the second forward.

        weight = getattr(self.base_layer, self.weight_name + "_orig")
        u = getattr(self.base_layer, self.weight_name + "_u")
        v = getattr(self.base_layer, self.weight_name + "_v")
        sigma_log = getattr(self.base_layer, self.weight_name + "_sigma")  # for logging

        # get settings from conv-module (for transposed convolution)
        stride = self.base_layer.stride
        padding = self.base_layer.padding

        if not isinstance(stride, tuple) or not isinstance(padding, tuple):
            raise ValueError(f"stride: ({type(stride)}) and padding: ({type(padding)})")

        if self.training:
            with torch.no_grad():
                output_padding = 0
                if stride[0] > 1:
                    # Note: the below does not generalize to stride > 2
                    output_padding = 1 - self.dim[-1] % 2
                for _ in range(self.n_power_iterations):
                    v_s = conv_transpose2d(
                        u.view(self.out_shape),
                        weight,
                        stride=stride,
                        padding=padding,
                        output_padding=output_padding,
                    )

                    # Note: out flag for in-place changes
                    v = normalize(v_s.view(-1), dim=0, eps=self.eps, out=v)
                    u_s = conv2d(v.view(self.dim), weight, stride=stride, padding=padding, bias=None,)
                    u = normalize(u_s.view(-1), dim=0, eps=self.eps, out=u)
                if self.n_power_iterations > 0:
                    # See above on why we need to clone
                    u = u.clone()
                    v = v.clone()

        weight_v = conv2d(v.view(self.dim), weight, stride=stride, padding=padding, bias=None)
        weight_v = weight_v.view(-1)
        sigma = torch.dot(u.view(-1), weight_v)
        # enforce spectral norm only as constraint
        c = (0.01 + 0.99 * F.softplus(self.c)) if self.ctype != "none" else self.c
        factorReverse = torch.max(torch.ones(1, device=weight.device), sigma / c)

        # rescaling
        weight = weight / (factorReverse + 1e-5)  # for stability

        # for logging
        sigma_log.copy_(sigma.detach())
        setattr(self.base_layer, self.weight_name, weight)

    def forward(self, x: T) -> T:
        if self.base_layer.weight.device != self.base_layer.weight_u.device:
            self.base_layer.weight = self.base_layer.weight.to(self.base_layer.weight_u.device)  # type: ignore

        self.compute_weight()
        return self.base_layer(x)  # type: ignore
