import torch


class MLP(torch.nn.Module):
    def __init__(self, in_size, out_size, hidden=64):
        super().__init__()
        self.fc = torch.nn.Sequential(
            torch.nn.Linear(in_size, hidden),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden, hidden),
            torch.nn.ReLU(),
        )
        self.out = torch.nn.Linear(hidden, out_size)

    def forward(self, inputs):
        inp = inputs.view(inputs.shape[0], -1)
        feat = self.fc(inp)
        out = self.out(feat)
        return out