import math
 
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init

def weights_init_classifier(m):
    classname = m.__class__.__name__
    if classname.find('Linear') != -1:
        init.normal_(m.weight.data, std=0.001)
        if m.bias is not None:
            init.constant_(m.bias.data, 0.0)

class MSoftmax(nn.Module):
    def __init__(self, embedding_size=2048, num_class=51332, **kwargs):
        super(MSoftmax, self).__init__()
        self.classnum = num_class
        self.embedding_layer = nn.Linear(embedding_size, self.classnum, bias=False)
        weights_init_classifier(self.embedding_layer)
        self.kernel = self.embedding_layer.weight
        self.ce = nn.CrossEntropyLoss()
    def forward(self, embeddings, label):
        kernel_norm = F.normalize(self.kernel, p=2, dim=1)
        # n_embeddings = F.normalize(embeddings, p=2, dim=1)
        cos_theta = torch.mm(embeddings, kernel_norm.t())
        loss = self.ce(cos_theta*1.0, label)
        return loss
