import torch

from kge.types import GrammaticalFunction

from .base import GrammaticalEncoder


class GrammaticalComplexEncoder(GrammaticalEncoder):
    def __call__(self, x: torch.Tensor, as_: GrammaticalFunction) -> torch.Tensor:
        match as_:
            case GrammaticalFunction.SUBJECT | GrammaticalFunction.RELATION:
                return x
            case GrammaticalFunction.OBJECT:
                return torch.conj(x)
            case GrammaticalFunction.LATENT:
                return x
            case _:
                raise ValueError(f"Unknown grammatical-function {as_}")
