import time
import wandb
import numpy as np
from functools import reduce
import torch
from code_ptmc_mappo.runner.shared.base_runner import Runner
import copy

def _t2n(x):
    return x.detach().cpu().numpy()

class StugHuntRunner(Runner):
    """Runner class to perform training, evaluation. and data collection for SMAC. See parent class for details."""

    def __init__(self, config):
        super(StugHuntRunner, self).__init__(config)

        from code_ptmc_mappo.runner.tacit_pretr import staghunt_tacit_reward_construction as tacit_module
        self.data_mask = tacit_module.data_mask
        self.reward_reconfiguration = tacit_module.reward_reconfiguration

    def run(self):
        logger = self.log_creater()
        agent_pos_temp = np.zeros(
            (self.episode_length + 1, self.n_rollout_threads, self.num_agents, 2),
            dtype=np.float32)
        agent_pos_temp = self.warmup(agent_pos_temp)

        start = time.time()
        episodes = int(self.num_env_steps) // self.episode_length // self.n_rollout_threads
        # 解析(等价于)--episodes = (self.num_env_steps) ÷ (self.episode_length × self.n_rollout_threads)

        for episode in range(episodes):
            if self.use_linear_lr_decay:
                self.trainer.policy.lr_decay(episode, episodes)

            # tacit 指标记录 -- 每一个episode更新：
            # 大小为一个 4（mean_r_tacit, num_all, num_positive, tacit_indicator）* 4（4个局面） 的矩阵
            if self.algorithm_name == "ippo":

                tacit_indicator = np.zeros((4, 4), dtype=np.float32)

            for step in range(self.episode_length):
                # 重置 agent_pos_temp,保存obs_t0
                if episode != 0 and step == 0:
                    agent_pos_temp[step] = agent_pos_temp[self.episode_length]
                    agent_pos_temp[step+1:] = 0
                    obs_t0 = self.buffer.obs[self.episode_length]

                # Sample actions
                values, actions, action_log_probs, rnn_states, rnn_states_critic = self.collect(step)
                # Obser reward and next obs
                obs, share_obs, rewards, dones, available_actions, agent_pos = self.envs.step(actions)
                agent_pos = np.squeeze(agent_pos, axis=2)

                if self.algorithm_name == "ippo":
                    obs, available_actions = self.data_mask(obs, available_actions, self.all_args)  # 掩码
                    agent_pos_t0 = agent_pos_temp[step]  # 调取前一时刻的 agent_pos, 即 agent_pos_t0
                    if episode == 0 or step != 0:
                        obs_t0 = self.buffer.obs[step]
                    tacit_reward, tacit_indicator = self.reward_reconfiguration(self.all_args, agent_pos_t0, agent_pos,
                                                                                obs_t0, tacit_indicator)
                    rewards = tacit_reward
                    agent_pos_temp[step + 1] = copy.deepcopy(agent_pos)  # 记录这个step的share_obs

                data = obs, share_obs, rewards, dones, available_actions, \
                    values, actions, action_log_probs, rnn_states, rnn_states_critic

                # insert data into buffer
                self.insert(data)

            # post process
            total_num_steps = (episode + 1) * self.episode_length * self.n_rollout_threads

            # compute return and update network
            self.compute()
            train_infos = self.train(total_num_steps)

            # 记录该 episode 的 tacit 值
            if self.algorithm_name == "ippo":
                self.log_tacit(tacit_indicator, total_num_steps)

            # save model
            if (episode % self.save_interval == 0 or episode == episodes - 1):
                self.save(episode, total_num_steps)

            # log information
            if episode % self.log_interval == 0:
                end = time.time()
                print("\n Algo {} Exp {} updates {}/{} episodes, total num timesteps {}/{}, FPS {}.\n"
                      .format(self.algorithm_name,
                              self.experiment_name,
                              episode,
                              episodes,
                              total_num_steps,
                              self.num_env_steps,
                              int(total_num_steps / (end - start))))
                logger.info("\n Algo {} Exp {} updates {}/{} episodes, total num timesteps {}/{}, FPS {}.\n"
                            .format(self.algorithm_name,
                                    self.experiment_name,
                                    episode,
                                    episodes,
                                    total_num_steps,
                                    self.num_env_steps,
                                    int(total_num_steps / (end - start))))

                train_infos["average_episode_rewards"] = np.mean(self.buffer.rewards)
                print("average episode rewards is {}".format(train_infos["average_episode_rewards"]))
                logger.info("average episode rewards is {}".format(train_infos["average_episode_rewards"]))
                info_log_reward = train_infos["average_episode_rewards"]
                if self.use_wandb:
                    wandb.log({"average_episode_rewards": info_log_reward}, step=total_num_steps)
                else:
                    logger.info(f"average_episode_rewards: {info_log_reward}, step: {total_num_steps}")
                    self.writter.add_scalars("average_episode_rewards", {"average_episode_rewards": info_log_reward}, total_num_steps)

            # eval
            if episode % self.eval_interval == 0 and self.use_eval:
                self.eval(total_num_steps)

    def warmup(self, agent_pos_temp):
        # reset env
        obs, share_obs, available_actions, agent_pos = self.envs.reset()
        agent_pos = np.squeeze(agent_pos, axis=2)
        if self.algorithm_name == "ippo":
            obs, available_actions = self.data_mask(obs, available_actions, self.all_args)  # 掩码
        agent_pos_cpy = copy.deepcopy(agent_pos)

        # replay buffer
        if not self.use_centralized_V:
            share_obs = obs

        self.buffer.share_obs[0] = share_obs.copy()
        self.buffer.obs[0] = obs.copy()
        self.buffer.available_actions[0] = available_actions.copy()
        agent_pos_temp[0] = agent_pos_cpy.copy()

        return agent_pos_temp

    @torch.no_grad()
    def collect(self, step):
        self.trainer.prep_rollout()
        value, action, action_log_prob, rnn_state, rnn_state_critic \
            = self.trainer.policy.get_actions(np.concatenate(self.buffer.share_obs[step]),
                                              np.concatenate(self.buffer.obs[step]),
                                              np.concatenate(self.buffer.rnn_states[step]),
                                              np.concatenate(self.buffer.rnn_states_critic[step]),
                                              np.concatenate(self.buffer.masks[step]),
                                              np.concatenate(self.buffer.available_actions[step]))
        # [self.envs, agents, dim]
        values = np.array(np.split(_t2n(value), self.n_rollout_threads))
        actions = np.array(np.split(_t2n(action), self.n_rollout_threads))
        action_log_probs = np.array(np.split(_t2n(action_log_prob), self.n_rollout_threads))
        rnn_states = np.array(np.split(_t2n(rnn_state), self.n_rollout_threads))
        rnn_states_critic = np.array(np.split(_t2n(rnn_state_critic), self.n_rollout_threads))

        return values, actions, action_log_probs, rnn_states, rnn_states_critic

    def insert(self, data):
        obs, share_obs, rewards, dones, available_actions, \
            values, actions, action_log_probs, rnn_states, rnn_states_critic = data

        dones_env = np.all(dones, axis=1)

        if self.algorithm_name == "ippo":
            rewards_squeezed = rewards.squeeze(-1)
            dones[rewards_squeezed == 0] = True

        rnn_states[dones_env == True] = np.zeros(
            ((dones_env == True).sum(), self.num_agents, self.recurrent_N, self.hidden_size), dtype=np.float32)
        rnn_states_critic[dones_env == True] = np.zeros(
            ((dones_env == True).sum(), self.num_agents, *self.buffer.rnn_states_critic.shape[3:]), dtype=np.float32)

        masks = np.ones((self.n_rollout_threads, self.num_agents, 1), dtype=np.float32)
        masks[dones_env == True] = np.zeros(((dones_env == True).sum(), self.num_agents, 1), dtype=np.float32)

        active_masks = np.ones((self.n_rollout_threads, self.num_agents, 1), dtype=np.float32)
        active_masks[dones == True] = np.zeros(((dones == True).sum(), 1), dtype=np.float32)
        active_masks[dones_env == True] = np.ones(((dones_env == True).sum(), self.num_agents, 1), dtype=np.float32)

        if not self.use_centralized_V:
            share_obs = obs

        self.buffer.insert(share_obs, obs, rnn_states, rnn_states_critic,
                           actions, action_log_probs, values, rewards, masks, active_masks=active_masks,
                           available_actions = available_actions)
    @torch.no_grad()
    def eval(self, total_num_steps):
        eval_episode = 0

        eval_episode_rewards = []
        one_episode_rewards = []

        logger = self.log_creater()
        eval_obs, eval_share_obs, eval_available_actions, eval_agent_pos = self.eval_envs.reset()
        eval_rnn_states = np.zeros((self.n_eval_rollout_threads, self.num_agents, self.recurrent_N, self.hidden_size),
                                   dtype=np.float32)
        eval_masks = np.ones((self.n_eval_rollout_threads, self.num_agents, 1), dtype=np.float32)

        while True:
            self.trainer.prep_rollout()
            if self.algorithm_name == "mat" or self.algorithm_name == "mat_dec":
                eval_actions, eval_rnn_states = \
                    self.trainer.policy.act(np.concatenate(eval_share_obs),
                                            np.concatenate(eval_obs),
                                            np.concatenate(eval_rnn_states),
                                            np.concatenate(eval_masks),
                                            np.concatenate(eval_available_actions),
                                            deterministic=True)
            else:
                eval_actions, eval_rnn_states = \
                    self.trainer.policy.act(np.concatenate(eval_obs),
                                            np.concatenate(eval_rnn_states),
                                            np.concatenate(eval_masks),
                                            np.concatenate(eval_available_actions),
                                            deterministic=True)
            eval_actions = np.array(np.split(_t2n(eval_actions), self.n_eval_rollout_threads))
            eval_rnn_states = np.array(np.split(_t2n(eval_rnn_states), self.n_eval_rollout_threads))

            # Obser reward and next obs
            eval_obs, eval_share_obs, eval_rewards, eval_dones, eval_available_actions, eval_agent_pos = self.eval_envs.step(
                eval_actions)
            one_episode_rewards.append(eval_rewards)
            eval_dones_env = np.all(eval_dones, axis=1)
            eval_rnn_states[eval_dones_env == True] = np.zeros(
                ((eval_dones_env == True).sum(), self.num_agents, self.recurrent_N, self.hidden_size), dtype=np.float32)
            eval_masks = np.ones((self.all_args.n_eval_rollout_threads, self.num_agents, 1), dtype=np.float32)
            eval_masks[eval_dones_env == True] = np.zeros(((eval_dones_env == True).sum(), self.num_agents, 1),
                                                          dtype=np.float32)

            for eval_i in range(self.n_eval_rollout_threads):
                if eval_dones_env[eval_i]:
                    eval_episode += 1
                    eval_episode_rewards.append(np.sum(one_episode_rewards, axis=0))
                    one_episode_rewards = []

            if eval_episode >= self.all_args.eval_episodes:
                flattened_rewards = []
                for r in eval_episode_rewards:
                    if isinstance(r, (float, int, np.float32, np.float64)):
                        flattened_rewards.append(float(r))
                    elif isinstance(r, np.ndarray):
                        flattened_rewards.extend(r.flatten().tolist())
                    else:
                        raise ValueError(f"Unexpected type in eval_episode_rewards: {type(r)}")
                eval_episode_rewards = float(np.mean(flattened_rewards))
                print("eval_episode_rewards is {}.".format(eval_episode_rewards))
                if self.use_wandb:
                    wandb.log({'eval_episode_rewards': eval_episode_rewards}, step=total_num_steps)
                else:
                    logger.info(f"eval_episode_rewards': {eval_episode_rewards} at step {total_num_steps}")
                    self.writter.add_scalars('eval_episode_rewards', {'eval_episode_rewards': eval_episode_rewards}, total_num_steps)
                break