#! /usr/bin/python
# -*- encoding: utf-8 -*-

import math

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init

'''
def l2_norm(input,axis=1):
    norm = torch.norm(input,2,axis,True)
    output = torch.div(input, norm)
    return output
'''

def weights_init_classifier(m):
    classname = m.__class__.__name__
    if classname.find('Linear') != -1:
        init.normal_(m.weight.data, std=0.001)
        if m.bias is not None:
            init.constant_(m.bias.data, 0.0)


class Arcface(nn.Module):
    def __init__(self, embedding_size=2048, num_class=51332,  s=64., m2=0.5, m3=0.0, **kwargs):
        super(Arcface, self).__init__()
        self.classnum = num_class
        '''
        self.kernel_ = nn.Parameter(torch.Tensor(embedding_size,num_class))
        self.kernel = self.kernel_.cuda()
        # initial kernel
        self.kernel.data.uniform_(-1, 1).renorm_(2,1,1e-5).mul_(1e5)
        '''
        self.embedding_layer = nn.Linear(embedding_size, self.classnum, bias=False)
        weights_init_classifier(self.embedding_layer)
        # self.kernel = self.embedding_layer.weight
        self.m2 = m2 # the margin value, default is 0.5
        self.m3 = m3 # the margin value defaut is 0.0
        self.s = s 
        '''
        m = m2
        self.cos_m = math.cos(m)
        self.sin_m = math.sin(m)
        self.mm = self.sin_m * m  # issue 1
        '''
        self.ce = nn.CrossEntropyLoss()
        # self.ce = nn.NLLLoss()
        # self.threshold = math.cos(math.pi - m2)
    def forward(self, embeddings, label):
        # weights norm and feature normalize
        nB = len(embeddings)
        # kernel_norm = l2_norm(self.kernel,axis=0)
        kernel_norm = F.normalize(self.embedding_layer.weight, p=2, dim=1)
        n_embeddings = F.normalize(embeddings, p=2, dim=1)
        #cos(theta+m)
        cos_theta = torch.mm(n_embeddings, kernel_norm.t())
        # cos_theta = self.embedding_layer(n_embeddings)
        '''
        cos_theta = cos_theta.clamp(-1,1) # for numerical stability
        cos_theta_2 = torch.pow(cos_theta, 2)
        sin_theta_2 = 1 - cos_theta_2
        sin_theta = torch.sqrt(sin_theta_2)
        cos_theta_m = (cos_theta * self.cos_m - sin_theta * self.sin_m)
        '''
        
        theta_m = torch.acos(cos_theta) + self.m2
        cos_theta_m = torch.cos(theta_m)
        cos_theta_m -= self.m3
        
        '''
        # this condition controls the theta+m should in range [0, pi]
        #      0<=theta+m<=pi
        #     -m<=theta<=pi-m
        cond_v = cos_theta - self.threshold
        cond_mask = cond_v <= 0
        # when theta not in [0,pi], use cosface instead
        keep_val = (cos_theta - self.m3)
        # keep_val = (cos_theta - self.mm) 
        cos_theta_m[cond_mask] = keep_val[cond_mask]
        '''
        # a little bit hacky way to prevent in_place operation on cos_theta
        output = cos_theta * 1.0
        idx_ = torch.arange(0, nB, dtype=torch.long)
        output[idx_, label] = cos_theta_m[idx_, label]
        # scale up in order to make softmax work, first introduced in normface
        output *= self.s
        loss = self.ce(output, label)
        return loss
