from abc import ABC, abstractmethod

import torch
from torch import nn


class GNNEmbedding(nn.Module, ABC):
    def __init__(self, num_entities: int, num_relations: int, dimension: int):
        super().__init__()
        self.num_entities = num_entities
        self.num_relations = num_relations
        self.dimension = dimension

    @abstractmethod
    def embed_sr(
        self,
        s: torch.Tensor,
        r: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Returns embeddings for s, r, and all entities."""

    @abstractmethod
    def regularization_term(self, lmbda: float) -> torch.Tensor:
        """Returns regularization term for all embeddings at once."""
