################################################################################
# deepthinking/models/deep_thinking_vanilla.py
#
# 
# 
# 2023
#
# Implementation of a wrapper for 'vanilla' (non-recall) deep thinking
#  networks as originally described in [1].
# Extended to be a generalized version for implementation of individual
#  components (modules).

import torch

from typing import Callable, Optional, Tuple, Union

Module = torch.nn.Module
Tensor = torch.Tensor

class DeepThinking(Module):
  """
  Wrapper for 'vanilla' (non-recall) deep thinking systems as originally
  described by Schwarzschild et al. (2021).
  """

  def __init__(self,
      # Arguments:
      input_module:   Module,
      thought_module: Module,
      output_module:  Module,
      *args,
      # Keyword Arguments:
      max_iterations:           int  = 1,
      use_incremental_progress: bool = False,
      **kwargs,
    ):
    """
    Initializes ``DeepThinking``.

    Args:
      input_module (Module):
        Transforms and pre-processes before thought.
      thought_module (Module):
        Will be iterated to perform the thought.
      output_module (Module):
        Transforms thought into output.
      max_iterations (int, optional):
        Maximum iterations for training and inference.
        Defaults to ``1``.
      use_incremental_progress (bool, optional):
        Whether to use incremental progress training.
        Defaults to ``False``.
      *args:
        Additional arguments for ``Module``.
      **kwargs:
        Additional keyword arguments for ``Module``.
    """
    super(DeepThinking, self).__init__(*args, **kwargs)
    self.input_module   = input_module
    self.thought_module = thought_module
    self.output_module  = output_module
    # Type/value checking will only happen here. Modifications to these values
    #  elsewhere are unchecked.
    assert isinstance(max_iterations, int), \
      "max_iterations must be an integer."
    assert max_iterations >= 0, \
      "max_iterations must be positive or zero."
    self.max_iterations = max_iterations
    assert isinstance(use_incremental_progress, bool), \
      "use_incremental_progress must be a bool."
    self.use_incremental_progress = use_incremental_progress

  def incremental_progress(self,
      # Arguments:
      x_tilde:        Tensor,
      max_iterations: int
    ) -> Tuple[Tensor, Tensor]:
    """
    Implements the iterative portion of the "Incremental Progress Training
    Algorithm" by Bansal et al. (2022).

    Args:
      x_tilde (Tensor):
        Initial thought value (from input module).
      max_iterations (int):
        Maximum number of iterations.

    Returns:
      Tuple[Tensor, Tensor]:
        Two tensors (phi_prog, phi_m)
    """
    # n ~ U{0, m-1}, k ~ U{1, m-n}
    n: int = torch.randint(
      low = 0, high = max_iterations - 1,
      size = (1,)
    ).item()
    k: int = torch.randint(
      low = 1, high = max_iterations - n,
      size = (1,)
    ).item()
    # Set up progressive thought.
    phi: Tensor = x_tilde
    # First n iterations do not track gradients.
    with torch.no_grad():
      for _ in range(n):
        phi = self.perform_iteration(phi)
    # no_grad doesn't calculate gradients. detach ensures these are not part
    #  of the graph.
    phi = phi.detach()
    # Final k iterations do track gradients.
    for _ in range(k):
      phi = self.perform_iteration(phi)
    phi_prog: Tensor = phi
    # Set up maximum iteration thought.
    phi = x_tilde
    # All iterations track gradients.
    for _ in range(max_iterations):
      phi = self.perform_iteration(phi)
    phi_m: Tensor = phi
    # Return both outputs.
    return (phi_prog, phi_m)

  def incremental_progress_loss(self,
      # Arguments:
      y_hat:   Tuple[Tensor, Tensor],
      y:       Tensor,
      loss_fn: Callable[[Tensor, Tensor], Tensor],
      # Keyword Arguments:
      alpha:       float = 0.0,
      batcherized: bool  = False
    ) -> Tensor:
    """
    The loss function for incremental progress for this particular deep thinking
    system.

    Args:
      y_hat (Tuple[Tensor, Tensor]):
        Output module result (y_prog, y_m).
      y (Tensor):
        True value to compare against for loss.
      loss_fn (Callable[[Tensor, Tensor], Tensor]):
        Loss function to calculate L_prog and L_m.
      alpha (float, optional):
        Weighting for L_prog (L_m is weighted by ``1 - alpha``).
        Defaults to ``0.0``.
      batcherized (bool, optional):
        Used to signal that the incremental progress loss is in concatenated
        form.
        Defaults to ``False``.

    Returns:
      Tensor:
        Loss tensor.
    """
    if batcherized:
      y_hat_prog, y_hat_m = torch.split(y_hat, 2, dim = 0)
    else:
      y_hat_prog, y_hat_m = y_hat
    # Calculate individual losses.
    L_prog: Tensor = loss_fn(y_hat_prog, y)
    L_m:    Tensor = loss_fn(y_hat_m,    y)
    # Apply weighting and sum.
    L: Tensor = (1 - alpha) * L_m + alpha * L_prog
    return L

  def perform_iteration(self,
      # Arguments:
      phi: Tensor
    ) -> Tensor:
    """
    Implements a single iteration of the thought module(s).

    Args:
      phi (Tensor):
        Current state of the thought module.

    Returns:
      Tensor:
        Next state of the thought module.
    """
    # In the case of the vanilla DT network, this is just the thought module.
    phi = self.thought_module(phi)
    return phi

  def forward(self,
      # Arguments:
      x: Tensor,
      # Keyword Arguments:
      max_iterations: Optional[int]    = None,
      return_thought: bool             = False,
      phi:            Optional[Tensor] = None
    ) -> Union[
      Tuple[Tensor, Tensor, Tensor, Tensor], Tuple[Tensor, Tensor], Tensor
    ]:
    """
    Forward pass of the model.

    Args:
      x (Tensor):
        Model input.
      max_iterations (int, optional):
        Override attribute value.
        Defaults to ``None`` (don't override value).
      return_thought (bool, optional):
        Whether to return the thought after iterations.
        Defaults to ``False``.
      phi (Tensor, optional):
        The intermediatary phi used only for inference, to avoid recalculating
        many phis.
        Defaults to ``None``.

    Returns:
      Tuple[Tensor, Tensor, Tensor, Tensor]:
        (y_hat_prog, y_hat_m, phi_prog, phi_m)...
        If ``use_incremental_progress`` *and* ``return_thought``.
      Tuple[Tensor, Tensor]:
        (y_hat_prog, y_hat_m) OR (y_hat, phi)...
        First case if *only* ``use_incremental_progress``.
        Second case if *only* ``return_thought``.
      Tensor:
        y_hat...
        If neither ``use_incremental_progress`` nor ``return_thought``.
    """
    m: int = self.max_iterations if max_iterations is None else max_iterations
    # Get the first value of thought from input module.
    x_tilde: Tensor = self.input_module(x)
    # Only use incremental training if currently training and has the flag set.
    if self.training and self.use_incremental_progress:
      phi_prog, phi_m = self.incremental_progress(x_tilde, m)
      y_hat_prog: Tensor = self.output_module(phi_prog)
      y_hat_m:    Tensor = self.output_module(phi_m)
      return (y_hat_prog, y_hat_m, phi_prog, phi_m) if return_thought else \
             (y_hat_prog, y_hat_m)
    # Otherwise just compute iterations normally.
    phi = x_tilde if phi is None else phi
    for _ in range(m):
      phi = self.perform_iteration(phi)
    y_hat: Tensor = self.output_module(phi)
    return (y_hat, phi) if return_thought else y_hat

# REFERENCES:
#
# [1] Avi Schwarzschild, Eitan Borgnia, Arjun Gupta, Furong Huang, Uzi Vishkin,
#     Micah Goldblum, and Tom Goldstein.
#     "Can You Learn an Algorithm? Generalizing from Easy to Hard Problems
#      with Recurrent Networks".
#     CoRR. June 2021.
#     https://arxiv.org/abs/2106.04537
#
# [2] Arpit Bansal, Avi Schwarzschild, Eitan Borgnia, Zeyad Emam, Furong Huang,
#     Micah Goldlum, and Tom Goldstein.
#     "End-to-End Algorithm Synthesis with Recurrent Networks: Extrapolation
#      without Overthinking".
#     Advances in Neural Information Processing Systems. Oct 2022.
#     https://arxiv.org/abs/2202.05826