from kge import Config, Dataset
from kge.model.kge_model import RelationalScorer, KgeModel
from kge.util import similarity, KgeLoss, rat

import torch
from torch import Tensor, sort
from torch.nn import functional as F
from torch import nn
from torch.nn.parameter import Parameter

from pytorch_pretrained_bert.modeling import BertEncoder, BertConfig, BertLayerNorm, BertPreTrainedModel

from functools import partial

import time
import networkx as nx
import numpy as np

from kge.util import sc
from kge.indexing import extract_subgraph_information

from kge.model.trme import TrmE, TrmEScorer


class UTGScorer(TrmEScorer):
    r"""Implementation of the TrmE KGE scorer."""

    def __init__(self, config: Config, dataset: Dataset, configuration_key=None):
        super().__init__(config, dataset, configuration_key)
        
        # embedding for the relation of unconnected (or > max_hop)
        self.unreachable_embed = Parameter(torch.Tensor(1, self.dim))
        torch.nn.init.normal_(self.unreachable_embed, std=self.initializer_range)
        

        # semantic position encoding layer
        if self.get_option('semantic_encoding'):
            self.semantic_encoder = rat.SemanticEncoding(
                self.dim, 1, self.get_option("hidden_dropout"))
            
            self.se_loss = nn.BCEWithLogitsLoss()


    def _get_encoder_output(self, e_emb, p_emb, ids, gt_ent, gt_rel, output_repr=False):
        n = p_emb.size(0)
        device = p_emb.device

        ctx_list, ctx_size = self.dataset.index('neighbor')
        ctx_ids = ctx_list[ids].to(device).transpose(1, 2)
        ctx_size = ctx_size[ids].to(device)

        # sample neighbors unifromly during training
        if self.training:
            perm_vector = sc.get_randperm_from_lengths(ctx_size, ctx_ids.size(1))
            ctx_ids = torch.gather(
                ctx_ids, 1, perm_vector.unsqueeze(-1).expand_as(ctx_ids))

        # [bs, length, 2]
        ctx_ids = ctx_ids[:, :self.max_context_size]
        ctx_size[ctx_size > self.max_context_size] = self.max_context_size

        # [bs, max_ctx_size]
        entity_ids = ctx_ids[...,0]
        relation_ids = ctx_ids[...,1]

        # initialize mask by length of context, seq=[CLS, S, N1, N2, ...]
        # 1 is the positions that will be attended to
        ctx_size = ctx_size + 2
        attention_mask = sc.get_mask_from_sequence_lengths(
            ctx_size, self.max_context_size + 2)

        if self.training and not output_repr:
            # mask out ground truth during training to avoid overfitting
            gt_mask = ((entity_ids != gt_ent.view(n, 1))
                       | (relation_ids != gt_rel.view(n, 1)))
            ctx_random_mask = (attention_mask
                               .new_ones((n, self.max_context_size))
                               .bernoulli_(1 - self.get_option("ctx_dropout")))
            attention_mask[:,2:] = attention_mask[:, 2:] & ctx_random_mask & gt_mask

        # [bs, max_ctx_size, dim]
        entity_emb = self._entity_embedder().embed(entity_ids)
        relation_emb = self._relation_embedder().embed(relation_ids)

        if self.training and self.get_option("self_dropout") > 0 and self.max_context_size > 0 and not output_repr:
            # sample a proportion of input for masked prediction similar to the MLM in BERT
            self_dropout_sample = sc.get_bernoulli_mask(
                [n], self.get_option("self_dropout"), device)

            # replace with mask tokens
            masked_sample = sc.get_bernoulli_mask(
                [n], self.get_option("mlm_mask"), device) & self_dropout_sample
            e_emb[masked_sample] = self.local_mask.unsqueeze(0)

            # replace with random sampled entities, no back propagation here
            replaced_sample = sc.get_bernoulli_mask([n], self.get_option(
                "mlm_replace"), device) & self_dropout_sample & ~masked_sample
            e_emb[replaced_sample] = self._entity_embedder().embed(torch.randint(self.dataset.num_entities(
            ), replaced_sample.shape, dtype=torch.long, device=device))[replaced_sample].detach()

        

        src = torch.cat([torch.stack([e_emb, p_emb], dim=1), torch.stack([entity_emb, relation_emb], dim=2)
                         .view(n, 2 * self.max_context_size, self.dim)], dim=1)
        src = src.reshape(n, self.max_context_size + 1, 2, self.dim)
        # print(src.shape)

        # only keep un-masked positions to reduce computational cost
        src = src[attention_mask[:, 1:]]

        # add CLS (local) and pos embedding
        pos = self.atomic_type_embeds(torch.arange(
            0, 3, device=device)).unsqueeze(0).repeat(src.shape[0], 1, 1)
        # print(pos.shape, src.shape, self.cls.expand(src.size(0), 1, self.dim).shape)
        # exit()
        src = torch.cat([self.cls.expand(src.size(0), 1, self.dim), src], dim=1) + pos

        src = F.dropout(src, p=self.get_option("output_dropout"),
                        training=self.training and not output_repr)
        src = self.atomic_layer_norm(src)

        # [bs, dim]
        out = self.atom_encoder(src,
                                self.convert_mask(src.new_ones(
                                    src.size(0), src.size(1), dtype=torch.long)),
                                output_all_encoded_layers=False)[-1][:,0]

        # recover results from output based on mask
        src = out.new_zeros(n, self.max_context_size + 1, self.dim)
        src[attention_mask[:, 1:]] = out

        # when not using graph context, exit here
        if self.max_context_size == 0:
            return src[:, 0], 0

        # begin the processing of graph context with the upper transformer block
        # add CLS (global) and pos embeddings
        src = torch.cat([self.global_cls.expand(n, 1, self.dim), src], dim=1)
        pos = self.type_embeds(torch.arange(0, 3, device=device))
        src[:, 0] = src[:, 0] + pos[0].unsqueeze(0)
        src[:, 1] = src[:, 1] + pos[1].unsqueeze(0)
        src[:, 2:] = src[:, 2:] + pos[2].view(1, 1, -1)

        src = F.dropout(src, p=self.get_option(
            "hidden_dropout"), training=self.training)
        src = self.layer_norm(src)

        bias = None
        loss = 0
        if self.get_option('semantic_encoding'):
            # for training
            # [bs, 1+max_ctx_size, dim]
            raw_emb = torch.cat([e_emb.unsqueeze(1), entity_emb], axis=1)
            scores = self.semantic_encoder(raw_emb, raw_emb)
            
            bias = scores

            if self.get_option('semantic_encoding_loss'):
                noise_neighbor_emb = self._entity_embedder().embed(torch.randint(low=0, high=self.dataset.num_entities(), size=(self.max_context_size,)
                , device=device))

                pos_score = torch.stack([scores[i, attention_mask[i, 1:]].sum() for i in range(scores.shape[0])])
                neg_score = self.semantic_encoder(e_emb.unsqueeze(1), noise_neighbor_emb.unsqueeze(0)).squeeze().sum(1)
                # [bs, 2]
                combined = torch.stack([pos_score, neg_score], 1)

                target = torch.zeros_like(combined, requires_grad=False)
                target[:, 0] = 1

                loss += self.se_loss(combined, target)

            

        out = self.transformer_encoder(
            src, None, self.convert_mask_rat(attention_mask), bias=bias)

        if output_repr:
            return out, self.convert_mask(attention_mask)

        out = out[-1][:,:2]

        # compute the mlm-like loss if needed
        if self.training and self.get_option("add_mlm_loss") and self.get_option("self_dropout") > 0.0 and self_dropout_sample.sum() > 0:
            all_entity_emb = self._entity_embedder().embed_all()
            all_entity_emb = F.dropout(all_entity_emb, p=self.get_option(
                "output_dropout"), training=self.training)
            source_scores = self.similarity(
                out[:, 1], all_entity_emb, False).view(n, -1)
            self_pred_loss = F.cross_entropy(
                source_scores[self_dropout_sample], ids[self_dropout_sample], reduction='mean')
            loss += self_pred_loss
        

        out_rp = out[:, 0]
        if self.get_option('semantic_partition'):
            sp_out, sp_loss = self._get_semantic_partition_output(e_emb, p_emb, ids, gt_ent, gt_rel, output_repr=output_repr)
            out_rp =  (out_rp + sp_out) / 2
            loss += .1*sp_loss

        return out_rp, loss
        

    def _get_semantic_partition_output(self, e_emb, p_emb, ids, gt_ent, gt_rel, output_repr=False):
        n = p_emb.size(0)
        device = p_emb.device

        ctx_list, ctx_size = self.dataset.index('semantic_neighbor')
        ctx_ids = ctx_list[ids].to(device).transpose(1, 2)
        ctx_size = ctx_size[ids].to(device)

        # sample neighbors unifromly during training
        if self.training:
            perm_vector = sc.get_randperm_from_lengths(ctx_size, ctx_ids.size(1))
            ctx_ids = torch.gather(
                ctx_ids, 1, perm_vector.unsqueeze(-1).expand_as(ctx_ids))

        # [bs, length, 2]
        ctx_ids = ctx_ids[:, :self.max_context_size]
        ctx_size[ctx_size > self.max_context_size] = self.max_context_size

        # [bs, max_ctx_size]
        entity_ids = ctx_ids[...,0]
        relation_ids = ctx_ids[...,1]

        # initialize mask by length of context, seq=[CLS, S, N1, N2, ...]
        # 1 is the positions that will be attended to
        ctx_size = ctx_size + 2
        attention_mask = sc.get_mask_from_sequence_lengths(
            ctx_size, self.max_context_size + 2)

        if self.training and not output_repr:
            # mask out ground truth during training to avoid overfitting
            gt_mask = ((entity_ids != gt_ent.view(n, 1))
                       | (relation_ids != gt_rel.view(n, 1)))
            ctx_random_mask = (attention_mask
                               .new_ones((n, self.max_context_size))
                               .bernoulli_(1 - self.get_option("ctx_dropout")))
            attention_mask[:,2:] = attention_mask[:, 2:] & ctx_random_mask & gt_mask

        # [bs, max_ctx_size, dim]
        entity_emb = self._entity_embedder().embed(entity_ids)
        # set relation as unreachable embed
        relation_emb = self.unreachable_embed.expand(relation_ids.shape[0], relation_ids.shape[1], -1)
        # print(entity_ids.shape, relation_ids.shape, entity_emb.shape, relation_emb.shape)

        if self.training and self.get_option("self_dropout") > 0 and self.max_context_size > 0 and not output_repr:
            # sample a proportion of input for masked prediction similar to the MLM in BERT
            self_dropout_sample = sc.get_bernoulli_mask(
                [n], self.get_option("self_dropout"), device)

            # replace with mask tokens
            masked_sample = sc.get_bernoulli_mask(
                [n], self.get_option("mlm_mask"), device) & self_dropout_sample
            e_emb[masked_sample] = self.local_mask.unsqueeze(0)

            # replace with random sampled entities, no back propagation here
            replaced_sample = sc.get_bernoulli_mask([n], self.get_option(
                "mlm_replace"), device) & self_dropout_sample & ~masked_sample
            e_emb[replaced_sample] = self._entity_embedder().embed(torch.randint(self.dataset.num_entities(
            ), replaced_sample.shape, dtype=torch.long, device=device))[replaced_sample].detach()

        

        src = torch.cat([torch.stack([e_emb, p_emb], dim=1), torch.stack([entity_emb, relation_emb], dim=2)
                         .view(n, 2 * self.max_context_size, self.dim)], dim=1)
        src = src.reshape(n, self.max_context_size + 1, 2, self.dim)
        # print(src.shape)

        # only keep un-masked positions to reduce computational cost
        src = src[attention_mask[:, 1:]]

        # add CLS (local) and pos embedding
        pos = self.atomic_type_embeds(torch.arange(
            0, 3, device=device)).unsqueeze(0).repeat(src.shape[0], 1, 1)
        # print(pos.shape, src.shape, self.cls.expand(src.size(0), 1, self.dim).shape)
        # exit()
        src = torch.cat([self.cls.expand(src.size(0), 1, self.dim), src], dim=1) + pos

        src = F.dropout(src, p=self.get_option("output_dropout"),
                        training=self.training and not output_repr)
        src = self.atomic_layer_norm(src)

        # [bs, dim]
        out = self.atom_encoder(src,
                                self.convert_mask(src.new_ones(
                                    src.size(0), src.size(1), dtype=torch.long)),
                                output_all_encoded_layers=False)[-1][:,0]

        # recover results from output based on mask
        src = out.new_zeros(n, self.max_context_size + 1, self.dim)
        src[attention_mask[:, 1:]] = out

        # when not using graph context, exit here
        if self.max_context_size == 0:
            return src[:, 0], 0

        # begin the processing of graph context with the upper transformer block
        # add CLS (global) and pos embeddings
        src = torch.cat([self.global_cls.expand(n, 1, self.dim), src], dim=1)
        pos = self.type_embeds(torch.arange(0, 4, device=device))
        src[:, 0] = src[:, 0] + pos[0].unsqueeze(0)
        src[:, 1] = src[:, 1] + pos[1].unsqueeze(0)
        # pos[3] for semantic neighbors
        src[:, 2:] = src[:, 2:] + pos[3].view(1, 1, -1)

        src = F.dropout(src, p=self.get_option(
            "hidden_dropout"), training=self.training)
        src = self.layer_norm(src)

        bias = None
        loss = 0
        if self.get_option('semantic_encoding'):
            # for training
            # [bs, 1+max_ctx_size, dim]
            raw_emb = torch.cat([e_emb.unsqueeze(1), entity_emb], axis=1)
            scores = self.semantic_encoder(raw_emb, raw_emb)
            if 'train_epoch' in self.meta:
                # 
                bias = scores * (1- .9 ** self.meta["train_epoch"])
            else:
                bias = scores

            # disable semantic encoding loss for semantic neighbors to avoid overfitting
            # if self.get_option('semantic_encoding_loss'):
            #     noise_neighbor_emb = self._entity_embedder().embed(torch.randint(low=0, high=self.dataset.num_entities(), size=(self.max_context_size,)
            #     , device=device))

            #     pos_score = torch.stack([scores[i, attention_mask[i, 1:]].sum() for i in range(scores.shape[0])])
            #     neg_score = self.semantic_encoder(e_emb.unsqueeze(1), noise_neighbor_emb.unsqueeze(0)).squeeze().sum(1)
            #     # [bs, 2]
            #     combined = torch.stack([pos_score, neg_score], 1)

            #     target = torch.zeros_like(combined, requires_grad=False)
            #     target[:, 0] = 1

            #     loss += 1e-2 * self.se_loss(combined, target)

            

        out = self.transformer_encoder(
            src, None, self.convert_mask_rat(attention_mask), bias=bias)

        if output_repr:
            return out, self.convert_mask(attention_mask)

        out = out[-1][:,:2]

        # compute the mlm-like loss if needed
        if self.training and self.get_option("add_mlm_loss") and self.get_option("self_dropout") > 0.0 and self_dropout_sample.sum() > 0:
            all_entity_emb = self._entity_embedder().embed_all()
            all_entity_emb = F.dropout(all_entity_emb, p=self.get_option(
                "output_dropout"), training=self.training)
            source_scores = self.similarity(
                out[:, 1], all_entity_emb, False).view(n, -1)
            self_pred_loss = F.cross_entropy(
                source_scores[self_dropout_sample], ids[self_dropout_sample], reduction='mean')
            loss += self_pred_loss
        
        return out[:, 0], loss
       


import pandas as pd

class UTG(TrmE):
    r"""Implementation of the TrmE KGE model."""

    def __init__(self, config: Config, dataset: Dataset, configuration_key=None, scorer=TrmEScorer):
        super().__init__(
            config, dataset, configuration_key=configuration_key, scorer=UTGScorer
        )
        dataset._indexes['max_hop'] = self.get_option("max_hop")
        train_triples = dataset.split('train')

        df = pd.DataFrame(train_triples.numpy(), columns=['h','r','t'])
        df.to_csv('train_set.csv')

        G = nx.DiGraph()
        for tri in train_triples:
            s, p, o = tri.tolist()
            G.add_node(s)
            G.add_node(o)
            G.add_edge(s, o, type=p)
            G.add_edge(o, s, type=p + dataset.num_relations())
        self.G = G

        pd.Series(dataset._meta['entity_ids']).to_csv('entity_ids.csv')
        # exit()


    def forward(self, fn_name, *args, **kwargs):
        # bind entity/relation embedder to scorer to retrieve embeddings
        self._scorer._entity_embedder = self.get_s_embedder
        self._scorer._relation_embedder = self.get_p_embedder


        if self.get_option('semantic_partition') and self.dataset._indexes.get('semantic_neighbor') is None:
            self.update_semantic_neighbor()

        # call score_sp/score_po during training, score_spo/score_sp_po during inference
        scores = getattr(self, fn_name)(*args, **kwargs)

        # delete references to embedder getter
        del self._scorer._entity_embedder
        del self._scorer._relation_embedder

        if fn_name == 'get_hitter_repr':
            return scores

        if self.training:
            self_loss_w = self.get_option("self_dropout")
            # MLM-like loss is weighted by the proportion of entities sampled
            self_loss_w = self_loss_w / (1 + self_loss_w)
            return self.loss(scores[0], kwargs["gt_ent"]) + self_loss_w * scores[1] * scores[0].size(0)
        else:
            return scores

    def update_semantic_neighbor(self):
        dataset = self.dataset
        G = self.G
        name = "semantic_neighbor"
        with torch.no_grad():
            embeddings = self.get_s_embedder().embed_all()
            max_neighbor_num = min(16, self.get_option("max_context_size"))


            all_neighbor = torch.zeros((dataset.num_entities(), 2, self.get_option("max_context_size")), dtype=torch.long)
            all_neighbor_num = torch.zeros(dataset.num_entities(), dtype=torch.long)


            start = time.time()

            input_nodes, se_neighbors, se_scores, st_neighbors, st_scores = [], [], [], [], []

            def sim_matrix(a, b, eps=1e-8):
                """
                added eps for numerical stability
                """
                a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None]
                a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n))
                b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n))
                sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1))
                return sim_mt

            cosine_matrix = sim_matrix(embeddings, embeddings).detach().cpu()
            for entity, ent_emb in enumerate(embeddings):
                if entity not in G:
                    continue
                # calculate the similarity
                sim_mat = cosine_matrix[entity] #self._scorer.semantic_encoder(ent_emb.unsqueeze(0), embeddings).squeeze().detach().cpu()
                # and then select top-max_neighbor_num for each entity as semantic neighbors

                one_hop_neighbor = list(G.successors(entity)) + [entity, ]

                min_sim = sim_mat[one_hop_neighbor].min() * .8

                st_neighbor = one_hop_neighbor
                st_score = sim_mat[one_hop_neighbor]

                sim_mat[one_hop_neighbor] = -1e10

                sim, neighbor = torch.topk(sim_mat, 50, dim=-1, sorted=True)

                neighbor = neighbor[sim>=min_sim][:max_neighbor_num]
                se_score = sim[sim>=min_sim]



                input_nodes.append(entity)
                se_neighbors.append(neighbor.numpy())
                se_scores.append(se_score.numpy())
                st_neighbors.append(st_neighbor)
                st_scores.append(st_score.numpy())

                
                all_neighbor[entity, 0, 0:len(neighbor)] = torch.tensor(neighbor, dtype=torch.long)
                all_neighbor_num[entity] = len(neighbor)

        # print(all_neighbor[0], all_neighbor_num[0], all_rel_chain[0], all_rel_chain_num[0])

        # generate the graph information: relation_chains, shortest path lengths
        # need dummy length vector and dummy relation to deal with unconnected semantic neighbors
        
        dataset._indexes[name] =  (all_neighbor, all_neighbor_num)
        dataset.config.log("Semantic Neighbors index finished, time:%fs" % (time.time()-start), prefix="  ")

        pd.DataFrame({'node':input_nodes, 
                    'se_neighbors':se_neighbors, 
                    'se_scores':se_scores, 
                    'st_neighbors':st_neighbors, 
                    'st_scores':st_scores}).to_csv('se_neighbors.csv')


        # del sim, neighbor, sim_mat
        return dataset._indexes.get(name)
