import os
import json
import requests
import time
from bs4 import BeautifulSoup
from tqdm import tqdm

from concurrent.futures import ProcessPoolExecutor, as_completed
import re

base_path = ''

MAX_RETRIES = 5
WAIT_PERIOD = 5  # seconds

def fetch_url(wiki_url):
    retries = 0
    while retries < MAX_RETRIES:
        headers = {'User-Agent': 'CoolBot/0.0 (https://example.org/coolbot/; coolbot@example.org)'}
        response = requests.get(wiki_url, headers=headers)
        if response.status_code == 429:
            time.sleep(WAIT_PERIOD * (retries + 1))
            retries += 1
        else:
            return response
    
    raise Exception("Maximum retries occurred")


def clean_table_html(content):
    # Remove all hyperlinks
    for a in content.find_all('a'):
        a.unwrap()

    # Remove images, as the images are taken care of in other functions.
    for img in content.find_all('img'):
        img.decompose()

    # Remove inline styles (style) (class attribute is used to apply color, font, borders, padding, margins, etc.)
    # Remove unique identifiers (id) and class to simplify table HTML structure.
    # data-mw-deduplicate is purposed for specifying mediaWiki that powers Wikipedia, hence we remove it.
    # typeof is metadata to describe the type of resource, which is redundant.
    for tag in content.find_all(True):
        tag.attrs = {key: value for key, value in tag.attrs.items() if key not in ['style', 'id', 'class', 'data-mw-deduplicate', 'typeof']}

    for script in content.find_all(['script', 'style']):
        script.decompose()

    return str(content)

def extract_images_with_sections(soup, title):

    # Only in the image & table extracting, we ignore the images in these sections.
    ignored_sections = {'References', 'External links', 'See also', 'Further reading', 'Footnotes', 'Notes'}

    supported_img_extensions = {'jpg', 'png', 'JPEG', 'JPG', 'SVG'}

    image_urls = []
    image_reference_descriptions = []
    image_section_indices = []

    # Find the summary part
    summary_started = False
    cur_section_name = title  # The wikipedia title would be the title of the summary part.
    cur_section_id = 0

    specific_tags = {'p', 'img'}
    _h = re.compile(r'^h\d+$') # h2, h3, ... etc.
    heading_tags = {tag.name for tag in soup.find_all() if _h.match(tag.name)}

    all_tages = list(heading_tags.union(specific_tags))

    for element in soup.find_all(all_tages):
        if element.name == 'p' and not summary_started:
            summary_started = True

        if summary_started:
            if 'h' in element.name:
                # Update section title
                cur_section_name = element.get_text(strip=True)
                cur_section_id += 1

            elif element.name == 'img':

                # We will not use these sections
                if cur_section_name in ignored_sections:
                    continue

                # Too small image is mostly redundant.
                if int(element.get('height', 39)) < 40:
                    continue

                # Ignore other non-image modality data like GIF
                if element['src'].split('.')[-1] not in supported_img_extensions:
                    continue

                img_url = "https:" + element['src']
                img_description = element.get('alt', '')

                # Check if the image has a parent figure or div with a description
                parent = element.find_parent(['figure', 'div'])
                if parent:
                    caption = parent.find('figcaption')
                    if caption:
                        img_description = caption.get_text(strip=True)

                image_urls.append(img_url)
                image_reference_descriptions.append(img_description)
                image_section_indices.append(cur_section_id)

    return image_urls, image_reference_descriptions, image_section_indices


def extract_tables_with_sections(soup, title):
    
    # Only in the image & table extracting, we ignore the images in these sections.
    ignored_sections = {'References', 'External links', 'See also', 'Further reading', 'Footnotes', 'Notes'}

    tables = []
    table_section_indices = []

    # Find the summary part
    summary_started = False
    cur_section_name = title  # The title would be the title of the summary part.
    cur_section_id = 0

    specific_tags = {'p', 'table',}
    _h = re.compile(r'^h\d+$') # h2, h3, ... etc.
    heading_tags = {tag.name for tag in soup.find_all() if _h.match(tag.name)}

    all_tages = list(heading_tags.union(specific_tags))

    for element in soup.find_all(all_tages):
        if element.name == 'p' and not summary_started:
            summary_started = True
        
        if summary_started:
            if 'h' in element.name:
                # Update section title
                cur_section_name = element.text.strip()
                cur_section_id += 1

            elif element.name == 'table':

                # We will not use these sections
                if cur_section_name in ignored_sections:
                    continue
                
                # The infobox is not treated as table here.
                if 'infobox' in element.get('class', []):
                    continue

                table_html = clean_table_html(element)
                tables.append(table_html)

                table_section_indices.append(cur_section_id)

    return tables, table_section_indices


def extract_text_with_sections(soup, title):
    section_texts = []
    section_hie_titles = []
    section_titles = []
    
    cur_sec_title = title

    # Stack to keep track of the section hierarchy
    section_stack = []

    summary_started = False
    temp_section_text = []

    specific_tags = {'p', 'ul', 'ol'}
    _h = re.compile(r'^h\d+$') # h2, h3, ... etc.
    heading_tags = {tag.name for tag in soup.find_all() if _h.match(tag.name)}

    all_tages = list(heading_tags.union(specific_tags))

    for element in soup.find_all(all_tages):
        if element.name == 'p' and not summary_started:
            summary_started = True

        if summary_started:
            # Section is changed
            if 'h' in element.name:
                # When section is changed we upload the stacked information to the lists.
                section_texts.append('\n'.join(temp_section_text))
                temp_section_text = []

                section_title = title if len(section_stack) == 0 else ':'.join(section_stack)
                section_hie_titles.append(section_title)
                section_titles.append(cur_sec_title)

                # Update the stack based on the section level
                section_level = int(element.name[1])
                while len(section_stack) >= section_level - 1:
                    section_stack.pop()

                cur_sec_title = element.get_text(strip=True)                
                section_stack.append(cur_sec_title + '.')  # Add '.' to match the format in the Viquae query dataset.

            elif element.name == 'p':
                paragraph_text = element.get_text(strip=True)
                # Ignore the warning message by the old wikipedia pages.
                if paragraph_text.startswith('This is anold revisionof this page'):
                    continue

                temp_section_text.append(paragraph_text)
            
            # When a 'ul' tag is found, the code iterates over all 'li' elements, which are bullet points in the document.
            elif element.name == 'ul':
                if element.find_parent('table') is None:  # To exclude the infobox card
                    # Extract text from all list items within the unordered list
                    for li in element.find_all('li'):
                        list_item_text = li.get_text(strip=True)
                        temp_section_text.append(f"- {list_item_text}")
            
            elif element.name == 'ol':
                list_items = element.find_all('li')
                for idx, li in enumerate(list_items):
                    list_item_text = li.get_text(strip=True)
                    temp_section_text.append(f"{idx + 1}. {list_item_text}")

    # Add the last section
    section_texts.append('\n'.join(temp_section_text))
    section_title = ':'.join(section_stack)
    section_hie_titles.append(section_title)
    section_titles.append(cur_sec_title)

    # Manually add the infobox info.
    infobox = soup.find('table', {'class': 'infobox'})
    if infobox:
        # Collect infobox content
        info_box_table_html = clean_table_html(infobox)
        section_texts[0] = info_box_table_html + '\n' + section_texts[0]

    return section_texts, section_hie_titles, section_titles


def construct_interleaved_document(chunk):
    chunk_documents = []
    for meta_data in tqdm(chunk, total=len(chunk)):
        wiki_url, title = meta_data

        # Load page file
        try:
            response = fetch_url(wiki_url)

            if response.status_code != 200:
                response.raise_for_status()

            soup = BeautifulSoup(response.text, 'lxml')

            images_with_sections = extract_images_with_sections(soup=soup, title=title)
            texts_with_sections = extract_text_with_sections(soup=soup, title=title)
            tables_with_sections = extract_tables_with_sections(soup=soup, title=title)

            chunk_documents.append((wiki_url, texts_with_sections, images_with_sections, tables_with_sections))
        
        except Exception as e:
            print(f"An error occurred: {e} in {wiki_url}")
            continue

    return chunk_documents


# Here, we load the html of wikipedia url and fill in the text, table, and images of the 100k documents.
if __name__ == '__main__':

    file_path = os.path.join(base_path, 'viquae_wikipedia', 'viquae_kb_wiki_empty_200k.json')
    kb_wiki = json.load(open(file_path, 'r'))

    wiki_meta_list = []
    for wiki_url in kb_wiki:
        wiki_meta_list.append((wiki_url, kb_wiki[wiki_url]['title']))

    # Parallel computing
    num_workers = min(os.cpu_count(), 32)  # Too many CPU causes 426 Client error, since there are too many requests.
    chunk_size = len(wiki_meta_list) // num_workers
    chunks = [wiki_meta_list[i*chunk_size:(i+1)*chunk_size] if i != (num_workers-1)
            else wiki_meta_list[i*chunk_size:] for i in range(num_workers)]

    global_kbs = []
    with ProcessPoolExecutor(max_workers=num_workers) as executor:
        futures = [executor.submit(construct_interleaved_document, chunk) for chunk in chunks]

        for future in as_completed(futures):
            chunk_kbs = future.result()
            global_kbs.extend(chunk_kbs)

    for doc_info in tqdm(global_kbs, total=len(global_kbs), desc="Filling the empty KB."):
        
        if doc_info == None:  # Case that the wikipedia url is not linked.
            continue

        wiki_url = doc_info[0]
        # Update text info
        kb_wiki[wiki_url]['section_texts'] = doc_info[1][0]
        kb_wiki[wiki_url]['section_hie_titles'] = doc_info[1][1]
        kb_wiki[wiki_url]['section_titles'] = doc_info[1][2]

        # Update image info
        kb_wiki[wiki_url]['image_urls'] = doc_info[2][0]
        kb_wiki[wiki_url]['image_reference_descriptions'] = doc_info[2][1]
        kb_wiki[wiki_url]['image_section_indices'] = doc_info[2][2]

        # Update table info
        kb_wiki[wiki_url]['tables'] = doc_info[3][0]
        kb_wiki[wiki_url]['table_section_indices'] = doc_info[3][1]

    with open(os.path.join(base_path, 'viquae_wikipedia', 'viquae_kb_wiki_200k.json'), 'w') as f:
        json.dump(kb_wiki, f)

    print('Done!')
