# -*- coding: utf-8 -*
import os.path as osp
from collections import Mapping
from typing import Dict

from loguru import logger

from torch.utils.tensorboard.writer import SummaryWriter

from videoanalyst.utils import ensure_dir

from ..monitor_base import TRACK_MONITORS, VOS_MONITORS, MonitorBase


@TRACK_MONITORS.register
@VOS_MONITORS.register
class TensorboardLogger(MonitorBase):
    r"""Log training info to tensorboard for better visualization

    Hyper-parameters
    ----------------
    exp_name : str
        experiment name
    exp_save : str
        directory to save snapshots
    log_dir : str
        places to save tensorboard file
        will be updated in update_params
        EXP_SAVE/EXP_NAME/logs/tensorboard 
    """

    default_hyper_params = dict(
        exp_name="",
        exp_save="",
        log_dir="",
    )

    def __init__(self, ):
        r"""
        Arguments
        ---------
        """
        super(TensorboardLogger, self).__init__()
        self._state["writer"] = None

    def update_params(self):
        self._hyper_params["log_dir"] = osp.join(
            self._hyper_params["exp_save"],
            self._hyper_params["exp_name"],
            "logs/tensorboard",
        )

    def init(self, engine_state: Dict):
        super(TensorboardLogger, self).init(engine_state)

    def update(self, engine_data: Dict):
        # from engine state calculate global step
        engine_state = self._state["engine_state"]
        epoch = engine_state["epoch"]
        max_epoch = engine_state["max_epoch"]
        iteration = engine_state["iteration"]
        max_iteration = engine_state["max_iteration"]
        global_step = iteration + epoch * max_iteration

        # build at first update
        if self._state["writer"] is None:
            self._build_writer(global_step=global_step)
            logger.info(
                "Tensorboard writer built, starts recording from global_step=%d"
                % global_step, )
            logger.info(
                "epoch=%d, max_epoch=%d, iteration=%d, max_iteration=%d" %
                (epoch, max_epoch, iteration, max_iteration))
        writer = self._state["writer"]

        # traverse engine_data and put to scalar
        self._add_scalar_recursively(writer, engine_data, "", global_step)

    def _build_writer(self, global_step=0):
        log_dir = self._hyper_params["log_dir"]
        ensure_dir(log_dir)
        self._state["writer"] = SummaryWriter(
            log_dir=log_dir,
            purge_step=global_step,
            filename_suffix="",
        )

    def _add_scalar_recursively(self, writer: SummaryWriter, o, prefix: str,
                                global_step: int):
        """Recursively add scalar from mapping-like o: tag1/tag2/tag3/...
        
        Parameters
        ----------
        writer : SummaryWriter
            writer
        o : mapping-like or scalar
            [description]
        prefix : str
            tag prefix, tag is the name to be passed into writer
        global_step : int
            global step counter
        """
        if isinstance(o, Mapping):
            for k in o:
                if len(prefix) > 0:
                    tag = "%s/%s" % (prefix, k)
                else:
                    tag = k
                self._add_scalar_recursively(writer, o[k], tag, global_step)
        else:
            writer.add_scalar(prefix, o, global_step=global_step)
