import torch
import torch.nn as nn
from torch.nn import functional as F

class MLP(nn.Module):

    def __init__(self, embed_dim, hidden_dim, dropout, bias=True):
        super().__init__()
        self.c_fc    = nn.Linear(embed_dim, hidden_dim, bias=bias)
        self.gelu    = nn.GELU()
        self.c_proj  = nn.Linear(hidden_dim, embed_dim, bias=bias)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        x = self.dropout(x)
        return x