import torch

# Used only in MLP2L in TransformerBlock
class GLU(torch.nn.Module):
    def __init__(self, d0, d1, bias=True, act=torch.nn.ReLU()):
        super().__init__()

        self.d0 = d0
        self.d1 = d1
        self.bias = bias
        self.act = act
        
        self.gate = torch.nn.Sequential(torch.nn.Linear(d0, d1, bias), act)

        self.proj = torch.nn.Linear(d0, d1, bias)

    def forward(self, x):
        y = self.gate(x) * self.proj(x)

        return y

class MLP2L(torch.nn.Module):
    def __init__(self, d0, d1, d2, bias=True, act=torch.nn.ReLU(), dropout=0, l1_type="linear"):
        super().__init__()

        self.d0 = d0
        self.d1 = d1
        self.d2 = d2
        self.bias = bias
        self.act = act
        self.dropout = dropout
        self.l1_type = l1_type

        if l1_type=="linear":
            self.l1 = torch.nn.Sequential(torch.nn.Linear(d0, d1, bias), act)
            self.l2 = torch.nn.Linear(d1, d2, bias)
        elif l1_type=="glu":
            # Memory of linear: d0*d1 + d1*d0 = 2d0*d1
            # Memory of GLU: 2d0*d1_ + d1_*d0 = 3d0*d1_
            # Equating memory: 2d0*d1 = 3d0*d1_ <=> d1_ = (2/3)d0
            # Computation of linear: d0*d1 + d1*d0 = 2d0*d1
            # Computation of GLU: 2d0*d1_ + d1_ + d1_*d0 = 3d0*d1_+d1_
            # Equating computation: 2d0*d1 = 3d0*d1_+d1_ <=> d1_ = [2d0/(3d0+1)]d1 ~ (2/3)d1
            self.d1_ = (2*d1)//3
            self.l1 = GLU(d0, self.d1_, bias, act)
            self.l2 = torch.nn.Linear(self.d1_, d2, bias)

    def forward(self, x):
        a1 = self.l1(x)
        a1 = torch.nn.functional.dropout(a1, p=self.dropout, training=self.training)

        y = self.l2(a1)

        return y

class MLP3L(torch.nn.Module):
    def __init__(self, d0, d1, d2, d3, bias=True, act=torch.nn.ReLU(), dropout=0):
        super().__init__()

        self.d0 = d0
        self.d1 = d1
        self.d2 = d2
        self.d3 = d3
        self.bias = bias
        self.act = act
        self.dropout=dropout

        self.l1 = torch.nn.Linear(d0, d1, bias)
        self.l2 = torch.nn.Linear(d1, d2, bias)
        self.l3 = torch.nn.Linear(d2, d3, bias)

    def forward(self, x):
        z1 = self.l1(x)
        a1 = self.act(z1)
        a1 = torch.nn.functional.dropout(a1, p=self.dropout, training=self.training)

        z2 = self.l2(a1)
        a2 = self.act(z2)
        a2 = torch.nn.functional.dropout(a2, p=self.dropout, training=self.training)

        y = self.l3(a2)

        return y
