################################################################################
# spectral/modules/trainer.py
#
# 
# 
# 
# 2024
#
#

import torch

from os.path  import isdir, join
from typing   import Optional
from warnings import warn

from .warmup                   import BaseWarmup
from choochoo.training.logging import LoggingTrainer
from experilog.logger          import JSONType

TRAINING:   str = "training"
VALIDATION: str = "validation"

Scheduler = torch.optim.lr_scheduler._LRScheduler

class DeepThinkingTrainer(LoggingTrainer):

  def __init__(self,
      # Arguments:
      config: JSONType,
      *args,
      # Keyword Arguments:
      clip:               Optional[float]      = None,
      clip_foreach:       Optional[bool]       = None,
      save_best:          Optional[str]        = None,
      save_dir:           Optional[str]        = None,
      scheduler:          Optional[Scheduler]  = None,
      scheduler_callback: Optional[str]        = None,
      warmup:             Optional[BaseWarmup] = None,
      name:               str                  = "model",
      **kwargs
    ) -> None:
    """
    Initializes ``DeepThinkingTrainer``.

    Args:
      config (JSONType):
        The config dictionary used to construct the trainer and model.
      *args:
        Additional arguments for ``LoggingTrainer``.
      clip (float, optional):
        The float value to clip gradient norms to.
        Defaults to ``None``.
      clip_foreach (bool, optional):
        Whether gradient clipping applies to each parameter (``True``) or
        all parameters (``False``).
        Defaults to ``None`` (``False`` if ``clip`` is not ``None``).
      save_best (str, optional):
        If specified, this is either ``"max"`` or ``"min"``, and determines the
        measurement for comparison on deciding whether to overwrite the current
        saved version.
        Defaults to ``None`` (no saving).
      save_dir (str, optional):
        Should be used if ``save_best`` is not ``None``. This is the directory
        to save models and configs to.
        Defaults to ``None``.
      scheduler (Scheduler, optional):
        The LR scheduler for training.
        Defaults to ``None``.
      scheduler_callback (str, optional):
        Either ``"training"`` or ``"validation"``. The type of loss specified
        will be fed into the scheduler as a step. Otherwise, no value will be
        used.
        Defaults to ``None``.
      warmup (BaseWarmup, optional):
        The warmup to use for training.
        Defaults to ``None``.
      name (str, optional):
        The model name.
        Defaults to ``"model"``.
    """
    super(DeepThinkingTrainer, self).__init__(*args, **kwargs)
    # Config.
    self.config = config
    # Clip.
    assert isinstance(clip, float) or clip is None, \
      "clip must be a float or None."
    self.clip = clip
    # Clip foreach.
    assert isinstance(clip_foreach, bool) or clip_foreach is None, \
      "clip_foreach must be a bool or None."
    if self.clip is not None:
      self.clip_foreach = False if clip_foreach is None else clip_foreach
    else:
      self.clip_foreach = None
    # Save best.
    assert isinstance(save_best, str) or save_best is None, \
      "save_best must be a string or None."
    self.save_best = save_best
    if self.save_best is not None:
      self._current_best = None
      self._is_better = lambda current, new: new <= current
      if self.save_best == "max":
        self._is_better = lambda current, new: new >= current
      elif self.save_best != "min":
        warn(
          f"'{self.save_best}' is not a recognized save mode. Defaulting " + \
          "to 'min'."
        )
        self.save_best = "min"
    # Save directory.
    assert isinstance(save_dir, str) or save_dir is None, \
      "save_dir must be a string or None."
    if save_dir is None and self.save_best is not None:
      raise Exception("save_dir must be specified if save_best is used.")
    assert isdir(save_dir), \
      f"{save_dir} is not a valid directory."
    self.save_dir = save_dir
    # Scheduler and warmup.
    self.scheduler          = scheduler
    self.scheduler_callback = scheduler_callback
    self.warmup             = warmup
    # Name.
    self._name = name

  def _save_model(self):
    filename = self._name + self._train_start_stamp + ".tar"
    filename = join(self.save_dir, filename)
    torch.save(
      {
        "timestamp":  self._train_start_stamp,
        "config":     self.config,
        "epoch":      self._current_epoch,
        "train_loss": self._train_average(),
        "valid_loss": self._validation_average(),
        "model_state": self.model.state_dict(),
        "optim_state": self.optimizer.state_dict()
      },
      filename
    )

  def on_train_start(self):
    super(DeepThinkingTrainer, self).on_train_start()
    if self._current_best is None:
      self._current_best = -torch.inf if self.save_best == "max" else torch.inf

  def process_training_batch(self):
    if self._current_batch is None: pass
    input_batch, target_batch = self._current_batch
    self._current_batch = None
    input_batch  = input_batch.to(self.device)
    target_batch = target_batch.to(self.device)
    predicted_batch = self.model(input_batch)
    self._current_loss = self.loss_fn(predicted_batch, target_batch)
    for constraint in self.model.thought_module.constraints:
      constraint._update_enabled = True

  def optimization_step(self):
    if self._current_loss is None: pass
    # This is required to counteract a strange bug that exists in MPS, where
    # loss can become 1.0 after using .backward() unless it is evaluated in
    # some form:
    #_ = self._current_loss.item()
    self._current_loss.backward()
    self._batch_loss = self._current_loss.detach().item()
    self._current_loss = None
    if self.clip is not None:
      if self.clip_foreach:
        for p in self.model.parameters():
          torch.nn.utils.clip_grad_norm_(p, self.clip)
      else:
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip)
    self.optimizer.step()

  def on_epoch_end(self):
    super(DeepThinkingTrainer, self).on_epoch_end()
    if self.scheduler_callback is not None:
      scheduler_step_value = \
        self._train_average() if self.scheduler_callback == TRAINING \
                              else self._validation_average()
      self.scheduler.step(scheduler_step_value)
    else:
      self.scheduler.step()
    if self.warmup is not None:
      self.warmup.dampen()
    if self.save_best is not None:
      contender = self._validation_average()
      if self._is_better(self._current_best, contender):
        if self.verbose >= 1:
          print("New best model. Saving.")
        self._current_best = contender
        self._save_model()