import torch

from kge.types import GrammaticalFunction

from .base import KGModel
from .embedding import Embedding
from .fusing_function import FusingFunction
from .grammatical_encoding import GrammaticalEncoder
from .mixture_layer import MixtureLayer


class TailMixtureModel(KGModel):
    """TailModel where the linear output layer for objects is replaced with a mixture layer."""

    def __init__(
        self,
        embedding: Embedding,
        fusing_function: FusingFunction,
        grammatical_encoder: GrammaticalEncoder,
        mixture_layer: MixtureLayer,
        fusing_dropout: float = 0.0,
    ):
        super().__init__(
            name=f"{fusing_function.__class__.__name__}_{mixture_layer.name}",
            is_bidirectional=False,
            return_log_prob=mixture_layer.return_log_prob,
        )
        self.embedding = embedding
        self.fusing_function = fusing_function
        self.grammatical_encoder = grammatical_encoder
        self.mixture_layer = mixture_layer
        self.fusing_dropout = torch.nn.Dropout(p=fusing_dropout)

    def score_sro(self, s: torch.Tensor, r: torch.Tensor, o: torch.Tensor) -> torch.Tensor:
        x_s = self.embedding.embed_entities(s)
        x_r = self.embedding.embed_relations(r)
        x_s = self.grammatical_encoder(x_s, GrammaticalFunction.SUBJECT)
        x_r = self.grammatical_encoder(x_r, GrammaticalFunction.RELATION)
        x = self.fusing_function(x_s, x_r)
        x = self.fusing_dropout(x)
        Y_o = self.embedding.get_all_entity_embeddings()
        scores = self.mixture_layer(x, class_embeddings=Y_o)  # (b, E)
        return torch.gather(scores, 1, o.unsqueeze(1)).squeeze(1)  # (b,)

    def score_o(self, s: torch.Tensor, r: torch.Tensor) -> torch.Tensor:
        x_s = self.embedding.embed_entities(s)
        x_r = self.embedding.embed_relations(r)
        x_s = self.grammatical_encoder(x_s, GrammaticalFunction.SUBJECT)
        x_r = self.grammatical_encoder(x_r, GrammaticalFunction.RELATION)
        x = self.fusing_function(x_s, x_r)
        x = self.fusing_dropout(x)
        Y_o = self.embedding.get_all_entity_embeddings()
        return self.mixture_layer(x, class_embeddings=Y_o)

    def regularization_term(self) -> torch.Tensor:
        # return self.embedding.regularization_term() + self.mixture_layer.regularization_term()
        return self.mixture_layer.regularization_term()
