import logging

import torch
from einops import rearrange, reduce

from kge.models.base import KGModel
from kge.types import GrammaticalFunction

from .embedding import Embedding
from .fusing_function import EncoderBottleneck, FusingFunction
from .grammatical_encoding import GrammaticalEncoder

logger = logging.getLogger(__name__)


class TailModel(KGModel):
    """KGEs where s and r are embedded together and o is predicted using a matrix multiplication."""

    def __init__(
        self,
        embedding: Embedding,
        grammatical_encoder: GrammaticalEncoder,
        fusing_function: FusingFunction,
        fusing_dropout: float = 0.0,
    ):
        name = f"{fusing_function.__class__.__name__}"
        if isinstance(fusing_function, EncoderBottleneck):
            name += f"@{fusing_function.bottleneck_dimension}"
        super().__init__(
            name=name,
            is_bidirectional=False,
        )
        self.embedding = embedding
        self.fusing_function = fusing_function
        self.grammatical_encoder = grammatical_encoder
        self.fusing_dropout = torch.nn.Dropout(p=fusing_dropout)

    def score_sro(self, s: torch.Tensor, r: torch.Tensor, o: torch.Tensor) -> torch.Tensor:
        """Score subject-relation-object triples."""
        x_s = self.embedding.embed_entities(s)
        x_r = self.embedding.embed_relations(r)
        y_o = self.embedding.embed_entities(o)

        # Apply grammatical encoding
        x_s = self.grammatical_encoder(x_s, GrammaticalFunction.SUBJECT)
        x_r = self.grammatical_encoder(x_r, GrammaticalFunction.RELATION)
        y_o = self.grammatical_encoder(y_o, GrammaticalFunction.OBJECT)

        # Fuse subject and relation embeddings and apply dropout
        x = self.fusing_function(x_s, x_r)
        x = self.fusing_dropout(x)
        return reduce(x * y_o, "b f -> b", "sum")  # TODO: Return torch.real for complex embeddings

    def score_o(self, s: torch.Tensor, r: torch.Tensor) -> torch.Tensor:
        """Score all possible objects for given subject and relation."""
        x_s = self.embedding.embed_entities(s)
        x_r = self.embedding.embed_relations(r)

        # Apply grammatical encoding
        x_s = self.grammatical_encoder(x_s, GrammaticalFunction.SUBJECT)
        x_r = self.grammatical_encoder(x_r, GrammaticalFunction.RELATION)

        # Fuse subject and relation embeddings and apply dropout
        x = self.fusing_function(x_s, x_r)
        x = self.fusing_dropout(x)

        # Prepare for matrix multiplication: (batch_size, features) @ (features, num_entities)
        Y_o = self.embedding.get_all_entity_embeddings()
        x = rearrange(x, "B f -> B f")
        Y_o = rearrange(Y_o, "N f -> f N")
        return torch.matmul(x, Y_o)

    def regularization_term(self) -> torch.Tensor:
        # return self.embedding.regularization_term()
        return 0.0
