from envs import REGISTRY as env_REGISTRY
from functools import partial
from components.episode_buffer import EpisodeBatch

import numpy as np
import random
import torch as th

class EpisodeRunner:
    def __init__(self, args, logger):
        self.args = args
        self.logger = logger

        self.batch_size = self.args.batch_size_run
        assert self.batch_size == 1

        self.env = env_REGISTRY[self.args.env](**self.args.env_args)
        self.episode_limit = self.env.episode_limit
        self.t = 0
        self.t_env = 0

        self.train_returns = []
        self.test_returns = []
        self.test_returns_no_attack = []
        self.train_stats = {}
        self.test_stats = {}
        self.test_stats_no_attack = {}

        self.log_train_stats_t = -1000000

    def setup(self, scheme, groups, preprocess, mac, mi_model, mi_model_obs):
        self.new_batch = partial(
            EpisodeBatch, scheme, groups, self.batch_size, self.episode_limit + 1,
            preprocess=preprocess, device=self.args.device
        )
        self.mac = mac
        self.mi_model = mi_model
        self.mi_model_obs = mi_model_obs

    def get_env_info(self):
        return self.env.get_env_info()

    def save_replay(self):
        self.env.save_replay()

    def close_env(self):
        self.env.close()

    def reset(self):
        self.batch = self.new_batch()
        self.env.reset()
        self.t = 0

    def _get_obs_layout(self, map_name: str):
        if map_name in ["8m", "3m", "3s_vs_3z", "corridor"]:
            return 4, 5, 5
        elif map_name in ["2s3z", "1c3s5z"]:
            return 4, 5, 8
        elif map_name == "MMM":
            return 4, 10, 8
        else:
            return 4, 5, 5
    
    def _ally_block_range(self, obs_dim: int, target_i: int, observer_j: int, map_name: str):
        if target_i == observer_j:
            return None

        n_agents = self.args.n_agents
        n_enemies = getattr(self.env, "n_enemies", None)

        move_feats, enemy_feat_per, ally_feat_per = self._get_obs_layout(map_name)

        enemy_feats = enemy_feat_per * n_enemies
        ally_base = move_feats + enemy_feats

        ally_slot = target_i if target_i < observer_j else target_i - 1 

        start = ally_base + ally_slot * ally_feat_per
        end = min(start + ally_feat_per, obs_dim)
        if start >= obs_dim:
            return None
        return (start, end)
    
    def _build_mi_obs_mask_plan(self, group_1, group_2, obs_dim: int, map_name: str):
        score = self.mi_model_obs.get_scores()           
        score = score.detach().to("cpu")

        N = self.args.n_agents
        Do = score.size(-1)
        Do_eff = min(Do, obs_dim)

        _, _, ally_feat_per = self._get_obs_layout(map_name)

        k_per_pair = int(getattr(self.args, "obs_dim_num", -1))
        if k_per_pair <= 0:
            k_per_pair = ally_feat_per

        banned_global = set([0, 1, 2, 3, Do_eff - 1])

        pairs = []
        for gi in group_1:
            for gj in group_2:
                pairs.append((gi, gj))  
        for gj in group_2:
            for gi in group_1:
                pairs.append((gj, gi)) 

        plan = {j: [] for j in range(N)}
        meta_counts = {j: 0 for j in range(N)}

        for target_i, observer_j in pairs:
            rng = self._ally_block_range(obs_dim, target_i, observer_j, map_name)
            if rng is None:
                continue
            start, end = rng
            end = min(end, Do_eff)
            if end <= start:
                continue

            block = score[observer_j, target_i, start:end] 
            L = end - start

            allowed = th.ones(L, dtype=th.bool)
            for gdim in banned_global:
                if start <= gdim < end:
                    allowed[gdim - start] = False

            allowed_idx = th.nonzero(allowed, as_tuple=False).squeeze(-1)
            if allowed_idx.numel() == 0:
                continue

            k = min(k_per_pair, int(allowed_idx.numel()))
            block_allowed = block[allowed_idx] 

            top_local = th.topk(block_allowed, k=k, largest=True).indices  
            top_idx = allowed_idx[top_local] + start                      
            dims = top_idx.tolist()

            plan[observer_j].extend(dims)
            meta_counts[observer_j] += k

        for j in plan:
            if len(plan[j]) > 0:
                plan[j] = sorted(set([d for d in plan[j] if d not in banned_global]))

        meta = {
            "k_per_pair": k_per_pair,
            "ally_feat_per": ally_feat_per,
            "banned_global": sorted(list(banned_global)),
            "masked_dims_per_agent_rawsum": meta_counts,
        }
        return plan, meta
    
    def _build_random_obs_mask_plan(self, group_1, group_2, obs_dim: int, map_name: str):
        N = self.args.n_agents

        _, _, ally_feat_per = self._get_obs_layout(map_name)

        k_per_pair = int(getattr(self.args, "obs_dim_num", -1))
        if k_per_pair <= 0:
            k_per_pair = ally_feat_per

        banned_global = set([])

        pairs = []
        for gi in group_1:
            for gj in group_2:
                pairs.append((gi, gj))  
        for gj in group_2:
            for gi in group_1:
                pairs.append((gj, gi))  

        plan = {j: [] for j in range(N)}
        meta_counts = {j: 0 for j in range(N)}

        for target_i, observer_j in pairs:
            rng = self._ally_block_range(obs_dim, target_i, observer_j, map_name)
            if rng is None:
                continue
            start, end = rng
            end = min(end, obs_dim)
            if end <= start:
                continue

            cand = [d for d in range(start, end) if d not in banned_global]
            if len(cand) == 0:
                continue

            k = min(k_per_pair, len(cand))
            picked = random.sample(cand, k)

            plan[observer_j].extend(picked)
            meta_counts[observer_j] += k

        for j in plan:
            if len(plan[j]) > 0:
                plan[j] = sorted(set([d for d in plan[j] if d not in banned_global]))

        meta = {
            "k_per_pair": k_per_pair,
            "ally_feat_per": ally_feat_per,
            "banned_global": sorted(list(banned_global)),
            "masked_dims_per_agent_rawsum": meta_counts,
        }
        return plan, meta
    
    def _remove_agent_info_generic(
        self,
        all_obs,
        i,
        j,
        move_feats,
        enemy_feat_per,
        ally_feat_per,
        fill_value=0.0,
        ally_slot_map=None,
        map_name_for_log="",
    ):
        if i == j:
            return all_obs

        n_agents, obs_dim = all_obs.shape

        n_enemies = getattr(self.env, "n_enemies", None)

        n_allies = n_agents - 1
        enemy_feats = enemy_feat_per * n_enemies
        ally_block = ally_feat_per * n_allies
        min_needed = move_feats + enemy_feats + ally_block

        ally_base = move_feats + enemy_feats

        if ally_slot_map is not None:
            ally_slot = ally_slot_map[j][i]
        else:
            ally_slot = i if i < j else i - 1

        start = ally_base + ally_slot * ally_feat_per
        end = min(start + ally_feat_per, obs_dim)
        if start < obs_dim:
            all_obs[j, start:end] = fill_value
        return all_obs

    def remove_agent_info_8m(self, all_obs, i, j, fill_value=0.0, ally_slot_map=None):
        return self._remove_agent_info_generic(all_obs, i, j, 4, 5, 5, fill_value, ally_slot_map, "8m")

    def remove_agent_info_3m(self, all_obs, i, j, fill_value=0.0, ally_slot_map=None):
        return self._remove_agent_info_generic(all_obs, i, j, 4, 5, 5, fill_value, ally_slot_map, "3m")

    def remove_agent_info_3s_vs_3z(self, all_obs, i, j, fill_value=0.0, ally_slot_map=None):
        return self._remove_agent_info_generic(all_obs, i, j, 4, 5, 5, fill_value, ally_slot_map, "3s_vs_3z")

    def remove_agent_info_2s3z(self, all_obs, i, j, fill_value=0.0, ally_slot_map=None):
        return self._remove_agent_info_generic(all_obs, i, j, 4, 5, 8, fill_value, ally_slot_map, "2s3z")

    def remove_agent_info_1c3s5z(self, all_obs, i, j, fill_value=0.0, ally_slot_map=None):
        return self._remove_agent_info_generic(all_obs, i, j, 4, 5, 8, fill_value, ally_slot_map, "1c3s5z")

    def remove_agent_info_MMM(self, all_obs, i, j, fill_value=0.0, ally_slot_map=None):
        return self._remove_agent_info_generic(all_obs, i, j, 4, 10, 8, fill_value, ally_slot_map, "MMM")

    def remove_agent_info_corridor(self, all_obs, i, j, fill_value=0.0, ally_slot_map=None):
        return self._remove_agent_info_generic(all_obs, i, j, 4, 5, 5, fill_value, ally_slot_map, "corridor")

    def run(
        self,
        fixed_K=0,
        eval_mode=False,
        test_mode=False,
        obs_attack=False,
        action_attack=False,
        remove_attack=False,
        action_attack_pro_min=None,
        action_attack_pro_max=None,
    ):
        self.reset()
        self.mac.init_hidden(batch_size=self.batch_size)

        n_agents = self.args.n_agents
        max_pick = min(self.args.K, n_agents)

        if eval_mode:
            k = int(fixed_K)
            remove_agents = random.sample(range(self.args.n_agents), fixed_K)
        else:
            k = random.randint(0, max_pick)

        group_1 = random.sample(range(n_agents), k) if k > 0 else []
        group_2 = [i for i in range(n_agents) if i not in group_1]

        group_1_mask = np.zeros(n_agents, dtype=np.uint8)
        group_2_mask = np.zeros(n_agents, dtype=np.uint8)
        if k > 0:
            group_1_mask[group_1] = 1
        if len(group_2) > 0:
            group_2_mask[group_2] = 1

        action_attack_pro = 0.0
        if action_attack:
            if self.args.action_pro_sampling:
                if (k == 0) or remove_attack:
                    action_attack_pro = 0.0
                else:
                    lo = action_attack_pro_min[k] if action_attack_pro_min is not None else 1.0 / (k)
                    hi = action_attack_pro_max[k] if action_attack_pro_max is not None else 1.0 / (k)

                    action_attack_pro = random.uniform(lo, hi)
            else:
                action_attack_pro = 1.0

        obs_mask_plan = None
        obs_mask_meta = None
        map_name = self.args.env_args.get("map_name")

        use_mi_obs = bool(getattr(self.args, "obs_attack_use_mi", True)) and (self.mi_model_obs is not None)

        if obs_attack and (k > 0):
            obs_dim = int(np.asarray(self.env.get_obs(), dtype=np.float32).shape[1])

            if self.args.obs_random:
                obs_mask_plan, obs_mask_meta = self._build_random_obs_mask_plan(
                    group_1=group_1, group_2=group_2, obs_dim=obs_dim, map_name=map_name
                )

            elif use_mi_obs:
                obs_mask_plan, obs_mask_meta = self._build_mi_obs_mask_plan(
                    group_1=group_1, group_2=group_2, obs_dim=obs_dim, map_name=map_name
                )
            
        terminated = False
        episode_return = 0.0

        while not terminated:
            if remove_attack == True:
                for i in remove_agents:
                    self.env.remove_selected_agent(i)

            avail_actions = np.asarray(self.env.get_avail_actions(), dtype=np.int64)  
            obs_ori = np.asarray(self.env.get_obs(), dtype=np.float32)               
            state_temp = np.asarray(self.env.get_state(), dtype=np.float32)         

            attacked_obs = obs_ori
            if obs_attack and k > 0:
                attacked_obs = obs_ori.copy()

                if (obs_mask_plan is not None):
                    for observer_j, dims in obs_mask_plan.items():
                        if len(dims) == 0:
                            continue
                        attacked_obs[observer_j, dims] = 0.0

                else:
                    map_name = self.args.env_args.get("map_name")

                    def _apply_remove(fn):
                        nonlocal attacked_obs
                        for gi in group_1:
                            for gj in group_2:
                                attacked_obs = fn(attacked_obs, gi, gj)
                        for gj in group_2:
                            for gi in group_1:
                                attacked_obs = fn(attacked_obs, gj, gi)

                    if map_name == "MMM":
                        _apply_remove(self.remove_agent_info_MMM)
                    elif map_name == "8m":
                        _apply_remove(self.remove_agent_info_8m)
                    elif map_name == "3m":
                        _apply_remove(self.remove_agent_info_3m)
                    elif map_name == "3s_vs_3z":
                        _apply_remove(self.remove_agent_info_3s_vs_3z)
                    elif map_name == "2s3z":
                        _apply_remove(self.remove_agent_info_2s3z)
                    elif map_name == "1c3s5z":
                        _apply_remove(self.remove_agent_info_1c3s5z)
                    elif map_name == "corridor":
                        _apply_remove(self.remove_agent_info_corridor)

            pre_transition_data = {
                "state": [state_temp],
                "avail_actions": [avail_actions],
                "obs": [obs_ori],
                "attacked_obs": [attacked_obs],
                "group_1": [group_1_mask],
                "group_2": [group_2_mask],
            }
            self.batch.update(pre_transition_data, ts=self.t)

            ret = self.mac.select_actions_attack(
                self.batch, t_ep=self.t, t_env=self.t_env, test_mode=test_mode
            )
            if isinstance(ret, (tuple, list)) and len(ret) == 2:
                actions, history = ret

            if not th.is_tensor(history):
                history_store = th.as_tensor(history, dtype=th.float32, device=self.args.device)
            else:
                history_store = history.detach()
                if history_store.dtype != th.float32:
                    history_store = history_store.float()
                if str(history_store.device) != str(self.args.device):
                    history_store = history_store.to(self.args.device)

            if history_store.dim() == 3:
                history_store = history_store[0]  

            if th.is_tensor(actions):
                actions_store = actions.clone()
                ori_joint = actions
                if ori_joint.dim() == 3:
                    ori_joint = ori_joint.squeeze(-1)
                ori_joint = ori_joint.long()
            else:
                ori_joint = th.as_tensor(np.array(actions).reshape(1, -1), device=self.args.device).long()
                actions_store = ori_joint.unsqueeze(-1) 

            exec_joint = ori_joint 
            if action_attack and (k > 0) and (action_attack_pro > 0.0):
                device = th.device(self.args.device)

                exec_joint = ori_joint.clone().to(device)

                idx = th.tensor(group_1, device=device, dtype=th.long)
                if action_attack_pro >= 1.0:
                    idx_apply = idx
                else:
                    idx_apply = idx[(th.rand(idx.numel(), device=device) < action_attack_pro)]

                if idx_apply.numel() > 0:
                    avail_b = th.as_tensor(avail_actions, device=device).bool()  

                    if self.args.action_random:
                        for a_id in idx_apply.tolist():
                            legal = th.nonzero(avail_b[a_id], as_tuple=False).squeeze(-1) 
                            if legal.numel() == 0:
                                continue
                            pick = legal[th.randint(0, legal.numel(), (1,), device=device)]
                            exec_joint[0, a_id] = pick.item()

                    else:
                        history_b = history_store.unsqueeze(0) 
                        g1_mask_b = th.as_tensor(group_1_mask, device=device, dtype=th.uint8).view(1, -1)
                        g2_mask_b = th.as_tensor(group_2_mask, device=device, dtype=th.uint8).view(1, -1)
                        avail_b2 = th.as_tensor(avail_actions, device=device).unsqueeze(0)  

                        with th.no_grad():
                            mi_best_joint, _ = self.mi_model.select_min_mi_actions_per_agent(
                                history=history_b,
                                g1_index_mask=g1_mask_b,
                                g2_index_mask=g2_mask_b,
                                base_actions_all=ori_joint.to(device),
                                avail_actions=avail_b2,
                                g2_action_mask=avail_b2,
                            ) 
                        exec_joint[0, idx_apply] = mi_best_joint[0, idx_apply]

                exec_joint = exec_joint.to(ori_joint.device)
            reward, terminated, env_info = self.env.step(exec_joint[0])
            episode_return += float(reward)

            if th.is_tensor(actions_store):
                if actions_store.dim() == 2:
                    actions_store = actions_store.unsqueeze(-1)
            else:
                actions_store = th.as_tensor(np.array(actions_store), device=self.args.device).long()
                if actions_store.dim() == 2:
                    actions_store = actions_store.unsqueeze(-1)

            post_transition_data = {
                "actions": actions_store,
                "history": history_store,  
                "reward": [(reward,)],
                "terminated": [(terminated != env_info.get("episode_limit", False),)],
            }
            self.batch.update(post_transition_data, ts=self.t)

            self.t += 1

        last_data = {
            "state": [self.env.get_state()],
            "avail_actions": [self.env.get_avail_actions()],
            "obs": [self.env.get_obs()],
        }
        self.batch.update(last_data, ts=self.t)

        actions_last = self.mac.select_actions(self.batch, t_ep=self.t, t_env=self.t_env, test_mode=test_mode)
        self.batch.update({"actions": actions_last}, ts=self.t)

        if not eval_mode:
            if obs_attack is False and action_attack is False and remove_attack is False:
                cur_stats = self.test_stats_no_attack
                cur_returns = self.test_returns_no_attack
                log_prefix = "test_no_attack_" if test_mode else ""
            else:
                cur_stats = self.test_stats if test_mode else self.train_stats
                cur_returns = self.test_returns if test_mode else self.train_returns
                log_prefix = "test_" if test_mode else ""

            cur_stats.update({k: cur_stats.get(k, 0) + env_info.get(k, 0) for k in set(cur_stats) | set(env_info)})
            cur_stats["n_episodes"] = 1 + cur_stats.get("n_episodes", 0)
            cur_stats["ep_length"] = self.t + cur_stats.get("ep_length", 0)

            if not test_mode:
                self.t_env += self.t

            cur_returns.append(episode_return)

            if test_mode and (len(self.test_returns) == self.args.test_nepisode):
                self._log(cur_returns, cur_stats, log_prefix)
            elif self.t_env - self.log_train_stats_t >= self.args.runner_log_interval:
                self._log(cur_returns, cur_stats, log_prefix)
                if hasattr(self.mac.action_selector, "epsilon"):
                    self.logger.log_stat("epsilon", self.mac.action_selector.epsilon, self.t_env)
                self.log_train_stats_t = self.t_env

        battle_won = env_info.get("battle_won", 0)
        return self.batch, battle_won

    def _log(self, returns, stats, prefix):
        self.logger.log_stat(prefix + "return_mean", np.mean(returns), self.t_env)
        self.logger.log_stat(prefix + "return_std", np.std(returns), self.t_env)
        returns.clear()

        n_eps = stats.get("n_episodes", 1)
        for k, v in stats.items():
            if k != "n_episodes":
                self.logger.log_stat(prefix + k + "_mean", v / n_eps, self.t_env)
        stats.clear()