from envs import REGISTRY as env_REGISTRY
from functools import partial
from components.episode_buffer import EpisodeBatch
from multiprocessing import Pipe, Process
import numpy as np
import torch as th


# Based (very) heavily on SubprocVecEnv from OpenAI Baselines
# https://github.com/openai/baselines/blob/master/baselines/common/vec_env/subproc_vec_env.py
class ADParallelRunner:

    def __init__(self, args, logger):
        self.args = args
        self.logger = logger
        self.batch_size = self.args.batch_size_run

        # Make subprocesses for the envs
        self.parent_conns, self.worker_conns = zip(*[Pipe() for _ in range(self.batch_size)])
        env_fn = env_REGISTRY[self.args.env]
        self.ps = [Process(target=env_worker, args=(worker_conn, CloudpickleWrapper(partial(env_fn, **self.args.env_args))))
                            for worker_conn in self.worker_conns]

        for p in self.ps:
            p.daemon = True
            p.start()

        self.parent_conns[0].send(("get_env_info", None))
        self.env_info = self.parent_conns[0].recv()
        self.episode_limit = self.env_info["episode_limit"]

        self.t = 0

        self.t_env = 0

        self.returns = {}
        self.stats = {}

        self.log_train_stats_t = -100000
        self.battle_info = {'battle_won': 0, 'dead_allies': 0, 'dead_enemies': 0}

    def setup(self, scheme, groups, preprocess, mac):
        self.new_batch = partial(EpisodeBatch, scheme, groups, self.batch_size, self.episode_limit + 1,
                                 preprocess=preprocess, device=self.args.device)
        self.mac = mac
        self.scheme = scheme
        self.groups = groups
        self.preprocess = preprocess

    def get_env_info(self):
        return self.env_info

    def save_replay(self):
        pass

    def close_env(self):
        for parent_conn in self.parent_conns:
            parent_conn.send(("close", None))

    def reset(self):
        self.batch = self.new_batch()

        # Reset the envs
        for parent_conn in self.parent_conns:
            parent_conn.send(("reset", None))

        pre_transition_data = {
            "state": [],
            "avail_actions": [],
            "obs": [],
            "obs_origin": []
        }
        # Get the obs, state and avail_actions back
        for parent_conn in self.parent_conns:
            data = parent_conn.recv()
            pre_transition_data["state"].append(data["state"])
            pre_transition_data["avail_actions"].append(data["avail_actions"])
            pre_transition_data["obs"].append(data["obs"])
            pre_transition_data["obs_origin"].append(data["obs"])

        self.batch.update(pre_transition_data, ts=0)

        self.t = 0
        self.env_steps_this_run = 0

    def run(self, test_mode=False, tag='train', print_attack_obs=False):
        self.reset()
        self.battle_info = {'battle_won': 0, 'dead_allies': 0, 'dead_enemies': 0}

        all_terminated = False
        pretrain = True if tag == 'pretrain' else False
        episode_returns = [0 for _ in range(self.batch_size)]
        episode_lengths = [0 for _ in range(self.batch_size)]
        self.mac.init_hidden(batch_size=self.batch_size)
        terminated = [False for _ in range(self.batch_size)]
        envs_not_terminated = [b_idx for b_idx, termed in enumerate(terminated) if not termed]
        final_env_infos = []  # may store extra stats like battle won. this is filled in ORDER OF TERMINATION
        attack_id_arr = np.zeros((self.batch_size, self.args.n_agents + 1))

        while True:

            # Pass the entire batch of experiences up till now to the agents
            # Receive the actions for each agent at this timestep in a batch for each un-terminated env
           

            if tag[0:3] == 'def':
                actions, ad_discrete_actions, ad_continuous_actions, ad_discrete_emb, ad_continuous_emb = self.mac.select_actions(self.batch, t_ep=self.t, t_env=self.t_env, pretrain=pretrain, bs=envs_not_terminated, test_mode=test_mode, attack_mode=int(tag[3:]), print_attack_obs=print_attack_obs)
            else: 
                actions, ad_discrete_actions, ad_continuous_actions, ad_discrete_emb, ad_continuous_emb = self.mac.select_actions(self.batch, t_ep=self.t, t_env=self.t_env, pretrain=pretrain, bs=envs_not_terminated, test_mode=test_mode)
            
            cpu_actions = actions.to("cpu").numpy()

            # log agent id attacked
            for idxidx, env_id in enumerate(envs_not_terminated):
                attack_id_arr[env_id][ad_discrete_actions.flatten().clone().detach().cpu().numpy()[idxidx]] += 1

            # Update the actions taken
            actions_chosen = {
                "actions": actions.unsqueeze(1), 
                "ad_discrete_actions": ad_discrete_actions,
                "ad_continuous_actions": ad_continuous_actions,
                "ad_discrete_emb": ad_discrete_emb,
                "ad_continuous_emb": ad_continuous_emb
            }
            self.batch.update(actions_chosen, bs=envs_not_terminated, ts=self.t, mark_filled=False)

            # Send actions to each env
            action_idx = 0
            for idx, parent_conn in enumerate(self.parent_conns):
                if idx in envs_not_terminated: # We produced actions for this env
                    if not terminated[idx]: # Only send the actions to the env if it hasn't terminated
                        parent_conn.send(("step", cpu_actions[action_idx]))
                    action_idx += 1 # actions is not a list over every env

            # Update envs_not_terminated
            envs_not_terminated = [b_idx for b_idx, termed in enumerate(terminated) if not termed]
            all_terminated = all(terminated)
            if all_terminated:
                break

            # Post step data we will insert for the current timestep
            post_transition_data = {
                "reward": [],
                "ad_reward": [],
                "terminated": []
            }
            # Data for the next step we will insert in order to select an action
            pre_transition_data = {
                "state": [],
                "avail_actions": [],
                "obs": [], 
                "obs_origin": []
            }
            
            if vars(self.args).get("defense_loaded", False):
                pre_transition_data["obs_perturbed"] = [] 
            # Receive data back for each unterminated env
            not_terminated_index = 0
            for idx, parent_conn in enumerate(self.parent_conns):
                if not terminated[idx]:
                    data = parent_conn.recv()
                    # Remaining data for this current timestep
                    post_transition_data["reward"].append((data["reward"],))
                    post_transition_data["ad_reward"].append(((data["reward"] - max(abs(ad_continuous_actions[not_terminated_index]) / (th.norm(ad_continuous_actions[not_terminated_index], p=1) + 1)).item()) * self.args.ad_policy_reward_scale,))
                    data["info"]["max_attack"] = max(abs(ad_continuous_actions[not_terminated_index]) / (th.norm(ad_continuous_actions[not_terminated_index], p=1) + 1e-6)).item()
                    if 'battle_won' in data['info'].keys():
                        self.battle_info['battle_won'] += data['info']['battle_won']
                        self.battle_info['dead_allies'] += data['info']['dead_allies']
                        self.battle_info['dead_enemies'] += data['info']['dead_enemies']

                    episode_returns[idx] += data["reward"]
                    episode_lengths[idx] += 1
                    if not tag == 'test':
                        self.env_steps_this_run += 1

                    env_terminated = False
                    if data["terminated"]:
                        for i in range(self.args.n_agents):
                            data["info"][f"attack_{i}_ratio"] = attack_id_arr[idx][i] / sum(attack_id_arr[idx])
                        data["info"]["no_attack_ratio"] = attack_id_arr[idx][self.args.n_agents] / sum(attack_id_arr[idx])
                        final_env_infos.append(data["info"])
                    if data["terminated"] and not data["info"].get("episode_limit", False):
                        env_terminated = True
                    terminated[idx] = data["terminated"]
                    post_transition_data["terminated"].append((env_terminated,))

                    # Data for the next timestep needed to select an action
                    pre_transition_data["state"].append(data["state"])
                    pre_transition_data["avail_actions"].append(data["avail_actions"])
                    pre_transition_data["obs"].append(data["obs"])
                    pre_transition_data["obs_origin"].append(data["obs"])
                    if vars(self.args).get("defense_loaded", False):
                        pre_transition_data["obs_perturbed"].append(data["obs"])

                    not_terminated_index += 1

            # Add post_transiton data into the batch
            self.batch.update(post_transition_data, bs=envs_not_terminated, ts=self.t, mark_filled=False)

            # Move onto the next timestep
            self.t += 1

            # Add the pre-transition data
            self.batch.update(pre_transition_data, bs=envs_not_terminated, ts=self.t, mark_filled=True)

        # Get stats back for each env
        for parent_conn in self.parent_conns:
            parent_conn.send(("get_stats",None))

        env_stats = []
        for parent_conn in self.parent_conns:
            env_stat = parent_conn.recv()
            env_stats.append(env_stat)

        self.returns[tag] = self.returns.get(tag, []) + episode_returns
        if tag not in self.stats:
            self.stats[tag] = {}
        infos = [self.stats[tag]] + final_env_infos
        self.stats[tag].update({ k: sum(d.get(k, 0) for d in infos) for k in set.union(*[set(d) for d in infos]) })
        self.stats[tag]["n_episodes"] = self.batch_size + self.stats[tag].get("n_episodes", 0)
        self.stats[tag]["ep_length"] = sum(episode_lengths) + self.stats[tag].get("ep_length", 0)
        if not tag == 'test':
            self.t_env += self.env_steps_this_run

        return self.batch

    def log_info(self, tag="train"):
        log_dic = {}
        log_dic["return_mean"] = float(np.mean(self.returns[tag]))
        for k, v in self.stats[tag].items():
            if k != "n_episodes":
                log_dic[f"{k}_mean"] = float(v/self.stats[tag]["n_episodes"])
        if hasattr(self.mac, 'action_selector') and hasattr(self.mac.action_selector, "epsilon"):
                log_dic['epsilon'] = float(self.mac.action_selector.epsilon)
        self.stats[tag].clear()
        self.returns[tag].clear()

        log_dic = {f"{tag}/{k}": v for k, v in log_dic.items()}
        for k, v in log_dic.items():
            self.logger.log_stat(k, v, self.t_env)
        return log_dic


def env_worker(remote, env_fn):
    # Make environment
    env = env_fn.x()
    while True:
        cmd, data = remote.recv()
        if cmd == "step":
            actions = data
            # Take a step in the environment
            reward, terminated, env_info = env.step(actions)
            # Return the observations, avail_actions and state to make the next action
            state = env.get_state()
            avail_actions = env.get_avail_actions()
            obs = env.get_obs()
            remote.send({
                # Data for the next timestep needed to pick an action
                "state": state,
                "avail_actions": avail_actions,
                "obs": obs,
                # Rest of the data for the current timestep
                "reward": reward,
                "terminated": terminated,
                "info": env_info
            })
        elif cmd == "reset":
            env.reset()
            remote.send({
                "state": env.get_state(),
                "avail_actions": env.get_avail_actions(),
                "obs": env.get_obs()
            })
        elif cmd == "close":
            env.close()
            remote.close()
            break
        elif cmd == "get_env_info":
            remote.send(env.get_env_info())
        elif cmd == "get_stats":
            remote.send(env.get_stats())
        else:
            raise NotImplementedError


class CloudpickleWrapper():
    """
    Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle)
    """
    def __init__(self, x):
        self.x = x
    def __getstate__(self):
        import cloudpickle
        return cloudpickle.dumps(self.x)
    def __setstate__(self, ob):
        import pickle
        self.x = pickle.loads(ob)

