import torch
from torch.nn import Linear
import torch.nn.functional as F


class MLP(torch.nn.Module):
    def __init__(
        self, in_channels, out_channels, d_model=512, d_ff=2048, bias=True, dropout=0.1
    ):
        super(MLP, self).__init__()

        self.linear_in = Linear(in_channels, d_model, bias=bias)

        self.ffn = torch.nn.Sequential(
            Linear(d_model, d_ff, bias=bias),
            torch.nn.ReLU(),
            Linear(d_ff, d_model, bias=bias),
        )

        self.linear_out = Linear(d_model, out_channels, bias=bias)

        self.dropout = dropout

    def forward(self, x):
        x = self.linear_in(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = x + F.dropout(self.ffn(x), p=self.dropout, training=self.training)
        x = F.layer_norm(x, x.size()[1:])
        x = self.linear_out(x)
        return x
