import argparse
import copy
import json
import logging
import os
import time
from typing import List

import pandas as pd
import requests
from dotenv import load_dotenv
from pydantic import BaseModel
from tqdm import tqdm

# Configure logging
logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)


def load_env_azure(env_path):
    load_dotenv(dotenv_path=env_path, override=True)
    deployment_name = os.environ["AZURE_ENGINE_NAME"]
    openai_api_base = os.environ["AZURE_API_URL"]
    openai_api_key = os.environ["AZURE_API_KEY"]
    openai_api_version = os.environ["AZURE_API_VERSION"]
    api_url = f"{openai_api_base}/openai/deployments/{deployment_name}/chat/completions?api-version={openai_api_version}"
    headers = {"api-key": openai_api_key}
    return api_url, headers


def generate_azure(api_url, headers, developer_prompt, user_prompt):
    json_data = {
        "messages": [
            {"role": "system", "content": developer_prompt},
            {"role": "user", "content": [{"type": "text", "text": user_prompt}]},
        ],
        "temperature": 0.1,
        "top_p": 0.95,
        "max_tokens": 4000,
        "response_format": {"type": "json_object"},
    }
    response = requests.post(api_url, headers=headers, json=json_data)
    return response


def process_data(api_url, headers, dataset, op_dir_save):
    for index in tqdm(range(len(dataset[:5]))):
        data_obj = dataset.iloc[index].to_dict()

        file_id = data_obj["id"]

        save_file_name = os.path.join(op_dir_save, file_id + ".json")

        query_list = data_obj["filterd_queries"]
        q_str = ""
        for index, q in enumerate(query_list):
            q_str += str(index) + ". " + q + "\n"

        location = data_obj["country"]

        user_prompt = f"""
        Given the following list of **image search queries** related to {location}, **evaluate and assign a relevance score to every single query**.

        Each query must receive a **relevance score from 1 to 100**, where:
        - 100 represents the highest relevance.
        - Higher scores go to queries highlighting **iconic landmarks, cultural elements, or unique aspects of {location}**.
        - Queries mentioning a **location outside of {location}** should receive a **low relevance score**.
        - **DO NOT skip any query** – every query in the list must be assigned a score.
        - Queries are in **Arabic and English** – evaluate both equally. 
        - Queries those are not related to {location} should receive a very low score.

        ### **List of Queries:**  
        ```json
        {json.dumps(query_list, ensure_ascii=False)}
        ```

        **Expected JSON Output Format:**        
        ```json
        [
        {{"Q": "Eiffel Tower at sunset", "score": 100}},
        {{"Q": "Paris street art", "score": 90}},
        ]
        ```        
        """

        # print(user_prompt)
        # sys.exit()
        developer_prompt = """You are an expert in evaluating search query effectiveness for image search. Your task is to rank image search queries based on their relevance to a given location. Focus on specificity, uniqueness, and cultural significance when ranking them. Assign a relevance score from 1 to 100 and return the output in JSON format."""

        if os.path.exists(save_file_name):
            logging.info(f"Skipping existing file: {save_file_name}")
            continue

        try:
            response = generate_azure(api_url, headers, developer_prompt, user_prompt)
            data_obj["response"] = response.json()

            # save response to disk as a JSON file
            with open(save_file_name, "w") as json_file:
                json.dump(data_obj, json_file, ensure_ascii=False, indent=4)
            logging.info(f"Processed file: {save_file_name}")
            st_time = time.time()
            # sys.exit()
        except Exception as e:
            logging.error(f"Error processing file {file_id}: {e}")
            time.sleep(2)


def process_data_batch(api_url, headers, dataset, op_dir_save):
    for index in tqdm(range(len(dataset))):
        data_obj = dataset.iloc[index].to_dict()
        file_id = data_obj["id"]
        query_list = data_obj["filterd_queries"]
        location = data_obj["country"]
        category = data_obj["category"]
        subcategory = data_obj["subcategory"]
        topic = data_obj["topic"]

        # Batch queries in groups of 20
        batch_size = 20
        batches = [
            query_list[i : i + batch_size]
            for i in range(0, len(query_list), batch_size)
        ]

        for batch_idx, batch in enumerate(batches):
            batch_file_name = os.path.join(
                op_dir_save, f"{file_id}_batch{batch_idx}.json"
            )

            if os.path.exists(batch_file_name):
                # logging.info(f"Skipping existing n-batch file: {n_batch_file_name}")
                continue

            n_batch_size = 10
            n_batches = [
                batch[i : i + n_batch_size] for i in range(0, len(batch), n_batch_size)
            ]

            for n_batch_idx, n_batch in enumerate(n_batches):
                n_batch_file_name = os.path.join(
                    op_dir_save, f"{file_id}_batch{batch_idx}_nbatch{n_batch_idx}.json"
                )

                if os.path.exists(n_batch_file_name):
                    continue

                user_prompt = f"""
                Given the following list of **image search queries** related to:
                **Location:** {location},
                **Category:** {category.lower()},
                **Subcategory:** {subcategory.lower()}, and
                **Topic:** {topic.lower()}, evaluate each query and assign a relevance score.

                Each query must receive a **relevance score from 1 to 100**, where:
                - 100 represents the highest relevance.
                - Higher scores go to queries highlighting **iconic landmarks, cultural elements, or unique aspects of {location}**.
                - Queries mentioning a **location outside of {location}** should receive a **low relevance score**.
                - **DO NOT skip any query** – every query in the list must be assigned a score.
                - Queries are in **Arabic and English** – evaluate both equally.
                - Queries those are not related to {location} should receive a very low score within.
                - In all condition, you should maintain the expect json output format.
                - You must predict a score for given queries.
                - Do generate empty json object.
                
                ### **List of Queries:**
                ```json
                {json.dumps(n_batch, ensure_ascii=False)}
                ```
                
                **Expected JSON Output Format:**
                ```json
                [
                    {{"Q": "Eiffel Tower at sunset", "score": 10}},
                    {{"Q": "Corniche promenade sunset photography", "score": 95}}
                ]
                ```
                """

                developer_prompt = """You are an expert in evaluating search query effectiveness for image search. Your task is to rank image search queries based on their relevance to a given location. Focus on specificity, uniqueness, and cultural significance when ranking them. Assign a relevance score from 1 to 100 and return the output in JSON format."""

                try:
                    response = generate_azure(
                        api_url, headers, developer_prompt, user_prompt
                    )
                    response_data = response.json()

                    # Save each batch response in a separate file
                    with open(n_batch_file_name, "w") as json_file:
                        json.dump(
                            response_data, json_file, ensure_ascii=False, indent=4
                        )
                    logging.info(f"Processed n-batch file: {n_batch_file_name}")
                    time.sleep(1)  # Prevent API rate limiting
                except Exception as e:
                    logging.error(f"Error processing n-batch file {file_id}: {e}")
                    time.sleep(2)


def main(input_file, env_path, output_dir):
    dir_path = os.path.abspath(os.path.dirname(input_file))
    output_dir_path = os.path.abspath(os.path.dirname(output_dir))
    os.makedirs(output_dir_path, exist_ok=True)

    api_url, headers = load_env_azure(env_path)
    dataset = pd.read_json(input_file, lines=True)

    logging.info(f"Number of samples in the dataset: {len(dataset)}")

    process_data_batch(api_url, headers, dataset, output_dir_path)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Process some data.")
    parser.add_argument(
        "-i", "--input_file", type=str, required=True, help="Path to the input file"
    )
    parser.add_argument(
        "-e", "--env_path", type=str, required=True, help="Path to the environment file"
    )
    parser.add_argument(
        "-o",
        "--output_dir",
        type=str,
        required=True,
        help="Path to the output directory",
    )

    args = parser.parse_args()
    main(args.input_file, args.env_path, args.output_dir)
