import torch
import torch.nn.functional as F
import torch.nn as nn
from torch_geometric.nn import GCNConv, global_mean_pool, GINConv

class MotifEncoder(torch.nn.Module):
    def __init__(self, in_channels,out_channels):
        super().__init__()
        hidden_channels = 16
        # mlp = nn.Sequential(
        #     nn.Linear(in_channels, hidden_channels),
        #     nn.ReLU(),
        #     nn.Linear(hidden_channels,out_channels)
        # )
        # self.conv = GINConv(mlp)
        self.conv = GCNConv(in_channels,out_channels)

    def forward(self, x, edge_index):
        x = x.unsqueeze(-1)
        #edge_index = edge_index.to_sparse()
        x = self.conv(x, edge_index)
        x = F.relu(x)
        # Global mean pooling to get graph-level representation
        return x