import torch.nn as nn

class Net(nn.Module):

    def __init__(self, in_features, out_features):
        super().__init__()

        self.in_features = in_features
        self.out_features = out_features

        self.decoder = nn.Sequential(
            nn.Linear(self.in_features, self.out_features),
        )
    
    def forward(self, x, is_pre=True):

        # x = torch.reshape(x, (x.size(dim=0), -1))
        x = x.view(x.size(dim=0), -1)
        x = self.decoder(x) 

        return x