import argparse
import hashlib
import json
import logging
import os
import re

import openai
import requests

import vertexai
import vertexai.preview.generative_models as generative_models
from dotenv import load_dotenv
from tqdm import tqdm
from vertexai.generative_models import FinishReason, GenerativeModel, Part

# Configure logging
logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)


def generate_hash(text: str, algorithm: str = "sha256") -> str:
    """
    Generates a hash for the given text using the specified algorithm.

    :param text: The input text to be hashed.
    :param algorithm: The hashing algorithm to use (e.g., 'sha256', 'md5').
    :return: The resulting hash as a hexadecimal string.
    """
    hash_object = hashlib.new(algorithm)
    hash_object.update(text.encode("utf-8"))
    return hash_object.hexdigest()


def generate(sys_prompt, user_prompt, temp=0.5):
    safety_settings = {
        generative_models.HarmCategory.HARM_CATEGORY_HATE_SPEECH: generative_models.HarmBlockThreshold.BLOCK_ONLY_HIGH,
        generative_models.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: generative_models.HarmBlockThreshold.BLOCK_ONLY_HIGH,
        generative_models.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: generative_models.HarmBlockThreshold.BLOCK_ONLY_HIGH,
        generative_models.HarmCategory.HARM_CATEGORY_HARASSMENT: generative_models.HarmBlockThreshold.BLOCK_ONLY_HIGH,
    }

    vertexai.init(project="your-gcp-project-id", location="us-central1")

    generation_config = {
        "max_output_tokens": 8192,
        "temperature": temp,
        "top_p": 0.95,
    }
    model = GenerativeModel("gemini-2.0-flash-001", system_instruction=[sys_prompt])

    responses = model.generate_content(
        [user_prompt],
        generation_config=generation_config,
        safety_settings=safety_settings,
        # stream=True,
    )
    # print(responses)
    return responses


def usage_metadata_to_dict(metadata_str):
    metadata_dict = {}

    # Split the string by newlines to process each line
    lines = metadata_str.strip().split("\n")

    for line in lines:
        # Split each line by the colon to separate the key and value
        key, value = line.split(":")
        # Remove any leading/trailing whitespace and convert the value to an integer
        metadata_dict[key.strip()] = int(value.strip())

    return metadata_dict


def generate_queries(country, category, subcategory, topic):

    system_prompt = """
    You are an expert at generating **highly relevant, human-like image search queries** optimized for **Google Image Search**.
    
    Your task is to generate 50 unique search queries based on a given **country**, **category**, **subcategory** and **topics**. 
    These queries should reflect **natural human behavior** and be suitable for image search intent:    

    Follow these guidelines to generate the queries:

    1. Reflect natural human behavior:
    - Include typos, slang, informal expressions, and incomplete or autocomplete-style phrases
    - Use a mix of short-form and punctuation styles

    2. Incorporate visual descriptors:
    - Use terms like "HD," "4K," "wallpaper," "real photo," "aesthetic," "close-up," "latest pics," etc.

    3. Mimic real-world search styles:
    - Pure keywords
    - Questions (e.g., "what do [topic] look like")
    - Autocomplete-like fragments (e.g., "best pics of…")
    - Trending styles (e.g., "free download")

    4. Include **localized and culturally relevant elements** from the specified country:
    - Use dialects, slang, and spelling variations specific to the country
    - Reference famous cities, landmarks, or cultural symbols
    - Incorporate country-specific visual cues, aesthetics, or references

    Ensure that the queries have the following characteristics:
    - **Short, human-like, and natural-sounding** (2-5 words on average)
    - **Highly visual** and suitable for image search intent
    - Focused on the **topics**, not just the country, category or subcategory
    - Avoid any formal tone
    
    Examples:
    pani puri hd images
    islamic tile patterns
    gold artifacts closeup aesthetic
    vintage islamic pattern doors

    - Always returned in **strict JSON format**:
    ```json
    {
    "queries": [
        "query 1",
        "query 2",
        "... up to query 50"
    ]
    }
    ```
    """

    user_prompt = f"""
    Generate **50** unique, human-like image search queries** based on below information:

    - Country: {country}  
    - Category: {category}  
    - Subcategory: {subcategory}  
    - Topics: {topic}  
    """
    response = generate(system_prompt, user_prompt)
    responses_dict = [response.to_dict() for response in response.candidates]
    usage_metadata_dict = usage_metadata_to_dict(str(response.usage_metadata))
    responses_dict.append({"usage_metadata": usage_metadata_dict})

    return responses_dict


def read_country_list(file_path):
    with open(file_path, "r", encoding="utf-8") as file:
        lines = file.readlines()
    return [line.strip() for line in lines if line.strip()]


def read_category_list(file_path):
    with open(file_path, "r", encoding="utf-8") as file:
        data = json.load(file)
    return data


def save_responses(response, file_path):
    with open(file_path, "w", encoding="utf-8") as file:
        json.dump(response, file, ensure_ascii=False, indent=4)


def transform_string(s):
    s = s.lower()
    s = re.sub(r"[,&]", "", s)
    s = re.sub(r"\s+", "_", s.strip())
    return s


def parse_queries(response):
    if response is None:
        raise ValueError(f"'response' field is missing")

    # Ensure that 'response' is a list and contains elements
    if not isinstance(response, list) or len(response) == 0:
        raise ValueError(f"'response' is not a list or is empty")

    # Check if 'content' is within the first item of 'response'
    content = response[0].get("content")
    if content is None:
        print(f"'content' field is missing in the first item of 'response'.")
        return None

    # Navigate to the text field
    try:
        parts = content.get("parts", [])
        if not parts or not isinstance(parts, list) or "text" not in parts[0]:
            raise ValueError(
                f"'parts' structure is invalid or 'text' is missing in the response"
            )

        queries = parts[0]["text"]
        if "```json" in queries:
            queries = queries.replace("\n", "")
            queries = re.search(r"```json(.*)```", queries).group(1)
            queries = json.loads(queries)

    except (KeyError, IndexError, ValueError) as e:
        raise ValueError(f"Missing expected fields in JSON structure: {e}")

    if not queries:
        raise ValueError("Essay content is missing in the parsed JSON.")

    return queries


def process_data(data, cached_dir, output_file):
    hierarchical_data = []

    for data_obj in tqdm(data, desc="Processing categories", leave=False):
        country = data_obj["country"]
        category_name = data_obj["category"]
        subcategory = data_obj["subcategory"]

        if "topics" not in data_obj:
            logging.warning(
                f"No topics found for {country} -> {category_name} -> {subcategory}"
            )
            continue
        topic_list = data_obj["topics"]
        logging.info(f"Processing country: {country}, category: {category_name}")

        # for topic in topic_list:
        try:
            category_str = transform_string(category_name)
            subtopic_str = transform_string(subcategory)
            # topic_str = transform_string(topic)

            topic_text = "\n".join([topic.lower() for topic in topic_list])
            topic_str = generate_hash(topic_text)

            file_path = os.path.join(
                cached_dir,
                f"{country}_{category_str}_{subtopic_str}_{topic_str}.json",
            )
            if os.path.exists(file_path):
                with open(file_path, "r", encoding="utf-8") as file:
                    response = json.load(file)
                    response = response["response"]
            else:
                response = generate_queries(
                    country,
                    category_name,
                    subcategory,
                    topic_text,
                )
                responses_dict = {}
                responses_dict["country"] = country
                responses_dict["category"] = category_name
                responses_dict["subcategory"] = subcategory
                responses_dict["topic"] = topic_list
                responses_dict["response"] = response

                save_responses(responses_dict, file_path)

            queries = parse_queries(response)
            json_object = {
                "country": country,
                "category": category_name,
                "subcategory": subcategory,
                "topic": topic_list,
                "queries": queries["queries"],
            }
            hierarchical_data.append(json_object)

        except Exception as e:
            logging.error(
                f"Error processing {country} -> {category_name} -> {subcategory}: {e}"
            )
            continue

    # Save the hierarchical JSON output
    with open(output_file, "w", encoding="utf-8") as file:
        json.dump(hierarchical_data, file, ensure_ascii=False, indent=4)
    logging.info(f"Data saved to {output_file}")


def main():
    parser = argparse.ArgumentParser(
        description="Generate search queries for Google Image Search"
    )
    parser.add_argument(
        "-c", "--topic_file", type=str, help="Path to the category file"
    )
    parser.add_argument(
        "-d",
        "--cached_dir",
        type=str,
        default="data/queries/cache_queries/",
        help="Directory to cache responses",
    )
    parser.add_argument(
        "-o",
        "--output_file",
        type=str,
        default="",
        help="Path to save the output JSON file",
    )
    parser.add_argument("-e", "--env_file", type=str, help="Path to the .env file")

    args = parser.parse_args()

    args = parser.parse_args()
    topic_file = args.topic_file
    topic_data = read_category_list(topic_file)

    output_file = args.output_file

    env_path = args.env_file
    if not os.path.exists(env_path):
        logging.error(f"Error: {env_path} not found!")
        return

    cached_dir = args.cached_dir
    if not os.path.exists(cached_dir):
        os.makedirs(cached_dir, exist_ok=True)

    logging.info(f"Processing topics {len(topic_data)}...")

    process_data(topic_data, cached_dir, output_file)


if __name__ == "__main__":
    main()
