################################################################################
# spectral/modules/prefixsums/model.py
#
# 
# 
# 
# 2024
#

import torch

from typing   import Optional, Tuple, Type, Union
from warnings import warn

from deepthinking.models.deep_thinking_recall import DeepThinkingRecall
from experilog.logger                         import JSONType

Module = torch.nn.Module
Tensor = torch.Tensor

# Defaults.
ACTIVATION:               Type[Module] = torch.nn.ReLU
BIAS:                     bool         = False
INPUT_CHANNELS:           int          = 1
MAX_ITERATIONS:           int          = 30
OUTPUT_CHANNELS:          int          = 2
USE_BATCHNORM:            bool         = False
USE_INCREMENTAL_PROGRESS: bool         = False

class PrefixSumsInput(Module):
  """
  Input module for the prefix sums problem.
  """

  def __init__(self,
      # Arguments:
      *args,
      # Keyword Arguments:
      in_channels:   Optional[int]  = None,
      use_batchnorm: Optional[bool] = None,
      **kwargs
    ) -> None:
    """
    Initializes ``PrefixSumsInput``.

    Args:
      *args:
        Additional arguments for ``torch.nn.Module``.
      in_channels (int, optional):
        The number of channels for input. Needed for norm.
        Defaults to ``INPUT_CHANNELS`` if ``None``.
      use_batchnorm (bool, optional):
        Whether to use batchnorm for layers.
        Defaults to ``USE_BATCHNORM`` if ``None``.
      **kwargs:
        Additional keyword arguments for ``torch.nn.Module``.
    """
    super(PrefixSumsInput, self).__init__(*args, **kwargs)
    # In channels.
    assert isinstance(in_channels, int) or in_channels is None, \
      "in_channels must be an integer or None."
    self.in_channels = INPUT_CHANNELS if in_channels is None else in_channels
    # Batch norm.
    assert isinstance(use_batchnorm, bool) or use_batchnorm is None, \
      "use_batchnorm must be a boolean or None."
    self.use_batchnorm = USE_BATCHNORM if use_batchnorm is None \
                                       else use_batchnorm
    self.ident = torch.nn.Identity()
    self.norm1 = torch.nn.BatchNorm1d(self.in_channels) if self.use_batchnorm \
                                                        else lambda x: x

  def forward(self,
      # Arguments:
      x: Tensor
    ) -> Tensor:
    """
    Forward function for ``PrefixSumsInput``.

    Args:
      x (Tensor):
        The input tensor.

    Returns:
      Tensor:
        The input module's output tensor.
    """
    x_tilde = self.norm1(self.ident(x))
    return x_tilde

class PrefixSumsPreprocessing(Module):
  """
  Preprocessing module for the prefix sums problem.
  """

  def __init__(self,
      # Arguments:
      width: int,
      *args,
      # Keyword Arguments:
      activation:    Optional[Type[Module]] = None,
      bias:          Optional[bool]         = None,
      in_channels:   Optional[int]          = None,
      use_batchnorm: Optional[bool]         = None,
      **kwargs
    ) -> None:
    """
    Initializes ``PrefixSumsPreprocessing``.

    Args:
      width (int):
        The width of the network, as defined in [1]. This should be the same as
        the width for ``PrefixSumsThought``.
      *args:
        Additional arguments for ``torch.nn.Module``.
      activation (torch.nn.Module (class), optional):
        The activation function.
        Defaults to ``ACTIVATION`` if ``None``.
      bias (bool, optional):
        Whether to use bias for convolutional layers.
        Defaults to ``BIAS`` if ``None``.
      in_channels (int, optional):
        The number of channels for input.
        Defaults to ``INPUT_CHANNELS`` if ``None``.
      use_batchnorm (bool, optional):
        Whether to use batchnorm for layers.
        Defaults to ``USE_BATCHNORM`` if ``None``.
      **kwargs:
        Additional keyword arguments for ``torch.nn.Module``.
    """
    super(PrefixSumsPreprocessing, self).__init__(*args, **kwargs)
    # Width.
    assert isinstance(width, int), \
      "width must be an integer."
    self.width = width
    # Bias.
    assert isinstance(bias, bool) or bias is None, \
      "bias must be a boolean or None."
    self.bias = BIAS if bias is None else bias
    # In channels.
    assert isinstance(in_channels, int) or in_channels is None, \
      "in_channels must be an integer or None."
    self.in_channels = INPUT_CHANNELS if in_channels is None else in_channels
    # Batch norm.
    assert isinstance(use_batchnorm, bool) or use_batchnorm is None, \
      "use_batchnorm must be a boolean or None."
    self.use_batchnorm = USE_BATCHNORM if use_batchnorm is None \
                                       else use_batchnorm
    # Activation.
    activation = ACTIVATION if activation is None else activation
    # Model, convolution 1.
    self.conv1 = torch.nn.Conv1d(
      in_channels  = self.in_channels,
      out_channels = self.width,
      kernel_size  = 3,
      stride       = 1,
      padding      = 1,
      bias         = self.bias
    )
    self.norm1 = torch.nn.BatchNorm1d(self.width) if self.use_batchnorm \
                                                  else lambda x: x
    self.act1 = activation()

  def forward(self,
      # Arguments:
      x_tilde: Tensor
    ) -> Tensor:
    """
    Forward function for ``PrefixSumsPreprocessing``.

    Args:
      x_tilde (Tensor):
        The input tensor provided from the input module.

    Returns:
      Tensor:
        The output tensor from the preprocessing module.
    """
    phi_0 = self.act1(self.norm1(self.conv1(x_tilde)))
    return phi_0

class PrefixSumsThought(Module):
  """
  Thought module for the prefix sums problem.
  """

  def __init__(self,
      # Arguments:
      width: int,
      *args,
      # Keyword Arguments:
      activation:  Optional[Type[Module]] = None,
      bias:        Optional[bool]         = None,
      in_channels: Optional[int]          = None,
      **kwargs
    ) -> None:
    """
    Initializes ``PrefixSumsThought``.

    Args:
      width (int):
        The width of the network, as defined in [1]. This should be the same as
        the width for ``PrefixSumsPreprocessing``.
      *args:
        Additional arguments for ``torch.nn.Module``.
      activation (torch.nn.Module (class), optional):
        The activation function.
        Defaults to ``ACTIVATION`` if ``None``.
      bias (bool, optional):
        Whether to use bias for convolutional layers.
        Defaults to ``BIAS`` if ``None``.
      in_channels (int, optional):
        The number of channels for input.
        Defaults to ``INPUT_CHANNELS`` if ``None``.
      **kwargs:
        Additional keyword arguments for ``torch.nn.Module``.
    """
    super(PrefixSumsThought, self).__init__(*args, **kwargs)
    # Width.
    assert isinstance(width, int), \
      "width must be an integer."
    self.width = width
    # Bias.
    assert isinstance(bias, bool) or bias is None, \
      "bias must be a boolean or None."
    self.bias = BIAS if bias is None else bias
    # In channels.
    assert isinstance(in_channels, int) or in_channels is None, \
      "in_channels must be an integer or None."
    self.in_channels = INPUT_CHANNELS if in_channels is None else in_channels
    # Activation.
    activation = ACTIVATION if activation is None else activation
    # Model, recall.
    self.conv_recall = torch.nn.Conv1d(
      in_channels  = self.width + self.in_channels,
      out_channels = self.width,
      kernel_size  = 5,
      stride       = 1,
      padding      = 2,
      bias         = self.bias
    )
    self.act1 = activation()

  def forward(self,
      phi:     Tensor,
      x_tilde: Tensor
    ) -> Tensor:
    """
    Forward function for ``PrefixSumsThought``.

    Args:
      phi (Tensor):
        Input tensor from the previous recurrence.
      x_tilde (Tensor):
        Input tensor from the input module.

    Returns:
      Tensor:
        The output tensor from the current recurrence of the thought module.
    """
    t = torch.cat([phi, x_tilde], dim = 1)
    t = self.act1(self.conv_recall(t))
    return t

class PrefixSumsOutput(Module):

  def __init__(self,
      # Arguments:
      width:    int,
      channels: Tuple[int, int],
      *args,
      # Keyword Arguments:
      activation:    Optional[Type[Module]] = None,
      bias:          Optional[bool]         = None,
      out_channels:  Optional[int]          = None,
      use_batchnorm: Optional[bool]         = None,
      **kwargs
    ) -> None:
    """
    Initializes ``PrefixSumsOutput``.

    Args:
      width (int):
        The width of the network, as defined in [1]. This should be the same as
        the width for ``PrefixSumsThought``.
      channels (Tuple[int, int]):
        The channels of the two "h" layers (except the last).
      *args:
        Additional arguments for ``torch.nn.Module``.
      activation (torch.nn.Module (class), optional):
        The activation function.
        Defaults to ``ACTIVATION`` if ``None``.
      bias (bool, optional):
        Whether to use bias for convolutional layers.
        Defaults to ``BIAS`` if ``None``.
      out_channels (int, optional):
        The number of channels for output.
        Defaults to ``OUTPUT_CHANNELS`` if ``None``.
      use_batchnorm (bool, optional):
        Whether to use batchnorm for layers.
        Defaults to ``USE_BATCHNORM`` if ``None``.
      **kwargs:
        Additional keyword arguments for ``torch.nn.Module``.
    """
    super(PrefixSumsOutput, self).__init__(*args, **kwargs)
    # Width.
    assert isinstance(width, int), \
      "width must be an integer."
    self.width = width
    # Channels.
    assert isinstance(channels[0], int) and isinstance(channels[1], int), \
      "channels must be a pair of integers."
    self.channels = channels
    # Bias.
    assert isinstance(bias, bool) or bias is None, \
      "bias must be a boolean or None."
    self.bias = BIAS if bias is None else bias
    # Out channels.
    assert isinstance(out_channels, int) or out_channels is None, \
      "out_channels must be an integer or None."
    self.out_channels = INPUT_CHANNELS if out_channels is None else out_channels
    # Batch norm.
    assert isinstance(use_batchnorm, bool) or use_batchnorm is None, \
      "use_batchnorm must be a boolean or None."
    self.use_batchnorm = USE_BATCHNORM if use_batchnorm is None \
                                       else use_batchnorm
    # Model, convolution 1.
    self.conv1 = torch.nn.Conv1d(
      in_channels  = self.width,
      out_channels = self.channels[0],
      kernel_size  = 3,
      stride       = 1,
      padding      = 1,
      bias         = self.bias
    )
    self.norm1 = torch.nn.BatchNorm1d(self.channels[0]) if self.use_batchnorm \
                                                        else lambda x: x
    self.act1 = activation()
    # Model, convolution 2.
    self.conv2 = torch.nn.Conv1d(
      in_channels  = self.channels[0],
      out_channels = self.channels[1],
      kernel_size  = 3,
      stride       = 1,
      padding      = 1,
      bias         = self.bias
    )
    self.norm2 = torch.nn.BatchNorm1d(self.channels[1]) if self.use_batchnorm \
                                                        else lambda x: x
    self.act2 = activation()
    # Model, convolution 3.
    self.conv3 = torch.nn.Conv1d(
      in_channels  = self.channels[1],
      out_channels = self.out_channels,
      kernel_size  = 3,
      stride       = 1,
      padding      = 1,
      bias         = self.bias
    )

  def forward(self,
      # Arguments:
      phi_M: Tensor
    ) -> Tensor:
    """
    Forward function for ``PrefixSumsOutput``.

    Args:
      phi_M (Tensor):
        The input tensor from the final recurrence of ``PrefixSumsThought``.

    Returns:
      Tensor:
        The final output tensor.
    """
    t = self.act1(self.norm1(self.conv1(phi_M)))
    t = self.act2(self.norm2(self.conv2(phi_M)))
    y_hat = self.conv3(t)
    return y_hat
  
def make_prefix_sums_model(
    # Arguments:
    width:      int,
    h_channels: Tuple[int, int],
    *args,
    # Keyword Arguments:
    activation:               Optional[Type[Module]] = None,
    bias:                     Optional[bool]         = None,
    in_channels:              Optional[int]          = None,
    max_iterations:           Optional[int]          = None,
    out_channels:             Optional[int]          = None,
    use_batchnorm:            Optional[bool]         = None,
    use_incremental_progress: Optional[bool]         = None,
    get_json_dict:            bool                   = False,
    **kwargs
  ) -> Union[Module, Tuple[Module, JSONType]]:
  """
  Constructs the prefix sums deep thinking model.

  Args:
    width (int):
      The width of the network, as defined in [1]. This should be the same as
      the width for ``PrefixSumsThought``.
    h_channels (Tuple[int, int]):
      The channels of the two "h" layers (except the last).
    *args:
      Additional arguments for every ``torch.nn.Module``.
    activation (torch.nn.Module (class), optional):
      The activation function.
      Defaults to ``ACTIVATION`` if ``None``.
    bias (bool, optional):
      Whether to use bias for convolutional layers.
      Defaults to ``BIAS`` if ``None``.
    in_channels (int, optional):
      The number of channels for input.
      Defaults to ``INPUT_CHANNELS`` if ``None``.
    max_iterations (int, optional):
      The maximum number of iterations of the thought network to perform. This
      can be changed safely inside the model, and can be changed during
      inference using the ``max_iterations`` argument of the forward function.
      Defaults to ``MAX_ITERATIONS`` if ``None``.
    out_channels (int, optional):
      The number of channels for output.
      Defaults to ``OUTPUT_CHANNELS`` if ``None``.
    use_batchnorm (bool, optional):
      Whether to use batchnorm for layers.
      Defaults to ``USE_BATCHNORM`` if ``None``.
    use_incremental_progress (bool, optional):
      Whether to use incremental progress training, as defined in [1]. This
      only applies during training.
      Defaults to ``USE_INCREMENTAL_PROGRESS`` if ``None``.
    get_json_dict (bool, optional):
      If ``True``, this returns a dictionary (in JSON-ready format) containing
      the above parameters used for the model. This can be read to construct the
      model again with ``load_from_json_dict``.
      Defaults to ``False``.
    **kwargs:
      Additional keyword arguments for ``torch.nn.Module``.

  Returns:
    (Module | (Module, JSONType)):
      The model if ``get_json_dict`` is ``False``, otherwise a tuple containing
      both the model and the JSON-ready dictionary describing the model's
      hyperparameters.
  """
  assert isinstance(get_json_dict, bool), \
    "get_json_dict must be a boolean."
  max_iterations = MAX_ITERATIONS if max_iterations is None else max_iterations
  use_incremental_progress = USE_INCREMENTAL_PROGRESS \
                             if use_incremental_progress is None \
                             else use_incremental_progress
  model = DeepThinkingRecall(
    PrefixSumsInput(
      *args,
      in_channels   = in_channels,
      use_batchnorm = use_batchnorm,
      **kwargs
    ),
    PrefixSumsPreprocessing(
      width,
      *args,
      activation    = activation,
      bias          = bias,
      in_channels   = in_channels,
      use_batchnorm = use_batchnorm,
      **kwargs
    ),
    PrefixSumsThought(
      width,
      *args,
      activation  = activation,
      bias        = bias,
      in_channels = in_channels,
      **kwargs
    ),
    PrefixSumsOutput(
      width,
      h_channels,
      *args,
      activation    = activation,
      bias          = bias,
      out_channels  = out_channels,
      use_batchnorm = use_batchnorm,
      **kwargs
    ),
    *args,
    max_iterations           = max_iterations,
    use_incremental_progress = use_incremental_progress,
    **kwargs
  )
  if get_json_dict:
    # Get defaults for JSON. Default values still need to be recorded.
    activation = ACTIVATION if activation is None else activation
    bias = BIAS if bias is None else bias
    in_channels = INPUT_CHANNELS if in_channels is None else in_channels
    out_channels = OUTPUT_CHANNELS if out_channels is None else out_channels
    # For almost all activation functions, this shouldn't be an issue, but
    #  as a precaution a warning is given if it isn't possible to load the
    #  correct function from a string.
    try:
      getattr(torch.nn, activation.__name__)
    except:
      warn(
        f"'{activation.__name__}' cannot be converted back into the " + \
        "torch.nn class it comes from. The name will still be recorded, " + \
        "but experiments reading this model from a file may need to have " + \
        "its activation manually set to the correct class."
      )
    json_dict = {
      "width":                    width,
      "h_channels":               list(h_channels),
      "activation":               activation.__name__,
      "bias":                     bias,
      "in_channels":              in_channels,
      "max_iterations":           max_iterations,
      "out_channels":             out_channels,
      "use_batchnorm":            use_batchnorm,
      "use_incremental_progress": use_incremental_progress
    }
    return (model, json_dict)
  return model

def fill_with_nones(
    # Arguments:
    json_dict: JSONType
  ) -> JSONType:
  """
  Takes a JSON-ready dictionary and replaces missing values with ``None``.
  This is needed as changes to model classes has happened after some config
  files have been saved.

  Args:
    json_dict (JSONType):
      The JSON-ready dictionary containing the construction hyperparameters.

  Returns:
    JSONType:
      The dictionary with missing values replaced with ``None``.
  """
  

def load_from_json_dict(
    # Arguments:
    json_dict: JSONType
  ) -> Module:
  """
  Constructs a prefix sums deep thinking model from a JSON-ready dictionary.

  Args:
    json_dict (JSONType):
      The JSON-ready dictionary containing the construction hyperparameters.

  Returns:
    Module:
      The reconstructed model.
  """
  # Yes, I'm fully aware that **{...} exists.
  return make_prefix_sums_model(
    width                    = int(json_dict["width"]),
    h_channels               = tuple(int(x) for x in json_dict["h_channels"]),
    activation               = getattr(torch.nn, json_dict["activation"]),
    bias                     = bool(json_dict["bias"]),
    in_channels              = int(json_dict["in_channels"]),
    max_iterations           = int(json_dict["max_iterations"]),
    out_channels             = int(json_dict["out_channels"]),
    use_batchnorm            = bool(json_dict["use_batchnorm"]),
    use_incremental_progress = bool(json_dict["use_incremental_progress"]),
    get_json_dict            = False
  )

# REFERENCES:
#
# [1] Arpit Bansal, Avi Schwarzschild, Eitan Borgnia, Zeyad Emam, Furong Huang,
#     Micah Goldlum, and Tom Goldstein.
#     "End-to-End Algorithm Synthesis with Recurrent Networks: Extrapolation
#      without Overthinking".
#     Advances in Neural Information Processing Systems. Oct 2022.
#     https://arxiv.org/abs/2202.05826