import re
from copy import deepcopy
import ast
import random
import math

FINISH_WORD = "finished"
WAIT_WORD = "wait"
ENV_FAIL_WORD = "error_env"
CALL_USER = "call_user"

def new_action_to_old_action(action, image_width, image_height):
    old_format_action = {
        "type": action["type"],
        "custom": {},
        "boxes": [],
    }
    if "start_box" in action["params"]:
        start_box = deepcopy(action["params"]["start_box"])
        start_box = eval(start_box)
        if len(start_box) == 2:
            start_box = start_box + start_box
        start_box[0] *= image_width
        start_box[1] *= image_height
        start_box[2] *= image_width
        start_box[3] *= image_height
        old_format_action["boxes"].append(start_box)
    if "end_box" in action["params"]:
        end_box = deepcopy(action["params"]["end_box"])
        end_box = eval(end_box)
        if len(end_box) == 2:
            end_box = end_box + end_box
        end_box[0] *= image_width
        end_box[1] *= image_height
        end_box[2] *= image_width
        end_box[3] *= image_height
        old_format_action["boxes"].append(end_box)
    for key, value in action["params"].items():
        if key in ["type", "start_box", "end_box"]: continue
        old_format_action["custom"][key] = value
    return old_format_action

def parse_action(action_str):
    try:
        node = ast.parse(action_str, mode='eval')

        if not isinstance(node, ast.Expression):
            raise ValueError("Not an expression")

        call = node.body

        if not isinstance(call, ast.Call):
            raise ValueError("Not a function call")

        if isinstance(call.func, ast.Name):
            func_name = call.func.id
        elif isinstance(call.func, ast.Attribute):
            func_name = call.func.attr
        else:
            func_name = None

        kwargs = {}
        for kw in call.keywords:
            key = kw.arg
            if isinstance(kw.value, ast.Constant):
                value = kw.value.value
            elif isinstance(kw.value, ast.Str): 
                value = kw.value.s
            else:
                value = None
            kwargs[key] = value

        return {
            'function': func_name,
            'args': kwargs
        }

    except Exception as e:
        print(f"Failed to parse action '{action_str}': {e}")
        return None
    
def escape_single_quotes(text):
    pattern = r"(?<!\\)'"
    return re.sub(pattern, r"\\'", text)

def parse_action_qwen2vl(text, factor, image_height, image_width):
    text = text.strip()
    if text.startswith("Thought:"):
        thought_pattern = r"Thought: (.+?)(?=\s*Action:|$)"
        thought_hint = "Thought: "
    elif text.startswith("Reflection:"):
        thought_pattern = r"Reflection: (.+?)Action_Summary: (.+?)(?=\s*Action:|$)"
        thought_hint = "Reflection: "
    elif text.startswith("Action_Summary:"):
        thought_pattern = r"Action_Summary: (.+?)(?=\s*Action:|$)"
        thought_hint = "Action_Summary: "
    else:
        thought_pattern = r"Thought: (.+?)(?=\s*Action:|$)"
        thought_hint = "Thought: "
    reflection, thought = None, None
    thought_match = re.search(thought_pattern, text, re.DOTALL)
    if thought_match:
        if len(thought_match.groups()) == 1:
            thought = thought_match.group(1).strip()
        elif len(thought_match.groups()) == 2:
            thought = thought_match.group(2).strip()
            reflection = thought_match.group(1).strip()
    assert "Action:" in text
    action_str = text.split("Action:")[-1]
    action_str = action_str.replace("\n\n", "\n")
    pattern = r"type\(content='(.*?)'\)"
    matches = re.findall(pattern, action_str, re.DOTALL)
    for m in matches:
        new_m = m.replace("\n", "<special_newline>")
        action_str = action_str.replace(m, new_m)

    tmp_all_action = action_str.split("\n")
    all_action = []
    for action_str in tmp_all_action:
        if "type(content" in action_str:
            def escape_quotes(match):
                content = match.group(1)
                return content

            pattern = r"type\(content='(.*?)'\)" 
            content = re.sub(pattern, escape_quotes, action_str)

            action_str = escape_single_quotes(content)
            action_str = action_str.replace("<special_newline>", "\n")
            action_str = "type(content='" + action_str + "')"
        all_action.append(action_str)

    parsed_actions = [parse_action(action.replace("\n","\\n").lstrip()) for action in all_action]
    actions = []
    for action_instance, raw_str in zip(parsed_actions, all_action):
        if action_instance == None:
            print(f"Action can't parse: {raw_str}")
            continue
        action_type = action_instance["function"]
        params = action_instance["args"]

        # import pdb; pdb.set_trace()
        action_inputs = {}
        for param_name, param in params.items():
            if param == "": continue
            param = param.lstrip()
            action_inputs[param_name.strip()] = param
            
            if "start_box" in param_name or "end_box" in param_name:
                ori_box = param
                # Remove parentheses and split the string by commas
                numbers = ori_box.replace("(", "").replace(")", "").split(",")

                # Convert to float and scale by 1000
                float_numbers = [float(num) / factor for num in numbers]
                if len(float_numbers) == 2:
                    float_numbers = [float_numbers[0], float_numbers[1], float_numbers[0], float_numbers[1]]
                action_inputs[param_name.strip()] = str(float_numbers)

        # import pdb; pdb.set_trace()
        actions.append({
            "reflection": reflection,
            "thought": thought,
            "action_type": action_type,
            "action_inputs": action_inputs,
            "text": text
        })
    return actions

def parse_refine_coordinate_response(text, factor=1000):
    pattern = r"(start_box|end_box)='?\((\d+),(\d+)\)'?"
    
    matches = re.findall(pattern, text)
    results = {}
    
    if matches:
        for match in matches:
            key = match[0]  # start_box or end_box
            x = int(match[1])  
            y = int(match[2])  
            coordinates = (x / factor, y / factor, x / factor, y / factor) 
            results[key] = str(list(coordinates))
    
    return results

def parsing_response_to_pyautogui_code(responses, image_height: int, image_width:int, input_swap:bool=True) -> str:
    pyautogui_code = f"import pyautogui\nimport time\n"
    coordinate = None
    if isinstance(responses, dict):
        responses = [responses]
    for response_id, response in enumerate(responses):
        if "observation" in response:
            observation = response["observation"]
        else:
            observation = ""

        if "thought" in response:
            thought = response["thought"]
        else:
            thought = ""
        
        # if response_id == 0:
        #     pyautogui_code += f"'''\nObservation:\n{observation}\n\nThought:\n{thought}\n'''\n"
        # else:
        #     pyautogui_code += f"\ntime.sleep(3)\n"

        action_dict = response
        action_type = action_dict.get("action_type")
        action_inputs = action_dict.get("action_inputs", {})
        
        if action_type == "hotkey":
            # Parsing hotkey action
            if "key" in action_inputs:
                hotkey = action_inputs.get("key", "")
            else:
                hotkey = action_inputs.get("hotkey", "")

            if hotkey:
                # Handle other hotkeys
                keys = hotkey.split()  # Split the keys by space
                pyautogui_code += f"\npyautogui.hotkey({', '.join([repr(k) for k in keys])})"
            
            pyautogui_code += f"\ntime.sleep(6)\n"
        
        elif action_type == "type":
            # Parsing typing action using clipboard
            content = action_inputs.get("content", "")
            content = escape_single_quotes(content)
            if content:
                if input_swap:
                    pyautogui_code += f"\nimport pyperclip"
                    pyautogui_code += f"\npyperclip.copy('{content.strip()}')"
                    pyautogui_code += f"\npyautogui.hotkey('ctrl', 'v')"
                    pyautogui_code += f"\ntime.sleep(0.5)\n"
                    if content.endswith("\n") or content.endswith("\\n"):
                        pyautogui_code += f"\npyautogui.press('enter')"
                else:
                    pyautogui_code += f"\npyautogui.write('{content.strip()}', interval=0.1)"
                    pyautogui_code += f"\ntime.sleep(0.5)\n"
                    if content.endswith("\n") or content.endswith("\\n"):
                        pyautogui_code += f"\npyautogui.press('enter')"
            
            pyautogui_code += f"\ntime.sleep(10)\n"

        
        elif action_type in ["drag", "select"]:
            # Parsing drag or select action based on start and end_boxes
            start_box = action_inputs.get("start_box")
            end_box = action_inputs.get("end_box")
            if start_box and end_box:
                x1, y1, x2, y2 = eval(start_box)  # Assuming box is in [x1, y1, x2, y2]
                sx = round(float((x1 + x2) / 2) * image_width, 3)
                sy = round(float((y1 + y2) / 2) * image_height, 3)
                x1, y1, x2, y2 = eval(end_box)  # Assuming box is in [x1, y1, x2, y2]
                ex = round(float((x1 + x2) / 2) * image_width, 3)
                ey = round(float((y1 + y2) / 2) * image_height, 3)
                pyautogui_code += (
                    f"\npyautogui.moveTo({sx}, {sy})\n"
                    f"\npyautogui.dragTo({ex}, {ey}, duration=1.0)\n"
                )
                coordinate = (ex, ey)
            
            pyautogui_code += f"\ntime.sleep(6)\n"

        elif action_type == "scroll":
            # Parsing scroll action
            start_box = action_inputs.get("start_box")
            if start_box:
                x1, y1, x2, y2 = eval(start_box)  # Assuming box is in [x1, y1, x2, y2]
                x = round(float((x1 + x2) / 2) * image_width, 3)
                y = round(float((y1 + y2) / 2) * image_height, 3)
                coordinate = (x, y)
                
                # # 先点对应区域，再滚动
                # pyautogui_code += f"\npyautogui.click({x}, {y}, button='left')"
            else:
                x = None
                y = None
            direction = action_inputs.get("direction", "")
            
            if x == None:
                if "up" in direction.lower():
                    pyautogui_code += f"\npyautogui.scroll(5)"
                elif "down" in direction.lower():
                    pyautogui_code += f"\npyautogui.scroll(-5)"
            else:
                if "up" in direction.lower():
                    pyautogui_code += f"\npyautogui.scroll(5, x={x}, y={y})"
                elif "down" in direction.lower():
                    pyautogui_code += f"\npyautogui.scroll(-5, x={x}, y={y})"
            
            pyautogui_code += f"\ntime.sleep(6)\n"

        elif action_type in ["click", "left_single", "left_double", "right_single", "hover"]:
            # Parsing mouse click actions
            start_box = action_inputs.get("start_box")
            start_box = str(start_box)
            if start_box:
                start_box = eval(start_box)
                if len(start_box) == 4:
                    x1, y1, x2, y2 = start_box  # Assuming box is in [x1, y1, x2, y2]
                elif len(start_box) == 2:
                    x1, y1 = start_box
                    x2 = x1
                    y2 = y1
                x = round(float((x1 + x2) / 2) * image_width, 3)
                y = round(float((y1 + y2) / 2) * image_height, 3)
                coordinate = (x, y)
                if action_type == "left_single" or action_type == "click":
                    pyautogui_code += f"\npyautogui.click({x}, {y}, button='left')"
                elif action_type == "left_double":
                    pyautogui_code += f"\npyautogui.doubleClick({x}, {y}, button='left')"
                elif action_type == "right_single":
                    pyautogui_code += f"\npyautogui.click({x}, {y}, button='right')"
                elif action_type == "hover":
                    pyautogui_code += f"\npyautogui.moveTo({x}, {y})"
            
            pyautogui_code += f"\ntime.sleep(6)\n"
        
        elif action_type in ["finished"]:
            pyautogui_code = f"DONE"
        
        else:
            pyautogui_code += f"\n# Unrecognized action type: {action_type}"

    return pyautogui_code, coordinate

def parsing_response_to_workarena_action(responses, image_height: int, image_width:int, input_swap:bool=True) -> str:
    actions = []
    coordinate = None
    if isinstance(responses, dict):
        responses = [responses]
    for response_id, response in enumerate(responses):
        action_type = response.get("action_type")
        action_inputs = response.get("action_inputs", {})
        
        if action_type == "hotkey":
            # Parsing hotkey action
            if "key" in action_inputs:
                hotkey = action_inputs.get("key", "")
            else:
                hotkey = action_inputs.get("hotkey", "")

            if hotkey:
                # Handle other hotkeys
                keys = hotkey.split()  # Split the keys by space
                for key in keys:
                    actions.append(f"keyboard_down('{key}')")
                for key in reversed(keys):
                    actions.append(f"keyboard_up('{key}')")
        
        elif action_type == "type":
            # Parsing typing action using clipboard
            content = action_inputs.get("content", "")
            content = escape_single_quotes(content)
            if content:
                if input_swap:
                    # pyautogui_code += f"\nimport pyperclip"
                    # pyautogui_code += f"\npyperclip.copy('{content.strip()}')"
                    # pyautogui_code += f"\npyautogui.hotkey('ctrl', 'v')"
                    # pyautogui_code += f"\ntime.sleep(0.5)\n"
                    actions.append(f"keyboard_insert_text('{content.strip()}')")
                    if content.endswith("\n") or content.endswith("\\n"):
                        # pyautogui_code += f"\npyautogui.press('enter')"
                        actions.append(f"keyboard_press('enter')")
                else:
                    # pyautogui_code += f"\npyautogui.write('{content.strip()}', interval=0.1)"
                    # pyautogui_code += f"\ntime.sleep(0.5)\n"
                    actions.append(f"keyboard_type('{content.strip()}')")
                    if content.endswith("\n") or content.endswith("\\n"):
                        # pyautogui_code += f"\npyautogui.press('enter')"
                        actions.append(f"keyboard_press('enter')")
        
        elif action_type in ["drag", "select"]:
            # Parsing drag or select action based on start and end_boxes
            start_box = action_inputs.get("start_box")
            end_box = action_inputs.get("end_box")
            if start_box and end_box:
                x1, y1, x2, y2 = eval(start_box)  # Assuming box is in [x1, y1, x2, y2]
                sx = round(float((x1 + x2) / 2) * image_width, 3)
                sy = round(float((y1 + y2) / 2) * image_height, 3)
                x1, y1, x2, y2 = eval(end_box)  # Assuming box is in [x1, y1, x2, y2]
                ex = round(float((x1 + x2) / 2) * image_width, 3)
                ey = round(float((y1 + y2) / 2) * image_height, 3)
                # pyautogui_code += (
                #     f"\npyautogui.moveTo({sx}, {sy})\n"
                #     f"\npyautogui.dragTo({ex}, {ey}, duration=1.0)\n"
                # )
                actions.append(f"mouse_move({sx}, {sy})")
                actions.append(f"mouse_drag_and_drop({sx}, {sy}, {ex}, {ey})")
                coordinate = (ex, ey)
            
            pyautogui_code += f"\ntime.sleep(6)\n"

        elif action_type == "scroll":
            # Parsing scroll action
            start_box = action_inputs.get("start_box")
            if start_box:
                x1, y1, x2, y2 = eval(start_box)  # Assuming box is in [x1, y1, x2, y2]
                x = round(float((x1 + x2) / 2) * image_width, 3)
                y = round(float((y1 + y2) / 2) * image_height, 3)
                coordinate = (x, y)
            else:
                x = None
                y = None
            direction = action_inputs.get("direction", "")
            
            if x == None:
                if "up" in direction.lower():
                    # pyautogui_code += f"\npyautogui.scroll(5)"
                    actions.append(f"scroll(0, -720)")
                elif "down" in direction.lower():
                    # pyautogui_code += f"\npyautogui.scroll(-5)"
                    actions.append(f"scroll(0, 720)")
                elif "left" in direction.lower():
                    actions.append(f"scroll(-720, 0)")
                elif "right" in direction.lower():
                    actions.append(f"scroll(720, 0)")
            else:
                if "up" in direction.lower():
                    # pyautogui_code += f"\npyautogui.scroll(5, x={x}, y={y})"
                    actions.append(f"mouse_move({x}, {y})")
                    actions.append(f"scroll(0, -720)")
                elif "down" in direction.lower():
                    # pyautogui_code += f"\npyautogui.scroll(-5, x={x}, y={y})"
                    actions.append(f"mouse_move({x}, {y})")
                    actions.append(f"scroll(0, 720)")
                elif "left" in direction.lower():
                    actions.append(f"scroll(-720, 0)")
                elif "right" in direction.lower():
                    actions.append(f"scroll(720, 0)")

        elif action_type in ["click", "left_single", "left_double", "right_single", "hover"]:
            # Parsing mouse click actions
            start_box = action_inputs.get("start_box")
            start_box = str(start_box)
            if start_box:
                start_box = eval(start_box)
            if start_box:
                if len(start_box) == 4:
                    x1, y1, x2, y2 = start_box  # Assuming box is in [x1, y1, x2, y2]
                elif len(start_box) == 2:
                    x1, y1 = start_box
                    x2 = x1
                    y2 = y1
                else:
                    raise ValueError("No coordinates")
                x = round(float((x1 + x2) / 2) * image_width, 3)
                y = round(float((y1 + y2) / 2) * image_height, 3)
                coordinate = (x, y)
                if action_type == "left_single" or action_type == "click":
                    # pyautogui_code += f"\npyautogui.click({x}, {y}, button='left')"
                    actions.append(f"mouse_click({x}, {y}, 'left')")
                elif action_type == "left_double":
                    # pyautogui_code += f"\npyautogui.doubleClick({x}, {y}, button='left')"
                    actions.append(f"mouse_dblclick({x}, {y}, 'left')")
                elif action_type == "right_single":
                    # pyautogui_code += f"\npyautogui.click({x}, {y}, button='right')"
                    actions.append(f"mouse_click({x}, {y}, 'right')")
                elif action_type == "hover":
                    # pyautogui_code += f"\npyautogui.moveTo({x}, {y})"
                    actions.append(f"mouse_move({x}, {y})")
        
        elif action_type in ["finished"]:
            actions = [f"DONE"]
        
        else:
            raise ValueError(f"Unrecognized action type: {action_type}")

    return actions, coordinate

IMAGE_FACTOR = 28
MIN_PIXELS = 100 * 28 * 28
MAX_PIXELS = 16384 * 28 * 28
MAX_RATIO = 200
def parse_action_to_structure_output(text, factor, origin_resized_height, origin_resized_width, model_type, max_pixels=16384*28*28, min_pixels=100*28*28):
    text = text.strip()
    if model_type == "qwen25vl":
        smart_resize_height, smart_resize_width = smart_resize(origin_resized_height, origin_resized_width, factor=IMAGE_FACTOR, min_pixels=min_pixels, max_pixels=max_pixels)

    if text.startswith("Thought:"):
        thought_pattern = r"Thought: (.+?)(?=\s*Action:|$)"
        thought_hint = "Thought: "
    elif text.startswith("Reflection:"):
        thought_pattern = r"Reflection: (.+?)Action_Summary: (.+?)(?=\s*Action:|$)"
        thought_hint = "Reflection: "
    elif text.startswith("Action_Summary:"):
        thought_pattern = r"Action_Summary: (.+?)(?=\s*Action:|$)"
        thought_hint = "Action_Summary: "
    else:
        thought_pattern = r"Thought: (.+?)(?=\s*Action:|$)"
        thought_hint = "Thought: "
    reflection, thought = None, None
    thought_match = re.search(thought_pattern, text, re.DOTALL)
    if thought_match:
        if len(thought_match.groups()) == 1:
            thought = thought_match.group(1).strip()
        elif len(thought_match.groups()) == 2:
            thought = thought_match.group(2).strip()
            reflection = thought_match.group(1).strip()
    assert "Action:" in text
    action_str = text.split("Action:")[-1]

    tmp_all_action = action_str.split("\n\n")
    all_action = []
    for action_str in tmp_all_action:
        if "type(content" in action_str:
            def escape_quotes(match):
                content = match.group(1) 
                return content

            pattern = r"type\(content='(.*?)'\)"
            content = re.sub(pattern, escape_quotes, action_str)

            action_str = escape_single_quotes(content)
            action_str = "type(content='" + action_str + "')"
        all_action.append(action_str)

    parsed_actions = [parse_action(action.replace("\n","\\n").lstrip()) for action in all_action]
    actions = []
    for action_instance, raw_str in zip(parsed_actions, all_action):
        if action_instance == None:
            print(f"Action can't parse: {raw_str}")
            raise ValueError(f"Action can't parse: {raw_str}") 
        action_type = action_instance["function"]
        params = action_instance["args"]

        # import pdb; pdb.set_trace()
        action_inputs = {}
        for param_name, param in params.items():
            if param == "": continue
            param = param.lstrip()
            action_inputs[param_name.strip()] = param
            
            if "start_box" in param_name or "end_box" in param_name:
                ori_box = param
                # Remove parentheses and split the string by commas
                numbers = ori_box.replace("(", "").replace(")", "").split(",")

                # Convert to float and scale by 1000
                # Qwen2.5vl output absolute coordinates, qwen2vl output relative coordinates
                if model_type == "qwen25vl":
                    float_numbers = []
                    for num_idx, num in enumerate(numbers):
                        num = float(num)
                        if (num_idx + 1) % 2 == 0:
                            float_numbers.append(float(num/smart_resize_height))
                        else:
                            float_numbers.append(float(num/smart_resize_width))
                else:
                    float_numbers = [float(num) / factor for num in numbers]

                if len(float_numbers) == 2:
                    float_numbers = [float_numbers[0], float_numbers[1], float_numbers[0], float_numbers[1]]
                action_inputs[param_name.strip()] = str(float_numbers)

        # import pdb; pdb.set_trace()
        actions.append({
            "reflection": reflection,
            "thought": thought,
            "action_type": action_type,
            "action_inputs": action_inputs,
            "text": text
        })
    return actions


def round_by_factor(number: int, factor: int) -> int:
    """Returns the closest integer to 'number' that is divisible by 'factor'."""
    return round(number / factor) * factor


def ceil_by_factor(number: int, factor: int) -> int:
    """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
    return math.ceil(number / factor) * factor


def floor_by_factor(number: int, factor: int) -> int:
    """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
    return math.floor(number / factor) * factor

def smart_resize(
    height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
) -> tuple[int, int]:
    """
    Rescales the image so that the following conditions are met:

    1. Both dimensions (height and width) are divisible by 'factor'.

    2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].

    3. The aspect ratio of the image is maintained as closely as possible.
    """
    if max(height, width) / min(height, width) > MAX_RATIO:
        raise ValueError(
            f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}"
        )
    h_bar = max(factor, round_by_factor(height, factor))
    w_bar = max(factor, round_by_factor(width, factor))
    if h_bar * w_bar > max_pixels:
        beta = math.sqrt((height * width) / max_pixels)
        h_bar = floor_by_factor(height / beta, factor)
        w_bar = floor_by_factor(width / beta, factor)
    elif h_bar * w_bar < min_pixels:
        beta = math.sqrt(min_pixels / (height * width))
        h_bar = ceil_by_factor(height * beta, factor)
        w_bar = ceil_by_factor(width * beta, factor)
    return h_bar, w_bar

if __name__ == '__main__':
    # Example usage

    mock_response = f"""Action: finish(content='(873,667)')"""
    mock_response = mock_response.replace("Thought:", "Action_Summary:")
    
    mock_response_dict = parse_action_qwen2vl(mock_response, 1000, 1080, 1920, "bc")
    print(mock_response_dict)
    input()
    rc_response = parse_refine_coordinate_response("drag(start_box='(579,853)', end_box='(607,853)')")
    print(rc_response)


    response_dict = parsing_response_to_pyautogui_code(mock_response_dict, 1080, 1920)
    print(response_dict)