import argparse
import concurrent.futures
import hashlib
import json
import logging
import os
import re
from io import BytesIO

import requests
from PIL import Image
from tqdm import tqdm

# Configure logging
logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)


def get_image_extension(url, response):
    try:
        content_type = response.headers.get("Content-Type")
        if content_type and content_type.startswith("image/"):
            # Map common MIME types to file extensions
            mime_to_extension = {
                "image/jpeg": "jpg",
                "image/png": "png",
                "image/gif": "gif",
                "image/bmp": "bmp",
                "image/webp": "webp",
                "image/tiff": "tiff",
                "image/x-icon": "ico",
            }
            return mime_to_extension.get(content_type, None)
        else:
            return None
    except requests.RequestException as e:
        print(f"Failed to check URL {url}: {e}")
        return None


def normalize_name(name):
    name = name.replace("&", "and")
    name = name.replace("'", "")
    name = name.replace(",", "")
    name = re.sub(r"\s+", "_", name.strip())
    name = name.lower()
    return name


def download_image(image_obj, image_save_dir):
    url = image_obj["image_url"]
    image_id = image_obj["image_id"]
    category = normalize_name(image_obj.get("category", "Unknown"))
    subcategory = normalize_name(image_obj.get("subcategory", "Unknown"))

    try:
        response = requests.get(url, timeout=(5, 10))
        response.raise_for_status()
        extension = get_image_extension(url, response)
        if extension is None:
            return {"url": url, "status": "failed"}

        new_directory = os.path.join(image_save_dir, "images", category, subcategory)
        os.makedirs(new_directory, exist_ok=True)
        image_path = os.path.join(new_directory, image_id + "." + extension)

        # image_path = image_save_dir + "/" + image_id + "." + extension

        image_obj["image_path"] = image_path
        img = Image.open(BytesIO(response.content))
        if img.mode != "RGB":
            img = img.convert("RGB")
        img.save(image_path)
        logging.info(f"Downloaded {url}")
        return {"url": url, "name": image_path, "status": "success"}
    except Exception as e:
        logging.error(f"Failed to download {url}: {e}")
        return {"url": url, "status": "failed", "error": str(e)}


def download_images_in_parallel(image_objects, image_save_dir, max_workers=5):
    results = []
    with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
        futures = [
            executor.submit(download_image, image_obj, image_save_dir)
            for image_obj in image_objects
        ]
        for future in concurrent.futures.as_completed(futures):
            try:
                results.append(future.result())
            except Exception as e:
                logging.error(f"Error in downloading {e}")

    return results


def generate_hash(text: str, algorithm: str = "sha256") -> str:
    """
    Generates a hash for the given text using the specified algorithm.

    :param text: The input text to be hashed.
    :param algorithm: The hashing algorithm to use (e.g., 'sha256', 'md5').
    :return: The resulting hash as a hexadecimal string.
    """
    hash_object = hashlib.new(algorithm)
    hash_object.update(text.encode("utf-8"))
    return hash_object.hexdigest()


def list_and_json_files(input_dir):
    json_files = [f for f in os.listdir(input_dir) if f.endswith(".json")]
    image_objects = []

    for json_file in json_files:
        with open(os.path.join(input_dir, json_file), "r") as f:
            data = json.load(f)
            image_objects.append(data)
    return image_objects


def parse_files(image_objects):
    duplicate_image_urls = set()
    image_content_list = []
    for image_object in tqdm(image_objects, desc="Parsing files..."):
        # print(image_object)
        image_list = image_object["responses"]
        index = 0
        for img_object in image_list:
            if not "items" in img_object:
                continue
            for item in img_object["items"]:
                image_url = item["link"].strip()
                # print(image_url)
                image_id = generate_hash(image_url)
                if image_id in duplicate_image_urls:
                    print(f"duplicate: {image_id}")
                    continue
                else:
                    duplicate_image_urls.add(image_id)
                image_content = {
                    "q_id": image_object["q_id"],
                    "country": image_object["country"],
                    "category": image_object["category"],
                    "subcategory": image_object["subcategory"],
                    "topic": image_object["topic"],
                    "query": image_object["query"],
                    "topic_rank": image_object["topic_rank"],
                    "ranked_index": index,
                    "image_id": image_id,
                    "image_url": image_url,
                    "image_info": item["image"],
                }
                index += 1
                image_content_list.append(image_content)

    print(f"Number of images url: {len(image_content_list)}")
    return image_content_list


def check_chache(image_objects, dir_path):
    image_paths = {}
    for file in os.listdir(dir_path):
        file_path = os.path.join(dir_path, file)
        if os.path.isfile(file_path):
            basename = os.path.splitext(os.path.basename(file_path))[0]
            image_paths[basename] = file_path
    objects_to_process = []
    for img_object in image_objects:
        image_id = img_object["image_id"]
        if image_id in image_paths:
            img_object["image_path"] = image_paths[image_id]
        else:
            objects_to_process.append(img_object)
    return objects_to_process


def save_to_jsonl(jsonl_save_path, image_objects, failed_downloads):
    os.makedirs(os.path.dirname(jsonl_save_path), exist_ok=True)
    with open(jsonl_save_path, "w") as jsonl_file:
        i = 0
        for item in image_objects:
            if "image_path" in item:
                image_path = item["image_path"]
                if image_path in failed_downloads:
                    continue
                if not os.path.exists(image_path):
                    logging.info(f"{image_path} does not exist..")
                    continue
                jsonl_file.write(json.dumps(item, ensure_ascii=False) + "\n")
                i += 1
    logging.info(f"Number of image downloaded: {i}")
    logging.info(f"Saved results to {jsonl_save_path}")


def main(input_dir, image_save_dir, jsonl_save_path, max_workers):
    os.makedirs(image_save_dir, exist_ok=True)
    query_objects = list_and_json_files(input_dir)

    image_objects = parse_files(query_objects)
    logging.info(f"Number of image objects {len(image_objects)}")

    objects_to_process = check_chache(image_objects, image_save_dir)
    # logging.info(f"Number of image objects to process {len(image_objects_to_process)}")
    results = download_images_in_parallel(
        objects_to_process, image_save_dir, max_workers=max_workers
    )
    failed_downloads = [result for result in results if result["status"] == "failed"]

    logging.info(f"Number of failed download: {len(failed_downloads)}.")

    save_to_jsonl(jsonl_save_path, image_objects, failed_downloads)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Download images and save metadata.")
    parser.add_argument(
        "--input_dir",
        type=str,
        default="data/crawled_images/Country/responses/",
        help="Input directory",
    )
    parser.add_argument(
        "--image_save_dir",
        type=str,
        default="data/crawled_images/Country/images/",
        help="Image save directory",
    )
    parser.add_argument(
        "--output_file",
        type=str,
        default="data/output.jsonl",
        help="JSONL save path",
    )
    parser.add_argument(
        "--max_workers", type=int, default=5, help="Maximum number of workers"
    )

    args = parser.parse_args()
    main(args.input_dir, args.image_save_dir, args.output_file, args.max_workers)
