import asyncio
import json
import os
import sys

from llms.llm_utils import visualize_prompt
from llms.prompt_utils import get_conversation_payload_size, get_messages
from offline_experiments.vwa_specific import get_msg_for_url_link
from utils.logger_utils import logger

if __name__ == "__main__" and not __package__:  # @debug
    # Insert the parent directory into sys.path so that the package can be found
    parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
    sys.path.insert(0, parent_dir)
    # Manually set the package name so that relative imports work
    __package__ = "offline_experiments"


from .build_prompt import get_k_injection, get_verifier_prompts
from .config_run import build_all_eval_configs, gen_config, run_config
from .first_pass import get_k_resp_from_cache
from .runners import run_batch_mode, run_sequential
from .utils_offline_exper import get_intent_message, get_trace_data, get_trajectory_msgs

cache_trajectory_msgs = {}
total_tokens = 0
total_prompts = 0


# ===============================================
# Task data to load
# ===============================================
def get_image(trajectory_msgs):
    trajectory_msgs[-1]


def build_llm_call_args(task_id, config, run_config) -> tuple[list[dict], str, str, str]:
    global total_tokens, total_prompts
    try:
        conversation_dir = f"{config['out_dir'].strip('-')}/conversation"
        usage_dir = f"{config['out_dir'].strip('-')}/usage"

        if not run_config["overwrite"]:
            full_conversation_path = f"{conversation_dir}/{task_id}.txt"
            if os.path.exists(full_conversation_path):
                logger.info(f"Skipping {task_id} because {full_conversation_path} exists.")
                return None, None, None, None

        k_resp, final_k = None, None
        if config["prompt_args"].get("k_config"):
            k_config = config["prompt_args"]["k_config"]
            k_dir_name = k_config["cached_k_dir"]
            k_resp = get_k_resp_from_cache(file_path=f"{k_dir_name}/conversation/{task_id}.html")
            if not k_resp:
                logger.error(f"Failed to build prompt for {config} {task_id}: unable to get first pass response.")
                return None, None, None, None

            total_tokens += len(k_resp)
            total_prompts += 1
            k_injection = get_k_injection(config)
            final_k = k_injection.format(k=k_resp)
        # return None, None, None, task_id

        trace_data = get_trace_data(config, task_id)
        if not trace_data:
            logger.error(f"Failed to build prompt for {config} {task_id}: unable to get trace data.")
            return None, None, None, None

        msg_intent = get_intent_message(config, trace_data, add_state_idxs=[], state_img_intros=[])
        trajectory_msgs = get_trajectory_msgs(config, trace_data)

        verifier_prompts = get_verifier_prompts(config)
        sys_prompt, eval_prompt = verifier_prompts["sys_prompt"], verifier_prompts["eval_prompt"]

        if config["env"] == "vwa":
            msg_for_url_link = get_msg_for_url_link(trajectory_msgs)
            if msg_for_url_link:
                logger.info(f"Adding IMG links for {task_id}, {config}")
                trajectory_msgs.append(msg_for_url_link)

        full_prompt = [{"role": "system", "content": sys_prompt}, msg_intent, trajectory_msgs, final_k, eval_prompt]
        if run_config.get("skip_payload", 0):
            payload_size = get_conversation_payload_size(get_messages(full_prompt))
            if payload_size > run_config["skip_payload"]:
                logger.info(
                    f"Skipping {task_id} because payload size {payload_size / 1024 / 1024} MB is greater than {run_config['skip_payload'] / 1024 / 1024} MB"
                )
                return None, None, None, None
        logger.info(f"VERIFY: Finished building llm call args for task {task_id}, config {config}")
        return full_prompt, conversation_dir, usage_dir, task_id
    except Exception as e:
        logger.error(f"Failed to build prompt for {config} {task_id}: {e}")
        return None, None, None, None


if __name__ == "__main__":
    batch_mode = run_config["batch_mode"]
    all_configs = build_all_eval_configs()

    if len(all_configs) == 0:
        logger.info("No configs to run")
        exit()

    if batch_mode:
        logger.info("Running verify in batch mode")
        run_batch_mode(
            all_configs,
            run_config=run_config,
            gen_config=gen_config,
            build_llm_call_args_fn=build_llm_call_args,
        )
    else:
        logger.info("Running verify in sequential mode")
        asyncio.run(
            run_sequential(
                all_configs,
                run_config=run_config,
                gen_config=gen_config,
                build_llm_call_args_fn=build_llm_call_args,
            )
        )
    print(f"Total tokens: {total_tokens}")
    print(f"Total prompts: {total_prompts}")
    print(f"Average tokens per prompt: {total_tokens / total_prompts}")
