import torch as th
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pdb

class ADERMixer(nn.Module):
    def __init__(self, args):
        super(ADERMixer, self).__init__()

        self.args = args
        self.n_agents = args.n_agents
        self.state_dim = int(np.prod(args.state_shape))

        self.embed_dim = args.mixing_embed_dim

        if getattr(args, "hypernet_layers", 1) == 1:
            self.hyper_w_1 = nn.Linear(self.state_dim, self.embed_dim * self.n_agents)
            self.hyper_w_final = nn.Linear(self.state_dim, self.embed_dim)
        elif getattr(args, "hypernet_layers", 1) == 2:
            hypernet_embed = self.args.hypernet_embed
            self.hyper_w_1 = nn.Sequential(nn.Linear(self.state_dim, hypernet_embed),
                                           nn.ReLU(),
                                           nn.Linear(hypernet_embed, self.embed_dim * self.n_agents))
            self.hyper_w_final = nn.Sequential(nn.Linear(self.state_dim, hypernet_embed),
                                           nn.ReLU(),
                                           nn.Linear(hypernet_embed, self.embed_dim))
        elif getattr(args, "hypernet_layers", 1) > 2:
            raise Exception("Sorry >2 hypernet layers is not implemented!")
        else:
            raise Exception("Error setting number of hypernet layers.")

        # State dependent bias for hidden layer
        self.hyper_b_1 = nn.Linear(self.state_dim, self.embed_dim)

        # V(s) instead of a bias for the last layers
        self.V = nn.Sequential(nn.Linear(self.state_dim, self.embed_dim),
                               nn.ReLU(),
                               nn.Linear(self.embed_dim, 1))

    def forward(self, agent_qs, states, grad=False):
        bs = agent_qs.size(0)
        states = states.reshape(-1, self.state_dim)
        agent_qs = agent_qs.view(-1, 1, self.n_agents)
        # First layer
        w1 = th.abs(self.hyper_w_1(states))
        b1 = self.hyper_b_1(states)
        w1 = w1.view(-1, self.n_agents, self.embed_dim)
        b1 = b1.view(-1, 1, self.embed_dim)
        hidden = F.elu(th.bmm(agent_qs, w1) + b1)
        # Second layer
        w_final = th.abs(self.hyper_w_final(states))
        w_final = w_final.view(-1, self.embed_dim, 1)
        # State-dependent bias
        v = self.V(states).view(-1, 1, 1)
        # Compute final output
        y = th.bmm(hidden, w_final) + v
        # Reshape and return
        q_tot = y.view(bs, -1, 1)
        if grad:
            return q_tot, th.autograd.grad(q_tot.sum(), agent_qs, retain_graph=True)[0]
        return q_tot

class ADERMixer_Ent(nn.Module):
    def __init__(self, args):
        super(ADERMixer_Ent, self).__init__()

        self.args = args
        self.n_agents = args.n_agents
        self.state_dim = int(np.prod(args.state_shape)) + args.n_agents

        self.embed_dim = args.mixing_embed_dim

        self.states_ind = []
        for i in range(self.n_agents):
            if self.args.use_cuda:
                self.states_ind.append(th.zeros(1, self.args.n_agents).cuda())
            else:
                self.states_ind.append(th.zeros(1, self.args.n_agents))
            self.states_ind[i][0, i] = 1

        if getattr(args, "hypernet_layers", 1) == 1:
            self.hyper_w_1 = nn.Linear(self.state_dim, self.embed_dim * self.n_agents)
            self.hyper_w_final = nn.Linear(self.state_dim, self.embed_dim)
        elif getattr(args, "hypernet_layers", 1) == 2:
            hypernet_embed = self.args.hypernet_embed
            self.hyper_w_1 = nn.Sequential(nn.Linear(self.state_dim, hypernet_embed),
                                           nn.ReLU(),
                                           nn.Linear(hypernet_embed, self.embed_dim * self.n_agents))
            self.hyper_w_final = nn.Sequential(nn.Linear(self.state_dim, hypernet_embed),
                                           nn.ReLU(),
                                           nn.Linear(hypernet_embed, self.embed_dim))
        elif getattr(args, "hypernet_layers", 1) > 2:
            raise Exception("Sorry >2 hypernet layers is not implemented!")
        else:
            raise Exception("Error setting number of hypernet layers.")

        # State dependent bias for hidden layer
        self.hyper_b_1 = nn.Linear(self.state_dim, self.embed_dim)

        # V(s) instead of a bias for the last layers
        self.V = nn.Sequential(nn.Linear(self.state_dim, self.embed_dim),
                               nn.ReLU(),
                               nn.Linear(self.embed_dim, 1))


    def forward(self, agent_qs, states):
        v_tot = []
        bs = agent_qs.size(0)
        bb = np.arange(self.n_agents)
        agent_qs = agent_qs.reshape(-1, self.args.n_agents)
        for i in range(self.args.n_agents):
            #states = states.reshape(-1, self.state_dim)
            states_ = states.reshape(-1, self.state_dim - self.args.n_agents)
            states_ind = self.states_ind[i].repeat(states_.shape[0], 1)
            states_ = th.cat([states_, states_ind], dim=-1)

            ind = np.concatenate([bb[i:self.n_agents], bb[0:i]])
            agent_qs_ = agent_qs[:, ind]

            # First layer
            w1 = th.abs(self.hyper_w_1(states_))
            b1 = self.hyper_b_1(states_)
            w1 = w1.view(-1, self.n_agents, self.embed_dim)
            b1 = b1.view(-1, 1, self.embed_dim)

            agent_qs_ = agent_qs_.view(-1, 1, self.n_agents)

            hidden = F.elu(th.bmm(agent_qs_, w1) + b1)
            # Second layer
            w_final = th.abs(self.hyper_w_final(states_))
            w_final = w_final.view(-1, self.embed_dim, 1)
            # State-dependent bias
            v = self.V(states_).view(-1, 1, 1)
            # Compute final output
            y = th.bmm(hidden, w_final) + v
            # Reshape and return
            v_tot_ = y.view(bs, -1, 1)
            v_tot.append(v_tot_)
            # pdb.set_trace()
        v_tot = th.cat(v_tot, dim=-1)
        return v_tot



