import torch as th
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


class QoptMixer(nn.Module):
    def __init__(self, args):
        super(QoptMixer, self).__init__()

        self.args = args
        self.n_agents = args.n_agents
        self.state_dim = int(np.prod(args.state_shape))

        self.embed_dim = args.mixing_embed_dim

        self.hidden_dim = 128


        if getattr(args, "hypernet_layers", 1) == 1:
            self.hyper_w_1 = nn.Linear(self.state_dim, self.embed_dim * self.n_agents)
            self.hyper_w_final = nn.Linear(self.state_dim, self.embed_dim)

            self.hyper_w_2 = nn.Linear(self.state_dim, self.embed_dim * self.n_agents)
            self.hyper_w_final2 = nn.Linear(self.state_dim, self.embed_dim)
        elif getattr(args, "hypernet_layers", 1) == 2:
            hypernet_embed = self.args.hypernet_embed
            self.hyper_w_1 = nn.Sequential(nn.Linear(self.state_dim, hypernet_embed),
                                           nn.ReLU(),
                                           nn.Linear(hypernet_embed, self.embed_dim * self.n_agents))
            self.hyper_w_final = nn.Sequential(nn.Linear(self.state_dim, hypernet_embed),
                                           nn.ReLU(),
                                           nn.Linear(hypernet_embed, self.embed_dim))

            self.hyper_w_2 = nn.Sequential(nn.Linear(self.state_dim, hypernet_embed),
                                           nn.ReLU(),
                                           nn.Linear(hypernet_embed, self.embed_dim * self.n_agents))
            self.hyper_w_final2 = nn.Sequential(nn.Linear(self.state_dim, hypernet_embed),
                                           nn.ReLU(),
                                           nn.Linear(hypernet_embed, self.embed_dim))

            self.hyper_w_3 = nn.Sequential(nn.Linear(self.state_dim, hypernet_embed),
                                           nn.ReLU(),
                                           nn.Linear(hypernet_embed, self.embed_dim * self.n_agents))
            self.hyper_w_final3 = nn.Sequential(nn.Linear(self.state_dim, hypernet_embed),
                                           nn.ReLU(),
                                           nn.Linear(hypernet_embed, self.embed_dim))


            self.weight = th.Tensor(1, 1, self.n_agents)
            self.weight = th.nn.Parameter(0.2 + 0.0 * th.nn.init.xavier_uniform_(self.weight))




        elif getattr(args, "hypernet_layers", 1) > 2:
            raise Exception("Sorry >2 hypernet layers is not implemented!")
        else:
            raise Exception("Error setting number of hypernet layers.")

        # State dependent bias for hidden layer
        self.hyper_b_1 = nn.Linear(self.state_dim, self.embed_dim)
        self.hyper_b_2 = nn.Linear(self.state_dim, self.embed_dim)
        self.hyper_b_3 = nn.Linear(self.state_dim, self.embed_dim)


        # V(s) instead of a bias for the last layers
        self.V = nn.Sequential(nn.Linear(self.state_dim, self.embed_dim),
                               nn.ReLU(),
                               nn.Linear(self.embed_dim, self.embed_dim),
                               nn.ReLU(),
                               nn.Linear(self.embed_dim, 1))

        self.V2 = nn.Sequential(nn.Linear(self.state_dim, self.embed_dim),
                               nn.ReLU(),
                               nn.Linear(self.embed_dim, self.embed_dim),
                               nn.ReLU(),
                               nn.Linear(self.embed_dim, self.n_agents))

    def forward(self, agent_qs, states):
        bs = agent_qs.size(0)
        ts = agent_qs.size(1)


        v2 = self.V2(states).view(bs, ts, self.n_agents)


        # agent_qs = agent_qs * (0 + th.abs(self.weight))

        states = states.reshape(-1, self.state_dim)
        agent_qs_rep = agent_qs.view(bs, ts, 1, self.n_agents).repeat(1, 1, self.n_agents, 1)
        agent_qs2_rep = v2.view(bs, ts, 1, self.n_agents).repeat(1, 1, self.n_agents, 1)
        
        mask = th.eye(self.n_agents, device=self.args.device).view(1, 1, self.n_agents, self.n_agents)
        agent_qs_all = agent_qs_rep * (1-mask) + agent_qs2_rep * mask
        
        agent_qs_all = agent_qs_all.view(bs * ts, self.n_agents, self.n_agents)
        agent_qs = agent_qs.view(bs * ts, 1, self.n_agents)

        # First layer
        w1 = self.hyper_w_1(states)
        b1 = self.hyper_b_1(states)
        w1 = w1.view(-1, self.n_agents, self.embed_dim)
        b1 = b1.view(-1, 1, self.embed_dim)
        hidden = F.elu(th.bmm(agent_qs, w1) + b1)

        w2 = th.abs(self.hyper_w_2(states))
        b2 = self.hyper_b_2(states)
        w2 = w2.view(-1, self.n_agents, self.embed_dim)
        b2 = b2.view(-1, 1, self.embed_dim)
        hidden2 = F.elu(th.bmm(agent_qs, w2) + b2)

        w3 = th.abs(self.hyper_w_3(states))
        b3 = self.hyper_b_3(states)
        w3 = w3.view(-1, self.n_agents, self.embed_dim)
        b3 = b3.view(-1, 1, self.embed_dim)
        hidden3 = F.elu(th.bmm(agent_qs_all, w3) + b3)


        # Second layer
        w_final = self.hyper_w_final(states)
        w_final = w_final.view(-1, self.embed_dim, 1)

        w_final2 = th.abs(self.hyper_w_final2(states))
        w_final2 = w_final2.view(-1, self.embed_dim, 1)

        w_final3 = th.abs(self.hyper_w_final3(states))
        w_final3 = w_final3.view(-1, self.embed_dim, 1)


        # print (self.weight)
  
        # State-dependent bias
        v = self.V(states).view(bs, ts, 1)

        states = states.reshape(bs, ts, self.state_dim)
        agent_qs = agent_qs.view(bs, ts, self.n_agents)

        # Compute final output
        y = th.bmm(hidden, w_final).view(bs, ts, 1) #X+ v
        y2 = th.bmm(hidden2, w_final2).view(bs, ts, 1) #+ v
        y3 = th.bmm(hidden3, w_final3).view(bs, ts, self.n_agents) #+ v

    

        # Reshape and return
        q_tot = y + y2 + v
        q_tot2 = (agent_qs) * (th.abs(self.weight)) + y3 + v        

        return q_tot, q_tot2

