import json
from collections import defaultdict
import pickle
from tqdm import tqdm
with open("./data/FB15k-number/FB15Knumber2id.json",'r') as f: #need to change
    num2id = json.load(f)
def convert_to_tuple(nested_list):
    tuple_list = []
    for sub_list in nested_list:
        if isinstance(sub_list, list):
            tuple_list.append(convert_to_tuple(sub_list))
        else:
            tuple_list.append(sub_list)
    return tuple(tuple_list)
def get_subpattern(pattern):
    pattern = pattern[1:-1]
    parenthesis_count = 0

    sub_queries = []
    jj = 0

    for ii, character in enumerate(pattern):
        # Skip the comma inside a parenthesis
        if character == "(":
            parenthesis_count += 1

        elif character == ")":
            parenthesis_count -= 1

        if parenthesis_count > 0:
            continue

        if character == ",":
            sub_queries.append(pattern[jj: ii])
            jj = ii + 1

    sub_queries.append(pattern[jj: len(pattern)])
    return sub_queries
doublesub_htype = ["np","b"]
doublesub_ntype = ["i","u"]
ntype_dict = {"i":-1,"u":-2}
stop_symbol = ["e","nv"]
def split_queries(pattern):
    sub_queries = get_subpattern(pattern)
    now_prediction = str(sub_queries[0])
    if(now_prediction in doublesub_htype and len(sub_queries) == 4):
        cal_symbol = sub_queries[1].replace("(","")
        cal_symbol = cal_symbol.replace(")","")
        cal_symbol = int(cal_symbol)
        que_1,symbol_1 = split_queries(sub_queries[2])
        que_2,symbol_2 = split_queries(sub_queries[3])
        out = list()
        out.append(que_1)
        out.append(que_2)
        out.append([now_prediction])
        out_sym = list()
        out_sym.append(symbol_1)
        out_sym.append(symbol_2)
        out_sym.append([cal_symbol])
        return out,out_sym
    elif(now_prediction in doublesub_htype and len(sub_queries) == 3):
        cal_symbol = sub_queries[1].replace("(","")
        cal_symbol = cal_symbol.replace(")","")
        cal_symbol = int(cal_symbol)
        que_1,symbol_1 = split_queries(sub_queries[2])
        out = list()
        out.append(que_1)
        out.append([now_prediction])
        out_sym = list()
        out_sym.append(symbol_1)
        out_sym.append([cal_symbol])
        return out,out_sym
        # return "(" + que_1 + ",('" + now_prediction + "'))",\
        #     "(" + symbol_1 + ",(" + cal_symbol + "))"
    elif(now_prediction in doublesub_htype and len(sub_queries) == 5):
        cal_symbol = sub_queries[1].replace("(","")
        cal_symbol = cal_symbol.replace(")","")
        cal_symbol = int(cal_symbol)
        que_1,symbol_1 = split_queries(sub_queries[2])
        que_2,symbol_2 = split_queries(sub_queries[3])
        que_3,symbol_3 = split_queries(sub_queries[4])
        out = list()
        out.append(que_1)
        out.append(que_2)
        out.append(que_3)
        out.append([now_prediction])
        out_sym = list()
        out_sym.append(symbol_1)
        out_sym.append(symbol_2)
        out_sym.append(symbol_3)
        out_sym.append([cal_symbol])
        return out,out_sym
    elif(now_prediction in doublesub_ntype and len(sub_queries) == 4):
        que_1,symbol_1 = split_queries(sub_queries[1])
        que_2,symbol_2 = split_queries(sub_queries[2])
        que_3,symbol_3 = split_queries(sub_queries[3])
        out = list()
        out.append(que_1)
        out.append(que_2)
        out.append(que_3)
        out.append([now_prediction])
        out_sym = list()
        out_sym.append(symbol_1)
        out_sym.append(symbol_2)
        out_sym.append(symbol_3)
        out_sym.append(ntype_dict[now_prediction])
        return out,out_sym
    elif(now_prediction in doublesub_ntype):
        que_1,symbol_1 = split_queries(sub_queries[1])
        que_2,symbol_2 = split_queries(sub_queries[2])
        out = list()
        out.append(que_1)
        out.append(que_2)
        out.append([now_prediction])
        out_sym = list()
        out_sym.append(symbol_1)
        out_sym.append(symbol_2)
        out_sym.append(ntype_dict[now_prediction])
        return out,out_sym
        # return "(" + que_1 + ")," + que_2 + ")"+ ",('" + now_prediction + "'))",\
        #     "(" + symbol_1 + ")," + symbol_2 + ")"+ ",(" + ntype_dict[now_prediction] + "))"
    elif(now_prediction in stop_symbol):
        cal_symbol = sub_queries[1].replace("(","")
        cal_symbol = cal_symbol.replace(")","")
        if(now_prediction == "e"):
            cal_symbol = int(cal_symbol)
        else:
            if cal_symbol == '-0.0':
                cal_symbol = '0.0'
            cal_symbol = int(num2id[cal_symbol])
        return now_prediction,cal_symbol
        # return "'" + now_prediction + "'",cal_symbol
    else:
        cal_symbol = sub_queries[1].replace("(","")
        cal_symbol = cal_symbol.replace(")","")
        cal_symbol = int(cal_symbol)
        que_1,symbol_1 = split_queries(sub_queries[2])
        out = list()
        out.append(que_1)
        out.append([now_prediction])
        out_sym = list()
        out_sym.append(symbol_1)
        out_sym.append([cal_symbol])
        return out,out_sym
        # return "(" + que_1 + ",('" + now_prediction + "',)","(" + symbol_1 + ",(" + cal_symbol + ",)"
queries_all = {}
easy_ans = {}
hard_ans = {}
for sym in tqdm(range(0,18)):
    with open("./Generate_Queries/sampled_data/FB15K_e34_test_queries_" + str(sym) + ".json",'r') as f: #need to change
        data = json.load(f)
        for type_all in data:
            queries_temp = {}
            for queries in data[type_all]:
                a,b = split_queries(queries)
                a = convert_to_tuple(a)
                b = convert_to_tuple(b)
                if a not in easy_ans:
                    easy_ans[a] = {}
                if a not in hard_ans:
                    hard_ans[a] = {}
                easy_ans[a][b] = data[type_all][queries]["train_answers"] + data[type_all][queries]["valid_answers"]
                hard_ans[a][b] = data[type_all][queries]["test_answers"]
                if(a not in queries_all):
                    queries_all[a] = set(b)
                else:
                    queries_all[a].add(b)
with open("test-queries.pkl","wb") as f:
    pickle.dump(queries_all,f)
with open("test-easy-answer.pkl","wb") as f:
    pickle.dump(easy_ans,f)
with open("test-hard-answer.pkl","wb") as f:
    pickle.dump(hard_ans,f)