import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn.pytorch import GATConv


class GAT(nn.Module):
    def __init__(self,
                 g,
                 num_layers,
                 in_dim,
                 num_hidden,
                 heads,
                 activation,
                 feat_drop,
                 attn_drop,
                 negative_slope):
        super(GAT, self).__init__()
        self.g = g
        self.num_layers = num_layers
        self.num_hidden = num_hidden
        self.gat_layers = nn.ModuleList()
        self.activation = activation
        self.sigm = nn.Sigmoid()
        self.disc = Discriminator(num_hidden)

        self.gat_layers.append(GATConv(
            in_dim, num_hidden, heads[0],
            feat_drop, attn_drop, negative_slope, False, self.activation))
        for l in range(1, num_layers):
            self.gat_layers.append(GATConv(
                num_hidden * heads[l - 1], num_hidden, heads[l],
                feat_drop, attn_drop, negative_slope, False, self.activation))
 
         
    def forward(self, inputs, adj):

        heads = []

        h = inputs

        for l in range(self.num_layers):

            h = self.gat_layers[l](self.g, h)

        for i in range(h.shape[1]):

            heads.append(h[:, i])
        r_m, s = self.get_sub_emb(heads, adj, m=2)

        ret = self.disc(r_m.unsqueeze(0), heads[1].unsqueeze(0), heads[2].unsqueeze(0))
        
        return heads, ret, s 
    
    
    def get_sub_emb(self, h, adj, m=2): 
        # get subgraph embedding

        t = h[0]
        r = h[1]
        e = h[2]

        adj = adj + torch.eye(adj.shape[0]).cuda(0)

        r_m = torch.mm(F.normalize(adj, p=1, dim=1), r)

        d_raw = torch.cdist(r_m, t, p=2)
        eps = 1e-12
        zero_mask = (d_raw < eps)  
        d_raw = d_raw.clamp(min=eps)

        d = (1 / torch.pow(d_raw, 2 / (m - 1))).mul(adj)
        d[zero_mask] = 0

        u = F.normalize(d, p=1, dim=1)

        s = torch.mm(u, e)

        return r_m, s

    
class Discriminator(nn.Module):
    def __init__(self, num_hidden):
        super(Discriminator, self).__init__()
        self.f_k = nn.Bilinear(num_hidden, num_hidden, 1)

        for m in self.modules():
            self.weights_init(m)

    def weights_init(self, m):
        if isinstance(m, nn.Bilinear):
            torch.nn.init.xavier_uniform_(m.weight.data)
            if m.bias is not None:
                m.bias.data.fill_(0.0)

    def forward(self, c, h_p, h_n):

        sc_1 = torch.squeeze(self.f_k(h_p, c), 2)

        sc_2 = torch.squeeze(self.f_k(h_n, c), 2)

        logits = torch.cat((sc_1, sc_2), 1)

        return logits


     