import os
import json
import time
import requests

from tqdm import tqdm
from PIL import Image
from io import BytesIO
from concurrent.futures import ProcessPoolExecutor, as_completed


base_path = ''
min_size = 256

MAX_RETRIES = 5
WAIT_PERIOD = 5  # seconds

def fetch_image(wiki_url):
    retries = 0
    while retries < MAX_RETRIES:
        headers = {'User-Agent': 'CoolBot/0.0 (https://example.org/coolbot/; coolbot@example.org)'}
        response = requests.get(wiki_url, stream=True, headers=headers)
        if response.status_code == 429:
            time.sleep(WAIT_PERIOD * (retries + 1))
            retries += 1
        else:
            return response
    
    raise Exception("Maximum retries occurred")


def download_image(document_image_urls_chunk):

    failed_image_urls = []
    downloaded_image_urls = []
    for doc_image_url in tqdm(document_image_urls_chunk, total=len(document_image_urls_chunk)):

        try:
            response = fetch_image(doc_image_url)

            if response.status_code != 200:
                response.raise_for_status()
                
            img = Image.open(response.raw)

            # Resize the image so that the minimum size of axis is min_size
            width, height = img.size
            if width < height:
                scaling_factor = min_size / width
            else:
                scaling_factor = min_size / height
            new_width = int(width * scaling_factor)
            new_height = int(height * scaling_factor)

            img = img.resize((new_width, new_height))

            # Save the image
            image_id = doc_image_url.split('/')[-1]
            img.save(os.path.join(base_path, 'viquae_wikipedia/images', image_id))

            downloaded_image_urls.append(doc_image_url)

        except Exception as e:
            print(f"Error occured: {e} with {doc_image_url}")
            failed_image_urls.append(doc_image_url)

    return downloaded_image_urls, failed_image_urls


# Viquae KB does not provide the full images. Here, we download the images for each document.
if __name__ == '__main__':

    # Load document KB.
    document_kb = json.load(open(os.path.join(base_path, 'viquae_wikipedia', 'viquae_kb_wiki_200k.json'), 'r'))
    # Load image_url -> image id mapping function of documents.
    doc_image_url_to_id = {}

    # Record image urls.
    document_path_list = list(document_kb.keys())
    kb_image_urls = []
    for document_path in tqdm(document_path_list, total=len(document_path_list)):
        for doc_image_url in document_kb[document_path]["image_urls"]:
            kb_image_urls.append(doc_image_url)

    os.makedirs(os.path.join(base_path, 'viquae_wikipedia', 'images'), exist_ok=True)

    # Download the omitted images and record the failed download.
    global_failed_image_urls = []
    global_downloaded_image_urls = []

    num_workers = min(os.cpu_count(), 32)  # Too many CPU causes 426 Client error, since there are too many requests.
    chunk_size = len(kb_image_urls) // num_workers
    chunks = [kb_image_urls[i*chunk_size:(i+1)*chunk_size] if i != (num_workers-1)
              else kb_image_urls[i*chunk_size:] for i in range(num_workers)]

    with ProcessPoolExecutor(max_workers=num_workers) as executor:
        futures = [executor.submit(download_image, chunk) for chunk in chunks]

        for future in as_completed(futures):
            chunk_downloaded_image_urls, chunk_failed_image_urls =  future.result()
            global_downloaded_image_urls.extend(chunk_downloaded_image_urls)
            global_failed_image_urls.extend(chunk_failed_image_urls)

    # Update the mapping function.
    for image_url in global_downloaded_image_urls:
        image_id = image_url.split('/')[-1]
        doc_image_url_to_id[image_url] = image_id

    with open(os.path.join(base_path, 'viquae_wikipedia/image_url_to_id.json'), 'w') as f:
        json.dump(doc_image_url_to_id, f)

    # Record the final failed images.
    with open(os.path.join(base_path, 'viquae_wikipedia/omitted_image_urls.json'), 'w') as f:
        json.dump(global_failed_image_urls, f)

    print("Done!")
