import os
import socket
import time
from pathlib import Path
import wandb
from torch.utils.tensorboard import SummaryWriter
from .runner_basic import Runner_Base, make_envs
from xuance.tensorflow.agents import REGISTRY as REGISTRY_Agent
from gymnasium.spaces.box import Box
from tqdm import tqdm
import numpy as np
from copy import deepcopy


class Pettingzoo_Runner(Runner_Base):
    def __init__(self, args):
        self.args = args if type(args) == list else [args]
        self.fps = 20

        time_string = time.asctime().replace(" ", "").replace(":", "_")
        for arg in self.args:
            seed = f"seed_{arg.seed}_"
            arg.model_dir_load = arg.model_dir
            arg.model_dir_save = os.path.join(os.getcwd(), arg.model_dir, seed + time_string)
            if (not os.path.exists(arg.model_dir_save)) and (not arg.test_mode):
                os.makedirs(arg.model_dir_save)

            if arg.logger == "tensorboard":
                log_dir = os.path.join(os.getcwd(), arg.log_dir, seed + time_string)
                if not os.path.exists(log_dir):
                    os.makedirs(log_dir)
                self.writer = SummaryWriter(log_dir)
                self.use_wandb = False
            else:
                self.use_wandb = True

        for arg in self.args:
            if arg.agent_name == "random":
                continue
            else:
                self.args_base = arg
                super(Pettingzoo_Runner, self).__init__(arg)
                self.running_steps = arg.running_steps
                self.training_frequency = arg.training_frequency
                self.train_per_step = arg.train_per_step

                # build environments
                self.n_handles = len(self.envs.handles)
                self.agent_keys = self.envs.agent_keys
                self.agent_ids = self.envs.agent_ids
                self.agent_keys_all = self.envs.keys
                self.n_agents_all = len(self.agent_keys_all)
                self.render = arg.render

                self.n_steps = arg.running_steps
                self.test_mode = arg.test_mode
                self.marl_agents, self.marl_names = [], []
                self.current_step, self.current_episode = 0, np.zeros((self.envs.num_envs,), np.int32)

                if self.use_wandb:
                    config_dict = vars(arg)
                    wandb_dir = Path(os.path.join(os.getcwd(), arg.log_dir))
                    if not wandb_dir.exists():
                        os.makedirs(str(wandb_dir))
                    wandb.init(config=config_dict,
                               project=arg.project_name,
                               entity=arg.wandb_user_name,
                               notes=socket.gethostname(),
                               dir=wandb_dir,
                               group=arg.env_id,
                               job_type=arg.agent,
                               name=time.asctime(),
                               reinit=True)
                break

        self.episode_length = self.envs.max_episode_length

        # environment details, representations, policies, optimizers, and agents.
        for h, arg in enumerate(self.args):
            arg.handle_name = self.envs.side_names[h]
            if self.n_handles > 1 and arg.agent != "RANDOM":
                arg.model_dir += "{}/".format(arg.handle_name)
            arg.handle, arg.n_agents = h, self.envs.n_agents[h]
            arg.agent_keys, arg.agent_ids = self.agent_keys[h], self.agent_ids[h]
            arg.state_space = self.envs.state_space
            arg.observation_space = self.envs.observation_space
            if isinstance(self.envs.action_space[self.agent_keys[h][0]], Box):
                arg.dim_act = self.envs.action_space[self.agent_keys[h][0]].shape[0]
                arg.act_shape = (arg.dim_act,)
            else:
                arg.dim_act = self.envs.action_space[self.agent_keys[h][0]].n
                arg.act_shape = ()
            arg.action_space = self.envs.action_space
            if arg.env_name == "MAgent2":
                arg.obs_shape = (np.prod(self.envs.observation_space[self.agent_keys[h][0]].shape),)
                arg.dim_obs = arg.obs_shape[0]
            else:
                arg.obs_shape = self.envs.observation_space[self.agent_keys[h][0]].shape
                arg.dim_obs = arg.obs_shape[0]
            arg.rew_shape, arg.done_shape, arg.act_prob_shape = (arg.n_agents, 1), (arg.n_agents,), (arg.dim_act,)
            self.marl_agents.append(REGISTRY_Agent[arg.agent](arg, self.envs, arg.device))
            self.marl_names.append(arg.agent)
            if arg.test_mode:
                self.marl_agents[h].load_model(arg.model_dir, arg.seed)

        self.print_infos(self.args)

    def log_infos(self, info: dict, x_index: int):
        """
        info: (dict) information to be visualized
        n_steps: current step
        """
        if self.use_wandb:
            for k, v in info.items():
                wandb.log({k: v}, step=x_index)
        else:
            for k, v in info.items():
                try:
                    self.writer.add_scalar(k, v, x_index)
                except:
                    self.writer.add_scalars(k, v, x_index)

    def log_videos(self, info: dict, fps: int, x_index: int = 0):
        if self.use_wandb:
            for k, v in info.items():
                wandb.log({k: wandb.Video(v, fps=fps, format='gif')}, step=x_index)
        else:
            for k, v in info.items():
                self.writer.add_video(k, v, fps=fps, global_step=x_index)

    def print_infos(self, args):
        infos = []
        for h, arg in enumerate(args):
            agent_name = self.envs.agent_keys[h][0][0:-2]
            if arg.n_agents == 1:
                infos.append(agent_name + ": {} agent".format(arg.n_agents) + ", {}".format(arg.agent))
            else:
                infos.append(agent_name + ": {} agents".format(arg.n_agents) + ", {}".format(arg.agent))
        print(infos)
        time.sleep(0.01)

    def combine_env_actions(self, actions):
        actions_envs = []
        num_env = actions[0].shape[0]
        for e in range(num_env):
            act_handle = {}
            for h, keys in enumerate(self.agent_keys):
                act_handle.update({agent_name: actions[h][e][i] for i, agent_name in enumerate(keys)})
            actions_envs.append(act_handle)
        return actions_envs

    def get_actions(self, obs_n, test_mode, act_mean_last, agent_mask, state):
        actions_n, log_pi_n, values_n, actions_n_onehot = [], [], [], []
        act_mean_current = act_mean_last
        for h, mas_group in enumerate(self.marl_agents):
            if self.marl_names[h] == "MFQ":
                _, a, a_mean = mas_group.act(obs_n[h], test_mode=test_mode, act_mean=act_mean_last[h], agent_mask=agent_mask[h])
                act_mean_current[h] = a_mean
            elif self.marl_names[h] == "MFAC":
                a, a_mean = mas_group.act(obs_n[h], test_mode, act_mean_last[h], agent_mask[h])
                act_mean_current[h] = a_mean
                _, values = mas_group.values(obs_n[h], act_mean_current[h])
                values_n.append(values)
            elif self.marl_names[h] == "VDAC":
                _, a, values = mas_group.act(obs_n[h], state=state, test_mode=test_mode)
                values_n.append(values)
            elif self.marl_names[h] in ["MAPPO", "IPPO"]:
                _, a, log_pi = mas_group.act(obs_n[h], test_mode=test_mode, state=state)
                _, values = mas_group.values(obs_n[h], state=state)
                log_pi_n.append(log_pi)
                values_n.append(values)
            elif self.marl_names[h] in ["COMA"]:
                _, a, a_onehot = mas_group.act(obs_n[h], test_mode)
                _, values = mas_group.values(obs_n[h], state=state, actions_n=a, actions_onehot=a_onehot)
                actions_n_onehot.append(a_onehot)
                values_n.append(values)
            else:
                _, a = mas_group.act(obs_n[h], test_mode=test_mode)
            actions_n.append(a)
        return {'actions_n': actions_n, 'log_pi': log_pi_n, 'act_mean': act_mean_current,
                'act_n_onehot': actions_n_onehot, 'values': values_n}

    def store_data(self, obs_n, next_obs_n, actions_dict, state, next_state, agent_mask, rew_n, done_n):
        for h, mas_group in enumerate(self.marl_agents):
            if mas_group.args.agent_name == "random":
                continue
            data_step = {'obs': obs_n[h], 'obs_next': next_obs_n[h], 'actions': actions_dict['actions_n'][h],
                         'state': state, 'state_next': next_state, 'rewards': rew_n[h],
                         'agent_mask': agent_mask[h], 'terminals': done_n[h]}
            if mas_group.on_policy:
                data_step['values'] = actions_dict['values'][h]
                if self.marl_names[h] == "MAPPO":
                    data_step['log_pi_old'] = actions_dict['log_pi'][h]
                elif self.marl_names[h] == "COMA":
                    data_step['actions_onehot'] = actions_dict['act_n_onehot'][h]
                else:
                    pass
                mas_group.memory.store(data_step)
                if mas_group.memory.full:
                    if self.marl_names[h] == "COMA":
                        _, values_next = mas_group.values(next_obs_n[h],
                                                          state=next_state,
                                                          actions_n=actions_dict['actions_n'][h],
                                                          actions_onehot=actions_dict['act_n_onehot'][h])
                    elif self.marl_names[h] == "MFAC":
                        _, values_next = mas_group.values(next_obs_n[h], actions_dict['act_mean'][h])
                    elif self.marl_names[h] == "VDAC":
                        _, _, values_next = mas_group.act(next_obs_n[h])
                    else:
                        _, values_next = mas_group.values(next_obs_n[h], state=next_state)
                    for i_env in range(self.n_envs):
                        if done_n[h][i_env].all():
                            mas_group.memory.finish_path(0.0, i_env)
                        else:
                            mas_group.memory.finish_path(values_next[i_env], i_env)
                continue
            elif self.marl_names[h] in ["MFQ", "MFAC"]:
                data_step['act_mean'] = actions_dict['act_mean'][h]
            else:
                pass
            mas_group.memory.store(data_step)

    def train_episode(self, n_episodes):
        act_mean_last = [np.zeros([self.n_envs, arg.dim_act]) for arg in self.args]
        terminal_handle = np.zeros([self.n_handles, self.n_envs], dtype=np.bool_)
        truncate_handle = np.zeros([self.n_handles, self.n_envs], dtype=np.bool_)
        episode_score = np.zeros([self.n_handles, self.n_envs, 1], dtype=np.float32)
        episode_info, train_info = {}, {}
        for _ in tqdm(range(n_episodes)):
            obs_n = self.envs.buf_obs
            state, agent_mask = self.envs.global_state(), self.envs.agent_mask()
            for step in range(self.episode_length):
                actions_dict = self.get_actions(obs_n, False, act_mean_last, agent_mask, state)
                actions_execute = self.combine_env_actions(actions_dict['actions_n'])
                next_obs_n, rew_n, terminated_n, truncated_n, infos = self.envs.step(actions_execute)
                next_state, agent_mask = self.envs.global_state(), self.envs.agent_mask()

                self.store_data(obs_n, next_obs_n, actions_dict, state, next_state, agent_mask, rew_n, terminated_n)

                # train the model for each step
                if self.train_per_step:
                    if self.current_step % self.training_frequency == 0:
                        for h, mas_group in enumerate(self.marl_agents):
                            if mas_group.args.agent_name == "random":
                                continue
                            train_info = self.marl_agents[h].train(self.current_step)

                obs_n, state, act_mean_last = deepcopy(next_obs_n), deepcopy(next_state), deepcopy(
                    actions_dict['act_mean'])

                for h, mas_group in enumerate(self.marl_agents):
                    episode_score[h] += np.mean(rew_n[h] * agent_mask[h][:, :, np.newaxis], axis=1)
                    terminal_handle[h] = terminated_n[h].all(axis=-1)
                    truncate_handle[h] = truncated_n[h].all(axis=-1)

                for i_env in range(self.n_envs):
                    if terminal_handle.all(axis=0)[i_env] or truncate_handle.all(axis=0)[i_env]:
                        self.current_episode[i_env] += 1
                        for h, mas_group in enumerate(self.marl_agents):
                            if mas_group.args.agent_name == "random":
                                continue
                            if mas_group.on_policy:
                                if mas_group.args.agent == "COMA":
                                    _, value_next_e = mas_group.values(next_obs_n[h],
                                                                       state=next_state,
                                                                       actions_n=actions_dict['actions_n'][h],
                                                                       actions_onehot=actions_dict['act_n_onehot'][h])
                                elif mas_group.args.agent == "MFAC":
                                    _, value_next_e = mas_group.values(next_obs_n[h], act_mean_last[h])
                                elif mas_group.args.agent == "VDAC":
                                    _, _, value_next_e = mas_group.act(next_obs_n[h])
                                else:
                                    _, value_next_e = mas_group.values(next_obs_n[h], state=next_state)
                                mas_group.memory.finish_path(value_next_e[i_env], i_env)
                            obs_n[h][i_env] = infos[i_env]["reset_obs"][h]
                            agent_mask[h][i_env] = infos[i_env]["reset_agent_mask"][h]
                            act_mean_last[h][i_env] = np.zeros([self.args[h].dim_act])
                            episode_score[h, i_env] = np.mean(infos[i_env]["individual_episode_rewards"][h])
                        state[i_env] = infos[i_env]["reset_state"]
                self.current_step += self.n_envs

            if self.n_handles > 1:
                for h in range(self.n_handles):
                    episode_info["Train_Episode_Score/side_{}".format(self.args[h].handle_name)] = episode_score[h].mean()
            else:
                episode_info["Train_Episode_Score"] = episode_score[0].mean()

            # train the model for each episode
            if not self.train_per_step:
                for h, mas_group in enumerate(self.marl_agents):
                    if mas_group.args.agent_name == "random":
                        continue
                    train_info = self.marl_agents[h].train(self.current_step)
            self.log_infos(train_info, self.current_step)
            self.log_infos(episode_info, self.current_step)

    def test_episode(self, env_fn):
        test_envs = env_fn()
        test_info = {}
        num_envs = test_envs.num_envs
        videos, episode_videos = [[] for _ in range(num_envs)], []
        obs_n, infos = test_envs.reset()
        state, agent_mask = test_envs.global_state(), test_envs.agent_mask()
        if self.args_base.render_mode == "rgb_array" and self.render:
            images = test_envs.render(self.args_base.render_mode)
            for idx, img in enumerate(images):
                videos[idx].append(img)
        act_mean_last = [np.zeros([num_envs, arg.dim_act]) for arg in self.args]
        terminal_handle = np.zeros([self.n_handles, num_envs], dtype=np.bool_)
        truncate_handle = np.zeros([self.n_handles, num_envs], dtype=np.bool_)
        episode_score = np.zeros([self.n_handles, num_envs, 1], dtype=np.float32)

        for step in range(self.episode_length):
            actions_dict = self.get_actions(obs_n, True, act_mean_last, agent_mask, state)
            actions_execute = self.combine_env_actions(actions_dict['actions_n'])
            next_obs_n, rew_n, terminated_n, truncated_n, infos = test_envs.step(actions_execute)
            if self.args_base.render_mode == "rgb_array" and self.render:
                images = test_envs.render(self.args_base.render_mode)
                for idx, img in enumerate(images):
                    videos[idx].append(img)

            next_state, agent_mask = test_envs.global_state(), test_envs.agent_mask()

            obs_n, state, act_mean_last = deepcopy(next_obs_n), deepcopy(next_state), deepcopy(actions_dict['act_mean'])

            for h, mas_group in enumerate(self.marl_agents):
                episode_score[h] += np.mean(rew_n[h] * agent_mask[h][:, :, np.newaxis], axis=1)
                terminal_handle[h] = terminated_n[h].all(axis=-1)
                truncate_handle[h] = truncated_n[h].all(axis=-1)

            for i in range(num_envs):
                if terminal_handle.all(axis=0)[i] or truncate_handle.all(axis=0)[i]:
                    for h, mas_group in enumerate(self.marl_agents):
                        obs_n[h][i] = infos[i]["reset_obs"][h]
                        agent_mask[h][i] = infos[i]["reset_agent_mask"][h]
                        act_mean_last[h][i] = np.zeros([self.args[h].dim_act])
                    state = infos[i]["reset_state"]
        scores = episode_score.mean(axis=1).reshape([self.n_handles])
        if self.args_base.test_mode:
            print("Mean score: ", scores)

        if self.args_base.render_mode == "rgb_array" and self.render:
            # time, height, width, channel -> time, channel, height, width
            videos_info = {"Videos_Test": np.array(videos, dtype=np.uint8).transpose((0, 1, 4, 2, 3))}
            self.log_videos(info=videos_info, fps=self.fps, x_index=self.current_step)

        if self.n_handles > 1:
            for h in range(self.n_handles):
                test_info["Test-Episode-Rewards/Side_{}".format(self.args[h].handle_name)] = scores[h]
        else:
            test_info["Test-Episode-Rewards"] = scores[0]
        self.log_infos(test_info, self.current_step)

        test_envs.close()

        return episode_score

    def run(self):
        if self.args_base.test_mode:
            def env_fn():
                args_test = deepcopy(self.args_base)
                args_test.parallels = args_test.test_episode
                return make_envs(args_test)

            self.render = True
            for h, mas_group in enumerate(self.marl_agents):
                mas_group.load_model(mas_group.model_dir_load, mas_group.args.seed)
            self.test_episode(env_fn)
            print("Finish testing.")
        else:
            n_train_episodes = self.args_base.running_steps // self.episode_length // self.n_envs
            self.train_episode(n_train_episodes)
            print("Finish training.")
            for h, mas_group in enumerate(self.marl_agents):
                mas_group.save_model("final_train_model.ckpt")

        self.envs.close()
        if self.use_wandb:
            wandb.finish()
        else:
            self.writer.close()

    def benchmark(self):
        def env_fn():
            args_test = deepcopy(self.args_base)
            args_test.parallels = args_test.test_episode
            return make_envs(args_test)

        n_train_episodes = self.args_base.running_steps // self.episode_length // self.n_envs
        n_eval_interval = self.args_base.eval_interval // self.episode_length // self.n_envs
        num_epoch = int(n_train_episodes / n_eval_interval)

        test_scores = self.test_episode(env_fn)
        best_scores = [{
            "mean": np.mean(test_scores, axis=1).reshape([self.n_handles]),
            "std": np.std(test_scores, axis=1).reshape([self.n_handles]),
            "step": self.current_step
        } for _ in range(self.n_handles)]
        for h in range(self.n_handles):
            self.marl_agents[h].save_model("best_model")

        for i_epoch in range(num_epoch):
            print("Epoch: %d/%d:" % (i_epoch, num_epoch))
            self.train_episode(n_episodes=n_eval_interval)
            test_scores = self.test_episode(env_fn)

            mean_test_scores = np.mean(test_scores, axis=1)
            for h in range(self.n_handles):
                if mean_test_scores[h] > best_scores[h]["mean"][h]:
                    best_scores[h] = {
                        "mean": mean_test_scores.reshape([self.n_handles]),
                        "std": np.std(test_scores, axis=1).reshape([self.n_handles]),
                        "step": self.current_step
                    }
                    # save best model
                    self.marl_agents[h].save_model("best_model.ckpt")

        # end benchmarking
        print("Finish benchmarking.")
        for h in range(self.n_handles):
            print("Best Score for {}: ".format(self.envs.envs[0].side_names[h]))
            print("Mean: ", best_scores[h]["mean"], "Std: ", best_scores[h]["std"])

        self.envs.close()
        if self.use_wandb:
            wandb.finish()
        else:
            self.writer.close()
