################################################################################
# spectral/modules/spectral_regularizer.py
#
# 
# 
# 
# 2024
#
# Implements a spectral regularizer which uses a power iteration which should
# be updated only when the regularizer is called.

import torch

from typing import Optional

Tensor = torch.Tensor
Module = torch.nn.Module

class _SpectralRegularizer(Module):

  def __init__(self,
      # Arguments:
      module:           Module,
      name:             str,
      # Keyword Arguments:
      n_power_iterations: int   = 1,
      dim:                int   = 0,
      eps:                float = 1e-12
    ):
    super(_SpectralRegularizer, self).__init__()
    self.module = module
    self.name = name
    assert n_power_iterations > 0, \
      "n_power_iterations must be greater than 0."
    weight = self._get_weight()
    ndim = weight.ndim
    self.dim = dim if dim >= 0 else dim + ndim
    self.eps = eps
    if ndim > 1:
      self.n_power_iterations = n_power_iterations
      weight_mat = self._reshape_weight_to_matrix(weight)
      h, w = weight_mat.size()
      u = weight_mat.new_empty(h).normal_(0, 1)
      v = weight_mat.new_empty(w).normal_(0, 1)
      self.register_buffer(
        "_u",
        torch.nn.functional.normalize(u, dim = 0, eps = self.eps)
      )
      self.register_buffer(
        "_v",
        torch.nn.functional.normalize(v, dim = 0, eps = self.eps)
      )
      self._power_method(weight_mat, 15)

  def _reshape_weight_to_matrix(self,
      weight: Tensor
    ) -> Tensor:
    """
    """
    assert weight.ndim > 1
    if self.dim != 0:
      weight = weight.permute(
        self.dim, *(d for d in range(weight.dim()) if d != self.dim)
      )
    return weight.flatten(1)

  def _get_weight(self) -> Tensor:
    """
    Gets the weight from the module the regularizer acts for.

    Returns:
      Tensor:
        The weight of the module.
    """
    return getattr(self.module, self.name)

  @torch.autograd.no_grad()
  def _power_method(self,
      weight_mat:         Tensor,
      n_power_iterations: int
    ) -> None:
    """
    """
    assert weight_mat.ndim > 1
    for _ in range(n_power_iterations):
      self._u = torch.nn.functional.normalize(
        torch.mv(weight_mat, self._v), dim = 0, eps = self.eps, out = self._u
      )
      self._v = torch.nn.functional.normalize(
        torch.mv(weight_mat.H, self._u), dim = 0, eps = self.eps, out = self._v
      )

  def forward(self) -> Tensor:
    weight = self._get_weight()
    if weight.ndim == 1:
      return torch.linalg.norm(weight, ord = 2)
    else:
      weight_mat = self._reshape_weight_to_matrix(weight)
      if self.training:
        self._power_method(weight_mat, self.n_power_iterations)
      u = self._u.clone(memory_format = torch.contiguous_format)
      v = self._v.clone(memory_format = torch.contiguous_format)
      sigma = torch.dot(u, torch.mv(weight_mat, v))
      return sigma

def spectral_regularizer(
    # Arguments:
    module: Module,
    # Keyword Arguments:
    name:               str           = "weight",
    n_power_iterations: int           = 1,
    eps:                float         = 1e-12,
    dim:                Optional[int] = None
  ) -> Module:
  weight = getattr(module, name, None)
  assert isinstance(weight, Tensor), \
    f"{module} does not have a parameter or buffer called '{name}'."
  if dim is None:
    is_conv_transpose = isinstance(
      module,
      (
        torch.nn.ConvTranspose1d,
        torch.nn.ConvTranspose2d,
        torch.nn.ConvTranspose3d
      )
    )
    dim = 1 if is_conv_transpose else 0
  regularizer = _SpectralRegularizer(
    module,
    name,
    n_power_iterations,
    dim,
    eps
  )
  return regularizer