"""
OpenCUA Utilities for OfficeArena.

Ported from OSWorld implementation:
https://github.com/xlang-ai/OSWorld/blob/main/mm_agents/opencua/utils.py
"""

import ast
import base64
import math
import re
from io import BytesIO
from typing import Optional, Tuple

from PIL import Image


def encode_image(image_content: bytes) -> str:
    """Encode image bytes to base64 string."""
    return base64.b64encode(image_content).decode("utf-8")


def smart_resize(
    height: int,
    width: int,
    factor: int = 28,
    min_pixels: int = 56 * 56,
    max_pixels: int = 14 * 14 * 4 * 1280,
    max_aspect_ratio_allowed: Optional[float] = None,
    size_can_be_smaller_than_factor: bool = False,
) -> Tuple[int, int]:
    """
    Rescales the image so that the following conditions are met:

    1. Both dimensions (height and width) are divisible by 'factor'.
    2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
    3. The aspect ratio of the image is maintained as closely as possible.
    """
    if not size_can_be_smaller_than_factor and (height < factor or width < factor):
        raise ValueError(f"height:{height} or width:{width} must be larger than factor:{factor} " f"(when size_can_be_smaller_than_factor is False)")
    elif max_aspect_ratio_allowed is not None and max(height, width) / min(height, width) > max_aspect_ratio_allowed:
        raise ValueError(f"absolute aspect ratio must be smaller than {max_aspect_ratio_allowed}, " f"got {max(height, width) / min(height, width)}" f"(when max_aspect_ratio_allowed is not None)")

    h_bar = max(1, round(height / factor)) * factor
    w_bar = max(1, round(width / factor)) * factor

    if h_bar * w_bar > max_pixels:
        beta = math.sqrt((height * width) / max_pixels)
        h_bar = max(1, math.floor(height / beta / factor)) * factor
        w_bar = max(1, math.floor(width / beta / factor)) * factor
    elif h_bar * w_bar < min_pixels:
        beta = math.sqrt(min_pixels / (height * width))
        h_bar = math.ceil(height * beta / factor) * factor
        w_bar = math.ceil(width * beta / factor) * factor

    return h_bar, w_bar


def process_image_for_opencua(image_bytes: bytes) -> Tuple[str, int, int, int, int]:
    """
    Process an image for OpenCUA models.

    Returns:
        Tuple of (base64_encoded_image, original_width, original_height, processed_width, processed_height)
    """
    image = Image.open(BytesIO(image_bytes))
    original_width, original_height = image.size

    resized_height, resized_width = smart_resize(
        height=original_height,
        width=original_width,
        factor=28,
        min_pixels=3136,
        max_pixels=12845056,
    )

    image = image.resize((resized_width, resized_height))

    buffer = BytesIO()
    image.save(buffer, format="PNG")
    processed_bytes = buffer.getvalue()

    return base64.b64encode(processed_bytes).decode("utf-8"), original_width, original_height, resized_width, resized_height


def project_coordinate_to_absolute_scale(pyautogui_code: str, screen_width: int, screen_height: int, coordinate_type: str = "relative") -> str:
    """
    Convert the relative coordinates in the pyautogui code to absolute coordinates
    based on the logical screen size.

    Args:
        pyautogui_code: PyAutoGUI code with relative coordinates
        screen_width: Actual screen width
        screen_height: Actual screen height
        coordinate_type: Type of coordinate system ("relative" or "qwen25")

    Returns:
        PyAutoGUI code with absolute coordinates
    """

    def _coordinate_projection(x: float, y: float, width: int, height: int, coord_type: str) -> Tuple[int, int]:
        if coord_type == "relative":
            return int(round(x * width)), int(round(y * height))
        elif coord_type == "qwen25":
            resized_height, resized_width = smart_resize(height=height, width=width, factor=28, min_pixels=3136, max_pixels=12845056)
            if 0 <= x <= 1 and 0 <= y <= 1:
                # If already normalized, treat like "relative"
                return int(round(x * resized_width)), int(round(y * resized_height))
            return int(x / resized_width * width), int(y / resized_height * height)
        else:
            raise ValueError(f"Invalid coordinate type: {coord_type}. Expected one of ['relative', 'qwen25'].")

    pattern = r"(pyautogui\.\w+\([^\)]*\))"
    matches = re.findall(pattern, pyautogui_code)

    new_code = pyautogui_code

    for full_call in matches:
        func_name_pattern = r"(pyautogui\.\w+)\((.*)\)"
        func_match = re.match(func_name_pattern, full_call, re.DOTALL)
        if not func_match:
            continue

        func_name = func_match.group(1)
        args_str = func_match.group(2)

        try:
            parsed = ast.parse(f"func({args_str})").body[0].value
            parsed_args = parsed.args
            parsed_keywords = parsed.keywords
        except SyntaxError:
            return pyautogui_code

        function_parameters = {
            "click": ["x", "y", "clicks", "interval", "button", "duration", "pause"],
            "rightClick": ["x", "y", "duration", "tween", "pause"],
            "middleClick": ["x", "y", "duration", "tween", "pause"],
            "doubleClick": ["x", "y", "interval", "button", "duration", "pause"],
            "tripleClick": ["x", "y", "interval", "button", "duration", "pause"],
            "moveTo": ["x", "y", "duration", "tween", "pause"],
            "dragTo": ["x", "y", "duration", "button", "mouseDownUp", "pause"],
        }

        func_base_name = func_name.split(".")[-1]
        param_names = function_parameters.get(func_base_name, [])

        args = {}
        for idx, arg in enumerate(parsed_args):
            if idx < len(param_names):
                param_name = param_names[idx]
                try:
                    arg_value = ast.literal_eval(arg)
                except (ValueError, SyntaxError):
                    continue
                args[param_name] = arg_value

        try:
            for kw in parsed_keywords:
                param_name = kw.arg
                try:
                    arg_value = ast.literal_eval(kw.value)
                except (ValueError, SyntaxError):
                    continue
                args[param_name] = arg_value
        except Exception:
            return pyautogui_code

        updated = False
        if "x" in args and "y" in args:
            try:
                x_rel = float(args["x"])
                y_rel = float(args["y"])
                x_abs, y_abs = _coordinate_projection(x_rel, y_rel, screen_width, screen_height, coordinate_type)
                args["x"] = x_abs
                args["y"] = y_abs
                updated = True
            except ValueError:
                pass

        if updated:
            reconstructed_args = []
            for idx, param_name in enumerate(param_names):
                if param_name in args:
                    arg_value = args[param_name]
                    if isinstance(arg_value, str):
                        arg_repr = f"'{arg_value}'"
                    else:
                        arg_repr = str(arg_value)
                    reconstructed_args.append(arg_repr)
                else:
                    break

            used_params = set(param_names[: len(reconstructed_args)])
            for kw in parsed_keywords:
                if kw.arg not in used_params:
                    arg_value = args.get(kw.arg)
                    if arg_value is not None:
                        if isinstance(arg_value, str):
                            arg_repr = f"{kw.arg}='{arg_value}'"
                        else:
                            arg_repr = f"{kw.arg}={arg_value}"
                        reconstructed_args.append(arg_repr)

            new_args_str = ", ".join(reconstructed_args)
            new_call = f"{func_name}({new_args_str})"
            new_code = new_code.replace(full_call, new_call, 1)

    return new_code


def bbox_to_center_1000(bbox: str) -> Tuple[int, int]:
    """
    Extract center coordinates from a bounding box string with 0-1000 scale.

    Args:
        bbox: Bounding box string in various formats

    Returns:
        Tuple of (x_center, y_center) in 0-1000 scale
    """
    regex_list = [
        r"<\|box_start\|>\((\d+),(\d+)\),\((\d+),(\d+)\)<\|box_end\|>",
        r"<\|box_start\|>\[\[(\d+),\s*(\d+),\s*(\d+),\s*(\d+)\]\]<\|box_end\|>",
        r"<\|box_start\|>\[\[(\d+),\s*(\d+),\s*(\d+),\s*(\d+)\]<\|box_end\|>",
        r"<\|box_start\|>\((\d+),\s*(\d+),\s*(\d+),\s*(\d+)\)<\|box_end\|>",
        r"\((\d+),(\d+)\),\((\d+),(\d+)\)",
        r"\[\[(\d+),\s*(\d+),\s*(\d+),\s*(\d+)\]\]",
        r"\[\[(\d+),\s*(\d+),\s*(\d+),\s*(\d+)\]",
        r"\((\d+),\s*(\d+),\s*(\d+),\s*(\d+)\)",
    ]

    match = None
    for regex in regex_list:
        match = re.search(regex, bbox)
        if match:
            break

    if not match:
        raise ValueError(f"Bounding box coordinates not found in the input string: {bbox}")

    x_top_left, y_top_left, x_bottom_right, y_bottom_right = map(int, match.groups())
    x_center = (x_top_left + x_bottom_right) // 2
    y_center = (y_top_left + y_bottom_right) // 2
    return x_center, y_center


def bbox_to_center_1(bbox: str) -> Tuple[int, int]:
    """
    Extract center coordinates from a bounding box string with 0-1 scale.

    Args:
        bbox: Bounding box string with float coordinates

    Returns:
        Tuple of (x_center, y_center) in 0-1000 scale
    """
    regex_list = [
        r"\[\s*(-?\d+\.\d+)\s*,\s*(-?\d+\.\d+)\s*,\s*(-?\d+\.\d+)\s*,\s*(-?\d+\.\d+)\s*\]",
    ]

    match = None
    for regex in regex_list:
        match = re.search(regex, bbox)
        if match:
            break

    if not match:
        raise ValueError(f"Bounding box coordinates not found in the input string: {bbox}")

    coordinates = tuple(map(float, match.groups()))
    coordinates_scaled = [int(coord * 1000) for coord in coordinates]
    x_center = (coordinates_scaled[0] + coordinates_scaled[2]) // 2
    y_center = (coordinates_scaled[1] + coordinates_scaled[3]) // 2
    return x_center, y_center

