import math
import torch
import torch.nn as nn

class AngularMargin(nn.Module):
    """
    EXPERIMENTAL DO NOT USE.
    """
    def __init__(self, margin, scale):
        super().__init__()
        self._margin = margin
        self._scale = scale
        self._cos_m = math.cos(self._margin)
        self._sin_m = math.sin(self._margin)
        self._msin_m = self._margin * self._sin_m
        self.ce = nn.CrossEntropyLoss()

    def forward(self, outputs, ys):
        logits_gt = torch.gather(outputs, dim=1, index=ys.view(-1, 1)).view(-1) # logits corresponding to groundtruth class
        cos_t = logits_gt
        sin_t = torch.sqrt(1.0 - (cos_t**2))
        cos_t_plus_m = cos_t * self._cos_m - sin_t * self._sin_m
        angular_margin_logits_gt = ((cos_t + self._cos_m) > 0.).float() * cos_t_plus_m + ((cos_t + self._cos_m) <= 0.).float() * (cos_t - self._msin_m)
        angular_outputs = outputs.clone()
        angular_outputs[torch.arange(outputs.size(0)), ys] = angular_margin_logits_gt
        return self.ce(self._scale * angular_outputs, ys)
