from typing import Protocol

import torch
from pykeen.losses import MarginRankingLoss
from typing_extensions import runtime_checkable


@runtime_checkable
class MarginLossProtocol(Protocol):
    """Protocol for margin-based loss functions."""

    def __call__(
        self,
        positive_scores: torch.Tensor,
        negative_scores: torch.Tensor,
    ) -> torch.Tensor:
        """Calculate loss between positive and negative scores.

        Args:
            positive_scores: Scores for positive triples
            negative_scores: Scores for negative triples

        Returns:
            Loss value

        """
        ...


# Runtime protocol verification
_: MarginLossProtocol = MarginRankingLoss(margin=1.0)
