from typing import Optional, List

from dinoreg.utils.average_meter import AverageMeter
from dinoreg.utils.common import get_print_format


class SummaryBoard:
    r"""Summary board."""

    def __init__(self, names: Optional[List[str]] = None, adaptive=False):
        r"""Instantiate a SummaryBoard.

        Args:
            names (List[str]=None): create AverageMeter with the names.
            adaptive (bool=False): whether register basic meters automatically on the fly.
        """
        self.meter_dict = {}
        self.meter_names = []
        self.adaptive = adaptive

        if names is not None:
            self.register_all(names)

    def register_meter(self, name):
        self.meter_dict[name] = AverageMeter()
        self.meter_names.append(name)

    def register_all(self, names):
        for name in names:
            self.register_meter(name)

    def reset_meter(self, name):
        self.meter_dict[name].reset()

    def reset_all(self):
        for name in self.meter_names:
            self.reset_meter(name)

    def check_name(self, name):
        if name not in self.meter_names:
            if self.adaptive:
                self.register_meter(name)
            else:
                raise KeyError('No meter for key "{}".'.format(name))

    def update(self, name, value):
        self.check_name(name)
        self.meter_dict[name].update(value)

    def update_from_result_dict(self, result_dict):
        if not isinstance(result_dict, dict):
            raise TypeError('`result_dict` must be a dict: {}.'.format(type(result_dict)))
        for key, value in result_dict.items():
            if key not in self.meter_names and self.adaptive:
                self.register_meter(key)
            if key in self.meter_names:
                self.meter_dict[key].update(value)

    def sum(self, name):
        self.check_name(name)
        return self.meter_dict[name].sum()

    def mean(self, name):
        self.check_name(name)
        return self.meter_dict[name].mean()

    def std(self, name):
        self.check_name(name)
        return self.meter_dict[name].std()

    def median(self, name):
        self.check_name(name)
        return self.meter_dict[name].median()

    def tostring(self, names=None):
        if names is None:
            names = self.meter_names
        items = []
        for name in names:
            value = self.meter_dict[name].mean()
            fmt = get_print_format(value)
            format_string = '{}: {:' + fmt + '}'
            items.append(format_string.format(name, value))
        summary = ', '.join(items)
        return summary

    def summary(self, names=None, last_n=None):
        if names is None:
            names = self.meter_names
        summary_dict = {name: self.meter_dict[name].mean(last_n) for name in names}
        return summary_dict
