from abc import ABC, abstractmethod

import torch
from torch import nn


class Embedding(nn.Module, ABC):
    def __init__(self, num_entities: int, num_relations: int, dimension: int):
        super().__init__()
        self.num_entities = num_entities
        self.num_relations = num_relations
        self.dimension = dimension

    @abstractmethod
    def embed_entities(
        self,
        entities_batch: torch.Tensor,
    ) -> torch.Tensor:
        pass

    @abstractmethod
    def embed_relations(
        self,
        relations_batch: torch.Tensor,
    ) -> torch.Tensor:
        pass

    @abstractmethod
    def get_all_entity_embeddings(self) -> torch.Tensor:
        """Returns embeddings for all entities at once."""

    @abstractmethod
    def regularization_term(self, lmbda: float) -> torch.Tensor:
        """Returns regularization term for all embeddings at once."""
