import sys
from qwen_vl_utils import process_vision_info
from transformers import Qwen2VLForConditionalGeneration, Qwen2VLProcessor
from io import BytesIO
from typing import List, Literal, Optional
import requests, ast, json
from PIL import Image
import numpy as np
import cv2
import torch
sys.path.append("./")
from gui_speaker.models.QwenVL import Qwen2_VL_Agent
from utils.schema.aguvisConstants import agent_system_message, chat_template, grounding_system_message, until
from utils.result_preprocess import AGUVIS_RES_PREPROCESS

class Aguvis_Agent(Qwen2_VL_Agent):
    def __init__(self, device, accelerator, cache_dir='~/.cache', dropout=0.5, policy_lm=None):
        super().__init__(device, accelerator, cache_dir, dropout, policy_lm)
        self.policy_lm = policy_lm
        self.res_pre_process = self._res_pre_process()
        self.tokenizer.pad_token
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
        self.generation_config = {
            "max_new_tokens": 1024,
            "do_sample": True,
            "temperature": 0,
        }

    def _load_model(self):
        self.model = Qwen2VLForConditionalGeneration.from_pretrained(
            self.policy_lm, 
            torch_dtype=torch.bfloat16, 
            device_map="auto", 
            trust_remote_code=True,
            attn_implementation="flash_attention_2",
        )
        self.model.config.pad_token_id = self.model.config.eos_token_id
        return self.model
    
    def _res_pre_process(self):
        return AGUVIS_RES_PREPROCESS()
    
    def get_action(self, obs, args): 
        user_messages = obs['messages']
        system_message = {
            "role": "system",
            "content": grounding_system_message if obs['mode'] == "grounding" else agent_system_message,
        }
        if args.probing_method == 'visual_mask':
            image = Image.open(obs['images'][0]).convert('RGB')
            user_messages['content'][0]['image'] = self.visual_mask(image, obs, args)
            output = self._get_action(obs, system_message, user_messages)
            return output
        elif args.probing_method == 'zoom':
            image = Image.open(obs['images'][0]).convert('RGB')
            user_messages['content'][0]['image'], label = self.zoom_in(image, obs)
            label['label'] = f"assistantos\npyautogui.click(x={label['label'][0]}, y={label['label'][1]})"
            output = self._get_action(obs, system_message, user_messages)
            return output, label
        elif args.probing_method == 'visual_edit':
            image = Image.open(obs['images'][0]).convert('RGB')
            user_messages['content'][0]['image'] = self.visual_mask(image, obs, args)
            output = self._get_action(obs, system_message, user_messages)
            return output
        elif args.probing_method == 'structure_mask' and args.dataset_type == 'action_shortcuts':
            image = Image.open(obs['images'][0]).convert('RGB')
            user_messages['content'][0]['image'] = self.structure_mask(image)
            output = self._get_action(obs, system_message, user_messages)
            return output
        else:
            image = Image.open(obs['images'][0]).convert('RGB')
            user_messages['content'][0]['image'] = Image.open(obs['images'][0]).convert('RGB')
            output = self._get_action(obs, system_message, user_messages)
            return output
        

    def _get_action(self, obs, system_message, user_messages):
        if obs["is_low_level_instruction"]:
            recipient_text = f"<|im_start|>assistant<|recipient|>all\nAction: {obs['low_level_instruction']}\n"
        elif obs['mode'] == "grounding":
            recipient_text = "<|im_start|>assistant<|recipient|>os\n"
        elif obs['mode'] == "self-plan":
            recipient_text = "<|im_start|>assistant<|recipient|>"
        elif obs['mode'] == "force-plan":
            recipient_text = "<|im_start|>assistant<|recipient|>all\nThought: "
        else:
            raise ValueError(f"Invalid mode: {obs['mode']}")
        messages = [system_message, user_messages]
        text = self.processor.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=False, chat_template=chat_template
        )
        text += recipient_text
        image_inputs, video_inputs = process_vision_info(messages)
        inputs =self.processor(text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt")
        inputs = inputs.to(self.model.device)
        cont = self.model.generate(**inputs, max_new_tokens=self.generation_config.get("max_new_tokens", 1024), pad_token_id=self.tokenizer.eos_token_id)
        cont_toks = cont.tolist()[0][len(inputs.input_ids[0]) :]
        text_outputs = self.tokenizer.decode(cont_toks, skip_special_tokens=True).strip()
        for term in until:
            if len(term) > 0:
                text_outputs = text_outputs.split(term)[0]
        
        return text_outputs
    
    def visual_mask(self, image_input, obs, args):
        from PIL import ImageDraw
        draw = ImageDraw.Draw(image_input)
        image_width, image_height = image_input.size[0], image_input.size[1]
        if obs.get('dataset_name') == 'AndroidControl':
            accessibility_trees_file_path = obs['accessibility_trees']
            bbox_data = []
            with open(accessibility_trees_file_path, "r", encoding="utf-8") as file:
                for line in file:
                  try:
                      obj = json.loads(line)
                      bbox = obj.get("bbox_pixels", None)
                      class_name = obj.get("class_name", None)
                      if bbox and (class_name == 'android.widget.ImageButton' or class_name == 'android.widget.TextView' or class_name == 'android.widget.ImageView') and obj.get("is_clickable"):
                          x_min, y_min, x_max, y_max = bbox["x_min"], bbox["y_min"], bbox["x_max"], bbox["y_max"]
                          if (
                              0 <= x_min < x_max <= image_width and
                              0 <= y_min < y_max <= image_height
                          ):
                              bbox_data.append([x_min, y_min, x_max-x_min, y_max-y_min])
                  except Exception:
                      continue
        elif obs.get('dataset_name') == 'AITZ':
            accessibility_trees_file_path = obs['accessibility_trees']
            bbox_data = []
            with open(accessibility_trees_file_path, "r", encoding="utf-8") as file:
                accessibility_trees_file_data = json.load(file)
            for idx, acc_data in enumerate(accessibility_trees_file_data):
                if acc_data['image_path'] in obs.get('images')[0]:
                    bbox = ast.literal_eval(accessibility_trees_file_data[idx]['ui_positions'])
            bbox_data = [[y, x, h, w] for (x, y, w, h) in bbox]
        else:
            bbox_data = obs.get('bbox')
            bbox_data = [bbox_data[0]/1000*image_width, bbox_data[1]/1000*image_height, bbox_data[2]/1000*image_width, bbox_data[3]/1000*image_height]
            bbox_data = [[bbox_data[0], bbox_data[1], bbox_data[2]-bbox_data[0], bbox_data[3]-bbox_data[1]]]
        gt = self.res_pre_process.extract_action(obs['label'])
        gt = self.res_pre_process.extract_coordinates(gt)
        _, bbox_list, point = self.remove_containing_bboxes(bbox_list=bbox_data, gt=gt, image_size=[image_width, image_height]) 
        if args.probing_method == 'visual_mask':
            if len(bbox_list) > 0:
                for bbox in bbox_list:
                    x, y, w, h = bbox
                    draw.rectangle([x-20, y-20, x+w+20, y+h+20], fill="black")
            else:
                r = args.mask_object_ratio
                draw.rectangle([point[0]-r, point[1]-r, point[0]+r, point[1]+r], fill="black")
        else:
            image_cv = np.array(image_input)
            image_input = cv2.cvtColor(image_cv, cv2.COLOR_RGB2BGR)
            if len(bbox_list) > 0:
                for bbox in bbox_list:
                    mask = np.zeros(image_input.shape[:2], dtype=np.uint8)
                    mask[int(bbox[1]):int(bbox[3]), int(bbox[0]):int(bbox[2])] = 255
                    image_input = cv2.inpaint(image_input, mask, 3, cv2.INPAINT_TELEA)
            r = args.mask_object_ratio
            mask = np.zeros(image_input.shape[:2], dtype=np.uint8)
            x, y = point
            x_min = int(x - r)
            y_min = int(y - r)
            x_max = int(x + r)
            y_max = int(y + r)
            mask[y_min:y_max, x_min:x_max] = 255
            image_input = cv2.inpaint(image_input, mask, 3, cv2.INPAINT_TELEA)
            image_input = Image.fromarray(image_input)
        return image_input

    def remove_containing_bboxes(self, bbox_list, gt, image_size):
        click_x, click_y = gt[0] * image_size[0], gt[1] * image_size[1]
        out_bbox_list = []
        in_bbox_list = []
        if len(bbox_list) > 0:
            for bbox in bbox_list:
                x, y, w, h = bbox
                if not (x <= click_x <= x+w and y <= click_y <= y+h):
                    out_bbox_list.append(bbox)
                else:
                    in_bbox_list.append(bbox)
        return out_bbox_list, in_bbox_list, (click_x, click_y)
    
    def zoom_in(self, pil_image, obs):
        from PIL import Image

        try:
            content = obs['label']
        except (IndexError, KeyError, TypeError):
            raise ValueError("Invalid message format in obs")

        ground_truth = self.res_pre_process.extract_action(content)
        task = self.res_pre_process.get_action_type(ground_truth)
        bbox = obs.get("bbox")  # [x_min, y_min, x_max, y_max]

        w, h = pil_image.size

        if task == 1:
            click_x, click_y = self.res_pre_process.extract_coordinates(ground_truth)
            click_x = click_x /1000 * w
            click_y = click_y /1000 * h

            mid_x, mid_y = w // 2, h // 2
            if click_x < mid_x and click_y < mid_y:
                region = (0, 0, mid_x, mid_y)
            elif click_x >= mid_x and click_y < mid_y:
                region = (mid_x, 0, w, mid_y)
            elif click_x < mid_x and click_y >= mid_y:
                region = (0, mid_y, mid_x, h)
            else:
                region = (mid_x, mid_y, w, h)

            cropped = pil_image.crop(region)
            zoomed_image = cropped.resize((w, h), Image.LANCZOS)

            def transform_coord(x, y, region, w, h):
                rel_x, rel_y = x - region[0], y - region[1]
                scale_x = w / (region[2] - region[0])
                scale_y = h / (region[3] - region[1])
                new_x = int(rel_x * scale_x)
                new_y = int(rel_y * scale_y)
                return new_x, new_y

            new_click_x, new_click_y = transform_coord(click_x, click_y, region, w, h)
            norm_click_x, norm_click_y = new_click_x / w, new_click_y / h
          
            new_bbox = None
            if bbox is not None:
                bbox = [bbox[0]/1000*w, bbox[1]/1000*h, bbox[2]/1000*w, bbox[3]/1000*h]
                x_min, y_min = transform_coord(bbox[0], bbox[1], region, w, h)
                x_max, y_max = transform_coord(bbox[2], bbox[3], region, w, h)
                new_bbox = [x_min/w*1000, y_min/h*1000, x_max/w*1000, y_max/h*1000]

            return zoomed_image, {"label": [norm_click_x, norm_click_y], "bbox": new_bbox}

        return pil_image, None
    
    def structure_mask(self, image):
        masked_image = np.zeros_like(image)
        image_input = Image.fromarray(masked_image)
        return image_input
    