from __future__ import annotations

from abc import ABC, abstractmethod

import torch
from torch import nn

from kge.types import GrammaticalFunction


class GrammaticalEncoder(ABC):
    @abstractmethod
    def __call__(self, x: torch.Tensor, as_: GrammaticalFunction) -> torch.Tensor:
        pass

    def encode_subject(self, x: torch.Tensor) -> torch.Tensor:
        return self(x, as_=GrammaticalFunction.SUBJECT)

    def encode_object(self, x: torch.Tensor) -> torch.Tensor:
        return self(x, as_=GrammaticalFunction.OBJECT)

    def encode_relation(self, x: torch.Tensor) -> torch.Tensor:
        return self(x, as_=GrammaticalFunction.RELATION)


class ParametricGrammaticalEncoder(nn.Module, GrammaticalEncoder):
    def __init__(self):
        super().__init__()

    def __call__(self, x: torch.Tensor, as_: GrammaticalFunction) -> torch.Tensor:
        return super().__call__(x, as_)

    @abstractmethod
    def forward(self, x: torch.Tensor, as_: GrammaticalFunction) -> torch.Tensor:
        pass
