import torch

class LinearHead(torch.nn.Module):
    def __init__(self, input_dim, num_classes):
        super().__init__()
        self.fc = torch.nn.Linear(input_dim, num_classes)
        self.fc.weight.data.zero_()
        self.fc.bias.data.zero_()
    
    def forward(self, x):
        return self.fc(x)