import numpy as np
import pickle
from HIME import HIME
import hyper_math as hm
import torch
from HIME_dataset import HIME_Dataset
from sklearn.metrics import precision_recall_curve, roc_curve, auc
import networkx as nx
import argparse
import utils

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', default='dblp', type=str, help='dblp or protein_go')
    parser.add_argument('--emb_num', default=4, type=int, help='branch vector number')
    parser.add_argument('--emb_dim', default=32, type=int, help='embedding dimension')
    parser.add_argument('--epoch_num', default=50, type=int, help='epoch number')
    return parser.parse_args()


def read_labeled_data(dataset):
    if dataset == "dblp":
        file = open("dataset/dblp/label.lp.dat", 'r')
    if dataset == "protein_go":
        file = open(r"dataset/protein_go/label.lp.dat", 'r')

    left_nodes = []
    right_nodes = []
    labels = []
    for line in file.readlines():
        items = line.strip().split('.')
        node1 = int(items[0])
        node2 = int(items[1])
        label = int(items[2])
        left_nodes.append(node1)
        right_nodes.append(node2)
        labels.append(label)
    left_nodes = torch.LongTensor(left_nodes)
    right_nodes = torch.LongTensor(right_nodes)
    labels = torch.LongTensor(labels)
    return left_nodes.cuda(), right_nodes.cuda(), labels


def node_pair_retrieval(left_nodes, right_nodes, labels):
    x = -model.node_node_dist(left_nodes, right_nodes).data.cpu()
    labels = labels.reshape(-1, 1)
    #print(labels.shape)
    data = np.concatenate((x, labels), axis=1)
    #print(data)
    fpr, tpr, thresholds = roc_curve(y_true=data[:, 1],
                                     y_score=data[:, 0], pos_label=1)
    prec, recall, thresholds = precision_recall_curve(y_true=data[:, 1],
                                                      probas_pred=data[:, 0], pos_label=1)
    AUPRC = auc(recall, prec)
    AUROC = auc(fpr, tpr)
    print("AUPRC: ", AUPRC, "AUROC: ", AUROC)
    return auc


def label_path_retrieval(dataset, taxo_child2parents, level_by_level=True):
    correct = []
    if dataset == "dblp":
        test_data = r"dataset/dblp/label.taxo.dat"
    if dataset == "protein_go":
        test_data = r"dataset/protein_go/label.taxo.dat"

    file = open(test_data, 'r')
    for line in file.readlines():
        items = line.strip().split('.')
        node = int(items[0])
        category = items[1]
        category_path = utils.find_path(taxo_child2parents, category2id[category])[:-1]

        p = category2id["root"]
        for true_pos in category_path[::-1]:
            candidates = taxo_parent2children.get(p, None)
            if candidates is None:
                continue
            node_list = []
            for i in range(len(candidates)):
                node_list.append(node)
            node_list = torch.tensor(node_list, dtype=torch.long).cuda()
            tag_list = torch.tensor(candidates, dtype=torch.long).cuda()
            pred = model.node_tag_dist(node_list, tag_list).argmin()
            pred = candidates[pred]
            correct.append(pred == true_pos)
            if level_by_level:
                p = true_pos
            else:
                p = pred
    acc = np.sum(correct) * 100 / len(correct)
    print("mean ACC: ", acc)
    return acc


def get_node_tag(node_id):
    node_list = []
    for i in range(tag_num):
        node_list.append(node_id)
    tag_list = range(0, tag_num)

    tag_list = torch.tensor(tag_list, dtype=torch.long).cuda()
    node_list = torch.tensor(node_list, dtype=torch.long).cuda()

    score = model.node_tag_dist(node_list, tag_list).cpu()
    rank_score = dict()
    for k in range(tag_num):
        rank_score[k] = score[k]
    sorted_score = sorted(rank_score, key=lambda x:rank_score[x])
    ground_truth_tag = nodeid2category[node_id]
    ground_truth_label = []
    for i in range(tag_num):
        if i in ground_truth_tag:
            ground_truth_label.append(1)
        else:
            ground_truth_label.append(0)

    avg_rank = 0
    for i in ground_truth_tag:
        avg_rank += sorted_score.index(i) + 1
    avg_rank /= len(ground_truth_tag)
    return avg_rank, len(ground_truth_tag)


def get_tag_node(tag_id, node_num=1000):
    if tag_id not in category2nodeid:
        return -1
    ground_truth_node = category2nodeid[tag_id]
    if len(ground_truth_node) < 10:
        return -1

    ground_truth_label = []
    cnt = 0
    for i in range(node_num):
        if i in ground_truth_node:
            ground_truth_label.append(1)
            cnt += 1
        else:
            ground_truth_label.append(0)
    if cnt == 0:
        return -1
    tag_list = []
    for i in range(node_num):
        tag_list.append(tag_id)
    node_list = range(0, node_num)

    tag_list = torch.tensor(tag_list, dtype=torch.long).cuda()
    node_list = torch.tensor(node_list, dtype=torch.long).cuda()

    score = model.node_tag_dist(node_list, tag_list).cpu()
    rank_score = dict()
    for k in range(node_num):
        rank_score[k] = score[k]

    prec, recall, thresholds = precision_recall_curve(y_true=ground_truth_label,
                                                      probas_pred=-score.detach().numpy(), pos_label=1)
    AUPRC = auc(recall, prec)
    #print("AUPRC:", AUPRC)

    return AUPRC


def node_retrieval(tag_num):
    mtAUPRC = 0
    tot = 0
    for i in range(tag_num):
        AUPRC = get_tag_node(i)
        if AUPRC != -1:
            mtAUPRC += AUPRC
            tot += 1
    print("mean AUPRC: ", mtAUPRC/tot)


def label_retrieval():
    mMR = 0
    mTN = 0
    for i in range(100):
        MR, TN = get_node_tag(i)
        mMR += MR
        mTN += TN
    mMR /= 100
    mTN /= 100
    print("mean MR: ", mMR)


if __name__ == '__main__':
    args = parse_args()
    save_dir = "saved_model/" + args.dataset + "_emb_num_" + str(args.emb_num) + "_emb_dim_" + str(args.emb_dim)
    model_path = save_dir + "/epoch_" + str(args.epoch_num) + '.pkl'

    if args.dataset == "dblp":
        file = open("dataset/dblp/dblp.txt")
    if args.dataset == "protein_go":
        file = open("dataset/protein_go/protein_go.txt")
    items = file.readline().strip().split()
    node_num, tag_num = int(items[0]), int(items[1])
    tot_num = node_num + tag_num

    if args.dataset == "dblp":
        json_file = "dataset/dblp/taxo.json"
        taxo_file = "dataset/dblp/taxo.dat"
    if args.dataset == "protein_go":
        json_file = "dataset/protein_go/taxo.json"
        taxo_file = "dataset/protein_go/taxo.dat"
    taxo_parent2children, taxo_child2parents, nodeid2category, category2nodeid, category2id, nodeid2path = utils.read_taxos(json_file,
           taxo_file, extend_label=True)

    with open(model_path, 'rb') as f:
        model = pickle.loads(f.read()).cuda()

    a, b, c = read_labeled_data(args.dataset)
    print("Node Pair Retrieval:")
    node_pair_retrieval(a, b, c)
    print("Label-Path Retrieval:")
    label_path_retrieval(args.dataset, taxo_child2parents)
    print("Node Retrieval:")
    node_retrieval(tag_num)
    print("Label Retrieval:")
    label_retrieval()
