################################################################################
# training/utils/simple.py
#
# 
# 
# 2023
#
# Simple implementations of the trainer.

from collections.abc import Callable
from tqdm            import tqdm

import torch

from .utils.reduction_tools import *
from .utils.type_checking   import *
from .trainer               import Trainer

DataLoader = torch.utils.data.DataLoader
Module     = torch.nn.Module
Optimizer  = torch.optim.Optimizer
Tensor     = torch.Tensor

# Verbosity enumerable:
# No output.
SILENT: int = 0
# Only epoch results.
QUIET: int = 1
# Epoch progress bars and live results.
FULL: int = 2

class SimpleTrainer(Trainer):
  """
  Simple implementation of ``Trainer`` for training a model.
  """

  def __init__(self,
      # Arguments:
      model:            Module,
      loss_fn:          Callable[[Tensor, Tensor], Tensor],
      optimizer:        Optimizer,
      train_dataloader: DataLoader,
      # Keyword Arguments:
      device:                str        = "cpu",
      verbose:               int        = 0,
      format_loss_string:    str        = "Loss: %.3f",
      progress_bar_fn                   = tqdm
    ):
    """
    Initializes ``SimpleTrainer``.

    Args:
      model (Module):
        Model to train.
      loss_fn (Callable[[Tensor, Tensor], Tensor]):
        Loss function for training.
      optimizer (Optimizer):
        Optimizer.
      train_dataloader (DataLoader):
        DataLoader used for loading the training dataset.
      device (str, optional):
        Device for tensors.
        Defaults to ``"cpu"``.
      verbose (int, optional):
        Verbosity.
        Defaults to ``0`` (SILENT).
      format_loss_string (str, optional):
        Formattable loss string to use for the progress bar on verbose = FULL.
        Defaults to ``"Loss: %.3f"``.
      progress_bar_fn (Callable, optional):
        Override for ``tqdm`` if alternative classes are needed.
        Defaults to ``tqdm``.
    """
    # Type check model.
    check_if_type_or_none(model, Module, "model")
    self.model = model
    # Assume loss_fn is callable in some way.
    self.loss_fn = loss_fn
    # Type check optimizer.
    check_if_type_or_none(optimizer, Optimizer, "optimizer")
    self.optimizer = optimizer
    # Type check train_dataloader.
    check_if_type_or_none(train_dataloader, DataLoader, "train_dataloader")
    self.train_dataloader = train_dataloader
    # Type check device.
    check_if_type_or_none(device, str, "device")
    self.device = device
    # Type check verbose.
    check_if_type_or_none(verbose, int, "verbose")
    self.verbose = max(0, min(2, verbose))
    # Type check format_loss_string.
    check_if_type_or_none(format_loss_string, str, "format_loss_string")
    self.format_loss_string = format_loss_string
    # Progress bar.
    self.progress_bar_fn = progress_bar_fn

  def on_train_start(self):
    if self.verbose >= QUIET:
      print("Beginning training...")

  def on_epoch_start(self):
    self._train_average = RunningAverage()
    if self.verbose == FULL:
      self._progress_bar = self.progress_bar_fn(
        total = len(self.train_dataloader)
      )

  # on_batch_start is a good place to perform zero_grad.
  def on_batch_start(self):
    self.optimizer.zero_grad()

  # Get the current batch and calculate the loss.
  def process_training_batch(self):
    if self._current_batch is None: pass
    input_batch, target_batch = self._current_batch
    self._current_batch = None
    input_batch  = input_batch.to(self.device)
    target_batch = target_batch.to(self.device)
    predicted_batch = self.model(input_batch)
    self._current_loss = self.loss_fn(predicted_batch, target_batch)

  # Backwards pass through the function and update the parameters.
  def optimization_step(self):
    if self._current_loss is None: pass
    self._current_loss.backward()
    self._batch_loss = self._current_loss.item()
    self._current_loss = None
    self.optimizer.step()
  
  def on_batch_end(self):
    self._train_average + self._batch_loss
    if self.verbose == FULL:
      self._progress_bar.set_description(
        self.format_loss_string % self._batch_loss
      )
      self._progress_bar.update(1)
    
  def on_epoch_end(self):
    if self.verbose == FULL:
      self._progress_bar.close()
    if self.verbose >= QUIET:
      print(
        f"Epoch {self._current_epoch} (average): " + \
        self.format_loss_string % self._train_average()
      )
  
  def on_train_end(self):
    if self.verbose >= QUIET:
      print("Training complete or stopped.")

  def _full_batch(self):
    self.on_batch_start()
    self.process_training_batch()
    self.optimization_step()
    self.on_batch_end()

  def _full_epoch(self):
    self.on_epoch_start()
    for _current_batch in self.train_dataloader:
      self._current_batch = _current_batch
      self._full_batch()
    self.on_epoch_end()

  def _full_train(self):
    self.on_train_start()
    for _current_epoch in range(self._start_epoch, self._max_epochs):
      self._current_epoch = _current_epoch
      self._full_epoch()
    self.on_train_end()

  def __call__(self,
      # Keyword Arguments:
      max_epochs:  int = 1,
      start_epoch: int = 0
    ):
    """
    ``SimpleTrainer`` implementation of ``__call__``.

    Args:
      max_epochs (int, optional):
        Maximum number of epochs.
        Defaults to ``1``.
      start_epoch (int, optional):
        Epoch to start on.
        Defaults to ``0``.
    """
    # Type check both epoch number inputs and set.
    check_if_type(max_epochs, int, "max_epochs")
    self._max_epochs = max(0, max_epochs)
    check_if_type(start_epoch, int, "start_epoch")
    self._start_epoch = max(0, start_epoch)
    self._full_train()