import torch
import torch.nn as nn
import torch.nn.functional as F

class Discriminator_net(nn.Module):
    def __init__(self, input_dim, output_dim, args):
        super(Discriminator_net, self).__init__()
        self.args = args
        self.n_agents = args.n_agents
        self.hidden_dim = args.hidden_dim
        self.latent_dim = self.hidden_dim // 2
        dr = getattr(args, "dropout", 0.1)

        self.encoder_mlp = nn.Sequential(
            nn.Linear(input_dim, self.hidden_dim),
            nn.ReLU(),
            nn.LayerNorm(self.hidden_dim),
            nn.Linear(self.hidden_dim, self.hidden_dim),
            nn.ReLU(),
            nn.Dropout(dr)
        )
        self.gru = nn.GRUCell(self.hidden_dim, self.hidden_dim)
        self.q_layer = nn.Linear(self.hidden_dim, self.hidden_dim)
        self.k_layer = nn.Linear(self.hidden_dim, self.hidden_dim)
        self.v_layer = nn.Linear(self.hidden_dim, self.latent_dim)

        self.decoder = nn.Sequential(
            nn.Linear(self.latent_dim, self.hidden_dim),
            nn.ReLU(),
            nn.Dropout(dr),
            nn.Linear(self.hidden_dim, self.hidden_dim),
            nn.ReLU(),
            nn.Linear(self.hidden_dim, output_dim)
        )

    def forward(self, inputs, hidden_state):

        batch_total = inputs.shape[0]

        hidden_state = hidden_state.reshape(-1, self.hidden_dim)

        x = self.encoder_mlp(inputs)         
        h = self.gru(x, hidden_state)         
        h = h + x                            

        q = self.q_layer(x)
        k = self.k_layer(h)
        v = self.v_layer(h)
        attn_score = (q * k).sum(-1, keepdim=True)
        attn_weight = torch.sigmoid(attn_score)
        z = attn_weight * v

        out = self.decoder(z)               
        return out, z, h

    def init_hidden(self):
        return self.encoder_mlp[0].weight.new_zeros(1, self.hidden_dim)
