import numpy as np
from typing import Optional, List
from collections import defaultdict
from src.verify.trainer.metrics.base_metric import BaseMetric
from src.verify.trainer.utils import is_number
import torch


class MetricsSummary(object):
    def __init__(self,
                 batch_metrics: Optional[List[BaseMetric]] = None,
                 epoch_metrics: Optional[List[BaseMetric]] = None):

        self.batch_metrics = batch_metrics or []
        self.epoch_metrics = epoch_metrics or []

        self._data = defaultdict(lambda: [])
        self._to_cal_data = defaultdict(lambda: [])

    def to_detach(self, params):
        if isinstance(params, dict):
            result = {}
            for name, value in params.items():
                result[name] = self.to_detach(value)
            return result
        elif isinstance(params, torch.Tensor):
            return params.detach()
        elif isinstance(params, tuple):
            return tuple([self.to_detach(i) for i in params])
        elif isinstance(params, list):
            return list([self.to_detach(i) for i in params])
        else:
            return params

    def _epoch_metrics_params(self):
        r = []
        for metric in self.epoch_metrics:
            r.extend(metric.params)
        return r

    def append(self, metrics: dict, batch_nums: int = None):
        if batch_nums is None:
            batch_nums = 1
        metrics = self.to_detach(metrics)

        out_cal_metrics = {}
        for func in self.batch_metrics:
            out_cal_metrics[func.name] = func(**metrics)

        for name, value in metrics.items():
            if name in self._epoch_metrics_params():
                self._to_cal_data[name].append(value)
            elif is_number(value):
                self._data[name].append((float(value), batch_nums))

        return out_cal_metrics

    @property
    def to_cal_data(self):
        result = {}
        for name, value in self._to_cal_data.items():
            if isinstance(value[0], torch.Tensor):
                result[name] = torch.cat(value)
            else:
                r = []
                for p_value in zip(*value):
                    if isinstance(p_value[0], torch.Tensor):
                        r.append(torch.cat(p_value))
                    else:
                        r.append(list(p_value))
                result[name] = tuple(r)
        return result

    def mean(self):
        result = {}
        for name, values in self._data.items():
            values = np.array(values)
            mean_value = np.sum(values[:, 0] * values[:, 1]) / np.sum(values[:, 1])
            result[name] = mean_value

        for metric_func in self.epoch_metrics:
            result[metric_func.name] = metric_func(**self.to_cal_data)

        return result


if __name__ == '__main__':
    import torch
    from verify.trainer.metrics import create_metric

    def wtf_metric(output, target):
        output, _ = output
        return torch.mean(output - target)

    metric = create_metric(wtf_metric)
    summary = MetricsSummary(batch_metrics=[metric], epoch_metrics=[metric])
    summary.append(dict(loss=1,
                        output=(torch.ones([16, 2]), 'sss'),
                        target=torch.ones([16, 2])))
    summary.append(dict(loss=1,
                        output=(torch.ones([16, 2]), 'sss'),
                        target=torch.ones([16, 2])))
    summary.mean()
