# This script was hastily generated by Gemini-2.5 Pro with some edits. What a wonderful time to be alive.

import requests
import time
import os
import json
import re
import io
import torch
import clip
from PIL import Image, UnidentifiedImageError

DOWNLOAD_DIR = "../data/factual_recall/images"
REQUEST_DELAY = 1.0  # Delay between API requests (seconds)
MAX_CANDIDATES_TO_SCORE = 5
MAX_ASPECT_RATIO = 1.5  # Maximum allowed aspect ratio distortion (1.5 means width is no more than 1.5x height, and vice-versa)
RESIZE_TARGET = (
    256  # The maximum dimension after resizing, and the size of the square crop
)

# Define acceptable licenses (check Wikimedia Commons for exact short names)
ACCEPTABLE_LICENSES = [
    "CC BY-SA",  # Creative Commons Attribution-ShareAlike (versions 1.0 to 4.0)
    "CC BY",  # Creative Commons Attribution (versions 1.0 to 4.0)
    "CC0",  # Creative Commons Zero (Public Domain dedication)
    "PD",  # Public Domain markers (PD-old, PD-USGov, etc.) - Be cautious, verify specifics
    "Public domain",
]

# The entity filename, containing the factual recall question templates with lists of entities
# for each template. The templates were adopted from https://arxiv.org/abs/2308.09124, and the lists
# of entities were taken from their dataset.
entity_qa_json = json.load(open("../data/factual_recall/qa_raw.json", "r"))
entities = [entity_answer[0] for entity_answer in sum(entity_qa_json.values(), [])]


def sanitize_filename(name):
    """Removes invalid characters for filenames."""
    sanitized = re.sub(r'[\\/*?:"<>|]', "", name)
    sanitized = sanitized.replace(" ", "_")
    sanitized = re.sub(r"_+", "_", sanitized)
    sanitized = sanitized.strip("._")
    max_len = 100
    if len(sanitized) > max_len:
        sanitized = sanitized[:max_len]
    if not sanitized:
        sanitized = "unnamed_entity"
    return sanitized


def process_and_save_image(image_url, entity_name, download_dir):
    """Downloads, resizes (based on min dimension), crops, and saves image as PNG."""
    global RESIZE_TARGET

    if not image_url:
        return None

    try:
        print(f"   Processing final selected image: {image_url}")
        headers = {"User-Agent": "MyImageDownloaderBot/1.0"}
        response = requests.get(image_url, headers=headers, stream=True, timeout=20)
        response.raise_for_status()

        try:
            img = Image.open(io.BytesIO(response.content))
        except UnidentifiedImageError:
            print("     Error: Downloaded content is not a recognizable image format.")
            return None
        except Exception as e_open:
            print(f"     Error opening image with Pillow: {e_open}")
            return None

        if img.mode not in ["RGB", "RGBA"]:
            img = img.convert("RGBA")

        width, height = img.size
        if width == 0 or height == 0:
            print("     Error: Image has zero width or height.")
            return None

        # Scale based on the MINIMUM dimension
        scale = RESIZE_TARGET / min(width, height)
        new_width = int(width * scale)
        new_height = int(height * scale)

        if new_width == 0 or new_height == 0:
            print("     Error: Calculated new dimension is zero during resize.")
            return None

        print(
            f"     Resizing from {width}x{height} to {new_width}x{new_height} (based on min dim)"
        )
        try:
            resized_img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
        except Exception as e_resize:
            print(f"     Error during Pillow resize: {e_resize}")
            return None

        # --- Crop Center Square ---
        left = (new_width - RESIZE_TARGET) // 2
        top = (new_height - RESIZE_TARGET) // 2
        right = left + RESIZE_TARGET
        bottom = top + RESIZE_TARGET
        print(f"     Cropping center {RESIZE_TARGET}x{RESIZE_TARGET} square")
        try:
            cropped_img = resized_img.crop((left, top, right, bottom))
        except Exception as e_crop:
            print(f"     Error during Pillow crop: {e_crop}")
            return None

        # --- Save as PNG (Simplified Filename) ---
        sanitized_name = sanitize_filename(entity_name)
        filename = f"{sanitized_name}.png"  # Single image per entity now
        filepath = os.path.join(download_dir, filename)

        print(f"     Saving final image to: {filepath}")
        try:
            cropped_img.save(filepath, format="PNG")
            return filepath
        except IOError as e_save:
            print(f"     Error saving file {filepath}: {e_save}")
            return None
        except Exception as e_save_other:
            print(
                f"     An unexpected error occurred saving {filepath}: {e_save_other}"
            )
            return None

    # Error handling for download / general processing remains similar
    except requests.exceptions.Timeout:
        print(f"     Error: Timeout downloading {image_url}")
        return None
    except requests.exceptions.RequestException as e_req:
        print(f"     Error downloading {image_url}: {e_req}")
        return None
    except Exception as e_general:
        print(
            f"     An unexpected error occurred processing image from {image_url}: {e_general}"
        )
        return None


# --- Wikimedia API Interaction (Functions unchanged from previous version) ---
WIKIMEDIA_API_URL = "https://commons.wikimedia.org/w/api.php"


def search_wikimedia_images(query):
    """Search Wikimedia Commons for image files related to the query."""
    params = {
        "action": "query",
        "format": "json",
        "list": "search",
        "srsearch": query,
        "srnamespace": 6,
        "srlimit": 25,  # Increased limit for more candidates
    }
    headers = {"User-Agent": "MyImageDownloaderBot/1.0"}
    try:
        response = requests.get(
            WIKIMEDIA_API_URL, params=params, headers=headers, timeout=10
        )
        response.raise_for_status()
        data = response.json()
        if "query" in data and "search" in data["query"]:
            return [item["title"] for item in data["query"]["search"]]
        return []
    except requests.exceptions.RequestException as e:
        print(f" - Error searching Wikimedia for '{query}': {e}")
        return []


def get_wikimedia_image_info(file_titles):
    """Get image info (URL, license, size) for a list of file titles."""
    if not file_titles:
        return {}
    params = {
        "action": "query",
        "format": "json",
        "prop": "imageinfo",
        "iiprop": "url|extmetadata|size",
        "titles": "|".join(file_titles),
    }
    headers = {"User-Agent": "MyImageDownloaderBot/1.0"}  # CHANGE THIS
    try:
        response = requests.get(
            WIKIMEDIA_API_URL, params=params, headers=headers, timeout=15
        )
        response.raise_for_status()
        data = response.json()
        return data.get("query", {}).get("pages", {})
    except requests.exceptions.RequestException as e:
        print(f" - Error getting image info from Wikimedia: {e}")
        return {}


def check_aspect_ratio(image_page_data):
    """Checks if the image aspect ratio is within the allowed range."""
    # (Implementation unchanged from previous version)
    global MAX_ASPECT_RATIO
    if "imageinfo" not in image_page_data or not image_page_data["imageinfo"]:
        return False, "No image info for aspect ratio check"
    img_info = image_page_data["imageinfo"][0]
    width = img_info.get("width")
    height = img_info.get("height")
    if (
        not isinstance(width, int)
        or not isinstance(height, int)
        or width <= 0
        or height <= 0
    ):
        return False, f"Invalid dimensions ({width}x{height})"
    ratio = max(width / height, height / width)
    is_ok = 1.0 <= ratio <= MAX_ASPECT_RATIO
    reason = f"Aspect ratio {ratio:.2f} (WxH: {width}x{height})"
    if not is_ok:
        reason += f" - outside allowed range [1.0, {MAX_ASPECT_RATIO}]"
    print(f"   Aspect ratio check: {is_ok} ({reason})")
    return is_ok, reason


def check_license(image_page_data):
    """Checks if the image license in metadata is acceptable."""
    # (Implementation mostly unchanged from previous version)
    if "imageinfo" not in image_page_data or not image_page_data["imageinfo"]:
        return False, "No image info for license check"
    metadata = image_page_data["imageinfo"][0].get("extmetadata", {})
    license_short = metadata.get("LicenseShortName", {}).get("value", "").strip()
    usage_terms = metadata.get("UsageTerms", {}).get("value", "").strip()
    found_license_str = ""
    is_acceptable = False
    # Check logic remains the same...
    if license_short:
        found_license_str = license_short
        for acceptable in ACCEPTABLE_LICENSES:
            if license_short.startswith(acceptable) or license_short == acceptable:
                is_acceptable = True
                break
            if "public domain" in license_short.lower() and (
                "PD" in ACCEPTABLE_LICENSES or "Public domain" in ACCEPTABLE_LICENSES
            ):
                is_acceptable = True
                break
    if not is_acceptable and usage_terms:
        found_license_str = usage_terms[:100] + (
            "..." if len(usage_terms) > 100 else ""
        )
        for acceptable in ACCEPTABLE_LICENSES:
            if acceptable in usage_terms:
                is_acceptable = True
                break
        if (
            not is_acceptable
            and "public domain" in usage_terms.lower()
            and ("PD" in ACCEPTABLE_LICENSES or "Public domain" in ACCEPTABLE_LICENSES)
        ):
            is_acceptable = True
    reason = f"License found: '{found_license_str}'. Acceptable: {is_acceptable}"
    if not is_acceptable:
        reason += f". Allowed: {ACCEPTABLE_LICENSES}"
    print(f"   License check: {reason}")
    return is_acceptable, reason


# --- CLIP Scoring Function ---
def calculate_clip_score(image_data, text_prompt, clip_model, clip_preprocess, device):
    """Calculates CLIP similarity score between an image (bytes) and a text prompt."""
    try:
        img = Image.open(io.BytesIO(image_data))
        # Preprocess the image and send it to the device (CPU or GPU)
        image_input = clip_preprocess(img).unsqueeze(0).to(device)
        # Tokenize the text prompt and send it to the device
        text_input = clip.tokenize([text_prompt]).to(device)

        # Calculate features and similarity
        with torch.no_grad():  # No need to track gradients
            image_features = clip_model.encode_image(image_input)
            text_features = clip_model.encode_text(text_input)

            # Normalize features for cosine similarity
            image_features /= image_features.norm(dim=-1, keepdim=True)
            text_features /= text_features.norm(dim=-1, keepdim=True)

            # Calculate cosine similarity score (scaled by 100 as typical in CLIP examples)
            similarity = (
                100.0 * image_features @ text_features.T
            ).item()  # Use .item() to get Python number

        return similarity
    except UnidentifiedImageError:
        print("     CLIP Error: Cannot identify image format.")
        return -1.0  # Return a low score on error
    except Exception as e:
        print(f"     CLIP Error processing image: {e}")
        return -1.0  # Return a low score on error


# --- Main Execution ---
def main():
    os.makedirs(DOWNLOAD_DIR, exist_ok=True)
    print(f"Images will be saved to: {DOWNLOAD_DIR}")
    print(
        f"Image Requirements: Aspect Ratio <= {MAX_ASPECT_RATIO}, Free License ({ACCEPTABLE_LICENSES})."
    )
    print(
        f"Selection: Top 1 image per entity based on CLIP score (checking up to {MAX_CANDIDATES_TO_SCORE} candidates)."
    )
    print(
        f"Final Image: Processed to {RESIZE_TARGET}x{RESIZE_TARGET} PNG (scaled by min dimension)."
    )

    # --- Load CLIP Model ---
    print("\nLoading CLIP model...")
    try:
        device = "cuda" if torch.cuda.is_available() else "cpu"
        clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
        clip_model.eval()  # Set model to evaluation mode
        print(f"CLIP model loaded successfully on device: {device}")
    except Exception as e:
        print(f"Error loading CLIP model: {e}")
        print("CLIP scoring will be skipped.")
        clip_model = None  # Disable CLIP if loading fails

    entity_image_paths = {}

    for entity_name in entities:
        if os.path.exists(
            os.path.join(DOWNLOAD_DIR, sanitize_filename(entity_name) + ".png")
        ):
            print(f"   Image already exists for '{entity_name}'. Skipping.")
            continue

        print(f"\n{'='*10} Processing Entity: '{entity_name}' {'='*10}")
        entity_image_paths[entity_name] = None
        candidate_images = (
            []
        )  # List to store info of suitable candidates [{title, imageinfo}, ...]

        # 1. Search Wikimedia
        file_titles = search_wikimedia_images(entity_name)
        if not file_titles:
            print("   No files found in initial search.")
            time.sleep(REQUEST_DELAY)
            continue
        print(f"   Found {len(file_titles)} potential files in search.")
        time.sleep(REQUEST_DELAY)

        # 2. Get image info
        chunk_size = 50
        all_image_pages = {}
        for i in range(0, len(file_titles), chunk_size):
            chunk = file_titles[i : i + chunk_size]
            print(
                f"   Getting info for files {i+1} to {min(i+chunk_size, len(file_titles))}..."
            )
            image_pages_chunk = get_wikimedia_image_info(chunk)
            if image_pages_chunk:
                all_image_pages.update(image_pages_chunk)
            time.sleep(REQUEST_DELAY)

        if not all_image_pages:
            print("   Could not retrieve image info for found files.")
            continue

        # 3. Filter Candidates (Aspect Ratio & License)
        print("\n   Filtering candidates by aspect ratio and license...")
        processed_titles_filter = set()
        for title in file_titles:
            if len(candidate_images) >= MAX_CANDIDATES_TO_SCORE:
                print(
                    f"   Reached candidate limit ({MAX_CANDIDATES_TO_SCORE}), stopping filter."
                )
                break  # Stop collecting candidates

            if title in processed_titles_filter:
                continue

            page_data = None
            for page_id, data in all_image_pages.items():
                if data.get("title") == title:
                    page_data = data
                    processed_titles_filter.add(title)
                    break

            if not page_data:
                continue

            print(f"\n   Checking candidate: {title}")
            aspect_ok, _ = check_aspect_ratio(page_data)
            if not aspect_ok:
                continue

            license_ok, _ = check_license(page_data)
            if not license_ok:
                continue

            # If both checks pass, add to candidates
            if page_data.get("imageinfo"):
                print(f"   Candidate accepted: {title}")
                candidate_images.append(
                    {"title": title, "imageinfo": page_data["imageinfo"][0]}
                )
            else:
                print(
                    f"   Checks OK, but no imageinfo in page data for {title}. Skipping."
                )

        # 4. Score Candidates with CLIP (if model loaded and candidates exist)
        scored_candidates = []  # List of (score, candidate_info) tuples
        if clip_model and candidate_images:
            print(f"\n   Scoring {len(candidate_images)} candidates with CLIP...")
            # Simple prompt - might need adjustment for better results
            text_prompt = f"A photo of {entity_name}"
            print(f"   Using text prompt: '{text_prompt}'")

            for candidate in candidate_images:
                img_info = candidate["imageinfo"]
                img_url = img_info.get("url")
                if not img_url:
                    continue

                print(f"     Scoring image: {candidate['title']} ({img_url})")
                try:
                    # Download image data for scoring
                    headers = {"User-Agent": "MyImageDownloaderBot/1.0"}
                    response = requests.get(img_url, headers=headers, timeout=15)
                    response.raise_for_status()
                    image_data = response.content

                    # Calculate score
                    score = calculate_clip_score(
                        image_data, text_prompt, clip_model, clip_preprocess, device
                    )
                    if score > -1.0:  # Only consider valid scores
                        print(f"       CLIP Score: {score:.2f}")
                        scored_candidates.append((score, candidate))
                    else:
                        print("       CLIP score calculation failed.")

                except requests.exceptions.RequestException as e_score_dl:
                    print(f"       Error downloading image for scoring: {e_score_dl}")
                except Exception as e_score_gen:
                    print(
                        f"       Unexpected error during scoring setup: {e_score_gen}"
                    )
                time.sleep(0.1)  # Small delay between scoring requests/processing

            # Sort candidates by score (highest first)
            scored_candidates.sort(key=lambda item: item[0], reverse=True)

        elif not clip_model and candidate_images:
            print("\n   CLIP model not loaded. Skipping scoring.")
            # Fallback: Just use the first candidate found if CLIP is unavailable
            scored_candidates = [(0.0, candidate_images[0])]  # Assign arbitrary score
        elif not candidate_images:
            print("\n   No suitable candidates found after filtering.")

        # 5. Select and Process the Best Candidate
        if scored_candidates:
            top_score, best_candidate_info = scored_candidates[0]
            print(
                f"\n   Selected best candidate: {best_candidate_info['title']} (Score: {top_score:.2f})"
            )

            # Process and save the single best image
            final_filepath = process_and_save_image(
                best_candidate_info["imageinfo"]["url"], entity_name, DOWNLOAD_DIR
            )
            entity_image_paths[entity_name] = final_filepath
            if not final_filepath:
                print(
                    f"   Processing failed for the selected best image of {entity_name}."
                )

        else:
            print(f"\n   Could not select a final image for '{entity_name}'.")

        # Delay before processing the next entity
        time.sleep(REQUEST_DELAY)

    print("\n--- CLIP-Enhanced Wikimedia Search and Download Complete ---")
    # --- Final Summary ---
    successful_saves = 0
    failed_entities_list = []
    print("\nSummary:")
    for entity, path in entity_image_paths.items():
        if path:
            print(f"- {entity}: Saved to {path}")
            successful_saves += 1
        else:
            print(f"- {entity}: Failed to save final image.")
            failed_entities_list.append(entity)

    print(
        f"\nSuccessfully saved final images for {successful_saves} out of {len(entities)} entities."
    )
    if failed_entities_list:
        print(f"Failed entities: {', '.join(failed_entities_list)}")


if __name__ == "__main__":
    main()
