import torch
import torch.nn as nn


class Classifier(nn.Module):

    def __init__(self, input_dim: int, num_classes: int):
        super().__init__()
        self.input_dim = input_dim
        self.num_classes = num_classes

        self.fc = nn.Linear(input_dim, num_classes, bias=True)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        y = self.fc(x)
        return y
