# run with -Xutf8 under Windows

import torch
import argparse
from os import path

from kge.model import KgeModel
from kge.util.io import load_checkpoint

e2n = {}
n2e = {}
r2n = {}
n2r = {}


def rescore(modelpath_kge, rankingpath_ab, model_name):
    
    checkpoint = load_checkpoint(modelpath_kge)
    model = KgeModel.create_from(checkpoint)
    f = open(rankingpath_ab)
    out = open(rankingpath_ab + "-" + model_name, "w")
    counter = 0
    (s,o,r) = ('','','')
    for x in f:
        counter = counter + 1
        if (counter % 1000 == 0):
            print(">>> went through " + str(counter) + " lines in the ranking file")
        if (counter % 3 == 1):
            (s, r, o) = x.split()
            out.write(x)
        if (counter % 3 == 2 or counter % 3 == 0):
            token = x.split()
            if (token[0] == "Heads:"):
                out.write("Heads: ")
                rr = []
                oo = []
                ss = []
                for i in range(1, len(token), 2):
                    s_predicted = token[i]
                    ss.append(e2n[s_predicted])
                    rr.append(r2n[r])
                    oo.append(e2n[o])
                if len(ss) > 0:
                    scores = get_kge_scores(model, ss, rr, oo, "s")
                    for i, score in enumerate(scores):
                        out.write(n2e[ss[i]] + "\t" + str(score.item()) + "\t")
                out.write("\n")
            if (token[0] == "Tails:"):
                out.write("Tails: ")
                rr = []
                oo = []
                ss = []
                for i in range(1, len(token), 2):
                    o_predicted = token[i]
                    ss.append(e2n[s])
                    rr.append(r2n[r])
                    oo.append(e2n[o_predicted])
                if len(ss) > 0:
                    scores = get_kge_scores(model, ss, rr, oo, "o")
                    for i, score in enumerate(scores):
                        out.write(n2e[oo[i]] + "\t" + str(score.item()) + "\t")
                out.write("\n")
    out.close()
    f.close()


def get_kge_score(model, s, r, o, direction):
    global n2e, e2n, r2n, n2r
    t_s = torch.Tensor([e2n[s],]).long()
    t_r = torch.Tensor([r2n[r],]).long()
    t_o = torch.Tensor([e2n[o],]).long()
    score = model.score_spo(t_s, t_r, t_o, direction)
    return score.item()

def get_kge_scores(model, ss, rr, oo, direction):
    global n2e, e2n, r2n, n2r
    t_s = torch.Tensor(ss).long()
    t_r = torch.Tensor(rr).long()
    t_o = torch.Tensor(oo).long()
    scores = model.score_spo(t_s, t_r, t_o, direction)
    return scores
    

def read_mapping(path, n2x, x2n):
    f = open(path, "r")
    counter = 0
    for x in f:
        token = x.split()
        if len(token) == 2:
            counter += 1
            n2x[int(token[0])] = token[1]
            x2n[token[1]] = int(token[0])
    f.close()
    print(">>> read ", counter, "mappings from", path,"between entity/relation names and libKGE numbers")

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset_folder", required=True, help="LibKGE dataset folder.")
    parser.add_argument("--checkpoint", required=True, help="LibKGE model checkpoint (.pt file).")
    parser.add_argument("--ab_ranking", required=True, help="AnyBURL ranking file. Output is written to folder where this file is located.")
    parser.add_argument("--model_name", required=True, help="Name of the model used for the output file name.")
    args = parser.parse_args()
    return args

if __name__ == "__main__":
    args = parse_args()

    libkge_datapath = args.dataset_folder
    rankingpath_ab = args.ab_ranking
    mn = args.model_name
    cp = args.checkpoint


    read_mapping(path.join(libkge_datapath,"entity_ids.del"), n2e, e2n)
    read_mapping(path.join(libkge_datapath, "relation_ids.del"), n2r, r2n)

    print("*** converting", cp.split(path.sep)[-1], "to anyburl ranking format")
    rescore(cp, rankingpath_ab, mn)
