import torch
from torch_geometric.graphgym.register import register_node_encoder
from torch_geometric.graphgym.config import cfg


@register_node_encoder("MLP2Node")
class MLP2FeatureEncoder(torch.nn.Module):
    def __init__(self, dim_in, emb_dim, hidden_dim, activate_fn):
        super().__init__()

        self.encoder = torch.nn.Sequential(
            torch.nn.Linear(dim_in, hidden_dim),
            eval(activate_fn),
            torch.nn.Dropout(p=cfg.model.ffn_dropout),
            torch.nn.Linear(hidden_dim, emb_dim),
        )

    def forward(self, x):
        return self.encoder(x)
