import torch


class LinearModel(torch.nn.Module):
    def __init__(self, input_dim, output_dim, mnist_like=False):
        super(LinearModel, self).__init__()
        self.mnist_like = mnist_like
        self.linear = torch.nn.Linear(input_dim, output_dim)

    def forward(self, x):
        if self.mnist_like:
            x = x.view(-1, 784)
        outputs = self.linear(x)
        return outputs
