from typing import Protocol

import torch
from pykeen.sampling import BasicNegativeSampler, BernoulliNegativeSampler
from typing_extensions import runtime_checkable


@runtime_checkable
class NegativeSamplerProtocol(Protocol):
    """Protocol for negative samplers."""

    def sample(
        self,
        positive_batch: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
        num_samples: int = 1,
    ) -> torch.Tensor:
        """Sample negative triples for a batch of positive triples.

        Args:
            positive_batch: Tuple of (head, relation, tail) tensors
            num_samples: Number of negative samples per positive triple

        Returns:
            Tensor of shape (batch_size, 3) containing the negative triples

        """
        ...


# Type verification that PyKeen's samplers implement the protocol
dummy_triples = torch.randint(0, 100, (100, 3))
_: NegativeSamplerProtocol = BasicNegativeSampler(num_entities=100, mapped_triples=dummy_triples)
_: NegativeSamplerProtocol = BernoulliNegativeSampler(
    num_entities=100, mapped_triples=dummy_triples
)
