from typing import Dict
import re
import logging
import os
from PIL import Image
from copy import deepcopy

from transformers import Qwen2_5_VLProcessor

from .image import mask_image, compute_crop_box
from .generate import generate

def get_input(padded_size, user_query, processor: Qwen2_5_VLProcessor):
    
    x, y = padded_size
    
    qwen_prompts = "You are a helpful assistant."

    user_prompt = f"The image is a screenshot of a computer or mobile phone interface, with a resolution of {x}x{y}. Please provide the coordinates of the object to be operated according to the command, which is as follows: {user_query}.\n"

    user_prompt_repeat = f"\nRepeat the task again for you:\nPlease provide the coordinates of the object to be operated according to the command, which is as follows: {user_query}. You must output in the following format, and the specific format is as follows: <|box_start|>(x1,y1),(x2,y2)<|box_end|>\n"

    message = [
        {"role": "system", "content": qwen_prompts},
        {
            "role": "user",
            "content": [
                {"type": "text", "text": f"{user_prompt}"},
                {"type": "image", "image": None},
                {"type": "text", "text": f"{user_prompt_repeat}"},
            ]
        }
    ]

    text = processor.apply_chat_template(message, tokenize=False, add_generation_prompt=True)
    return text

def process_image(img, max_size=7494400):
    """scale image and padding to multiples of 28"""
    width, height = img.size
    original_pixels = width * height
    scale_factor = 1.0

    if original_pixels > max_size:
        scale_factor = (max_size / original_pixels) ** 0.5
        new_width = max(1, int(width * scale_factor))
        new_height = max(1, int(height * scale_factor))
        logging.info(f"scale image from {width}×{height} to {new_width}×{new_height}")
    else:
        new_width, new_height = width, height

    padded_width = ((new_width // 28) + 1) * 28 if new_width % 28 != 0 else new_width
    padded_height = ((new_height // 28) + 1) * 28 if new_height % 28 != 0 else new_height

    img = img.resize((new_width, new_height), Image.LANCZOS)
    img_rgb = Image.new('RGB', (padded_width, padded_height), (0, 0, 0))
    img_rgb.paste(img, (0, 0))

    return img_rgb, scale_factor, (padded_width, padded_height)

def extract_output(output_text):
    """
    skip_special_tokens=False, output: <|box_start|>(593,264),(681,354)<|box_end|><|im_end|>
    skip_special_tokens=True, output: (593,264),(681,354)
    """
    point_in_pixel = None
    bbx_pred = None
    try:
        pattern = r'\((\d+),(\d+)\)'
        matches = re.findall(pattern, output_text)
        if len(matches) == 2:
            coords = [int(c) for m in matches for c in m]
            bbx_pred = tuple(coords)
            x1, y1, x2, y2 = bbx_pred
            center_x = (x1 + x2) / 2
            center_y = (y1 + y2) / 2
            point_in_pixel = (center_x, center_y)
    except Exception as e:
        logging.error("wrong_format in extract_output: %s", e)

    return point_in_pixel, bbx_pred

def compute_ground_result(img, text, model, processor):
    # model forward
    inputs = processor(
        text=[text], images=[img], 
        max_length=40000, truncation=False, 
        padding=True, return_tensors="pt").to('cuda')
    output_dict = generate(model=model, **inputs, max_new_tokens=50, return_scores=True)
    # extract results
    output_ids = output_dict["output_ids"]
    generated_ids = [
        output_ids[len(input_ids):] for input_ids, output_ids 
        in zip(inputs.input_ids, output_ids)
    ]
    output_text = processor.batch_decode(generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True)[0]
    point_pred, bbox_pred = extract_output(output_text)

    return dict(
        point_pred=point_pred,
        bbox_pred=bbox_pred,
        output_text=output_text
    )

class BaseAction:
    _required_keys_ = ["image"]
    
    def check_output(self, output_dict):
        for key in self._required_keys_:
            if key not in output_dict:
                raise ValueError(f"{key} not found in output_dict")

    def __call__(self, input_dict, model_dict):
        output_dict = self.compute(input_dict, model_dict)
        self.check_output(output_dict)
        return output_dict

    def compute(self, input_dict, model_dict) -> Dict:
        raise NotImplementedError

class Grounding(BaseAction):
    _required_keys_ = ["image", "bbox_abs"]

    def compute(self, input_dict, model_dict):
        # prepare input
        assert len(input_dict) == 1, f"The length of input_dict should be 1, but got {len(input_dict)}"
        input_dict = input_dict[0]
        image = input_dict["image"]
        user_query = input_dict["user_query"]
        coord_abs = input_dict["coord_abs"]
        model = model_dict["model"]
        processor = model_dict["processor"]

        # compute
        output_dict = deepcopy(input_dict)
        img, scale, padded_size = process_image(image)
        text = get_input(padded_size, user_query, processor)
        ground_result = compute_ground_result(img, text, model, processor)
        output_dict.update(**ground_result)

        if coord_abs is None:
            a_x1, a_y1 = 0, 0
        else:
            assert len(coord_abs) == 4 # (x1, y1, x2, y2)
            a_x1, a_y1, *_ = coord_abs
        
        if output_dict["bbox_pred"] is not None:
            r_x1, r_y1, r_x2, r_y2 = output_dict["bbox_pred"]
            r_x1 = r_x1 / scale
            r_y1 = r_y1 / scale
            r_x2 = r_x2 / scale
            r_y2 = r_y2 / scale
            x1 = a_x1 + r_x1
            y1 = a_y1 + r_y1
            x2 = a_x1 + r_x2
            y2 = a_y1 + r_y2
            output_dict["bbox_abs"] = (x1, y1, x2, y2)
            output_dict["point_abs"] = ((x1 + x2) / 2, (y1 + y2) / 2)
            output_dict["bbox_pred"] = (r_x1, r_y1, r_x2, r_y2)
        else:
            output_dict["bbox_abs"] = None
            output_dict["point_abs"] = None
        output_dict["image"] = img
        output_dict["rescale"] = scale
        
        return output_dict

class MaskGrounding(Grounding):
    def compute(self, input_dict, model_dict):
        output_dict = super(MaskGrounding, self).compute(input_dict, model_dict)
        # mask image
        if output_dict["bbox_pred"] is not None:
            mask_bbox = output_dict["bbox_pred"]
            image = output_dict["image"]
            output_dict["image"] = mask_image(image, mask_bbox)
        return output_dict

class Crop(BaseAction):
    _required_keys_ = ["image"]

    def __init__(self, crop_ratio: float = 0.25):
        super(Crop, self).__init__()
        self.crop_ratio = crop_ratio

    def compute(self, input_dict, model_dict):
        # prepare input
        assert len(input_dict) > 1
        output_dict = deepcopy(input_dict[0])
        base_image = output_dict["image"]
        ref_bbox = []
        for idx in range(1, len(input_dict)):
            cur_bbox_pred = input_dict[idx]["bbox_pred"]
            if cur_bbox_pred is not None:
                ref_bbox.append(input_dict[idx]["bbox_pred"])
        # crop image
        bbox_crop = compute_crop_box(ref_bbox, base_image.size, ratio=self.crop_ratio)
        
        # Additional safety check
        x1, y1, x2, y2 = bbox_crop
        if x1 >= x2 or y1 >= y2:
            logging.warning(f"Invalid crop box: {bbox_crop}, using full image")
            bbox_crop = (0, 0, base_image.width, base_image.height)
        
        crop_image = base_image.crop(bbox_crop)
        output_dict["image"] = crop_image
        output_dict["bbox_crop"] = bbox_crop
        output_dict["coord_abs"] = bbox_crop

        return output_dict


class DrawDualBoxesSeparate(BaseAction):
    """Draw two bounding boxes separately and save as two independent images, each centered on bbox with 20% expansion"""
    _required_keys_ = ["image"]
    
    def compute(self, input_dict, model_dict):
        # Get bbox from two inputs
        assert len(input_dict) == 2, "DrawDualBoxesSeparate requires two inputs"
        
        bbox1 = input_dict[0].get("bbox_abs")
        bbox2 = input_dict[1].get("bbox_abs")
        # Get original image from pipeline initial state
        base_image = model_dict.get("original_image", input_dict[0]["image"])
        user_query = input_dict[0]["user_query"]
        
        logging.info(f"DrawDualBoxesSeparate - Processing two bounding boxes")
        logging.info(f"Box 1 (First grounding): {bbox1}")
        logging.info(f"Box 2 (After mask regrounding): {bbox2}")
        
        # Validate bbox
        if bbox1 is not None and len(bbox1) == 4:
            if bbox1[0] >= bbox1[2] or bbox1[1] >= bbox1[3]:
                logging.warning(f"Box 1 coordinates invalid: {bbox1}")
                bbox1 = None
        
        if bbox2 is not None and len(bbox2) == 4:
            if bbox2[0] >= bbox2[2] or bbox2[1] >= bbox2[3]:
                logging.warning(f"Box 2 coordinates invalid: {bbox2}")
                bbox2 = None
        
        from PIL import ImageDraw, ImageFont, Image
        
        # Get original image dimensions
        img_width, img_height = base_image.size
        
        # Calculate expansion dimensions (20% of original image size)
        expand_width = int(img_width * 0.2)
        expand_height = int(img_height * 0.2)
        logging.info(f"Image expansion size: {expand_width}x{expand_height} (20% of original)")
        
        # Process first bbox
        img1 = self._process_single_bbox(base_image, bbox1, "1", (0, 255, 0), expand_width, expand_height)
        
        # Process second bbox
        img2 = self._process_single_bbox(base_image, bbox2, "2", (255, 0, 0), expand_width, expand_height)
        
        # Return single output dict containing two images
        output_dict = deepcopy(input_dict[0])
        output_dict["image1"] = img1
        output_dict["image2"] = img2
        output_dict["bbox1"] = bbox1
        output_dict["bbox2"] = bbox2
        output_dict["user_query"] = user_query
        # Keep image field to satisfy BaseAction requirements
        output_dict["image"] = img1  # Default to first image
        
        return output_dict
    
    def _process_single_bbox(self, base_image, bbox, label, color, expand_width, expand_height):
        """Process single bbox, draw and expand around center"""
        from PIL import ImageDraw, ImageFont, Image
        
        if bbox is None:
            # If no bbox, return original image
            return base_image.copy()
        
        # Copy image
        img = base_image.copy()
        
        # Create transparent layer
        overlay = Image.new('RGBA', base_image.size, (0, 0, 0, 0))
        draw_overlay = ImageDraw.Draw(overlay)
        
        # Try to load font
        try:
            font = ImageFont.truetype("arial.ttf", 24)
        except:
            font = ImageFont.load_default()
        
        # Draw bounding box and label
        draw_overlay.rectangle(bbox, outline=color + (255,), width=3)
        draw_overlay.text((bbox[0], bbox[1] - 30), label, fill=color + (255,), font=font)
        
        # Blend layers
        img = Image.alpha_composite(img.convert('RGBA'), overlay).convert('RGB')
        
        # Calculate bbox center
        center_x = (bbox[0] + bbox[2]) / 2
        center_y = (bbox[1] + bbox[3]) / 2
        
        # Calculate crop region (centered on bbox with 20% expansion)
        crop_left = max(0, int(center_x - expand_width))
        crop_top = max(0, int(center_y - expand_height))
        crop_right = min(img.width, int(center_x + expand_width))
        crop_bottom = min(img.height, int(center_y + expand_height))
        
        # Ensure crop region is valid
        if crop_right <= crop_left:
            # If width invalid, ensure minimum width
            if center_x < img.width / 2:
                crop_right = min(img.width, crop_left + 100)  # At least 100 pixels wide
            else:
                crop_left = max(0, crop_right - 100)
        
        if crop_bottom <= crop_top:
            # If height invalid, ensure minimum height
            if center_y < img.height / 2:
                crop_bottom = min(img.height, crop_top + 100)  # At least 100 pixels high
            else:
                crop_top = max(0, crop_bottom - 100)
        
        # Crop image
        cropped_img = img.crop((crop_left, crop_top, crop_right, crop_bottom))
        
        return cropped_img




class GPTJudgeTwoImages(BaseAction):
    """Use GPT/OpenRouter to judge which image selection is better"""
    _required_keys_ = ["image", "image1", "image2"]
    
    def __init__(self, api_key=None, base_url=None, model="openai/gpt-4o", site_url=None, site_title=None):
        super().__init__()
        # Base URL: OpenRouter uses https://openrouter.ai/api/v1; OpenAI uses https://api.openai.com/v1
        self.base_url = base_url or os.environ.get("OPENROUTER_BASE_URL") or os.environ.get("OPENAI_BASE_URL") or "https://openrouter.ai/api/v1"
        # API Key: OpenRouter prioritizes OPENROUTER_API_KEY
        if "openrouter.ai" in self.base_url:
            self.api_key = api_key or os.environ.get("OPENROUTER_API_KEY") or os.environ.get("OPENAI_API_KEY")
            # OpenRouter model names usually follow provider/model format, e.g. openai/gpt-4o
            if model and "/" not in model:
                model = f"openai/{model}"
        else:
            self.api_key = api_key or os.environ.get("OPENAI_API_KEY")
        self.model = model
        self.site_url = site_url or os.environ.get("OPENROUTER_SITE_URL")
        self.site_title = site_title or os.environ.get("OPENROUTER_SITE_TITLE") or "GUI-Agent"
    
    def compute(self, input_dict, model_dict):
        # Process single input (from DrawDualBoxesSeparate)
        assert len(input_dict) == 1
        input_dict = input_dict[0]
        
        # Get two images
        image1 = input_dict["image1"]
        image2 = input_dict["image2"] 
        bbox1 = input_dict.get("bbox1")
        bbox2 = input_dict.get("bbox2")
        user_query = input_dict["user_query"]
        
        logging.info(f"GPTJudgeTwoImages - Starting to judge two images")
        logging.info(f"User query: {user_query}")
        logging.info(f"Box 1: {bbox1}")
        logging.info(f"Box 2: {bbox2}")
        
        # Call GPT API to compare two images
        selected_image, reason, response_text = self.judge_two_images(
            image1, image2, user_query
        )
        
        logging.info(f"GPT judgment result: Selected image {selected_image}")
        logging.info(f"GPT judgment reason:\n{reason}")
        
        # Build output
        output_dict = deepcopy(input_dict)
        output_dict["selected_image"] = selected_image
        output_dict["judge_reason"] = reason
        output_dict["judge_response"] = response_text
        
        # Set final bbox and point
        if selected_image == "1" and bbox1:
            output_dict["bbox_abs"] = bbox1
            output_dict["point_abs"] = ((bbox1[0] + bbox1[2]) / 2, (bbox1[1] + bbox1[3]) / 2)
            output_dict["image"] = image1  # Set selected image
        elif selected_image == "2" and bbox2:
            output_dict["bbox_abs"] = bbox2
            output_dict["point_abs"] = ((bbox2[0] + bbox2[2]) / 2, (bbox2[1] + bbox2[3]) / 2)
            output_dict["image"] = image2  # Set selected image
        else:
            # If no valid selection, keep bbox1
            output_dict["bbox_abs"] = bbox1
            output_dict["point_abs"] = ((bbox1[0] + bbox1[2]) / 2, (bbox1[1] + bbox1[3]) / 2) if bbox1 else None
            output_dict["image"] = image1
        
        return output_dict
    
    def judge_two_images(self, image1, image2, user_query):
        """Use GPT to judge which image better meets user requirements"""
        try:
            from openai import OpenAI
        except ImportError:
            logging.error("Please install openai library: pip install openai")
            raise
        
        client = OpenAI(api_key=self.api_key, base_url=self.base_url)
        
        prompt = f"""You are comparing two images to determine which one better fulfills the user's intent.

User Command: "{user_query}"

Image 1: Shows a GUI element marked with a green box labeled "1"
Image 2: Shows a GUI element marked with a red box labeled "2"

Your task: Determine which image shows the element that will best fulfill the user's command.

ANALYSIS APPROACH:
1. Examine what GUI element is highlighted in each image
2. Consider which element better matches the user's intent
3. Think about standard GUI patterns and user expectations
4. Choose the image that shows the more appropriate interaction target

KEY PRINCIPLES:
- Focus on the functional purpose of the highlighted elements
- Consider standard UI patterns (buttons for actions, text fields for input, etc.)
- Choose interactive elements over static text/labels
- If one shows a selected state and the other shows normal state, prefer the normal state
- ELEMENT QUALITY HIERARCHY (best to worst):
   - Icon + Text together (most informative and complete)
   - Complete icon alone (clear visual indicator)  
   - Complete text alone (readable label)
   - Multiple elements in one box OR incomplete elements (ambiguous target)

COMMON PITFALLS TO AVOID:
    - Don't choose based on keyword matching alone
    - Don't overlook the user's actual goal in favor of literal interpretation

Remember: Provide SPECIFIC analysis based on what you actually observe, not generic descriptions.

**OUTPUT FORMAT**:
<analysis>
Image 1: [Describe what element is highlighted and its purpose]
Image 2: [Describe what element is highlighted and its purpose]
Comparison: [Explain which better serves the user's intent and why]
</analysis>

<answer>1 or 2</answer>
<reason>Brief explanation of why this image shows the better choice</reason>"""

        try:
            import base64
            from io import BytesIO
            
            # Convert images to base64
            buffered1 = BytesIO()
            image1.save(buffered1, format="PNG")
            img1_base64 = base64.b64encode(buffered1.getvalue()).decode()
            
            buffered2 = BytesIO()
            image2.save(buffered2, format="PNG")
            img2_base64 = base64.b64encode(buffered2.getvalue()).decode()
            
            extra_headers = None
            if "openrouter.ai" in self.base_url:
                extra_headers = {
                    "HTTP-Referer": self.site_url or "https://localhost",
                    "X-Title": self.site_title,
                }
            
            response = client.chat.completions.create(
                model=self.model,
                extra_headers=extra_headers,
                messages=[
                    {
                        "role": "user",
                        "content": [
                            {"type": "text", "text": prompt},
                            {
                                "type": "image_url",
                                "image_url": {
                                    "url": f"data:image/png;base64,{img1_base64}",
                                    "detail": "high"
                                }
                            },
                            {
                                "type": "image_url", 
                                "image_url": {
                                    "url": f"data:image/png;base64,{img2_base64}",
                                    "detail": "high"
                                }
                            }
                        ]
                    }
                ],
                temperature=0,
                max_tokens=9600
            )
            
            # Parse response
            response_text = response.choices[0].message.content
            import re
            
            # Extract analysis and answer
            analysis_match = re.search(r'<analysis>(.*?)</analysis>', response_text, re.DOTALL)
            answer_match = re.search(r'<answer>(\d)</answer>', response_text)
            reason_match = re.search(r'<reason>(.*?)</reason>', response_text, re.DOTALL)
            
            selected_image = answer_match.group(1) if answer_match else "1"
            reason = reason_match.group(1).strip() if reason_match else "No reason provided"
            
            # If there's analysis content, include it in the reason
            if analysis_match:
                analysis = analysis_match.group(1).strip()
                logging.info(f"GPT detailed analysis:\n{analysis}")
                reason = f"{analysis}\n\nFinal selection: {reason}"
            
        except Exception as e:
            logging.error(f"GPT API call failed: {e}")
            selected_image = "1"
            reason = f"API call failed: {str(e)}"
            response_text = str(e)
        
        return selected_image, reason, response_text

