from collections import defaultdict

from torchmetrics import MeanMetric


class MeanDictMetric:
    def __init__(self, device=None, dtype=None):
        self.metrics = defaultdict(lambda: MeanMetric().to(device=device, dtype=dtype))

    def update(self, metrics_dict):
        if not isinstance(metrics_dict, dict):
            raise ValueError("update() expects a dictionary")

        for key, value in metrics_dict.items():
            self.metrics[key].update(value)

    def compute(self):
        return {k: m.compute() for k, m in self.metrics.items()}

    def reset(self):
        for metric in self.metrics.values():
            metric.reset()
