import copy
import numpy as np
import torch as th
from utils.custom_metric_logger import CustomMetricLogger


class AgentImportance:
    """The AgentImportance class is used to calculate an online approximation of the
    shapley value for each agent in the team during an evaluation interval. The
    contribution of each agent is then calculated by computing the difference
    between the team reward and the team reward without the considered agent. In
    order to ensure the same environment state is acted upon by all different
    agent coalitions we make a copy of the underlying environment and step the
    environment `n` times, where `n` is the number of agents in the environment.
    The agent for which the importance score is being computed then has its action
    replaced with a no-op action."""

    def __init__(self, n_agents, noop_action, logger, args, log_extra_details=False):
        """
        Initialize the AgentImportance class to calculate the contribution of
        each agent during the evaluation per step.

        Arguments:
            n_agents (int): Number of agents.
            noop_action: No-operation action(s). (int) if all the agents use the same
                noop action, (list) otherwise.
            logger: Logger object for logging the results in each evaluation.
            args: Dictionary of args related to the logging of the specific metric.
            log_extra_details (bool, optional): Whether to log extra details.
                Defaults to False.
        """

        self.n_agents = n_agents

        # Define the data structures to be used to store and log the data
        base_agent_importance = {
            f"agent_{agent_id}_importance_value": [] for agent_id in range(n_agents)
        }
        base_team_reward = {"team_reward": []}
        base_reward_without_agent = {
            f"reward_without_agent_{agent_id}": [] for agent_id in range(n_agents)
        }

        self.cached_data = {
            **base_agent_importance,
            **base_team_reward,
            **base_reward_without_agent,
        }

        # List of noop actions of each agent.
        if type(noop_action) == list:
            self.noop_action = noop_action
        else:
            self.noop_action = [noop_action] * n_agents

        # Define the logger
        self.logger = logger
        self.log_extra_details = log_extra_details

        # Create a custom logger for agent importance
        self.logger_agent_importance = CustomMetricLogger(args, "agent_importance")
        self._update = False

    def compute_per_step(self, env, actions, test_mode, t_env):
        """
        Compute agent importance values per step during the evalution.

        Arguments:
            env: Environment object.
            actions: Actions taken by the agents in an evaluation step.
            test_mode (bool): Whether it is test mode(during evaluation).
            t_env: Environment time step.
        """
        if test_mode:
            selected_actions = actions.tolist()

            # # Make copies of the environment for each agent + an extra copy for
            #  the team reward
            copy_envs = [copy.deepcopy(env) for _ in range(self.n_agents + 1)]

            # Replace agent action with NOOP
            noop_actions_per_agent = [
                [
                    self.noop_action[i] if i == j else action
                    for j, action in enumerate(selected_actions)
                ]
                for i in range(self.n_agents)
            ]

            # Calculate the team_reward
            team_reward, _, _ , _= copy_envs[-1].step(actions)

            # Calculate the marginal contribution of each agent
            for agent_id in range(self.n_agents):
                reward_without_agent, _, _ , _= copy_envs[agent_id].step(
                    th.tensor(noop_actions_per_agent[agent_id])
                )

                # Cache the data during runtime
                self.cached_data[f"agent_{agent_id}_importance_value"].append(
                    team_reward - reward_without_agent
                )

                if self.log_extra_details:
                    self.cached_data[f"reward_without_agent_{agent_id}"].append(
                        reward_without_agent
                    )

            if self.log_extra_details:
                self.cached_data["team_reward"].append(team_reward)

            if not (self._update):
                self._update = True

        elif self._update:
            # Log the collected samples after fininshing the evalution
            self._update = False
            for agent_id in range(self.n_agents):
                self.logger.log_stat(
                    f"agent_{agent_id}_importance_value",
                    np.mean(self.cached_data[f"agent_{agent_id}_importance_value"]),
                    t_env,
                )
                if self.log_extra_details:
                    self.logger.log_stat(
                        f"reward_without_agent_{agent_id}",
                        np.mean(self.cached_data[f"reward_without_agent_{agent_id}"]),
                        t_env,
                    )
            if self.log_extra_details:
                self.logger.log_stat(
                    f"team_reward", np.mean(self.cached_data["team_reward"]), t_env
                )

            # Write and store the data into the custom logger after finishing the
            # evaluation.
            write_data = {}
            write_data[t_env] = self.cached_data

            self.logger_agent_importance.write(write_data)

            # Reset the cached data
            base_agent_importance = {
                f"agent_{agent_id}_importance_value": []
                for agent_id in range(self.n_agents)
            }
            base_team_reward = {"team_reward": []}
            base_reward_without_agent = {
                f"reward_without_agent_{agent_id}": []
                for agent_id in range(self.n_agents)
            }
            self.cached_data = {
                **base_agent_importance,
                **base_reward_without_agent,
                **base_team_reward,
            }
