import copy
import os
import json
import pandas as pd
import random
import time
import requests
from bs4 import BeautifulSoup
from tqdm import tqdm

from concurrent.futures import ProcessPoolExecutor, as_completed

base_path = ''

MAX_RETRIES = 5
WAIT_PERIOD = 5  # seconds

def fetch_url(wiki_url):
    retries = 0
    while retries < MAX_RETRIES:
        headers = {'User-Agent': 'CoolBot/0.0 (https://example.org/coolbot/; coolbot@example.org)'}
        response = requests.get(wiki_url, headers=headers)
        if response.status_code == 429:
            time.sleep(WAIT_PERIOD * (retries + 1))
            retries += 1
        else:
            return response
    
    raise Exception("Maximum retries occurred")


def check_valid_entity(chunk):

    valid_entity_list = []

    for wiki_url in tqdm(chunk, total=len(chunk)):
        try:
            response = fetch_url(wiki_url)

            if response.status_code != 200:
                response.raise_for_status()

            soup = BeautifulSoup(response.text, 'lxml')

            # Check if the page is deprecated.
            redirected = soup.find(class_='redirectMsg')
            if redirected:
                continue

            # There are also some errored pages
            # https://en.wikipedia.org/w/index.php?title=Open%20University&oldid=907350037
            errored =  soup.find('div', class_='mw-message-box-error')
            if errored:
                continue
                
            valid_entity_list.append(wiki_url)
        
        except Exception as e:
            print(f"Error occured: {e} with {wiki_url}")
            continue

    return valid_entity_list

# Since the kb size is too large compared to the query dataset, we will extract only a small subset
# and merge the kb & query dataset to the infoseek or encyclopedic-vqa.
# Here, we extract 100k documents, including query-related ones and randomly selected ones.
# Since the KB is based on the old wikipedia urls, we remove the 
if __name__ == '__main__':

    train_query_df = pd.read_csv(os.path.join(base_path, 'train_name_clean.csv'))
    valid_query_df = pd.read_csv(os.path.join(base_path, 'validation_name_clean.csv'))
    test_query_df = pd.read_csv(os.path.join(base_path, 'test_name_clean.csv'))

    query_df = pd.concat([train_query_df, valid_query_df, test_query_df])

    document_kb = json.load(open(os.path.join(base_path, 'viquae_wikipedia/viquae_kb_wiki_empty.json'), 'r'))

    # Make directory for wikipedia document images
    os.makedirs(os.path.join(base_path, 'viquae_wikipedia', 'images'), exist_ok=True)

    subset_doc_kb = {}
    for query in query_df.itertuples(index=False):
        subset_doc_kb[query.wikipedia_url] = copy.deepcopy(document_kb[query.wikipedia_url])

    subset_doc_kb_urls = set(subset_doc_kb.keys())
    doc_kb_urls = set(document_kb.keys())

    unseen_kb_urls = list(set(document_kb.keys()) - set(subset_doc_kb.keys()))
    unseen_kb_urls = random.sample(unseen_kb_urls, 500*1000)
    # Parallel computing
    num_workers = min(os.cpu_count(), 32)  # Too many CPU causes 426 Client error, since there are too many requests.
    chunk_size = len(unseen_kb_urls) // num_workers
    chunks = [unseen_kb_urls[i*chunk_size:(i+1)*chunk_size] if i != (num_workers-1)
              else unseen_kb_urls[i*chunk_size:] for i in range(num_workers)]
    
    global_exist_doc_list = []
    with ProcessPoolExecutor(max_workers=num_workers) as executor:
        futures = [executor.submit(check_valid_entity, chunk) for chunk in chunks]

        for future in as_completed(futures):
            chunk_exist_doc_list = future.result()
            global_exist_doc_list.extend(chunk_exist_doc_list)
    
    chosen_exist_doc_list = random.sample(global_exist_doc_list, 200*1000 - len(subset_doc_kb))

    for kb_url in chosen_exist_doc_list:
        subset_doc_kb[kb_url] = copy.deepcopy(document_kb[kb_url])

    with open(os.path.join(base_path, 'viquae_wikipedia/viquae_kb_wiki_empty_200k.json'), 'w') as f:
        json.dump(subset_doc_kb, f)
    
