import argparse
import csv
import json
import os
import re

import openai
import requests
from dotenv import load_dotenv
from tqdm import tqdm


def generate(api_url, headers, developer_prompt, user_prompt):
    json_data = {
        "messages": [
            {"role": "system", "content": developer_prompt},
            {"role": "user", "content": [{"type": "text", "text": user_prompt}]},
        ],
        "temperature": 0.7,  # 0.5, 0.6 0.7
        "top_p": 0.95,
        "max_tokens": 2000,
        "response_format": {"type": "json_object"},
    }
    response = requests.post(api_url, headers=headers, json=json_data)
    return response


def generate_queries(api_url, headers, country, category, subcategory):
    system_prompt = f"""
        You are an AI specialized in generating highly relevant topics for image searches 
        based on a given country, category, and subcategory. Your task is to generate a list 
        of topics that are highly visual and well-suited for image searches. Ensure that the topics 
        reflect the cultural, historical, or modern significance of the specified location.

        Guidelines:

        1. Topics should be engaging, highly visual, and unique to the specified country.
        2. Ensure a mix of historical, modern, and futuristic aspects based on the subcategory.
        3. Use well-known landmarks, cultural elements, or emerging trends where relevant.
        4. Prioritize topics that are frequently searched for in image search engines.
        5. If the subcategory is broad, ensure a diverse selection covering different aspects.
        6. Do not include generic topics that could apply to any country; make them location-specific.
        7. If the subcategory is too narrow and lacks visual topics, expand the scope slightly to include related themes.
        8. Generate exactly 10 topics per request. If necessary, include related visual aspects.
        9. Avoid redundant or overly generic suggestions.
        10. Ensure diversity in the topics; avoid generating closely related topics.

        JSON Format:
        - Provide a list of topics, each being short, clear, and descriptive 
        (e.g., 'Futuristic Skyscrapers of Dubai' or 'Traditional Wooden Temples of Japan').

        ```json
        [
            "Futuristic Skyscrapers",
            "Traditional Mosque"
        ]
        ```
    """

    user_prompt = f"""
    Generate **exactly** 10 highly relevant topics for image search based on the following:

    - Country: {country}
    - Category: {category}
    - Subcategory: {subcategory}

    If there are fewer than 10 highly relevant topics, expand the scope slightly to related visual themes. 

    Ensure the topics are visually engaging, related to the specified country, and match common image search behavior. 
    The topics should cover a mix of historical, modern, and futuristic elements unique to the location.
    """

    response = generate(api_url, headers, system_prompt, user_prompt)

    return response.json()


def read_country_list(file_path):
    with open(file_path, "r", encoding="utf-8") as file:
        lines = file.readlines()
    return [line.strip() for line in lines if line.strip()]


def read_category_list(file_path):
    data = {}
    with open(file_path, "r", encoding="utf-8") as file:
        reader = csv.DictReader(file, delimiter="\t")
        for row in reader:
            if row["category"] in data:
                data[row["category"]].append(row["subcategory"])
            else:
                data[row["category"]] = []
                data[row["category"]].append(row["subcategory"])
    return data


def load_env(env_path):
    load_dotenv(dotenv_path=env_path, override=True)
    deployment_name = os.environ["AZURE_ENGINE_NAME"]
    openai_api_base = os.environ["AZURE_API_URL"]
    openai_api_key = os.environ["AZURE_API_KEY"]
    openai_api_version = os.environ["AZURE_API_VERSION"]
    api_url = f"{openai_api_base}/openai/deployments/{deployment_name}/chat/completions?api-version={openai_api_version}"
    headers = {"api-key": openai_api_key}
    return api_url, headers


def save_responses(response, file_path):
    with open(file_path, "w", encoding="utf-8") as file:
        json.dump(response, file, ensure_ascii=False, indent=4)


def transform_string(s):
    s = s.lower()
    s = re.sub(r"[,&]", "", s)
    s = re.sub(r"\s+", "_", s.strip())
    return s


def process_data(api_url, headers, country_list, category_list, cached_dir):
    topics_data = []

    for country in tqdm(country_list, desc="Processing countries"):
        for category_name, subcategory_list in tqdm(
            category_list.items(), desc="Processing categories", leave=False
        ):
            for subcategory in subcategory_list:
                category_str = transform_string(category_name)
                subcategory_str = transform_string(subcategory)
                try:
                    file_path = os.path.join(
                        cached_dir,
                        f"{country}_{category_str}_{subcategory_str}.json",
                    )
                    if os.path.exists(file_path):
                        with open(file_path, "r", encoding="utf-8") as file:
                            response = json.load(file)
                    else:
                        response = generate_queries(
                            api_url, headers, country, category_name, subcategory
                        )

                        response_object = response
                        response_object["country"] = country
                        response_object["category"] = category_name
                        response_object["subcategory"] = subcategory
                        save_responses(response, file_path)

                    if "choices" not in response:
                        continue

                    topics = json.loads(response["choices"][0]["message"]["content"])
                    json_object = {
                        "country": country,
                        "category": category_name,
                        "subcategory": subcategory,
                        "topics": topics["topics"],
                    }
                    topics_data.append(json_object)

                except Exception as e:
                    print(
                        f"Error processing {country} -> {category_name} -> {subcategory}: {e}"
                    )
                    continue

    return topics_data


def save_output(topics_data, output_file):
    # Save the JSON output
    with open(output_file, "w", encoding="utf-8") as file:
        json.dump(topics_data, file, ensure_ascii=False, indent=4)

    # Save the data in TSV format
    tsv_output_file = output_file.replace(".json", ".tsv")
    with open(tsv_output_file, "w", encoding="utf-8", newline="") as tsv_file:
        tsv_writer = csv.writer(tsv_file, delimiter="\t")

        # Write the header
        tsv_writer.writerow(["country", "category", "subcategory", "topics"])

        # Write the data
        for item in topics_data:
            for obj in item["topics"]:
                tsv_writer.writerow(
                    [item["country"], item["category"], item["subcategory"], obj]
                )


def main():
    parser = argparse.ArgumentParser(
        description="Generate search queries for Google Image Search"
    )
    parser.add_argument(
        "-c", "--category_file", type=str, help="Path to the category file"
    )
    parser.add_argument(
        "-f", "--country_file", type=str, help="Path to the country file"
    )
    parser.add_argument(
        "-d",
        "--cashed_dir",
        type=str,
        default="data/queries/topics/cache/",
        help="Directory to cache responses",
    )
    parser.add_argument(
        "-o",
        "--output_file",
        type=str,
        default="data/queries/topics/generated_topics.jsonl",
        help="Path to save the output JSON file",
    )
    parser.add_argument("-e", "--env_file", type=str, help="Path to the .env file")

    args = parser.parse_args()

    country_file = args.country_file

    category_file = args.category_file
    country_list = read_country_list(country_file)
    category_list = read_category_list(category_file)
    # print(country_list)
    # sys.exit()

    output_file = args.output_file

    env_path = args.env_file
    if not os.path.exists(env_path):
        print(f"Error: {env_path} not found!")
        return

    cached_dir = args.cashed_dir
    if not os.path.exists(cached_dir):
        os.makedirs(cached_dir, exist_ok=True)

    api_url, headers = load_env(env_path)

    print(
        f"Processing {len(country_list)} countries and {len(category_list)} categories..."
    )
    topics_data = process_data(
        api_url, headers, country_list, category_list, cached_dir
    )
    save_output(topics_data, output_file)

    print("Queries generated successfully!")


if __name__ == "__main__":
    main()
