import torch
import torch.nn as nn
import torch.nn.functional as F


class head(nn.Linear):
    """Build the head to outputs."""
    def __init__(self, 
                 num_features: int, 
                 num_classes: int,
                 use_lln: bool, 
                 bias: bool = True) -> None:
        super(head, self).__init__(num_features, num_classes, bias)
        self.use_lln = use_lln
        self.weight.data.copy_(self.get_weight())


    def forward(self, x: torch.Tensor) -> torch.Tensor:

        weight = self.get_weight()
        x = F.linear(x, weight, self.bias)
        return x

    def get_weight(self):

        if self.use_lln:
            weight = F.normalize(self.weight.double(), dim=1).to(self.weight.dtype)
        else:
            weight = self.weight

        return weight
