import math

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init

def weights_init_classifier(m):
    classname = m.__class__.__name__
    if classname.find('Linear') != -1:
        init.normal_(m.weight.data, std=0.001)
        if m.bias is not None:
            init.constant_(m.bias.data, 0.0)

class softmax(nn.Module):
    def __init__(self, embedding_size=2048, num_class=51332,**kwargs):
        super(softmax, self).__init__()
        self.classnum = num_class
        self.embedding_layer = nn.Linear(embedding_size, self.classnum, bias=True)
        weights_init_classifier(self.embedding_layer)
        self.ce = nn.CrossEntropyLoss()
    def forward(self, embeddings, label):
        output = self.embedding_layer(embeddings)
        loss = self.ce(output, label)
        return loss
