import json
import shlex
import time
from typing import Any, Dict, Literal, Optional, Union

import litellm

from officearena.adapter.action_adapter import ActionAdapter
from officearena.agents.base import BaseAgent
from officearena.agents.config import ClaudeConfig
from officearena.utils.task_proposal import add_tasks_to_dataset

CLAUDE_SYSTEM_PROMPT = """<SYSTEM_CAPABILITY>
You are utilising an Ubuntu virtual machine with internet access. You are able to use the computer to solve Microsoft Office tasks.
</SYSTEM_CAPABILITY>

<IMPORTANT>
You absolutely must avoid asking any clarification or follow-up questions--just execute the task as best you can with what you're given.
Refrain from asking any "Yes" or "No" questions about whether you should proceed--just assume the answer is always "Yes".
When you are done with the task or are unable to complete it, use the finish tool to finish.
</IMPORTANT>"""


class ClaudeActionAdapter(ActionAdapter):
    ComputerUseActionType = Literal[
        "screenshot",
        "wait",
        "mouse_move",
        "left_click",
        "right_click",
        "double_click",
        "key",
        "type",
        "left_click_drag",
        "middle_click",
        "triple_click",
        "left_mouse_down",
        "left_mouse_up",
        "scroll",
        "hold_key",
        "cursor_position",
    ]

    ActionType = ComputerUseActionType | Literal["computer", "add_tasks_to_dataset", "finish"]

    click_buttons = {
        "left_click": 1,
        "right_click": 3,
        "middle_click": 2,
        "double_click": "--repeat 2 --delay 10 1",
        "triple_click": "--repeat 3 --delay 10 1",
    }

    scroll_buttons = {
        "up": 4,
        "down": 5,
        "left": 6,
        "right": 7,
    }

    @property
    def action_adapter_class(self) -> type[ActionAdapter]:
        return ClaudeActionAdapter

    def _dispatch_action(self, action_type: ActionType, args: Dict[str, Any]) -> Any:
        if action_type == "computer":
            action_type = args["action"]

        if action_type == "screenshot":
            pass  # Screenshot is automatically added for all actions
        elif action_type == "wait":
            time.sleep(args["duration"])
            return
        elif action_type == "mouse_move":
            x, y = args["coordinate"]
            command_parts = ["xdotool", "mousemove", str(x), str(y)]
            self.sandbox.execute_command(" ".join(command_parts))
            return
        elif action_type == "left_click":
            x, y = args["coordinate"]
            mouse_move_part = f"mousemove --sync {x} {y}"
            command_parts = ["xdotool", mouse_move_part]
            if "key" in args:
                command_parts.append(f"keydown {args['key']}")
            command_parts.append(f"click {self.click_buttons['left_click']}")
            if "key" in args:
                command_parts.append(f"keyup {args['key']}")
            self.sandbox.execute_command(" ".join(command_parts))
            return
        elif action_type == "right_click":
            x, y = args["coordinate"]
            mouse_move_part = f"mousemove --sync {x} {y}"
            command_parts = ["xdotool", mouse_move_part]
            if "key" in args:
                command_parts.append(f"keydown {args['key']}")
            command_parts.append(f"click {self.click_buttons['right_click']}")
            return
        elif action_type == "middle_click":
            x, y = args["coordinate"]
            mouse_move_part = f"mousemove --sync {x} {y}"
            command_parts = ["xdotool", mouse_move_part]
            if "key" in args:
                command_parts.append(f"keydown {args['key']}")
            command_parts.append(f"click {self.click_buttons['middle_click']}")
            if "key" in args:
                command_parts.append(f"keyup {args['key']}")
            self.sandbox.execute_command(" ".join(command_parts))
            return
        elif action_type == "double_click":
            x, y = args["coordinate"]
            mouse_move_part = f"mousemove --sync {x} {y}"
            command_parts = ["xdotool", mouse_move_part]
            if "key" in args:
                command_parts.append(f"keydown {args['key']}")
            command_parts.append(f"click {self.click_buttons['double_click']}")
            self.sandbox.execute_command(" ".join(command_parts))
            return
        elif action_type == "triple_click":
            x, y = args["coordinate"]
            mouse_move_part = f"mousemove --sync {x} {y}"
            command_parts = ["xdotool", mouse_move_part]
            if "key" in args:
                command_parts.append(f"keydown {args['key']}")
            command_parts.append(f"click {self.click_buttons['triple_click']}")
            self.sandbox.execute_command(" ".join(command_parts))
            return
        elif action_type == "left_mouse_down":
            command_parts = ["xdotool", "mousedown", self.click_buttons["left_click"]]
            self.sandbox.execute_command(" ".join(command_parts))
            return
        elif action_type == "left_mouse_up":
            command_parts = ["xdotool", "mouseup", self.click_buttons["left_click"]]
            self.sandbox.execute_command(" ".join(command_parts))
            return
        elif action_type == "left_click_drag":
            x, y = args["coordinate"]
            command_parts = ["xdotool", f"mousedown 1 mousemove --sync {x} {y} mouseup 1"]
            self.sandbox.execute_command(" ".join(command_parts))
            return
        elif action_type == "scroll":
            coordinate = args.get("coordinate")
            mouse_move_part = ""
            if coordinate:
                x, y = coordinate
                mouse_move_part = f"mousemove --sync {x} {y}"
            command_parts = ["xdotool", mouse_move_part]
            modifier_text = args.get("text", "")
            if modifier_text:
                command_parts.append(f"keydown {modifier_text}")
            # Support either 'direction' or 'scroll_direction'
            direction = args.get("direction") or args.get("scroll_direction")
            # Support either 'scroll_amount' or 'amount'
            amount = args.get("scroll_amount", args.get("amount", 1))
            if not direction:
                raise ValueError("scroll action missing 'direction' or 'scroll_direction'")
            button = self.scroll_buttons[direction]
            command_parts.append(f"click --repeat {amount} {button}")
            if modifier_text:
                command_parts.append(f"keyup {modifier_text}")
            self.sandbox.execute_command(" ".join(command_parts))
            return
        elif action_type == "cursor_position":
            command_parts = ["xdotool", "getmouselocation --shell"]
            output = self.sandbox.execute_command(" ".join(command_parts))
            x, y = (
                int(output.split("X=")[1].split("\n")[0]),
                int(output.split("Y=")[1].split("\n")[0]),
            )
            return f"X={x},Y={y}"
        elif action_type == "key":
            key_text = args["text"]
            command_parts = ["xdotool", f"key {key_text}"]
            self.sandbox.execute_command(" ".join(command_parts))
            return
        elif action_type == "type":
            self.sandbox.write(args["text"])  # TODO: Double-check, this implementation diverges from anthropic's
        elif action_type == "hold_key":
            escaped_keys = shlex.quote(args["text"])
            command_parts = [
                "xdotool",
                "keydown",
                escaped_keys,
                f"sleep {args['duration']}",
                "keyup",
                escaped_keys,
            ]
            self.sandbox.execute_command(" ".join(command_parts))
            return
        elif action_type == "add_tasks_to_dataset":
            add_tasks_to_dataset(args["tasks"])
            return
        else:
            raise ValueError(f"Unknown action type: {action_type}")


class ClaudeAgent(BaseAgent):
    def __init__(self, config: Optional[Union[Dict[str, Any], ClaudeConfig]] = None):
        # Handle configuration conversion
        if isinstance(config, ClaudeConfig):
            config.validate()
            self.agent_config = config
            config_dict = config.to_dict()
        elif isinstance(config, dict):
            config_dict = config.copy()
            # Create ClaudeConfig from dict for validation
            try:
                self.agent_config = ClaudeConfig(**config_dict)
                self.agent_config.validate()
            except TypeError:
                # Handle legacy dict format with display_size
                display_size = config_dict.pop("display_size", {"width": 1024, "height": 768})
                config_dict["display_width"] = display_size["width"]
                config_dict["display_height"] = display_size["height"]
                self.agent_config = ClaudeConfig(**config_dict)
                self.agent_config.validate()
                config_dict = self.agent_config.to_dict()
        elif config is None:
            self.agent_config = ClaudeConfig()
            config_dict = self.agent_config.to_dict()
        else:
            raise ValueError(f"Config must be dict or ClaudeConfig, got {type(config)}")

        # Initialize base class
        super().__init__(config_dict)

        self.model = self.agent_config.model_name
        base_url = self.agent_config.base_url
        api_key = self.agent_config.api_key

        if not self.model:
            raise ValueError("model_name must be provided in the config for Claude agent.")

        if not api_key:
            raise ValueError("api_key must be provided for the LLM endpoint.")

        self.api_key = api_key
        self.base_url = base_url
        self.previous_response_id: Optional[str] = None
        self.previous_computer_call_id: Optional[str] = None
        self.display_size: Dict[str, int] = self.agent_config.display_size
        self.tools = [
            {
                "type": "computer_20250124",
                "function": {
                    "name": "computer",
                    "parameters": {
                        "display_height_px": self.display_size["height"],
                        "display_width_px": self.display_size["width"],
                        "display_number": 1,
                    },
                },
            },
            {
                "type": "function",
                "function": {
                    "name": "finish",
                    "description": "Finish the agent and return message to user.",
                    "parameters": {
                        "type": "object",
                        "properties": {
                            "message": {
                                "type": "string",
                                "description": "The message to the user with a reason for finishing.",
                            }
                        },
                    },
                },
            },
        ]

        self.messages = [{"role": "system", "content": CLAUDE_SYSTEM_PROMPT, "cache_control": {"type": "ephemeral"}}]

    @property
    def action_adapter_class(self) -> type[ActionAdapter]:
        return ClaudeActionAdapter

    def step(self, screenshot: bytes, instruction: str) -> str:
        """
        Takes a step in the environment using the CUA model.
        """
        import base64

        screenshot_b64 = base64.b64encode(screenshot).decode("utf-8")

        if self.previous_response_id is None:
            self.messages.append(
                {
                    "role": "user",
                    "content": [{"type": "text", "text": f"Instruction: {instruction}", "cache_control": {"type": "ephemeral"}}],
                }
            )
            self.messages[-1]["content"].append(
                {
                    "type": "image_url",
                    "image_url": {
                        "url": f"data:image/png;base64,{screenshot_b64}",
                    },
                    "cache_control": {"type": "ephemeral"},
                }
            )

        else:
            # Subsequent steps: Continue the conversation

            self.messages[-1]["content"] = [
                {
                    "type": "image_url",
                    "image_url": {
                        "url": f"data:image/png;base64,{screenshot_b64}",
                    },
                    "cache_control": {"type": "ephemeral"},
                }
            ]

        response = litellm.completion(
            messages=self.messages,
            model=self.model,
            api_key=self.api_key,
            base_url=self.base_url,
            tools=self.tools,
            extra_headers={"anthropic-beta": "computer-use-2025-01-24"},
            reasoning_effort="medium",
        )

        response_message = response.choices[0].message
        self.messages.append(response_message)

        tool_call = response_message.tool_calls[-1]
        self.previous_response_id = tool_call.id

        tool_result = {
            "role": "tool",
            "tool_call_id": tool_call.id,
            "name": tool_call.function.name,
        }

        self.messages.append(tool_result)

        action_output_tool_type = "computer_call" if tool_result.get("name") == "computer" else "call"

        action_output = {
            "status": "completed",
            "output": [
                {
                    "type": "reasoning",
                    "summary": [{"text": response_message.reasoning_content if hasattr(response_message, "reasoning_content") else "", "cache_control": {"type": "ephemeral"}}],
                },
                {
                    "type": action_output_tool_type,
                    "action": {
                        "type": tool_result.get("name"),
                        **json.loads(tool_call.function.arguments),
                    },
                },
            ],
        }

        return json.dumps(action_output)

    def reset(self) -> None:
        super().reset()
        self.messages = [{"role": "system", "content": CLAUDE_SYSTEM_PROMPT, "cache_control": {"type": "ephemeral"}}]
        self.previous_response_id: Optional[str] = None
        self.previous_computer_call_id: Optional[str] = None
