import datetime
import json
from collections import OrderedDict
import time
import mm_agents.uipath.llm_client as llm_client
from mm_agents.uipath.types_utils import (
    PlanAction,
    ExecutionState,
    State,
    PlanActionType,
)
from mm_agents.uipath.action_planner_prompt_builder import (
    ComputerUseAgentInterface,
    PlanerCoTSections,
    user_command_template,
    user_task_info_template,
    PlannerOutput,
)
from mm_agents.uipath.utils import ValidationException, parse_message_json


class ActionPlanner(object):
    def __init__(self):
        self.number_history_steps_with_images = 2
        self.computer_use_agent_interface = ComputerUseAgentInterface()

    def build_message_output_format_info(self) -> str:
        output_dict = OrderedDict({})
        for _, value in PlanerCoTSections.items():
            display = value["display"]
            description = value["description"]
            output_dict[display] = description

        output_dict["action"] = (
            "<The action to perform in JSON format as specified in the system message>"
        )

        return json.dumps(output_dict, indent=4, ensure_ascii=False)

    def get_step_content(
        self, step: dict, following_step: dict | None
    ) -> tuple[str, str]:
        content_dict = OrderedDict({})
        observation_dict = OrderedDict({})

        observation_dict["Performed actions"] = step["actions"]

        if (
            "extracted_data" in step["additional_parameters"]
        ):  # if the step was an extraction step add the dummy extraction action
            extraction_action = {
                "type": PlanActionType.ExtractData,
                "description": step["description"],
                "status": "data extracted",
            }
            observation_dict["Performed actions"] = [extraction_action]

        if following_step:
            observation_dict["Observation"] = following_step[
                "additional_parameters"
            ].get("review", None)

        for key, value in PlanerCoTSections.items():
            if key != "review":
                param_value = step["additional_parameters"].get(key, None)
                display_name = value["display"]
                content_dict[display_name] = param_value
        content_dict["actions"] = json.loads(
            step["additional_parameters"]["plan_action"]
        )

        content_dict = json.dumps(content_dict, indent=4, ensure_ascii=False)
        observation_dict = json.dumps(observation_dict, indent=4, ensure_ascii=False)
        return content_dict, observation_dict

    def build_messages_chat(self, state: State, execution_info: dict) -> list[dict]:
        messages = []
        system_message = {
            "role": "system",
            "content": self.computer_use_agent_interface.get_system_prompt(),
        }

        messages.append(system_message)

        user_task_info_message = {
            "role": "user",
            "content": [
                {
                    "type": "text",
                    "text": user_task_info_template.format(
                        task=state.task,
                        current_date=datetime.datetime.now().strftime("%Y-%m-%d"),
                    ),
                }
            ],
        }

        messages.append(user_task_info_message)

        start_index = max(
            0, len(state.previous_steps) - self.number_history_steps_with_images
        )
        end_index = len(state.previous_steps)

        for index in range(0, end_index):
            step = state.previous_steps[index]

            if index >= start_index:
                assert step["image"] is not None and len(step["image"]) > 0, (
                    "Step image is empty"
                )
                user_image_message = {
                    "role": "user",
                    "content": [
                        {
                            "type": "image_url",
                            "image_url": {
                                "url": f"data:image/jpeg;base64,{step['image']}"
                            },
                        },
                    ],
                }
                messages.append(user_image_message)

            assistant_message_text, user_observation = self.get_step_content(
                step, state.previous_steps[index + 1] if index < end_index - 1 else None
            )

            assistant_message = {
                "role": "assistant",
                "content": [
                    {
                        "type": "text",
                        "text": assistant_message_text,
                    },
                ],
            }
            messages.append(assistant_message)

            user_message_reply = {
                "role": "user",
                "content": [
                    {
                        "type": "text",
                        "text": user_observation,
                    },
                ],
            }
            messages.append(user_message_reply)

        last_user_message = {
            "role": "user",
            "content": [
                {
                    "type": "text",
                    "text": "Current screenshot:",
                },
                {
                    "type": "image_url",
                    "image_url": {
                        "url": f"data:image/jpeg;base64,{state.image_base64}"
                    },
                },
                {
                    "type": "text",
                    "text": user_command_template.format(
                        task=state.task,
                        execution_info_message=self.build_execution_info_message(
                            execution_info
                        ),
                        json_output_format=self.build_message_output_format_info(),
                    ),
                },
            ],
        }

        messages.append(last_user_message)
        return messages

    def extract_response(
        self, response_content: str
    ) -> tuple[PlanAction, dict[str, str]]:
        cot_sections_lst = list(PlanerCoTSections.keys())

        additional_sections = OrderedDict({})
        response_json = parse_message_json(response_content)

        for section in cot_sections_lst:
            section_display = PlanerCoTSections[section]["display"]
            if section_display not in response_json:
                raise ValidationException(
                    f"Invalid response format, '{section}' key not found: {response_content}"
                )
            additional_sections[section] = response_json.get(
                PlanerCoTSections[section]["display"]
            )

        if "action" not in response_json:
            raise ValidationException(
                f"Invalid response format, 'action' key not found: {response_content}"
            )

        action_dict = response_json["action"]

        plan_action = PlanAction.from_dict(self.correct_action_type(action_dict))

        if plan_action.action_type == PlanActionType.Drag:
            self.computer_use_agent_interface.validate_action(plan_action)

        return plan_action, additional_sections

    def build_execution_info_message(self, execution_info: dict) -> str:
        execution_info_message = ""
        if "planner_action_review" in execution_info:
            action_description = execution_info["planner_action_review"][
                "action_description"
            ]
            error_message = execution_info["planner_action_review"]["error_message"]

            execution_info_message = f"You predicted this action: '{action_description}' but it is not valid because: {error_message}. If the target element is not visible on the screenshot, scroll first to make the target element visible. If the target element is not correct, change the action description with more precise element description using nearby context."
        return execution_info_message

    def correct_action_type(self, response_json: dict) -> dict:
        action_type = response_json.get("type", "").lower()
        if action_type in ("press", "key_press", "press_key"):
            response_json["type"] = "key_press"
        elif action_type in ("mouse_move", "move_mouse"):
            response_json["type"] = "move_mouse"
        elif action_type in ("type_text", "type_into", "type"):
            response_json["type"] = "type"
        elif "scroll" in action_type:
            response_json["type"] = "scroll"
        elif "wait" in action_type:
            response_json["type"] = "wait"
        return response_json

    def predict(self, state: State, execution_state: ExecutionState) -> PlannerOutput:
        messages = self.build_messages_chat(state, execution_state.execution_info)
        llm_messages = [message for message in messages]
        repeat_count = 2
        plan, response_content = None, None
        while repeat_count > 0:
            try:
                payload = {
                    "model": execution_state.model_name,
                    "messages": llm_messages,
                    "max_completion_tokens": 5000,
                    "reasoning_effort": "medium",
                }
                response_content = llm_client.send_messages(payload)
                if response_content is None or len(response_content.strip()) == 0:
                    raise ValidationException("Planner response is None or empty")
                plan_action, additional_sections = self.extract_response(
                    str(response_content)
                )
                plan = PlannerOutput(plan_action, additional_sections)
                break
            except ValidationException as e:
                time.sleep(5)
                repeat_count -= 1
                ai_message = {
                    "role": "assistant",
                    "content": [
                        {
                            "type": "text",
                            "text": response_content,
                        },
                    ],
                }
                error_message = {
                    "role": "user",
                    "content": [
                        {
                            "type": "text",
                            "text": f"{e.message}. Please try again and output a valid response in the correct format.",
                        },
                    ],
                }

                llm_messages = messages + [ai_message, error_message]

                if repeat_count == 0:
                    raise ValueError(
                        f"Invalid planner response format: {response_content}, {str(e)}"
                    )
        if plan is None:
            raise ValueError("Planner response is not valid")
        return plan
