################################################################################
# spectral/modules/lipschitz_constraint.py
#
# 
# 
# 
# 2024
#
# Implements Lipschitz normalization, but with some modifications.

import torch

from typing import Callable, List, Optional, Tuple

Tensor = torch.Tensor
Module = torch.nn.Module

CONV_UPDATERS = {
  torch.nn.Conv1d: (
    torch.nn.functional.conv1d,
    torch.nn.functional.conv_transpose1d
  ),
  torch.nn.Conv2d: (
    torch.nn.functional.conv2d,
    torch.nn.functional.conv_transpose2d
  ),
  torch.nn.Conv3d: (
    torch.nn.functional.conv3d,
    torch.nn.functional.conv_transpose3d
  )
}

CONV_TRANSPOSE_MAP = {
  torch.nn.ConvTranspose1d: torch.nn.Conv1d,
  torch.nn.ConvTranspose2d: torch.nn.Conv2d,
  torch.nn.ConvTranspose3d: torch.nn.Conv3d
}

class _LipschitzConstraint(Module):

  def __init__(self,
      # Arguments:
      weight:           Tensor,
      forward_updater:  Callable[[Tensor, Tensor], Tensor],
      backward_updater: Callable[[Tensor, Tensor], Tensor],
      u_size:           Tuple[int, ...],
      v_size:           Tuple[int, ...],
      # Keyword Arguments:
      n_power_iterations: int   = 1,
      eps:                float = 1e-12,
    ):
    super(_LipschitzConstraint, self).__init__()
    assert n_power_iterations > 0, \
      "n_power_iterations must be greater than 0."
    self.n_power_iterations = n_power_iterations
    self.eps = eps
    self._forward_updater = forward_updater
    self._backward_updater = backward_updater
    u, v = torch.randn(u_size), torch.randn(v_size)
    u, v = self._normalize(u, u), self._normalize(v, v)
    self.register_buffer("_u", u)
    self.register_buffer("_v", v)
    self._power_method(weight, 15)
    self._update_enabled = True

  @torch.autograd.no_grad()
  def _normalize(self,
      # Arguments:
      x:   Tensor,
      out: Tensor,
    ) -> Tensor:
    """
    Based on implementation of ``torch.nn.functional.normalize``, but allows for
    unspecified dims.

    Args:
      x (Tensor):
        Input tensor to normalize.
      out (Tensor):
        Output tensor.
    """
    denominator = torch.linalg.norm(x, keepdim = True).clamp_min_(self.eps)
    return torch.div(x, denominator.expand_as(x), out = out)

  @torch.autograd.no_grad()
  def _power_method(self,
      weight:             Tensor,
      n_power_iterations: int
    ) -> None:
    for _ in range(n_power_iterations):
      self._u = self._normalize(
        self._forward_updater(weight, self._v),
        self._u
      )
      self._v = self._normalize(
        self._backward_updater(weight, self._u),
        self._v
      )

  def forward(self,
      weight: Tensor
    ) -> Tensor:
    if weight.ndim == 1:
      sigma = torch.linalg.norm(weight)
      return weight / (sigma + self.eps)
    else:
      if self.training and self._update_enabled:
        self._power_method(weight, self.n_power_iterations)
        self._update_enabled = False
      v = self._v.clone(memory_format = torch.contiguous_format)
      sigma = torch.linalg.norm(self._forward_updater(weight, v))
      return weight / sigma

def calculate_sizes(
    kernel_size: List[int],
    padding:     List[int]
  ) -> List[int]:
  # Temporarily ignore stride.
  sizes = [
    -2 * p + 2 * k + 1
    for k, p in zip(kernel_size, padding)
  ]
  return sizes

def lipschitz_constraint(
    # Arguments:
    module:      Module,
    # Keyword Arguments:
    name:                str           = "weight",
    n_power_iterations:  int           = 1,
    eps:                 float         = 1e-12,
    dim:                 Optional[int] = None
  ) -> Module:
  weight = getattr(module, name, None)
  assert isinstance(weight, Tensor), \
    f"{module} does not have a parameter or buffer called '{name}'."
  if isinstance(module, torch.nn.Linear):
    kind = "linear"
    forward_updater  = lambda w, x: torch.mv(w,     x)
    backward_updater = lambda w, x: torch.mv(w.t(), x)
    u_size, v_size = weight.size()
  elif CONV_TRANSPOSE_MAP.get(module.__class__) is not None:
    kind = "conv"
    bward, fward = CONV_UPDATERS[CONV_TRANSPOSE_MAP[module.__class__]]
  elif CONV_UPDATERS.get(module.__class__) is not None:
    kind = "conv"
    fward, bward = CONV_UPDATERS[module.__class__]
  else:
    raise Exception(f"Unsupported module '{module}'.")
  if kind == "conv":
    assert all(x == 1 for x in list(module.stride)), \
      "Lipschitz constraint has not been implemented for convolutions with " + \
      "stride != 1."
    sizes = calculate_sizes(
      list(module.kernel_size),
      list(module.padding)
    )
    u_size = tuple([1, module.out_channels] + sizes)
    v_size = tuple([1, module.in_channels ] + sizes)
    forward_updater = lambda w, x: fward(
      x, w,
      stride   = module.stride,
      padding  = module.padding
    )
    backward_updater = lambda w, x: bward(
      x, w,
      stride   = module.stride,
      padding  = module.padding
    )
  parametrization = _LipschitzConstraint(
    weight,
    forward_updater,
    backward_updater,
    u_size,
    v_size,
    n_power_iterations,
    eps
  )
  torch.nn.utils.parametrize.register_parametrization(
    module,
    name,
    parametrization
  )
  return parametrization