import torch
from torch_geometric.graphgym.register import register_node_encoder


@register_node_encoder("LinearNode")
class LinearFeatureEncoder(torch.nn.Module):
    def __init__(self, dim_in, emb_dim, hidden_dim=None, activate_fn=None):
        super().__init__()

        self.encoder = torch.nn.Linear(dim_in, emb_dim)

    def forward(self, x):
        return self.encoder(x)
