################################################################################
# spectral/modules/spectral_lipschitz.py
#
# 
# 
# 
# 2024
#
# Implements a modified version of the spectral normalization parametrization.
# In this case, the spectral norm is calculated for Lipschitz.

import torch

from typing import Callable, List, Optional, Tuple

Tensor = torch.Tensor
Module = torch.nn.Module

CONV_UPDATERS = {
  torch.nn.Conv1d: (
    torch.nn.functional.conv1d,
    torch.nn.functional.conv_transpose1d
  ),
  torch.nn.Conv2d: (
    torch.nn.functional.conv2d,
    torch.nn.functional.conv_transpose2d
  ),
  torch.nn.Conv3d: (
    torch.nn.functional.conv3d,
    torch.nn.functional.conv_transpose3d
  )
}

CONV_TRANSPOSE_MAP = {
  torch.nn.ConvTranspose1d: torch.nn.Conv1d,
  torch.nn.ConvTranspose2d: torch.nn.Conv2d,
  torch.nn.ConvTranspose3d: torch.nn.Conv3d
}

class _LipschitzSpectralNorm(Module):

  def __init__(self,
      # Arguments:
      weight:           Tensor,
      forward_updater:  Callable[[Tensor, Tensor], Tensor],
      backward_updater: Callable[[Tensor, Tensor], Tensor],
      u_size:           Tuple[int, ...],
      v_size:           Tuple[int, ...],
      # Keyword Arguments:
      n_power_iterations: int   = 1,
      eps:                float = 1e-12
    ):
    super(_LipschitzSpectralNorm, self).__init__()
    assert n_power_iterations > 0, \
      "n_power_iterations must be greater than 0."
    self.n_power_iterations = n_power_iterations
    self.eps = eps
    self._forward_updater = forward_updater
    self._backward_updater = backward_updater
    u, v = torch.randn(u_size), torch.randn(v_size)
    u, v = self._normalize(u, u), self._normalize(v, v)
    self.register_buffer("_u", u)
    self.register_buffer("_v", v)
    self._power_method(weight, 15)
    self._enabled = True

  @torch.autograd.no_grad()
  def _normalize(self,
      # Arguments:
      x:   Tensor,
      out: Tensor,
    ) -> Tensor:
    """
    Based on implementation of ``torch.nn.functional.normalize``, but allows for
    unspecified dims.

    Args:
      x (Tensor):
        Input tensor to normalize.
      out (Tensor):
        Output tensor.
    """
    denominator = torch.linalg.norm(x, keepdim = True).clamp_min_(self.eps)
    return torch.div(x, denominator.expand_as(x), out = out)

  @torch.autograd.no_grad()
  def _power_method(self,
      weight:             Tensor,
      n_power_iterations: int
    ) -> None:
    for _ in range(n_power_iterations):
      self._u = self._normalize(
        self._forward_updater(weight, self._v),
        self._u
      )
      self._v = self._normalize(
        self._backward_updater(weight, self._u),
        self._v
      )

  def forward(self,
      weight: Tensor
    ) -> Tensor:
    if weight.ndim == 1:
      return torch.nn.functional.normalize(weight, dim = 0, eps = self.eps)
    else:
      if self.training and self._enabled:
        self._power_method(weight, self.n_power_iterations)
      v = self._v.clone(memory_format = torch.contiguous_format)
      sigma = torch.linalg.norm(self._forward_updater(weight, v))
      return weight / sigma

  def right_inverse(self,
      value: Tensor
    ) -> Tensor:
    return value

class _LipschitzSpectralRegularizer(_LipschitzSpectralNorm):

  def __init__(self,
      # Arguments:
      weight:           Tensor,
      forward_updater:  Callable[[Tensor, Tensor], Tensor],
      backward_updater: Callable[[Tensor, Tensor], Tensor],
      u_size:           Tuple[int, ...],
      v_size:           Tuple[int, ...],
      # Keyword Arguments:
      n_power_iterations: int   = 1,
      eps:                float = 1e-12
    ):
    super(_LipschitzSpectralRegularizer, self).__init__(
      weight             = weight,
      forward_updater    = forward_updater,
      backward_updater   = backward_updater,
      u_size             = u_size,
      v_size             = v_size,
      n_power_iterations = n_power_iterations,
      eps                = eps
    )
    self.lipschitz_scale = torch.nn.Parameter(torch.tensor(0.0))

  def get_scale(self) -> Tensor:
    return self.lipschitz_scale + 1.0

  def forward(self,
      weight: Tensor
    ) -> Tensor:
    norm_weight = super(_LipschitzSpectralRegularizer, self).forward(weight)
    scale = self.get_scale()
    if (not self.training) and self._enabled:
      print(
        self.lipschitz_scale.detach().cpu().item(),
        scale.detach().cpu().item()
      )
    return scale * norm_weight

def calculate_sizes(
    kernel_size: List[int],
    padding:     List[int],
    stride:      List[int]
  ) -> List[int]:
  # Temporarily ignore stride.
  sizes = [
    -2 * p + 2 * k + 1
    for k, p, s in zip(kernel_size, padding, stride)
  ]
  return sizes

def lipschitz_spectral_norm(
    # Arguments:
    module:      Module,
    # Keyword Arguments:
    name:                str   = "weight",
    n_power_iterations:  int   = 1,
    eps:                 float = 1e-12,
    get_parametrization: bool  = False,
    soft:                bool  = False
  ) -> Module:
  weight = getattr(module, name, None)
  assert isinstance(weight, Tensor), \
    f"{module} does not have a parameter or buffer called '{name}'."
  if isinstance(module, torch.nn.Linear):
    kind = "linear"
    forward_updater  = lambda w, x: torch.mv(w,     x)
    backward_updater = lambda w, x: torch.mv(w.t(), x)
    u_size, v_size = weight.size()
  elif CONV_TRANSPOSE_MAP.get(module.__class__) is not None:
    kind = "conv"
    bward, fward = CONV_UPDATERS[CONV_TRANSPOSE_MAP[module.__class__]]
  elif CONV_UPDATERS.get(module.__class__) is not None:
    kind = "conv"
    fward, bward = CONV_UPDATERS[module.__class__]
  else:
    raise Exception(f"Unsupported module '{module}'.")
  if kind == "conv":
    sizes = calculate_sizes(
      list(module.kernel_size),
      list(module.padding),
      list(module.stride)
    )
    u_size = tuple([1, module.out_channels] + sizes)
    v_size = tuple([1, module.in_channels ] + sizes)
    forward_updater = lambda w, x: fward(
      x, w,
      stride   = module.stride,
      padding  = module.padding
    )
    backward_updater = lambda w, x: bward(
      x, w,
      stride   = module.stride,
      padding  = module.padding
    )
  if soft:
    parametrization = _LipschitzSpectralRegularizer(
      weight,
      forward_updater,
      backward_updater,
      u_size,
      v_size,
      n_power_iterations,
      eps
    )
  else:
    parametrization = _LipschitzSpectralNorm(
      weight,
      forward_updater,
      backward_updater,
      u_size,
      v_size,
      n_power_iterations,
      eps
    )
  torch.nn.utils.parametrize.register_parametrization(
    module,
    name,
    parametrization
  )
  return (module, parametrization) if get_parametrization else module

# https://github.com/henrygouk/keras-lipschitz-networks/blob/master/arch/lipschitz.py