# -*- coding: utf-8 -*-
import time
from tianshou.utils import MovAvg
from tensorboard.backend.event_processing import event_accumulator
from tianshou.utils import TensorboardLogger as TSLogger
from tianshou.utils.logger.base import LOG_DATA_TYPE, BaseLogger
from aim import Run


class NoMovAvg(MovAvg):

  def __init__(self):
    super().__init__(size=1)


class AimRayTrainableLogger(BaseLogger):
  """ Ray Trainable's AIM logger
  def setup():
    In Trainable.setup():
    repo = os.getcwd()
    experiment_name = self._trial_info.experiment_name
    """

  def __init__(self,
               train_interval: int = 1000,
               test_interval: int = 1,
               update_interval: int = 50,  # by gradient step
               save_interval: int = 1,
               repo=None,
               exp_name=None,
               config=None,
               **aim_run_kwargs) -> None:
    super().__init__(train_interval, test_interval, update_interval)
    self.save_interval = save_interval
    self.last_save_step = -1
    self.run_steps_cache_dict = dict()

    while True:
      try:
        self.writer = Run(repo=repo, experiment=exp_name, **aim_run_kwargs)
        break
      except:
        time.sleep(0.1)

    self.writer['ts_config'] = config
    self.config = config

  def write(self,
            step_type: str,
            step: int,
            data: LOG_DATA_TYPE,
            epoch=None,
            context=None) -> None:
    for k, v in data.items():
      self.writer.track(
          value=v, name=k, step=step, epoch=epoch, context=context)

  def save_data(self, epoch: int, env_step: int, gradient_step: int,
                save_checkpoint_fn, last_log_update_step: int) -> None:
    # aim ray Trainable logger only: save steps at every epoch
    self.last_save_step = epoch
    self.run_steps_cache_dict["save/epoch"] = epoch
    self.run_steps_cache_dict["save/env_step"] = env_step
    self.run_steps_cache_dict["save/gradient_step"] = gradient_step
    self.run_steps_cache_dict[
        "save/last_log_update_step"] = last_log_update_step
    # if save_checkpoint_fn and epoch - self.last_save_step >= self.save_interval:
    #   save_checkpoint_fn(epoch, env_step, gradient_step)

  def restore_data(self):
    try:  # epoch / gradient_step
      epoch = self.run_steps_cache_dict["save/epoch"]
      gradient_step = self.run_steps_cache_dict["save/gradient_step"]
      self.last_save_step = self.last_log_test_step = epoch
      self.last_log_update_step = self.run_steps_cache_dict[
          "save/last_log_update_step"]
    except KeyError:
      epoch, gradient_step = 0, 0

    try:  # offline trainer doesn't have env_step
      env_step = self.run_steps_cache_dict["save/env_step"]
      self.last_log_train_step = env_step
    except KeyError:
      env_step = 0

    return epoch, env_step, gradient_step


class TensorboardLogger(TSLogger):
  ''' Ray Trainable's Tensorboard logger
  def setup():
    # Logger
    now = datetime.now().strftime("%y%m%d-%H%M%S")
    self.config.algo_name = "ppo"
    log_name = os.path.join(self.config.env, self.config.algo_name,
                            'seed' + str(self.config.seed), now)
    self.log_path = os.path.join(self.config.logdir, log_name)

    # logger
    writer = SummaryWriter(self.log_path)
    writer.add_text("config", str(self.config))
    self.logger = TensorboardLogger(writer, update_interval=50)
  '''

  def save_data(self, epoch: int, env_step: int, gradient_step: int,
                save_checkpoint_fn, last_log_update_step: int) -> None:
    if save_checkpoint_fn and epoch - self.last_save_step >= self.save_interval:
      self.last_save_step = epoch
      save_checkpoint_fn(epoch, env_step, gradient_step)
      self.write("save/epoch", epoch, {"save/epoch": epoch})
      self.write("save/env_step", env_step, {"save/env_step": env_step})
      # SG: fix gradient_step = last_log_update_step bug
      self.write("save/gradient_step", gradient_step,
                 {"save/gradient_step": gradient_step})
      self.write("save/last_log_update_step", last_log_update_step,
                 {"save/last_log_update_step": last_log_update_step})

  def restore_data(self):
    ea = event_accumulator.EventAccumulator(self.writer.log_dir)
    ea.Reload()

    try:  # epoch / gradient_step
      epoch = ea.scalars.Items("save/epoch")[-1].step
      self.last_save_step = self.last_log_test_step = epoch
      gradient_step = ea.scalars.Items("save/gradient_step")[-1].step

      # SG: fix gradient_step = last_log_update_step bug
      self.last_log_update_step = ea.scalars.Items(
          "save/last_log_update_step")[-1].step

    except KeyError:
      epoch, gradient_step = 0, 0
    try:  # offline trainer doesn't have env_step
      env_step = ea.scalars.Items("save/env_step")[-1].step
      self.last_log_train_step = env_step
    except KeyError:
      env_step = 0

    return epoch, env_step, gradient_step
