import torch
import torch.nn as nn
import torch.nn.functional as F


class LinearNorm(nn.Linear):
    def __init__(
            self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None
    ):
        super().__init__(in_features, out_features, bias, device, dtype)

    def forward(self, x):
        d = torch.norm(self.weight.t(), p=2, dim=1, keepdim=False)
        x = F.linear(x, self.weight / d, self.bias)
        return x


if __name__ == '__main__':
    net = LinearNorm(32, 10)
    ins = torch.randn((1, 32))
    print(net(ins).size())
