import torch
import torch.nn as nn
import torch.nn.functional as F



class LogitNorm(nn.Module):

    def __init__(self,  t=1.0):
        super(LogitNorm, self).__init__()
        # self.device = device
        self.t = t

    def forward(self, x):
        norms = torch.norm(x, p=2, dim=-1, keepdim=True) + 1e-7
        logit_norm = torch.div(x, norms) / self.t
        return logit_norm