import torch


class MLP(torch.nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.dim_feat = cfg.model.feat_encoder.emb_dim
        self.encoder = torch.nn.Sequential(
            torch.nn.Linear(cfg.model.dim_in, cfg.model.feat_encoder.hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Dropout(p=cfg.model.ffn_dropout),
            torch.nn.Linear(cfg.model.feat_encoder.hidden_dim, cfg.model.feat_encoder.emb_dim),
        )

    def forward(self, x):
        return self.encoder(x)
