#from sparse import sparse_layer
import torch.nn as nn

class MLP(nn.Module):
    def __init__(self, indim, hiddim, outdim, dropout) -> None:
        super().__init__()
        self.Linear1 = nn.Linear(indim, hiddim[0])
        self.Linear2 = nn.Linear(hiddim[0], hiddim[1])
        self.Linear3 = nn.Linear(hiddim[1], hiddim[2])
        self.last_layer = nn.Linear(hiddim[2], outdim)

        self.relu1 = nn.ReLU()
        self.relu2 = nn.ReLU()
        self.relu3 = nn.ReLU()
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x):
        batch_size = x.shape[0]
        out = self.dropout(self.relu1(self.Linear1(x.reshape(batch_size, -1))))
        out = self.dropout(self.relu2(self.Linear2(out)))
        out = self.dropout(self.relu3(self.Linear3(out)))
        out = self.last_layer(out)

        return out

