from envs import REGISTRY as env_REGISTRY
from functools import partial
from components.episode_buffer import EpisodeBatch
import numpy as np
import torch as th


class EpisodeRunner:

    def __init__(self, args, logger, task):
        self.args = args
        self.logger = logger
        # 主要是输入给mac用来分辨任务的
        self.task = task
        self.batch_size = self.args.batch_size_run
        assert self.batch_size == 1

        self.env = env_REGISTRY[self.args.env](**self.args.env_args)
        self.env_args = self.args.env_args
        self.episode_limit = self.env.episode_limit
        self.t = 0

        self.t_env = 0

        self.train_returns = []
        self.test_returns = []
        self.train_stats = {}
        self.test_stats = {}

        # Log the first run
        self.log_train_stats_t = -1000000
    
    def reset_env(self, seed=0):
        new_env_agrs = self.env_args
        # new_env_agrs.seed = seed
        new_env_agrs["seed"] = seed
        self.env = env_REGISTRY[self.args.env](**new_env_agrs)

    def setup(self, scheme, groups, preprocess, mac):
        self.new_batch = partial(EpisodeBatch, scheme, groups, self.batch_size, self.episode_limit + 1,
                                 preprocess=preprocess, device=self.args.device)
        self.mac = mac

    def get_env_info(self):
        return self.env.get_env_info()

    def save_replay(self):
        self.env.save_replay()

    def close_env(self):
        self.env.close()

    def reset(self):
        self.batch = self.new_batch()
        self.env.reset()
        self.t = 0

    def run(self, task_encoding, test_mode=False, nolog=False, run_with_prior_encoding=False, learner=None, pretrain=False, return_detail=False, deterministic=False):
        self.reset()

        terminated = False
        episode_return = 0
        self.mac.init_hidden(batch_size=self.batch_size, task=self.task)
        if not run_with_prior_encoding:
            self.mac.set_task_encoding(task_encoding, self.task)
            
        if getattr(self.args, "prior_role_use_history", False):
            prior_role_hidden = learner.prior_role_encoder.init_hidden().unsqueeze(0).expand(self.batch_size, self.mac.task2n_agents[self.task], -1)

        while not terminated:

            pre_transition_data = {
                "state": [self.env.get_state()],
                "avail_actions": [self.env.get_avail_actions()],
                "obs": [self.env.get_obs()]
            }

            self.batch.update(pre_transition_data, ts=self.t)

            if run_with_prior_encoding:
                obs = self.batch["obs"][:, self.t:self.t+1]
                if getattr(self.args, "prior_role_use_history", False):
                    prior_encoding, prior_role_hidden = learner.prior_role_encoder(obs, task_encoding, self.task, hidden_state=prior_role_hidden)
                else:
                    prior_encoding, _ = learner.prior_role_encoder(obs, task_encoding, self.task)
                new_task_encoding = task_encoding.unsqueeze(1)
                if not self.args.only_role_encoding:
                    new_task_encoding = th.cat([new_task_encoding, prior_encoding], dim=-1)[:,0,:,:]
                else:
                    new_task_encoding = prior_encoding[:,0,:,:]
                self.mac.set_task_encoding(new_task_encoding, self.task)

            # Pass the entire batch of experiences up till now to the agents
            # Receive the actions for each agent at this timestep in a batch of size 1
            if getattr(self.args, "is_diffusion", False):
                actions = self.mac.select_actions(self.batch, t_ep=self.t, t_env=self.t_env, q=learner.q_critic, task=self.task, test_mode=test_mode or deterministic)
            else:
                actions = self.mac.select_actions(self.batch, t_ep=self.t, t_env=self.t_env, task=self.task, test_mode=test_mode or deterministic)

            reward, terminated, env_info = self.env.step(actions[0])
            episode_return += reward

            post_transition_data = {
                "actions": actions,
                "reward": [(reward,)],
                "terminated": [(terminated != env_info.get("episode_limit", False),)],
            }

            self.batch.update(post_transition_data, ts=self.t)

            self.t += 1

        last_data = {
            "state": [self.env.get_state()],
            "avail_actions": [self.env.get_avail_actions()],
            "obs": [self.env.get_obs()]
        }
        self.batch.update(last_data, ts=self.t)

        # Select actions in the last stored state
        if getattr(self.args, "is_diffusion", False):
            actions = self.mac.select_actions(self.batch, t_ep=self.t, t_env=self.t_env, q=learner.q_critic, task=self.task, test_mode=test_mode or deterministic)
        else:
            actions = self.mac.select_actions(self.batch, t_ep=self.t, t_env=self.t_env, task=self.task, test_mode=test_mode or deterministic)
        
        self.batch.update({"actions": actions}, ts=self.t)

        if not test_mode:
            self.t_env += self.t
        cur_stats = self.test_stats if test_mode else self.train_stats
        cur_returns = self.test_returns if test_mode else self.train_returns
        # log_prefix = f"{self.task}/test_" if test_mode else f"{self.task}/"
        log_prefix = f"{'pretrain/' if pretrain else ''}{self.task}/{'test_' if test_mode else ''}"
        if run_with_prior_encoding:
            log_prefix += "prior_role_"
        cur_stats.update({k: cur_stats.get(k, 0) + env_info.get(k, 0) for k in set(cur_stats) | set(env_info)})
        cur_stats["n_episodes"] = 1 + cur_stats.get("n_episodes", 0)
        cur_stats["ep_length"] = self.t + cur_stats.get("ep_length", 0)

        cur_returns.append(episode_return)

        return_mean = None
        battle_won_mean = None
        if not nolog:
            if test_mode and len(self.test_returns) == self.args.test_nepisode:
                log_dic, return_mean, battle_won_mean = self._log(cur_returns, cur_stats, log_prefix)
            elif not test_mode and self.t_env - self.log_train_stats_t >= self.args.runner_log_interval:
                log_dic, return_mean, battle_won_mean = self._log(cur_returns, cur_stats, log_prefix)
                if 'offline' not in self.args.run_file:
                    if hasattr(self.mac.action_selector, "epsilon"):
                        self.logger.log_stat(f"{self.task}/epsilon", self.mac.action_selector.epsilon, self.t_env)
                self.log_train_stats_t = self.t_env

        if return_detail:
            return self.batch, return_mean, battle_won_mean, episode_return, float(env_info.get("battle_won", 0))
        return self.batch, return_mean, battle_won_mean

    def _log(self, returns, stats, prefix):
        log_dic = {}
        return_mean = np.mean(returns)
        self.logger.log_stat(prefix + "return_mean", return_mean, self.t_env)
        self.logger.log_stat(prefix + "return_std", np.std(returns), self.t_env)
        log_dic[prefix + "return_mean"] = return_mean
        log_dic[prefix + "return_std"] = np.std(returns)
        returns.clear()

        battle_won_mean = 0
        for k, v in stats.items():
            if k != "n_episodes":
                if k == "battle_won":
                    battle_won_mean = v / stats["n_episodes"]
                self.logger.log_stat(prefix + k + "_mean" , v/stats["n_episodes"], self.t_env)
                log_dic[prefix + k + "_mean"] = v/stats["n_episodes"]
        stats.clear()

        return log_dic, return_mean, battle_won_mean
