import torch as th
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


class MAVENmixer(nn.Module):
    def __init__(self, args):
        super(MAVENmixer, self).__init__()

        self.args = args
        self.n_agents = args.n_agents
        self.state_dim = int(np.prod(args.state_shape)) + args.noise_dim

        self.embed_dim = args.mixing_embed_dim
        self.device = args.device
        self.hyper_w_1 = nn.Linear(self.state_dim, self.embed_dim * self.n_agents, device=self.device)
        self.hyper_w_final = nn.Linear(self.state_dim, self.embed_dim, device=self.device)


        # Initialise the hyper networks with a fixed variance, if specified
        if self.args.hyper_initialization_nonzeros > 0:
            std = self.args.hyper_initialization_nonzeros ** -0.5
            self.hyper_w_1.weight.data.normal_(std=std)
            self.hyper_w_1.bias.data.normal_(std=std)
            self.hyper_w_final.weight.data.normal_(std=std)
            self.hyper_w_final.bias.data.normal_(std=std)

        # Initialise the hyper-network of the skip-connections, such that the result is close to VDN
        if self.args.skip_connections:
            self.skip_connections = nn.Linear(self.state_dim, self.args.n_agents, bias=True, device=self.device)
            self.skip_connections.bias.data.fill_(1.0)  # bias produces initial VDN weights

        # State dependent bias for hidden layer
        self.hyper_b_1 = nn.Linear(self.state_dim, self.embed_dim, device=self.device)

        # V(s) instead of a bias for the last layers
        self.V = nn.Sequential(nn.Linear(self.state_dim, self.embed_dim, device=self.device),
                               nn.GELU(),
                               nn.Linear(self.embed_dim, 1, device=self.device))

    def forward(self, agent_qs, states, noise):
        bs = agent_qs.size(0)
        states = th.cat([states, noise], dim=-1)
        states = states.reshape(-1, self.state_dim)
        agent_qs = agent_qs.view(-1, 1, self.n_agents)
        # First layer
        w1 = th.abs(self.hyper_w_1(states))
        b1 = self.hyper_b_1(states)
        w1 = w1.view(-1, self.n_agents, self.embed_dim)
        b1 = b1.view(-1, 1, self.embed_dim)
        hidden = F.elu(th.bmm(agent_qs, w1) + b1)
        # Second layer
        w_final = th.abs(self.hyper_w_final(states))
        w_final = w_final.view(-1, self.embed_dim, 1)
        # State-dependent bias
        v = self.V(states).view(-1, 1, 1)
        # Skip connections
        s = 0
        if self.args.skip_connections:
            ws = th.abs(self.skip_connections(states)).view(-1, self.n_agents, 1)
            s = th.bmm(agent_qs, ws)
        # Compute final output
        y = th.bmm(hidden, w_final) + v + s
        # Reshape and return
        q_tot = y.view(bs, -1, 1)
        return q_tot