"""
references: UNITER
"""

import logging
from tensorboardX import SummaryWriter
import os

_LOG_FMT = '%(asctime)s - %(levelname)s - %(name)s -   %(message)s'
_DATE_FMT = '%m/%d/%Y %H:%M:%S'
logging.basicConfig(format=_LOG_FMT, datefmt=_DATE_FMT, level=logging.INFO)
LOGGER = logging.getLogger('__main__')  # this is the global logger


def add_log_to_file(log_path):
    fh = logging.FileHandler(log_path)
    formatter = logging.Formatter(_LOG_FMT, datefmt=_DATE_FMT)
    fh.setFormatter(formatter)
    LOGGER.addHandler(fh)


class TensorboardLogger(object):
    def __init__(self):
        self._logger = None
        self._global_step = 0

    def create(self, path):
        if "AZUREML_TB_PATH" in os.environ:
            self._logger = SummaryWriter(os.environ["AZUREML_TB_PATH"])
        else:
            self._logger = SummaryWriter(path)

    def noop(self, *args, **kwargs):
        return

    def step(self):
        self._global_step += 1

    @property
    def global_step(self):
        return self._global_step

    @global_step.setter
    def global_step(self, step):
        self._global_step = step

    def log_scalar_dict(self, log_dict, prefix=''):
        """ log a dictionary of scalar values"""
        if self._logger is None:
            return
        if prefix:
            prefix = f'{prefix}_'
        for name, value in log_dict.items():
            if isinstance(value, dict):
                self.log_scalar_dict(value, self._global_step,
                                     prefix=f'{prefix}{name}')
            else:
                self._logger.add_scalar(f'{prefix}{name}', value,
                                        self._global_step)

    def __getattr__(self, name):
        if self._logger is None:
            return self.noop
        return self._logger.__getattribute__(name)


TB_LOGGER = TensorboardLogger()


class RunningMeter(object):
    """ running meteor of a scalar value
        (useful for monitoring training loss)
    """
    def __init__(self, name, val=None, smooth=0.99):
        self._name = name
        self._sm = smooth
        self._val = val

    def __call__(self, value):
        self._val = (value if self._val is None
                     else value*(1-self._sm) + self._val*self._sm)

    def __str__(self):
        return f'{self._name}: {self._val:.4f}'

    @property
    def val(self):
        return self._val

    @property
    def name(self):
        return self._name
