import copy
import numpy as np
import torch as th
from itertools import combinations
from math import factorial
import random
from utils.custom_metric_logger import CustomMetricLogger


class ShapleyValue:
    """The shapley value class is designed to evaluate the contribution of each agent
    to a cooperative task. The Shapley values may either be computed per episode,
    in which case the Monte-Carlo approximation is used, or by computing all possible
    agent coalitions and stepping the environment `2^{n-1}` at each evaluation step
    times. Here `n` is the number of agents in the environment. In all cases the agent
    for which the contribution is being computed has its action replaced with a no-op
    action."""

    def __init__(self, n_agents, noop_action, logger, args, use_original_shap=True):
        """
        Initialize the ShapleyValue class for computing the shapley values either
        using the exact method or using the Monte-Carlo approximation"

        Args:
            n_agents (int): Number of agents.
            noop_action: No-operation action(s). (int) if all the agents use the same
                noop action, (list) otherwise.
            logger: Logger object for logging results.
            args: Dictionary of args related to the logging of the specific metric
            use_original_shap (bool, optional): Whether to use the original Shapley
                value method. Defaults to True.
        """
        self.n_agents = n_agents
        # List of noop actions of each agent.
        if type(noop_action) == list:
            self.noop_action = noop_action
        else:
            self.noop_action = [noop_action] * n_agents

        # Adjust the method to be used to calculate the shapley value
        self.use_original_shap = use_original_shap
        if self.use_original_shap:
            self.shap_type = "shapley_value"
        else:
            self.shap_type = "monte_carlo_shap"

        # Get all possible coalitions without permutation
        self.agents_ids = tuple(agent_id for agent_id in range(self.n_agents))
        self.coalitions_list = [
            coalition
            for i in range(n_agents + 1)
            for coalition in combinations(self.agents_ids, i)
        ]

        # Define the data structure to be used to store and log the data
        self.collected_samples = {
            f"agent_{agent_id}_{self.shap_type}": [] for agent_id in range(n_agents)
        }

        # Define the logger and create a custom logger for agent importance
        self.logger = logger
        self.logger_shap = CustomMetricLogger(args, f"{self.shap_type}")
        self._update = False

    def compute_per_step(self, env, actions, test_mode, t_env):
        """
        Compute Shapley values per step.

        Args:
            env: Environment object.
            actions: Actions taken by the agents.
            test_mode (bool): Whether it is test mode(evaluation).
            t_env: Environment time step.
        """
        if test_mode:
            selected_actions = actions.tolist()
            if self.use_original_shap:
                self.shapley_values(env, selected_actions)
            else:
                self.monte_carlo_approximation(env, selected_actions)
            if not (self._update):
                self._update = True
        elif self._update:
            self._update = False
            for agent_id in range(self.n_agents):
                self.logger.log_stat(
                    f"agent_{agent_id}_{self.shap_type}",
                    np.mean(
                        self.collected_samples[f"agent_{agent_id}_{self.shap_type}"]
                    ),
                    t_env,
                )
            # Write to the custom logger
            write_data = {}
            write_data[t_env] = self.collected_samples

            self.logger_shap.write(write_data)
            self.collected_samples = {
                f"agent_{agent_id}_{self.shap_type}": []
                for agent_id in range(self.n_agents)
            }

    def get_coalitions_for_agent(self, agent_id):
        """
        Get all the coalitions that contain an agent "agent".

        Args:
            agent_id (int): Agent ID.

        Returns:
            tuple or list: Coalition(s) with the agent or coalition(s) without the agent.
        """
        coalitions_with_agent = [
            coalition for coalition in self.coalitions_list if agent_id in coalition
        ]
        if not self.use_original_shap:
            c_with = random.choice(coalitions_with_agent)
            c_without = tuple(a for a in c_with if a != agent_id)
            return c_with, c_without
        coalitions_without_agent = [
            tuple(a for a in coalition if a != agent_id)
            for coalition in coalitions_with_agent
        ]
        return coalitions_with_agent, coalitions_without_agent

    def calculate_coalition_reward(self, coalition, selected_actions, env):
        """
        Replace the agent's action with a no-op if it's not included in 'coalition'
        and calculate the reward of the coalition.

        Args:
            coalition (tuple or list): Coalition of agents.
            selected_actions: Actions selected by the agents.
            env: Environment object.

        Returns:
            float: Reward of the coalition.
        """
        copy_env = copy.deepcopy(env)
        actions_coalition = [
            self.noop_action[agent_id]
            if agent_id not in coalition
            else selected_actions[agent_id]
            for agent_id in range(self.n_agents)
        ]
        reward_coalition, _, _ , _= copy_env.step(th.tensor(actions_coalition))
        return reward_coalition

    def shapley_values(self, env, selected_actions):
        """
        Calculate the Shapley value for each agent with different coalitions.

        Args:
            env: Environment object.
            selected_actions: Actions selected by the agents.
        """

        coalitions_reward = {}
        for agent_id in range(self.n_agents):
            agent_shapley_value = 0
            (
                coalitions_with_agent,
                coalitions_without_agent,
            ) = self.get_coalitions_for_agent(agent_id)

            # Loop through all possible coalitions with and without the current agent
            for c_with, c_without in zip(
                coalitions_with_agent, coalitions_without_agent
            ):
                size_s = len(c_without)
                if c_with not in coalitions_reward:
                    coalitions_reward[c_with] = self.calculate_coalition_reward(
                        c_with, selected_actions, env
                    )

                if c_without not in coalitions_reward:
                    coalitions_reward[c_without] = self.calculate_coalition_reward(
                        c_without, selected_actions, env
                    )

                # Calculate the Shapley value for the current agent
                marginal_contribution = (
                    coalitions_reward[c_with] - coalitions_reward[c_without]
                )
                agent_shapley_value += (
                    factorial(size_s)
                    * factorial(self.n_agents - size_s - 1)
                    / factorial(self.n_agents)
                ) * marginal_contribution

            # Store the shapley value for each agent
            self.collected_samples[f"agent_{agent_id}_{self.shap_type}"].append(
                agent_shapley_value
            )

    def monte_carlo_approximation(self, env, selected_actions):
        """Calculate the shapley values for each agent using Monte-Carlo approximation

        Args:
            env: Environment object.
            selected_actions: Actions selected by the agents.
        """

        coalitions_reward = {}

        # Calculate the shapley value for each agent using Monte-Carlo approximation:
        for agent_id in range(self.n_agents):
            c_with, c_without = self.get_coalitions_for_agent(agent_id)
            if c_with not in coalitions_reward:
                coalitions_reward[c_with] = self.calculate_coalition_reward(
                    c_with, selected_actions, env
                )
            if c_without not in coalitions_reward:
                coalitions_reward[c_without] = self.calculate_coalition_reward(
                    c_without, selected_actions, env
                )
            marginal_contribution = (
                coalitions_reward[c_with] - coalitions_reward[c_without]
            )

            # Store the shapley value for each agent
            self.collected_samples[f"agent_{agent_id}_{self.shap_type}"].append(
                marginal_contribution
            )
