import torch.nn as nn
import torch.nn.functional as F


class MLP(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.dim_pe = cfg.model.pos_encoder.emb_dim
        dim_in = cfg.model.num_eigenvecs
        hidden_dim = cfg.model.pos_encoder.hidden_dim

        self.encoder = nn.Sequential(
            nn.Linear(dim_in, hidden_dim),
            nn.ReLU(),
            nn.Dropout(p=cfg.model.ffn_dropout),
            nn.Linear(hidden_dim, self.dim_pe),
        )

    def forward(self, x, batch):
        return self.encoder(F.normalize(x, dim=-1))
