from ast import arg
from copy import deepcopy
from modules.agents import REGISTRY as agent_REGISTRY
import torch as th
import numpy as np
from controllers import REGISTRY as mac_REGISTRY
import yaml
import os
import collections
import argparse
import wandb


def recursive_dict_update(d, u):
    for k, v in u.items():
        if isinstance(v, collections.abc.Mapping):
            d[k] = recursive_dict_update(d.get(k, {}), v)
        else:
            d[k] = v
    return d


def load_config(config_name, subfolder):
    with open(os.path.join(os.path.dirname(os.path.dirname(__file__)), "config", subfolder, "{}.yaml".format(config_name)), "r") as f:
        try:
            config_dict = yaml.load(f, Loader=yaml.CLoader)
        except AttributeError:
            config_dict = yaml.load(f, Loader=yaml.Loader)
        except yaml.YAMLError as exc:
            assert False, "{}.yaml error: {}".format(config_name, exc)

    return config_dict

# This multi-agent controller shares parameters between agents


class MetaMAC:
    def __init__(self, scheme, groups, args):
        self.args = args
        self.scheme = scheme
        self.groups = groups

        self.attacker_args = deepcopy(vars(args))
        config_dict = load_config(args.attacker, "algs")
        if args.use_wandb:
            config = wandb.config
            config[args.attacker] = config_dict
        self.attacker_args = recursive_dict_update(
            self.attacker_args, config_dict)
        self.attacker_args = argparse.Namespace(**self.attacker_args)
        self.defender_args = deepcopy(vars(args))
        config_dict = load_config(args.defender, "algs")
        if args.use_wandb:
            config[args.defender] = config_dict
        self.defender_args = recursive_dict_update(
            self.defender_args, config_dict)
        self.defender_args = argparse.Namespace(**self.defender_args)
        # attacker can have population
        self.controller_dict = {
            "defender": mac_REGISTRY[self.defender_args.mac](scheme, groups, self.defender_args),
            "attacker": collections.deque(maxlen=10),
        }
        self.controller_dict["attacker"].append(mac_REGISTRY[self.attacker_args.mac](scheme, groups, self.attacker_args))
        self.chosen_attacker_idx = 0

    def init_hidden(self, batch_size):
        # when training defenders, randomly select an attacker
        self.chosen_attacker_idx = np.random.randint(len(self.controller_dict["attacker"]))
        self.controller_dict["defender"].init_hidden(batch_size)

    def select_actions(self, ep_batch, t_ep, t_env, pretrain=False, bs=slice(None), test_mode=False):
        attacker_learn = True if (t_env // self.args.switch_interval) % 2 == 0 and t_env < self.args.attacker_stop else False
        if attacker_learn:
            perturb_agent_id, perturb_action, discrete_emb, parameter_emb = self.controller_dict["attacker"][0].select_actions(
                ep_batch, t_ep, t_env, pretrain, bs=bs, test_mode=test_mode)
        else:
            # random select a history attacker
            perturb_agent_id, perturb_action, discrete_emb, parameter_emb = self.controller_dict["attacker"][self.chosen_attacker_idx].select_actions(
                ep_batch, t_ep, t_env, pretrain, bs=bs, test_mode=True)
        # NOTE: broadcast
        perturb_action[perturb_agent_id.view(-1) >= self.groups["agents"]] = 0
        # NOTE: perturba_agent_id > self.groups["agents"] means no attack
        perturb_agent_id_clip = perturb_agent_id.clamp(0, self.groups["agents"] - 1)
        # perturb_action = perturb_action.clamp(
        #     -self.attacker_args.perturbation_range, self.attacker_args.perturbation_range)
        idx = perturb_agent_id_clip.repeat(
            1, perturb_action.size()[-1]).view(perturb_action.size()[0], 1, -1)
        val = perturb_action.view(perturb_action.size()[0], 1, -1)
        ep_batch['obs'][bs, t_ep, :] = ep_batch['obs'][bs,
                                                       t_ep, :].scatter_(1, idx, val, reduce="add")
        if not attacker_learn:
            if vars(self.args).get("defense_loaded", False):
                chosen_actions, _ = self.controller_dict["defender"].select_actions(
                    ep_batch, t_ep, t_env, bs=bs, test_mode=test_mode)
            else:
                chosen_actions = self.controller_dict["defender"].select_actions(
                    ep_batch, t_ep, t_env, bs=bs, test_mode=test_mode)
        else:
            if vars(self.args).get("defense_loaded", False):
                chosen_actions, _ = self.controller_dict["defender"].select_actions(
                    ep_batch, t_ep, t_env, bs=bs, test_mode=True)
            else:
                chosen_actions = self.controller_dict["defender"].select_actions(
                    ep_batch, t_ep, t_env, bs=bs, test_mode=True)
        return chosen_actions, perturb_agent_id, perturb_action, discrete_emb, parameter_emb
    
    def add_test_attacker(self):
        self.controller_dict["attacker"].append(deepcopy(self.controller_dict["attacker"][0]))
    
    def reset_attacker_exploration(self):
        self.controller_dict["attacker"][0].reset_attacker_exploration()

    def reinit(self):
        self.controller_dict['attacker'][0] = mac_REGISTRY[self.attacker_args.mac](self.scheme, self.groups, self.attacker_args)
