from ast import arg
from copy import deepcopy
from modules.agents import REGISTRY as agent_REGISTRY
import torch as th
import numpy as np
from controllers import REGISTRY as mac_REGISTRY
import yaml
import os
import collections
import argparse
import wandb

def recursive_dict_update(d, u):
    for k, v in u.items():
        if isinstance(v, collections.abc.Mapping):
            d[k] = recursive_dict_update(d.get(k, {}), v)
        else:
            d[k] = v
    return d


def load_config(config_name, subfolder):
    with open(os.path.join(os.path.dirname(os.path.dirname(__file__)), "config", subfolder, "{}.yaml".format(config_name)), "r") as f:
        try:
            config_dict = yaml.load(f, Loader=yaml.CLoader)
        except AttributeError:
            config_dict = yaml.load(f, Loader=yaml.Loader)
        except yaml.YAMLError as exc:
            assert False, "{}.yaml error: {}".format(config_name, exc)

    return config_dict

# This multi-agent controller shares parameters between agents


class DefenderMAC:
    def __init__(self, scheme, groups, args):
        self.args = args
        self.scheme = scheme
        self.groups = groups

        self.attacker_args = deepcopy(vars(args))
        # config_dict = load_config(args.attacker, "algs")

        #eval config
        eval_config_dict = load_config('eval', 'algs')
        self.eval_args = deepcopy(eval_config_dict)
        self.eval_args = recursive_dict_update(self.eval_args, eval_config_dict)
        self.eval_args = argparse.Namespace(**self.eval_args)
        

        if args.use_wandb:
            config = wandb.config
            # config[args.attacker] = config_dict
            config['def_eval'] = eval_config_dict 
            
        # self.attacker_args = recursive_dict_update(
        #     self.attacker_args, config_dict)
        # self.attacker_args = argparse.Namespace(**self.attacker_args)
        self.defender_args = deepcopy(vars(args))
        config_dict = load_config(self.eval_args.defender, "algs")
        if args.use_wandb:
            config[self.eval_args.defender] = config_dict

        self.defender_args = recursive_dict_update(
            self.defender_args, config_dict)
        self.defender_args = argparse.Namespace(**self.defender_args)
        # attacker can have population
        self.defender = mac_REGISTRY[self.defender_args.mac](scheme, groups, self.defender_args)

    def add_attacker(self, path):
        self.attacker_list.append(self.attacker_list[0])

    def init_hidden(self, batch_size):
        # when training defenders, randomly select an attacker
        self.defender.init_hidden(batch_size)

    def select_actions(self, ep_batch, t_ep, t_env, pretrain=False, bs=slice(None), test_mode=False, attack_mode=0, attack_obs_dim=0, print_attack_obs=False):

        perturb_agent_id = th.randint(
                0, self.args.n_agents + 1, (ep_batch.batch_size, 1)).to(ep_batch.device)[bs]
        perturb_action = ((th.rand((ep_batch.batch_size, ep_batch['obs'].shape[-1])).to(
                ep_batch.device) - 0.5) * 2)[bs]
        discrete_emb = th.zeros(ep_batch.batch_size, self.eval_args.discrete_action_dim).to(ep_batch.device)[bs]
        parameter_emb = th.zeros(ep_batch.batch_size, self.eval_args.parameter_action_dim).to(ep_batch.device)[bs]
        attack_obs_dim = np.random.randint(0, perturb_action.size()[1])    

        if attack_mode == 1:
            perturb_action_id = th.randint(
                0, 1, (ep_batch.batch_size, 1)).to(ep_batch.device)[bs]
            perturb_action = ((th.rand((ep_batch.batch_size, ep_batch['obs'].shape[-1])).to(
                ep_batch.device) - 0.5) * 2)[bs]
            perturb_action = perturb_action / th.linalg.norm(perturb_action, ord=1, dim=1, keepdim=True).repeat(1, perturb_action.size()[1]) * self.eval_args.perturbation_range


        if attack_mode == 2:
            perturb_action_id = th.randint(
                0, self.args.n_agents, (ep_batch.batch_size, 1)).to(ep_batch.device)[bs]
            perturb_action = ((th.rand((ep_batch.batch_size, ep_batch['obs'].shape[-1])).to(
                ep_batch.device) - 0.5) * 2)[bs]
            perturb_action = perturb_action / th.linalg.norm(perturb_action, ord=1, dim=1, keepdim=True).repeat(1, perturb_action.size()[1]) * self.eval_args.perturbation_range

        if attack_mode == 3:
            perturb_agent_id = th.randint(
                0, 1, (ep_batch.batch_size, 1)).to(ep_batch.device)[bs]
            perturb_action = ((th.zeros((ep_batch.batch_size, ep_batch['obs'].shape[-1])).to(
                ep_batch.device)))[bs]
            perturb_action[:, 4] = 1.0
            perturb_action = perturb_action / th.linalg.norm(perturb_action, ord=1, dim=1, keepdim=True).repeat(1, perturb_action.size()[1]) * self.eval_args.perturbation_range

    
        
        # NOTE: broadcast
        perturb_action[perturb_agent_id.view(-1) >= self.groups["agents"]] = 0
        # NOTE: perturba_agent_id > self.groups["agents"] means no attack
        perturb_agent_id_clip = perturb_agent_id.clamp(0, self.groups["agents"] - 1)
        perturb_action = perturb_action.clamp(
            -self.eval_args.perturbation_range, self.eval_args.perturbation_range)
        idx = perturb_agent_id_clip.repeat(
            1, perturb_action.size()[-1]).view(perturb_action.size()[0], 1, -1)
        val = perturb_action.view(perturb_action.size()[0], 1, -1)
      
        ep_batch['obs'][bs, t_ep, :] = ep_batch['obs'][bs,
                                                       t_ep, :].scatter_(1, idx, val, reduce="add").to(ep_batch.device)
        chosen_actions = self.defender.select_actions(
                ep_batch, t_ep, t_env, bs=bs, test_mode=True)
        if vars(self.args).get("defense_loaded", False):
            chosen_actions = chosen_actions[0]
        return chosen_actions, perturb_agent_id, perturb_action, discrete_emb, parameter_emb
