import json
import os
import pandas as pd
import time
import requests
from bs4 import BeautifulSoup

base_path = ''

MAX_RETRIES = 5
WAIT_PERIOD = 5  # seconds

def fetch_url(wiki_url):
    retries = 0
    while retries < MAX_RETRIES:
        headers = {'User-Agent': 'CoolBot/0.0 (https://example.org/coolbot/; coolbot@example.org)'}
        response = requests.get(wiki_url, headers=headers)
        if response.status_code == 429:
            time.sleep(WAIT_PERIOD * (retries + 1))
            retries += 1
        else:
            return response
    
    raise Exception("Maximum retries occurred")


def check_deprecated_urls(wiki_url):
    # Load page file
    try:
        response = fetch_url(wiki_url)
        if response.status_code != 200:
            response.raise_for_status()

        soup = BeautifulSoup(response.text, 'lxml')

        # Check if the page is deprecated.
        deprecated = soup.find(class_='redirectMsg')

        # There are also some errored pages
        # https://en.wikipedia.org/w/index.php?title=Open%20University&oldid=907350037
        errored =  soup.find('div', class_='mw-message-box-error')

        if deprecated or errored:
            return True
        else:
            return False
        
    except Exception as e:
        print(f"An error occurred: {e} in {wiki_url}")
        return True

def preprocess_viquae(wiki_url_dict, valid_doc_urls, split='train'):

    question_list = []
    answer_list = []
    answer_eval_list = []
    image_path_list = []
    wikipedia_url_list = []
    evidence_section_name_list = []
    question_original_list = []

    with open(os.path.join(base_path, f'viquae_dataset/{split}.jsonl'), encoding="utf-8") as f:
        for idx, line in enumerate(f):
            query = json.loads(line.strip())
            valid = False

            # Use only the first valid relevant document for simplicity.
            for i in range(len(query['output']['provenance'])):
                relevant_document_id = query['output']['provenance'][i]['wikipedia_id'][0]
                evidence_sec_name = query['output']['provenance'][i]['section'][0]
                document_title = query['output']['provenance'][i]['title'][0]
                relevent_doc_url = wiki_url_dict[relevant_document_id]

                if relevent_doc_url in valid_doc_urls:
                    valid = True
                    break
            
            # query section name and document is not matching.
            if relevent_doc_url == 'https://en.wikipedia.org/w/index.php?title=Ethiopia&oldid=920775651':
                continue
            

            # Some wikipedia urls are deprecated.
            deprecated = check_deprecated_urls(relevent_doc_url)

            if valid and not deprecated:
                question_list.append(query['input'])
                question_original_list.append(query['original_question'])
                answer_list.append(query['output']['original_answer'])
                answer_eval_list.append(query['output']['answer'])
                image_path_list.append(query['image'])

                wikipedia_url_list.append(relevent_doc_url)

                evidence_sec_name = evidence_sec_name.split("Section::::")[-1].strip()
                if evidence_sec_name == 'Abstract.':
                    evidence_sec_name = document_title
                evidence_section_name_list.append(evidence_sec_name)

    viquae_split_list = [split] * len(question_list)
    dataset_name = ['Viquae'] * len(question_list)
    clean_query = {'question': question_list,
                   'answer': answer_list,
                   'dataset_image_ids': image_path_list,
                   'wikipedia_url': wikipedia_url_list,
                   'evidence_section_name': evidence_section_name_list,  # since we don't have the evidence id yet, this should be updated later.
                   'viquae_split': viquae_split_list,
                   'question_original':  question_original_list,
                   'dataset_name': dataset_name,
                   }
    
    clean_query = pd.DataFrame(clean_query)
    clean_query.to_csv(os.path.join(base_path, f'{split}_name.csv'), index=False)


# Here, we change the viquae query format into the encyclopedic-vqa like format
# Some old URLs are deprecated. (train: 1171, valid: 1230, test: 1241)
# We need to merge this into another dataset.
if __name__ == '__main__':

    wikipedia_id_to_url = json.load(open(os.path.join(base_path, 'viquae_wikipedia', 'wikipedia_id_to_url.json'), 'r'))
    # Since the ViQuae dataset use the 1.5M KB out of 6M KB from the KILT KB.
    valid_document_urls = set(json.load(open(os.path.join(base_path, 'viquae_wikipedia', 'viquae_kb_wiki_empty.json'), 'r')).keys())

    preprocess_viquae(wikipedia_id_to_url, valid_document_urls, split='train')
    preprocess_viquae(wikipedia_id_to_url, valid_document_urls, split='validation')
    preprocess_viquae(wikipedia_id_to_url, valid_document_urls, split='test')

    print('Done!')