import pickle
import torch
with open("test-queries.pkl", 'rb') as inp_f:
    queries = pickle.load(inp_f)
with open("test-hard-answer.pkl", 'rb') as inp_f:
    test_hard_answers = pickle.load(inp_f)
with open("test-easy-answer.pkl", 'rb') as inp_f:
    test_easy_answers = pickle.load(inp_f)
all_queries = []
for query_structure in queries:
    tmp_queries = list(queries[query_structure])
    all_queries.extend([(query, query_structure) for query in tmp_queries])
global out
from tqdm import tqdm
out = []
def destack_tuple(t):
    for element in t:
        if isinstance(element, tuple):
            destack_tuple(element)
        else:
            out.append(element)
    return out
i = 0
print(len(all_queries))
while(i < len(all_queries)):
    que = all_queries[i]
    if(not isinstance(que[0],tuple)):
        all_queries.remove(que)
        continue
    out = []
    length1 = len(destack_tuple(que[0]))
    out = []
    length2 = len(destack_tuple(que[1]))
    if(length1 != length2 or length1 == 0 or length2 == 0):
        all_queries.remove(que)
        continue
    if(len(test_hard_answers[que[1]][que[0]]) == 0):
        all_queries.remove(que)
        continue
    i += 1
    if(i % 10000 == 0):
        print(i)
print(len(all_queries))
queries_all = {}
for (query, query_structure) in all_queries:
    if(query_structure not in queries_all):
        a = set()
        a.add(query)
        queries_all[query_structure] = a
    else:
        queries_all[query_structure].add(query)
with open("test-queries.pkl", 'wb') as inp_f:
    pickle.dump(queries_all,inp_f)