################################################################################
# training/utils/logging.py
#
# 
# 
# 2023
#
# Implementations of the simple trainer with logging.

from datetime import datetime
from os.path  import isdir, isfile, join
from typing   import TextIO

import torch

from .utils.reduction_tools import *
from .utils.type_checking   import *
from .validation            import ValidationTrainer

DataLoader = torch.utils.data.DataLoader
Module     = torch.nn.Module
Optimizer  = torch.optim.Optimizer
Tensor     = torch.Tensor

# Verbosity enumerable:
# No output.
SILENT: int = 0
# Only epoch results.
QUIET: int = 1
# Epoch progress bars and live results.
FULL: int = 2

class LoggingTrainer(ValidationTrainer):
  """
  Implementation of ``ValidationTrainer`` which logs results.
  """

  def __init__(self,
      # Arguments:
      log_dir: str,
      *args,
      # Keyword Arguments:
      log_name:       str  = "log",
      file_per_epoch: bool = False,
      **kwargs
    ):
    """
    Initializes ``LoggingTrainer``.

    Args:
      log_dir (str):
        The directory to save the logs to.
      log_name (str, optional):
        The start of the log file names. The timestamp (and possibly epoch) is
        appended to this, and saved as a ``.log`` file.
        Defaults to ``"log"``.
      file_per_epoch (bool, optional):
        Whether to create a new file for each epoch, otherwise the same file
        is used for the whole training instance.
        Defaults to ``False``.
    """
    super(LoggingTrainer, self).__init__(*args, **kwargs)
    # Type check log_dir.
    check_if_type_or_none(log_dir, str, "log_dir")
    assert isdir(log_dir), \
      "log_dir needs to be a valid directory."
    self.log_dir = log_dir
    # Type check log_name.
    check_if_type_or_none(log_name, str, "log_name")
    self.log_name = log_name
    # Type check file_per_epoch.
    check_if_type_or_none(file_per_epoch, bool, "file_per_epoch")
    self.file_per_epoch = file_per_epoch

  @property
  def _current_filename(self) -> str:
    """
    Gets the filename with the directory prepended.
    """
    # Create the filename in the "[log_name][timestamp]<E[epoch]>.log" format.
    filename: str = self.log_name + self._train_start_stamp
    if self.file_per_epoch:
      filename += "E" + str(self._current_epoch)
    filename += ".log"
    # Return the filename joined with the directory.
    return join(self.log_dir, filename)
  
  def on_train_start(self):
    # Get a timestamp used to differentiate this training instance from others
    #  in the logs.
    self._train_start_stamp = datetime.utcnow().strftime("%Y%m%dT%H%M%SZ")
    if self.verbose >= QUIET:
      print(
        f"Creating log files in directory\n  \"{self.log_dir}\"\nwith " + \
        f"timestamp \"{self._train_start_stamp}\"."
      )
    super(LoggingTrainer, self).on_train_start()

  def on_epoch_end(self):
    super(LoggingTrainer, self).on_epoch_end()
    filename: str = self._current_filename
    # Determine whether to append or write (depending on if the file exists or
    #  not).
    writetype: str = "a" if isfile(filename) else "w"
    file: TextIO = open(filename, writetype)
    file.write(
      f"EPOCH {self._current_epoch}:\n" + \
      f"  TRAIN: {self._train_average()}\n" + \
      f"  VALID: {self._validation_average()}\n"
    )
    file.close()
    if self.verbose == FULL:
      print(("Written" if writetype == "w" else "Appended") + " to log.")
