import typing as tp

import torch
import logging
import torch.distributed as dist



class MetricsLogger:
    def __init__(self):
        self.counts = {}
        self.metrics = {}

    def add(self, metrics: tp.Dict[str, torch.Tensor]) -> None:
        for k, v in metrics.items():
            if k in self.counts.keys():
                self.counts[k] += 1
                self.metrics[k] += v.detach().clone()
            else:
                self.counts[k] = 1
                self.metrics[k] = v.detach().clone()

    def pop(self, mean: bool = True) -> tp.Dict[str, torch.Tensor]:
        metrics = {}
        for k, v in self.metrics.items():
            metrics[k] = v / self.counts[k] if mean else v

        # reset
        self.counts = {}
        self.metrics = {}

        return metrics


def create_logger(logging_dir, accelerator):
    """
    Create a logger that writes to a log file and stdout.
    """
    if accelerator.is_main_process: # real logger
        logging.basicConfig(
            level=logging.INFO,
            format='[\033[34m%(asctime)s\033[0m] %(message)s',
            datefmt='%Y-%m-%d %H:%M:%S',
            handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")]
        )
        logger = logging.getLogger(__name__)
    else:  # dummy logger (does nothing)
        logger = logging.getLogger(__name__)
        logger.addHandler(logging.NullHandler())
    return logger