import argparse
import json
import logging
import os
import random
import collections
import math
import numpy as np
import torch
from torch.utils.data import DataLoader
from dataset import TestDataset
from model import KGReasoning
import time
import pickle
from collections import defaultdict
from tqdm import tqdm
from util import flatten_query, list2tuple, parse_time, set_global_seed, eval_tuple
from torchmetrics import SpearmanCorrCoef
close_cri = 0.001
query_name_dict_classify = {
    '(p,(e))':{
        ('e', ('ap',)): '1ap',
        ('e', ('rp',)): '1p',
        ('nv', ('np',)):'np',
        ('nv', ('rap',)): '1rap',
    },
    '(p,(p,(e)))':{
        (('e', ('ap',)), ('np',)):'an',
        (('e', ('ap',)), ('rap',)): 'ar',
        (('e', ('rp',)), ('ap',)): 'pa',
        (('e', ('rp',)), ('rp',)): '2p',
        (('nv', ('rap',)), ('rp',)): 'rp',
        (('nv', ('rap',)), ('ap',)): 'ra',
        (('nv', ('np',)), ('np',)):'nn',
        (('nv', ('np',)), ('rap',)):'nr',
    },
    '(i,(p,(e)),(p,(e)))':{
        (('e', ('rp',)), ('nv', ('rap',)), ('i',)): 'pri',
        (('e', ('rp',)), ('e', ('rp',)), ('i',)): '2pi',
        (('e', ('ap',)), ('e', ('ap',)), ('i',)): '2ai',
        (('e', ('ap',)), ('nv', ('np',)), ('i',)):'ani',
        (('nv', ('rap',)), ('e', ('rp',)), ('i',)): 'rpi',
        (('nv', ('rap',)), ('nv', ('rap',)), ('i',)): '2ri',
        (('nv', ('np',)), ('nv', ('np',)), ('i',)):'nni',
        (('nv', ('np',)), ('e', ('ap',)), ('i',)):'nai',
    },
    '(i,(p,(p,(e))),(p,(e)))':{
        ((('e', ('rp',)), ('rp',)), ('e', ('rp',)), ('i',)): 'pppi',
        ((('e', ('rp',)), ('rp',)), ('nv', ('rap',)), ('i',)): 'ppri',
        ((('e', ('rp',)), ('ap',)), ('e', ('ap',)), ('i',)): 'paai',
        ((('e', ('rp',)), ('ap',)), ('nv', ('np',)), ('i',)):'pani',
        ((('e', ('ap',)), ('rap',)), ('e', ('rp',)), ('i',)): 'arpi',
        ((('e', ('ap',)), ('rap',)), ('nv', ('rap',)), ('i',)): 'arri',
        ((('e', ('ap',)), ('np',)), ('nv', ('np',)), ('i',)):'anni',
        ((('e', ('ap',)), ('np',)), ('e', ('ap',)), ('i',)):'anai',
        ((('nv', ('np',)), ('rap',)), ('e', ('rp',)), ('i',)):'nrpi',
        ((('nv', ('np',)), ('rap',)), ('nv', ('rap',)), ('i',)):'nrri',
        ((('nv', ('np',)), ('np',)), ('e', ('ap',)), ('i',)):'nnai',
        ((('nv', ('np',)), ('np',)), ('nv', ('np',)), ('i',)):'nnni',
        ((('nv', ('rap',)), ('ap',)), ('nv', ('np',)), ('i',)):'rani',
        ((('nv', ('rap',)), ('ap',)), ('e', ('ap',)), ('i',)):'raai',
        ((('nv', ('rap',)), ('rp',)), ('e', ('rp',)), ('i',)): 'rppi',
        ((('nv', ('rap',)), ('rp',)), ('nv', ('rap',)), ('i',)): 'rpri',
    },
    '(p,(i,(p,(e)),(p,(e))))':{
        ((('e', ('rp',)), ('e', ('rp',)), ('i',)), ('ap',)): '2pia',
        ((('e', ('rp',)), ('e', ('rp',)), ('i',)), ('rp',)): '2pip',
        ((('e', ('rp',)), ('nv', ('rap',)), ('i',)), ('rp',)): 'prip',
        ((('e', ('rp',)), ('nv', ('rap',)), ('i',)), ('ap',)): 'pria',
        ((('e', ('ap',)), ('e', ('ap',)), ('i',)), ('rap',)): '2air',
        ((('e', ('ap',)), ('e', ('ap',)), ('i',)), ('np',)):'2ain',
        ((('e', ('ap',)), ('nv', ('np',)), ('i',)), ('rap',)):'anir',
        ((('e', ('ap',)), ('nv', ('np',)), ('i',)), ('np',)):'anin',
        ((('nv', ('rap',)), ('e', ('rp',)), ('i',)), ('rp',)): 'rpip',
        ((('nv', ('rap',)), ('e', ('rp',)), ('i',)), ('ap',)): 'rpia',
        ((('nv', ('rap',)), ('nv', ('rap',)), ('i',)), ('rp',)): '2rip',
        ((('nv', ('rap',)), ('nv', ('rap',)), ('i',)), ('ap',)):'2ria',
        ((('nv', ('np',)), ('nv', ('np',)), ('i',)), ('rap',)):'2nir',
        ((('nv', ('np',)), ('nv', ('np',)), ('i',)), ('np',)):'2nin',
        ((('nv', ('np',)), ('e', ('ap',)), ('i',)), ('rap',)) : 'nair',
        ((('nv', ('np',)), ('e', ('ap',)), ('i',)), ('np',)):'nain',
    },
    '(i,(p,(e)),(p,(e)),(p,(e)))':{
        (('nv', ('np',)), ('nv', ('np',)), ('e', ('ap',)), ('i',)):'2na3i',
        (('nv', ('np',)), ('nv', ('np',)), ('nv', ('np',)), ('i',)):'3n3i',
        (('nv', ('rap',)), ('e', ('rp',)), ('e', ('rp',)), ('i',)):'r2p3i',
        (('nv', ('rap',)), ('e', ('rp',)), ('nv', ('rap',)), ('i',)) :'rpr3i',
        (('nv', ('np',)), ('e', ('ap',)), ('nv', ('np',)), ('i',)):'nan3i',
        (('nv', ('np',)), ('e', ('ap',)), ('e', ('ap',)), ('i',)):'n2a3i',
        (('nv', ('rap',)), ('nv', ('rap',)), ('e', ('rp',)), ('i',)):'2rp3i',
        (('nv', ('rap',)), ('nv', ('rap',)), ('nv', ('rap',)), ('i',)):'3r3i',
        (('e', ('rp',)), ('nv', ('rap',)), ('e', ('rp',)), ('i',)):'prp3i',
        (('e', ('rp',)), ('nv', ('rap',)), ('nv', ('rap',)), ('i',)):'p2r3i',
        (('e', ('rp',)), ('e', ('rp',)), ('e', ('rp',)), ('i',)):'3p3i',
        (('e', ('rp',)), ('e', ('rp',)), ('nv', ('rap',)), ('i',)):'2pr3i',
        (('e', ('ap',)), ('e', ('ap',)), ('nv', ('np',)), ('i',)) :'2an3i',
        (('e', ('ap',)), ('e', ('ap',)), ('e', ('ap',)), ('i',)):'3a3i',
        (('e', ('ap',)), ('nv', ('np',)), ('nv', ('np',)), ('i',)):'a2n3i',
        (('e', ('ap',)), ('nv', ('np',)), ('e', ('ap',)), ('i',)) :'ana3i',
    },
    '(u,(p,(e)),(p,(e)))':{
        (('e', ('rp',)), ('e', ('rp',)), ('u',)): '2pu',
        (('e', ('rp',)), ('nv', ('rap',)), ('u',)): 'pru',
        (('e', ('ap',)), ('e', ('ap',)), ('u',)): 'aau',
        (('e', ('ap',)), ('nv', ('np',)), ('u',)):'anu',
        (('nv', ('rap',)), ('nv', ('rap',)), ('u',)): '2ru',
        (('nv', ('rap',)), ('e', ('rp',)), ('u',)): 'rpu',
        (('nv', ('np',)), ('nv', ('np',)), ('u',)):'2nu',
        (('nv', ('np',)), ('e', ('ap',)), ('u',)):'nau',
    },
    '(p,(u,(p,(e)),(p,(e))))':{
        ((('e', ('rp',)), ('e', ('rp',)), ('u',)), ('rp',)): '2pup',
        ((('e', ('rp',)), ('e', ('rp',)), ('u',)), ('ap',)) :'2pua',
        ((('e', ('rp',)), ('nv', ('rap',)), ('u',)), ('rp',)): 'prup',
        ((('e', ('rp',)), ('nv', ('rap',)), ('u',)), ('ap',)):'prua',
        ((('e', ('ap',)), ('nv', ('np',)), ('u',)), ('np',)) : 'anun',
        ((('e', ('ap',)), ('nv', ('np',)), ('u',)), ('rap',)) : 'anur',
        ((('e', ('ap',)), ('e', ('ap',)), ('u',)), ('rap',)): '2aur',
        ((('e', ('ap',)), ('e', ('ap',)), ('u',)), ('np',)): '2aun',
        ((('nv', ('np',)), ('e', ('ap',)), ('u',)), ('rap',)): 'naur',
        ((('nv', ('np',)), ('e', ('ap',)), ('u',)), ('np',)) : 'naun',
        ((('nv', ('np',)), ('nv', ('np',)), ('u',)), ('np',)):'2nun',
        ((('nv', ('np',)), ('nv', ('np',)), ('u',)), ('rap',)):'2nur',
        ((('nv', ('rap',)), ('nv', ('rap',)), ('u',)), ('rp',)): '2rup', 
        ((('nv', ('rap',)), ('nv', ('rap',)), ('u',)), ('ap',)): '2rua',
        ((('nv', ('rap',)), ('e', ('rp',)), ('u',)), ('ap',)): 'rpua',
        ((('nv', ('rap',)), ('e', ('rp',)), ('u',)), ('rp',)): 'rpup',
    }
}
query_name_dict= {
    ('e', ('ap',)): '1ap',
    ('e', ('rp',)): '1p',
    ('nv', ('np',)):'np',
    ('nv', ('rap',)): '1rap',
    (('e', ('ap',)), ('np',)):'an',
    (('e', ('ap',)), ('rap',)): 'ar',
    (('e', ('rp',)), ('ap',)): 'pa',
    (('e', ('rp',)), ('rp',)): '2p',
    (('nv', ('rap',)), ('rp',)): 'rp',
    (('nv', ('rap',)), ('ap',)): 'ra',
    (('nv', ('np',)), ('np',)):'nn',
    (('nv', ('np',)), ('rap',)):'nr',
    (('e', ('rp',)), ('nv', ('rap',)), ('i',)): 'pri',
    (('e', ('rp',)), ('e', ('rp',)), ('i',)): '2pi',
    (('e', ('ap',)), ('e', ('ap',)), ('i',)): '2ai',
    (('e', ('ap',)), ('nv', ('np',)), ('i',)):'ani',
    (('nv', ('rap',)), ('e', ('rp',)), ('i',)): 'rpi',
    (('nv', ('rap',)), ('nv', ('rap',)), ('i',)): '2ri',
    (('nv', ('np',)), ('nv', ('np',)), ('i',)):'nni',
    (('nv', ('np',)), ('e', ('ap',)), ('i',)):'nai',
    ((('e', ('rp',)), ('rp',)), ('e', ('rp',)), ('i',)): 'pppi',
    ((('e', ('rp',)), ('rp',)), ('nv', ('rap',)), ('i',)): 'ppri',
    ((('e', ('rp',)), ('ap',)), ('e', ('ap',)), ('i',)): 'paai',
    ((('e', ('rp',)), ('ap',)), ('nv', ('np',)), ('i',)):'pani',
    ((('e', ('ap',)), ('rap',)), ('e', ('rp',)), ('i',)): 'arpi',
    ((('e', ('ap',)), ('rap',)), ('nv', ('rap',)), ('i',)): 'arri',
    ((('e', ('ap',)), ('np',)), ('nv', ('np',)), ('i',)):'anni',
    ((('e', ('ap',)), ('np',)), ('e', ('ap',)), ('i',)):'anai',
    ((('nv', ('np',)), ('rap',)), ('e', ('rp',)), ('i',)):'nrpi',
    ((('nv', ('np',)), ('rap',)), ('nv', ('rap',)), ('i',)):'nrri',
    ((('nv', ('np',)), ('np',)), ('e', ('ap',)), ('i',)):'nnai',
    ((('nv', ('np',)), ('np',)), ('nv', ('np',)), ('i',)):'nnni',
    ((('nv', ('rap',)), ('ap',)), ('nv', ('np',)), ('i',)):'rani',
    ((('nv', ('rap',)), ('ap',)), ('e', ('ap',)), ('i',)):'raai',
    ((('nv', ('rap',)), ('rp',)), ('e', ('rp',)), ('i',)): 'rppi',
    ((('nv', ('rap',)), ('rp',)), ('nv', ('rap',)), ('i',)): 'rpri',
    ((('e', ('rp',)), ('e', ('rp',)), ('i',)), ('ap',)): '2pia',
    ((('e', ('rp',)), ('e', ('rp',)), ('i',)), ('rp',)): '2pip',
    ((('e', ('rp',)), ('nv', ('rap',)), ('i',)), ('rp',)): 'prip',
    ((('e', ('rp',)), ('nv', ('rap',)), ('i',)), ('ap',)): 'pria',
    ((('e', ('ap',)), ('e', ('ap',)), ('i',)), ('rap',)): '2air',
    ((('e', ('ap',)), ('e', ('ap',)), ('i',)), ('np',)):'2ain',
    ((('e', ('ap',)), ('nv', ('np',)), ('i',)), ('rap',)):'anir',
    ((('e', ('ap',)), ('nv', ('np',)), ('i',)), ('np',)):'anin',
    ((('nv', ('rap',)), ('e', ('rp',)), ('i',)), ('rp',)): 'rpip',
    ((('nv', ('rap',)), ('e', ('rp',)), ('i',)), ('ap',)): 'rpia',
    ((('nv', ('rap',)), ('nv', ('rap',)), ('i',)), ('rp',)): '2rip',
    ((('nv', ('rap',)), ('nv', ('rap',)), ('i',)), ('ap',)):'2ria',
    ((('nv', ('np',)), ('nv', ('np',)), ('i',)), ('rap',)):'2nir',
    ((('nv', ('np',)), ('nv', ('np',)), ('i',)), ('np',)):'2nin',
    ((('nv', ('np',)), ('e', ('ap',)), ('i',)), ('rap',)) : 'nair',
    ((('nv', ('np',)), ('e', ('ap',)), ('i',)), ('np',)):'nain',
    (('nv', ('np',)), ('nv', ('np',)), ('e', ('ap',)), ('i',)):'2na3i',
    (('nv', ('np',)), ('nv', ('np',)), ('nv', ('np',)), ('i',)):'3n3i',
    (('nv', ('rap',)), ('e', ('rp',)), ('e', ('rp',)), ('i',)):'r2p3i',
    (('nv', ('rap',)), ('e', ('rp',)), ('nv', ('rap',)), ('i',)) :'rpr3i',
    (('nv', ('np',)), ('e', ('ap',)), ('nv', ('np',)), ('i',)):'nan3i',
    (('nv', ('np',)), ('e', ('ap',)), ('e', ('ap',)), ('i',)):'n2a3i',
    (('nv', ('rap',)), ('nv', ('rap',)), ('e', ('rp',)), ('i',)):'2rp3i',
    (('nv', ('rap',)), ('nv', ('rap',)), ('nv', ('rap',)), ('i',)):'3r3i',
    (('e', ('rp',)), ('nv', ('rap',)), ('e', ('rp',)), ('i',)):'prp3i',
    (('e', ('rp',)), ('nv', ('rap',)), ('nv', ('rap',)), ('i',)):'p2r3i',
    (('e', ('rp',)), ('e', ('rp',)), ('e', ('rp',)), ('i',)):'3p3i',
    (('e', ('rp',)), ('e', ('rp',)), ('nv', ('rap',)), ('i',)):'2pr3i',
    (('e', ('ap',)), ('e', ('ap',)), ('nv', ('np',)), ('i',)) :'2an3i',
    (('e', ('ap',)), ('e', ('ap',)), ('e', ('ap',)), ('i',)):'3a3i',
    (('e', ('ap',)), ('nv', ('np',)), ('nv', ('np',)), ('i',)):'a2n3i',
    (('e', ('ap',)), ('nv', ('np',)), ('e', ('ap',)), ('i',)) :'ana3i',
    (('e', ('rp',)), ('e', ('rp',)), ('u',)): '2pu',
    (('e', ('rp',)), ('nv', ('rap',)), ('u',)): 'pru',
    (('e', ('ap',)), ('e', ('ap',)), ('u',)): 'aau',
    (('e', ('ap',)), ('nv', ('np',)), ('u',)):'anu',
    (('nv', ('rap',)), ('nv', ('rap',)), ('u',)): '2ru',
    (('nv', ('rap',)), ('e', ('rp',)), ('u',)): 'rpu',
    (('nv', ('np',)), ('nv', ('np',)), ('u',)):'2nu',
    (('nv', ('np',)), ('e', ('ap',)), ('u',)):'nau',
    ((('e', ('rp',)), ('e', ('rp',)), ('u',)), ('rp',)): '2pup',
    ((('e', ('rp',)), ('e', ('rp',)), ('u',)), ('ap',)) :'2pua',
    ((('e', ('rp',)), ('nv', ('rap',)), ('u',)), ('rp',)): 'prup',
    ((('e', ('rp',)), ('nv', ('rap',)), ('u',)), ('ap',)):'prua',
    ((('e', ('ap',)), ('nv', ('np',)), ('u',)), ('np',)) : 'anun',
    ((('e', ('ap',)), ('nv', ('np',)), ('u',)), ('rap',)) : 'anur',
    ((('e', ('ap',)), ('e', ('ap',)), ('u',)), ('rap',)): '2aur',
    ((('e', ('ap',)), ('e', ('ap',)), ('u',)), ('np',)): '2aun',
    ((('nv', ('np',)), ('e', ('ap',)), ('u',)), ('rap',)): 'naur',
    ((('nv', ('np',)), ('e', ('ap',)), ('u',)), ('np',)) : 'naun',
    ((('nv', ('np',)), ('nv', ('np',)), ('u',)), ('np',)):'2nun',
    ((('nv', ('np',)), ('nv', ('np',)), ('u',)), ('rap',)):'2nur',
    ((('nv', ('rap',)), ('nv', ('rap',)), ('u',)), ('rp',)): '2rup', 
    ((('nv', ('rap',)), ('nv', ('rap',)), ('u',)), ('ap',)): '2rua',
    ((('nv', ('rap',)), ('e', ('rp',)), ('u',)), ('ap',)): 'rpua',
    ((('nv', ('rap',)), ('e', ('rp',)), ('u',)), ('rp',)): 'rpup',   
    ((('e', ('ap',)), ('e', ('ap',)), ('b',)), ('rap',)):'aabr',
    ((('e', ('ap',)), ('e', ('ap',)), ('b',)), ('np',)):'aabn',
    (('e', ('ap',)), ('e', ('ap',)), ('b',)):'aab',
    (((('e', ('ap',)), ('e', ('ap',)), ('b',)), ('rap',)), ('ap',)):'aabra',
    (((('e', ('ap',)), ('e', ('ap',)), ('b',)), ('np',)), ('np',)):'aabnn',
    (((('e', ('ap',)), ('e', ('ap',)), ('b',)), ('np',)), ('rap',)):'aabnr',
    (((('e', ('ap',)), ('e', ('ap',)), ('b',)), ('rap',)), ('rp',)):'aabrp',
    (('e', ('ap',)), ('e', ('ap',)), ('e', ('ap',)), ('b',)):'3ab',
    ((('e', ('rp',)), ('ap',)), ('e', ('ap',)), ('b',)):'paab',
    ((('nv', ('rap',)), ('ap',)), ('e', ('ap',)), ('b',)):'raab',
    ((('nv', ('rap',)), ('ap',)), (('e', ('rp',)), ('ap',)), ('b',)):'rapab',
    ((('e', ('rp',)), ('ap',)), (('e', ('rp',)), ('ap',)), ('b',)):'papab',
    ((('nv', ('rap',)), ('ap',)), (('nv', ('rap',)), ('ap',)), ('b',)):'rarab',
    ((('e', ('rp',)), ('ap',)), (('nv', ('rap',)), ('ap',)), ('b',)):'parab',
    ((('e', ('ap',)), ('e', ('ap',)), ('b',)), ('e', ('ap',)), ('b',)):'2abb',
}
name_answer_dict = {}
name_query_dict = {value: key for key, value in query_name_dict.items()}
all_tasks = list(name_query_dict.keys())
espace = 9
rspace = 11
mapping = dict()

def parse_args(args=None):
    parser = argparse.ArgumentParser(
        description='Training and Testing Knowledge Graph Embedding Models',
        usage='train.py [<args>] [-h | --help]'
    )
    
    parser.add_argument('--do_valid', action='store_true', help="do valid")
    parser.add_argument('--do_test', action='store_true', help="do test")
    parser.add_argument('--do_cp', action='store_true', help="do cardinality prediction")
    parser.add_argument('--path', action='store_true', help="do interpretation study")

    parser.add_argument('--train', action='store_true', help="do test")
    parser.add_argument('--data_path', type=str, default=None, help="KG data path")
    parser.add_argument('--kbc_path', type=str, default=None, help="kbc model path")
    parser.add_argument('--test_batch_size', default=1, type=int, help='valid/test batch size')
    parser.add_argument('-cpu', '--cpu_num', default=10, type=int, help="used to speed up torch.dataloader")
    
    parser.add_argument('--nentity', type=int, default=0, help='DO NOT MANUALLY SET')
    parser.add_argument('--nrelation', type=int, default=0, help='DO NOT MANUALLY SET')
    parser.add_argument('--fraction', type=int, default=1, help='fraction the entity to save gpu memory usage')
    parser.add_argument('--thrshd', type=float, default=0.001, help='thrshd for neural adjacency matrix')
    parser.add_argument('--eavthrshd', type=float, default=0.0005, help='thrshd for eav neural adjacency matrix')
    parser.add_argument('--vnpvthrshd', type=float, default=0.0001, help='thrshd for nvpv neural adjacency matrix')
    parser.add_argument('--fuzzythrshd', type=float, default=0.001, help='thrshd for fuzzyset')
    parser.add_argument('--neg_scale', type=int, default=1, help='scaling neural adjacency matrix for negation')
    parser.add_argument('--pre_fuzzy', type=str, default="pre_fuzzy.pkl", help='pretrain for dataset number value')
    parser.add_argument('--tasks', default='1p.2p.3p.2i.3i.ip.pi.2in.3in.inp.pin.pni.2u.up', type=str, help="tasks connected by dot, refer to the BetaE paper for detailed meaning and structure of each task")
    parser.add_argument('--seed', default=12345, type=int, help="random seed")
    parser.add_argument('--use_newmetric', default=False, type=bool, help="New numerical evaluation indicators")
    parser.add_argument('-evu', '--evaluate_union', default="DNF", type=str, choices=['DNF', 'DM'], help='the way to evaluate union queries, transform it to disjunctive normal form (DNF) or use the De Morgan\'s laws (DM)')

    return parser.parse_args(args)

def log_metrics(mode, metrics, writer):
    '''
    Print the evaluation logs
    '''
    for metric in metrics:
        logging.info('%s %s: %f' % (mode, metric, metrics[metric]))
        print('%s %s: %f' % (mode, metric, metrics[metric]))
        writer.write('%s %s: %f\n' % (mode, metric, metrics[metric]))

def read_triples(filenames, nrelation, datapath):
    adj_list = [[] for i in range(nrelation)]
    edges_all = set()
    edges_vt = set()
    for filename in filenames:
        with open(filename) as f:
            for line in f.readlines():
                line = line.replace("\n","")
                h, r, t = line.strip().split(' ')
                adj_list[int(r)].append((int(h), int(t))) # The format of adj_list is a list where each element represents the ith relationally connected entity, each element is denoted as [(0, 1), (1024, 2795), (307, 2795), (3477, 6321), (7226, 2795)], and only train
    for filename in ['valid.txt', 'test.txt']:
        with open(os.path.join(datapath, filename)) as f:
            for line in f.readlines():
                line = line.replace("\n","")
                h, r, t = line.strip().split(' ')
                edges_all.add((int(h), int(r), int(t)))
                edges_vt.add((int(h), int(r), int(t)))# edge_all format is a ternary (138, 621, 10236)
    with open(os.path.join(datapath, "train.txt")) as f:
        for line in f.readlines():
            line = line.replace("\n","")
            h, r, t = line.strip().split(' ')
            edges_all.add((int(h), int(r), int(t)))
    return adj_list, edges_all, edges_vt
def num_read_triples(filenames, nrelation, datapath):
    adj_list = [[] for i in range(nrelation)]
    adj_reverse_list = [[] for i in range(nrelation)]
    edges_all = set()
    edges_vt = set()
    for filename in filenames:
        with open(filename) as f:
            for line in f.readlines():
                line = line.replace("\n","")
                h, r, t = line.strip().split(' ')
                adj_list[int(r)].append((int(h), int(t))) 
                adj_reverse_list[int(r)].append((int(t),int(h)))
    for filename in ['valid_num.txt', 'test_num.txt']:
        with open(os.path.join(datapath, filename)) as f:
            for line in f.readlines():
                line = line.replace("\n","")
                h, r, t = line.strip().split(' ')
                edges_all.add((int(h), int(r), int(t)))
                edges_vt.add((int(h), int(r), int(t)))
    with open(os.path.join(datapath, "train_num.txt")) as f:
        for line in f.readlines():
            line = line.replace("\n","")
            h, r, t = line.strip().split(' ')
            edges_all.add((int(h), int(r), int(t)))
    return adj_list, edges_all, edges_vt,adj_reverse_list
def np_read_triples(filenames, nrelation, datapath):
    adj_list = [[] for i in range(nrelation)]
    edges_all = set()
    edges_vt = set()
    for filename in filenames:
        with open(filename) as f:
            for line in f.readlines():
                line = line.replace("\n","")
                h, r, t = line.strip().split(' ')
                adj_list[int(r)].append((int(h), int(t)))
    for filename in ['valid_np.txt', 'test_np.txt']:
        with open(os.path.join(datapath, filename)) as f:
            for line in f.readlines():
                line = line.replace("\n","")
                h, r, t = line.strip().split(' ')
                edges_all.add((int(h), int(r), int(t)))
                edges_vt.add((int(h), int(r), int(t)))
    with open(os.path.join(datapath, "train_np.txt")) as f:
        for line in f.readlines():
            line = line.replace("\n","")
            h, r, t = line.strip().split(' ')
            edges_all.add((int(h), int(r), int(t)))

    return adj_list, edges_all, edges_vt

def verify_chain(chain, chain_structure, edges_y, edges_p): # (e, r, e, ..., e)
    '''
    verify the validity of the reasoning path (chain)
    '''
    global mapping
    head = chain[0]
    rel = 0
    neg = False
    judge = True
    edge_class = []
    for ele, ans_ele in zip(chain_structure[1:], chain[1:]):
        if ele == 'e':
            if neg:
                edge_judge = ((head, rel, ans_ele) not in edges_y)
                judge = judge & edge_judge
                if edge_judge: # not in train/val/test
                    edge_class.append('y')
                elif (head, rel, ans_ele) in edges_p: # in val/test
                    edge_class.append('p')
                else: # in train
                    edge_class.append('n')
                neg = False
            else:
                edge_judge = ((head, rel, ans_ele) in edges_y)
                if edge_judge:
                    if (head, rel, ans_ele) in edges_p: # in val/test
                        edge_class.append('p')
                    else: # in train
                        edge_class.append('y')
                else: # not in train/val/test
                    edge_class.append('n')
                judge = judge & edge_judge
            head = ans_ele
        elif ele == 'r':
            rel = ans_ele
        elif ele == 'n':
            neg = True
    
    chain_structure = chain_structure[1:-1]
    chain = chain[1:-1]
    out = ''
    neg = False
    edge_class = edge_class[::-1]
    idx = 0
    for ele, ans_ele in zip(chain_structure[::-1], chain[::-1]):
        if ele == 'e':
            out += '{:<9}'.format(str(ans_ele))
            mapping[str(ans_ele)] = id2ent[ans_ele]
        elif ele == 'r':
            if neg:
                out += '{:<11}'.format(edge_class[idx]+'<-r'+str(ans_ele)+'-X')
                neg = False
            else:
                out += '{:<11}'.format(edge_class[idx]+'<-r'+str(ans_ele)+'-')
            mapping['r'+str(ans_ele)] = id2rel[ans_ele]
            idx += 1
        elif ele == 'n':
            neg = True
    return judge, out

def verify(ans_structure, ans, edges_y, edges_p, offset=0):
    '''
    verify the validity of the reasoning path
    '''
    global mapping
    if ans_structure[1][0] == 'r': # [[...], ['r', ...], 'e']
        chain_stucture = ['e']+ans_structure[1]+['e']
        if ans_structure[0] == 'e': # ['e', ['r', ...], 'e']
            chain = [ans[0]]+ans[1]+[ans[2]]
            judge, out = verify_chain(chain, chain_stucture, edges_y, edges_p)
            out = '{:<9}'.format(str(ans[2])) + out + '{:<9}'.format(str(ans[0]))
            mapping[str(ans[2])] = id2ent[ans[2]]
            mapping[str(ans[0])] = id2ent[ans[0]]
            return judge, out
        else:
            chain = [ans[0][-1]]+ans[1]+[ans[2]]
            judge1, out1 = verify_chain(chain, chain_stucture, edges_y, edges_p)
            for ele in ans_structure[1] + [ans_structure[2]]:
                if ele == 'r':
                    offset += 11
                elif ele == 'e':
                    offset += 9
            judge2, out2 = verify(ans_structure[0], ans[0], edges_y, edges_p, offset)
            judge = judge1 & judge2
            out = '{:<9}'.format(str(ans[2])) + out1 + out2
            mapping[str(ans[2])] = id2ent[ans[2]]
            return judge, out
        
    else: # [[...], [...], 'e']
        if ans_structure[-2][0] == 'u':
            union = True
            out = '{:<9}'.format(str(ans[-1])+'(u)')
            ans_structure, ans = ans_structure[:-1], ans[:-1]
        else:
            union = False
            out = '{:<9}'.format(str(ans[-1])+'(i)')
        mapping[str(ans[-1])] = id2ent[ans[-1]]
        judge = not union
        offset += 9
        for ele, ans_ele in zip(ans_structure[:-1], ans[:-1]):
            judge_ele, out_ele = verify(ele, ans_ele, edges_y, edges_p, offset)
            if union:
                judge = judge | judge_ele
            else:
                judge = judge & judge_ele
            out = out + out_ele + '\n' + ' '*offset
        return judge, out

def get_cp_thrshd(model, tp_answers, fn_answers, args, dataloader, query_name_dict, device):
    '''
    get the best threshold for cardinality prediction on valid set
    tp_ansers is hard_answers
    fn_ansers is easy_answers
    '''
    probs = defaultdict(list)
    cards = defaultdict(list)
    best_thrshds = dict()
    for queries, queries_unflatten, query_structures in tqdm(dataloader):
        queries = torch.LongTensor(queries).to(device)
        embedding, _, _ = model.embed_query(queries, query_structures[0], 0)
        embedding = embedding.squeeze()
        hard_answer = tp_answers[queries_unflatten[0]]
        easy_answer = fn_answers[queries_unflatten[0]]
        num_hard = len(hard_answer)
        num_easy = len(easy_answer)

        probs[query_structures[0]].append(embedding.to('cpu'))
        cards[query_structures[0]].append(torch.tensor([num_hard+num_easy]))
    for query_structure in probs:
        prob = torch.stack(probs[query_structure])#.to(device)
        card = torch.stack(cards[query_structure]).squeeze().to(torch.float)#.to(device)
        ape = torch.zeros_like(card).to(torch.float).to(device)
        best_thrshd = 0
        best_mape = 10000
        nquery = prob.size(0)
        fraction = 10
        dim = nquery // fraction
        rest = nquery - fraction * dim
        for i in tqdm(range(10)):
            thrshd = i / 10
            for j in range(fraction):
                s = j * dim
                t = (j+1) * dim
                if j == fraction - 1:
                    t += rest
                fractional_prob = prob[s:t, :].to(device)
                fractional_card = card[s:t].to(device)
                pre_card = (fractional_prob >= thrshd).to(torch.float).sum(-1)
                ape[s:t] = torch.abs(fractional_card - pre_card) / fractional_card
            mape = ape.mean()
            if mape < best_mape:
                best_mape = mape
                best_thrshd = thrshd
        best_thrshds[query_structure] = best_thrshd
    print(best_thrshds)
    return best_thrshds

def evaluate(model, tp_answers, fn_answers, args, dataloader, query_name_dict, device, writer, edges_y, edges_p, cp_thrshd,id2number_path):
    '''
    Evaluate queries in dataloader
    '''
    with open(id2number_path,'r') as f:
        id2number = json.load(f)
    number2id = {}
    for id in id2number:
        number2id[id2number[id]] = id
    global mapping
    mode = "Test"
    average_metrics = defaultdict(float)
    all_metrics = defaultdict(float)
    logs = defaultdict(list)
    rates = defaultdict(list)
    probs = defaultdict(list)
    cards = defaultdict(list)
    all_time = defaultdict(list)
    for queries, queries_unflatten, query_structures in tqdm(dataloader):
        queries = torch.LongTensor(queries).to(device)
        start = time.time()
        embedding, _, exec_query,Affiliation_Center = model.embed_query(queries, query_structures[0], 0,None) #Reasoning Funtion
        if embedding == None:
            logs[query_structures[0]].append({
                'MRR_hard': 0,
                'HITS1_hard': 0,
                'HITS3_hard': 0,
                'HITS10_hard': 0,
                'num_hard_answer': 0,
                'MRR_easy': 0,
                'HITS1_easy': 0,
                'HITS3_easy': 0,
                'HITS10_easy': 0,
                'num_easy_answer': 0,
            })
            continue
        embedding = embedding.squeeze()
        end = time.time()
        use_time = end-start
        all_time[query_structures[0]].append(use_time)
        try:
            tem = embedding.shape[0]
        except:
            continue
        if(embedding.shape[0] != args.nentity and embedding.shape[0] != args.nnum):
            # When the answer is a numerical fuzzy set
            hard_answer = tp_answers[query_structures[0]][queries_unflatten[0]]
            easy_answer = fn_answers[query_structures[0]][queries_unflatten[0]]
            easy_answer = list(map(float, easy_answer))
            hard_answer = list(map(float, hard_answer))
            hard_answer = list(set(hard_answer))
            easy_answer = list(set(easy_answer))
            temp_num = 0
            while(temp_num < len(hard_answer)):
                if(hard_answer[temp_num] in easy_answer):
                    hard_answer.remove(hard_answer[temp_num])
                    temp_num -= 1
                temp_num += 1
            num_hard = len(hard_answer)
            num_easy = len(easy_answer)
            num_all = hard_answer + easy_answer
            if num_hard == 0:
                continue
            avg_rank_easy = []
            avg_rank_hard = []
            for ans in hard_answer:
                for i in range(len(Affiliation_Center)):
                    num_result = Affiliation_Center[i]
                    if num_result == 0.0 and ans == 0.0:
                        mape_hard = 10
                    else:
                        mape_hard = math.sqrt(((ans-num_result)/max(abs(ans),abs(num_result)))**2)
                    if mape_hard < close_cri:
                        if ans in easy_answer:
                            avg_rank_easy.append(i+1)
                        else:
                            avg_rank_hard.append(i+1)
                        break
            if len(avg_rank_hard) == 0:
                mrr_hard = 0
                h1_hard = 0
                h3_hard = 0
                h10_hard = 0
            else:
                answer_list_hard= np.arange(len(avg_rank_hard))
                avg_rank_hard = sorted(avg_rank_hard) - answer_list_hard
                for i in range(len(avg_rank_hard)):
                    if avg_rank_hard[i] < 1:
                        avg_rank_hard[i] = 1
                mrr_hard = np.mean( [1./c for c in avg_rank_hard])
                h1_hard = np.mean(list(map(float,[c<=1 for c in avg_rank_hard])))
                h3_hard = np.mean(list(map(float,[c<=3 for c in avg_rank_hard])))
                h10_hard = np.mean(list(map(float,[c<=10 for c in avg_rank_hard])))
            if len(avg_rank_easy) == 0:
                mrr_easy = 0
                h1_easy = 0
                h3_easy = 0
                h10_easy = 0
            else:
                answer_list_easy= np.arange(len(avg_rank_easy))
                avg_rank_easy = sorted(avg_rank_easy) - answer_list_easy
                for i in range(len(avg_rank_easy)):
                    if avg_rank_easy[i] < 1:
                        avg_rank_easy[i] = 1
                mrr_easy = np.mean( [1./c for c in avg_rank_easy])
                h1_easy = np.mean(list(map(float,[c<=1 for c in avg_rank_easy])))
                h3_easy = np.mean(list(map(float,[c<=3 for c in avg_rank_easy])))
                h10_easy = np.mean(list(map(float,[c<=10 for c in avg_rank_easy])))
            logs[query_structures[0]].append({
                'MRR_hard': mrr_hard,
                'HITS1_hard': h1_hard,
                'HITS3_hard': h3_hard,
                'HITS10_hard': h10_hard,
                'num_hard_answer': num_hard,
                'MRR_easy': mrr_easy,
                'HITS1_easy': h1_easy,
                'HITS3_easy': h3_easy,
                'HITS10_easy': h10_easy,
                'num_easy_answer': num_easy,
            })   
        elif embedding.shape[0] != args.nentity and args.use_newmetric:
            # Using newmetric
            order = torch.argsort(embedding, dim=-1, descending=True) # Returns the index of the array value from smallest to largest.
            ranking = torch.argsort(order) # Indicates what the probability value ranking of the ith element is
            # eval_entity
            hard_answer = tp_answers[query_structures[0]][queries_unflatten[0]]
            easy_answer = fn_answers[query_structures[0]][queries_unflatten[0]]
            easy_answer = list(map(float,easy_answer))
            hard_answer = list(map(float,hard_answer))
            hard_answer = list(set(hard_answer))
            easy_answer = list(set(easy_answer))
            temp_num = 0
            while(temp_num < len(hard_answer)):
                if(hard_answer[temp_num] in easy_answer):
                    hard_answer.remove(hard_answer[temp_num])
                    temp_num -= 1
                temp_num += 1
            num_hard = len(hard_answer)
            num_easy = len(easy_answer)           
            if num_hard == 0:
                continue
            num_all = hard_answer + easy_answer
            if num_hard == 0:
                continue
            avg_rank_easy = []
            avg_rank_hard = []
            ans_all = []
            for i in range(order.shape[0]):
                ans_all.append(float(id2number[str(order[i].item())]))
            sorted(ans_all)
            sorted(num_all)
            for i in range(len(num_all)):
                for j in range(len(ans_all)):
                    try:
                        RMSE_ = math.sqrt(((ans_all[j]-num_all[i])/max(abs(ans_all[j]),abs(num_all[i])))**2)
                    except:
                        continue
                    if RMSE_ < close_cri:
                        ans_all.remove(ans_all[j])
                        if num_all[i] in easy_answer:
                            avg_rank_easy.append(j+1)
                        else:
                            avg_rank_hard.append(j+1)
                        break
            if len(avg_rank_hard) == 0:
                mrr_hard = 0
                h1_hard = 0
                h3_hard = 0
                h10_hard = 0
            else:
                answer_list_hard= np.arange(len(avg_rank_hard))
                avg_rank_hard = sorted(avg_rank_hard) - answer_list_hard
                for i in range(len(avg_rank_hard)):
                    if avg_rank_hard[i] < 1:
                        avg_rank_hard[i] = 1
                mrr_hard = np.mean( [1./c for c in avg_rank_hard])
                h1_hard = np.mean(list(map(float,[c<=1 for c in avg_rank_hard])))
                h3_hard = np.mean(list(map(float,[c<=3 for c in avg_rank_hard])))
                h10_hard = np.mean(list(map(float,[c<=10 for c in avg_rank_hard])))
            if len(avg_rank_easy) == 0:
                mrr_easy = 0
                h1_easy = 0
                h3_easy = 0
                h10_easy = 0
            else:
                answer_list_easy= np.arange(len(avg_rank_easy))
                avg_rank_easy = sorted(avg_rank_easy) - answer_list_easy
                for i in range(len(avg_rank_easy)):
                    if avg_rank_easy[i] < 1:
                        avg_rank_easy[i] = 1
                mrr_easy = np.mean( [1./c for c in avg_rank_easy])
                h1_easy = np.mean(list(map(float,[c<=1 for c in avg_rank_easy])))
                h3_easy = np.mean(list(map(float,[c<=3 for c in avg_rank_easy])))
                h10_easy = np.mean(list(map(float,[c<=10 for c in avg_rank_easy])))
            logs[query_structures[0]].append({
                'MRR_hard': mrr_hard,
                'HITS1_hard': h1_hard,
                'HITS3_hard': h3_hard,
                'HITS10_hard': h10_hard,
                'num_hard_answer': num_hard,
                'MRR_easy': mrr_easy,
                'HITS1_easy': h1_easy,
                'HITS3_easy': h3_easy,
                'HITS10_easy': h10_easy,
                'num_easy_answer': num_easy,
            })  
        else:
            # When the answer is an entity
            order = torch.argsort(embedding, dim=-1, descending=True) 
            ranking = torch.argsort(order)
            hard_answer = list(tp_answers[query_structures[0]][queries_unflatten[0]])
            easy_answer = list(fn_answers[query_structures[0]][queries_unflatten[0]])
            
            if(len(easy_answer) > 0 and isinstance(easy_answer[0],float) and embedding.shape[0] != args.nentity):
                for i in range(len(easy_answer)):
                    if str(easy_answer[i]) == '-0.0':
                        easy_answer[i] = "0.0"
                    easy_answer[i] = number2id[str(easy_answer[i])]
            if(len(hard_answer) > 0 and isinstance(hard_answer[0],float) and embedding.shape[0] != args.nentity):
                for i in range(len(hard_answer)):
                    if str(hard_answer[i]) == '-0.0':
                        hard_answer[i] = "0.0"
                    hard_answer[i] = number2id[str(hard_answer[i])]
            easy_answer = list(map(int, easy_answer))
            hard_answer = list(map(int, hard_answer))
            hard_answer = list(set(hard_answer))
            easy_answer = list(set(easy_answer))
            temp_num = 0
            while(temp_num < len(hard_answer)):
                if(hard_answer[temp_num] in easy_answer):
                    hard_answer.remove(hard_answer[temp_num])
                    temp_num -= 1
                temp_num += 1
            num_hard = len(hard_answer)
            num_easy = len(easy_answer)
            if num_hard == 0:
                continue
            #     print("")
            cur_ranking = ranking[list(easy_answer) + list(hard_answer)]
            all_path, h1_path, h3_path, h10_path = 0, 0, 0, 0
            num_h1, num_h3, num_h10 = 0, 0, 0
            if args.path:
                for root in list(hard_answer):
                    rank = ranking[root]
                    rank -= ((cur_ranking < rank).sum()-1)
                    ans, _ = model.find_ans(exec_query, query_structures[0], root)
                    mapping = dict()
                    judge, out = verify(name_answer_dict[query_name_dict[query_structures[0]]], ans, edges_y, edges_p)
                    if judge:
                        all_path += 1
                    if rank <= 1:
                        num_h1 += 1
                        if judge:
                            h1_path += 1
                    if rank <= 3:
                        num_h3 += 1
                        if judge:
                            h3_path += 1
                    if rank <= 10:
                        num_h10 += 1
                        if judge:
                            h10_path += 1
                    print(judge, rank.item())
                    print(out, mapping)
            if args.do_cp:
                probs[query_structures[0]].append(embedding.to('cpu'))
                cards[query_structures[0]].append(torch.tensor([num_hard+num_easy]))
            try:
                cur_ranking, indices = torch.sort(cur_ranking)
                masks_hard = indices >= num_easy
                masks_easy = indices < num_easy
                answer_list = torch.arange(num_hard + num_easy).to(torch.float).to(device)
                cur_ranking = cur_ranking - answer_list + 1 # filtered setting
                cur_ranking_hard = cur_ranking[masks_hard] # take indices that belong to the hard answers
                cur_ranking_easy = cur_ranking[masks_easy] # take indices that belong to the easy answers

                mrr_hard = torch.mean(1./cur_ranking_hard).item()
                h1_hard = torch.mean((cur_ranking_hard <= 1).to(torch.float)).item()
                h3_hard = torch.mean((cur_ranking_hard <= 3).to(torch.float)).item()
                h10_hard = torch.mean((cur_ranking_hard <= 10).to(torch.float)).item()
                mrr_easy = torch.mean(1./cur_ranking_easy).item()
                h1_easy = torch.mean((cur_ranking_easy <= 1).to(torch.float)).item()
                h3_easy = torch.mean((cur_ranking_easy <= 3).to(torch.float)).item()
                h10_easy = torch.mean((cur_ranking_easy <= 10).to(torch.float)).item()
            except Exception as e:
                print(e)
            if num_easy == 0:
                mrr_easy, h1_easy, h3_easy, h10_easy = 1, 1, 1, 1

            logs[query_structures[0]].append({
                'MRR_hard': mrr_hard,
                'HITS1_hard': h1_hard,
                'HITS3_hard': h3_hard,
                'HITS10_hard': h10_hard,
                'num_hard_answer': num_hard,
                'MRR_easy': mrr_easy,
                'HITS1_easy': h1_easy,
                'HITS3_easy': h3_easy,
                'HITS10_easy': h10_easy,
                'num_easy_answer': num_easy,
            })
            if args.path:
                if num_hard > 0:
                    rates[query_name_dict[query_structures[0]]+" all path interpretability"].append(all_path / num_hard)
                if num_h1 > 0:
                    rates[query_name_dict[query_structures[0]]+" HITS1 path interpretability"].append(h1_path / num_h1)
                if num_h3 > 0:
                    rates[query_name_dict[query_structures[0]]+" HITS3 path interpretability"].append(h3_path / num_h3)
                if num_h10 > 0:
                    rates[query_name_dict[query_structures[0]]+" HITS10 path interpretability"].append(h10_path / num_h10)
        if args.path:
            rate_metric = defaultdict(float)
            for query_structure in rates:
                rate_metric[query_structure] = sum(rates[query_structure])/len(rates[query_structure])
            log_metrics('Interpretability', rate_metric, writer)
    metrics = collections.defaultdict(lambda: collections.defaultdict(int))
    for ty in all_time:
        print(ty,sum(all_time[ty])/len(all_time[ty]))
    for query_structure in logs:
        for metric in logs[query_structure][0].keys():
            if metric in ['num_hard_answer', 'num_easy_answer']:
                continue
            metrics[query_structure][metric] = sum([log[metric] for log in logs[query_structure]])/len(logs[query_structure])
        metrics[query_structure]['num_queries'] = len(logs[query_structure])
    
    num_query_structures = 0
    num_queries = 0
    for query_structure in metrics:
        log_metrics(mode+" "+query_name_dict[query_structure], metrics[query_structure], writer)
        for metric in metrics[query_structure]:
            all_metrics["_".join([query_name_dict[query_structure], metric])] = metrics[query_structure][metric]
            if metric != 'num_queries':
                average_metrics[metric] += metrics[query_structure][metric]
        num_queries += metrics[query_structure]['num_queries']
        num_query_structures += 1

    for metric in average_metrics:
        average_metrics[metric] /= num_query_structures
        all_metrics["_".join(["average", metric])] = average_metrics[metric]
    log_metrics('%s average'%mode, average_metrics, writer)

    if args.do_cp:
        card_metrics = defaultdict(float)
        spearman = SpearmanCorrCoef()
        for query_structure in probs:
            prob = torch.stack(probs[query_structure])
            card = torch.stack(cards[query_structure]).squeeze().to(torch.float)
            pre_card = (prob >= cp_thrshd[query_structure]).to(torch.float).sum(-1)
            mape = (torch.abs(card - pre_card) / card).mean()
            spm = spearman(pre_card, card)
            card_metrics[query_name_dict[query_structure]+" MAPE"] = mape
            card_metrics[query_name_dict[query_structure]+" Spearman"] = spm
        log_metrics('Card', card_metrics, writer)
    writer.write('\n')
    return all_metrics

def load_data(args, tasks):
    '''
    Load queries and remove queries not in tasks
    '''
    logging.info("loading data")
    test_queries = pickle.load(open(os.path.join(args.data_path, "test-queries_test.pkl"), 'rb'))
    test_hard_answers = pickle.load(open(os.path.join(args.data_path, "test-hard-answer.pkl"), 'rb'))
    test_easy_answers = pickle.load(open(os.path.join(args.data_path, "test-easy-answer.pkl"), 'rb'))
    
    # remove tasks not in args.tasks
    # for name in all_tasks:
    #     if 'u' in name:
    #         name, evaluate_union = name.split('-')
    #     else:
    #         evaluate_union = args.evaluate_union
    #     if name not in tasks or evaluate_union != args.evaluate_union:
    #         query_structure = name_query_dict[name if 'u' not in name else '-'.join([name, evaluate_union])]
    #         if query_structure in valid_queries:
    #             del valid_queries[query_structure]
    #         if query_structure in test_queries:
    #             del test_queries[query_structure]

    # return valid_queries, valid_hard_answers, valid_easy_answers, test_queries, test_hard_answers, test_easy_answers
    return test_queries, test_hard_answers, test_easy_answers

def main(args):
    set_global_seed(args.seed)
    tasks = args.tasks.split('.') #Get task type
    device = "cuda" if torch.cuda.is_available() else "cpu"
    # using cpu
    # device = "cpu"
    print(device)
    dataset_name = args.data_path.split('/')[1].split('-')[0]
    # if args.data_path.split('/')[1].split('-')[1] == "237":
    #     dataset_name += "-237"
    filename = 'results/'+dataset_name+'_'+str(args.fraction)+'_'+str(args.thrshd)+'.txt'
    writer = open(filename, 'a+') 

    
    with open('%s/stats.txt'%args.data_path) as f: #data_path:data/FB1k-237-betae
        entrel = f.readlines()
        nentity = int(entrel[0].split(' ')[-1])
        nrelation = int(entrel[1].split(' ')[-1])# It says the number of entities and the number of relationships
        nattribute = int(entrel[2].split(' ')[-1])
        nnum = int(entrel[3].split(' ')[-1])
        num_rank = int(entrel[4].split(' ')[-1])
        nnumpre = int(entrel[5].split(' ')[-1])
    global id2ent, id2rel
    # with open('%s/id2ent.pkl'%args.data_path, 'rb') as f:
    #     id2ent = pickle.load(f)
    # with open('%s/ent2id.pkl'%args.data_path, 'rb') as f:
    #     ent2id = pickle.load(f)
    # with open('%s/id2rel.pkl'%args.data_path, 'rb') as f:
    #     id2rel = pickle.load(f)
    
    args.nentity = nentity
    args.nrelation = nrelation
    args.nattribute = nattribute
    args.nnum = nnum
    args.num_rank = num_rank
    args.nnumpre = nnumpre

    adj_list, edges_y, edges_p = read_triples([os.path.join(args.data_path, "train.txt")], args.nrelation, args.data_path)
    num_adj_list, num_edges_y, edges_p,adj_reverse_list = num_read_triples([os.path.join(args.data_path, "train_num.txt")], args.nattribute, args.data_path)
    np_adj_list, np_edges_y, edges_p = np_read_triples([os.path.join(args.data_path, "train_np.txt")], args.nnumpre, args.data_path)
    # edges_p contains only validation and test set triples, edges_y contains all triples.
    
    # valid_queries, valid_hard_answers, valid_easy_answers, test_queries, test_hard_answers, test_easy_answers = load_data(args, tasks)
    test_queries, test_hard_answers, test_easy_answers = load_data(args, tasks) 
    # valid_queries = flatten_query(valid_queries)
    # valid_dataloader = DataLoader(
    #     TestDataset(
    #         valid_queries, 
    #         args.nentity, 
    #         args.nrelation, 
    #         args.nattribute,
    #     ), 
    #     batch_size=args.test_batch_size,
    #     num_workers=args.cpu_num, 
    #     collate_fn=TestDataset.collate_fn
    # )

    test_queries = flatten_query(test_queries,test_easy_answers)
    test_dataloader = DataLoader(
        TestDataset(
            test_queries, 
            args.nentity, 
            args.nrelation, 
            args.nattribute,
        ), 
        batch_size=args.test_batch_size,
        num_workers=args.cpu_num, 
        collate_fn=TestDataset.collate_fn
    )
    # if args.do_cp:
    #     cp_thrshd = get_cp_thrshd(model, valid_hard_answers, valid_easy_answers, args, valid_dataloader, query_name_dict, device)
    dataset_name = args.data_path.split('/')[1].split('-')[0]
    with open('data/' + dataset_name + "-number/id2number.json",'r') as f:
        id2number = json.load(f)
    model = KGReasoning(args, device, adj_list,num_adj_list,np_adj_list,adj_reverse_list, query_name_dict, name_answer_dict,id2number) #Reasoning Model
    cp_thrshd = None
    evaluate(model, test_hard_answers, test_easy_answers, args, test_dataloader, query_name_dict, device, writer, edges_y, edges_p, cp_thrshd,'data/' + dataset_name + "-number/id2number.json")

if __name__ == '__main__':
    main(parse_args())