import argparse
import hashlib
import json
import logging
import os
import re

import openai
import requests
from dotenv import load_dotenv
from tqdm import tqdm

# Configure logging
logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)


def generate_hash(text: str, algorithm: str = "sha256") -> str:
    """
    Generates a hash for the given text using the specified algorithm.

    :param text: The input text to be hashed.
    :param algorithm: The hashing algorithm to use (e.g., 'sha256', 'md5').
    :return: The resulting hash as a hexadecimal string.
    """
    hash_object = hashlib.new(algorithm)
    hash_object.update(text.encode("utf-8"))
    return hash_object.hexdigest()


def generate(api_url, headers, developer_prompt, user_prompt):
    json_data = {
        "messages": [
            {"role": "system", "content": developer_prompt},
            {"role": "user", "content": [{"type": "text", "text": user_prompt}]},
        ],
        "temperature": 0.5,
        "top_p": 0.95,
        "max_tokens": 2000,
        "response_format": {"type": "json_object"},
    }
    response = requests.post(api_url, headers=headers, json=json_data)
    return response


def generate_queries(api_url, headers, country, category, subcategory, topic):

    system_prompt = """
    You are an expert at generating **highly relevant, human-like image search queries** optimized for **Google Image Search**.
    
    Your task is to generate 50 unique search queries based on a given **country**, **category**, **subcategory** and **topics**. 
    These queries should accurately reflect **natural human behavior** and how real people from the Arab world search for images, using both Modern Standard Arabic (MSA) and country-specific dialects where appropriate. 

    Follow these guidelines to generate the queries:

    1. Reflect natural human behavior:
    - Include typos, slang, informal expressions, and incomplete or autocomplete-style phrases
    - Use a mix of capitalization and punctuation styles

    2. Incorporate visual descriptors:
    - Use terms like "HD," "4K," "wallpaper," "real photo," "aesthetic," "close-up," "latest pics," etc.

    3. Mimic real-world search styles:
    - Pure keywords
    - Questions (e.g., "what do [topic] look like")
    - Autocomplete-like fragments (e.g., "best pics of…")
    - Trending styles (e.g., "free download")

    4. Include **localized and culturally relevant elements** from the specified country:
    - Use dialects, slang, and spelling variations specific to the country
    - Reference famous cities, landmarks, or cultural symbols
    - Incorporate country-specific visual cues, aesthetics, or references

    Ensure that the queries have the following characteristics:
    - **Short, human-like, and natural-sounding** (2-5 words on average)
    - **Highly visual** and suitable for image search intent
    - Focused on the **topics**, not just the country, category or subcategory and should reflect realistic search behavior of Arabic-speaking users.
    - Avoid any formal tone
    - Do not include any explanations or additional text outside of the JSON structure. 
    
    - Always returned in **strict JSON format**:
    ```json
    {
    "queries": [
        "query 1",
        "query 2",
        "... up to query 50"
    ]
    }
    ```
    """

    user_prompt = f"""
    Generate **50** unique, human-like image search queries** based on below information:

    - Country: {country}  
    - Category: {category}  
    - Subcategory: {subcategory}  
    - Topics: {topic}  
    """

    response = generate(api_url, headers, system_prompt, user_prompt)

    return response.json()


def read_country_list(file_path):
    with open(file_path, "r", encoding="utf-8") as file:
        lines = file.readlines()
    return [line.strip() for line in lines if line.strip()]


def read_category_list(file_path):
    with open(file_path, "r", encoding="utf-8") as file:
        data = json.load(file)

    return data


def load_env(env_path):
    load_dotenv(dotenv_path=env_path, override=True)
    deployment_name = os.environ["AZURE_ENGINE_NAME"]
    openai_api_base = os.environ["AZURE_API_URL"]
    openai_api_key = os.environ["AZURE_API_KEY"]
    openai_api_version = os.environ["AZURE_API_VERSION"]
    api_url = f"{openai_api_base}/openai/deployments/{deployment_name}/chat/completions?api-version={openai_api_version}"
    headers = {"api-key": openai_api_key}
    return api_url, headers


def save_responses(response, file_path):
    with open(file_path, "w", encoding="utf-8") as file:
        json.dump(response, file, ensure_ascii=False, indent=4)


def transform_string(s):
    s = s.lower()
    s = re.sub(r"[,&]", "", s)
    s = re.sub(r"\s+", "_", s.strip())
    return s


def process_data(api_url, headers, data, cached_dir, output_file):
    hierarchical_data = []

    for data_obj in tqdm(data, desc="Processing categories", leave=False):
        country = data_obj["country"]
        category_name = data_obj["category"]
        subcategory = data_obj["subcategory"]

        if "topics" not in data_obj:
            logging.warning(
                f"No topics found for {country} -> {category_name} -> {subcategory}"
            )
            continue
        topic_list = data_obj["topics"]
        logging.info(f"Processing country: {country}, category: {category_name}")

        # for topic in topic_list:
        try:
            category_str = transform_string(category_name)
            subtopic_str = transform_string(subcategory)
            # topic_str = transform_string(topic)

            topic_text = "\n".join(topic_list)
            topic_str = generate_hash(topic_text)

            file_path = os.path.join(
                cached_dir,
                f"{country}_{category_str}_{subtopic_str}_{topic_str}.json",
            )
            if os.path.exists(file_path):
                with open(file_path, "r", encoding="utf-8") as file:
                    response = json.load(file)
            else:
                response = generate_queries(
                    api_url,
                    headers,
                    country,
                    category_name,
                    subcategory,
                    topic_text,
                )
                response_object = response
                response_object["country"] = country
                response_object["category"] = category_name
                response_object["subcategory"] = subcategory
                response_object["topic"] = topic_list

                save_responses(response_object, file_path)

            if "choices" not in response:
                logging.warning(
                    f"No choices found in response for {country} -> {category_name} -> {subcategory}"
                )
                continue

            queries = json.loads(response["choices"][0]["message"]["content"])
            json_object = {
                "country": country,
                "category": category_name,
                "subcategory": subcategory,
                "topic": topic_list,
                "queries": queries["queries"],
            }
            hierarchical_data.append(json_object)

        except Exception as e:
            logging.error(
                f"Error processing {country} -> {category_name} -> {subcategory}: {e}"
            )
            continue

    # Save the hierarchical JSON output
    with open(output_file, "w", encoding="utf-8") as file:
        json.dump(hierarchical_data, file, ensure_ascii=False, indent=4)
    logging.info(f"Data saved to {output_file}")


def main():
    parser = argparse.ArgumentParser(
        description="Generate search queries for Google Image Search"
    )
    parser.add_argument(
        "-c", "--topic_file", type=str, help="Path to the category file"
    )
    parser.add_argument(
        "-d",
        "--cached_dir",
        type=str,
        default="data/queries/cache_queries/",
        help="Directory to cache responses",
    )
    parser.add_argument(
        "-o",
        "--output_file",
        type=str,
        default="",
        help="Path to save the output JSON file",
    )
    parser.add_argument("-e", "--env_file", type=str, help="Path to the .env file")

    args = parser.parse_args()
    topic_file = args.topic_file
    topic_data = read_category_list(topic_file)

    output_file = args.output_file

    env_path = args.env_file
    if not os.path.exists(env_path):
        logging.error(f"Error: {env_path} not found!")
        return

    cached_dir = args.cashed_dir
    if not os.path.exists(cached_dir):
        os.makedirs(cached_dir, exist_ok=True)

    api_url, headers = load_env(env_path)

    logging.info(f"Processing topics {len(topic_data)}...")

    process_data(api_url, headers, topic_data, cached_dir, output_file)

    logging.info("Queries generated successfully!")


if __name__ == "__main__":
    main()
