################################################################################
# spectral/modules/prefixsums/model.py
#
# 
# 
# 
# 2024
#

import torch

from typing   import Callable, Optional, Tuple, Type, Union
from warnings import warn

from deepthinking.models.deep_thinking_recall import DeepThinkingRecall
from experilog.logger                         import JSONType
from modules.spectral_set                     import set_spectral_norm

Module = torch.nn.Module
Tensor = torch.Tensor

# Defaults.
ACTIVATION:               Type[Module] = torch.nn.ReLU
BIAS:                     bool         = False
INPUT_CHANNELS:           int          = 1
MAX_ITERATIONS:           int          = 30
OUTPUT_CHANNELS:          int          = 2
PARAMETRIC_SKIP:          bool         = False
USE_BATCHNORM:            bool         = False
USE_INCREMENTAL_PROGRESS: bool         = False

class PrefixSumsInput(Module):
  """
  Input module for the prefix sums problem.
  """

  def __init__(self,
      # Arguments:
      *args,
      # Keyword Arguments:
      in_channels:   Optional[int]  = None,
      use_batchnorm: Optional[bool] = None,
      **kwargs
    ) -> None:
    """
    Initializes ``PrefixSumsInput``.

    Args:
      *args:
        Additional arguments for ``torch.nn.Module``.
      in_channels (int, optional):
        The number of channels for input. Needed for norm.
        Defaults to ``INPUT_CHANNELS`` if ``None``.
      use_batchnorm (bool, optional):
        Whether to use batchnorm for layers.
        Defaults to ``USE_BATCHNORM`` if ``None``.
      **kwargs:
        Additional keyword arguments for ``torch.nn.Module``.
    """
    super(PrefixSumsInput, self).__init__(*args, **kwargs)
    # In channels.
    assert isinstance(in_channels, int) or in_channels is None, \
      "in_channels must be an integer or None."
    self.in_channels = INPUT_CHANNELS if in_channels is None else in_channels
    # Batch norm.
    assert isinstance(use_batchnorm, bool) or use_batchnorm is None, \
      "use_batchnorm must be a boolean or None."
    self.use_batchnorm = USE_BATCHNORM if use_batchnorm is None \
                                       else use_batchnorm
    self.ident = torch.nn.Identity()
    self.norm1 = torch.nn.BatchNorm1d(self.in_channels) if self.use_batchnorm \
                                                        else lambda x: x

  def forward(self,
      # Arguments:
      x: Tensor
    ) -> Tensor:
    """
    Forward function for ``PrefixSumsInput``.

    Args:
      x (Tensor):
        The input tensor.

    Returns:
      Tensor:
        The input module's output tensor.
    """
    x_tilde = self.norm1(self.ident(x))
    return x_tilde

class PrefixSumsPreprocessing(Module):
  """
  Preprocessing module for the prefix sums problem.
  """

  def __init__(self,
      # Arguments:
      width: int,
      *args,
      # Keyword Arguments:
      activation:    Optional[Type[Module]] = None,
      bias:          Optional[bool]         = None,
      in_channels:   Optional[int]          = None,
      use_batchnorm: Optional[bool]         = None,
      **kwargs
    ) -> None:
    """
    Initializes ``PrefixSumsPreprocessing``.

    Args:
      width (int):
        The width of the network, as defined in [1]. This should be the same as
        the width for ``PrefixSumsThought``.
      *args:
        Additional arguments for ``torch.nn.Module``.
      activation (torch.nn.Module (class), optional):
        The activation function.
        Defaults to ``ACTIVATION`` if ``None``.
      bias (bool, optional):
        Whether to use bias for convolutional layers.
        Defaults to ``BIAS`` if ``None``.
      in_channels (int, optional):
        The number of channels for input.
        Defaults to ``INPUT_CHANNELS`` if ``None``.
      use_batchnorm (bool, optional):
        Whether to use batchnorm for layers.
        Defaults to ``USE_BATCHNORM`` if ``None``.
      **kwargs:
        Additional keyword arguments for ``torch.nn.Module``.
    """
    super(PrefixSumsPreprocessing, self).__init__(*args, **kwargs)
    # Width.
    assert isinstance(width, int), \
      "width must be an integer."
    self.width = width
    # Bias.
    assert isinstance(bias, bool) or bias is None, \
      "bias must be a boolean or None."
    self.bias = BIAS if bias is None else bias
    # In channels.
    assert isinstance(in_channels, int) or in_channels is None, \
      "in_channels must be an integer or None."
    self.in_channels = INPUT_CHANNELS if in_channels is None else in_channels
    # Batch norm.
    assert isinstance(use_batchnorm, bool) or use_batchnorm is None, \
      "use_batchnorm must be a boolean or None."
    self.use_batchnorm = USE_BATCHNORM if use_batchnorm is None \
                                       else use_batchnorm
    # Activation.
    activation = ACTIVATION if activation is None else activation
    # Model, convolution 1.
    self.conv1 = torch.nn.Conv1d(
      in_channels  = self.in_channels,
      out_channels = self.width,
      kernel_size  = 3,
      stride       = 1,
      padding      = 1,
      bias         = self.bias
    )
    self.norm1 = torch.nn.BatchNorm1d(self.width) if self.use_batchnorm \
                                                  else lambda x: x
    self.act1 = activation()

  def forward(self,
      # Arguments:
      x_tilde: Tensor
    ) -> Tensor:
    """
    Forward function for ``PrefixSumsPreprocessing``.

    Args:
      x_tilde (Tensor):
        The input tensor provided from the input module.

    Returns:
      Tensor:
        The output tensor from the preprocessing module.
    """
    phi_0 = self.act1(self.norm1(self.conv1(x_tilde)))
    return phi_0

class PrefixSumsThought(Module):
  """
  Thought module for the prefix sums problem.
  """

  def __init__(self,
      # Arguments:
      width: int,
      *args,
      # Keyword Arguments:
      activation:      Optional[Type[Module]] = None,
      bias:            Optional[bool]         = None,
      in_channels:     Optional[int]          = None,
      parametric_skip: Optional[bool]         = None,
      recall_norm:     Optional[float]        = None,
      spectral_norm:   Optional[float]        = None,
      **kwargs
    ) -> None:
    """
    Initializes ``PrefixSumsThought``.

    Args:
      width (int):
        The width of the network, as defined in [1]. This should be the same as
        the width for ``PrefixSumsPreprocessing``.
      *args:
        Additional arguments for ``torch.nn.Module``.
      activation (torch.nn.Module (class), optional):
        The activation function.
        Defaults to ``ACTIVATION`` if ``None``.
      bias (bool, optional):
        Whether to use bias for convolutional layers.
        Defaults to ``BIAS`` if ``None``.
      in_channels (int, optional):
        The number of channels for input.
        Defaults to ``INPUT_CHANNELS`` if ``None``.
      parametric_skip (bool, optional):
        Whether the skip connection should have a parameterized weight.
        Defaults to ``PARAMETRIC_SKIP`` if ``None``.
      recall_norm (float, optional):
        The value to set the spectral norm of the recall layer to.
        Defaults to ``None``, which matches the choice of ``spectral_norm``.
      spectral_norm (float, optional):
        The value to set the spectral norm to.
        Defaults to ``None``, which disables any spectral norm setting.
      **kwargs:
        Additional keyword arguments for ``torch.nn.Module``.
    """
    super(PrefixSumsThought, self).__init__(*args, **kwargs)
    # Width.
    assert isinstance(width, int), \
      "width must be an integer."
    self.width = width
    # Bias.
    assert isinstance(bias, bool) or bias is None, \
      "bias must be a boolean or None."
    self.bias = BIAS if bias is None else bias
    # In channels.
    assert isinstance(in_channels, int) or in_channels is None, \
      "in_channels must be an integer or None."
    self.in_channels = INPUT_CHANNELS if in_channels is None else in_channels
    # Parametric skip.
    assert isinstance(parametric_skip, bool) or parametric_skip is None, \
      "parametric_skip must be a boolean or None."
    self.parametric_skip = parametric_skip
    if self.parametric_skip:
      self.skip1_weight = torch.nn.Parameter(torch.randn((1,)))
      self.skip2_weight = torch.nn.Parameter(torch.randn((1,)))
      self.skip1 = self.get_parametric_skip(self.skip1_weight)
      self.skip2 = self.get_parametric_skip(self.skip2_weight)
    else:
      self.skip1 = lambda x, y: x + y
      self.skip2 = self.skip1
    # Spectral norm setting.
    assert isinstance(spectral_norm, float) or spectral_norm is None, \
      "spectral_norm must be a float or None."
    self.spectral_norm = spectral_norm
    assert isinstance(recall_norm, float) or recall_norm is None, \
      "recall_norm must be a float or None."
    self.recall_norm = recall_norm
    if self.spectral_norm is None:
      spectral_fn = lambda x: x
    else:
      spectral_fn = lambda x: set_spectral_norm(
        x,
        lambda weight, sigma: self.spectral_norm * (weight / sigma)
      )
    if self.recall_norm is None:
      recall_fn = spectral_fn
    else:
      recall_fn = lambda x: set_spectral_norm(
        x,
        lambda weight, sigma: self.recall_norm * (weight / sigma)
      )
    # Activation.
    activation = ACTIVATION if activation is None else activation
    # Model, recall.
    self.conv_recall = recall_fn(torch.nn.Conv1d(
      in_channels  = self.width + self.in_channels,
      out_channels = self.width,
      kernel_size  = 3,
      stride       = 1,
      padding      = 1,
      bias         = self.bias
    ))
    # Model, convolutions.
    for i in range(1, 5):
      setattr(self, f"conv{i}", spectral_fn(torch.nn.Conv1d(
        in_channels  = self.width,
        out_channels = self.width,
        kernel_size  = 3,
        stride       = 1,
        padding      = 1,
        bias         = self.bias
      )))
      setattr(self, f"act{i}", activation())

  def get_parametric_skip(self,
      # Arguments:
      skip_weight: Tensor
    ) -> Callable[[Tensor, Tensor], Tensor]:
    """
    Retrieves a function to perform the parametric skip connection.

    Args:
      skip_weight (Tensor):
        The skip parameter weight.

    Returns:
      Callable[[Tensor, Tensor], Tensor]:
        The function to perform the skip.
    """
    def temp(x, y):
      p = torch.sigmoid(skip_weight)
      return (1 - p) * x + p * y
    return temp

  def forward(self,
      phi:     Tensor,
      x_tilde: Tensor
    ) -> Tensor:
    """
    Forward function for ``PrefixSumsThought``.

    Args:
      phi (Tensor):
        Input tensor from the previous recurrence.
      x_tilde (Tensor):
        Input tensor from the input module.

    Returns:
      Tensor:
        The output tensor from the current recurrence of the thought module.
    """
    t = torch.cat([phi, x_tilde], dim = 1)
    t = self.conv_recall(t)
    t = self.act2(self.skip1(t, self.conv2(self.act1(self.conv1(t)))))
    t = self.act4(self.skip2(t, self.conv4(self.act3(self.conv3(t)))))
    return t

class PrefixSumsOutput(Module):
  """
  Output module for the prefix sums problem.
  """

  def __init__(self,
      # Arguments:
      width:    int,
      channels: Tuple[int, int],
      *args,
      # Keyword Arguments:
      activation:    Optional[Type[Module]] = None,
      bias:          Optional[bool]         = None,
      final_bias:    Optional[bool]         = None,
      out_channels:  Optional[int]          = None,
      use_batchnorm: Optional[bool]         = None,
      **kwargs
    ) -> None:
    """
    Initializes ``PrefixSumsOutput``.

    Args:
      width (int):
        The width of the network, as defined in [1]. This should be the same as
        the width for ``PrefixSumsThought``.
      channels (Tuple[int, int]):
        The channels of the two "h" layers (except the last).
      *args:
        Additional arguments for ``torch.nn.Module``.
      activation (torch.nn.Module (class), optional):
        The activation function.
        Defaults to ``ACTIVATION`` if ``None``.
      bias (bool, optional):
        Whether to use bias for convolutional layers.
        Defaults to ``BIAS`` if ``None``.
      final_bias (bool, optional):
        Whether to use bias for the final convolutional layer.
        Defaults to ``bias`` (argument) if ``None``.
      out_channels (int, optional):
        The number of channels for output.
        Defaults to ``OUTPUT_CHANNELS`` if ``None``.
      use_batchnorm (bool, optional):
        Whether to use batchnorm for layers.
        Defaults to ``USE_BATCHNORM`` if ``None``.
      **kwargs:
        Additional keyword arguments for ``torch.nn.Module``.
    """
    super(PrefixSumsOutput, self).__init__(*args, **kwargs)
    # Width.
    assert isinstance(width, int), \
      "width must be an integer."
    self.width = width
    # Channels.
    assert isinstance(channels[0], int) and isinstance(channels[1], int), \
      "channels must be a pair of integers."
    self.channels = channels
    # Bias.
    assert isinstance(bias, bool) or bias is None, \
      "bias must be a boolean or None."
    self.bias = BIAS if bias is None else bias
    # Final bias.
    assert isinstance(final_bias, bool) or final_bias is None, \
      "final_bias must be a boolean or None."
    self.final_bias = self.bias if final_bias is None else final_bias
    # Out channels.
    assert isinstance(out_channels, int) or out_channels is None, \
      "out_channels must be an integer or None."
    self.out_channels = INPUT_CHANNELS if out_channels is None else out_channels
    # Batch norm.
    assert isinstance(use_batchnorm, bool) or use_batchnorm is None, \
      "use_batchnorm must be a boolean or None."
    self.use_batchnorm = USE_BATCHNORM if use_batchnorm is None \
                                       else use_batchnorm
    # Model, convolution 1.
    self.conv1 = torch.nn.Conv1d(
      in_channels  = self.width,
      out_channels = self.channels[0],
      kernel_size  = 3,
      stride       = 1,
      padding      = 1,
      bias         = self.bias
    )
    self.norm1 = torch.nn.BatchNorm1d(self.channels[0]) if self.use_batchnorm \
                                                        else lambda x: x
    self.act1 = activation()
    # Model, convolution 2.
    self.conv2 = torch.nn.Conv1d(
      in_channels  = self.channels[0],
      out_channels = self.channels[1],
      kernel_size  = 3,
      stride       = 1,
      padding      = 1,
      bias         = self.bias
    )
    self.norm2 = torch.nn.BatchNorm1d(self.channels[1]) if self.use_batchnorm \
                                                        else lambda x: x
    self.act2 = activation()
    # Model, convolution 3.
    self.conv3 = torch.nn.Conv1d(
      in_channels  = self.channels[1],
      out_channels = self.out_channels,
      kernel_size  = 3,
      stride       = 1,
      padding      = 1,
      bias         = self.final_bias
    )

  def forward(self,
      # Arguments:
      phi_M: Tensor
    ) -> Tensor:
    """
    Forward function for ``PrefixSumsOutput``.

    Args:
      phi_M (Tensor):
        The input tensor from the final recurrence of ``PrefixSumsThought``.

    Returns:
      Tensor:
        The final output tensor.
    """
    t = self.act1(self.norm1(self.conv1(phi_M)))
    t = self.act2(self.norm2(self.conv2(t)))
    y_hat = self.conv3(t)
    return y_hat
  
def make_prefix_sums_model(
    # Arguments:
    width:      int,
    h_channels: Tuple[int, int],
    *args,
    # Keyword Arguments:
    activation:               Optional[Type[Module]] = None,
    bias:                     Optional[bool]         = None,
    final_bias:               Optional[bool]         = None,
    in_channels:              Optional[int]          = None,
    max_iterations:           Optional[int]          = None,
    out_channels:             Optional[int]          = None,
    parametric_skip:          Optional[bool]         = None,
    recall_norm:              Optional[float]        = None,
    spectral_norm:            Optional[float]        = None,
    use_batchnorm:            Optional[bool]         = None,
    use_incremental_progress: Optional[bool]         = None,
    get_json_dict:            bool                   = False,
    **kwargs
  ) -> Union[Module, Tuple[Module, JSONType]]:
  """
  Constructs the prefix sums deep thinking model.

  Args:
    width (int):
      The width of the network, as defined in [1]. This should be the same as
      the width for ``PrefixSumsThought``.
    h_channels (Tuple[int, int]):
      The channels of the two "h" layers (except the last).
    *args:
      Additional arguments for every ``torch.nn.Module``.
    activation (torch.nn.Module (class), optional):
      The activation function.
      Defaults to ``ACTIVATION`` if ``None``.
    bias (bool, optional):
      Whether to use bias for convolutional layers.
      Defaults to ``BIAS`` if ``None``.
    final_bias (bool, optional):
      Whether to use bias for the final convolutional layer.
      Defaults to ``bias`` (argument) if ``None``.
    in_channels (int, optional):
      The number of channels for input.
      Defaults to ``INPUT_CHANNELS`` if ``None``.
    max_iterations (int, optional):
      The maximum number of iterations of the thought network to perform. This
      can be changed safely inside the model, and can be changed during
      inference using the ``max_iterations`` argument of the forward function.
      Defaults to ``MAX_ITERATIONS`` if ``None``.
    out_channels (int, optional):
      The number of channels for output.
      Defaults to ``OUTPUT_CHANNELS`` if ``None``.
    parametric_skip (bool, optional):
      Whether the skip connection should have a parameterized weight.
      Defaults to ``PARAMETRIC_SKIP`` if ``None``.
    recall_norm (float, optional):
      The value to set the spectral norm of the recall layer to.
      Defaults to ``None``, which matches the choice of ``spectral_norm``.
    spectral_norm (float, optional):
      The value to set the spectral norm to.
      Defaults to ``None``, which disables any spectral norm setting.
    use_batchnorm (bool, optional):
      Whether to use batchnorm for layers.
      Defaults to ``USE_BATCHNORM`` if ``None``.
    use_incremental_progress (bool, optional):
      Whether to use incremental progress training, as defined in [1]. This
      only applies during training.
      Defaults to ``USE_INCREMENTAL_PROGRESS`` if ``None``.
    get_json_dict (bool, optional):
      If ``True``, this returns a dictionary (in JSON-ready format) containing
      the above parameters used for the model. This can be read to construct the
      model again with ``load_from_json_dict``.
      Defaults to ``False``.
    **kwargs:
      Additional keyword arguments for ``torch.nn.Module``.

  Returns:
    (Module | (Module, JSONType)):
      The model if ``get_json_dict`` is ``False``, otherwise a tuple containing
      both the model and the JSON-ready dictionary describing the model's
      hyperparameters.
  """
  assert isinstance(get_json_dict, bool), \
    "get_json_dict must be a boolean."
  max_iterations = MAX_ITERATIONS if max_iterations is None else max_iterations
  use_incremental_progress = USE_INCREMENTAL_PROGRESS \
                             if use_incremental_progress is None \
                             else use_incremental_progress
  model = DeepThinkingRecall(
    PrefixSumsInput(
      *args,
      in_channels   = in_channels,
      use_batchnorm = use_batchnorm,
      **kwargs
    ),
    PrefixSumsPreprocessing(
      width,
      *args,
      activation    = activation,
      bias          = bias,
      in_channels   = in_channels,
      use_batchnorm = use_batchnorm,
      **kwargs
    ),
    PrefixSumsThought(
      width,
      *args,
      activation      = activation,
      bias            = bias,
      in_channels     = in_channels,
      parametric_skip = parametric_skip,
      recall_norm     = recall_norm,
      spectral_norm   = spectral_norm,
      **kwargs
    ),
    PrefixSumsOutput(
      width,
      h_channels,
      *args,
      activation    = activation,
      bias          = bias,
      final_bias    = final_bias,
      out_channels  = out_channels,
      use_batchnorm = use_batchnorm,
      **kwargs
    ),
    *args,
    max_iterations           = max_iterations,
    use_incremental_progress = use_incremental_progress,
    **kwargs
  )
  if get_json_dict:
    # Get defaults for JSON. Default values still need to be recorded.
    activation = ACTIVATION if activation is None else activation
    bias = BIAS if bias is None else bias
    in_channels = INPUT_CHANNELS if in_channels is None else in_channels
    out_channels = OUTPUT_CHANNELS if out_channels is None else out_channels
    # For almost all activation functions, this shouldn't be an issue, but
    #  as a precaution a warning is given if it isn't possible to load the
    #  correct function from a string.
    try:
      getattr(torch.nn, activation.__name__)
    except:
      warn(
        f"'{activation.__name__}' cannot be converted back into the " + \
        "torch.nn class it comes from. The name will still be recorded, " + \
        "but experiments reading this model from a file may need to have " + \
        "its activation manually set to the correct class."
      )
    json_dict = {
      "width":                    width,
      "h_channels":               list(h_channels),
      "activation":               activation.__name__,
      "bias":                     bias,
      "final_bias":               final_bias,
      "in_channels":              in_channels,
      "max_iterations":           max_iterations,
      "out_channels":             out_channels,
      "parametric_skip":          parametric_skip,
      "recall_norm":              recall_norm,
      "spectral_norm":            spectral_norm,
      "use_batchnorm":            use_batchnorm,
      "use_incremental_progress": use_incremental_progress
    }
    return (model, json_dict)
  return model

def load_from_json_dict(
    # Arguments:
    json_dict: JSONType
  ) -> Module:
  """
  Constructs a prefix sums deep thinking model from a JSON-ready dictionary.

  Args:
    json_dict (JSONType):
      The JSON-ready dictionary containing the construction hyperparameters.

  Returns:
    Module:
      The reconstructed model.
  """
  # Yes, I'm fully aware that **{...} exists.
  return make_prefix_sums_model(
    width                    = int(json_dict["width"]),
    h_channels               = tuple(int(x) for x in json_dict["h_channels"]),
    activation               = getattr(torch.nn, json_dict["activation"]),
    bias                     = json_dict.get("bias"),
    final_bias               = json_dict.get("final_bias"),
    in_channels              = json_dict.get("in_channels"),
    max_iterations           = json_dict.get("max_iterations"),
    out_channels             = json_dict.get("out_channels"),
    parametric_skip          = json_dict.get("parametric_skip"),
    recall_norm              = json_dict.get("recall_norm"),
    spectral_norm            = json_dict.get("spectral_norm"),
    use_batchnorm            = json_dict.get("use_batchnorm"),
    use_incremental_progress = json_dict.get("use_incremental_progress"),
    get_json_dict            = False
  )

# REFERENCES:
#
# [1] Arpit Bansal, Avi Schwarzschild, Eitan Borgnia, Zeyad Emam, Furong Huang,
#     Micah Goldlum, and Tom Goldstein.
#     "End-to-End Algorithm Synthesis with Recurrent Networks: Extrapolation
#      without Overthinking".
#     Advances in Neural Information Processing Systems. Oct 2022.
#     https://arxiv.org/abs/2202.05826