import torch
import torch.nn as nn
import torch.nn.functional as F


class UniquenessLoss(nn.Module):
    """
    Enforces uniqueness by minimizing the cosine similarity between the
    unique representation and the context representations.
    This pushes them to be orthogonal.
    """
    def __init__(self):
        super().__init__()

    def forward(self, 
                unique_repr: torch.Tensor, 
                context1_repr: torch.Tensor, 
                context2_repr: torch.Tensor = None) -> torch.Tensor:
        
        # Calculate similarity between the unique vector and the first context
        # We take the absolute value because we want them to be unaligned (positive or negative)
        loss1 = torch.abs(F.cosine_similarity(unique_repr, context1_repr.detach(), dim=-1)).mean()
        
        if context2_repr is not None:
            # Also calculate loss for the second context if it exists (for trivariate case)
            loss2 = torch.abs(F.cosine_similarity(unique_repr, context2_repr.detach(), dim=-1)).mean()
            return loss1 + loss2
        
        return loss1