################################################################################
# spectral/modules/prefixsums/model.py
#
# 
# 
# 
# 2024
#

import torch

from typing   import Callable, Optional, Tuple, Type, Union
from warnings import warn

from deepthinking.models.deep_thinking_recall import DeepThinkingRecall
from experilog.logger                         import JSONType
from modules.lipschitz_constraint             import lipschitz_constraint
from modules.spectral_constraint              import spectral_norm_constraint

Module = torch.nn.Module
Tensor = torch.Tensor

# Defaults.
ACTIVATION:               Type[Module] = torch.nn.ReLU
BIAS:                     bool         = False
INPUT_CHANNELS:           int          = 12
MAX_ITERATIONS:           int          = 30
ONES_CHANNEL:             bool         = False
OUTPUT_CHANNELS:          int          = 2
PARAMETRIC_SKIP:          bool         = False
USE_BATCHNORM:            bool         = False
USE_INCREMENTAL_PROGRESS: bool         = False

CONSTRAINTS = {
  "lipschitz": lipschitz_constraint,
  "spectral": spectral_norm_constraint
}

class PrefixSumsInput(Module):
  """
  Input module for the prefix sums problem.
  """

  def __init__(self,
      # Arguments:
      *args,
      # Keyword Arguments:
      **kwargs
    ) -> None:
    """
    Initializes ``PrefixSumsInput``.

    Args:
      *args:
        Additional arguments for ``torch.nn.Module``.
      **kwargs:
        Additional keyword arguments for ``torch.nn.Module``.
    """
    super(PrefixSumsInput, self).__init__(*args, **kwargs)
    # Model, identity.
    self.ident = torch.nn.Identity()

  def forward(self,
      # Arguments:
      x: Tensor
    ) -> Tensor:
    """
    Forward function for ``PrefixSumsInput``.

    Args:
      x (Tensor):
        The input tensor.

    Returns:
      Tensor:
        The input module's output tensor.
    """
    x_tilde = self.ident(x)
    return x_tilde

class PrefixSumsPreprocessing(Module):
  """
  Preprocessing module for the prefix sums problem.
  """

  def __init__(self,
      # Arguments:
      width: int,
      *args,
      # Keyword Arguments:
      activation:    Optional[Type[Module]] = None,
      bias:          Optional[bool]         = None,
      in_channels:   Optional[int]          = None,
      ones_channel:  Optional[bool]         = None,
      use_batchnorm: Optional[bool]         = None,
      **kwargs
    ) -> None:
    """
    Initializes ``PrefixSumsPreprocessing``.

    Args:
      width (int):
        The width of the network, as defined in [1]. This should be the same as
        the width for ``PrefixSumsThought``.
      *args:
        Additional arguments for ``torch.nn.Module``.
      activation (torch.nn.Module (class), optional):
        The activation function.
        Defaults to ``ACTIVATION`` if ``None``.
      bias (bool, optional):
        Whether to use bias for convolutional layers.
        Defaults to ``BIAS`` if ``None``.
      in_channels (int, optional):
        The number of channels for input.
        Defaults to ``INPUT_CHANNELS`` if ``None``.
      ones_channel (bool, optional):
        Whether to add an extra channel of ones to the input. This effectively
        adds bias to the first layer in the preprocessing module and the recall
        layer in the thought module.
        Defaults to ``ONES_CHANNEL`` if ``None``.
      use_batchnorm (bool, optional):
        Whether to use batchnorm for layers.
        Defaults to ``USE_BATCHNORM`` if ``None``.
      **kwargs:
        Additional keyword arguments for ``torch.nn.Module``.
    """
    super(PrefixSumsPreprocessing, self).__init__(*args, **kwargs)
    # Width.
    assert isinstance(width, int), \
      "width must be an integer."
    self.width = width
    # Bias.
    assert isinstance(bias, bool) or bias is None, \
      "bias must be a boolean or None."
    self.bias = BIAS if bias is None else bias
    # In channels.
    assert isinstance(in_channels, int) or in_channels is None, \
      "in_channels must be an integer or None."
    self.in_channels = INPUT_CHANNELS if in_channels is None else in_channels
    # Ones channel.
    assert isinstance(ones_channel, bool) or ones_channel is None, \
      "ones_channel must be a boolean or None."
    self.ones_channel = ONES_CHANNEL if ones_channel is None else ones_channel
    # Batch norm.
    assert isinstance(use_batchnorm, bool) or use_batchnorm is None, \
      "use_batchnorm must be a boolean or None."
    self.use_batchnorm = USE_BATCHNORM if use_batchnorm is None \
                                       else use_batchnorm
    # Activation.
    activation = ACTIVATION if activation is None else activation
    # Model, convolution 1.
    # If batch norm is used, then ``ones_channel`` is ignored as this'll have
    # no impact. ``bias`` is left enabled as an override.
    self.conv1 = torch.nn.Conv1d(
      in_channels  = self.in_channels,
      out_channels = self.width,
      kernel_size  = 3,
      stride       = 1,
      padding      = 1,
      bias         = self.bias or (self.ones_channel and not self.use_batchnorm)
    )
    self.norm1 = torch.nn.BatchNorm1d(self.width) if self.use_batchnorm \
                                                  else lambda x: x
    self.act1 = activation()

  def forward(self,
      # Arguments:
      x_tilde: Tensor
    ) -> Tensor:
    """
    Forward function for ``PrefixSumsPreprocessing``.

    Args:
      x_tilde (Tensor):
        The input tensor provided from the input module.

    Returns:
      Tensor:
        The output tensor from the preprocessing module.
    """
    phi_0 = self.act1(self.norm1(self.conv1(x_tilde)))
    return phi_0

class PrefixSumsThought(Module):
  """
  Thought module for the prefix sums problem.
  """

  def __init__(self,
      # Arguments:
      width: int,
      *args,
      # Keyword Arguments:
      activation:      Optional[Type[Module]] = None,
      bias:            Optional[bool]         = None,
      constraint:      Optional[str]          = None,
      in_channels:     Optional[int]          = None,
      ones_channel:    Optional[bool]         = None,
      parametric_skip: Optional[bool]         = None,
      **kwargs
    ) -> None:
    """
    Initializes ``PrefixSumsThought``.

    Args:
      width (int):
        The width of the network, as defined in [1]. This should be the same as
        the width for ``PrefixSumsPreprocessing``.
      *args:
        Additional arguments for ``torch.nn.Module``.
      activation (torch.nn.Module (class), optional):
        The activation function.
        Defaults to ``ACTIVATION`` if ``None``.
      bias (bool, optional):
        Whether to use bias for convolutional layers.
        Defaults to ``BIAS`` if ``None``.
      constraint (str, optional):
        The type of constraint to use.
        Defaults to ``None``.
      in_channels (int, optional):
        The number of channels for input.
        Defaults to ``INPUT_CHANNELS`` if ``None``.
      ones_channel (bool, optional):
        Whether to add an extra channel of ones to the input. This effectively
        adds bias to the first layer in the preprocessing module and the recall
        layer in the thought module.
        Defaults to ``ONES_CHANNEL`` if ``None``.
      parametric_skip (bool, optional):
        Whether the skip connection should have a parameterized weight.
        Defaults to ``PARAMETRIC_SKIP`` if ``None``.
      **kwargs:
        Additional keyword arguments for ``torch.nn.Module``.
    """
    super(PrefixSumsThought, self).__init__(*args, **kwargs)
    # Width.
    assert isinstance(width, int), \
      "width must be an integer."
    self.width = width
    # Bias.
    assert isinstance(bias, bool) or bias is None, \
      "bias must be a boolean or None."
    self.bias = BIAS if bias is None else bias
    # Constraint.
    assert isinstance(constraint, str) or constraint is None, \
      "constraint must be a string or None."
    if constraint is not None:
      assert constraint in CONSTRAINTS, \
        f"'{constraint}' is not a valid constraint."
    self.constraint = constraint
    # In channels.
    assert isinstance(in_channels, int) or in_channels is None, \
      "in_channels must be an integer or None."
    self.in_channels = INPUT_CHANNELS if in_channels is None else in_channels
    # Ones channel.
    assert isinstance(ones_channel, bool) or ones_channel is None, \
      "ones_channel must be a boolean or None."
    self.ones_channel = ONES_CHANNEL if ones_channel is None else ones_channel
    # Parametric skip.
    assert isinstance(parametric_skip, bool) or parametric_skip is None, \
      "parametric_skip must be a boolean or None."
    self.parametric_skip = parametric_skip
    if self.parametric_skip:
      self.skip1_param = torch.nn.Parameter(torch.randn((self.width, 1)))
      self.skip2_param = torch.nn.Parameter(torch.randn((self.width, 1)))
      self.skip1 = self.get_parametric_skip(self.skip1_param)
      self.skip2 = self.get_parametric_skip(self.skip2_param)
    else:
      self.skip1 = lambda x, y: x + y
      self.skip2 = self.skip1
    # Activation.
    activation = ACTIVATION if activation is None else activation
    # Model, recall.
    # If a constraint is used, this needs to be split into two layers. One
    # with constraints applied, the other not.
    if self.constraint is not None:
      self.conv_recall = torch.nn.Conv1d(
        in_channels  = self.in_channels,
        out_channels = self.width,
        kernel_size  = 3,
        stride       = 1,
        padding      = 1,
        bias         = self.bias or self.ones_channel
      )
      self.conv0 = torch.nn.Conv1d(
        in_channels  = self.width,
        out_channels = self.width,
        kernel_size  = 3,
        stride       = 1,
        padding      = 1,
        bias         = self.bias
      )
    else:
      self.conv_recall = torch.nn.Conv1d(
        in_channels  = self.width + self.in_channels,
        out_channels = self.width,
        kernel_size  = 3,
        stride       = 1,
        padding      = 1,
        bias         = self.bias or self.ones_channel
      )
    # Model, convolutions.
    for i in range(1, 5):
      conv = torch.nn.Conv1d(
        in_channels  = self.width,
        out_channels = self.width,
        kernel_size  = 3,
        stride       = 1,
        padding      = 1,
        bias         = self.bias
      )
      setattr(self, f"conv{i}", conv)
      setattr(self, f"act{i}", activation())
    # Model, constraints.
    if self.constraint is not None:
      to_constrain = [self.conv0] + \
                     [getattr(self, f"conv{i}") for i in range(1, 5)]
      self.constraints = torch.nn.ModuleList(
        [CONSTRAINTS[self.constraint](x) for x in to_constrain]
      )
      print(self.constraints)
    else:
      self.constraints = []

  def get_parametric_skip(self,
      # Arguments:
      skip_weight: Tensor
    ) -> Callable[[Tensor, Tensor], Tensor]:
    """
    Retrieves a function to perform the parametric skip connection.

    Args:
      skip_weight (Tensor):
        The skip parameter weight.

    Returns:
      Callable[[Tensor, Tensor], Tensor]:
        The function to perform the skip.
    """
    def temp(x, y):
      p = torch.sigmoid(skip_weight)
      return (1 - p) * x + p * y
    return temp

  def forward(self,
      phi:     Tensor,
      x_tilde: Tensor
    ) -> Tensor:
    """
    Forward function for ``PrefixSumsThought``.

    Args:
      phi (Tensor):
        Input tensor from the previous recurrence.
      x_tilde (Tensor):
        Input tensor from the input module.

    Returns:
      Tensor:
        The output tensor from the current recurrence of the thought module.
    """
    if self.constraint is not None:
      t = self.conv0(phi) + self.conv_recall(x_tilde)
    else:
      t = torch.cat([phi, x_tilde], dim = 1)
      t = self.conv_recall(t)
    t = self.act2(self.skip1(t, self.conv2(self.act1(self.conv1(t)))))
    t = self.act4(self.skip2(t, self.conv4(self.act3(self.conv3(t)))))
    return t

class PrefixSumsOutput(Module):
  """
  Output module for the prefix sums problem.
  """

  def __init__(self,
      # Arguments:
      width:    int,
      channels: Tuple[int, int],
      *args,
      # Keyword Arguments:
      activation:    Optional[Type[Module]] = None,
      bias:          Optional[bool]         = None,
      final_bias:    Optional[bool]         = None,
      out_channels:  Optional[int]          = None,
      use_batchnorm: Optional[bool]         = None,
      **kwargs
    ) -> None:
    """
    Initializes ``PrefixSumsOutput``.

    Args:
      width (int):
        The width of the network, as defined in [1]. This should be the same as
        the width for ``PrefixSumsThought``.
      channels (Tuple[int, int]):
        The channels of the two "h" layers (except the last).
      *args:
        Additional arguments for ``torch.nn.Module``.
      activation (torch.nn.Module (class), optional):
        The activation function.
        Defaults to ``ACTIVATION`` if ``None``.
      bias (bool, optional):
        Whether to use bias for convolutional layers.
        Defaults to ``BIAS`` if ``None``.
      final_bias (bool, optional):
        Whether to use bias for the final convolutional layer.
        Defaults to ``bias`` (argument) if ``None``.
      out_channels (int, optional):
        The number of channels for output.
        Defaults to ``OUTPUT_CHANNELS`` if ``None``.
      use_batchnorm (bool, optional):
        Whether to use batchnorm for layers.
        Defaults to ``USE_BATCHNORM`` if ``None``.
      **kwargs:
        Additional keyword arguments for ``torch.nn.Module``.
    """
    super(PrefixSumsOutput, self).__init__(*args, **kwargs)
    # Width.
    assert isinstance(width, int), \
      "width must be an integer."
    self.width = width
    # Channels.
    assert isinstance(channels[0], int) and isinstance(channels[1], int), \
      "channels must be a pair of integers."
    self.channels = channels
    # Bias.
    assert isinstance(bias, bool) or bias is None, \
      "bias must be a boolean or None."
    self.bias = BIAS if bias is None else bias
    # Final bias.
    assert isinstance(final_bias, bool) or final_bias is None, \
      "final_bias must be a boolean or None."
    self.final_bias = self.bias if final_bias is None else final_bias
    # Out channels.
    assert isinstance(out_channels, int) or out_channels is None, \
      "out_channels must be an integer or None."
    self.out_channels = INPUT_CHANNELS if out_channels is None else out_channels
    # Batch norm.
    assert isinstance(use_batchnorm, bool) or use_batchnorm is None, \
      "use_batchnorm must be a boolean or None."
    self.use_batchnorm = USE_BATCHNORM if use_batchnorm is None \
                                       else use_batchnorm
    # Model, convolution 1.
    self.conv1 = torch.nn.Conv1d(
      in_channels  = self.width,
      out_channels = self.channels[0],
      kernel_size  = 3,
      stride       = 1,
      padding      = 1,
      bias         = self.bias
    )
    self.norm1 = torch.nn.BatchNorm1d(self.channels[0]) if self.use_batchnorm \
                                                        else lambda x: x
    self.act1 = activation()
    # Model, convolution 2.
    self.conv2 = torch.nn.Conv1d(
      in_channels  = self.channels[0],
      out_channels = self.channels[1],
      kernel_size  = 3,
      stride       = 1,
      padding      = 1,
      bias         = self.bias
    )
    self.norm2 = torch.nn.BatchNorm1d(self.channels[1]) if self.use_batchnorm \
                                                        else lambda x: x
    self.act2 = activation()
    # Model, convolution 3.
    self.conv3 = torch.nn.Conv1d(
      in_channels  = self.channels[1],
      out_channels = self.out_channels,
      kernel_size  = 3,
      stride       = 1,
      padding      = 1,
      bias         = self.final_bias
    )

  def forward(self,
      # Arguments:
      phi_M: Tensor
    ) -> Tensor:
    """
    Forward function for ``PrefixSumsOutput``.

    Args:
      phi_M (Tensor):
        The input tensor from the final recurrence of ``PrefixSumsThought``.

    Returns:
      Tensor:
        The final output tensor.
    """
    t = self.act1(self.norm1(self.conv1(phi_M)))
    t = self.act2(self.norm2(self.conv2(t)))
    y_hat = self.conv3(t)
    return y_hat
  
def make_prefix_sums_model(
    # Arguments:
    width:      int,
    h_channels: Tuple[int, int],
    *args,
    # Keyword Arguments:
    activation:               Optional[Type[Module]] = None,
    bias:                     Optional[bool]         = None,
    constraint:               Optional[str]          = None,
    final_bias:               Optional[bool]         = None,
    in_channels:              Optional[int]          = None,
    max_iterations:           Optional[int]          = None,
    ones_channel:             Optional[bool]         = None,
    out_channels:             Optional[int]          = None,
    parametric_skip:          Optional[bool]         = None,
    use_batchnorm:            Optional[bool]         = None,
    use_incremental_progress: Optional[bool]         = None,
    get_json_dict:            bool                   = False,
    **kwargs
  ) -> Union[Module, Tuple[Module, JSONType]]:
  """
  Constructs the prefix sums deep thinking model.

  Args:
    width (int):
      The width of the network, as defined in [1]. This should be the same as
      the width for ``PrefixSumsThought``.
    h_channels (Tuple[int, int]):
      The channels of the two "h" layers (except the last).
    *args:
      Additional arguments for every ``torch.nn.Module``.
    activation (torch.nn.Module (class), optional):
      The activation function.
      Defaults to ``ACTIVATION`` if ``None``.
    bias (bool, optional):
      Whether to use bias for convolutional layers.
      Defaults to ``BIAS`` if ``None``.
    constraint (str, optional):
      The type of constraint to use.
      Defaults to ``None``.
    final_bias (bool, optional):
      Whether to use bias for the final convolutional layer.
      Defaults to ``bias`` (argument) if ``None``.
    in_channels (int, optional):
      The number of channels for input.
      Defaults to ``INPUT_CHANNELS`` if ``None``.
    max_iterations (int, optional):
      The maximum number of iterations of the thought network to perform. This
      can be changed safely inside the model, and can be changed during
      inference using the ``max_iterations`` argument of the forward function.
      Defaults to ``MAX_ITERATIONS`` if ``None``.
    ones_channel (bool, optional):
      Whether to add an extra channel of ones to the input. This effectively
      adds bias to the first layer in the preprocessing module and the recall
      layer in the thought module.
      Defaults to ``ONES_CHANNEL`` if ``None``.
    out_channels (int, optional):
      The number of channels for output.
      Defaults to ``OUTPUT_CHANNELS`` if ``None``.
    parametric_skip (bool, optional):
      Whether the skip connection should have a parameterized weight.
      Defaults to ``PARAMETRIC_SKIP`` if ``None``.
    use_batchnorm (bool, optional):
      Whether to use batchnorm for layers.
      Defaults to ``USE_BATCHNORM`` if ``None``.
    use_incremental_progress (bool, optional):
      Whether to use incremental progress training, as defined in [1]. This
      only applies during training.
      Defaults to ``USE_INCREMENTAL_PROGRESS`` if ``None``.
    get_json_dict (bool, optional):
      If ``True``, this returns a dictionary (in JSON-ready format) containing
      the above parameters used for the model. This can be read to construct the
      model again with ``load_from_json_dict``.
      Defaults to ``False``.
    **kwargs:
      Additional keyword arguments for ``torch.nn.Module``.

  Returns:
    (Module | (Module, JSONType)):
      The model if ``get_json_dict`` is ``False``, otherwise a tuple containing
      both the model and the JSON-ready dictionary describing the model's
      hyperparameters.
  """
  assert isinstance(get_json_dict, bool), \
    "get_json_dict must be a boolean."
  max_iterations = MAX_ITERATIONS if max_iterations is None else max_iterations
  use_incremental_progress = USE_INCREMENTAL_PROGRESS \
                             if use_incremental_progress is None \
                             else use_incremental_progress
  model = DeepThinkingRecall(
    PrefixSumsInput(
      *args,
      **kwargs
    ),
    PrefixSumsPreprocessing(
      width,
      *args,
      activation    = activation,
      bias          = bias,
      in_channels   = in_channels,
      ones_channel  = ones_channel,
      use_batchnorm = use_batchnorm,
      **kwargs
    ),
    PrefixSumsThought(
      width,
      *args,
      activation      = activation,
      bias            = bias,
      constraint      = constraint,
      in_channels     = in_channels,
      ones_channel    = ones_channel,
      parametric_skip = parametric_skip,
      **kwargs
    ),
    PrefixSumsOutput(
      width,
      h_channels,
      *args,
      activation    = activation,
      bias          = bias,
      final_bias    = final_bias,
      out_channels  = out_channels,
      use_batchnorm = use_batchnorm,
      **kwargs
    ),
    *args,
    max_iterations           = max_iterations,
    use_incremental_progress = use_incremental_progress,
    **kwargs
  )
  if get_json_dict:
    # Get defaults for JSON. Default values still need to be recorded.
    activation = ACTIVATION if activation is None else activation
    bias = BIAS if bias is None else bias
    in_channels = INPUT_CHANNELS if in_channels is None else in_channels
    out_channels = OUTPUT_CHANNELS if out_channels is None else out_channels
    # For almost all activation functions, this shouldn't be an issue, but
    #  as a precaution a warning is given if it isn't possible to load the
    #  correct function from a string.
    try:
      getattr(torch.nn, activation.__name__)
    except:
      warn(
        f"'{activation.__name__}' cannot be converted back into the " + \
        "torch.nn class it comes from. The name will still be recorded, " + \
        "but experiments reading this model from a file may need to have " + \
        "its activation manually set to the correct class."
      )
    json_dict = {
      "width":                    width,
      "h_channels":               list(h_channels),
      "activation":               activation.__name__,
      "bias":                     bias,
      "constraint":               constraint,
      "final_bias":               final_bias,
      "in_channels":              in_channels,
      "max_iterations":           max_iterations,
      "ones_channel":             ones_channel,
      "out_channels":             out_channels,
      "parametric_skip":          parametric_skip,
      "use_batchnorm":            use_batchnorm,
      "use_incremental_progress": use_incremental_progress,
    }
    return (model, json_dict)
  return model

def load_from_json_dict(
    # Arguments:
    json_dict: JSONType
  ) -> Module:
  """
  Constructs a prefix sums deep thinking model from a JSON-ready dictionary.

  Args:
    json_dict (JSONType):
      The JSON-ready dictionary containing the construction hyperparameters.

  Returns:
    Module:
      The reconstructed model.
  """
  # Yes, I'm fully aware that **{...} exists.
  return make_prefix_sums_model(
    width                    = int(json_dict["width"]),
    h_channels               = tuple(int(x) for x in json_dict["h_channels"]),
    activation               = getattr(torch.nn, json_dict["activation"]),
    bias                     = json_dict.get("bias"),
    constraint               = json_dict.get("constraint"),
    final_bias               = json_dict.get("final_bias"),
    in_channels              = json_dict.get("in_channels"),
    max_iterations           = json_dict.get("max_iterations"),
    ones_channel             = json_dict.get("ones_channel"),
    out_channels             = json_dict.get("out_channels"),
    parametric_skip          = json_dict.get("parametric_skip"),
    use_batchnorm            = json_dict.get("use_batchnorm"),
    use_incremental_progress = json_dict.get("use_incremental_progress"),
    get_json_dict            = False
  )

# REFERENCES:
#
# [1] Arpit Bansal, Avi Schwarzschild, Eitan Borgnia, Zeyad Emam, Furong Huang,
#     Micah Goldlum, and Tom Goldstein.
#     "End-to-End Algorithm Synthesis with Recurrent Networks: Extrapolation
#      without Overthinking".
#     Advances in Neural Information Processing Systems. Oct 2022.
#     https://arxiv.org/abs/2202.05826