#!/usr/bin/env python
# -*- coding: UTF-8 -*-
import numpy as np
import torch as th

from .casec_controller import CASECMAC


# This multi-agent controller shares parameters between agents
class CASEC_HPN_MAC(CASECMAC):
    def __init__(self, scheme, groups, args):
        super(CASEC_HPN_MAC, self).__init__(scheme, groups, args)
        self.n_enemies = args.n_enemies
        self.n_allies = self.n_agents - 1

    # Add new func
    def _get_obs_component_dim(self):
        move_feats_dim, enemy_feats_dim, ally_feats_dim, own_feats_dim = self.args.obs_component  # [4, (6, 5), (4, 5), 1]
        enemy_feats_dim_flatten = np.prod(enemy_feats_dim)
        ally_feats_dim_flatten = np.prod(ally_feats_dim)
        return (move_feats_dim, enemy_feats_dim_flatten, ally_feats_dim_flatten, own_feats_dim), (
            enemy_feats_dim, ally_feats_dim)

    def _build_inputs(self, batch, t):
        bs = batch.batch_size
        obs_component_dim, _ = self._get_obs_component_dim()
        raw_obs_t = batch["obs"][:, t]  # [batch, agent_num, obs_dim]
        move_feats_t, enemy_feats_t, ally_feats_t, own_feats_t = th.split(raw_obs_t, obs_component_dim, dim=-1)
        enemy_feats_t = enemy_feats_t.reshape(bs * self.n_agents * self.n_enemies,
                                              -1)  # [bs * n_agents * n_enemies, fea_dim]
        ally_feats_t = ally_feats_t.reshape(bs * self.n_agents * self.n_allies,
                                            -1)  # [bs * n_agents * n_allies, a_fea_dim]
        # merge move features and own features to simplify computation.
        context_feats = [move_feats_t, own_feats_t]  # [batch, agent_num, own_dim]
        own_context = th.cat(context_feats, dim=2).reshape(bs * self.n_agents, -1)  # [bs * n_agents, own_dim]

        embedding_indices = []
        if self.args.obs_agent_id:
            # agent-id indices, [bs, n_agents]
            embedding_indices.append(th.arange(self.n_agents, device=batch.device).unsqueeze(0).expand(bs, -1))
        if self.args.obs_last_action:
            # action-id indices, [bs, n_agents]
            if t == 0:
                embedding_indices.append(None)
            else:
                embedding_indices.append(batch["actions"][:, t - 1].squeeze(-1))

        return bs, own_context, enemy_feats_t, ally_feats_t, embedding_indices

    def _get_input_shape(self, scheme):
        move_feats_dim, enemy_feats_dim, ally_feats_dim, own_feats_dim = self.args.obs_component
        own_context_dim = move_feats_dim + own_feats_dim
        return own_context_dim, enemy_feats_dim, ally_feats_dim

    def update_action_repr(self, action_repr):
        self.action_repr = action_repr.detach().clone()  # [bs, n_agents, n_actions, action_latent_dim]
        # print('>>> Action Representation', self.action_repr.shape)
        # Pairwise Q (|A|, al) -> (|A|, |A|, 2*al)

        # [bs, n_agents, n_actions, action_latent_dim] -> [bs, n_agents, n_agents, n_actions, n_actions, action_latent_dim]
        input_i = self.action_repr.unsqueeze(2).unsqueeze(4).repeat(1, 1, self.n_agents, 1, self.n_actions, 1)
        input_j = self.action_repr.unsqueeze(1).unsqueeze(3).repeat(1, self.n_agents, 1, self.n_actions, 1, 1)
        # [bs, n_agents, n_agents, n_actions, n_actions, 2 * action_latent_dim]
        # (bs, n_agents * n_agents, |A|*|A|, 2 * action_latent_dim) (bs, n_agents * n_agents, 2 * action_latent_dim, 1)
        self.p_action_repr = th.cat([input_i, input_j], dim=-1).contiguous().view(
            self.bs, self.n_agents * self.n_agents, self.n_actions * self.n_actions, 2 * self.args.action_latent_dim
        )

    def calculate(self, ep_batch, t):
        agent_inputs = self._build_inputs(ep_batch, t)  # (bs*n, 3n) i.e. (bs*n, (obs+act+id))
        self.bs = ep_batch.batch_size

        action_repr = self.action_encoder(ep_batch["obs"][:, t])
        self.update_action_repr(action_repr)

        agent_outs, self.hidden_states = self.agent(agent_inputs, self.hidden_states, self.action_repr)
        agent_outs = agent_outs.view(self.bs, self.n_agents, self.n_actions)
        f_i = agent_outs.clone()

        if self.independent_p_q:
            self.p_hidden_states = self.p_agent.h_forward(agent_inputs, self.p_hidden_states).view(self.bs,
                                                                                                   self.n_agents, -1)
        else:
            self.p_hidden_states = self.hidden_states.clone().view(self.bs, self.n_agents, -1)

        delta_ij, his_cos_similarity = self._calculate_delta(self.p_hidden_states)
        # (bs,n,rnn_hidden_dim) -> (bs,n,n,|A|,A|)

        f_i_expand_j = f_i.unsqueeze(dim=1).unsqueeze(dim=-2).repeat(1, self.n_agents, 1, self.n_actions, 1)
        f_i_expand_i = f_i.unsqueeze(dim=2).unsqueeze(dim=-1).repeat(1, 1, self.n_agents, 1, self.n_actions)
        q_ij = f_i_expand_i.detach() + f_i_expand_j.detach() + delta_ij

        atten_ij = self._calculate_attention(self.p_hidden_states)

        delta_ij[:, self.eye2] = 0
        q_ij[:, self.eye2] = 0
        return f_i, delta_ij, q_ij, his_cos_similarity, atten_ij

    def _calculate_delta(self, hidden_states):
        # (bs,n_agents,rnn_hidden_dim) -> (bs,n_agents,n_agents,|A|,|A|)

        # (bs,n_agents,rnn_hidden_dim) -> (bs,n_agents,n_agents,2*rnn_hidden_dim)
        input_i = hidden_states.unsqueeze(2).repeat(1, 1, self.n_agents, 1)
        input_j = hidden_states.unsqueeze(1).repeat(1, self.n_agents, 1, 1)

        if self.independent_p_q:
            inputs = th.cat([input_i, input_j], dim=-1).view(-1, 2 * self.args.pair_rnn_hidden_dim)
        else:
            inputs = th.cat([input_i, input_j], dim=-1).view(-1, 2 * self.args.rnn_hidden_dim)

        history_cos_similarity = self.zeros.repeat(self.bs, 1, 1)
        if self.construction_history_similarity:
            history_cos_similarity = self.cosine_similarity(input_i.detach(), input_j.detach()).detach()

        # (bs,n,n,2*rnn_hidden_dim) -> (bs,n,n,|A|x|A|)

        if self.use_action_repr:
            key = self.delta(inputs).view(self.bs, self.n_agents * self.n_agents, 2 * self.args.action_latent_dim, 1)
            # (bs, n_agents * n_agents, 2 * action_latent_dim, 1)
            # p_action_repr -> (bs, n_agents * n_agents, |A|*|A|, 2 * action_latent_dim)
            f_ij = th.matmul(self.p_action_repr, key) / self.args.action_latent_dim / 2
        else:
            f_ij = self.delta(inputs)

        f_ij = f_ij.view(self.bs, self.n_agents, self.n_agents, self.n_actions, self.n_actions)
        f_ij = (f_ij + f_ij.permute(0, 2, 1, 4, 3).detach()) / 2.  # Achieve permutation invariant
        return f_ij, history_cos_similarity

    def action_repr_forward(self, ep_batch, t):
        return self.action_encoder.predict(ep_batch["obs"][:, t], ep_batch["actions_onehot"][:, t], ep_batch["actions"][:, t])
