from envs import REGISTRY as env_REGISTRY
from functools import partial
from components.episode_buffer import EpisodeBatch
from multiprocessing import Pipe, Process

import numpy as np
import time
import copy


# Based (very) heavily on SubprocVecEnv from OpenAI Baselines
# https://github.com/openai/baselines/blob/master/baselines/common/vec_env/subproc_vec_env.py
class ParallelRunner:

    def __init__(self, args, logger):
        self.args = args
        self.logger = logger
        self.batch_size = self.args.batch_size_run

        # Make subprocesses for the envs
        self.parent_conns, self.worker_conns = zip(*[Pipe() for _ in range(self.batch_size)])
        env_fn = env_REGISTRY[self.args.env]
        env_args = [self.args.env_args.copy() for _ in range(self.batch_size)]
        for i in range(self.batch_size):
            env_args[i]["seed"] += i
            env_args[i]["common_reward"] = self.args.common_reward
            env_args[i]["reward_scalarisation"] = self.args.reward_scalarisation
        self.ps = [
            Process(
                target=env_worker,
                args=(worker_conn, CloudpickleWrapper(partial(env_fn, **env_arg)), args.delay_type, args.delay_value, args.delay_scope, args.delay_aware),
            )
            for env_arg, worker_conn in zip(env_args, self.worker_conns)
        ]

        for p in self.ps:
            p.daemon = True
            p.start()

        self.parent_conns[0].send(("get_env_info", None))
        self.env_info = self.parent_conns[0].recv()
        self.episode_limit = self.env_info["episode_limit"]

        self.t = 0

        self.t_env = 0

        self.train_returns = []
        self.test_returns = []
        self.train_stats = {}
        self.test_stats = {}
        self.last_test_stats = {}
        
        self.log_train_stats_t = -100000

    def setup(self, scheme, groups, preprocess, mac):
        if self.args.use_cuda and not self.args.cpu_inference:
            self.batch_device = self.args.device
        else:
            self.batch_device = "cpu" if self.args.buffer_cpu_only else self.args.device
        print(" &&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&&& self.batch_device={}".format(
            self.batch_device))
        self.new_batch = partial(EpisodeBatch, scheme, groups, self.batch_size, self.episode_limit + 1,
                                 preprocess=preprocess, device=self.batch_device)
        self.mac = mac
        self.scheme = scheme
        self.groups = groups
        self.preprocess = preprocess

    def get_env_info(self):
        return self.env_info

    def save_replay(self):
        pass

    def close_env(self):
        for parent_conn in self.parent_conns:
            parent_conn.send(("close", None))

    def reset(self):
        self.batch = self.new_batch()

        if (self.args.use_cuda and self.args.cpu_inference) and str(self.mac.get_device()) != "cpu":
            self.mac.cpu()  # copy model to cpu

        # Reset the envs
        for parent_conn in self.parent_conns:
            parent_conn.send(("reset", None))

        pre_transition_data = {
            "state": [],
            "avail_actions": [],
            "obs": []
        }
        if self.args.delay_type != "":
            pre_transition_data["real_obs"] = []
            pre_transition_data["enemy_delay_values"] = []
            pre_transition_data["ally_delay_values"] = []
        # Get the obs, state and avail_actions back
        for parent_conn in self.parent_conns:
            data = parent_conn.recv()
            pre_transition_data["state"].append(data["state"])
            pre_transition_data["avail_actions"].append(data["avail_actions"])
            pre_transition_data["obs"].append(data["obs"])
            if self.args.delay_type != "":
                pre_transition_data["real_obs"].append(data["real_obs"])
                pre_transition_data["enemy_delay_values"].append(data["enemy_delay_values"])
                pre_transition_data["ally_delay_values"].append(data["ally_delay_values"])

        self.batch.update(pre_transition_data, ts=0, mark_filled=True)

        self.t = 0
        self.env_steps_this_run = 0

    def run(self, test_mode=False):
        self.reset()

        all_terminated = False
        if self.args.common_reward:
            episode_returns = [0 for _ in range(self.batch_size)]
        else:
            episode_returns = [
                np.zeros(self.args.n_agents) for _ in range(self.batch_size)
            ]
        episode_lengths = [0 for _ in range(self.batch_size)]
        self.mac.init_hidden(batch_size=self.batch_size)
        terminated = [False for _ in range(self.batch_size)]
        envs_not_terminated = [b_idx for b_idx, termed in enumerate(terminated) if not termed]
        final_env_infos = []  # may store extra stats like battle won. this is filled in ORDER OF TERMINATION

        while True:
            # Pass the entire batch of experiences up till now to the agents
            # Receive the actions for each agent at this timestep in a batch for each un-terminated env
            actions = self.mac.select_actions(self.batch, t_ep=self.t, t_env=self.t_env, bs=envs_not_terminated, test_mode=test_mode)
            cpu_actions = actions.to("cpu").numpy()  # (batch_size, n_agents)

            # Update the actions taken
            actions_chosen = {
                "actions": np.expand_dims(cpu_actions, axis=1),
            }
            self.batch.update(actions_chosen, bs=envs_not_terminated, ts=self.t, mark_filled=False)

            # Send actions to each env
            action_idx = 0
            for idx, parent_conn in enumerate(self.parent_conns):
                if idx in envs_not_terminated:  # We produced actions for this env
                    if not terminated[idx]:  # Only send the actions to the env if it hasn't terminated
                        parent_conn.send(("step", cpu_actions[action_idx]))
                    action_idx += 1  # actions is not a list over every env

            # Post step data we will insert for the current timestep
            post_transition_data = {
                "reward": [],
                "terminated": []
            }
            # Data for the next step we will insert in order to select an action
            pre_transition_data = {
                "state": [],
                "avail_actions": [],
                "obs": []
            }
            if self.args.delay_type != "":
                pre_transition_data["real_obs"] = []
                pre_transition_data["enemy_delay_values"] = []
                pre_transition_data["ally_delay_values"] = []
            # Receive data back for each unterminated env
            for idx, parent_conn in enumerate(self.parent_conns):
                if not terminated[idx]:
                    data = parent_conn.recv()
                    # Remaining data for this current timestep
                    post_transition_data["reward"].append((data["reward"],))

                    episode_returns[idx] += data["reward"]
                    episode_lengths[idx] += 1
                    if not test_mode:
                        self.env_steps_this_run += 1

                    env_terminated = False
                    if data["terminated"]:
                        final_env_infos.append(data["info"])
                    if data["terminated"] and not data["info"].get("episode_limit", False):
                        env_terminated = True
                    terminated[idx] = data["terminated"]
                    post_transition_data["terminated"].append((env_terminated,))

                    # Data for the next timestep needed to select an action
                    pre_transition_data["state"].append(data["state"])
                    pre_transition_data["avail_actions"].append(data["avail_actions"])
                    pre_transition_data["obs"].append(data["obs"])
                    if self.args.delay_type != "":
                        pre_transition_data["real_obs"].append(data["real_obs"])
                        pre_transition_data["enemy_delay_values"].append(data["enemy_delay_values"])
                        pre_transition_data["ally_delay_values"].append(data["ally_delay_values"])

            # Add post_transiton data into the batch
            self.batch.update(post_transition_data, bs=envs_not_terminated, ts=self.t, mark_filled=False)

            # Move onto the next timestep
            self.t += 1

            # Add the pre-transition data
            self.batch.update(pre_transition_data, bs=envs_not_terminated, ts=self.t, mark_filled=True)

            # TODO: 这段到底放在前面还是这里？
            # Update envs_not_terminated
            envs_not_terminated = [b_idx for b_idx, termed in enumerate(terminated) if not termed]
            all_terminated = all(terminated)
            if all_terminated:
                break

        if not test_mode:
            self.t_env += self.env_steps_this_run

        # Get stats back for each env
        for parent_conn in self.parent_conns:
            parent_conn.send(("get_stats", None))

        env_stats = []
        for parent_conn in self.parent_conns:
            env_stat = parent_conn.recv()
            env_stats.append(env_stat)

        cur_stats = self.test_stats if test_mode else self.train_stats
        cur_returns = self.test_returns if test_mode else self.train_returns
        log_prefix = "test_" if test_mode else ""
        infos = [cur_stats] + final_env_infos

        cur_stats.update({k: sum(d.get(k, 0) for d in infos) for k in set.union(*[set(d) for d in infos])})
        cur_stats["n_episodes"] = self.batch_size + cur_stats.get("n_episodes", 0)
        cur_stats["ep_length"] = sum(episode_lengths) + cur_stats.get("ep_length", 0)

        cur_returns.extend(episode_returns)

        # test with chunksize
        # chunksize = 32
        # if cur_stats['n_episodes'] % chunksize == 0:
        #     print(len(self.test_returns))
        #     if cur_stats['n_episodes'] / chunksize == 1:
        #         if self.args.env in ["sc2", "sc2_v2"]:
        #             print("test_battle_won: {} return_mean: {}".format(cur_stats['battle_won'] / chunksize, np.mean(cur_returns)))
        #         elif self.args.env == "gymma":
        #             print("return_mean: {}".format(np.mean(cur_returns)))
        #         self.last_test_stats = copy.deepcopy(cur_stats)
        #     else:
        #         if self.args.env in ["sc2", "sc2_v2"]:
        #             print("test_battle_won: {} return_mean: {}".format((cur_stats['battle_won'] - self.last_test_stats['battle_won']) / chunksize, np.mean(cur_returns[-chunksize:])))
        #         elif self.args.env == "gymma":
        #             print("return_mean: {}".format(np.mean(cur_returns[-chunksize:])))
        #         self.last_test_stats = copy.deepcopy(cur_stats)

        n_test_runs = max(1, self.args.test_nepisode // self.batch_size) * self.batch_size
        if test_mode and (len(self.test_returns) == n_test_runs):
            self._log(cur_returns, cur_stats, log_prefix)
        elif not test_mode and self.t_env - self.log_train_stats_t >= self.args.runner_log_interval:
            self._log(cur_returns, cur_stats, log_prefix)
            if hasattr(self.mac.action_selector, "epsilon"):
                self.logger.log_stat("eval/epsilon", self.mac.action_selector.epsilon, self.t_env)
            self.log_train_stats_t = self.t_env

        return self.batch

    def save_replay(self):
        print("----------------------------Replay----------------------------")
        if self.args.save_replay:
            for parent_conn in self.parent_conns:
                parent_conn.send(("save_replay", None))
            for parent_conn in self.parent_conns:
                _ = parent_conn.recv()

    def _log(self, returns, stats, prefix):
        if self.args.common_reward:
            self.logger.log_stat("eval/" + prefix + "return_mean", np.mean(returns), self.t_env)
            self.logger.log_stat("eval/" + prefix + "return_std", np.std(returns), self.t_env)
        else:
            for i in range(self.args.n_agents):
                self.logger.log_stat(
                    "eval/" + prefix + f"agent_{i}_return_mean",
                    np.array(returns)[:, i].mean(),
                    self.t_env,
                )
                self.logger.log_stat(
                    "eval/" + prefix + f"agent_{i}_return_std",
                    np.array(returns)[:, i].std(),
                    self.t_env,
                )
            total_returns = np.array(returns).sum(axis=-1)
            self.logger.log_stat(
                "eval/" + prefix + "total_return_mean", total_returns.mean(), self.t_env
            )
            self.logger.log_stat(
                "eval/" + prefix + "total_return_std", total_returns.std(), self.t_env
            )
        returns.clear()

        for k, v in stats.items():
            if k != "n_episodes":
                self.logger.log_stat("eval/" + prefix + k + "_mean", v / stats["n_episodes"], self.t_env)
        stats.clear()
        # print(self.logger.stats)


def env_worker(remote, env_fn, delay_type, delay_value, delay_scope, delay_aware):
    # Make environment
    env = env_fn.x()
    while True:
        cmd, data = remote.recv()
        if cmd == "step":
            actions = data
            # Take a step in the environment
            _, reward, terminated, truncated, env_info = env.step(actions)
            terminated = terminated or truncated
            # Return the observations, avail_actions and state to make the next action
            state = env.get_state()
            avail_actions = env.get_avail_actions()
            if delay_type == "":
                # 环境无延迟
                obs = env.get_obs()
                remote.send({
                    # Data for the next timestep needed to pick an action
                    "state": state,
                    "avail_actions": avail_actions,
                    "obs": obs,
                    # Rest of the data for the current timestep
                    "reward": reward,
                    "terminated": terminated,
                    "info": env_info
                })
            else:
                # 真实观测和延迟观测
                real_obs = env.get_obs()
                delay_obs, enemy_delay_values, ally_delay_values = env.get_obs_delayed(delay_type, delay_value, delay_scope)
                if not delay_aware:
                    state = env.get_state_delayed(delay_type, delay_value, delay_scope)
                remote.send({
                    # Data for the next timestep needed to pick an action
                    "state": state,
                    "avail_actions": avail_actions,
                    "obs": delay_obs,
                    "real_obs": real_obs,
                    "enemy_delay_values": enemy_delay_values,
                    "ally_delay_values": ally_delay_values,
                    # Rest of the data for the current timestep
                    "reward": reward,
                    "terminated": terminated,
                    "info": env_info
                })
        elif cmd == "reset":
            env.reset()
            state = env.get_state()
            avail_actions = env.get_avail_actions()
            if delay_type == "":
                # 环境无延迟
                obs = env.get_obs()
                remote.send({
                    "state": state,
                    "avail_actions": avail_actions,
                    "obs" : obs
                })
            else:
                # 真实观测和延迟观测
                real_obs = env.get_obs()
                delay_obs, enemy_delay_values, ally_delay_values = env.get_obs_delayed(delay_type, delay_value, delay_scope)
                if not delay_aware:
                    state = env.get_state_delayed(delay_type, delay_value, delay_scope)
                remote.send({
                    "state": state,
                    "avail_actions": avail_actions,
                    "obs" : delay_obs,
                    "real_obs": real_obs,
                    "enemy_delay_values": enemy_delay_values,
                    "ally_delay_values": ally_delay_values
                })
        elif cmd == "close":
            env.close()
            remote.close()
            break
        elif cmd == "get_env_info":
            remote.send(env.get_env_info(delay_aware))
        elif cmd == "get_stats":
            remote.send(env.get_stats())
        elif cmd == "save_replay":
            remote.send(env.save_replay())
        elif cmd == "render":
            env.render()
        else:
            raise NotImplementedError


class CloudpickleWrapper():
    """
    Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle)
    """

    def __init__(self, x):
        self.x = x

    def __getstate__(self):
        import cloudpickle
        return cloudpickle.dumps(self.x)

    def __setstate__(self, ob):
        import pickle
        self.x = pickle.loads(ob)
