import logging

import torch
from einops import rearrange

from kge.models.base import KGModel

from .fusing_function import FusingFunction
from .gnn_embedding import GNNEmbedding

logger = logging.getLogger(__name__)


class GNNModel(KGModel):
    def __init__(
        self,
        gnn_embedding: GNNEmbedding,
        fusing_function: FusingFunction,
        fusing_dropout: float = 0.0,
    ):
        name = f"{gnn_embedding.__class__.__name__}_{fusing_function.__class__.__name__}"
        super().__init__(
            name=name,
            is_bidirectional=False,
        )
        self.embedding = gnn_embedding
        self.fusing_function = fusing_function
        self.fusing_dropout = torch.nn.Dropout(p=fusing_dropout)

    def score_sro(self, s: torch.Tensor, r: torch.Tensor, o: torch.Tensor) -> torch.Tensor:
        """Score subject-relation-object triples.

        Args:
            s: Subject indices.
            r: Relation indices.
            o: Object indices.

        Returns:
            Scores for the given subject-relation-object triples.

        """
        all_scores = self.score_o(s, r)  # (b, E)
        return torch.gather(all_scores, 1, o.unsqueeze(1)).squeeze(1)  # (b,)

    def score_o(self, s: torch.Tensor, r: torch.Tensor) -> torch.Tensor:
        """Score all possible objects for given subject and relation.

        Args:
            s: Subject indices.
            r: Relation indices.

        Returns:
            Scores for all possible objects for the given subject and relation.

        """
        s_emb, r_emb, all_ent_emb = self.embedding.embed_sr(s, r)
        x = self.fusing_function(s_emb, r_emb)
        x = self.fusing_dropout(x)
        x = rearrange(x, "B f -> B f")
        all_ent = rearrange(all_ent_emb, "N f -> f N")
        return torch.matmul(x, all_ent)

    def regularization_term(self) -> torch.Tensor:
        # return self.embedding.regularization_term()
        return 0.0
