import torch
import torch.nn.functional as F

class NCDecoder(torch.nn.Module):
    def __init__(self, args, input_size, output_size):
        super(NCDecoder, self).__init__()
        self.linear = torch.nn.Linear(input_size, output_size)

    def forward(self, h):
        return self.linear(h)

class LPDecoder(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, dropout):
        super(LPDecoder, self).__init__()
        self.lins = torch.nn.ModuleList()

        self.lins.append(torch.nn.Linear(hidden_channels, out_channels))

        self.dropout = dropout
        self.reset_parameters()

    def reset_parameters(self):
        for lin in self.lins:
            lin.reset_parameters()

    def cross_layer(self, x_1, x_2):
        bi_layer = []
        for i in range(len(x_1)):
            xi = x_1[i]
            for j in range(len(x_2)):
                xj = x_2[j]
                bi_layer.append(torch.mul(xi, xj))
        bi_layer = torch.cat(bi_layer, dim=1)
        return bi_layer

    def forward(self, h, edge):
        src_x = [h[edge[0]]]
        dst_x = [h[edge[1]]]

        x = self.cross_layer(src_x, dst_x)
        for lin in self.lins[:-1]:
            x = lin(x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.lins[-1](x)
        return torch.sigmoid(x)