import logging
from collections.abc import Iterator
from typing import Protocol

import torch
from pykeen.models import ComplEx, DistMult
from torch.nn.parameter import Parameter
from typing_extensions import runtime_checkable

from .base import KGModel

logger = logging.getLogger(__name__)


@runtime_checkable
class _PyKeenModel(Protocol):
    """Protocol defining interface for KG embedding models from PyKeen."""

    def score_hrt(
        self,
        hrt: torch.Tensor,
    ) -> torch.Tensor:
        """Score head-relation-tail triples.

        Args:
            hrt: Tensor of shape (batch_size, 3) containing head-relation-tail triples

        Returns:
            Tensor of shape (batch_size,) containing the scores

        """
        ...

    def score_t(
        self,
        hr: torch.Tensor,
    ) -> torch.Tensor:
        """Score all possible tails for given head-relation pairs.

        Args:
            hr: Tensor of shape (batch_size, 2) containing head-relation pairs

        Returns:
            Tensor of shape (batch_size, num_entities) containing the scores

        """

    def score_h(
        self,
        rt: torch.Tensor,
    ) -> torch.Tensor:
        """Score all possible heads for given relation-tail pairs.

        Args:
            rt: Tensor of shape (batch_size, 2) containing relation-tail pairs

        Returns:
            Tensor of shape (batch_size, num_entities) containing the scores

        """

    def collect_regularization_term(self) -> torch.Tensor:
        """Collect regularization term for all the model's parameters at once."""

    def parameters(self) -> Iterator[Parameter]:
        """Return an iterator over the model's parameters."""


class PykeenWrapper(KGModel):
    """Wrapper around PyKeen models that implement the KGModel interface."""

    def __init__(self, model: _PyKeenModel):
        super().__init__(
            name=f"pykeen_{model.__class__.__name__}",
            is_bidirectional=isinstance(model, (DistMult, ComplEx)),
        )
        self.model = model

    def score_sro(self, s: torch.Tensor, r: torch.Tensor, o: torch.Tensor) -> torch.Tensor:
        return self.model.score_hrt(torch.stack([s, r, o], dim=-1))

    def score_s(self, r: torch.Tensor, o: torch.Tensor) -> torch.Tensor:
        if not self.is_bidirectional:
            logger.warning(
                "PykeenWrapper: score_s will be slow for %s",
                self.model.__class__.__name__,
            )
        return self.model.score_h(torch.stack([r, o], dim=-1))

    def score_o(self, s: torch.Tensor, r: torch.Tensor) -> torch.Tensor:
        return self.model.score_t(torch.stack([s, r], dim=-1))

    def regularization_term(self) -> torch.Tensor:
        return self.model.collect_regularization_term()
