import json
import sys

from llms.llm_utils import visualize_prompt

cache_trajectory_msgs = {}

import os
from pathlib import Path

from llms.prompt_utils import get_message
from osw_utils.osw_utils import TrajectoryView, annotate_action_on_image, trace_to_english
from utils.logger_utils import logger

if __name__ == "__main__" and not __package__:  # @debug
    # Insert the parent directory into sys.path so that the package can be found
    parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
    sys.path.insert(0, parent_dir)
    # Manually set the package name so that relative imports work
    __package__ = "offline_experiments"


def get_trace_data_osw(
    trajectory_path: str,
    task_id: str | int,
):
    trace_data = trace_to_english(trace_path=trajectory_path)
    trace_data["base_path"] = Path(trajectory_path).parent
    trace_data["trajectory_path"] = trajectory_path
    return trace_data


def get_intent_message_osw(
    trace_data,
    add_state_idxs: list[int] = [],
    state_img_intros: list[str] = [],
    objective_template: str = "## OBJECTIVE: {objective}",
    objective_first: bool = True,
):
    trajectory = TrajectoryView(trace_data, to_english=True)
    states = trajectory.states
    intent_str = trace_data["objective"]

    inputs = []
    if len(state_img_intros) > 0:
        inputs.append("## IMAGES:\n")
    for i, state_idx in enumerate(add_state_idxs):
        img_obs = states[state_idx]["observation"]["images"][-1]

        img_intro = f"Image ({i + 1})"
        if state_img_intros:
            img_intro += f": {state_img_intros[state_idx]}"
        inputs.append(img_intro)
        inputs.append(img_obs)

    str_objective = objective_template.format(objective=intent_str)

    if objective_first:
        return [str_objective] + inputs
    else:
        return inputs + [str_objective]


def get_interaction_history_message(
    trace_data,
    add_state_idxs: list[int] = [],
    num_states: int = 0,
    last_img_per_state: bool = True,
    use_text_obs=False,
    use_img_obs=True,
    intro_txt_obs_template="",
    intro_img_obs_template="## STATE t-{t} screenshot",
    use_thoughts: bool = True,
    use_actions: bool = True,
    name_user="user",
    name_assistant="assistant",
    img_detail: str = "auto",
    annotate_action: bool = True,
):
    trajectory = TrajectoryView(trace_data, to_english=True)

    states = trajectory.states
    if not add_state_idxs:
        add_state_idxs = list(range(len(states) - num_states))

    idxs = sorted(add_state_idxs)

    messages = []
    t = len(idxs)
    for idx in idxs:
        action = trajectory.actions[idx]
        state = trajectory.states[idx]

        text_obs = state["observation"]["text"] if use_text_obs else ""
        img_observations = state["observation"]["images"] if use_img_obs else []
        if last_img_per_state and img_observations:
            img_observations = [img_observations[-1]]

        intro_txt_obs = ""
        if use_text_obs and intro_txt_obs_template:
            intro_txt_obs = intro_txt_obs_template.format(t=t)

        intro_img_obs = ""
        if use_img_obs and intro_img_obs_template:
            intro_img_obs = intro_img_obs_template.format(t=t)

        if eng_texts := action.get("texts_en"):
            text_generation = eng_texts[0]
        else:
            text_generation = action["texts"][0]

        try:
            if annotate_action:
                img = img_observations[-1]
                img = annotate_action_on_image(img, text_generation)
                img_observations = [img]
        except Exception as e:
            logger.warning(
                f"{__file__}: Failed to annotate action on image: {e}. Action: {text_generation}. Image: {img_observations[-1]}"
            )

        user_msg = get_message(
            inputs=[intro_txt_obs, text_obs, intro_img_obs, *img_observations],
            role="user",
            name=name_user,
            img_detail=img_detail,
        )
        assistant_msg = get_message(
            inputs=text_generation,
            role="assistant",
            name=name_assistant,
            img_detail=img_detail,
        )

        messages.extend([user_msg, assistant_msg])
        t -= 1
    return messages


def get_trajectory_osw(config, trace_data):
    # global cache_trajectory_msgs

    prompt_config = config["prompt_args"]
    use_a = True if prompt_config.get("trace_info", "actions") == "actions" else False
    use_u = True if prompt_config.get("trace_info", "utt") == "utt" else False
    state_idxs = prompt_config.get("state_idxs", [])

    # TODO: caching
    # # If cached, return cached message
    # cache_key = (use_a, use_u, tuple(state_idxs), trace_data["trajectory_path"])

    # if cache_key in cache_trajectory_msgs:
    #     return cache_trajectory_msgs[cache_key]

    # Else, build message
    img_detail = config.get("img_detail", "auto")
    annotate_action = config.get("additional_config", {}).get("annotate_actions_on_image", True)
    msgs = get_interaction_history_message(
        trace_data, state_idxs, img_detail=img_detail, annotate_action=annotate_action
    )

    # cache_trajectory_msgs[cache_key] = msgs
    return msgs


if __name__ == "__main__":
    trajectory_path = (
        "trace_data_osworld/ui-tars-1.5_50steps_2025-04-05/chrome/7b6c7e24-c58a-49fc-a5bb-d57b80e5b4c3/trajectory.json"
    )
    config = {"prompt_args": {"trace_info": "utt", "state_idxs": []}}
    trace_data = get_trace_data_osw(trajectory_path)
    msgs = get_trajectory_osw(config, trace_data)
    visualize_prompt(msgs, "./vis.html")
