from modules.agents import REGISTRY as agent_REGISTRY
import torch as th
import numpy as np
from adversarial import P_TD3_relable
from adversarial import ActionRepresentation_vae
from copy import deepcopy
import pickle

# This multi-agent controller shares parameters between agents


class AttackMAC:
    def __init__(self, scheme, groups, args):
        self.args = args
        input_shape = self._get_input_shape(scheme)
        output_shape_discrete = self._get_output_shape_discrete(scheme)
        output_shape_continuous = self._get_output_shape_continuous(scheme)
        self._build_agents(input_shape, output_shape_discrete,
                           output_shape_continuous)

        self.steps = 0
        self.last_reset_attacker = 0
        self.hidden_states = None
        self.c_rate = None
        self.recon_s = None

    def select_actions(self, ep_batch, t_ep, t_env, pretrain=False, bs=slice(None), test_mode=False):
        # Only select actions for the selected batch elements in bs
        self.steps = t_env - self.last_reset_attacker
        if pretrain:
            if self.args.attack_agent_0:
                discrete_action = th.randint(
                    0, 1, (ep_batch.batch_size, 1)).to(ep_batch.device)
            else:
                discrete_action = th.randint(
                    0, self.args.n_agents + 1, (ep_batch.batch_size, 1)).to(ep_batch.device)
            # parameter_action = (th.rand((ep_batch.batch_size, ep_batch['obs'].shape[-1])).to(
            #     ep_batch.device) - 0.5) * 2 * self.args.perturbation_range
            parameter_action = (th.rand((ep_batch.batch_size, ep_batch['obs'].shape[-1])).to(
                ep_batch.device) - 0.5) * 2
            discrete_emb = th.zeros(ep_batch.batch_size, self.args.ad_policy_dict['discrete_action_dim'])
            parameter_emb = th.zeros(ep_batch.batch_size, self.args.ad_policy_dict['parameter_action_dim'])
        else:
            discrete_action, parameter_action, discrete_emb, parameter_emb = self.forward(
                ep_batch, t_ep, test_mode=test_mode
            )
            discrete_action = discrete_action.reshape(ep_batch.batch_size, -1)
            parameter_action = parameter_action.reshape(ep_batch.batch_size, -1)
            discrete_emb = discrete_emb.reshape(ep_batch.batch_size, -1)
            parameter_emb = parameter_emb.reshape(ep_batch.batch_size, -1)
        parameter_action = parameter_action / th.linalg.norm(parameter_action, ord=1, dim=1, keepdim=True).repeat(1, parameter_action.size()[1]) * self.args.perturbation_range
        return discrete_action[bs], parameter_action[bs], discrete_emb[bs], parameter_emb[bs]

    def train_agent(self, replay_buffer, batch_size):
        _, _, critic_loss, actor_loss = self.agent.train(replay_buffer, self.vae, self.c_rate, self.recon_s, batch_size)
        return critic_loss, actor_loss

    def forward(self, ep_batch, t, test_mode=False):
        agent_inputs = self._build_inputs(
            ep_batch, t).view(ep_batch.batch_size, -1)
        discrete_emb, parameter_emb = self.agent.select_action(
            agent_inputs)
        if not test_mode:
            if self.steps < self.args.epsilon_steps:
                epsilon = self.args.expl_noise_initial - (self.args.expl_noise_initial - self.args.expl_noise) * (
                    self.steps / self.args.epsilon_steps)
            else:
                epsilon = self.args.expl_noise
        else:  # test mode
            epsilon = 0
        # discrete_emb = (discrete_emb + th.rand((ep_batch.batch_size, self.args.ad_policy_dict['discrete_action_dim'])).to(ep_batch.device) * self.args.perturbation_range * epsilon).clip(
        #     -self.args.ad_policy_dict['max_action'], self.args.ad_policy_dict['max_action'])
        # parameter_emb = (parameter_emb + th.rand((ep_batch.batch_size, self.args.ad_policy_dict['parameter_action_dim'])).to(ep_batch.device) * self.args.perturbation_range * epsilon).clip(
        #     -self.args.ad_policy_dict['max_action'], self.args.ad_policy_dict['max_action'])
        discrete_emb = (discrete_emb + th.rand((ep_batch.batch_size, self.args.ad_policy_dict['discrete_action_dim'])).to(ep_batch.device) * epsilon).clip(
            -self.args.ad_policy_dict['max_action'], self.args.ad_policy_dict['max_action'])
        parameter_emb = (parameter_emb + th.rand((ep_batch.batch_size, self.args.ad_policy_dict['parameter_action_dim'])).to(ep_batch.device) * epsilon).clip(
            -self.args.ad_policy_dict['max_action'], self.args.ad_policy_dict['max_action'])
        true_parameter_emb = true_parameter_action(parameter_emb, self.c_rate)
        discrete_action_embedding = deepcopy(discrete_emb)
        discrete_action_embedding = discrete_action_embedding
        discrete_action = self.vae.select_discrete_action(
            discrete_action_embedding)
        discrete_emb_1 = self.vae.get_embedding(
            discrete_action)
        parameter_action = self.vae.select_parameter_action(
            agent_inputs, true_parameter_emb, discrete_emb_1)

        return discrete_action, parameter_action, discrete_action_embedding, true_parameter_emb

    def reset_attacker_exploration(self):
        self.last_reset_attacker = self.steps

    def init_hidden(self, batch_size):
        pass

    def parameters(self):
        return self.agent.actor.parameters(), self.agent.critic.parameters()

    def load_state(self, other_mac: "AttackMAC"):
        self.agent.actor.load_state_dict(other_mac.agent.actor.state_dict())
        self.agent.critic.load_state_dict(other_mac.agent.critic.state_dict())
        self.c_rate = deepcopy(other_mac.c_rate)

    def cuda(self):
        self.agent.cuda()

    def save_models(self, path):
        self.agent.save("ad", path)
        self.vae.save("vae", path)
        c_rate = np.array(self.c_rate)
        np.save(path + "crate.npy", c_rate)

    def load_models(self, path):
        self.agent.load("ad", path)
        self.vae.load("vae", path)
        self.c_rate = np.load(path + "crate.npy")

    def _build_agents(self, input_shape, output_shape_discrete, output_shape_continuous):
        self.args.ad_policy_dict['state_dim'] = input_shape
        self.args.ad_policy_dict['discrete_action_dim'] = self.args.discrete_action_dim
        self.args.ad_policy_dict['parameter_action_dim'] = self.args.parameter_action_dim
        self.args.ad_policy_dict['device'] = 'cpu' if self.args.ad_policy_cpu_only else 'cuda'
        self.agent = P_TD3_relable.TD3(**self.args.ad_policy_dict)
        self.ad_parameter_action_dim = output_shape_continuous
        kwargs = {
            'state_dim': input_shape,
            'action_dim': output_shape_discrete,
            'parameter_action_dim': output_shape_continuous,
            'reduced_action_dim': self.args.ad_policy_dict['discrete_action_dim'],
            'reduce_parameter_action_dim': self.args.ad_policy_dict['parameter_action_dim'],
            # 'recon_loss_c_weight': self.args.recon_loss_c_weight * (1.0 / self.args.perturbation_range ** 2),
            'recon_loss_c_weight': self.args.recon_loss_c_weight * (1.0 * (self.args.perturbation_range / 2)),
            'embed_lr': self.args.embed_lr,
            'device': 'cpu' if self.args.ad_policy_cpu_only else 'cuda'
        }
        self.vae = ActionRepresentation_vae.Action_representation(**kwargs)

    def _build_inputs(self, batch, t):
        return batch["state"][:, t]

    def _get_input_shape(self, scheme):
        input_shape = np.prod(scheme["state"]["vshape"])

        return input_shape

    def _get_output_shape_discrete(self, scheme):
        if self.args.attack_agent_0:
            output_shape = 1
        else:
            output_shape = self.args.n_agents + 1

        return output_shape

    def _get_output_shape_continuous(self, scheme):
        output_shape = scheme["obs"]["vshape"]

        return output_shape


def count_boundary(c_rate):
    median = (c_rate[0] - c_rate[1]) / 2
    offset = c_rate[0] - 1 * median
    return median, offset


def true_parameter_action(parameter_action, c_rate):
    parameter_action_ = deepcopy(parameter_action)
    for i in range(len(parameter_action)):
        median, offset = count_boundary(c_rate[i])
        parameter_action_[i] = parameter_action_[i] * median + offset
    return parameter_action_
