import torch.nn.functional as F
import torch
import torch.nn as nn
from models.layers import GraphConvNew
class GNNModel(nn.Module):
    def __init__(self, in_features, out_features, hidden_dim, dropout=0.5, n_hidden_edge=32, adj_sq=False, scale_identity=False, filters=3, n_hidden=256):
        super(GNNModel, self).__init__()
        self.gc1 = GraphConvNew(in_features, hidden_dim, activation=F.relu)
        self.gc2 = GraphConvNew(hidden_dim, out_features)
        self.dropout = nn.Dropout(dropout)


    def forward(self, data):
        x, A, mask = data[0], data[1], data[2]

        # Apply first graph convolution layer
        x = self.gc1((x, A, mask))[0]

        # Apply dropout and second graph convolution layer
        x = self.dropout(x)
        x = self.gc2((x, A, mask))[0]

        # Apply mask
        x = x * mask.unsqueeze(-1)

        # Max pooling
        x = torch.max(x, dim=1)[0]

        return x
