import torch


class MLP(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(MLP, self).__init__()
        self.hidden_layers = torch.nn.Sequential(
            torch.nn.Linear(input_dim, hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_dim, hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_dim, hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_dim, output_dim),
        )

    def forward(self, x, t, y=None, **kwargs):
        if len(t.shape) == 0:
            t = torch.full((x.shape[0], 1), t, device=x.device)

        t = t.view(-1, 1)

        if y is not None:
            x_in = torch.cat((x, t, y), dim=1)
        else:
            x_in = torch.cat((x, t), dim=1)

        return self.hidden_layers(x_in)
