import numpy as np


class Callback(object):
    """
    Interface for all basic callbacks. Implements a list in which it is possible
    to store data and methods to query and clear the content stored by the
    callback.

    """

    def __init__(self):
        """
        Constructor.

        """
        self._data_list = list()

    def __call__(self, dataset):
        """
        Add samples to the samples list.

        Args:
            dataset (list): the samples to collect.

        """
        raise NotImplementedError

    def get(self):
        """
        Returns:
             The current collected data as a list.

        """
        return self._data_list

    def clear(self):
        """
        Delete the current stored data list

        """
        self._data_list = list()


class CollectDataset(Callback):
    """
    This callback can be used to collect samples during the learning of the
    agent.

    """

    def __call__(self, dataset):
        if type(dataset) is list:
            self._data_list.extend(dataset)
        else:
            self._data_list.append(dataset)


def compute_J_all_agents(dataset, gamma=1.0):
    """
    Compute the cumulative discounted reward of each episode for each agent in the dataset.

    Args:
        dataset (list): the dataset to consider;
        gamma (float, 1.): discount factor.

    Returns:
        The cumulative discounted reward of each episode in the dataset.

    """
    num_agents = len(dataset[0]["actions"])
    js_all_agents = list()
    for idx_agent in range(num_agents):
        js_agent = list()
        j = 0.0
        episode_steps = 0
        for i in range(len(dataset)):
            j += gamma**episode_steps * dataset[i]["rewards"][idx_agent]
            episode_steps += 1
            if dataset[i]["last"] or i == len(dataset) - 1:
                js_agent.append(j)
                j = 0.0
                episode_steps = 0
        if len(js_agent) == 0:
            js_agent.append(0.0)
        js_all_agents.append(js_agent)

    return js_all_agents


def compute_episode_lengths(dataset):
    """
    Compute the length of each episode in the dataset.

    Args:
        dataset (list): the dataset to consider.

    Returns:
        The length of each episode in the dataset.

    """
    episode_lengths = list()
    episode_length = 0
    for i in range(len(dataset)):
        episode_length += 1
        if dataset[i]["last"] or i == len(dataset) - 1:
            episode_lengths.append(episode_length)
            episode_length = 0

    return episode_lengths


def compute_action_norms_all_agents(dataset):
    """
    Return the action norms of a given agent over the dataset
    """
    num_agents = len(dataset[0]["actions"])
    norms_all_agents = list()
    for idx_agent in range(num_agents):
        norms_agent = list()
        for i in range(len(dataset)):
            norm = np.linalg.norm(dataset[i]["actions"][idx_agent])
            norms_agent.append(norm)
            if dataset[i]["last"] or i == len(dataset) - 1:
                norms_agent.append(norm)
        if len(norms_agent) == 0:
            norms_agent.append(0.0)
        norms_all_agents.append(norms_agent)

    return norms_all_agents


def compute_action_norms(dataset):
    norms_all_agents = compute_action_norms_all_agents(dataset)


def smac_battles_won(dataset, info_dataset):
    """
    Compute the number of battles won by the agents in the dataset.

    Args:
        dataset (list): the dataset to consider;
        info_dataset (list): the info dataset to consider.

    Returns:
        The number of battles won by the agents in the dataset.

    """
    battles_won = []
    for i in range(len(dataset)):
        if dataset[i]["last"]:
            battles_won.append(info_dataset[i]["battle_won"])
    return battles_won
