import torch as th
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter


def get_activation_func(name, hidden_dim):
    """
    'relu'
    'tanh'
    'leaky_relu'
    'elu'
    'prelu'
    :param name:
    :return:
    """
    if name == "relu":
        return nn.ReLU(inplace=True)
    elif name == "tanh":
        return nn.Tanh()
    elif name == "leaky_relu":
        return nn.LeakyReLU(negative_slope=0.01, inplace=True)
    elif name == "elu":
        return nn.ELU(alpha=1., inplace=True)
    elif name == 'prelu':
        return nn.PReLU(num_parameters=hidden_dim, init=0.25)


class Hypernet(nn.Module):
    def __init__(self, input_dim, hidden_dim, main_input_dim, main_output_dim, activation_func, n_heads):
        super(Hypernet, self).__init__()

        self.n_heads = n_heads
        # the output dim of the hypernet
        output_dim = main_input_dim * main_output_dim
        # the output of the hypernet will be reshaped to [main_input_dim, main_output_dim]
        self.main_input_dim = main_input_dim
        self.main_output_dim = main_output_dim

        self.multihead_nn = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            get_activation_func(activation_func, hidden_dim),
            nn.Linear(hidden_dim, output_dim * self.n_heads),
        )

    def forward(self, x):
        # [...,  main_output_dim + main_output_dim + ... + main_output_dim]
        # [bs, main_input_dim, n_heads * main_output_dim]
        return self.multihead_nn(x).view([-1, self.main_input_dim, self.main_output_dim * self.n_heads])


class Merger(nn.Module):
    def __init__(self, head, fea_dim):
        super(Merger, self).__init__()
        self.head = head
        if head > 1:
            self.weight = Parameter(th.Tensor(1, head, fea_dim).fill_(1.))
            self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        """
        :param x: [bs, n_head, fea_dim]
        :return: [bs, fea_dim]
        """
        if self.head > 1:
            return th.sum(self.softmax(self.weight) * x, dim=1, keepdim=False)
        else:
            return th.squeeze(x, dim=1)


class CG_HPN_RNNAgent(nn.Module):
    def __init__(self, input_shape, args):
        super(CG_HPN_RNNAgent, self).__init__()
        self.args = args
        self.n_agents = args.n_agents
        self.n_allies = self.n_agents - 1
        self.n_enemies = args.n_enemies
        self.n_actions = args.n_actions
        self.n_heads = args.hpn_head_num
        self.rnn_hidden_dim = args.rnn_hidden_dim
        assert args.use_action_repr

        # [4 + 1, (6, 5), (4, 5)]
        self.own_feats_dim, self.enemy_feats_dim, self.ally_feats_dim = input_shape
        self.enemy_feats_dim = self.enemy_feats_dim[-1]  # [n_enemies, feat_dim]
        self.ally_feats_dim = self.ally_feats_dim[-1]  # [n_allies, feat_dim]

        if self.args.obs_agent_id:
            # embedding table for agent_id
            self.agent_id_embedding = th.nn.Embedding(self.n_agents, self.rnn_hidden_dim)

        if self.args.obs_last_action:
            # embedding table for action id
            self.action_id_embedding = th.nn.Embedding(self.n_actions, self.rnn_hidden_dim)

        # Unique Features (do not need hyper net)
        self.fc1_own = nn.Linear(self.own_feats_dim, self.rnn_hidden_dim, bias=True)  # only one bias is OK

        # %%%%%%%%%%%%%%%%%%%%%% Hypernet-based API input layer %%%%%%%%%%%%%%%%%%%%
        # Multiple entities (use hyper net to process these features to ensure permutation invariant)
        self.hyper_input_w_enemy = Hypernet(
            input_dim=self.enemy_feats_dim, hidden_dim=args.hpn_hyper_dim,
            main_input_dim=self.enemy_feats_dim, main_output_dim=self.rnn_hidden_dim,
            activation_func=args.hpn_hyper_activation, n_heads=self.n_heads
        )  # output shape: (enemy_feats_dim * self.rnn_hidden_dim)
        self.hyper_input_w_ally = Hypernet(
            input_dim=self.ally_feats_dim, hidden_dim=args.hpn_hyper_dim,
            main_input_dim=self.ally_feats_dim, main_output_dim=self.rnn_hidden_dim,
            activation_func=args.hpn_hyper_activation, n_heads=self.n_heads
        )  # output shape: ally_feats_dim * rnn_hidden_dim

        # self.unify_input_heads = nn.Linear(self.rnn_hidden_dim * self.n_heads, self.rnn_hidden_dim)
        self.unify_input_heads = Merger(self.n_heads, self.rnn_hidden_dim)

        self.rnn = nn.GRUCell(self.rnn_hidden_dim, self.rnn_hidden_dim)

        self.fc2 = nn.Linear(args.rnn_hidden_dim, args.action_latent_dim)

    def init_hidden(self):
        # make hidden states on same device as model
        return self.fc1_own.weight.new(1, self.rnn_hidden_dim).zero_()

    def forward(self, inputs, hidden_state, action_repr):
        # [bs * n_agents, mv_fea_dim], [bs * n_agents * n_enemies, enemy_fea_dim], [bs * n_agents * n_allies, ally_fea_dim]
        bs, own_feats_t, enemy_feats_t, ally_feats_t, embedding_indices = inputs

        # (1) Own feature
        embedding_own = self.fc1_own(own_feats_t)  # [bs * n_agents, rnn_hidden_dim]

        # (2) ID embeddings
        if self.args.obs_agent_id:
            agent_indices = embedding_indices[0]
            # [bs * n_agents, rnn_hidden_dim * head]
            embedding_own = embedding_own + self.agent_id_embedding(agent_indices).view(
                -1, self.rnn_hidden_dim)
        if self.args.obs_last_action:
            last_action_indices = embedding_indices[-1]
            if last_action_indices is not None:  # t != 0
                # [bs * n_agents, rnn_hidden_dim * head]
                embedding_own = embedding_own + self.action_id_embedding(last_action_indices).view(
                    -1, self.rnn_hidden_dim)

        # (3) Enemy feature: [bs * n_agents * n_enemies, enemy_fea_dim] -> [bs * n_agents * n_enemies, enemy_feats_dim, rnn_hidden_dim * n_heads]
        input_w_enemy = self.hyper_input_w_enemy(enemy_feats_t)
        # [bs * n_agents * n_enemies, 1, enemy_fea_dim] * [bs * n_agents * n_enemies, enemy_fea_dim, rnn_hidden_dim * head] = [bs * n_agents * n_enemies, 1, rnn_hidden_dim * head]
        embedding_enemies = th.matmul(enemy_feats_t.unsqueeze(1), input_w_enemy).view(
            bs * self.n_agents, self.n_enemies, self.n_heads, self.rnn_hidden_dim
        )  # [bs * n_agents, n_enemies, n_head, rnn_hidden_dim]
        embedding_enemies = embedding_enemies.sum(dim=1, keepdim=False)  # [bs * n_agents, n_head, rnn_hidden_dim]

        # (4) Ally features: [bs * n_agents * n_allies, ally_fea_dim] -> [bs * n_agents * n_allies, ally_feats_dim, rnn_hidden_dim * n_heads]
        input_w_ally = self.hyper_input_w_ally(ally_feats_t)
        # [bs * n_agents * n_allies, 1, ally_fea_dim] * [bs * n_agents * n_allies, ally_fea_dim, rnn_hidden_dim * head] = [bs * n_agents * n_allies, 1, rnn_hidden_dim * head]
        embedding_allies = th.matmul(ally_feats_t.unsqueeze(1), input_w_ally).view(
            bs * self.n_agents, self.n_allies, self.n_heads, self.rnn_hidden_dim
        )  # [bs * n_agents, n_allies, n_head, rnn_hidden_dim]
        embedding_allies = embedding_allies.sum(dim=1, keepdim=False)  # [bs * n_agents, n_head, rnn_hidden_dim]
        # Final embedding, merge multiple heads into one. -> [bs * n_agents, n_head, rnn_hidden_dim]
        embedding = embedding_own + self.unify_input_heads(
            embedding_enemies + embedding_allies
        )

        x = F.relu(embedding, inplace=True)
        # self.embedded_obs = x.view(bs, self.n_agents, -1)
        h_in = hidden_state.reshape(-1, self.rnn_hidden_dim)
        hh = self.rnn(x, h_in)  # [bs * n_agents, rnn_hidden_dim]

        # Q-values of all actions
        key = self.fc2(hh).view(bs, self.n_agents, self.args.action_latent_dim).unsqueeze(-1)  # [bs, n_agents, action_latent_dim, 1]
        # [bs, n_agents, n_actions, action_latent_dim] * [bs, n_agents, action_latent_dim, 1] = [bs, n_agents, n_actions, 1]
        q = th.matmul(action_repr, key).squeeze(-1)
        return q.view(bs, self.n_agents, -1), hh.view(bs, self.n_agents, -1)
