#! -*- coding: utf-8
import torch


class LogisticRegressionModel(torch.nn.Module):
    def __init__(self, in_features: int = 28*28, out_features: int = 10):
        super().__init__()
        self.linear = torch.nn.Linear(in_features, out_features)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if x.ndim > 2:
            x = x.flatten(start_dim=1)
        return self.linear(x)
