from abc import ABCMeta, abstractmethod
from typing import Sequence, Dict, Any
from collections import defaultdict


class MetricType:
    REWARD = "reward"
    LIVE_STEP = "live_step"


class Metric(metaclass=ABCMeta):
    def __init__(self, agents: Sequence[str]):
        self._agents = agents
        self._episode_data = dict()
        self._statistics = dict()

    @abstractmethod
    def step(self, agent_id, observation, action, reward, done, info):
        pass

    @abstractmethod
    def parse(self):
        pass

    @abstractmethod
    def merge_parsed(self, agent_result_seq):
        pass

    def reset(self):
        self._episode_data = dict()
        self._statistics = dict()


class SimpleMetrics(Metric):
    """
    For single episode only
    """

    def __init__(self, agents: Sequence[str]):
        super(SimpleMetrics, self).__init__(agents)
        self._episode_data = {
            MetricType.REWARD: defaultdict(lambda: []),
            MetricType.LIVE_STEP: defaultdict(lambda: 0),
        }
        self._statistics = defaultdict(
            lambda: {MetricType.REWARD: 0.0, MetricType.LIVE_STEP: 0}
        )

    def step(self, agent_id, observation, action, reward, done, info):
        self._episode_data[MetricType.REWARD][agent_id].append(reward)
        self._episode_data[MetricType.LIVE_STEP][agent_id] += 1

    def parse(self):
        """Return agent-wise statistic results"""
        for item_key, agent_data in self._episode_data.items():
            for aid in self._agents:
                if item_key == MetricType.REWARD:
                    self._statistics[aid][MetricType.REWARD] = sum(agent_data[aid])
                elif item_key == MetricType.LIVE_STEP:
                    self._statistics[aid][MetricType.LIVE_STEP] = agent_data[aid]

        return self._statistics

    def merge_parsed(self, agent_result_seq: Dict[str, Sequence[Dict[str, Any]]]):
        """Merge a sequential of statistic results"""

        agent_res = {}
        for k, result_seq in agent_result_seq.items():
            res = {MetricType.REWARD: 0.0, MetricType.LIVE_STEP: 0.0}
            for result in result_seq:
                res[MetricType.REWARD] += result[MetricType.REWARD] / len(result_seq)
                res[MetricType.LIVE_STEP] += result[MetricType.LIVE_STEP] / len(
                    result_seq
                )
            agent_res[k] = res
        return agent_res

    def reset(self):
        self._episode_data = {
            MetricType.REWARD: defaultdict(lambda: []),
            MetricType.LIVE_STEP: defaultdict(lambda: 0),
        }
        self._statistics = defaultdict(
            lambda: {MetricType.REWARD: 0.0, MetricType.LIVE_STEP: 0}
        )


def get_metric_handler(name):
    if name == "simple":
        return SimpleMetrics
    else:
        raise ValueError(f"No such a metric handler named with: {name}")
