import torch
import torch.nn.functional as F
from torch import autograd, nn

class MSmoothArcLoss(nn.Module):
    def __init__(self, s=64., m2=0.5, m3=0.0):
        super(MSmoothArcLoss, self).__init__()
        self.m2 = m2 # the margin value, default is 0.5
        self.m3 = m3 # the margin value defaut is 0.0
        self.s = s # scalar value default is 64, see normface https://arxiv.org/abs/1704.06369
        # have softmax
        self.ce = nn.CrossEntropyLoss()

    def forward(self, cos_affinity, label, smoothed_label=None):
        epsilon=1e-8
        cos_affinity = torch.clamp(cos_affinity, -1 + epsilon, 1 - epsilon)
        theta_m = torch.acos(cos_affinity) + self.m2
        cos_theta_m = torch.cos(theta_m)
        cos_theta_m -= self.m3
        # a little bit hacky way to prevent in_place operation on cos_theta
        output = cos_affinity * 1.0
        idx_ = torch.arange(0, len(cos_affinity), dtype=torch.long)
        output[idx_, label] = cos_theta_m[idx_, label]
        # scale up in order to make softmax work, first introduced in normface
        output *= self.s
        if smoothed_label is None:
            loss = self.ce(output, label)
        else:
            softmax_output = output.log_softmax(dim=-1)
            loss = torch.mean(torch.sum(-smoothed_label * softmax_output, dim=-1))
        return loss