import os
import json
import pandas as pd
import numpy as np
from tqdm import tqdm
from bs4 import BeautifulSoup
import requests
import random
import time

from concurrent.futures import ProcessPoolExecutor, as_completed


base_path = ''

MAX_RETRIES = 5
WAIT_PERIOD = 5  # seconds


def fetch_url(wiki_url):
    retries = 0
    while retries < MAX_RETRIES:
        headers = {'User-Agent': 'CoolBot/0.0 (https://example.org/coolbot/; coolbot@example.org)'}
        response = requests.get(wiki_url, headers=headers)
        if response.status_code == 429:
            time.sleep(WAIT_PERIOD * (retries + 1))
            retries += 1
        else:
            return response
    
    raise Exception("Maximum retries occurred")


def check_valid_entity(chunk):

    valid_entity_list = []
    invalid_entity_list = []
    for wiki_url in tqdm(chunk, total=len(chunk)):
        
        try:
            response = fetch_url(wiki_url)

            # There are no corresponding entities, for example, https://en.wikipedia.org/wiki/Awsegeawg
            if response.status_code != 200:
                response.raise_for_status()
            
            soup = BeautifulSoup(response.text, 'lxml')

            # Check if the disambiguation notice, for example, https://en.wikipedia.org/wiki/Yakut_revolt
            disambiguation_notice = soup.find('table', {'id': 'disambigbox'})
            if disambiguation_notice:
                invalid_entity_list.append(wiki_url)
                continue
            
            categories = soup.find_all('div', {'class': 'mw-normal-catlinks'})
            disambiguation_notice = False
            for category in categories:
                if 'disambiguation' in category.get_text().lower():
                    disambiguation_notice = True
                    break
            
            if disambiguation_notice:
                invalid_entity_list.append(wiki_url)
                continue

            valid_entity_list.append(wiki_url)
            
        except Exception as e:
            print(f"Error occured: {e} with {wiki_url}")
            invalid_entity_list.append(wiki_url)
            continue

    return valid_entity_list, invalid_entity_list

# Some documents are invalid when loaded from the recontructed Wikipedia URL 
# (https://en.wikipedia.org/wiki/Awsegeawg, https://en.wikipedia.org/wiki/Yakut_revolt).
if __name__ == "__main__":

    file_path = os.path.join(base_path, 'infoseek_wikipedia', 'Wiki6M_ver_1_0_title_only.jsonl')
    kb_wiki_meta = pd.read_json(file_path, lines=True)

    train_query_df = pd.read_csv(os.path.join(base_path, 'infoseek_train_section.csv'))
    val_query_df = pd.read_csv(os.path.join(base_path, 'infoseek_val_section.csv'))
    test_query_df = pd.read_csv(os.path.join(base_path, 'infoseek_test_section.csv'))

    query_df = pd.concat([train_query_df, val_query_df, test_query_df])

    query_meta = []
    for query in query_df.itertuples(index=False):
        query_meta.append(query.wikipedia_url)

    query_meta = set(query_meta)  # to remove duplicated documents.

    wiki_meta = []
    for wiki_data in tqdm(kb_wiki_meta.itertuples(index=False), total=len(kb_wiki_meta)):
        wiki_url = f"https://en.wikipedia.org/wiki/{wiki_data.wikipedia_title.replace(' ', '_')}"
        wiki_meta.append(wiki_url)
    
    wiki_meta = set(wiki_meta)
    unseen_meta = wiki_meta - query_meta
    unseen_subset_meta = set(random.sample(unseen_meta, 1500 * 1000 - len(query_meta)))

    unseen_subset_meta = list(unseen_subset_meta.union(query_meta))
    assert len(unseen_subset_meta) == 1500*1000, f"the final subset size is {unseen_subset_meta}"

    # Parallel computing
    num_workers = min(os.cpu_count(), 32)  # Too many CPU causes 426 Client error, since there are too many requests.
    chunk_size = len(unseen_subset_meta) // num_workers
    chunks = [unseen_subset_meta[i*chunk_size:(i+1)*chunk_size] if i != (num_workers-1)
            else unseen_subset_meta[i*chunk_size:] for i in range(num_workers)]

    global_valid_kbs = []
    global_invalid_kbs = []
    with ProcessPoolExecutor(max_workers=num_workers) as executor:
        futures = [executor.submit(check_valid_entity, chunk) for chunk in chunks]

        for future in as_completed(futures):
            chunk_kbs = future.result()
            global_valid_kbs.extend(chunk_kbs[0])
            global_invalid_kbs.extend(chunk_kbs[1])

    with open(os.path.join(base_path, 'infoseek_wikipedia', 'infoseek_valid_urls.json'), 'w') as f:
        json.dump(global_valid_kbs, f)

    with open(os.path.join(base_path, 'infoseek_wikipedia', 'infoseek_invalid_urls.json'), 'w') as f:
        json.dump(global_invalid_kbs, f)

    print('Done!')