import argparse
import json
import os
import re

import openai
import requests
from dotenv import load_dotenv
from tqdm import tqdm


def generate(api_url, headers, developer_prompt, user_prompt):
    json_data = {
        "messages": [
            {"role": "system", "content": developer_prompt},
            {"role": "user", "content": [{"type": "text", "text": user_prompt}]},
        ],
        "temperature": 0.5,
        "top_p": 0.95,
        "max_tokens": 2000,
        "response_format": {"type": "json_object"},
    }
    response = requests.post(api_url, headers=headers, json=json_data)
    return response


def generate_queries(api_url, headers, country, category):
    developer_prompt = """
    Generate highly plausible, country-specific **image search queries in Romanized Arabic (Arabizi)** that mimic real human behavior when looking for images online.
    Ensure queries:
    - Use **common Arabizi transliterations** (e.g., "3arabi", "7elwa", "tasawir", "sowar", "b7r", "9ora").
    - Contain **natural typos, slang, and informal phrasing**.
    - Reflect **phonetic spelling variations** (e.g., "soar" vs. "sowar" for "صور", "ta9awir" for "تصاوير").
    - Use **descriptive and visual terms** (e.g., "HD", "wallpaper", "real photo", "close-up").
    - Include **search styles such as**:
    - **Keyword-based queries** (e.g., "soar 3arabiya hd", "tasawir {category} free download")
    - **Questions** (e.g., "wsho a7la soar {category} fe {country}?")
    - **Comparisons** (e.g., "soar {category} gadeema vs jadeeda")
    - **Incomplete phrases** (e.g., "{category} pic free", "download {category} {country}")
    - Reflect **country-specific dialects and informal expressions** in Arabizi.
    - You can mention **city names** (e.g., "Doha", "Riyadh", "Cairo", "Casablanca") associated with the specificied country..
    - Prioritize **real-world terminology** and **slang that people use online**.
    """

    user_prompt = f"""
    Generate **20 image search queries** for **{{country}} - {{category}}** in **Romanized Arabic (Arabizi)**.  
    - Use **common transliterations and spelling variations** (e.g., "soar" for "صور", "3arabi" for "عربي").  
    - Include **natural typos, slang, and informal phrasing**.  
    - Use **image search-specific terms** like "wallpaper," "4K," "soar," "real photo," "free pic," "tasawir".  
    - Ensure **varied search styles**:  
    - **Keywords** ("soar {category} HD", "{category} wallpaper 4K")  
    - **Questions** ("wsho a7la soar {category} fe {country}?")  
    - **Comparisons** ("soar {category} gadeema vs jadeeda")  
    - **Incomplete phrases** ("{category} pic free", "download {category} {country}")  

    ### **Example Output**
    ```json
    {{
        "{category}": [
            "soar 3arabiya 7elwa fe {country}",
            "tasawir {category} hd download",
            "a7la soar {category} 2024",
            "{category} pic free",
            "background {category} 4K",
            "soar {category} gadeema vs jadeeda",
            "wsho a7la soar {category} fe {country}?",
            "top insta pics {category} {country}",
            "tasawir {category} elgdeed fe {country}",
            "real vs AI tasawir {category}"
        ]
    }}
    ```
    """

    response = generate(api_url, headers, developer_prompt, user_prompt)

    return response.json()  # ['choices'][0]['message']['content']['queries']


def read_country_list(file_path):
    with open(file_path, "r", encoding="utf-8") as file:
        lines = file.readlines()
    return [line.strip() for line in lines if line.strip()]


def read_category_list(file_path):
    with open(file_path, "r", encoding="utf-8") as file:
        data = json.load(file)

    return data["topics"]


def load_env(env_path):
    load_dotenv(dotenv_path=env_path, override=True)
    deployment_name = os.environ["AZURE_ENGINE_NAME"]
    openai_api_base = os.environ["AZURE_API_URL"]
    openai_api_key = os.environ["AZURE_API_KEY"]
    openai_api_version = os.environ["AZURE_API_VERSION"]
    api_url = f"{openai_api_base}/openai/deployments/{deployment_name}/chat/completions?api-version={openai_api_version}"
    headers = {"api-key": openai_api_key}
    return api_url, headers


def save_responses(response, file_path):
    with open(file_path, "w", encoding="utf-8") as file:
        json.dump(response, file, ensure_ascii=False, indent=4)


def transform_string(s):
    s = s.lower()  # Convert to lowercase
    s = re.sub(r"[,&]", "", s)  # Remove ',' and '&'
    s = re.sub(
        r"\s+", "_", s.strip()
    )  # Replace multiple spaces with a single underscore
    return s


def process_data(
    api_url, headers, country_list, category_list, cached_dir, output_file
):
    hierarchical_data = {}

    for country in tqdm(country_list, desc="Processing countries"):
        hierarchical_data[country] = {}

        for category in tqdm(category_list, desc="Processing categories", leave=False):
            print(category["main_topic"])
            category_name = category["main_topic"]
            hierarchical_data[country][category_name] = {}

            subtopics = category["subtopics"]
            for subtopic, topic_list in subtopics.items():
                hierarchical_data[country][category_name][subtopic] = {}

                for topic in topic_list:
                    try:
                        output_str = transform_string(category_name)
                        subtopic_str = transform_string(subtopic)
                        topic_str = transform_string(topic)

                        file_path = os.path.join(
                            cached_dir,
                            f"{country}_{output_str}_{subtopic_str}_{topic_str}.json",
                        )
                        if os.path.exists(file_path):
                            with open(file_path, "r", encoding="utf-8") as file:
                                response = json.load(file)
                        else:
                            response = generate_queries(
                                api_url, headers, country, topic
                            )
                            save_responses(response, file_path)

                        if "choices" not in response:
                            continue

                        queries = json.loads(
                            response["choices"][0]["message"]["content"]
                        )
                        hierarchical_data[country][category_name][subtopic][topic] = {
                            "queries": queries
                        }

                    except Exception as e:
                        print(
                            f"Error processing {country} -> {category_name} -> {subtopic} -> {topic}: {e}"
                        )
                        continue

    # Save the hierarchical JSON output
    with open(output_file, "w", encoding="utf-8") as file:
        json.dump(hierarchical_data, file, ensure_ascii=False, indent=4)


def main():
    parser = argparse.ArgumentParser(
        description="Generate search queries for Google Image Search"
    )
    parser.add_argument(
        "-c", "--category_file", type=str, help="Path to the category file"
    )
    parser.add_argument(
        "-f", "--country_file", type=str, help="Path to the country file"
    )
    parser.add_argument(
        "-d",
        "--cashed_dir",
        type=str,
        default="output/queries/",
        help="Directory to cache responses",
    )
    parser.add_argument(
        "-o",
        "--output_file",
        type=str,
        default="",
        help="Path to save the output JSON file",
    )
    parser.add_argument("-e", "--env_file", type=str, help="Path to the .env file")

    args = parser.parse_args()

    country_file = args.country_file

    category_file = args.category_file
    country_list = read_country_list(country_file)
    category_list = read_category_list(category_file)

    output_file = args.output_file

    env_path = args.env_file
    if not os.path.exists(env_path):
        print(f"Error: {env_path} not found!")
        return

    cached_dir = args.cashed_dir
    if not os.path.exists(cached_dir):
        os.makedirs(cached_dir, exist_ok=True)

    api_url, headers = load_env(env_path)

    print(
        f"Processing {len(country_list)} countries and {len(category_list)} categories..."
    )
    process_data(api_url, headers, country_list, category_list, cached_dir, output_file)

    print("Queries generated successfully!")


if __name__ == "__main__":
    main()
