import os
import json
import pandas as pd
import numpy as np
from tqdm import tqdm
from bs4 import BeautifulSoup
import requests
import random
import wikipedia

from concurrent.futures import ProcessPoolExecutor, as_completed


base_path = ''

def check_valid_entity(chunk):
    valid_entities, invalid_entities = [], []

    for entity in tqdm(chunk, total=len(chunk)):
        try:
            content = wikipedia.page(entity, auto_suggest=False)
            valid_entities.append(entity)
        except Exception as e:
            print(e)
            invalid_entities.append(entity)
            print(invalid_entities)

    return valid_entities, invalid_entities


# Some documents are invalid when loaded from the recontructed Wikipedia URL 
# (https://en.wikipedia.org/wiki/Awsegeawg, https://en.wikipedia.org/wiki/Yakut_revolt).
if __name__ == "__main__":

    file_path = os.path.join(base_path, 'infoseek_wikipedia', 'Wiki6M_ver_1_0_title_only.jsonl')
    kb_wiki_meta = pd.read_json(file_path, lines=True)

    train_query_df = pd.read_csv(os.path.join(base_path, 'infoseek_train_section.csv'))
    val_query_df = pd.read_csv(os.path.join(base_path, 'infoseek_val_section.csv'))
    test_query_df = pd.read_csv(os.path.join(base_path, 'infoseek_test_section.csv'))

    query_df = pd.concat([train_query_df, val_query_df, test_query_df])

    query_meta = []
    for query in query_df.itertuples(index=False):
        query_meta.append(query.wikipedia_title)

    query_meta = set(query_meta)  # to remove duplicated documents.

    wiki_meta = []
    for wiki_data in tqdm(kb_wiki_meta.itertuples(index=False), total=len(kb_wiki_meta)):
        # wiki_url = f"https://en.wikipedia.org/wiki/{wiki_data.wikipedia_title.replace(' ', '_')}"
        # wiki_meta.append(wiki_url)
        wiki_meta.append(wiki_data.wikipedia_title)
    
    wiki_meta = set(wiki_meta)
    unseen_meta = wiki_meta - query_meta
    unseen_subset_meta = set(random.sample(unseen_meta, 1500 * 1000 - len(query_meta)))

    unseen_subset_meta = list(unseen_subset_meta.union(query_meta))
    assert len(unseen_subset_meta) == 1500*1000, f"the final subset size is {unseen_subset_meta}"

    # Parallel computing
    num_workers = min(os.cpu_count(), 32)  # Too many CPU causes 426 Client error, since there are too many requests.
    chunk_size = len(unseen_subset_meta) // num_workers
    chunks = [unseen_subset_meta[i*chunk_size:(i+1)*chunk_size] if i != (num_workers-1)
                else unseen_subset_meta[i*chunk_size:] for i in range(num_workers)]

    global_valid_kbs = []
    global_invalid_kbs = []
    with ProcessPoolExecutor(max_workers=num_workers) as executor:
        futures = [executor.submit(check_valid_entity, chunk) for chunk in chunks]

        for future in as_completed(futures):
            chunk_kbs = future.result()
            global_valid_kbs.extend(chunk_kbs[0])
            global_invalid_kbs.extend(chunk_kbs[1])

    with open(os.path.join(base_path, 'infoseek_wikipedia', 'infoseek_valid_entities.json'), 'w') as f:
        json.dump(global_valid_kbs, f)

    with open(os.path.join(base_path, 'infoseek_wikipedia', 'infoseek_invalid_entities.json'), 'w') as f:
        json.dump(global_invalid_kbs, f)

    print('Done!')