from __future__ import annotations

import torch
from torch import nn

from kge.types import GrammaticalFunction

from .base import ParametricGrammaticalEncoder


class GrammaticalSumEmbedder(ParametricGrammaticalEncoder):
    """Encodes grammatical information by adding learned embeddings to the input tensor"""

    def __init__(
        self,
        dim: int,
        **kwargs,
    ):
        super().__init__()
        self.dim = dim

        # Create a Parameter for each GrammaticalFunction
        for gf in GrammaticalFunction:
            setattr(self, f"{gf.name}_embedding", nn.Parameter(torch.randn(dim)))

    def forward(self, x: torch.Tensor, as_: GrammaticalFunction) -> torch.Tensor:
        embedding = getattr(self, f"{as_.name}_embedding")
        return x + embedding

    def get_grammatical_embeddings(self) -> dict[GrammaticalFunction, torch.Tensor]:
        """Get a dictionary of all grammatical function embeddings."""
        return {gf: getattr(self, f"{gf.name}_embedding") for gf in GrammaticalFunction}
