import torch
import torch.nn as nn
import torch.nn.functional as F

class ArcMarginLoss(nn.Module):
    def __init__(self, margin: float = 0.30, scale: float = 30.0,
                 reduction: str = "mean") -> None:
        super().__init__()
        self.m = margin
        self.s = scale
        if reduction not in ("mean", "sum", "none"):
            raise ValueError("reduction must be 'mean', 'sum', or 'none'")
        self.reduction = reduction

    @staticmethod
    def _normalise(x: torch.Tensor) -> torch.Tensor:
        return F.normalize(x, dim=-1, p=2)

    def _pair_loss(self, emb_src: torch.Tensor, emb_tgt: torch.Tensor) -> torch.Tensor:
        emb_src = self._normalise(emb_src)
        emb_tgt = self._normalise(emb_tgt)
        cos_theta = torch.sum(emb_src * emb_tgt, dim=-1).clamp(-1.0 + 1e-7, 1.0 - 1e-7)
        cos_theta_m = torch.cos(torch.acos(cos_theta) + self.m)
        diff = F.relu(cos_theta - cos_theta_m)

        return self.s * diff

    def forward(self, emb_src, emb_tgt):
        if isinstance(emb_src, (list, tuple)):
            losses = [self._pair_loss(s, t).mean() for s, t in zip(emb_src, emb_tgt)]
            stacked = torch.stack(losses)  # (N_surrogate,)
        else:
            stacked = self._pair_loss(emb_src, emb_tgt).mean().unsqueeze(0)

        if self.reduction == "mean":
            return stacked.mean()
        elif self.reduction == "sum":
            return stacked.sum()
        else:  # 'none'
            return stacked
