from abc import ABC, abstractmethod
from pathlib import Path

import torch
from torch import nn


class KGModel(nn.Module, ABC):
    """Base class for knowledge graph embedding models."""

    def __init__(
        self,
        name: str,
        *,
        is_bidirectional: bool = False,
        return_log_prob: bool = False,
    ):
        """Initialize the model.

        Args:
            name: Name of the model
            is_bidirectional: Whether the model supports subject prediction efficiently.
            return_log_prob: Whether the model returns log probabilities.

        """
        super().__init__()
        self.is_bidirectional = is_bidirectional
        self.name = name
        self.return_log_prob = return_log_prob

    @abstractmethod
    def score_sro(self, s: torch.Tensor, r: torch.Tensor, o: torch.Tensor) -> torch.Tensor:
        """Score subject-relation-object triples."""

    @abstractmethod
    def score_o(self, s: torch.Tensor, r: torch.Tensor) -> torch.Tensor:
        """Score all possible objects for given subject and relation."""

    def score_s(self, r: torch.Tensor, o: torch.Tensor) -> torch.Tensor:
        """Score all possible subjects for given relation and object.

        This method must be overridden if is_bidirectional=True.
        """
        if self.is_bidirectional:
            raise NotImplementedError("Bidirectional models must implement score_s")
        raise NotImplementedError("This model does not support subject prediction")

    def save_checkpoint(self, path: Path) -> None:
        """Save the model checkpoint to a file."""
        torch.save(self.state_dict(), path)

    def load_checkpoint(self, path: Path) -> None:
        """Load the model checkpoint from a file."""
        state_dict = torch.load(path, map_location=self.device)
        self.load_state_dict(state_dict)

    @abstractmethod
    def regularization_term(self) -> torch.Tensor:
        """Return regularization term for all the model's parameters at once."""
