#! /usr/bin/python
# -*- encoding: utf-8 -*-


import torch
import torch.nn as nn


class AMSoftmax(nn.Module):
    '''
    examplify:
    >>> criteria = AMSoftmax(1024, 10)
    >>> a = torch.randn(20, 1024)
    >>> lb = torch.randint(0, 10, (20, ), dtype=torch.long)
    >>> loss = criteria(a, lb)
    >>> loss.backward()
    '''
    def __init__(self,
                 embedding_size = 512,
                 num_class=10,
                 m=0.3,
                 s=15,
                 **kwargs):
        super(AMSoftmax, self).__init__()
        self.m = m
        self.s = s
        self.in_feats = embedding_size
        self.W = torch.nn.Parameter(torch.randn(embedding_size, num_class), requires_grad=True)
        self.ce = nn.CrossEntropyLoss()
        nn.init.xavier_normal_(self.W, gain=1)

    def forward(self, x, lb):
        assert x.size()[0] == lb.size()[0]
        assert x.size()[1] == self.in_feats
        x_norm = torch.norm(x, p=2, dim=1, keepdim=True).clamp(min=1e-12)
        x_norm = torch.div(x, x_norm)
        w_norm = torch.norm(self.W, p=2, dim=0, keepdim=True).clamp(min=1e-12)
        w_norm = torch.div(self.W, w_norm)
        costh = torch.mm(x_norm, w_norm)
        lb_view = lb.view(-1, 1)
        if lb_view.is_cuda: lb_view = lb_view.cpu()
        delt_costh = torch.zeros(costh.size()).scatter_(1, lb_view, self.m)
        if x.is_cuda: delt_costh = delt_costh.cuda()
        costh_m = costh - delt_costh
        costh_m_s = self.s * costh_m
        loss = self.ce(costh_m_s, lb)
        return loss