import torch.nn as nn
import torch.nn.functional as F
import torch as th
import numpy as np


class RNNAttnAgent(nn.Module):
    def __init__(self, input_shape, decomposer, args):
        super(RNNAttnAgent, self).__init__()
        self.args = args
        self.decomposer = decomposer
        self.entity_embed_dim = args.entity_embed_dim

        self.fc1 = nn.Linear(self.entity_embed_dim, args.rnn_hidden_dim)
        self.rnn = nn.GRUCell(args.rnn_hidden_dim, args.rnn_hidden_dim)
        self.fc2 = nn.Linear(args.rnn_hidden_dim, args.n_actions)

        
        self.attn_embed_dim = args.attn_embed_dim
        # self.task_repre_dim = args.task_repre_dim
        ## get obs shape information
        obs_own_dim = decomposer.own_obs_dim
        obs_en_dim, obs_al_dim = decomposer.obs_nf_en, decomposer.obs_nf_al
        n_actions_no_attack = decomposer.n_actions_no_attack

        self.n_agents = args.n_agents
        
        has_attack_action = n_actions_no_attack != decomposer.n_actions

        if args.obs_agent_id and args.obs_last_action:
            if has_attack_action:
                ## get wrapped obs_own_dim
                wrapped_obs_own_dim = obs_own_dim + args.id_length + n_actions_no_attack + 1
                ## enemy_obs ought to add attack_action_info
                obs_en_dim += 1
            else:
                # wrapped_obs_own_dim = obs_own_dim + args.id_length + n_actions_no_attack
                wrapped_obs_own_dim = obs_own_dim + args.n_agents + n_actions_no_attack
        else:
            wrapped_obs_own_dim = obs_own_dim

        self.ally_value = nn.Linear(obs_al_dim, self.entity_embed_dim)
        self.enemy_value = nn.Linear(obs_en_dim, self.entity_embed_dim)
        self.own_value = nn.Linear(wrapped_obs_own_dim, self.entity_embed_dim)

        self.query = nn.Linear(self.entity_embed_dim, self.attn_embed_dim)
        self.key = nn.Linear(self.entity_embed_dim, self.attn_embed_dim)

    def init_hidden(self):
        # make hidden states on same device as model
        return self.fc1.weight.new(1, self.args.rnn_hidden_dim).zero_()

    def forward(self, inputs, hidden_state):

        task_decomposer = self.decomposer
        last_action_shape = self.args.n_actions
        obs_dim = task_decomposer.obs_dim
        obs_inputs, last_action_inputs, agent_id_inputs = inputs[:, :obs_dim], \
                                                          inputs[:, obs_dim:obs_dim + last_action_shape], inputs[:,
                                                                                                          obs_dim + last_action_shape:]
        own_obs, enemy_feats, ally_feats = task_decomposer.decompose_obs(
            obs_inputs)  # own_obs: [bs*self.n_agents, own_obs_dim]
        bs = int(own_obs.shape[0] / self.n_agents)
        _, attack_action_info, compact_action_states = task_decomposer.decompose_action_info(last_action_inputs)
        if self.args.obs_last_action and self.args.obs_agent_id:
            # if obs_last_action and obs_agent_id, then own_obs should be wrapped
            own_obs = th.cat([own_obs, agent_id_inputs, compact_action_states], dim=-1)
        else:
            own_obs = own_obs
        
        if np.prod(attack_action_info.shape) > 0:
            attack_action_info = attack_action_info.transpose(0, 1).unsqueeze(-1)
            enemy_feats = th.cat([th.stack(enemy_feats, dim=0), attack_action_info], dim=-1)
        else:
            enemy_feats = th.stack(enemy_feats, dim=0)
        ally_feats = th.stack(ally_feats, dim=0)

        # compute key, query and value for attention
        own_hidden = self.own_value(own_obs).unsqueeze(1)
        ally_hidden = self.ally_value(ally_feats).permute(1, 0, 2)
        enemy_hidden = self.enemy_value(enemy_feats).permute(1, 0, 2)
        total_hidden = th.cat([own_hidden, enemy_hidden, ally_hidden], dim=1)

        query = self.query(own_hidden)
        key = self.key(total_hidden)
        value = total_hidden
        attn_weights = th.bmm(query, key.transpose(1, 2)) / np.sqrt(self.attn_embed_dim)
        attn_weights = F.softmax(attn_weights, dim=-1)
        attn_output = th.bmm(attn_weights, value)
        base_action_inputs = attn_output[:, 0, :]

        x = F.relu(self.fc1(base_action_inputs))
        # x = base_action_inputs
        h_in = hidden_state.reshape(-1, self.args.rnn_hidden_dim)
        h = self.rnn(x, h_in)
        q = self.fc2(h)
        return q, h
