import torch
from torch import nn
import torch.nn.functional as F

class GraphNet(nn.Module):
    def __init__(self, input_channels, out_channels, hidden_channels):
        super(GraphNet, self).__init__()
        self.input_channels = input_channels
        self.hidden_channels = hidden_channels
        self.out_channels = out_channels

        self.layer1 = nn.Linear(self.input_channels, self.hidden_channels*2)
        self.layer2 = nn.Linear(self.hidden_channels*2, self.hidden_channels)
        self.layer3 = nn.Linear(self.hidden_channels, self.out_channels)

    def forward(self, x, adj): # x: B, L, V / adj: B, V, V
        x = x.transpose(1, 2)
        adj = adj.cuda()
        
        x = F.pad(x, (0, int(self.input_channels - x.shape[-1])), value=0)

        x = self.layer1(x)
        x = torch.bmm(adj, x)
        x = F.gelu(x)
        
        x = self.layer2(x)
        x = torch.bmm(adj, x)
        x = F.gelu(x)

        x = self.layer3(x)
        x = torch.bmm(adj, x)
        x = F.gelu(x)

        return x