from typing import DefaultDict, Union, Callable, Dict
from collections import Collection, Mapping
import torch
import numpy as np
import hashlib


def rec_map(data: Collection, func: Callable) -> Collection:
    result = func(data)
    if result is not None:
        return result

    _apply = lambda x: rec_map(x, func)
    if isinstance(data, Mapping):
        return type(data)({k: _apply(v) for k, v in data.items()})
    elif isinstance(data, Collection):
        return type(data)(_apply(v) for v in data)
    return data


def md5(x: Union[torch.nn.Module, torch.Tensor, np.ndarray, bytes]) -> str:
    if isinstance(x, torch.nn.Module):
        x = torch.nn.utils.parameters_to_vector(x.parameters())
    if isinstance(x, torch.Tensor):
        x = x.cpu().detach().numpy()
    if isinstance(x, np.ndarray):
        x = x.tobytes()
    return hashlib.md5(x).digest().hex()


class AverageMeter:
    count: int
    sum: float

    def __init__(self):
        self.count = 0
        self.sum = 0

    @property
    def mean(self) -> float:
        return self.sum / max(self.count, 1e-3)

    def update(self, x, n=1):
        self.count += n
        self.sum += x * n

    def reset(self):
        self.count = 0
        self.sum = 0

    def __iadd__(self, x):
        self.update(x)
        return self


class MeterLib(DefaultDict[str, AverageMeter]):
    def __init__(self):
        super().__init__(AverageMeter)

    def purge(self, prefix='') -> Dict[str, float]:
        ret = {}
        for key, meter in self.items():
            if key.startswith(prefix):
                ret[key] = meter.mean
                meter.reset()
        return ret
