import copy
import os
import json
import pandas as pd
import random

base_path = ''

# For fast evaluation of retrieval performance with test dataset, we extract randomly extract 100k KB that contains test dataset entities.
if __name__ == '__main__':

    test_query_df = pd.read_csv(os.path.join(base_path, 'test_clean.csv'))
    document_kb = json.load(open(os.path.join(base_path, 'infoseek_wikipedia/infoseek_kb_wiki_800k.json'), 'r'))

    subset_doc_kb = {}
    for query in test_query_df.itertuples(index=False):
        subset_doc_kb[query.wikipedia_url] = copy.deepcopy(document_kb[query.wikipedia_url])

    subset_doc_kb_urls = set(subset_doc_kb.keys())
    doc_kb_urls = set(document_kb.keys())
    # Due to GPU limit, we cannot handle too long document
    long_doc_urls = set()
    for kb_url in doc_kb_urls:
        if len(document_kb[kb_url]['section_titles']) > 25:
            long_doc_urls.add(kb_url)
    doc_kb_urls = doc_kb_urls - long_doc_urls

    unseen_kb_urls = doc_kb_urls - subset_doc_kb_urls

    # unseen_kb_urls = {item for item in unseen_kb_urls if not item.startswith('https://en.wikipedia.org/wiki/Glossary_of_baseball')}
    unseen_subset_kb_urls = random.sample(unseen_kb_urls, 500*1000 - len(subset_doc_kb_urls))

    for kb_url in unseen_subset_kb_urls:
        subset_doc_kb[kb_url] = copy.deepcopy(document_kb[kb_url])

    with open(os.path.join(base_path, 'infoseek_wikipedia/infoseek_kb_wiki_test.json'), 'w') as f:
        json.dump(subset_doc_kb, f)
