from ast import arg
from copy import deepcopy
from modules.agents import REGISTRY as agent_REGISTRY
import torch as th
import numpy as np
from controllers import REGISTRY as mac_REGISTRY
import yaml
import os
import collections
import argparse
import wandb
from modules.trajectory_encoder import REGISTRY as enc_REGISTRY
from modules.decoder import REGISTRY as dec_REGISTRY

def recursive_dict_update(d, u):
    for k, v in u.items():
        if isinstance(v, collections.abc.Mapping):
            d[k] = recursive_dict_update(d.get(k, {}), v)
        else:
            d[k] = v
    return d


def load_config(config_name, subfolder):
    with open(os.path.join(os.path.dirname(os.path.dirname(__file__)), "config", subfolder, "{}.yaml".format(config_name)), "r") as f:
        try:
            config_dict = yaml.load(f, Loader=yaml.CLoader)
        except AttributeError:
            config_dict = yaml.load(f, Loader=yaml.Loader)
        except yaml.YAMLError as exc:
            assert False, "{}.yaml error: {}".format(config_name, exc)

    return config_dict

# This multi-agent controller shares parameters between agents


class PopulationMAC:
    def __init__(self, scheme, groups, args):
        self.args = args
        self.scheme = scheme
        self.groups = groups

        self.attacker_args = deepcopy(vars(args))
        config_dict = load_config(args.attacker, "algs")
        if args.use_wandb:
            config = wandb.config
            config[args.attacker] = config_dict
        self.attacker_args = recursive_dict_update(
            self.attacker_args, config_dict)
        self.attacker_args = argparse.Namespace(**self.attacker_args)
        self.defender_args = deepcopy(vars(args))
        config_dict = load_config(args.defender, "algs")
        if args.use_wandb:
            config[args.defender] = config_dict
        self.defender_args = recursive_dict_update(
            self.defender_args, config_dict)
        self.defender_args = argparse.Namespace(**self.defender_args)
        # attacker can have population
        self.controller_dict = {
            "defender": mac_REGISTRY[self.defender_args.mac](scheme, groups, self.defender_args),
            "attacker": [mac_REGISTRY[self.attacker_args.mac](scheme, groups, self.attacker_args) for _ in range(self.args.n_attackers)],
        }
        self.chosen_attacker_idx = 0
        self.reproduction_attackers = []
        self._build_trajectory_autoencoder(self.args)
        from torch.optim import Adam
        self.enc_params = list(self.traj_encoder.parameters()) + list(self.traj_decoder.parameters())
        self.enc_optimiser = Adam(params=self.enc_params, lr=self.args.lr)

    def set_attacker(self, attacker_idx):
        self.chosen_attacker_idx = attacker_idx

    def init_hidden(self, batch_size):
        self.controller_dict["defender"].init_hidden(batch_size)

    def select_actions(self, ep_batch, t_ep, t_env, pretrain=False, bs=slice(None), test_mode=False):
        attacker_learn = True if t_env % (self.args.attacker_train + self.args.defender_train) < self.args.attacker_train and t_env < self.args.t_max - 1000000 else False
        if self.chosen_attacker_idx < self.args.n_attackers:
            perturb_agent_id, perturb_action, discrete_emb, parameter_emb = self.controller_dict["attacker"][self.chosen_attacker_idx].select_actions(
                ep_batch, t_ep, t_env, pretrain, bs=bs, test_mode=test_mode if attacker_learn else True)
        else:
            perturb_agent_id, perturb_action, discrete_emb, parameter_emb = self.reproduction_attackers[self.chosen_attacker_idx - self.args.n_attackers].select_actions(
                ep_batch, t_ep, t_env, pretrain, bs=bs, test_mode=test_mode if attacker_learn else True)
        # NOTE: broadcast
        perturb_action[perturb_agent_id.view(-1) >= self.groups["agents"]] = 0
        # NOTE: perturba_agent_id > self.groups["agents"] means no attack
        perturb_agent_id_clip = perturb_agent_id.clamp(0, self.groups["agents"] - 1)
        # perturb_action = perturb_action.clamp(
        #     -self.attacker_args.perturbation_range, self.attacker_args.perturbation_range)
        idx = perturb_agent_id_clip.repeat(
            1, perturb_action.size()[-1]).view(perturb_action.size()[0], 1, -1)
        val = perturb_action.view(perturb_action.size()[0], 1, -1)
        ep_batch['obs'][bs, t_ep, :] = ep_batch['obs'][bs,
                                                       t_ep, :].scatter_(1, idx, val, reduce="add")

        if not attacker_learn:
            chosen_actions = self.controller_dict["defender"].select_actions(
                ep_batch, t_ep, t_env, bs=bs, test_mode=test_mode)
        else:
            chosen_actions = self.controller_dict["defender"].select_actions(
                ep_batch, t_ep, t_env, bs=bs, test_mode=True)
        return chosen_actions, perturb_agent_id, perturb_action, discrete_emb, parameter_emb
    
    def init_reproduction_attackers(self):
        self.reproduction_attackers = []
        selected_attackers_id = np.random.choice(range(self.args.n_attackers), int(self.args.n_attackers * self.args.reproduction_ratio), replace=False)
        for idx in selected_attackers_id:
            self.reproduction_attackers.append(deepcopy(self.controller_dict["attacker"][idx]))
    
    def reset_attacker_exploration(self, idx):
        self.reproduction_attackers[idx].reset_attacker_exploration()

    def reinit(self, idx):
        self.controller_dict['attacker'][idx] = mac_REGISTRY[self.attacker_args.mac](self.scheme, self.groups, self.attacker_args)

    def decode_forward(self, encoded_z, state_input, hidden_state):
        return self.traj_decoder(encoded_z, state_input, hidden_state)

    def encode_forward(self, trajectory_input, trajectory_mask):
        return self.traj_encoder(trajectory_input, trajectory_mask)

    def save_attackers(self, path):
        for attacker in self.controller_dict['attacker']:
            attacker.save_models(path)

    def train_traj_encoder(self, batch):
        terminated = batch["terminated"][:, :-1].float()
        mask = batch["filled"][:, :-1].float()
        mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1])
        states = batch["state"]
        encoded_z = self.encode_forward(states[:, :-1], mask)

        dec_hidden = self.traj_decoder.init_hidden().repeat(batch.batch_size, 1)
        pred_states = []
        for t in range(batch.max_seq_length - 1):
            pred_state, dec_hidden = self.decode_forward(encoded_z, states[:, t], dec_hidden)
            pred_states.append(pred_state)
        pred_states = th.stack(pred_states, dim=1)  # [bs, seq_len - 1, state_dim]
        # Calculate state targets
        target_states = states[:, 1:]   # state prediction target
        
        # State pred-error
        pred_error = th.sqrt(th.sum((pred_states - target_states) ** 2, dim=-1, keepdim=True))
        
        pred_mask = mask.expand_as(pred_error)
        
        # Calculate concrete loss
        pred_loss = (pred_error * mask).sum() / pred_mask.sum()
        
        self.enc_optimiser.zero_grad()
        pred_loss.backward()
        pred_grad_norm = th.nn.utils.clip_grad_norm_(self.enc_params, self.args.grad_norm_clip)
        self.enc_optimiser.step()
        return pred_loss.item(), pred_grad_norm

    def _build_trajectory_autoencoder(self, args):
        self.traj_encoder = enc_REGISTRY[args.traj_encoder](args)
        self.traj_decoder = dec_REGISTRY[args.traj_decoder](args)
        if self.args.use_cuda:
            self.traj_encoder.cuda()
            self.traj_decoder.cuda()
