################################################################################
# spectral/modules/tsp/model.py
#
# 
# 
# 
# 2024
#

import torch

from math           import sqrt
from numpy          import ones_like
from scipy.optimize import linear_sum_assignment
from scipy.sparse   import coo_matrix
from typing         import Any, Callable, Dict, List, Optional, Tuple, Type, \
                           Union
from warnings       import warn

from deepthinking.models.deep_thinking_recall import DeepThinkingRecall
from experilog.logger                         import JSONType
from modules.lipschitz_constraint             import lipschitz_constraint
from modules.nonlocalblock                    import NonLocalBlock
from modules.spectral_constraint              import spectral_norm_constraint

Module = torch.nn.Module
Tensor = torch.Tensor

# Defaults.
ACTIVATION:               Type[Module] = torch.nn.ReLU
BIAS:                     bool         = False
INPUT_CHANNELS:           int          = 3
MAX_ITERATIONS:           int          = 30
ONES_CHANNEL:             bool         = False
ORTHOGONAL:               int          = 0
ORTHO_PHI_CHANNELS:       int          = 4
OUTPUT_CHANNELS:          int          = 1
PARAMETRIC_SKIP:          bool         = False
SINKHORN_TAU:             float        = 0.1
SINKHORN_N:               int          = 10
SINKHORN_NOISE:           float        = 2.0
USE_BATCHNORM:            bool         = False
USE_INCREMENTAL_PROGRESS: bool         = False

CONSTRAINTS = {
  "lipschitz": lipschitz_constraint,
  "spectral": spectral_norm_constraint
}

# Orthogonal enums.
NONE_ORTHOGONAL: int = 0
LEFT_ORTHOGONAL: int = 1
BOTH_ORTHOGONAL: int = 2

def log_sinkhorn(
    # Arguments:
    log_p: Tensor,
    # Keyword Arguments:
    n_iter: int = 1
  ) -> Tensor:
  """
  Performs the Sinkhorn operator in log space.

  Args:
    log_p (Tensor):
      The log of the probabilities for the path.
    n_iter (int, optional):
      The number of iterations to perform.
      Defaults to ``1``.

  Returns:
    Tensor:
      The tensor after Sinkhorn.
  """
  for _ in range(n_iter):
    log_p = log_p - torch.logsumexp(log_p, -1, keepdim = True)
    log_p = log_p - torch.logsumexp(log_p, -2, keepdim = True)
  return torch.exp(log_p)

def gumbel_sinkhorn(
    # Arguments:
    log_p:  Tensor,
    tau:    float,
    n_iter: int,
    # Keyword Arguments:
    eps:   float = 1e-20,
    noise: float = 1.0
  ) -> Tensor:
  """
  Produces a permutation matrix using Gumbel-Sinkhorn.

  Args:
    log_p (Tensor):
      The log of the probabilities for each path.
    tau (float):
      The temperature for Gumbel-Sinkhorn.
    n_iter (int):
      The number of iterations for Sinkhorn.
    eps (float, optional):
      Epsilon used for numerical stability of log.
      Defaults to ``1e-20``.
    noise (float, optional):
      Scaling for the Gumbel noise.
      Defaults to ``1.0``.

  Returns:
    Tensor:
      The permutation matrix from Gumbel-Sinkhorn.
  """
  g = torch.rand(log_p.shape, device = log_p.device)
  g = -torch.log(-torch.log(g + eps) + eps)
  return log_sinkhorn((log_p + noise * g) / tau, n_iter)

def linear_assignment(
    log_p: Tensor
  ) -> Tensor:
  """
  Performs linear assignment on the log probabilities.

  Args:
    log_p (Tensor):
      The log of the probabilities.

  Returns:
    Tensor:
      The linear assignment.
  """
  r, c = linear_sum_assignment(-log_p)
  perm = coo_matrix((ones_like(r), (r, c))).toarray()
  return torch.from_numpy(perm)

def make_orthogonal(
    # Arguments:
    x: Tensor
  ) -> Tensor:
  u, _, vt = torch.linalg.svd(x)
  return torch.matmul(u, vt)

class TSPInput(Module):
  """
  Input module for TSP.
  """

  def __init__(self,
      # Arguments:
      *args,
      # Keyword Arguments:
      **kwargs
    ) -> None:
    """
    Initializes ``TSPInput``.

    Args:
      *args:
        Additional arguments for ``torch.nn.Module``.
      **kwargs:
        Additional keyword arguments for ``torch.nn.Module``.
    """
    super(TSPInput, self).__init__(*args, **kwargs)
    # Model, identity.
    self.ident = torch.nn.Identity()

  def forward(self,
      # Arguments:
      x: Tensor
    ) -> Tensor:
    """
    Forward function for ``TSPInput``.

    Args:
      x (Tensor):
        The input tensor.

    Returns:
      Tensor:
        The input module's output tensor.
    """
    x_tilde = self.ident(x)
    x_tilde = torch.cat([x_tilde, x_tilde[:, [0]].mT], dim = 1)
    return x_tilde

class TSPPreprocessing(Module):
  """
  Preprocessing module for TSP.
  """

  def __init__(self,
      # Arguments:
      width: int,
      *args,
      # Keyword Arguments:
      activation:         Optional[Type[Module]] = None,
      bias:               Optional[bool]         = None,
      in_channels:        Optional[int]          = None,
      ones_channel:       Optional[str]          = None,
      orthogonal:         Optional[int]          = None,
      ortho_phi_channels: Optional[int]          = None,
      use_batchnorm:      Optional[bool]         = None,
      **kwargs
    ) -> None:
    """
    Initializes ``TSPPreprocessing``.

    Args:
      width (int):
        The width of the network, as defined in [1]. This should be the same as
        the width for ``TSPThought``.
      *args:
        Additional arguments for ``torch.nn.Module``.
      activation (torch.nn.Module (class), optional):
        The activation function.
        Defaults to ``ACTIVATION`` if ``None``.
      bias (bool, optional):
        Whether to use bias for convolutional layers.
        Defaults to ``BIAS`` if ``None``.
      in_channels (int, optional):
        The number of channels for input.
        Defaults to ``INPUT_CHANNELS`` if ``None``.
      ones_channel (bool, optional):
        Whether to add an extra channel of ones to the input. This effectively
        adds bias to the first layer in the preprocessing module and the recall
        layer in the thought module.
        Defaults to ``ONES_CHANNEL`` if ``None``.
      orthogonal (int, optional):
        The orthogonal enum (0 = none, 1 = left, 2 = both) for orthogonal
        transforms.
        Defaults to ``ORTHOGONAL`` if ``None``.
      ortho_phi_channels (int, optional):
        The number of (additional) channels in phi where the orthogonal
        transformations are applied to.
        Defaults to ``ORTHO_PHI_CHANNELS`` if ``None``.
      use_batchnorm (bool, optional):
        Whether to use batchnorm for layers.
        Defaults to ``USE_BATCHNORM`` if ``None``.
      **kwargs:
        Additional keyword arguments for ``torch.nn.Module``.
    """
    super(TSPPreprocessing, self).__init__(*args, **kwargs)
    # Width.
    assert isinstance(width, int), \
      "width must be an integer."
    self.width = width
    # Bias.
    assert isinstance(bias, bool) or bias is None, \
      "bias must be a boolean or None."
    self.bias = BIAS if bias is None else bias
    # In channels.
    assert isinstance(in_channels, int) or in_channels is None, \
      "in_channels must be an integer or None."
    self.in_channels = INPUT_CHANNELS if in_channels is None else in_channels
    # Ones channel.
    assert isinstance(ones_channel, bool) or ones_channel is None, \
      "ones_channel must be a boolean or None."
    self.ones_channel = ONES_CHANNEL if ones_channel is None else ones_channel
    # Orthogonal.
    assert isinstance(orthogonal, int) or orthogonal is None, \
      "orthogonal must be an integer or None."
    self.orthogonal = ORTHOGONAL if orthogonal is None else orthogonal
    assert NONE_ORTHOGONAL <= self.orthogonal <= BOTH_ORTHOGONAL, \
      "orthogonal must be in the enum range [0, 2]."
    # Orthogonal phi channels.
    assert isinstance(ortho_phi_channels, int) or ortho_phi_channels is None, \
      "ortho_phi_channels must be an integer or None."
    self.ortho_phi_channels = ORTHO_PHI_CHANNELS if ortho_phi_channels is None \
                                                 else ortho_phi_channels
    # Batch norm.
    assert isinstance(use_batchnorm, bool) or use_batchnorm is None, \
      "use_batchnorm must be a boolean or None."
    self.use_batchnorm = USE_BATCHNORM if use_batchnorm is None \
                                       else use_batchnorm
    # Activation.
    activation = ACTIVATION if activation is None else activation
    # Model, convolution 1.
    self.conv1 = torch.nn.Conv2d(
      in_channels  = self.in_channels,
      out_channels = self.width + self.orthogonal + self.ortho_phi_channels,
      kernel_size  = (3, 3),
      stride       = (1, 1),
      padding      = 1,
      bias         = self.bias or (self.ones_channel and not self.use_batchnorm)
    )
    self.norm1 = torch.nn.BatchNorm2d(self.width + self.ortho_phi_channels) \
                 if self.use_batchnorm else lambda x: x
    self.act1 = activation()

  def forward(self,
      # Arguments:
      x_tilde: Tensor
    ) -> Tensor:
    """
    Forward function for ``TSPPreprocessing``.

    Args:
      x_tilde (Tensor):
        The input tensor provided from the input module.

    Returns:
      Tensor:
        The output tensor from the preprocessing module.
    """
    t = self.conv1(x_tilde)
    if self.orthogonal > NONE_ORTHOGONAL:
      q, phi_0 = t[:, :self.orthogonal], t[:, self.orthogonal:]
      q        = make_orthogonal(q)
      phi_0    = torch.cat([q, self.act1(self.norm1(phi_0))], dim = 1)
    else:
      phi_0 = self.act1(self.norm1(t))
    return phi_0

class TSPThought(Module):
  """
  Thought module for TSP.
  """

  def __init__(self,
      # Arguments:
      width: int,
      *args,
      # Keyword Arguments:
      activation:         Optional[Type[Module]]   = None,
      bias:               Optional[bool]           = None,
      constraint:         Optional[str]            = None,
      in_channels:        Optional[int]            = None,
      ones_channel:       Optional[bool]           = None,
      orthogonal:         Optional[int]            = None,
      ortho_phi_channels: Optional[int]            = None,
      parametric_skip:    Optional[bool]           = None,
      **kwargs
    ) -> None:
    """
    Initializes ``TSPThought``.

    Args:
      width (int):
        The width of the network, as defined in [1]. This should be the same as
        the width for ``TSPPreprocessing``.
      *args:
        Additional arguments for ``torch.nn.Module``.
      activation (torch.nn.Module (class), optional):
        The activation function.
        Defaults to ``ACTIVATION`` if ``None``.
      bias (bool, optional):
        Whether to use bias for convolutional layers.
        Defaults to ``BIAS`` if ``None``.
      constraint (str, optional):
        The type of constraint to use.
        Defaults to ``None``.
      in_channels (int, optional):
        The number of channels for input.
        Defaults to ``INPUT_CHANNELS`` if ``None``.
      ones_channel (bool, optional):
        Whether to add an extra channel of ones to the input. This effectively
        adds bias to the first layer in the preprocessing module and the recall
        layer in the thought module.
        Defaults to ``ONES_CHANNEL`` if ``None``.
      orthogonal (int, optional):
        The orthogonal enum (0 = none, 1 = left, 2 = both) for orthogonal
        transforms.
        Defaults to ``ORTHOGONAL`` if ``None``.
      ortho_phi_channels (int, optional):
        The number of (additional) channels in phi where the orthogonal
        transformations are applied to.
        Defaults to ``ORTHO_PHI_CHANNELS`` if ``None``.
      parametric_skip (bool, optional):
        Whether the skip connection should have a parameterized weight.
        Defaults to ``PARAMETRIC_SKIP`` if ``None``.
      **kwargs:
        Additional keyword arguments for ``torch.nn.Module``.
    """
    super(TSPThought, self).__init__(*args, **kwargs)
    # Width.
    assert isinstance(width, int), \
      "width must be an integer."
    self.width = width
    # Bias.
    assert isinstance(bias, bool) or bias is None, \
      "bias must be a boolean or None."
    self.bias = BIAS if bias is None else bias
    # Constraint.
    assert isinstance(constraint, str) or constraint is None, \
      "constraint must be a string or None."
    if constraint is not None:
      assert constraint in CONSTRAINTS, \
        f"'{constraint}' is not a valid constraint."
    self.constraint = constraint
    # In channels.
    assert isinstance(in_channels, int) or in_channels is None, \
      "in_channels must be an integer or None."
    self.in_channels = INPUT_CHANNELS if in_channels is None else in_channels
    # Ones channel.
    assert isinstance(ones_channel, bool) or ones_channel is None, \
      "ones_channel must be a boolean or None."
    self.ones_channel = ONES_CHANNEL if ones_channel is None else ones_channel
    # Orthogonal.
    assert isinstance(orthogonal, int) or orthogonal is None, \
      "orthogonal must be an integer or None."
    self.orthogonal = ORTHOGONAL if orthogonal is None else orthogonal
    assert NONE_ORTHOGONAL <= self.orthogonal <= BOTH_ORTHOGONAL, \
      "orthogonal must be in the enum range [0, 2]."
    # Orthogonal phi channels.
    assert isinstance(ortho_phi_channels, int) or ortho_phi_channels is None, \
      "ortho_phi_channels must be an integer or None."
    self.ortho_phi_channels = ORTHO_PHI_CHANNELS if ortho_phi_channels is None \
                                                 else ortho_phi_channels
    full_width = self.width + self.ortho_phi_channels
    # Parametric skip.
    assert isinstance(parametric_skip, bool) or parametric_skip is None, \
      "parametric_skip must be a boolean or None."
    self.parametric_skip = parametric_skip
    if self.parametric_skip:
      self.skip1_param = torch.nn.Parameter(torch.randn((full_width, 1, 1)))
      self.skip2_param = torch.nn.Parameter(torch.randn((full_width, 1, 1)))
      self.skip1 = self.get_parametric_skip(self.skip1_param)
      self.skip2 = self.get_parametric_skip(self.skip2_param)
    else:
      self.skip1 = lambda x, y: x + y
      self.skip2 = self.skip1
    # Activation.
    activation = ACTIVATION if activation is None else activation
    # Model, recall.
    # If a constraint is used, this needs to be split into two layers. One
    # with constraints applied, the other not.
    if self.constraint is not None:
      self.conv_recall = torch.nn.Conv2d(
        in_channels  = self.in_channels,
        out_channels = full_width,
        kernel_size  = (3, 3),
        stride       = (1, 1),
        padding      = 1,
        bias         = self.bias or self.ones_channel
      )
      self.conv0 = torch.nn.Conv2d(
        in_channels  = full_width + self.orthogonal,
        out_channels = full_width,
        kernel_size  = (3, 3),
        stride       = (1, 1),
        padding      = 1,
        bias         = self.bias
      )
    else:
      self.conv_recall = torch.nn.Conv2d(
        in_channels  = full_width + self.orthogonal + self.in_channels,
        out_channels = full_width,
        kernel_size  = (3, 3),
        stride       = (1, 1),
        padding      = 1,
        bias         = self.bias or self.ones_channel
      )
    # Model, convolutions.
    for i in range(1, 5):
      conv = torch.nn.Conv2d(
        in_channels  = full_width,
        out_channels = full_width,
        kernel_size  = (3, 3),
        stride       = (1, 1),
        padding      = 1,
        bias         = self.bias
      )
      setattr(self, f"conv{i}", conv)
      setattr(self, f"act{i}", activation())
    # Model, constraints.
    if self.constraint is not None:
      to_constrain = [self.conv0] + \
                     [getattr(self, f"conv{i}") for i in range(1, 5)]
      self.constraints = torch.nn.ModuleList(
        [CONSTRAINTS[self.constraint](x) for x in to_constrain]
      )
    else:
      self.constraints = []
    # Model, orthogonal:
    if self.orthogonal > NONE_ORTHOGONAL:
      self.conv_ortho = torch.nn.Conv2d(
        in_channels  = full_width,
        out_channels = self.orthogonal,
        kernel_size  = (1, 1),
        stride       = (1, 1),
        padding      = 0,
        bias         = True
      )

  def get_parametric_skip(self,
      # Arguments:
      skip_weight: Tensor
    ) -> Callable[[Tensor, Tensor], Tensor]:
    """
    Retrieves a function to perform the parametric skip connection.

    Args:
      skip_weight (Tensor):
        The skip parameter weight.

    Returns:
      Callable[[Tensor, Tensor], Tensor]:
        The function to perform the skip.
    """
    def temp(x, y):
      p = torch.sigmoid(skip_weight)
      return (1 - p) * x + p * y
    return temp

  def forward(self,
      phi:     Tensor,
      x_tilde: Tensor
    ) -> Tensor:
    """
    Forward function for ``TSPThought``.

    Args:
      phi (Tensor):
        Input tensor from the previous recurrence.
      x_tilde (Tensor):
        Input tensor from the input module.

    Returns:
      Tensor:
        The output tensor from the current recurrence of the thought module.
    """
    if self.orthogonal > NONE_ORTHOGONAL:
      q, phi = phi[:, :self.orthogonal], phi[:, self.orthogonal:]
      phi_a, phi_b = phi[:, :self.ortho_phi_channels], \
                     phi[:,  self.ortho_phi_channels:]
      if self.orthogonal == LEFT_ORTHOGONAL:
        phi = torch.cat([phi_b, q, torch.matmul(q, phi_a)], dim = 1)
      else:
        q1, q2 = q[:, [0]], q[:, [1]]
        phi = torch.cat(
          [phi_b, q, torch.matmul(q1, torch.matmul(phi_a, q2.mT))],
          dim = 1
        )
    if self.constraint is not None:
      t = self.conv0(phi) + self.conv_recall(x_tilde)
    else:
      t = torch.cat([phi, x_tilde], dim = 1)
      t = self.conv_recall(t)
    t = self.act2(self.skip1(t, self.conv2(self.act1(self.conv1(t)))))
    t = self.act4(self.skip2(t, self.conv4(self.act3(self.conv3(t)))))
    if self.orthogonal > NONE_ORTHOGONAL:
      q = torch.tanh(q + self.conv_ortho(t))
      q = make_orthogonal(q)
      t = torch.cat([q, t], dim = 1)
    return t

class TSPOutput(Module):
  """
  Output module for TSP.
  """

  def __init__(self,
      # Arguments:
      width:    int,
      channels: Tuple[int, int],
      *args,
      # Keyword Arguments:
      activation:         Optional[Type[Module]] = None,
      bias:               Optional[bool]         = None,
      final_bias:         Optional[bool]         = None,
      orthogonal:         Optional[int]          = None,
      ortho_phi_channels: Optional[int]          = None,
      out_channels:       Optional[int]          = None,
      use_batchnorm:      Optional[bool]         = None,
      **kwargs
    ) -> None:
    """
    Initializes ``TSPOutput``.

    Args:
      width (int):
        The width of the network, as defined in [1]. This should be the same as
        the width for ``TSPThought``.
      channels (Tuple[int, int]):
        The channels of the two "h" layers (except the last).
      *args:
        Additional arguments for ``torch.nn.Module``.
      activation (torch.nn.Module (class), optional):
        The activation function.
        Defaults to ``ACTIVATION`` if ``None``.
      bias (bool, optional):
        Whether to use bias for convolutional layers.
        Defaults to ``BIAS`` if ``None``.
      final_bias (bool, optional):
        Whether to use bias for the final convolutional layer.
        Defaults to ``bias`` (argument) if ``None``.
      orthogonal (int, optional):
        The orthogonal enum (0 = none, 1 = left, 2 = both) for orthogonal
        transforms.
        Defaults to ``ORTHOGONAL`` if ``None``.
      ortho_phi_channels (int, optional):
        The number of (additional) channels in phi where the orthogonal
        transformations are applied to.
        Defaults to ``ORTHO_PHI_CHANNELS`` if ``None``.
      out_channels (int, optional):
        The number of channels for output.
        Defaults to ``OUTPUT_CHANNELS`` if ``None``.
      use_batchnorm (bool, optional):
        Whether to use batchnorm for layers.
        Defaults to ``USE_BATCHNORM`` if ``None``.
      **kwargs:
        Additional keyword arguments for ``torch.nn.Module``.
    """
    super(TSPOutput, self).__init__(*args, **kwargs)
    # Width.
    assert isinstance(width, int), \
      "width must be an integer."
    self.width = width
    # Channels.
    assert isinstance(channels[0], int) and isinstance(channels[1], int), \
      "channels must be a pair of integers."
    self.channels = channels
    # Bias.
    assert isinstance(bias, bool) or bias is None, \
      "bias must be a boolean or None."
    self.bias = BIAS if bias is None else bias
    # Final bias.
    assert isinstance(final_bias, bool) or final_bias is None, \
      "final_bias must be a boolean or None."
    self.final_bias = self.bias if final_bias is None else final_bias
    # Orthogonal.
    assert isinstance(orthogonal, int) or orthogonal is None, \
      "orthogonal must be an integer or None."
    self.orthogonal = ORTHOGONAL if orthogonal is None else orthogonal
    assert NONE_ORTHOGONAL <= self.orthogonal <= BOTH_ORTHOGONAL, \
      "orthogonal must be in the enum range [0, 2]."
    # Orthogonal phi channels.
    assert isinstance(ortho_phi_channels, int) or ortho_phi_channels is None, \
      "ortho_phi_channels must be an integer or None."
    self.ortho_phi_channels = ORTHO_PHI_CHANNELS if ortho_phi_channels is None \
                                                 else ortho_phi_channels
    # Out channels.
    assert isinstance(out_channels, int) or out_channels is None, \
      "out_channels must be an integer or None."
    print("OUT", out_channels)
    self.out_channels = OUTPUT_CHANNELS if out_channels is None \
                        else out_channels
    # Batch norm.
    assert isinstance(use_batchnorm, bool) or use_batchnorm is None, \
      "use_batchnorm must be a boolean or None."
    self.use_batchnorm = USE_BATCHNORM if use_batchnorm is None \
                                       else use_batchnorm
    # Model, convolution 1.
    self.conv1 = torch.nn.Conv2d(
      in_channels  = self.width + self.ortho_phi_channels,
      out_channels = self.channels[0],
      kernel_size  = (3, 3),
      stride       = (1, 1),
      padding      = 1,
      bias         = self.bias
    )
    self.norm1 = torch.nn.BatchNorm2d(self.channels[0]) if self.use_batchnorm \
                                                        else lambda x: x
    self.act1 = activation()
    # Model, convolution 2.
    self.conv2 = torch.nn.Conv2d(
      in_channels  = self.channels[0],
      out_channels = self.channels[1],
      kernel_size  = (3, 3),
      stride       = (1, 1),
      padding      = 1,
      bias         = self.bias
    )
    self.norm2 = torch.nn.BatchNorm2d(self.channels[1]) if self.use_batchnorm \
                                                        else lambda x: x
    self.act2 = activation()
    # Model, convolution 3.
    # This layer has no padding, as padding as added in a custom way.
    self.conv3 = torch.nn.Conv2d(
      in_channels  = self.channels[1],
      out_channels = self.out_channels,
      kernel_size  = (3, 3),
      stride       = (1, 1),
      padding      = 0,
      bias         = self.final_bias
    )

  def forward(self,
      # Arguments:
      phi_M: Tensor
    ) -> Tensor:
    """
    Forward function for ``TSPOutput``.

    Args:
      phi_M (Tensor):
        The input tensor from the final recurrence of ``TSPThought``.

    Returns:
      Tensor:
        The final output tensor.
    """
    if self.orthogonal > NONE_ORTHOGONAL:
      phi_M = phi_M[:, self.orthogonal:]
    t = self.act1(self.norm1(self.conv1(phi_M)))
    t = self.act2(self.norm2(self.conv2(t)))
    # Custom pad to cut leftmost and topmost parts.
    t = torch.nn.functional.pad(t, (0, 1, 0, 1))
    t = self.conv3(t)
    # If training, add Gumbel noise to sample and use Sinkhorn to generate a
    # doubly-stochastic and (hopefully) permutation matrix. Otherwise use
    # linear sum assignment.
    if self.training:
      t = gumbel_sinkhorn(t, SINKHORN_TAU, SINKHORN_N)
    else:
      t = torch.stack([
        linear_assignment(_t[0])
        for _t in t.cpu().detach().numpy()
      ]).unsqueeze(1).float().to(t.device)
    return t

def make_tsp_model(
    # Arguments:
    width:      int,
    h_channels: Tuple[int, int],
    *args,
    # Keyword Arguments:
    activation:               Optional[Type[Module]]   = None,
    bias:                     Optional[bool]           = None,
    constraint:               Optional[str]            = None,
    final_bias:               Optional[bool]           = None,
    in_channels:              Optional[int]            = None,
    max_iterations:           Optional[int]            = None,
    ones_channel:             Optional[bool]           = None,
    orthogonal:               Optional[int]            = None,
    ortho_phi_channels:       Optional[int]            = None,
    out_channels:             Optional[int]            = None,
    parametric_skip:          Optional[bool]           = None,
    use_batchnorm:            Optional[bool]           = None,
    use_incremental_progress: Optional[bool]           = None,
    get_json_dict:            bool                     = False,
    **kwargs
  ) -> Union[Module, Tuple[Module, JSONType]]:
  """
  Constructs TSP deep thinking model.

  Args:
    width (int):
      The width of the network, as defined in [1]. This should be the same as
      the width for ``TSPThought``.
    h_channels (Tuple[int, int]):
      The channels of the two "h" layers (except the last).
    *args:
      Additional arguments for every ``torch.nn.Module``.
    activation (torch.nn.Module (class), optional):
      The activation function.
      Defaults to ``ACTIVATION`` if ``None``.
    bias (bool, optional):
      Whether to use bias for convolutional layers.
      Defaults to ``BIAS`` if ``None``.
    constraint (str, optional):
      The type of constraint to use.
      Defaults to ``None``.
    final_bias (bool, optional):
      Whether to use bias for the final convolutional layer.
      Defaults to ``bias`` (argument) if ``None``.
    in_channels (int, optional):
      The number of channels for input.
      Defaults to ``INPUT_CHANNELS`` if ``None``.
    max_iterations (int, optional):
      The maximum number of iterations of the thought network to perform. This
      can be changed safely inside the model, and can be changed during
      inference using the ``max_iterations`` argument of the forward function.
      Defaults to ``MAX_ITERATIONS`` if ``None``.
    ones_channel (bool, optional):
      Whether to add an extra channel of ones to the input. This effectively
      adds bias to the first layer in the preprocessing module and the recall
      layer in the thought module.
      Defaults to ``ONES_CHANNEL`` if ``None``.
    orthogonal (int, optional):
      The orthogonal enum (0 = none, 1 = left, 2 = both) for orthogonal
      transforms.
      Defaults to ``ORTHOGONAL`` if ``None``.
    ortho_phi_channels (int, optional):
      The number of (additional) channels in phi where the orthogonal
      transformations are applied to.
      Defaults to ``ORTHO_PHI_CHANNELS`` if ``None``.
    out_channels (int, optional):
      The number of channels for output.
      Defaults to ``OUTPUT_CHANNELS`` if ``None``.
    parametric_skip (bool, optional):
      Whether the skip connection should have a parameterized weight.
      Defaults to ``PARAMETRIC_SKIP`` if ``None``.
    use_batchnorm (bool, optional):
      Whether to use batchnorm for layers.
      Defaults to ``USE_BATCHNORM`` if ``None``.
    use_incremental_progress (bool, optional):
      Whether to use incremental progress training, as defined in [1]. This
      only applies during training.
      Defaults to ``USE_INCREMENTAL_PROGRESS`` if ``None``.
    get_json_dict (bool, optional):
      If ``True``, this returns a dictionary (in JSON-ready format) containing
      the above parameters used for the model. This can be read to construct the
      model again with ``load_from_json_dict``.
      Defaults to ``False``.
    **kwargs:
      Additional keyword arguments for ``torch.nn.Module``.

  Returns:
    (Module | (Module, JSONType)):
      The model if ``get_json_dict`` is ``False``, otherwise a tuple containing
      both the model and the JSON-ready dictionary describing the model's
      hyperparameters.
  """
  assert isinstance(get_json_dict, bool), \
    "get_json_dict must be a boolean."
  max_iterations = MAX_ITERATIONS if max_iterations is None else max_iterations
  use_incremental_progress = USE_INCREMENTAL_PROGRESS \
                             if use_incremental_progress is None \
                             else use_incremental_progress
  model = DeepThinkingRecall(
    TSPInput(
      *args,
      **kwargs
    ),
    TSPPreprocessing(
      width,
      *args,
      activation         = activation,
      bias               = bias,
      in_channels        = in_channels,
      ones_channel       = ones_channel,
      orthogonal         = orthogonal,
      ortho_phi_channels = ortho_phi_channels,
      use_batchnorm      = use_batchnorm,
      **kwargs
    ),
    TSPThought(
      width,
      *args,
      activation         = activation,
      bias               = bias,
      constraint         = constraint,
      in_channels        = in_channels,
      ones_channel       = ones_channel,
      orthogonal         = orthogonal,
      ortho_phi_channels = ortho_phi_channels,
      parametric_skip    = parametric_skip,
      **kwargs
    ),
    TSPOutput(
      width,
      h_channels,
      *args,
      activation         = activation,
      bias               = bias,
      final_bias         = final_bias,
      orthogonal         = orthogonal,
      ortho_phi_channels = ortho_phi_channels,
      out_channels       = out_channels,
      use_batchnorm      = use_batchnorm,
      **kwargs
    ),
    *args,
    max_iterations           = max_iterations,
    use_incremental_progress = use_incremental_progress,
    **kwargs
  )
  if get_json_dict:
    # Get defaults for JSON. Default values still need to be recorded.
    activation = ACTIVATION if activation is None else activation
    bias = BIAS if bias is None else bias
    in_channels = INPUT_CHANNELS if in_channels is None else in_channels
    out_channels = OUTPUT_CHANNELS if out_channels is None else out_channels
    # For almost all activation functions, this shouldn't be an issue, but
    #  as a precaution a warning is given if it isn't possible to load the
    #  correct function from a string.
    try:
      getattr(torch.nn, activation.__name__)
    except:
      warn(
        f"'{activation.__name__}' cannot be converted back into the " + \
        "torch.nn class it comes from. The name will still be recorded, " + \
        "but experiments reading this model from a file may need to have " + \
        "its activation manually set to the correct class."
      )
    json_dict = {
      "width":                    width,
      "h_channels":               list(h_channels),
      "activation":               activation.__name__,
      "bias":                     bias,
      "constraint":               constraint,
      "final_bias":               final_bias,
      "in_channels":              in_channels,
      "max_iterations":           max_iterations,
      "ones_channel":             ones_channel,
      "orthogonal":               orthogonal,
      "ortho_phi_channels":       ortho_phi_channels,
      "out_channels":             out_channels,
      "parametric_skip":          parametric_skip,
      "use_batchnorm":            use_batchnorm,
      "use_incremental_progress": use_incremental_progress,
    }
    return (model, json_dict)
  return model

def load_from_json_dict(
    # Arguments:
    json_dict: JSONType
  ) -> Module:
  """
  Constructs a tsp deep thinking model from a JSON-ready dictionary.

  Args:
    json_dict (JSONType):
      The JSON-ready dictionary containing the construction hyperparameters.

  Returns:
    Module:
      The reconstructed model.
  """
  # Yes, I'm fully aware that **{...} exists.
  return make_tsp_model(
    width                    = int(json_dict["width"]),
    h_channels               = tuple(int(x) for x in json_dict["h_channels"]),
    activation               = getattr(torch.nn, json_dict["activation"]),
    bias                     = json_dict.get("bias"),
    constraint               = json_dict.get("constraint"),
    final_bias               = json_dict.get("final_bias"),
    in_channels              = json_dict.get("in_channels"),
    max_iterations           = json_dict.get("max_iterations"),
    ones_channel             = json_dict.get("ones_channel"),
    orthogonal               = json_dict.get("orthogonal"),
    ortho_phi_channels       = json_dict.get("ortho_phi_channels"),
    out_channels             = json_dict.get("out_channels"),
    parametric_skip          = json_dict.get("parametric_skip"),
    use_batchnorm            = json_dict.get("use_batchnorm"),
    use_incremental_progress = json_dict.get("use_incremental_progress"),
    get_json_dict            = False
  )

# REFERENCES:
#
# [1] Arpit Bansal, Avi Schwarzschild, Eitan Borgnia, Zeyad Emam, Furong Huang,
#     Micah Goldlum, and Tom Goldstein.
#     "End-to-End Algorithm Synthesis with Recurrent Networks: Extrapolation
#      without Overthinking".
#     Advances in Neural Information Processing Systems. Oct 2022.
#     https://arxiv.org/abs/2202.05826