import torch
import torch.nn as nn

class Net(nn.Module):
    def __init__(self, in_size=2, out_size=2, hidden_dim=2):
        super(Net, self).__init__()

        self.in_size = in_size
        self.out_size = out_size
        self.hidden_dim = hidden_dim
        self.fc_layer_1 = nn.Linear(in_size, hidden_dim, bias=True)
        self.relu = nn.ReLU()
        self.fc_layer_2 = nn.Linear(hidden_dim, out_size, bias=True)

    def forward(self, x):
        x = torch.flatten(x, 1)
        pre_activation = self.fc_layer_1(x)
        activation = self.relu(pre_activation)
        logits = self.fc_layer_2(activation)
        return logits

