import torch.nn as nn


class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.c_fc = nn.Linear(config.d_model, config.d_model * 4, config.bias)
        self.gelu = nn.GELU()
        self.c_proj = nn.Linear(config.d_model * 4, config.d_model, config.bias)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        return self.dropout(x)
