from torch import nn


class MLP_base(nn.Module):
    def __init__(self, n_classes, n_inputs, hidden):
        super().__init__()
        self.classifier = nn.Linear(hidden, n_classes)

        self.feature = nn.Sequential(
            nn.Linear(n_inputs, hidden),
            nn.ReLU(),
        )

    def forward(self, x):
        x = self.feature(x)
        return self.classifier(x)


def MLP(n_classes):
    return MLP_base(n_classes, 512, 400)


def reduced_MLP(n_classes):
    return MLP_base(n_classes, 512, 200)


def MLP_LLM(n_classes):
    return MLP_base(n_classes, 768, 256)


def reduced_MLP_LLM(n_classes):
    return MLP_base(n_classes, 768, 128)
