import ast
import re
from io import BytesIO
from typing import Dict
import math
import os
import json
import numpy as np
from PIL import Image
from transformers import AutoTokenizer, AutoProcessor
from qwen_vl_utils import smart_resize, process_vision_info
from .qwen25vl_prompts import SYSTEM_PROMPT_MOBILE, SYSTEM_PROMPT_PC, USER_PROMPT, RESOLUTION_TOKEN
import logging


# SCREEN_LOGIC_SIZE = (1228, 2700)

FINISH_WORD = "terminate"
WAIT_WORD = "wait"


def parse_action_to_dict(actions_str: list[str], orig_img_size: tuple, resized_img_size: tuple):
    scale_factor_x = orig_img_size[0] / resized_img_size[0]
    scale_factor_y = orig_img_size[1] / resized_img_size[1]
    logging.debug(f'Parsing action: {actions_str}')
    logging.debug(f'Scale factor: ({scale_factor_x}, {scale_factor_y}), orig_img_size: ({orig_img_size[0]}, {orig_img_size[1]}), resized_img_size: ({resized_img_size[0]}, {resized_img_size[1]})')

    def scale_coordinate(coord: list[int]):
        scaled_coord = [int(coord[0] * scale_factor_x), int(coord[1] * scale_factor_y)]
        logging.debug(f"({coord[0]}, {coord[1]}) -> ({scaled_coord[0]}, {scaled_coord[1]})")
        return scaled_coord

    actions = []
    for action_str in actions_str:
        action = json.loads(action_str)["arguments"]
        if "coordinate" in action:
            action["coordinate"] = scale_coordinate(action["coordinate"])
        if "coordinate2" in action:
            action["coordinate2"] = scale_coordinate(action["coordinate2"])
        actions.append(action)
    return actions


def map_mobile_action_to_uitars15(action_dict: dict, screen_size: tuple[int]):
    center_point = [screen_size[0] // 2, screen_size[1] // 2]
    if action_dict["action"] == "click":
        x, y = action_dict["coordinate"]
        action_str = f"click(start_box=({x}, {y}))"
    elif action_dict["action"] == "long_press":
        x, y = action_dict["coordinate"]
        action_str = f"long_press(start_box=({x}, {y}))"
    elif action_dict["action"] == "swipe":
        if "coordinate" in action_dict and "coordinate2" in action_dict:
            x1, y1 = action_dict["coordinate"]
            x2, y2 = action_dict["coordinate2"]
            # action_str = f"scroll(start_box=({x1}, {y1}),end_box=({x2}, {y2}))"
            if abs(x1 - x2) < abs(y1 - y2):
                # 云真机滑动方向与正常相反
                if y1 >= y2:
                    action_str = f"scroll(start_box=({x1}, {y1}),direction='down')"
                else:
                    action_str = f"scroll(start_box=({x1}, {y1}),direction='up')"
            else:
                if x1 >= x2:
                    action_str = f"scroll(start_box=({x1}, {y1}),direction='left')"
                else:
                    action_str = f"scroll(start_box=({x1}, {y1}),direction='right')"
        elif "direction" in action_dict:
            action_str = (
                f'scroll(start_box=({center_point[0]}, {center_point[1]}),direction="{action_dict["direction"]}")'
            )
        else:
            # 默认上滑
            action_str = f'scroll(start_box=({center_point[0]}, {center_point[1]}),direction="up")'
    elif action_dict["action"] == "type":
        content = action_dict["text"]
        action_str = f"type(content='{content}')"
    elif action_dict["action"] == "system_button":
        button = action_dict["button"].lower()
        if button == "home":
            action_str = "press_home()"
        else:
            action_str = "press_back()"
    elif action_dict["action"] == "open":
        app = action_dict["text"]
        action_str = f'do(action="Launch", app="{app}")'
    else:
        logging.error(f"unrecognized action: {str(action_dict)}")
        action_str = "WAIT"
    return action_str


def map_pc_action_to_uitars15(action_dict):
    raise NotImplementedError


class UITronAgent:
    def __init__(
        self,
        tokenizer_path,
        max_trajectory_length=15,
        history_n=5,
        action_space="computer",
    ):
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True, use_fast=False)
        self.processor = AutoProcessor.from_pretrained(tokenizer_path)
        # self.tokenizer = self.processor.tokenizer
        self.max_trajectory_length = max_trajectory_length
        self.history_n = history_n
        self.screen_size = None
        self.model_input_img_size = None
        self.action_space = action_space

        self.customize_action_parser = parse_action_to_dict  # 用于将模型输出的动作转换为dict格式

        if action_space == "mobile":
            # self.prompt_action_space = UITARS_MOBILE_ACTION_SPACE
            # self.action_code_mapper = parsing_response_to_android_action_code
            self.system_prompt = SYSTEM_PROMPT_MOBILE
            self.action_code_mapper = map_mobile_action_to_uitars15  # 用于将dict格式的动作转换到uitars15格式，供env调用
        else:  # computer
            # self.action_code_mapper = parsing_response_to_pyautogui_code
            self.system_prompt = SYSTEM_PROMPT_PC
            # TODO: 编写pc的动作解析函数
            raise NotImplementedError

        self.reset()

    def get_model_inputs(self, instruction: str, obs: Dict):
        if len(self.history_actions) == 0:
            history_action_str = "None"
        else:
            history_action_str = "\n".join(f"Step {i}: {act}" for i, act in enumerate(self.history_actions, start=1))
        user_prompt = USER_PROMPT.format(instruction=instruction, history=history_action_str)

        # max_pixels = int(os.environ.get("MAX_PIXELS", 1058400))  # 1350 * 28 * 28
        # min_pixels = int(os.environ.get("MIN_PIXELS", 78400))  # 100 * 28 * 28
        max_pixels = int(os.environ.get("MAX_PIXELS", 937664))
        min_pixels = int(os.environ.get("MIN_PIXELS", 200704))
        logging.debug(f'min_pixels={min_pixels}, max_pixels={max_pixels}')

        if isinstance(obs["screenshot"], bytes):
            obs["screenshot"] = Image.open(BytesIO(obs["screenshot"]))
        self.screen_size = obs["screenshot"].size
        resized_h, resized_w = smart_resize(
            self.screen_size[1], self.screen_size[0], min_pixels=min_pixels, max_pixels=max_pixels
        )
        self.model_input_img_size = [resized_w, resized_h]

        messages = [
            {
                "role": "system",
                "content": self.system_prompt.replace(
                    RESOLUTION_TOKEN, f"{self.model_input_img_size[0]}x{self.model_input_img_size[1]}"
                ),
            },
            {
                "role": "user",
                "content": [
                    {"type": "image", "image": obs["screenshot"], "max_pixels": max_pixels, "min_pixels": min_pixels},
                    {"type": "text", "text": user_prompt},
                ],
            },
        ]
        logging.debug(f"=====messages: {messages}")

        prompt_text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        inputs = {"prompt": prompt_text, "multi_modal_data": {"image": [obs["screenshot"]]}}
        return inputs

    def parse_action(self, response: str):
        # if response.count('</tool_call>') == 2:
        #     # 有时模型会输出"</tool_call>...</tool_call>"
        #     response = response.replace('</tool_call>', '<tool_call>', 1)
        
        if "<tool_call>" not in response:
            return ["FAIL"]
        self.history_responses.append(response)

        try:
            thoughts = re.findall(r"<think>(.*)</think>", response, re.DOTALL)[0].strip()
        except:
            thoughts = "None"
        # if "<observation>" in response:
        #     observation = re.findall(r"<observation>(.*)</observation>", response, re.DOTALL)[0].strip()
        # else:
        #     observation = "None"
        try:
            implement = re.findall(r"<conclusion>(.*)</conclusion>", response, re.DOTALL)[0].strip()
        except:
            implement = "None"
        try:
            actions_str = re.findall(r"<tool_call>(.*)</tool_call>", response, re.DOTALL)[0].strip().split("\n")
        except:
            return ["FAIL"]

        self.thoughts.append(thoughts)
        self.history_actions.append(implement)

        # print(f"hi here :{response}")
        logging.debug(f"Screen size is {self.screen_size[0]}x{self.screen_size[1]}")

        try:
            parsed_actions = self.customize_action_parser(actions_str, self.screen_size, self.model_input_img_size)
            # print(f'self.screen_size: {self.screen_size}')
            logging.debug(f"parsed_actions: {parsed_actions}")
        except Exception as e:
            logging.error(f"Parsing action error: {response}, with error:\n{e}")
            return ["DONE"]

        actions = []
        for parsed_response in parsed_actions:
            if "action" in parsed_response:
                if parsed_response["action"] == FINISH_WORD:
                    self.actions.append(actions)
                    if parsed_response.get("status", "success") == "success":
                        return ["DONE"]
                    else:
                        return ["FAIL"]
                elif parsed_response["action"] == WAIT_WORD:
                    self.actions.append(actions)
                    return ["WAIT"]

            try:
                action_code = self.action_code_mapper(parsed_response, self.screen_size)
                actions.append(action_code)
            except Exception as e:
                logging.error(f"Parsing pyautogui code error: {parsed_response}, with error:\n{e}")

        self.actions.append(actions)

        if len(self.history_responses) >= self.max_trajectory_length:
            # Default to FAIL if exceed max steps
            actions = ["FAIL"]

        return actions

    def reset(self):
        self.thoughts = []
        self.actions = []
        # self.observations = []
        # self.history_images = []
        self.history_responses = []
        self.history_actions = []
        self.model_input_img_size = None
        self.screen_size = None
