from envs import REGISTRY as env_REGISTRY
from functools import partial
from components.episode_buffer import EpisodeBatch
import numpy as np
import copy
from metrics.agent_importance import AgentImportance
from metrics.shapley_value import ShapleyValue
from metrics.individual_reward import IndividualReward


class EpisodeRunner:

    def __init__(self, args, logger):
        self.args = args
        self.logger = logger
        self.batch_size = self.args.batch_size_run
        assert self.batch_size == 1

        self.env = env_REGISTRY[self.args.env](**self.args.env_args)
        self.episode_limit = self.env.episode_limit
        self.t = 0

        self.t_env = 0

        self.train_returns = []
        self.test_returns = []
        self.train_stats = {}
        self.test_stats = {}

        # Log the first run
        self.log_train_stats_t = -1000000

        # Create an instance to calculate the agent importance values, or shpaley values using whether the original formula or monte-carlo approximation.
        if self.args.compute_agent_importance:
            self.agent_importance = AgentImportance(
                n_agents=self.env.n_agents,
                noop_action=0,
                logger=self.logger,
                args=args,
                log_extra_details=False)

        if self.args.shapley_value:
            self.shap_value = ShapleyValue(
                n_agents=self.env.n_agents,
                noop_action=0,
                logger=self.logger,
                args=args,
                use_original_shap=True)

        if self.args.monte_carlo:
            self.monte_carlo_shap_value = ShapleyValue(
                n_agents=self.env.n_agents,
                noop_action=0,
                logger=self.logger,
                args=args,
                use_original_shap=False)
            
        if self.args.individual_rewards:
            self.individual_reward = IndividualReward(
                n_agents=self.env.n_agents, args=args)

    def setup(self, scheme, groups, preprocess, mac):
        self.new_batch = partial(EpisodeBatch, scheme, groups, self.batch_size, self.episode_limit + 1,
                                 preprocess=preprocess, device=self.args.device)
        self.mac = mac
        # Initialize the absolute metric params
        if self.args.compute_absolute_metric:
            self.best_performance = 0
            self.best_mac = copy.deepcopy(self.mac)
            self.absolute_metric_val = []

    def get_env_info(self):
        return self.env.get_env_info()

    def save_replay(self):
        self.env.save_replay()

    def close_env(self):
        if self.args.compute_absolute_metric:
            self.mac.load_state(self.best_mac)
            for _ in range(10*self.args.test_nepisode):
                self.run(test_mode=True, compute_absolute_metric=True)
            # Logging results.
            self.logger.log_stat("absolute_metric_return_mean", np.mean(
                self.absolute_metric_val), self.t_env)
            self.logger.log_stat("absolute_metric_return_std",
                                 np.std(self.absolute_metric_val), self.t_env)
        self.env.close()

    def reset(self):
        self.batch = self.new_batch()
        self.env.reset()
        self.t = 0

    def run(self, test_mode=False, compute_absolute_metric=False):
        self.reset()

        terminated = False
        episode_return = 0
        self.mac.init_hidden(batch_size=self.batch_size)

        while not terminated:

            pre_transition_data = {
                "state": [self.env.get_state()],
                "avail_actions": [self.env.get_avail_actions()],
                "obs": [self.env.get_obs()]
            }

            self.batch.update(pre_transition_data, ts=self.t)

            # Pass the entire batch of experiences up till now to the agents
            # Receive the actions for each agent at this timestep in a batch of size 1
            actions = self.mac.select_actions(
                self.batch, t_ep=self.t, t_env=self.t_env, test_mode=test_mode)
            
            # Calculate agent importance values.
            if (not compute_absolute_metric) and (self.args.compute_agent_importance):
                self.agent_importance.compute_per_step(
                    self.env,
                    actions[0],
                    test_mode,
                    self.t_env)
            
            # Calculate Shapley Value
            if (not compute_absolute_metric) and (self.args.shapley_value):
                self.shap_value.compute_per_step(
                    env=self.env, actions=actions[0], test_mode=test_mode, t_env=self.t_env)
            if (not compute_absolute_metric) and (self.args.monte_carlo):
                self.monte_carlo_shap_value.compute_per_step(
                    env=self.env, actions=actions[0], test_mode=test_mode, t_env=self.t_env)
                
            # Store Individual Rewards
            if (not compute_absolute_metric) and self.args.individual_rewards:
                self.individual_reward.store_rewards(
                    env=self.env, actions=actions[0], t_env=self.t_env, test_mode=test_mode)



            reward, terminated, env_info, _ = self.env.step(actions[0])
            if test_mode and self.args.render:
                self.env.render()
            episode_return += reward

            post_transition_data = {
                "actions": actions,
                "reward": [(reward,)],
                "terminated": [(terminated != env_info.get("episode_limit", False),)],
            }

            self.batch.update(post_transition_data, ts=self.t)

            self.t += 1

        last_data = {
            "state": [self.env.get_state()],
            "avail_actions": [self.env.get_avail_actions()],
            "obs": [self.env.get_obs()]
        }
        if test_mode and self.args.render:
            print(f"Episode return: {episode_return}")
        self.batch.update(last_data, ts=self.t)

        # Select actions in the last stored state
        actions = self.mac.select_actions(
            self.batch, t_ep=self.t, t_env=self.t_env, test_mode=test_mode)
        self.batch.update({"actions": actions}, ts=self.t)

        cur_stats = self.test_stats if test_mode else self.train_stats
        cur_returns = self.test_returns if test_mode else self.train_returns
        log_prefix = "test_" if test_mode else ""
        cur_stats.update({k: cur_stats.get(k, 0) + env_info.get(k, 0)
                         for k in set(cur_stats) | set(env_info)})
        cur_stats["n_episodes"] = 1 + cur_stats.get("n_episodes", 0)
        cur_stats["ep_length"] = self.t + cur_stats.get("ep_length", 0)

        if not test_mode or compute_absolute_metric:
            self.t_env += self.t

        cur_returns.append(episode_return)
        if compute_absolute_metric:
            self.logger.log_stat("absolute_metric_return",
                                 episode_return, self.t_env)
            self.absolute_metric_val.append(episode_return)
            return self.batch
        elif test_mode and (len(self.test_returns) == self.args.test_nepisode):
            self._log(cur_returns, cur_stats, log_prefix)
        elif self.t_env - self.log_train_stats_t >= self.args.runner_log_interval:
            self._log(cur_returns, cur_stats, log_prefix)
            if hasattr(self.mac.action_selector, "epsilon"):
                self.logger.log_stat(
                    "epsilon", self.mac.action_selector.epsilon, self.t_env)
            self.log_train_stats_t = self.t_env

        # store best network
        if self.args.compute_absolute_metric and self.best_performance <= episode_return:
            self.best_mac.load_state(self.mac)
            self.best_performance = episode_return

        return self.batch

    def _log(self, returns, stats, prefix):
        self.logger.log_stat(prefix + "return_mean",
                             np.mean(returns), self.t_env)
        self.logger.log_stat(prefix + "return_std",
                             np.std(returns), self.t_env)
        returns.clear()

        for k, v in stats.items():
            if k != "n_episodes":
                self.logger.log_stat(prefix + k + "_mean",
                                     v/stats["n_episodes"], self.t_env)
        stats.clear()
