import os
import random
import torch, torch.nn, torch.nn.utils
from torch.autograd import Variable
from typing import Dict, List, Tuple
from transformers import BertTokenizer

from AlignKGC import kb, evaluate
from AlignKGC.utils import update_imp_sc, get_rel_align_dict
from AlignKGC.alignkgc_base import AlignKgcBaseTrainer

from AlignKGC.BERT_alignments import Basic_Bert_Unit_model, ent2Tokens_gene, \
    ent2bert_input,get_embeddings,cos_sim_mat_generate,batch_topk


class AlignKGCmBERT(AlignKgcBaseTrainer):
    def __init__(self, **kwargs):
        super(AlignKGCmBERT, self).__init__(**kwargs)
        self.a = 100  # MAGIC
        self.b = Variable(torch.tensor(90).cuda().float(), requires_grad=True)
        self.ranker = evaluate.ranker(self.scoring_function,
                        kb.union([self.dltrain.kb, self.dlvalid.kb]+
                                [i[1].kb for i in self.dltestmap.values()]))
        #ugh i[0]=lang, i[1]=kb
        trainpath = os.path.join(self.dataset_root, 'train.txt')
        self.rel_ent_pairs, self.yylang = get_rel_align_dict(self.meta, trainpath,
                                emap=self.dltrain.kb.entity_map,
                                rmap=self.dltrain.kb.relation_map)

        BERT_MODEL_PATH = kwargs["mbert_path"]
        BERT_MODEL_FILE = os.path.join(BERT_MODEL_PATH, "combined_mBERT.p")
        TOKENIZER_PATH = "bert-base-multilingual-cased"
        dbp5lentlist_path = os.path.join(kwargs["dbp5l"], "entity_lists/")

        self.bert_align = self.get_bert_entity_alignments(
            os.path.join(self.dataset_root,'mapping.txt'),
            BERT_MODEL_PATH, BERT_MODEL_FILE, TOKENIZER_PATH,
            dbp5lentlist_path, em=self.dltrain.kb.entity_map)


    def get_bert_entity_alignments(self,
                                    entmap_path,
                                    bert_dir_path,
                                    model_pkl_path,
                                    tokenizer_path,
                                    dbp5lentlist_path,
                                    em=None):
        print("Getting Alignment from mBERT.................\n",
            "entmap_path", entmap_path, "\n", "bert_dir_path", bert_dir_path, "\n",
            "model_pkl_path", model_pkl_path, "\n", "tokenizer_path", tokenizer_path, "\n",
            "dbp5lentlist_path", dbp5lentlist_path)

        mapping = {}
        with open(entmap_path) as f:
            lines=f.readlines()
            for line in lines:
                a=line.split()
                mapping[(a[2],int(a[1]))]=a[0]

        Tokenizer = BertTokenizer.from_pretrained(tokenizer_path)
        print("tokenizer loaded from", tokenizer_path)
        Model = Basic_Bert_Unit_model(768,300,bert_dir_path).cuda()
        print("model loaded from", bert_dir_path)
        Model.load_state_dict(torch.load(open(model_pkl_path,"rb")))
        print("state_dict loaded from", model_pkl_path)

        entities, entids = self.meta.get_entities()
        ent2tokenids = ent2Tokens_gene(Tokenizer,entids)
        ent2data = ent2bert_input(entids,Tokenizer,ent2tokenids)

        emb: Dict[str, List] = dict()
        for langx in self.meta.langs:
            emb[langx] = get_embeddings(Model, entities[langx], ent2data)

        alignment : Dict[Tuple[str], List] = dict()
        for (lang1, lang2) in self.meta.lang_pairs:
            print(lang1, lang2)
            assert lang1 != lang2
            res_mat12 = cos_sim_mat_generate(emb[lang1],emb[lang2],cuda_num=0)
            score12, top_index12 = batch_topk(res_mat12,topn=1,largest=True,cuda_num=0)
            # nearest lang2 nbr for each ent in lang1
            if (lang1,lang2) not in alignment:
                alignment[(lang1,lang2)]=list(zip(score12.view(-1).tolist(),
                                                top_index12.view(-1).tolist()))
            res_mat21 = cos_sim_mat_generate(emb[lang2],emb[lang1],cuda_num=0)
            score21, top_index21 = batch_topk(res_mat21,topn=1,largest=True,cuda_num=0)
            # nearest lang1 nbr for each ent in lang2
            if (lang2,lang1) not in alignment:
                alignment[(lang2,lang1)]=list(zip(score21.view(-1).tolist(),
                                                top_index21.view(-1).tolist()))

        predicted_alignments = list()
        for (lang1, lang2) in self.meta.lang_pairs:
            print(lang1, lang2)
            assert lang1 != lang2
            align12 = alignment[(lang1, lang2)]
            align21 = alignment[(lang2, lang1)]
            for index1, (score12, index12) in enumerate(align12):
                score21, index21 = align21[index12]
                if index21==index1 and (lang1,index1) in mapping and \
                        (lang2,index12) in mapping and \
                        mapping[(lang1,index1)] != mapping[(lang2,index12)]:
                    x=mapping[(lang1,index1)]
                    y=mapping[(lang2,index12)]
                    if x in em and y in em:
                        predicted_alignments.append((em[x], score12, em[y]))

        return predicted_alignments

    def ent_alignment_loss(self, align, E_re, E_im):
        loss=[]
        for ent1,sc,ent2 in align:
            p1=torch.tensor([ent1],device=self.torchdev)
            p2=torch.tensor([ent2],device=self.torchdev)
            loss.append(sc*(E_re(p1)-E_re(p2)).abs().mean()
                        +sc*(E_im(p1)-E_im(p2)).abs().mean())
        return sum(loss)/len(loss)

    def rel_alignment_loss(self, losstype="L1", entity_bactrack=0):
        R_re = self.scoring_function.R_re
        R_im = self.scoring_function.R_im
        loss : float=0
        cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
        for lid in self.meta.lids():
            lidprefix = str(lid)
            for rel in self.equiv_rel[lidprefix]:
                sc,rel2 = self.equiv_rel[lidprefix][rel]
                if entity_bactrack==0:
                    sc=sc.detach()
                p1=torch.tensor([rel],device=self.torchdev)
                p2=torch.tensor([rel2],device=self.torchdev)
                if losstype=="L1":
                    loss=loss+((sc*R_re(p1)-sc*R_re(p2)).abs().mean()+
                               (sc*R_im(p1)-sc*R_im(p2)).abs().mean())
                elif losstype == "cos":
                    loss=loss+(sc*(1-cos(R_re(p1),R_re(p2)))+
                               sc*(1-cos(R_im(p1),R_im(p2))))
                else:
                    raise ValueError(losstype)
        return loss

    def step(self, loop_no) -> Dict[str, float]:
        """
        Use all ent as neg sample
        """
        flag_using_full_softmax = 0

        if self.negative_sample_count == 0 or self.loss.name == 'crossentropy_loss_AllNeg_subsample':  # use all ent as neg sample
            ns = None
            no = None
            nr = None
            s, r, o, _, _ = self.dltrain.tensor_sample(self.batch_size, 1)
            flag_using_full_softmax = 1
        else:
            s, r, o, ns, no = self.dltrain.tensor_sample(self.batch_size, self.negative_sample_count)

        flag = random.randint(1,10001)
        if flag>9950:
            flag_debug = 1
        else:
            flag_debug = 0
        fp  = self.scoring_function(s, r, o,  flag_debug=flag_debug+1)
        fno = self.scoring_function(s, r, no, flag_debug=flag_debug+1)
        fns = self.scoring_function(ns, r, o, flag_debug=flag_debug+1)
        fnr = self.scoring_function(s, nr, o, flag_debug=flag_debug+1)

        if self.regloss_coeff > 0:
            reg = self.regularizer(s, r, o)#, reg_val=3) #+ self.regularizer(ns, r, o) + self.regularizer(s, r, no)
            if self.scoring_function.reg != 3:
                reg = reg/self.batch_size#/(self.batch_size*self.scoring_function.embedding_dim)
                ##dividing by dim size is a bit too much!!
        else:
            reg = 0  # TODO should be cast to tensor?
        reg_loss = self.regloss_coeff*reg

        kbc_loss = None
        if flag_using_full_softmax:  # use all ent as neg sample
            if self.loss.name == 'crossentropy_loss_AllNeg_subsample':
                kbc_loss = self.loss(s, fns, self.negative_sample_count) + \
                           self.loss(o, fno, self.negative_sample_count)
            else:
                if self.flag_add_reverse==0 or self.scoring_function.flag_avg_scores:
                    kbc_loss = self.loss(s, fns) + self.loss(o, fno)
                else:
                    kbc_loss = self.loss(o, fno)
        else:
            kbc_loss = self.loss(fp, fns) + self.loss(fp, fno)

        ea_loss=torch.zeros(1, device=self.torchdev)
        ra_loss=torch.zeros(1, device=self.torchdev)
        if loop_no%10==0:  # MAGIC
            if loop_no%5000==0 or (loop_no>13000 and loop_no%1000==0):
                self.equiv_rel = update_imp_sc(self.meta, self.rel_ent_pairs,
                                               self.yylang,
                                               self.scoring_function.E_re,
                                               self.scoring_function.E_im,
                                               self.a, self.b)
            if loop_no>13000 and loop_no%50==0:  # MAGIC
                    f=0.01* self.rel_alignment_loss(losstype="cos",entity_bactrack=1)
                    f.backward(retain_graph=True)
                    # print("b",self.b,self.b.grad)
                    self.b.data=self.b-0.8*self.b.grad.data  # MAGIC
                    self.b.grad.data.zero_()
                    # print("b.grad", self.b.grad)
            else :
                ra_loss=self.raloss_coeff* self.rel_alignment_loss(losstype="L1")
            if loop_no%50==0:
                ea_loss= self.ealoss_coeff * self.ent_alignment_loss(self.bert_align, self.scoring_function.E_re, self.scoring_function.E_im)
        total_loss = reg_loss + kbc_loss + ra_loss + ea_loss
        self.optim.zero_grad()
        total_loss.backward()
        if self.gradient_clip:
            torch.nn.utils.clip_grad_norm(self.scoring_function.parameters(),
                                          self.gradient_clip)
        self.optim.step()
        debug = ""
        if "post_epoch" in dir(self.scoring_function):
            debug = self.scoring_function.post_epoch()
        return { "total": total_loss.item(),
            "kbc": kbc_loss.item(), "ea": ea_loss.item(), "ra": ra_loss.item(),
            "reg": reg_loss.item()}
