import os
import re
import tempfile
import time
from typing import Optional, Tuple, List, Union

import torch
from PIL import Image
from transformers import Qwen2_5_VLProcessor, Qwen2_5_VLForConditionalGeneration
import logging  

logger = logging.getLogger(__name__)  


def image_to_temp_filename(image: Image.Image) -> str:
    tmp = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
    image.save(tmp.name)
    return tmp.name


def process_image_keep_ratio_and_pad(image_path: str, max_pixels: int = 7494400) -> Tuple[Image.Image, float, Tuple[int, int]]:
    """
    Scale while keeping aspect ratio, then pad to multiples of 28 (consistent with lenovo_grounding_eval_ss_pro.py).
    Returns: (padded_image, scale_factor, (padded_w, padded_h))
    """
    with Image.open(image_path) as img:
        w, h = img.size
        orig_pixels = w * h
        scale = 1.0
        if orig_pixels > max_pixels:
            scale = (max_pixels / orig_pixels) ** 0.5
            new_w = max(1, int(w * scale))
            new_h = max(1, int(h * scale))
        else:
            new_w, new_h = w, h

        padded_w = ((new_w // 28) + 1) * 28 if new_w % 28 != 0 else new_w
        padded_h = ((new_h // 28) + 1) * 28 if new_h % 28 != 0 else new_h

        img_resized = img.resize((new_w, new_h), Image.LANCZOS)
        img_rgb = Image.new('RGB', (padded_w, padded_h), (0, 0, 0))
        img_rgb.paste(img_resized, (0, 0))
        return img_rgb, scale, (padded_w, padded_h)


def extract_bbox_from_text(text: str) -> Optional[List[int]]:
    """
    Extract four integer pixel coordinates from <|box_start|>(x1,y1),(x2,y2)<|box_end|>.
    If the tags are not found, fall back to generic '(x,y)' matching and use the first two pairs as the bbox.
    """
    tag_match = re.search(r'<\|box_start\|\>(.*?)<\|box_end\|\>', text, flags=re.S)
    source = tag_match.group(1) if tag_match else text
    pairs = re.findall(r'\((\-?\d+)\s*,\s*(\-?\d+)\)', source)
    if len(pairs) >= 2:
        try:
            x1, y1 = map(int, pairs[0])
            x2, y2 = map(int, pairs[1])
            return [x1, y1, x2, y2]
        except Exception:
            return None
    return None


def center_from_bbox(bbox_xyxy: List[int]) -> Tuple[float, float]:
    x1, y1, x2, y2 = bbox_xyxy
    return (x1 + x2) / 2.0, (y1 + y2) / 2.0


class LenovoAction7BModel:
    """
    Same interface as OSAtlas7BModel: load_model / set_generation_config / ground_only_positive / ground_allow_negative
    Internally uses Qwen2_5_VLForConditionalGeneration + Qwen2_5_VLProcessor.
    """
    def __init__(self):
        self.model = None
        self.processor = None
        self.device = "cuda"
        self.generation_kwargs = {
            "max_new_tokens": 50,
            "do_sample": False,
            "temperature": 0.0
        }
        self.verbose = bool(int(os.getenv("TIANXI7B_VERBOSE", "1")))  # Added

    def load_model(self, model_name_or_path: str, device: str = "cuda"):
        self.device = device
        self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
            model_name_or_path,
            torch_dtype=torch.bfloat16,
            attn_implementation="flash_attention_2",
            device_map="auto" if device == "auto" else None
        )
        if device != "auto":
            self.model.to(device)
        self.model.eval()
        self.processor = Qwen2_5_VLProcessor.from_pretrained(model_name_or_path)

    def set_generation_config(self, **kwargs):
        self.generation_kwargs.update(kwargs)

    # Added: console debug switch
    def set_verbose(self, verbose: bool = True):
        self.verbose = verbose

    # Added: unified debug output
    def _debug(self, msg: str):
        if self.verbose:
            print(msg)
        logger.info(msg)

    def _build_messages(self, padded_size: Tuple[int, int], image_path: str, instruction: str):
        x, y = padded_size
        qwen_prompts = "You are a helpful assistant."
        user_prompt = f"The image is a screenshot of a computer or mobile phone interface, with a resolution of {x}x{y}. Please provide the coordinates of the object to be operated according to the command, which is as follows: {instruction}.\n"
        user_prompt_repeat = (
            "\nRepeat the task again for you:\n"
            f"Please provide the coordinates of the object to be operated according to the command, which is as follows: {instruction}. "
            "You must output in the following format, and the specific format is as follows: <|box_start|>(x1,y1),(x2,y2)<|box_end|>\n"
        )
        message = [
            {"role": "system", "content": qwen_prompts},
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": f"{user_prompt}"},
                    {"type": "image", "image": image_path},
                    {"type": "text", "text": f"{user_prompt_repeat}"},
                ],
            },
        ]
        return message

    def _forward(self, prompt_text: str, pil_image: Image.Image) -> str:
        inputs = self.processor(
            text=[prompt_text],
            images=[pil_image],
            max_length=40000,
            truncation=False,
            padding=True,
            return_tensors="pt"
        )
        inputs = inputs.to(self.device if self.device != "auto" else "cuda")
        output_ids = self.model.generate(**inputs, **self.generation_kwargs)
        gen = [output_ids[len(in_ids):] for in_ids, output_ids in zip(inputs.input_ids, output_ids)]
        output_text = self.processor.batch_decode(gen, skip_special_tokens=False, clean_up_tokenization_spaces=True)[0]
        return output_text

    def _normalize_point_and_bbox(self, bbox_xyxy: Optional[List[int]], center_xy: Optional[Tuple[float, float]],
                                  orig_w: int, orig_h: int, scale: float) -> Tuple[Optional[List[float]], Optional[List[float]]]:
        """
        Use original (pre-padding) image size as the normalization base.
        """
        if bbox_xyxy is not None:
            x1, y1, x2, y2 = bbox_xyxy
            # Predictions are in padded coordinates; rescale back to original image
            x1 /= scale; y1 /= scale; x2 /= scale; y2 /= scale
            bbox_norm = [x1 / orig_w, y1 / orig_h, x2 / orig_w, y2 / orig_h]
        else:
            bbox_norm = None

        if center_xy is not None:
            cx, cy = center_xy
            cx /= scale; cy /= scale
            point_norm = [cx / orig_w, cy / orig_h]
        else:
            point_norm = None
        return point_norm, bbox_norm

    def ground_only_positive(self, instruction: str, image: Union[str, Image.Image]):
        """
        Return:
        {
            "result": "positive",
            "format": "x1y1x2y2",
            "raw_response": str,
            "bbox": [x1,y1,x2,y2] (0~1) or None,
            "point": [x,y] (0~1) or None
        }
        """
        if not isinstance(image, str):
            assert isinstance(image, Image.Image)
            tmp_path = image_to_temp_filename(image)
            image_path = tmp_path
        else:
            image_path = image
        assert os.path.exists(image_path) and os.path.isfile(image_path), "Invalid input image path."

        # Original image size
        with Image.open(image_path) as im:
            orig_w, orig_h = im.size

        # Generate padded image for model and build messages
        img_rgb, scale, padded_size = process_image_keep_ratio_and_pad(image_path)
        messages = self._build_messages(padded_size, image_path, instruction)
        prompt_text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

        # Inference
        output_text = self._forward(prompt_text, img_rgb)

        # Parse
        bbox_xyxy = extract_bbox_from_text(output_text)
        center_xy = center_from_bbox(bbox_xyxy) if bbox_xyxy else None
        point_norm, bbox_norm = self._normalize_point_and_bbox(bbox_xyxy, center_xy, orig_w, orig_h, scale)

        # Added: debug logs
        self._debug("[tianxi7b] ground_only_positive raw_response: {}".format(output_text))
        self._debug("[tianxi7b] ground_only_positive point_norm: {}, bbox_norm: {}".format(point_norm, bbox_norm))
        if point_norm:
            px_point = [round(point_norm[0] * orig_w, 2), round(point_norm[1] * orig_h, 2)]
            self._debug("[tianxi7b] ground_only_positive point_px: {}".format(px_point))
        if bbox_norm:
            px_bbox = [round(bbox_norm[0] * orig_w, 1), round(bbox_norm[1] * orig_h, 1),
                       round(bbox_norm[2] * orig_w, 1), round(bbox_norm[3] * orig_h, 1)]
            self._debug("[tianxi7b] ground_only_positive bbox_px: {}".format(px_bbox))

        return {
            "result": "positive",
            "format": "x1y1x2y2",
            "raw_response": output_text,
            "bbox": bbox_norm,
            "point": point_norm
        }

    def ground_allow_negative(self, instruction: str, image: Union[str, Image.Image]):
        if not isinstance(image, str):
            assert isinstance(image, Image.Image)
            tmp_path = image_to_temp_filename(image)
            image_path = tmp_path
        else:
            image_path = image
        assert os.path.exists(image_path) and os.path.isfile(image_path), "Invalid input image path."

        with Image.open(image_path) as im:
            orig_w, orig_h = im.size

        img_rgb, scale, padded_size = process_image_keep_ratio_and_pad(image_path)
        messages = self._build_messages(padded_size, image_path, instruction)
        prompt_text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

        output_text = self._forward(prompt_text, img_rgb)

        bbox_xyxy = extract_bbox_from_text(output_text)
        center_xy = center_from_bbox(bbox_xyxy) if bbox_xyxy else None
        point_norm, bbox_norm = self._normalize_point_and_bbox(bbox_xyxy, center_xy, orig_w, orig_h, scale)

        if bbox_norm or point_norm:
            status = "positive"
        elif "Target does not exist".lower() in output_text.lower():
            status = "negative"
        else:
            status = "wrong_format"

        # Added: debug logs
        self._debug("[tianxi7b] ground_allow_negative raw_response: {}".format(output_text))
        self._debug("[tianxi7b] ground_allow_negative status: {}".format(status))
        self._debug("[tianxi7b] ground_allow_negative point_norm: {}, bbox_norm: {}".format(point_norm, bbox_norm))
        if point_norm:
            px_point = [round(point_norm[0] * orig_w, 2), round(point_norm[1] * orig_h, 2)]
            self._debug("[tianxi7b] ground_allow_negative point_px: {}".format(px_point))
        if bbox_norm:
            px_bbox = [round(bbox_norm[0] * orig_w, 1), round(bbox_norm[1] * orig_h, 1),
                       round(bbox_norm[2] * orig_w, 1), round(bbox_norm[3] * orig_h, 1)]
            self._debug("[tianxi7b] ground_allow_negative bbox_px: {}".format(px_bbox))

        return {
            "result": status,
            "format": "x1y1x2y2",
            "raw_response": output_text,
            "bbox": bbox_norm,
            "point": point_norm
        }