import torch
import torch.nn as nn
import torch.nn.functional as F


class LinearNorm(nn.Linear):
    def __init__(
            self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None
    ):
        super().__init__(in_features, out_features, bias, device, dtype)

    def forward(self, x):
        if self.training:
            d = torch.norm(self.weight, p=2, keepdim=False)
            x = F.linear(x, self.weight / d, self.bias)
        else:
            x = F.linear(x, self.weight, self.bias)
        return x
