import logging

import torch

from kge.models.base import KGModel

from .fusing_function import FusingFunction
from .gnn_embedding import GNNEmbedding
from .mixture_layer import MixtureLayer

logger = logging.getLogger(__name__)


class GNNMixtureModel(KGModel):
    def __init__(
        self,
        gnn_embedding: GNNEmbedding,
        fusing_function: FusingFunction,
        mixture_layer: MixtureLayer,
        fusing_dropout: float = 0.0,
    ):
        name = f"{gnn_embedding.__class__.__name__}_{fusing_function.__class__.__name__}_{mixture_layer.name}"
        super().__init__(
            name=name,
            is_bidirectional=False,
            return_log_prob=mixture_layer.return_log_prob,
        )
        self.embedding = gnn_embedding
        self.fusing_function = fusing_function
        self.mixture_layer = mixture_layer
        self.fusing_dropout = torch.nn.Dropout(p=fusing_dropout)

    def score_sro(self, s: torch.Tensor, r: torch.Tensor, o: torch.Tensor) -> torch.Tensor:
        """Score subject-relation-object triples.

        Args:
            s: Subject indices.
            r: Relation indices.
            o: Object indices.

        Returns:
            Scores for the given subject-relation-object triples.

        """
        all_scores = self.score_o(s, r)  # (b, E)
        return torch.gather(all_scores, 1, o.unsqueeze(1)).squeeze(1)  # (b,)

    def score_o(self, s: torch.Tensor, r: torch.Tensor) -> torch.Tensor:
        """Score all possible objects for given subject and relation.

        Args:
            s: Subject indices.
            r: Relation indices.

        Returns:
            Scores for all possible objects for the given subject and relation.

        """
        s_emb, r_emb, all_ent_emb = self.embedding.embed_sr(s, r)
        x = self.fusing_function(s_emb, r_emb)
        x = self.fusing_dropout(x)
        return self.mixture_layer(x, class_embeddings=all_ent_emb)

    def regularization_term(self) -> torch.Tensor:
        # return self.embedding.regularization_term()
        return 0.0
