"""Utilities for logging and averaging metrics."""

import logging
from collections import OrderedDict
import time
import torch


def log_fn(logger: None):
    return logger.info if logger is not None else print


def get_logger(file_name: str, mode="a"):
    logging.basicConfig(level=logging.INFO, format="%(message)s")
    logger = logging.getLogger()
    # remove old handlers to avoid duplicated outputs.
    for hdlr in logger.handlers:
        logger.removeHandler(hdlr)
    logger.addHandler(logging.FileHandler(file_name, mode=mode))
    logger.addHandler(logging.StreamHandler())
    return logger


class Averager(object):
    """This class is used to keep track of the sum and count of various metrics.

    Attrs:
        sum, OrderedDict[key, float]: sum for each metric.
        cnt,  OrderedDict[key, int]: number of records for each metric.
        clock, float: record time (secs).
    """

    def __init__(self, *keys):
        self.sum = OrderedDict()
        self.cnt = OrderedDict()
        self.clock = time.time()
        for key in keys:
            self.sum[key] = 0
            self.cnt[key] = 0

    def batch_update(self, batch_dict):
        for key, val in batch_dict.items():
            self.update(key, val)

    def update(self, key, val):
        if isinstance(val, torch.Tensor):
            val = val.item()

        if self.sum.get(key, None) is None:
            self.sum[key] = val
            self.cnt[key] = 1
        else:
            self.sum[key] = self.sum[key] + val
            self.cnt[key] += 1

    def reset(self):
        for key in self.sum.keys():
            self.sum[key] = 0
            self.cnt[key] = 0

        self.clock = time.time()

    def get(self, key):
        if key not in self.sum:
            return None
        return self.sum[key] / self.cnt[key] if self.cnt[key] > 0 else 0

    def info(self):
        line = ""
        for key in self.sum.keys():
            val = self.sum[key] / self.cnt[key]  # average
            line += f"{key}: {val:.4f} "

        line += f"({time.time()-self.clock:.3f} secs)"
        return line


def get_timestamp() -> str:
    return time.strftime("%Y%m%d_%H%M%S")
