import torch
import torch.nn as nn
import torch.nn.functional as F

class AAMSoftmaxLoss(nn.Module):
    def __init__(self, margin=0.2, scale=30.0):
        super().__init__()
        self.margin = margin
        self.scale = scale
        self.ce = nn.CrossEntropyLoss(reduction='sum')

    def forward(self, logits, target):
        """
        target: One-hot encoded labels [batch_size, num_classes]
        """
        # Normalize logits
        logits_norm = F.normalize(logits, dim=1)
        # Apply angular margin to the correct class only
        cosine = logits_norm
        phi = cosine - self.margin * target
        scaled = self.scale * phi
        return self.ce(scaled, target)
