import argparse
import hashlib
import json
import logging
import os
import re

import anthropic
import openai
import requests
from dotenv import load_dotenv
from tqdm import tqdm

# Configure logging
logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)


def generate_hash(text: str, algorithm: str = "sha256") -> str:
    """
    Generates a hash for the given text using the specified algorithm.

    :param text: The input text to be hashed.
    :param algorithm: The hashing algorithm to use (e.g., 'sha256', 'md5').
    :return: The resulting hash as a hexadecimal string.
    """
    hash_object = hashlib.new(algorithm)
    hash_object.update(text.encode("utf-8"))
    return hash_object.hexdigest()


def load_model(API_KEY):
    client = anthropic.Anthropic(
        api_key=API_KEY,
    )
    return client


def generate(client, user_prompt, system_prompt, temp=0):
    response = client.messages.create(
        model="claude-3-5-sonnet-20240620",
        max_tokens=4000,
        temperature=temp,
        system=system_prompt,
        messages=[
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": user_prompt},
                ],
            }
        ],
    )

    return response


def generate_queries(client, country, category, subcategory, topic):

    system_prompt = """
    You are an expert at generating **highly relevant, human-like image search queries** optimized for **Google Image Search**.

    Your task is to generate **50 unique search queries** that reflect **natural human behavior**, including:
    - **Typos, slang, informal expressions**, and **incomplete or autocomplete-style phrases**.
    - Use **descriptive visual terms** such as "HD," "4K," "wallpaper," "real photo," "aesthetic," "close-up," "latest pics," etc.
    - Mimic **real-world search styles**, including:
    - Pure **keywords**
    - **Questions** (e.g., “what do [topic] look like”)
    - **Autocomplete-like fragments** (e.g., “best pics of…”)
    - **Trending styles** (e.g., “free download,”)

    Incorporate **localized and culturally relevant elements** from the country provided, including:
    - Dialects, slang, and spelling variations
    - Famous **cities, landmarks**, or **cultural symbols**
    - Country-specific visual cues, aesthetics, or references

    Queries should be:
    - **Short, human-like and natural-sounding** (2-5 words on average)
    - **Highly visual** and suitable for image search intent
    - Focused on the **topics**, not just the country, category or subcategory
    - Always returned in **strict JSON format**:
    ```json
    {
    "queries": [
        "query 1",
        "query 2",
        "... up to query 50"
    ]
    }
    ```

    Remember to generate exactly 50 unique queries, ensuring a diverse range of search styles and incorporating elements specific to the given country. Focus on creating queries that real users might type when searching for images related to the provided topic.    
    """

    user_prompt = f"""
    Generate **50** unique, human-like image search queries** based on below information:

    - Country: {country}  
    - Category: {category}  
    - Subcategory: {subcategory}  
    - Topics: {topic}  
    """

    response = generate(client, system_prompt, user_prompt)

    return response


def read_country_list(file_path):
    with open(file_path, "r", encoding="utf-8") as file:
        lines = file.readlines()
    return [line.strip() for line in lines if line.strip()]


def read_category_list(file_path):
    with open(file_path, "r", encoding="utf-8") as file:
        data = json.load(file)

    return data


def load_env(env_path):
    load_dotenv(dotenv_path=env_path, override=True)
    api_key = os.environ["ANTHROPIC_API_KEY"]
    return api_key


def save_responses(response, file_path):
    with open(file_path, "w", encoding="utf-8") as file:
        json.dump(response, file, ensure_ascii=False, indent=4)


def transform_string(s):
    s = s.lower()
    s = re.sub(r"[,&]", "", s)
    s = re.sub(r"\s+", "_", s.strip())
    return s


def parse_queries(response):
    try:
        if isinstance(response, dict):
            return response
        if "```json" in response:
            response = response.replace("\n", "")
            response = re.search(r"```json(.*)```", response).group(1)
        response = json.loads(response)
    except (json.JSONDecodeError, ValueError) as e:
        raise ValueError(f"Error parsing JSON: {e}")
    return response


def process_data(client, data, cached_dir, output_file):
    hierarchical_data = []

    for data_obj in tqdm(data, desc="Processing categories", leave=False):
        country = data_obj["country"]
        category_name = data_obj["category"]
        subcategory = data_obj["subcategory"]

        if "topics" not in data_obj:
            logging.warning(
                f"No topics found for {country} -> {category_name} -> {subcategory}"
            )
            continue
        topic_list = data_obj["topics"]
        logging.info(f"Processing country: {country}, category: {category_name}")

        # for topic in topic_list:
        try:
            category_str = transform_string(category_name)
            subtopic_str = transform_string(subcategory)
            # topic_str = transform_string(topic)
            topic_text = "\n".join(topic_list)
            topic_str = generate_hash(topic_text)

            file_path = os.path.join(
                cached_dir,
                f"{country}_{category_str}_{subtopic_str}_{topic_str}.json",
            )
            if os.path.exists(file_path):
                with open(file_path, "r", encoding="utf-8") as file:
                    response = json.load(file)
                    response = response["response"]
            else:
                response = generate_queries(
                    client, country, category_name, subcategory, topic_text
                )

                response = response.content[0].text
                response_object = {}
                response_object["country"] = country
                response_object["category"] = category_name
                response_object["subcategory"] = subcategory
                response_object["topic"] = topic_list
                response_object["response"] = response

                save_responses(response_object, file_path)

            queries = parse_queries(response)
            # print(queries)
            # print(type(queries))
            json_object = {
                "country": country,
                "category": category_name,
                "subcategory": subcategory,
                "topic": topic_list,
                "queries": queries["queries"],
            }
            hierarchical_data.append(json_object)

        except Exception as e:
            logging.error(
                f"Error processing {country} -> {category_name} -> {subcategory}: {e}"
            )
            continue

    # Save the hierarchical JSON output
    with open(output_file, "w", encoding="utf-8") as file:
        json.dump(hierarchical_data, file, ensure_ascii=False, indent=4)
    logging.info(f"Data saved to {output_file}")


def main():
    parser = argparse.ArgumentParser(
        description="Generate search queries for Google Image Search"
    )
    parser.add_argument(
        "-c", "--topic_file", type=str, help="Path to the category file"
    )
    parser.add_argument(
        "-d",
        "--cached_dir",
        type=str,
        default="data/queries/cache_queries/",
        help="Directory to cache responses",
    )
    parser.add_argument(
        "-o",
        "--output_file",
        type=str,
        default="",
        help="Path to save the output JSON file",
    )
    parser.add_argument("-e", "--env_file", type=str, help="Path to the .env file")

    args = parser.parse_args()

    args = parser.parse_args()
    topic_file = args.topic_file
    topic_data = read_category_list(topic_file)

    output_file = args.output_file

    env_path = args.env_file
    if not os.path.exists(env_path):
        logging.error(f"Error: {env_path} not found!")
        return

    api_key = load_env(env_path)
    client = load_model(api_key)

    cached_dir = args.cached_dir
    if not os.path.exists(cached_dir):
        os.makedirs(cached_dir, exist_ok=True)

    logging.info(f"Processing topics {len(topic_data)}...")

    process_data(client, topic_data, cached_dir, output_file)


if __name__ == "__main__":
    main()
