"""GNN taken from GANF https://github.com/EnyanDai/GANF"""
import torch.nn as nn
import torch
import torch.nn.functional as F


class GNN(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(GNN, self).__init__()

        self.lin_n = nn.Sequential(nn.Linear(input_size, hidden_size))

        self.lin_r = nn.Sequential(nn.Linear(hidden_size, input_size))

        self.act = nn.ReLU()

    def forward(self, h, a):
      
        if len(h.shape) == 4:
            h_n = self.lin_n(torch.einsum('nlkd,kj->nljd', h, a))
        else: 
            h_n = self.lin_n(torch.einsum('anlkd,kj->anljd', h, a))
        h = self.lin_r(self.act(h_n))
        return h

class graph_constructor(nn.Module):
    def __init__(self, nnodes, k, device, dim, alpha=3, static_feat=None):
        super(graph_constructor, self).__init__()
        self.device = device
        self.nnodes = nnodes
        if static_feat is not None:
            xd = static_feat.shape[1] 
            self.lin1 = nn.Linear(xd, dim)
            self.lin2 = nn.Linear(xd, dim)
        else:

            self.emb1 = nn.Embedding(nnodes, dim).to(device)
            self.emb2 = nn.Embedding(nnodes, dim).to(device)
            self.lin1 = nn.Linear(dim,dim).to(device)
            self.lin2 = nn.Linear(dim,dim).to(device)

        self.k = k
        self.dim = dim
        self.alpha = alpha
        self.static_feat = static_feat

    def forward(self, idx):
        if self.static_feat is None:
            nodevec1 = self.emb1(idx)
            nodevec2 = self.emb2(idx)
        else:
            nodevec1 = self.static_feat[idx,:]
            nodevec2 = nodevec1

        nodevec1 = torch.tanh(self.alpha*self.lin1(nodevec1))
        nodevec2 = torch.tanh(self.alpha*self.lin2(nodevec2))

        a = torch.mm(nodevec1, nodevec2.transpose(1,0))-torch.mm(nodevec2, nodevec1.transpose(1,0))
        adj = F.relu(torch.tanh(self.alpha*a))
        mask = torch.zeros(idx.size(0), idx.size(0)).to(self.device)
        mask.fill_(float('0'))
        s1,t1 = (adj + torch.rand_like(adj)*0.01).topk(self.k,1)
        mask.scatter_(1,t1,s1.fill_(1))
        adj = adj*mask
        return adj

class GCN(nn.Module):
    def __init__(self, in_dim, hidden_dims, out_dim, activation='Tanh',
                 final_activation=None, wrapper_func=None, **kwargs):
        super().__init__()

        if not wrapper_func:
            wrapper_func = lambda x: x

        hidden_dims = hidden_dims[:]
        hidden_dims.append(out_dim)
        layers = [nn.Linear(in_dim, hidden_dims[0])]
        



        for i in range(len(hidden_dims) - 1):
            layers.append(getattr(nn, activation)())
            layers.append(wrapper_func(nn.Linear(hidden_dims[i], hidden_dims[i+1])))

        
        layers[-1].bias.data.fill_(0.0)

        if final_activation is not None:
            layers.append(getattr(nn, final_activation)())

        self.net = nn.Sequential(*layers)

    def forward(self, x, **kwargs):
    
        return self.net(x)