from torch import nn

class ProjBlock(nn.Module):
    def __init__(self, emb_dim, seq_len):
        super().__init__()
        self.linear1 = nn.Linear(emb_dim, seq_len)
        self.activation = nn.GELU()
        self.linear2 = nn.Linear(seq_len, seq_len)
    def forward(self, x):
        x = self.linear1(x)
        res = self.activation(x)
        res = self.linear2(res)
        x = x + res
        return x

class ProjMLP(nn.Module):
    def __init__(self, emb_dim, seq_len):
        super().__init__()
        self.l1 = ProjBlock(emb_dim=emb_dim, seq_len=seq_len)
        self.l2 = ProjBlock(emb_dim=seq_len, seq_len=seq_len)
        self.gelu = nn.GELU()
    def forward(self, x):
        x = self.l1(x)
        res = self.gelu(x)
        res = self.l2(res)
        x = x + res
        return x 