from modules.agents import REGISTRY as agent_REGISTRY
from components.action_selectors import REGISTRY as action_REGISTRY
from .basic_controller import BasicMAC
import numpy as np
import torch as th
from torch.optim import RMSprop
import copy
from torch.distributions import Categorical


def lnorm(p1, p2, l):
    if l == -1:
        return (p1 - p2).abs().max().clone()
    return ((p1 - p2).abs() ** l).sum().clone()


# This multi-agent controller which does not share parameters between agents
class IndividualMAC(BasicMAC):
    def __init__(self, scheme, groups, args):
        super(IndividualMAC, self).__init__(scheme, groups, args)

        self.adversarial_agent_id = args.adversarial_agent_id
        self.kl_divergence_upperbound = args.kl_divergence_upperbound
        self.attack_type = args.attack_type

        self.recent_log_likelihood = []
        self.recent_kl_divergence = []

        self.attack_step = 0

        self._optimizer = RMSprop(params=self.agent.parameters())  # to zero gradient

    def select_actions(self, ep_batch, t_ep, t_env, bs=slice(None), test_mode=False, info=None, expl=False):
        # Only select actions for the selected batch elements in bs
        avail_actions = ep_batch["avail_actions"][:, t_ep]
        agent_outputs = self.forward(t_env, ep_batch, t_ep, test_mode=test_mode, bs=bs, expl=expl)
        assert not test_mode, "test mode for individual controller"
        self._agent_outputs = agent_outputs[bs]
        with th.no_grad():
            critic = info['critic']
            _ep_batch = ep_batch[bs]
            _critic_input = critic._build_inputs_with_ts(_ep_batch, _ep_batch.batch_size, t_ep)
            self._q_vals = critic.forward(_critic_input).detach().clone()

        self._policy = self.action_selector.generate_expl_policy(self._agent_outputs, self._avail_actions)
        for agent_id in self.adversarial_agent_id:
            policy_before_attack = self._policy[:, agent_id].clone()
            self.adversarial_attack(agent_id, t_env, info=info)
            policy_after_attack = self._policy[:, agent_id].clone()
            self.recent_log_likelihood.append((policy_after_attack * th.log(policy_before_attack + 1e-6)).sum(dim=-1).mean())
        return Categorical(self._policy).sample().long()

    def adversarial_attack(self, agent_id, t_env, info=None):
        batch_size, n_agents, n_actions = self._agent_outputs.shape
        if self.attack_type == 'random':
            gradient = th.rand(batch_size, n_actions, device=self._agent_outputs.device) * 2 - 1
            for _ in range(batch_size):
                if np.random.rand() < self.args.attack_frequency:
                    self.attack_step += 1
                    self._policy[_, agent_id, :] = self._apply_gradient_attack(self._policy[_, agent_id], gradient[_], self._avail_actions[_, agent_id], self.action_selector.schedule.eval(t_env))
        elif self.attack_type in ['gradient']:
            theta_grad = info['theta_grad']
            for bi in range(batch_size):
                if np.random.rand() < self.args.attack_frequency:
                    self.attack_step += 1
                    advantage = (self._q_vals[bi, 0, agent_id, :] - (self._q_vals[bi, 0, agent_id, :] * self._agent_outputs[bi, agent_id, :]).sum(dim=-1, keepdim=True))
                    log_pi_theta_grad = self._forward_log_pi_theta_grad(self._agent_inputs[bi, agent_id], self._avail_actions[bi, agent_id], self._pre_hidden_states[bi, agent_id], t_env)
                    expectation_gradient = th.zeros_like(self._policy[bi, agent_id])
                    variance_gradient = th.zeros_like(self._policy[bi, agent_id])
                    for (action_id, grad) in enumerate(log_pi_theta_grad):
                        if grad is None:
                            continue
                        for i in range(len(grad)):
                            if theta_grad is not None:
                                expectation_gradient[action_id] += th.sum(grad[i].data * theta_grad[i].data)
                            variance_gradient[action_id] += th.sum(grad[i].data ** 2)
                    expectation_gradient = (expectation_gradient * advantage).clone()
                    variance_gradient = (-variance_gradient * (advantage ** 2)).clone()
                    random_gradient = th.rand(n_actions, device=self._agent_outputs.device) * 2 - 1
                    expectation_gradient = self._normalize_gradient(expectation_gradient.clone())
                    variance_gradient = self._normalize_gradient(variance_gradient.clone())
                    random_gradient = self._normalize_gradient(random_gradient.clone())
                    gradient = self.args.random_gradient_magnitude * random_gradient +\
                               self.args.expectation_gradient_magnitude * expectation_gradient +\
                               self.args.variance_gradient_magnitude * variance_gradient
                    self._policy[bi, agent_id, :] = self._apply_gradient_attack(self._policy[bi, agent_id], gradient, self._avail_actions[bi, agent_id], self.action_selector.schedule.eval(t_env))
        else:
            raise Exception("Sorry {} attack method is not implemented!".format(self.attack_type))

    def _forward_log_pi_theta_grad(self, _agent_inputs, avail_actions, _hidden_states, t_env):
        agent_inputs = _agent_inputs.clone()
        hidden_states = _hidden_states.clone()

        agent_outs, _ = self.agent(agent_inputs.unsqueeze(0), hidden_states.unsqueeze(0))
        agent_outs = agent_outs.squeeze()

        # Softmax the agent outputs if they're policy logits
        if self.agent_output_type == "pi_logits":

            if getattr(self.args, "expl_mask_before_softmax", True):
                # Make the logits for unavailable actions very negative to minimise their affect on the softmax
                agent_outs[avail_actions == 0] = -1e11

            agent_outs = th.nn.functional.softmax(agent_outs, dim=-1)
            # Epsilon floor
            epsilon_action_num = agent_outs.size(-1)
            if getattr(self.args, "expl_mask_before_softmax", True):
                # With probability epsilon, we will pick an available action uniformly
                epsilon_action_num = avail_actions.sum(dim=-1, keepdim=True).float()

            agent_outs = ((1 - self.action_selector.schedule.eval(t_env)) * agent_outs
                           + th.ones_like(agent_outs) * self.action_selector.schedule.eval(t_env)/epsilon_action_num)

            if getattr(self.args, "expl_mask_before_softmax", True):
                # Zero out the unavailable actions
                agent_outs[avail_actions == 0] = 0.0
                # agent_outs /= agent_outs.sum(dim=-1, keepdim=True)

        agent_outs = th.log(self.action_selector.generate_expl_policy(agent_outs, avail_actions, require_grad=True))  # modified in 1222 (add require_grad=True)
        ret = []
        for action_id in range(agent_outs.shape[0]):
            if avail_actions[action_id] == 0:
                ret.append(None)
                continue

            self._optimizer.zero_grad()
            # print(agent_outs[action_id])
            agent_outs[action_id].backward(retain_graph=True)
            grad = list(self.agent.parameters())
            ret_i = copy.deepcopy(list(self.agent.parameters()))
            for i in range(len(grad)):
                # print("grad:", grad[i].grad)
                ret_i[i].data = copy.deepcopy(grad[i].grad)
            ret.append(copy.deepcopy(ret_i))
        return copy.deepcopy(ret)

    def _normalize_gradient(self, gradient):
        return gradient / ((gradient ** 2).sum() ** .5 + 1e-6)

    def _apply_gradient(self, policy, gradient, avail_actions, coef):
        logit_gradient = policy * gradient * coef
        logit_gradient = logit_gradient - logit_gradient.sum() * policy
        new_logits = th.log(policy + 1e-6) - logit_gradient
        new_policy = th.nn.functional.softmax(new_logits)
        new_policy[avail_actions == 0] = 0.
        new_policy /= new_policy.sum(dim=-1, keepdim=True)
        return new_policy

    def _apply_gradient_attack(self, _policy, _gradient, avail_actions, epsilon):
        policy = _policy.clone()
        gradient = _gradient.clone()

        if avail_actions.sum() == 1.:
            return policy

        gradient = self._normalize_gradient(gradient.clone())

        lb = 0.
        rb = 1.

        gradient_attack_binary_search_step = self.args.gradient_attack_binary_search_step

        while gradient_attack_binary_search_step < self.args.max_gradient_attack_binary_search_step:
            attacked_policy = self._apply_gradient(policy, gradient, avail_actions, rb)
            kl_divergence = th.sum(attacked_policy * th.log(attacked_policy + 1e-6)) -\
                            th.sum(attacked_policy * th.log(policy + 1e-6))
            if kl_divergence > self.kl_divergence_upperbound:
                break

            rb *= 2.
            gradient_attack_binary_search_step += 1

        for _ in range(gradient_attack_binary_search_step):
            mid = .5 * (lb + rb)
            attacked_policy = self._apply_gradient(policy, gradient, avail_actions, mid)
            kl_divergence = th.sum(attacked_policy * th.log(attacked_policy + 1e-6)) -\
                            th.sum(attacked_policy * th.log(policy + 1e-6))
            if kl_divergence > self.kl_divergence_upperbound:
                rb = mid
            else:
                lb = mid

        attacked_policy = self._apply_gradient(policy, gradient, avail_actions, lb)
        kl_divergence = th.sum(attacked_policy * th.log(attacked_policy + 1e-6)) -\
                        th.sum(attacked_policy * th.log(policy + 1e-6))
        self.recent_kl_divergence.append(kl_divergence)
        assert (policy == _policy).all(), "policy changed during attack"
        return attacked_policy.clone()
