import inspect
import torch


class BaseMetric(object):
    def __init__(self, name):
        self.name = name

        argspec = inspect.getfullargspec(self.call)
        assert argspec.varargs is None
        self._params = argspec.args
        self._params = self._params[1:]  # remove `self`

    @property
    def params(self):
        return self._params

    def __call__(self, **kwargs):
        inputs = dict(filter(lambda kv: (kv[0] in self.params), kwargs.items()))
        try:
            with torch.no_grad():
                result = self.call(**inputs)
        except Exception as e:
            print(f'metric: {self.__class__.__name__}')
            print('inputs:', inputs)
            raise e
        return result

    def call(self, **kwargs):
        raise NotImplementedError


def create_metric(func, name=None):
    if name is None:
        name = func.__name__
    argspec = inspect.getfullargspec(func)

    class _Metric(BaseMetric):
        def __init__(self):
            super(_Metric, self).__init__(name)
            self._params = argspec.args

        def call(self, **kwargs):
            return func(**kwargs)

    return _Metric()
