from __future__ import annotations
from torch import Tensor
from torch.nn.init import xavier_normal_
from torch.distributions import Dirichlet
import torch.nn.functional as F
import torch
from torch.nn import Module, Parameter, Embedding

from src.autoregressive_models.arm import AutoRegressiveModel
from external_models.gekcs.gekc_models import TractableKBCModel
from external_models.gekcs.models import KBCModel
from src.priors import LearnableMarginalPrior, FrequencyPrior, UniformPrior

class GenerativeModel(Module):
    def __init__(self,
                 arm_model: AutoRegressiveModel,
                 embedding_dim: int,
                 n_entities: int,
                 n_relations: int,
                 prior_frequencies: Tensor,
                 prior_type: bool,
                 ) -> None:
        super(GenerativeModel, self).__init__()
        self.embedding_dimension = embedding_dim
        self.n_entities = n_entities
        self.n_relations = n_relations
        self.ent_emb = self._init_embedding(self.n_entities, self.embedding_dimension)
        self.rel_emb = self._init_embedding(self.n_relations, self.embedding_dimension)
        self.prior_frequencies = prior_frequencies
        # Tail prediction
        self.arm_model = arm_model
        self._set_prior(prior_type)

    def _set_prior(self, prior_type: str):
        if prior_type == 'learnable':
            self.prior = LearnableMarginalPrior(self.n_entities)
        elif prior_type == 'frequency':
            self.prior = FrequencyPrior(self.prior_frequencies)
        elif prior_type == 'uniform':
            self.prior = UniformPrior(self.n_entities)
        else:
            raise ValueError(f"Unknown prior type: {prior_type}")

    @property
    def is_arm_model(self) -> bool:
        return isinstance(self.arm_model, AutoRegressiveModel)

    def _init_embedding(self, n_emb: int, emb_dim: int) -> Embedding:

        # test if works
        embedding = Embedding(n_emb, emb_dim)
        t = Dirichlet(torch.tensor([0.01] * emb_dim)
                      ).sample([embedding.weight.shape[0]])

        embedding.weight = Parameter(t)
        return embedding

    def scoring_function(self, h_idx: Tensor, r_idx: Tensor, t_idx: Tensor):
        """Compute the scoring function for the triplets given as argument:
        by applying convolutions to the concatenation of the embeddings. See
        referenced paper for more details on the score. See
        torchkge.models.interfaces.Models for more details on the API.

        """
        head = self.ent_emb(h_idx).unsqueeze(1)
        relation = self.rel_emb(r_idx).unsqueeze(1)
        relation_predictions, entity_predictions = self.arm_model((head, relation))
        prior_predictions = self.prior(h_idx.shape[0])

        return prior_predictions, relation_predictions, entity_predictions

    def forward(self, triple: tuple[Tensor, Tensor, Tensor]) -> Tensor:
        heads, relations, tails = triple
        pos = self.scoring_function(heads, relations, tails)
        return pos


    def _compute_joint_probability(self, entity_idx, relation_idx, all_relation_predictions, all_entity_predictions):
        batch_size = entity_idx.shape[0]
        selected_relation_probs = all_relation_predictions.gather(1, relation_idx.unsqueeze(1))
        selected_entity_priors = self.prior(batch_size).gather(1, entity_idx.unsqueeze(1))
        return (selected_entity_priors * selected_relation_probs * all_entity_predictions).squeeze(1)

    def _get_predictions(self, entity_emb, relation_emb):
        entity_emb = entity_emb.unsqueeze(1)
        relation_emb = relation_emb.unsqueeze(1)
        return self.arm_model.forward((entity_emb, relation_emb))

    def inference_tail_prediction(self, h_idx: Tensor, r_idx: Tensor, t_idx: Tensor) -> Tensor:
        """Link prediction evaluation helper function for tail prediction."""
        h_emb, r_emb = self.inference_get_embeddings(h_idx, r_idx)
        all_relation_predictions, all_tail_predictions = self._get_predictions(h_emb, r_emb)
        return self._compute_joint_probability(h_idx, r_idx, all_relation_predictions, all_tail_predictions)

    def inference_head_prediction(self, h_idx: Tensor, r_idx: Tensor, t_idx: Tensor) -> Tensor:
        """Link prediction evaluation helper function for head prediction."""
        inverse_r_idx = r_idx + int(self.n_relations / 2)
        t_emb, r_inv_emb = self.inference_get_embeddings(t_idx, inverse_r_idx)
        all_relation_predictions, all_head_predictions = self._get_predictions(t_emb, r_inv_emb)
        return self._compute_joint_probability(t_idx, inverse_r_idx, all_relation_predictions, all_head_predictions)

    def inference_get_embeddings(self, entity: torch.Tensor, relation: torch.Tensor):
        """Link prediction evaluation helper function. Get entities embeddings
        and relations embeddings. The output will be fed to the
        `inference_scoring_function` method. See torchkge.models.interfaces.Models for
        more details on the API.

        """

        t = self.ent_emb(entity)
        r_inv = self.rel_emb(relation)
        return t, r_inv


class KBCWrapper(Module):

    def __init__(self, kbc_model, embedding_dim: int, n_entities: int, n_relations: int) -> None:
        super(KBCWrapper, self).__init__()
        self.emb_dim = embedding_dim * 2
        self.kbc_model = kbc_model
        self.number_of_entities = n_entities
        self.n_ent = n_entities

    @property
    def is_arm_model(self) -> bool:
        return isinstance(self.kbc_model, AutoRegressiveModel)

    @property
    def is_kbc_model(self) -> bool:
        return isinstance(self.kbc_model, KBCModel)

    def inference_tail_prediction(self, h: torch.Tensor, r: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        queries = torch.cat((h.unsqueeze(1), r.unsqueeze(1), t.unsqueeze(1)), dim=1)
        with torch.no_grad():
            cs = self.kbc_model.get_candidates(0, self.number_of_entities, target="rhs")
            qs = self.kbc_model.get_queries(queries, target="rhs")

            if isinstance(self.kbc_model, TractableKBCModel):
                scores = self.kbc_model.eval_circuit_all(qs, cs)
            else:
                scores = qs @ cs

        return scores

    def inference_head_prediction(self, h: Tensor, r: Tensor, t: Tensor) -> Tensor:
        queries = torch.cat((h.unsqueeze(1), r.unsqueeze(1), t.unsqueeze(1)), dim=1)
        with torch.no_grad():
            cs = self.kbc_model.get_candidates(0, self.number_of_entities, target="lhs")
            qs = self.kbc_model.get_queries(queries, target="lhs")

            if isinstance(self.kbc_model, TractableKBCModel):
                scores = self.kbc_model.eval_circuit_all(qs, cs)
            else:
                scores = qs @ cs

        return scores

def get_nbf_wrapper():
    # Hard to setup environment, protect import
    from external_models.nbfnet.model import NeuralBellmanFordNetwork

    class NBFWrapper(Module):
        def __init__(self, dataset, nbf_model, embedding_dim: int, n_entities: int, n_relations: int) -> None:
            super(NBFWrapper, self).__init__()
            self.emb_dim = embedding_dim
            self.nbf_model = nbf_model
            self.dataset = dataset
            self.number_of_entities = n_entities
            self.n_ent = n_entities

        @property
        def is_arm_model(self) -> bool:
            return isinstance(self.nbf_model, AutoRegressiveModel)

        @property
        def is_kbc_model(self) -> bool:
            return isinstance(self.nbf_model, KBCModel)

        @property
        def is_nbf_model(self) -> bool:
            return isinstance(self.nbf_model, NeuralBellmanFordNetwork)

        def inference_tail_prediction(self, h: torch.Tensor, r: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
            batch = torch.cat(
                (h.unsqueeze(1), t.unsqueeze(1), r.unsqueeze(1)), dim=1)

            with torch.no_grad():
                self.scores = self.nbf_model.predict(batch)
                return self.scores[:, 0, :]

        def inference_head_prediction(self, h: Tensor, r: Tensor, t: Tensor) -> Tensor:
            return self.scores[:, 1, :]

    return NBFWrapper
