"""
OpenCUA Agent for OfficeArena.

Ported from OSWorld implementation:
https://github.com/xlang-ai/OSWorld/blob/main/mm_agents/opencua/opencua_agent.py

Uses litellm as the unified API backend.
"""

import json
import re
import time
import traceback
from io import BytesIO
from typing import Any, Dict, List, Optional, Tuple, Union

import litellm
from PIL import Image

from officearena.adapter.action_adapter import ActionAdapter
from officearena.agents.base import BaseAgent
from officearena.agents.config import OpenCUAConfig

from .prompts import (
    ACTION_HISTORY_TEMPLATE,
    INSTRUCTION_TEMPLATE,
    OBSERVATION_HISTORY_TEMPLATE,
    STEP_TEMPLATE,
    SYSTEM_PROMPT_V1_L1,
    SYSTEM_PROMPT_V1_L2,
    SYSTEM_PROMPT_V1_L3,
    THOUGHT_HISTORY_TEMPLATE,
    build_sys_prompt,
)
from .utils import process_image_for_opencua, project_coordinate_to_absolute_scale


def parse_response_to_cot_and_action(input_string: str, screen_size: Tuple[int, int], coordinate_type: str) -> Tuple[str, List[str], dict]:
    """
    Parse response including Observation, Thought, Action and code block.

    Args:
        input_string: Raw LLM response
        screen_size: (width, height) of screen
        coordinate_type: Type of coordinate system

    Returns:
        Tuple of (action_description, list of pyautogui actions, parsed sections dict)
    """
    sections: Dict[str, Any] = {}

    try:
        # Extract observation
        obs_match = re.search(
            r"^##\s*Observation\s*:?[\n\r]+(.*?)(?=^##\s*Thought:|^##\s*Action:|^##|\Z)",
            input_string,
            re.DOTALL | re.MULTILINE,
        )
        if obs_match:
            sections["observation"] = obs_match.group(1).strip()

        # Extract thought
        thought_match = re.search(
            r"^##\s*Thought\s*:?[\n\r]+(.*?)(?=^##\s*Action:|^##|\Z)",
            input_string,
            re.DOTALL | re.MULTILINE,
        )
        if thought_match:
            sections["thought"] = thought_match.group(1).strip()

        # Extract action
        action_match = re.search(
            r"^##\s*Action\s*:?[\n\r]+(.*?)(?=^##|\Z)",
            input_string,
            re.DOTALL | re.MULTILINE,
        )
        if action_match:
            action = action_match.group(1).strip()
            sections["action"] = action.strip()

        # Extract code blocks
        code_blocks = re.findall(r"```(?:code|python)?\s*(.*?)\s*```", input_string, re.DOTALL | re.IGNORECASE)
        if not code_blocks:
            print("No code blocks found in the input string")
            return f"<Error>: no code blocks found in the input string: {input_string}", ["FAIL"], sections

        code_block = code_blocks[-1].strip()
        sections["original_code"] = code_block

        # Check for special actions
        if "computer.wait" in code_block.lower():
            sections["code"] = "WAIT"
            return sections.get("action", "Wait for operation to complete"), ["WAIT"], sections

        elif "computer.terminate" in code_block.lower():
            lower_block = code_block.lower()
            if "failure" in lower_block or "fail" in lower_block:
                sections["code"] = "FAIL"
                return code_block, ["FAIL"], sections
            elif "success" in lower_block:
                sections["code"] = "DONE"
                return code_block, ["DONE"], sections
            else:
                print("Terminate action found but no specific status provided in code block")
                return f"<Error>: terminate action found but no specific status provided in code block: {input_string}", ["FAIL"], sections

        # Process pyautogui code - project coordinates
        corrected_code = code_block
        sections["code"] = corrected_code
        sections["code"] = project_coordinate_to_absolute_scale(
            corrected_code,
            screen_width=screen_size[0],
            screen_height=screen_size[1],
            coordinate_type=coordinate_type,
        )

        if ("code" not in sections or sections["code"] is None or sections["code"] == "") or ("action" not in sections or sections["action"] is None or sections["action"] == ""):
            print("Missing required action or code section")
            return f"<Error>: no code parsed: {input_string}", ["FAIL"], sections

        return sections["action"], [sections["code"]], sections

    except Exception as e:
        error_message = f"<Error>: parsing response: {str(e)}\nTraceback:\n{traceback.format_exc()}\nInput string: {input_string}"
        print(error_message)
        return error_message, ["FAIL"], sections


class OpenCUAActionAdapter(ActionAdapter):
    """
    Action adapter for OpenCUA agent.

    Executes PyAutoGUI code blocks using the sandbox's execute_python_command method.
    """

    def execute_action(self, action_str: str) -> Dict[str, Any]:
        """
        Execute an action from OpenCUA agent.

        Args:
            action_str: JSON string containing action output from OpenCUA

        Returns:
            Dictionary with execution results
        """
        try:
            action_data = json.loads(action_str)
            output = action_data.get("output", [])

            last_result = None
            last_action_type = None
            found_any_computer_call = False

            for item in output:
                if item.get("type") == "message" and "DONE" in item.get("content", [{}])[0].get("text", ""):
                    return {
                        "success": True,
                        "task_completed": True,
                        "result": "Task finished by agent.",
                        "action_type": "finish_message",
                    }

                if item.get("type") == "computer_call":
                    found_any_computer_call = True
                    action = item.get("action", {})
                    action_type = action.get("type", "")

                    if action_type == "finish":
                        return {
                            "success": True,
                            "task_completed": True,
                            "result": action.get("message", "DONE. Task completed."),
                            "action_type": "finish",
                        }
                    elif action_type == "wait":
                        duration = action.get("duration", 20)
                        time.sleep(duration)
                        last_result = f"Waited for {duration} seconds"
                        last_action_type = "wait"
                    elif action_type == "pyautogui":
                        # Execute PyAutoGUI code
                        code = action.get("code", "")
                        if code:
                            last_result = self._execute_pyautogui_code(code)
                            last_action_type = "pyautogui"
                    else:
                        # Fallback to parent dispatch
                        action_type_dispatch = action.pop("type", None)
                        if action_type_dispatch:
                            last_result = self._dispatch_action(action_type_dispatch, action)
                            last_action_type = action_type_dispatch

            if not found_any_computer_call:
                return {
                    "success": False,
                    "task_completed": False,
                    "error": "No computer_call action found in the output.",
                    "action_type": "error",
                }

            return {
                "success": True,
                "task_completed": False,
                "result": last_result,
                "action_type": last_action_type or "computer_call",
            }

        except Exception as e:
            print(f"Error in execute_action: {e}")
            return {
                "success": False,
                "task_completed": False,
                "error": str(e),
                "action_type": "error",
            }

    def _execute_pyautogui_code(self, code: str) -> Any:
        """
        Execute PyAutoGUI code in the sandbox.

        Args:
            code: PyAutoGUI Python code to execute

        Returns:
            Result from sandbox execution
        """
        # Clean up the code - remove any stray code fences
        code = code.strip()
        if code.startswith("```"):
            lines = code.split("\n")
            # Remove first line if it's a code fence
            if lines[0].startswith("```"):
                lines = lines[1:]
            # Remove last line if it's a code fence
            if lines and lines[-1].strip() == "```":
                lines = lines[:-1]
            code = "\n".join(lines)

        # Execute using sandbox's execute_python_command
        result = self.sandbox.execute_python_command(
            import_prefix=["pyautogui"],
            command=code,
        )
        return result


class OpenCUA(BaseAgent):
    """
    OpenCUA Agent for OfficeArena.

    This agent uses the OpenCUA model to interact with the environment using
    PyAutoGUI-based actions. It is based on the OSWorld OpenCUA implementation.
    """

    def __init__(self, config: Optional[Union[Dict[str, Any], OpenCUAConfig]] = None):
        """
        Initialize the OpenCUA agent.

        Args:
            config: Configuration for the agent (OpenCUAConfig or dict)
        """
        # Handle configuration conversion
        if isinstance(config, OpenCUAConfig):
            config.validate()
            self.agent_config = config
            config_dict = config.to_dict()
        elif isinstance(config, dict):
            config_dict = config.copy()
            # Create OpenCUAConfig from dict for validation
            try:
                self.agent_config = OpenCUAConfig(**config_dict)
                self.agent_config.validate()
            except TypeError:
                # Handle legacy dict format with display_size
                display_size = config_dict.pop("display_size", {"width": 1024, "height": 768})
                config_dict["display_width"] = display_size.get("width", 1024)
                config_dict["display_height"] = display_size.get("height", 768)
                self.agent_config = OpenCUAConfig(**config_dict)
                self.agent_config.validate()
                config_dict = self.agent_config.to_dict()
        elif config is None:
            self.agent_config = OpenCUAConfig()
            config_dict = self.agent_config.to_dict()
        else:
            raise ValueError(f"Config must be dict or OpenCUAConfig, got {type(config)}")

        # Initialize base class
        super().__init__(config_dict)

        self.model = self.agent_config.model_name
        self.api_key = self.agent_config.api_key
        self.base_url = self.agent_config.base_url

        if not self.model:
            raise ValueError("model_name must be provided in the config for OpenCUA agent.")

        if not self.api_key:
            raise ValueError("api_key must be provided for the LLM endpoint.")

        self.display_size: Dict[str, int] = self.agent_config.display_size
        self.screen_size: Tuple[int, int] = (self.display_size["width"], self.display_size["height"])
        self.coordinate_type = self.agent_config.coordinate_type
        self.cot_level = self.agent_config.cot_level
        self.history_type = self.agent_config.history_type
        self.max_image_history_length = self.agent_config.max_image_history_length
        self.max_steps = self.agent_config.max_steps
        self.password = self.agent_config.password

        # Build system prompt
        if self.agent_config.use_old_sys_prompt:
            if self.cot_level == "l1":
                self.system_prompt = SYSTEM_PROMPT_V1_L1
            elif self.cot_level == "l2":
                self.system_prompt = SYSTEM_PROMPT_V1_L2.format(password=self.password)
            elif self.cot_level == "l3":
                self.system_prompt = SYSTEM_PROMPT_V1_L3
            else:
                raise ValueError("Invalid cot_level. Choose from 'l1', 'l2', or 'l3'.")
        else:
            self.system_prompt = build_sys_prompt(
                level=self.cot_level,
                password=self.password,
                use_random=False,
            )

        # Set up history template
        if self.history_type == "action_history":
            self.HISTORY_TEMPLATE = ACTION_HISTORY_TEMPLATE
        elif self.history_type == "thought_history":
            self.HISTORY_TEMPLATE = THOUGHT_HISTORY_TEMPLATE
        elif self.history_type == "observation_history":
            self.HISTORY_TEMPLATE = OBSERVATION_HISTORY_TEMPLATE
        else:
            raise ValueError(f"Invalid history type: {self.history_type}")

        # History tracking
        self.actions: List[str] = []
        self.observations: List[Dict[str, Any]] = []
        self.cots: List[Dict[str, Any]] = []
        self.screenshots: List[str] = []  # base64 encoded

    @property
    def action_adapter_class(self) -> type[ActionAdapter]:
        return OpenCUAActionAdapter

    def _build_messages(self, screenshot: bytes, instruction: str) -> List[Dict[str, Any]]:
        """
        Build message list for LLM call.

        Args:
            screenshot: Current screenshot bytes
            instruction: Task instruction

        Returns:
            List of messages for LLM
        """
        # Process current image
        processed_image, orig_w, orig_h, proc_w, proc_h = process_image_for_opencua(screenshot)

        messages = []
        messages.append({"role": "system", "content": self.system_prompt})

        instruction_prompt = INSTRUCTION_TEMPLATE.format(instruction=instruction)

        # Build history
        history_step_texts = []
        for i in range(len(self.actions)):
            if i > len(self.actions) - self.max_image_history_length:
                # Include image for recent history
                if i < len(self.screenshots):
                    messages.append(
                        {
                            "role": "user",
                            "content": [
                                {
                                    "type": "image_url",
                                    "image_url": {"url": f"data:image/png;base64,{self.screenshots[i]}"},
                                }
                            ],
                        }
                    )

                history_content = STEP_TEMPLATE.format(step_num=i + 1) + self.HISTORY_TEMPLATE.format(
                    observation=self.cots[i].get("observation", ""),
                    thought=self.cots[i].get("thought", ""),
                    action=self.cots[i].get("action", ""),
                )

                messages.append({"role": "assistant", "content": history_content})
            else:
                # Text-only history for older steps
                history_content = STEP_TEMPLATE.format(step_num=i + 1) + self.HISTORY_TEMPLATE.format(
                    observation=self.cots[i].get("observation", ""),
                    thought=self.cots[i].get("thought", ""),
                    action=self.cots[i].get("action", ""),
                )
                history_step_texts.append(history_content)

                if i == len(self.actions) - self.max_image_history_length:
                    messages.append({"role": "assistant", "content": "\n".join(history_step_texts)})

        # Add current screenshot with instruction
        messages.append(
            {
                "role": "user",
                "content": [
                    {
                        "type": "image_url",
                        "image_url": {"url": f"data:image/png;base64,{processed_image}"},
                    },
                    {"type": "text", "text": instruction_prompt},
                ],
            }
        )

        return messages

    def step(self, screenshot: bytes, instruction: str) -> str:
        """
        Takes a step in the environment using the OpenCUA model.

        Args:
            screenshot: Current screenshot as bytes
            instruction: Task instruction

        Returns:
            JSON string with action output
        """
        current_step = len(self.actions)
        print(f"========= OpenCUA Step {current_step + 1} =======")
        print(f"Instruction: {instruction}")

        # Get original dimensions
        image = Image.open(BytesIO(screenshot))
        original_width, original_height = image.size

        # Process image for storage
        processed_image, _, _, _, _ = process_image_for_opencua(screenshot)

        # Build messages
        messages = self._build_messages(screenshot, instruction)

        # Call LLM
        max_retry = 5
        retry_count = 0
        low_level_instruction = None
        pyautogui_actions = None
        other_cot: Dict[str, Any] = {}

        while retry_count < max_retry:
            try:
                completion_kwargs = {
                    "model": self.model,
                    "messages": messages,
                    "max_tokens": self.agent_config.max_tokens,
                    "top_p": self.agent_config.top_p,
                    "temperature": self.agent_config.temperature if retry_count == 0 else max(0.2, self.agent_config.temperature),
                    "api_key": self.api_key,
                }

                if self.base_url:
                    completion_kwargs["base_url"] = self.base_url

                response = litellm.completion(**completion_kwargs)

                response_text = response.choices[0].message.content or ""

                print(f"Model Output:\n{response_text}")

                if not response_text:
                    print("No response found in the response.")
                    raise ValueError(f"No response found in the response:\n{response}.")

                low_level_instruction, pyautogui_actions, other_cot = parse_response_to_cot_and_action(response_text, self.screen_size, self.coordinate_type)

                if "<Error>" in low_level_instruction or not pyautogui_actions:
                    print(f"Error parsing response: {low_level_instruction}")
                    raise ValueError(f"Error parsing response: {low_level_instruction}")

                break

            except Exception as e:
                print(f"Error during message preparation: {e}")
                retry_count += 1
                if retry_count == max_retry:
                    print("Maximum retries reached. Exiting.")
                    return json.dumps(
                        {
                            "status": "error",
                            "output": [
                                {"type": "reasoning", "summary": [{"text": str(e)}]},
                                {"type": "message", "content": [{"type": "output_text", "text": "FAIL. Maximum retries reached."}]},
                            ],
                        }
                    )
                time.sleep(1)

        print(f"Action: {low_level_instruction}")
        print(f"Code: {pyautogui_actions}")

        # Store in history
        self.screenshots.append(processed_image)
        self.observations.append({"screenshot": screenshot})
        self.actions.append(low_level_instruction or "")
        self.cots.append(other_cot)

        # Check max steps
        current_step = len(self.actions)
        if current_step >= self.max_steps and pyautogui_actions and "computer.terminate" not in pyautogui_actions[0].lower():
            print(f"Reached maximum steps {self.max_steps}. Forcing termination.")
            low_level_instruction = "Fail the task because reaching the maximum step limit."
            pyautogui_actions = ["FAIL"]
            other_cot["code"] = "FAIL"

        # Build response JSON
        if not pyautogui_actions:
            action_output = {
                "status": "error",
                "output": [
                    {"type": "reasoning", "summary": [{"text": str(other_cot)}]},
                    {"type": "message", "content": [{"type": "output_text", "text": "Could not parse action from response."}]},
                ],
            }
        else:
            action_code = pyautogui_actions[0] if pyautogui_actions else ""

            if action_code == "DONE":
                action_output = {
                    "status": "completed",
                    "output": [
                        {"type": "reasoning", "summary": [{"text": other_cot.get("thought", "")}]},
                        {"type": "message", "content": [{"type": "output_text", "text": "DONE. Task completed successfully."}]},
                    ],
                }
            elif action_code == "FAIL":
                action_output = {
                    "status": "completed",
                    "output": [
                        {"type": "reasoning", "summary": [{"text": other_cot.get("thought", "")}]},
                        {"type": "message", "content": [{"type": "output_text", "text": "DONE. Task failed."}]},
                    ],
                }
            elif action_code == "WAIT":
                action_output = {
                    "status": "completed",
                    "output": [
                        {"type": "reasoning", "summary": [{"text": other_cot.get("thought", "")}]},
                        {"type": "computer_call", "action": {"type": "wait", "duration": 20}},
                    ],
                }
            else:
                # PyAutoGUI code execution
                action_output = {
                    "status": "completed",
                    "output": [
                        {"type": "reasoning", "summary": [{"text": other_cot.get("thought", "")}]},
                        {"type": "computer_call", "action": {"type": "pyautogui", "code": action_code}},
                    ],
                }

        return json.dumps(action_output)

    def reset(self) -> None:
        """
        Resets the agent's state for a new task.
        """
        super().reset()
        self.actions = []
        self.observations = []
        self.cots = []
        self.screenshots = []
