import torch
import torch.nn.functional as F
from torch import nn
from torch import distributed as dist
from losses.gather import GatherLayer


class CosFaceLoss(nn.Module):
    """ CosFace Loss based on the predictions of classifier.
    Args:
        scale (float): scaling factor.
        margin (float): pre-defined margin.
    """
    def __init__(self, scale=16, margin=0.1, **kwargs):
        super().__init__()
        self.s = scale
        self.m = margin

    def forward(self, inputs, targets):
        """
        Args:
            inputs: prediction matrix (before softmax) with shape (batch_size, num_classes)
            targets: ground truth labels with shape (batch_size)
        """
        one_hot = torch.zeros_like(inputs)
        one_hot.scatter_(1, targets.view(-1, 1), 1.0)

        output = self.s * (inputs - one_hot * self.m)
        return F.cross_entropy(output, targets)
