
from torch import nn
import torch
from torch_geometric.nn import GCNConv
import torch.nn.functional as F

class GCN_emb(nn.Module):
    def __init__(self, nfeat, nhid, num_motif, dropout):
        super(GCN_emb, self).__init__()

        self.gc1 = GCNConv(nfeat, nhid)
        self.gc2 = GCNConv(nhid, nhid)
        self.dropout = dropout
        self.motif = nn.Embedding(num_motif, nhid)
        self.project = nn.Linear(nhid,nhid)
        self.threshold = torch.tensor([0.5])


    def forward(self, x, adj_norm, motif_num, adj_motif, original_node_num, motif_emb):
        x = self.gc1(x, adj_norm)
        motif_input = torch.arange(motif_num).cuda()
        motif_emb_1 = self.motif(motif_input)
        # new_motif_emb = torch.hstack([motif_emb_1, motif_emb])
        x_new = F.relu(torch.concatenate((x, motif_emb_1)))
        x_new = F.dropout(x_new, self.dropout, training=self.training)
        x2 = self.gc2(x_new, adj_motif)
        x2 = x2[:original_node_num, :]
        gate_initial = self.project(x2)
        gate = torch.sigmoid(gate_initial)
        gate = torch.max(gate,self.threshold)
        x2 = gate*x2
        return x2