import sys
from transformers import Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration
from qwen_vl_utils import process_vision_info
import torch, json, ast
import cv2
import numpy as np
from PIL import Image
sys.path.append('./')
from QwenVL import Qwen2_VL_Agent
from utils.result_preprocess import UI_TARS_RES_PREPROCESS

class UI_TARS_Agent(Qwen2_VL_Agent):
    def __init__(self, device, accelerator, policy_lm, *args, **kwargs):
        super().__init__(device=device, accelerator=accelerator, policy_lm=policy_lm, *args, **kwargs)
        self.policy_lm = policy_lm
        self.res_pre_process = self._res_pre_process()
        self.tokenizer.pad_token
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
    

    def _load_model(self):
        if "UI-TARS-7B-DPO" in self.policy_lm or "UI-TARS-7B-SFT" in self.policy_lm or "UI-TARS-2B-SFT" in self.policy_lm or "UI-TARS-72B-SFT" in self.policy_lm or "UI-TARS-72B-DPO" in self.policy_lm:
            self.model = Qwen2VLForConditionalGeneration.from_pretrained(
                self.policy_lm, 
                torch_dtype=torch.bfloat16, 
                device_map="auto", 
                trust_remote_code=True,
                attn_implementation="flash_attention_2",
            )
        else:
            self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
                self.policy_lm, 
                torch_dtype=torch.bfloat16, 
                device_map="auto", 
                trust_remote_code=True,
                attn_implementation="flash_attention_2",
            )
        # self.model = self.model.to(self.device)
        self.model.config.pad_token_id = self.model.config.eos_token_id
        return self.model
    
    def _res_pre_process(self):
        return UI_TARS_RES_PREPROCESS()
    
    def get_action(self, obs, args):        
        messages = obs['messages']
        
        # 过滤掉 assistant 消息（ground truth），推理时不应该包含
        # 只保留 system 和 user 消息用于构造 prompt
        filtered_messages = [msg for msg in messages if msg.get('role') != 'assistant']
        
        if not filtered_messages:
            return ""

        chat_text = self.processor.apply_chat_template(
                    filtered_messages, tokenize=False, add_generation_prompt=True
                )
        image_inputs, video_inputs = process_vision_info(filtered_messages)
        if args.thought == 'false':
            chat_text = chat_text.rsplit("<|im_end|>", 1)[0].strip()
        if args.probing_method == 'visual_mask':
            image_inputs[-1] = self.visual_mask(
                image_inputs[-1].copy(),
                obs=obs,
                args=args
            )
            output = self._get_action(chat_text, image_inputs, video_inputs, obs)
            # 确保返回字符串
            return str(output) if output is not None else ""
        elif args.probing_method == 'zoom':
            image_inputs[-1], label = self.zoom_in(image_inputs[-1].copy(), obs=obs)
            x, y = label['label'][0], label['label'][1]
            label['label'] = f"click(start_box='({x},{y})')" 
            output = self._get_action(chat_text, image_inputs, video_inputs, obs)
            # zoom 方法返回元组，但评估脚本期望字符串，只返回 output
            return str(output) if output is not None else ""
        elif args.probing_method == 'visual_edit':
            image_inputs[-1] = self.visual_mask(
                image_inputs[-1].copy(),
                obs=obs,
                args=args
            )
            output = self._get_action(chat_text, image_inputs, video_inputs, obs)
            return str(output) if output is not None else ""
        elif args.probing_method == 'structure_mask' and args.dataset_type == 'action_shortcuts':
            image_inputs = [self.structure_mask(img.copy()) for img in image_inputs]
            output = self._get_action(chat_text, image_inputs, video_inputs, obs)
            return str(output) if output is not None else ""
        else:
            output = self._get_action(chat_text, image_inputs, video_inputs, obs)
            return str(output) if output is not None else ""
        
    def _get_action(self, chat_text, image_inputs, video_inputs, obs):
        inputs = self.processor(
                    text=[chat_text],
                    images=[image_inputs],
                    videos=video_inputs,
                    padding=True,
                    return_tensors="pt",
                )
     
        self.device = self.model.device
        
        inputs = inputs.to(self.device)
    
        with torch.no_grad():
            try:
                generated_ids = self.model.generate(
                    **inputs, 
                    max_new_tokens=128,
                    pad_token_id=self.tokenizer.eos_token_id,
                    eos_token_id=self.tokenizer.eos_token_id,
                    do_sample=False,  # 使用贪心解码
                    temperature=None,  # 禁用采样
                )
                if isinstance(generated_ids, torch.Tensor):
                    generated_ids = generated_ids.to(self.device)
                else:
                    generated_ids = generated_ids.sequences
            except Exception as e:
                print(f"[UI_TARS Generation Error]: {e}")
                return ""
        
        generated_ids_trimmed = [
                    out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
                ]
        
        # 检查生成的token是否为空
        if len(generated_ids_trimmed) == 0 or len(generated_ids_trimmed[0]) == 0:
            return ""
        
        # 先尝试不跳过特殊token解码，看看是否有内容
        output_text_with_special = self.processor.batch_decode(
                    generated_ids_trimmed, skip_special_tokens=False, clean_up_tokenization_spaces=False
                )
        
        # 再尝试跳过特殊token解码
        output_text = self.processor.batch_decode(
                    generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
                )
        
        # 确保 output_text 不为空
        if not output_text or len(output_text) == 0:
            # 如果跳过特殊token后为空，尝试使用包含特殊token的版本
            if output_text_with_special and len(output_text_with_special) > 0:
                raw_output = output_text_with_special[0] if isinstance(output_text_with_special, list) else output_text_with_special
                # 手动移除常见的特殊token标记
                if raw_output:
                    raw_output = raw_output.replace("<|im_end|>", "").replace("<|endoftext|>", "").strip()
                    if raw_output:
                        return raw_output
            return ""
        
        raw_output = output_text[0] if isinstance(output_text, list) else output_text
        if not raw_output or (isinstance(raw_output, str) and raw_output.strip() == ""):
            return ""
        
        if self.res_pre_process.get_action_type(raw_output) == 1 and "UI-TARS-1.5-7B" in self.policy_lm:
            image_size = obs.get('image_size')
            if image_size:
                try:
                    scale_cord = self._get_scale(inputs, image_size)
                    pred_coord = self.res_pre_process.extract_coordinates(raw_output)
                    if pred_coord and len(pred_coord) >= 2:
                        pred_coord = f"click(start_box='({pred_coord[0]*scale_cord[0]}, {pred_coord[1]*scale_cord[1]})')"
                        return pred_coord
                except Exception as e:
                    # 如果坐标提取或缩放失败，返回原始输出
                    pass
        return raw_output
        
    
    def _get_scale(self, inputs, image_size):
        resized_height = inputs['image_grid_thw'][0][1] * self.processor.image_processor.patch_size
        resized_width = inputs['image_grid_thw'][0][2] * self.processor.image_processor.patch_size
              
        # 处理image_size可能是tuple或PIL Image对象
        if hasattr(image_size, 'size'):
            origin_height = image_size.size[1]
            origin_width = image_size.size[0]
        elif isinstance(image_size, (list, tuple)) and len(image_size) >= 2:
            origin_width = image_size[0]
            origin_height = image_size[1]
        else:
            # 默认尺寸
            origin_width = 1080
            origin_height = 1920
            
        scale_x = origin_width / resized_width
        scale_y = origin_height / resized_height

        return (scale_x, scale_y)
    

    def visual_mask(self, image_input, obs, args):
        from PIL import ImageDraw
        draw = ImageDraw.Draw(image_input)
        image_width, image_height = image_input.size[0], image_input.size[1]
        if obs.get('dataset_name') == 'AndroidControl':
            accessibility_trees_file_path = obs['accessibility_trees']
            bbox_data = []
            with open(accessibility_trees_file_path, "r", encoding="utf-8") as file:
                for line in file:
                  try:
                      obj = json.loads(line)
                      bbox = obj.get("bbox_pixels", None)
                      class_name = obj.get("class_name", None)
                      if bbox and (class_name == 'android.widget.ImageButton' or class_name == 'android.widget.TextView' or class_name == 'android.widget.ImageView') and obj.get("is_clickable"):
                          x_min, y_min, x_max, y_max = bbox["x_min"], bbox["y_min"], bbox["x_max"], bbox["y_max"]
                          if (
                              0 <= x_min < x_max <= image_width and
                              0 <= y_min < y_max <= image_height
                          ):
                              bbox_data.append([x_min, y_min, x_max-x_min, y_max-y_min])
                  except Exception:
                      continue
        elif obs.get('dataset_name') == 'AITZ':
            accessibility_trees_file_path = obs.get('accessibility_trees')
            bbox_data = []
            if accessibility_trees_file_path:
                try:
                    with open(accessibility_trees_file_path, "r", encoding="utf-8") as file:
                        accessibility_trees_file_data = json.load(file)
                    images = obs.get('images', [])
                    if images and len(images) > 0:
                        image_path = images[0] if isinstance(images, list) else images
                        for idx, acc_data in enumerate(accessibility_trees_file_data):
                            if acc_data.get('image_path') and image_path and acc_data['image_path'] in image_path:
                                bbox = ast.literal_eval(accessibility_trees_file_data[idx]['ui_positions'])
                                bbox_data = [[y, x, h, w] for (x, y, w, h) in bbox]
                                break
                except Exception as e:
                    pass
        else:
            bbox_data = obs.get('bbox')
            if bbox_data and isinstance(bbox_data, (list, tuple)) and len(bbox_data) >= 4:
                bbox_data = [bbox_data[0]/1000*image_width, bbox_data[1]/1000*image_height, bbox_data[2]/1000*image_width, bbox_data[3]/1000*image_height]
                bbox_data = [[bbox_data[0], bbox_data[1], bbox_data[2]-bbox_data[0], bbox_data[3]-bbox_data[1]]]
            else:
                bbox_data = []
        gt = self.res_pre_process.extract_action(obs['label'])
        gt = self.res_pre_process.extract_coordinates(gt)
        _, bbox_list, point = self.remove_containing_bboxes(bbox_list=bbox_data, gt=gt, image_size=[image_width, image_height], args=args) 
        if args.probing_method == 'visual_mask':
            if len(bbox_list) > 0:
                for bbox in bbox_list:
                    x, y, w, h = bbox
                    draw.rectangle([x-20, y-20, x+w+20, y+h+20], fill="black")
            else:
                r = args.mask_object_ratio
                draw.rectangle([point[0]-r, point[1]-r, point[0]+r, point[1]+r], fill="black")
        else:
            image_cv = np.array(image_input)
            image_input = cv2.cvtColor(image_cv, cv2.COLOR_RGB2BGR)
            if len(bbox_list) > 0:
                for bbox in bbox_list:
                    mask = np.zeros(image_input.shape[:2], dtype=np.uint8)
                    mask[int(bbox[1]):int(bbox[3]), int(bbox[0]):int(bbox[2])] = 255
                    image_input = cv2.inpaint(image_input, mask, 3, cv2.INPAINT_TELEA)
            r = args.mask_object_ratio
            mask = np.zeros(image_input.shape[:2], dtype=np.uint8)
            x, y = point
            x_min = int(x - r)
            y_min = int(y - r)
            x_max = int(x + r)
            y_max = int(y + r)
            mask[y_min:y_max, x_min:x_max] = 255
            image_input = cv2.inpaint(image_input, mask, 3, cv2.INPAINT_TELEA)
            image_input = Image.fromarray(image_input)
            
        return image_input

    def remove_containing_bboxes(self, bbox_list, gt, image_size, args):
        if '1.5' in args.model_name:
            click_x, click_y = gt[0], gt[1]
        else:
            click_x, click_y = gt[0] / 1000 * image_size[0], gt[1] / 1000 * image_size[1]
        out_bbox_list = []
        in_bbox_list = []
        if len(bbox_list) > 0:
            for bbox in bbox_list:
                x, y, w, h = bbox
                if not (x <= click_x <= x+w and y <= click_y <= y+h):
                    out_bbox_list.append(bbox)
                else:
                    in_bbox_list.append(bbox)
        return out_bbox_list, in_bbox_list, (click_x, click_y)
    
    def zoom_in(self, pil_image, obs):
        from PIL import Image

        try:
            content = obs['label']
        except (IndexError, KeyError, TypeError):
            raise ValueError("Invalid message format in obs")

        ground_truth = self.res_pre_process.extract_action(content)
        task = self.res_pre_process.get_action_type(ground_truth)
        bbox = obs.get("bbox")  # [x_min, y_min, x_max, y_max]

        w, h = pil_image.size

        if task == 1:
            # 原始点击点
            click_x, click_y = self.res_pre_process.extract_coordinates(ground_truth)
            click_x = click_x / 1000 * w
            click_y = click_y / 1000 * h

            # 计算象限区域
            mid_x, mid_y = w // 2, h // 2
            if click_x < mid_x and click_y < mid_y:
                region = (0, 0, mid_x, mid_y)
            elif click_x >= mid_x and click_y < mid_y:
                region = (mid_x, 0, w, mid_y)
            elif click_x < mid_x and click_y >= mid_y:
                region = (0, mid_y, mid_x, h)
            else:
                region = (mid_x, mid_y, w, h)

            # 裁剪并缩放
            cropped = pil_image.crop(region)
            zoomed_image = cropped.resize((w, h), Image.LANCZOS)

            # 缩放变换函数
            def transform_coord(x, y, region, w, h):
                rel_x, rel_y = x - region[0], y - region[1]
                scale_x = w / (region[2] - region[0])
                scale_y = h / (region[3] - region[1])
                new_x = int(rel_x * scale_x)
                new_y = int(rel_y * scale_y)
                return new_x, new_y

            # 点击点映射
            new_click_x, new_click_y = transform_coord(click_x, click_y, region, w, h)
            norm_click_x = new_click_x / w * 1000
            norm_click_y = new_click_y / h * 1000
            

            # bbox 映射
            new_bbox = None
            if bbox is not None:
                bbox = [bbox[0]/1000*w, bbox[1]/1000*h, bbox[2]/1000*w, bbox[3]/1000*h]
                x_min, y_min = transform_coord(bbox[0], bbox[1], region, w, h)
                x_max, y_max = transform_coord(bbox[2], bbox[3], region, w, h)
                new_bbox = [x_min/w*1000, y_min/h*1000, x_max/w*1000, y_max/h*1000]

            return zoomed_image, {"label": [norm_click_x, norm_click_y], "bbox": new_bbox}

        return pil_image, None
    

    def structure_mask(self, image):
        masked_image = np.zeros_like(image)
        image_input = Image.fromarray(masked_image)
        return image_input