# code adapted from https://github.com/wendelinboehmer/dcg
import math

import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F


class DTAPSAgent(nn.Module):
    def __init__(self, input_shape, args):
        super().__init__()
        self.args = args
        self.n_agents = args.n_agents

        self.attention_dim   = getattr(args, "attention_dim", 64)
        self.attn_temperature = getattr(args, "attn_temperature", 1.5)  
        self.edge_prior_beta  = getattr(args, "edge_prior_beta", 0.5)
        self.attn_logit_clip  = getattr(args, "attn_logit_clip", 5.0)
        self.attn_dropout_p   = getattr(args, "attn_dropout", 0.0)
        self.fc1 = nn.Linear(input_shape, args.hidden_dim)
        assert getattr(self.args, "use_rnn", True), "This agent expects use_rnn=True"
        self.rnn = nn.GRUCell(args.hidden_dim, args.hidden_dim)

        self.msg_k = nn.Linear(args.hidden_dim, self.attention_dim)
        self.msg_q = nn.Linear(args.hidden_dim, self.attention_dim)
        self.msg_v = nn.Linear(args.hidden_dim, args.hidden_dim)

        self.msg_ln = nn.LayerNorm(args.hidden_dim)
        self.attn_dropout = 0.1 # nn.Dropout(self.attn_dropout_p)

        self.fc2 = nn.Sequential(
            nn.Linear(args.hidden_dim * 2, args.hidden_dim),
            nn.ReLU(),
            nn.Linear(args.hidden_dim, args.n_actions),
        )

        self.state_dim = int(np.prod(args.state_shape))
        self.predict_net = nn.Sequential(
            nn.Linear(args.hidden_dim, args.hidden_dim),
            nn.ReLU(),
            nn.Linear(args.hidden_dim, self.state_dim),
        )

        for m in [self.fc1] + [self.fc2[0], self.fc2[2]]:
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                nn.init.constant_(m.bias, 0.0)

        for m in [self.msg_q, self.msg_k]:
            nn.init.xavier_uniform_(m.weight, gain=0.8)
            nn.init.constant_(m.bias, 0.0)
        nn.init.xavier_uniform_(self.msg_v.weight, gain=0.8)
        nn.init.constant_(self.msg_v.bias, 0.0)

    def init_hidden(self):
        # make hidden states on same device as model
        h = self.fc1.weight.new(1, self.args.hidden_dim).zero_()
        msg = th.zeros_like(h)
        hidden_info = (h, msg)
        return hidden_info
    
    def get_encoding(self, inputs, hidden_info):
        hidden_state, msg = hidden_info
        h_in = hidden_state.reshape(-1, self.args.hidden_dim)

        x = F.relu(self.fc1(inputs))
        h = self.rnn(x, h_in)
        
        temp_hidden_info = (h, msg)
        return h, temp_hidden_info    
    
    def communicate_and_act(self, h, temp_hidden_info, topk_indices, dynamic_weights=None):
        _, msg = temp_hidden_info
        
        msg = self._communicate(topk_indices, msg, h, dynamic_weights)

        out = th.cat([h, msg], dim=-1)
        q = self.fc2(out)
        
        new_hidden_info = (h, msg)
        return q, new_hidden_info, msg
    
    def forward(self, inputs, hidden_info, topk_indices):

        hidden_state, msg = hidden_info
        h_in = hidden_state.reshape(-1, self.args.hidden_dim)

        x = F.relu(self.fc1(inputs))
        h = self.rnn(x, h_in)
        msg = self._communicate1(topk_indices, msg, h)

        # detach msg for main loss, use aux loss for msg
        out = th.cat([h, msg], dim=-1)
        q = self.fc2(out)
        # (bs * n_agents, n_actions), (bs * n_agents, hidden_dim)
        hidden_info = (h, msg)
        return q, hidden_info, msg

    def aux_forward(self, msg):
        return self.predict_net(msg)
    
  
    def _communicate(self, topk_indices, other_msg, ego_h, dynamic_weights=None):
        bs, n_agents, topk = topk_indices.shape
        assert n_agents == self.n_agents

        clip_value = 10.0

        # (bs, n_agents, topk, hidden_dim)
        topk_indices = topk_indices[:, :, :, None].expand(
            -1, -1, -1, self.args.hidden_dim
        )
        # (bs, n_agents(repeated), n_agents, hidden_dim)
        other_msg = other_msg.reshape(bs, 1, n_agents, self.args.hidden_dim).expand(
            -1, n_agents, -1, -1
        )

        # (bs, n_agents, topk, hidden_dim)
        msg_received = other_msg.gather(dim=2, index=topk_indices)

        ego_h = ego_h.reshape(bs, n_agents, -1)
        q = self.msg_q(ego_h).reshape(bs, n_agents, self.attention_dim, 1)
        k = self.msg_k(msg_received)
        
        attention_scores = th.matmul(k, q)[:, :, :, 0]
        attention_scores = th.clamp(attention_scores, -clip_value, clip_value)
        
        attention = F.softmax(attention_scores / math.sqrt(self.attention_dim), dim=-1)
        
        if dynamic_weights is not None:
            dynamic_weights = th.clamp(dynamic_weights, 0.1, 10.0)
            attention = attention * dynamic_weights
            attention = attention / (attention.sum(dim=-1, keepdim=True) + 1e-8)
        
        attention = th.clamp(attention, 1e-5, 1.0 - 1e-5)
        
        # (bs, n_agents, topk, hidden_dim)
        m = self.msg_v(msg_received)
        # (bs, n_agents, hidden_dim)
        m_aggregated = (attention[:, :, :, None] * m).sum(dim=2)

        return m_aggregated.reshape(bs * n_agents, -1)


class DTAPSContAgent(DTAPSAgent):

    def aux_forward(self, msg):
        return F.normalize(msg, dim=-1)
