
import os
import sys
import json
import argparse
import traceback
from typing import Dict, Any, List, Optional
from tqdm import tqdm
from PIL import Image
import uuid

current_dir = os.path.dirname(os.path.abspath(__file__))
preprocess_dir = os.path.dirname(current_dir)
utils_dir = os.path.join(preprocess_dir, "utils")
sys.path.insert(0, utils_dir)

try:
    from vllm import LLM, SamplingParams
    from transformers import AutoProcessor
    VLLM_AVAILABLE = True
except ImportError:
    VLLM_AVAILABLE = False
    print("Error: Cannot import vLLM, please install vLLM: pip install vllm")
    sys.exit(1)

from qwen3_mobile_use import MOBILE_USE_TOOL_SCHEMA


def build_annotation_system_prompt() -> str:
    tool_json = json.dumps(MOBILE_USE_TOOL_SCHEMA, ensure_ascii=False)
    return f"""# Tools

You may call one or more functions to assist with the user query.

You are provided with function signatures within <tools></tools> XML tags:
<tools>
{tool_json}
</tools>

For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
<tool_call>
{{"name": <function-name>, "arguments": <args-json-object>}}
</tool_call>


Response format for every step:
1) Thought: one concise sentence explaining the next move based on the image content and instruction (no multi-step reasoning).
   - The Thought MUST be faithful to the image content: accurately describe what you see in the image.
   - The Thought MUST be faithful to the instruction: correctly interpret the task requirement.
2) Action: a short imperative describing what to do in the UI.
   - The Action MUST be faithful to the Thought: the action description must match what the Thought describes.
3) A single <tool_call>...</tool_call> block containing only the JSON: {{"name": <function-name>, "arguments": <args-json-object>}}.
   - The tool_call arguments MUST be faithful to the Action: the actual action parameters must match the Action description.
   - IMPORTANT: Coordinates in tool_call must be in normalized format (0-1000 range), NOT pixel coordinates. Convert pixel coordinates to 0-1000 range: normalized_x = pixel_x / image_width * 1000, normalized_y = pixel_y / image_height * 1000.


CRITICAL: You must ensure faithfulness at all levels:
1. Thought faithfulness: Your Thought must accurately reflect what you can see in the image and what the instruction says. Base your reasoning on actual observation, not on patterns or shortcuts.
2. Action faithfulness: Your Action description must be consistent with your Thought. The action you describe should be what your Thought suggests based on your observations.
3. Tool call faithfulness: The tool_call MUST be EXACTLY the same as the required tool_call provided in the user query. You MUST NOT modify any parameters, coordinates, or values. Use the exact tool_call JSON provided.
4. Ground Truth alignment: The user query will provide a REQUIRED tool_call that you must use exactly. Your output must use this exact tool_call without any modifications.


The user query will provide a REQUIRED tool_call that you must use exactly. Your task is to:
1. Observe the image carefully and describe what you see
2. Read the instruction carefully and interpret what it says
3. Understand what the required tool_call is asking you to do
4. Generate a Thought that explains WHY this required action makes sense based on your observation of the image and the instruction
5. Generate an Action that describes WHAT you will do (must match the required tool_call)
6. Use the EXACT tool_call provided (do not modify it)

Your Thought should explain WHY the required action is appropriate based on what you observe, not just state what the action is. Vary your phrasing naturally - use different sentence structures, vocabulary, and expressions to avoid repetitive patterns. For example:
- "I can see a search box in the image. Based on the instruction to search for 'chaos', I should click on the search box to activate it."
- "The screen displays a search field at the top. Since the task requires searching for 'chaos', I'll tap on it."
- "Looking at the interface, there's a search bar visible. The instruction asks me to search, so I need to click this element."
- "I notice a search input field. To complete the search task for 'chaos', I should activate it by clicking."

Similarly, vary your Action descriptions:
- "Click on the search box."
- "Tap the search field."
- "Select the search input."
- "Press the search bar."

IMPORTANT: You must use the exact tool_call provided in the user query. Do not generate your own tool_call or modify the provided one.

Rules:
- Output exactly in the order: Thought, Action, <tool_call>.
- You MUST start with "Thought:" label, then "Action:" label, then "<tool_call>" tag.
- Format: 
  Thought: [your one-sentence thought - vary your phrasing naturally]
  Action: [your one-sentence action - use different verbs and expressions]
  <tool_call>
  {{"name": "mobile_use", "arguments": {{...}}}}
  </tool_call>
- Be brief: one sentence for Thought, one for Action.
- Do not output anything else outside those three parts.
- Do NOT skip the "Thought:" or "Action:" labels.
- Vary your language: use different sentence structures, synonyms, and natural expressions to avoid repetitive patterns.
- If finishing, use action=terminate in the tool call.
- Always ensure faithfulness: Thought → Action → tool_call must be consistent.
- Always base your reasoning on actual observation of the image and instruction."""


def build_trap_annotation_system_prompt() -> str:
    tool_json = json.dumps(MOBILE_USE_TOOL_SCHEMA, ensure_ascii=False)
    return f"""# Tools

You may call one or more functions to assist with the user query.

You are provided with function signatures within <tools></tools> XML tags:
<tools>
{tool_json}
</tools>

For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
<tool_call>
{{"name": <function-name>, "arguments": <args-json-object>}}
</tool_call>


Response format for every step:
1) Thought: one concise sentence explaining the next move based on the image content and instruction (no multi-step reasoning).
   - The Thought MUST be faithful to the image content: accurately describe what you see in the image.
   - The Thought MUST be faithful to the instruction: correctly interpret the task requirement.
2) Action: a short imperative describing what to do in the UI.
   - The Action MUST be faithful to the Thought: the action description must match what the Thought describes.
3) A single <tool_call>...</tool_call> block containing only the JSON: {{"name": <function-name>, "arguments": <args-json-object>}}.
   - The tool_call arguments MUST be faithful to the Action: the actual action parameters must match the Action description.
   - IMPORTANT: Coordinates in tool_call must be in normalized format (0-1000 range), NOT pixel coordinates.


CRITICAL: You must ensure faithfulness at all levels:
1. Thought faithfulness: Your Thought must accurately reflect what you can see in the image and what the instruction says.
2. Action faithfulness: Your Action description must be consistent with your Thought.
3. Tool call faithfulness: The tool_call MUST match what you describe in Action.


You will encounter two types of trap scenarios:

**Type 1: Image Modification (mask/inpaint)**
- The original screenshot (before modification) showing what the screen looked like
- The current screenshot (after modification) which has been altered - UI elements have been removed, masked, or obscured

**Type 2: Instruction Modification (instruction_modify)**
- The instruction you receive has been modified and does NOT match the current screen or task history
- The current screen may show a different app, different page, or different state than what the instruction asks for

**CRITICAL for Type 1 (mask/inpaint)**: Compare the two images carefully. The current modified screen has been altered using mask or inpaint techniques - you should be able to see masked regions, inpainted areas, or missing UI elements.

**CRITICAL for Type 2 (instruction_modify)**: Compare the instruction with what you see on the current screen. The instruction has been modified and does NOT align with the current screen state, app, or available features.

**IMPORTANT for Type 1 (mask/inpaint)**: In your Thought, you MUST:
1. Identify that the image has been modified (mention "mask", "inpaint", "masked region", "inpainted area", "altered section", etc.)
2. Describe the SPECIFIC LOCATION of the mask/inpaint using spatial terms such as:
   - Horizontal: "left side", "right side", "center", "middle"
   - Vertical: "top", "upper", "bottom", "lower"
   - Combined: "top-left corner", "top-right area", "bottom-left section", "bottom-right region", "center-left", "center-right", "upper-middle", "lower-middle", etc.
   - Relative: "near the top", "towards the bottom", "in the middle area", "on the left edge", etc.
3. Explain that the target element in that specific location is missing or obscured

**IMPORTANT for Type 2 (instruction_modify)**: In your Thought, you MUST:
1. Identify that the instruction has been modified (mention "instruction mismatch", "instruction doesn't match", "modified instruction", "instruction conflict", etc.)
2. Describe the SPECIFIC mismatch between instruction and screen:
   - What the instruction asks for vs. what you actually see
   - Which app the instruction mentions vs. which app you're in
   - Which feature the instruction refers to vs. what's available on screen
   - Any UI elements the instruction mentions vs. what's visible
3. Explain that you need to navigate back because the instruction doesn't align with the current screen

**DO NOT** attempt to click on any coordinates in the modified screen - the target elements are gone.

Your task: Recognize the mask/inpaint modification, describe its location, then choose the appropriate recovery action.


**MANDATORY**: Since the image has been modified and target UI elements are missing, you MUST use a recovery action. In most cases, you MUST press Back.

1. **Press Back (REQUIRED in most cases)**: Use `system_button` with `button: "Back"` when:
   - The target element is missing or obscured in the modified image
   - You need to return to a previous screen where the element might be accessible
   - The current screen shows unexpected content due to modification
   - You're still in the same app but the target is not visible

2. **Press Home**: Use `system_button` with `button: "Home"` when:
   - You're in a completely wrong app
   - The app seems stuck or unresponsive
   - You need to start the task from scratch

3. **Terminate with failure (ONLY as last resort)**: Use `terminate` with `status: "failure"` ONLY when:
   - The task is fundamentally impossible (e.g., the app doesn't exist)
   - Multiple recovery attempts have failed
   - The instruction itself is contradictory or impossible

**IMPORTANT**: Since the image is modified, the target element is missing. You MUST press Back (or Home if appropriate) - do NOT attempt the original action on the modified screen.


Vary your Thought and Action phrasing naturally to avoid repetitive patterns. Use different sentence structures, vocabulary, and expressions.


Thought examples (MUST mention mask/inpaint and its location):
- "Comparing the two images, I can see a masked region in the center of the screen where the target button was located in the original image."
- "The original screen had a search field in the top area, but the current modified image shows an inpainted region in that location, making the element no longer visible."
- "Looking at both screens, I notice a mask covering the bottom-right corner where the intended UI element was in the original."
- "I can see that the current screen has been modified with inpaint in the middle section - the target element that was there is now missing."
- "The modified image shows a masked area in the upper-left portion, and the button that was present in the original screenshot is no longer accessible."
- "There's an inpainted region in the center-left area of the modified screen, and the original element in that location has been removed."
- "I notice a mask covering the top region of the screen, obscuring the search field that was visible in the original image."
- "The bottom area of the current screen shows signs of inpaint modification, and the target element that was there is gone."


Thought examples (MUST identify the mismatch and describe what doesn't align):
- "The instruction asks me to search for 'weather', but I'm currently on a settings screen that doesn't have a search function, so the instruction doesn't match what I see."
- "I notice a mismatch: the instruction mentions opening App A, but the current screen shows I'm in App B, indicating the instruction has been modified."
- "The current screen displays a music player interface, but the instruction asks about creating a playlist in a different app, so there's a clear mismatch."
- "Looking at the screen, I see a calendar view, but the instruction refers to a feature that doesn't exist here, suggesting the instruction was changed."
- "The instruction asks me to click on a 'Settings' button, but I'm currently viewing a home screen with no such button visible, showing the instruction doesn't align with the current state."
- "I can see I'm on a browser page, but the instruction mentions an app feature that's not available here, indicating the instruction has been modified."
- "The screen shows a login page, but the instruction asks about features that would only be available after logging in, showing a mismatch between the instruction and current state."
- "Comparing the instruction with what I see, the instruction mentions App X but I'm clearly in App Y, so the instruction doesn't match the current screen."

Action examples (vary naturally):
- "Press Back to return to the previous screen."
- "Go back to find the element on the previous page."
- "Navigate back since the target is missing."
- "Return to the previous screen where the element might be accessible."
- "Press the Back button to go back."

Rules:
- Output exactly in the order: Thought, Action, <tool_call>.
- You MUST start with "Thought:" label, then "Action:" label, then "<tool_call>" tag.
- Format: 
  Thought: [your one-sentence thought - vary phrasing naturally]
  Action: [your one-sentence action - use different expressions]
  <tool_call>
  {{"name": "mobile_use", "arguments": {{...}}}}
  </tool_call>
- Be brief: one sentence for Thought, one for Action.
- Do not output anything else outside those three parts.
- Do NOT skip the "Thought:" or "Action:" labels.
- Vary your language: use different sentence structures, synonyms, and natural expressions.
- Always ensure faithfulness: Thought → Action → tool_call must be consistent.
- Always base your reasoning on actual observation of the images."""


def format_label_description(label: Dict[str, Any], image_width: int, image_height: int) -> str:
    action_type = label.get("action", "")
    
    if action_type == "click":
        coord = label.get("coordinate", [])
        if coord and len(coord) >= 2:
            if coord[0] <= 1000 and coord[1] <= 1000:
                x_pixel = int(coord[0] * image_width / 1000.0) if image_width > 0 else int(coord[0])
                y_pixel = int(coord[1] * image_height / 1000.0) if image_height > 0 else int(coord[1])
            else:
                x_pixel = int(coord[0])
                y_pixel = int(coord[1])
            return f"Click at coordinate [{x_pixel}, {y_pixel}] (pixel coordinates, image resolution {image_width}×{image_height})"
        else:
            return "Click action (coordinate not specified)"
    
    elif action_type == "type":
        text = label.get("text", "")
        return f"Type text: {text}"
    
    elif action_type == "swipe":
        coord1 = label.get("coordinate", [])
        coord2 = label.get("coordinate2", [])
        if coord1 and coord2 and len(coord1) >= 2 and len(coord2) >= 2:
            if coord1[0] <= 1000 and coord1[1] <= 1000:
                x1_pixel = int(coord1[0] * image_width / 1000.0) if image_width > 0 else int(coord1[0])
                y1_pixel = int(coord1[1] * image_height / 1000.0) if image_height > 0 else int(coord1[1])
                x2_pixel = int(coord2[0] * image_width / 1000.0) if image_width > 0 else int(coord2[0])
                y2_pixel = int(coord2[1] * image_height / 1000.0) if image_height > 0 else int(coord2[1])
            else:
                x1_pixel = int(coord1[0])
                y1_pixel = int(coord1[1])
                x2_pixel = int(coord2[0])
                y2_pixel = int(coord2[1])
            return f"Swipe from coordinate [{x1_pixel}, {y1_pixel}] to [{x2_pixel}, {y2_pixel}] (pixel coordinates, image resolution {image_width}×{image_height})"
        else:
            return "Swipe action (coordinates not specified)"
    
    elif action_type == "system_button":
        button = label.get("button", "")
        return f"Press system button: {button}"
    
    elif action_type == "long_press":
        coord = label.get("coordinate", [])
        time = label.get("time", "")
        if coord and len(coord) >= 2:
            if coord[0] <= 1000 and coord[1] <= 1000:
                x_pixel = int(coord[0] * image_width / 1000.0) if image_width > 0 else int(coord[0])
                y_pixel = int(coord[1] * image_height / 1000.0) if image_height > 0 else int(coord[1])
            else:
                x_pixel = int(coord[0])
                y_pixel = int(coord[1])
            time_str = f" for {time} seconds" if time else ""
            return f"Long press at coordinate [{x_pixel}, {y_pixel}]{time_str} (pixel coordinates, image resolution {image_width}×{image_height})"
        else:
            return "Long press action (coordinate not specified)"
    
    elif action_type == "wait":
        time = label.get("time", "")
        return f"Wait for {time} seconds" if time else "Wait action"
    
    elif action_type == "terminate":
        status = label.get("status", "")
        return f"Terminate task with status: {status}" if status else "Terminate task"
    
    elif action_type == "answer":
        text = label.get("text", "")
        return f"Answer: {text}" if text else "Answer action"
    
    else:
        return f"Action: {action_type}"


def convert_gt_action_to_tool_call(gt_action: Dict[str, Any], image_width: int = None, image_height: int = None) -> Dict[str, Any]:
    action_type = gt_action.get("action", "")
    tool_call_args = {"action": action_type}
    
    if action_type == "click":
        coord = gt_action.get("coordinate", [])
        if coord and len(coord) >= 2:
            tool_call_args["coordinate"] = [int(coord[0]), int(coord[1])]
    
    elif action_type == "type":
        text = gt_action.get("text", "")
        if text:
            tool_call_args["text"] = text
    
    elif action_type == "swipe":
        coord1 = gt_action.get("coordinate", [])
        coord2 = gt_action.get("coordinate2", [])
        if coord1 and len(coord1) >= 2:
            tool_call_args["coordinate"] = [int(coord1[0]), int(coord1[1])]
        if coord2 and len(coord2) >= 2:
            tool_call_args["coordinate2"] = [int(coord2[0]), int(coord2[1])]
    
    elif action_type == "system_button":
        button = gt_action.get("button", "")
        if button:
            tool_call_args["button"] = button
    
    elif action_type == "long_press":
        coord = gt_action.get("coordinate", [])
        time = gt_action.get("time", "")
        if coord and len(coord) >= 2:
            tool_call_args["coordinate"] = [int(coord[0]), int(coord[1])]
        if time:
            tool_call_args["time"] = time
    
    elif action_type == "wait":
        time = gt_action.get("time", "")
        if time:
            tool_call_args["time"] = time
    
    elif action_type == "terminate":
        status = gt_action.get("status", "")
        if status:
            tool_call_args["status"] = status
    
    elif action_type == "answer":
        text = gt_action.get("text", "")
        if text:
            tool_call_args["text"] = text
    
    return tool_call_args


def build_user_query(instruction: str, label: Dict[str, Any], image_width: int, image_height: int,
                     action_history: Optional[str] = None) -> str:
    label_desc = format_label_description(label, image_width, image_height)
    
    required_tool_call_args = convert_gt_action_to_tool_call(label, image_width, image_height)
    required_tool_call_json = json.dumps({"name": "mobile_use", "arguments": required_tool_call_args}, ensure_ascii=False)
    
    base_query = f"Task: {instruction}\n"
    
    base_query += f"\nThe required action is: {label_desc}\n"
    base_query += f"\nYou must use this exact tool_call:\n"
    base_query += f"<tool_call>\n{required_tool_call_json}\n</tool_call>\n"
    base_query += "\nPlease:\n"
    base_query += "1. Look at the image and describe what you see\n"
    base_query += "2. Understand what the instruction is asking\n"
    base_query += "3. Write a Thought that explains why the required action makes sense based on your observation\n"
    base_query += "4. Write an Action that describes what you'll do (must match the required tool_call)\n"
    base_query += "5. Use the EXACT tool_call provided above - do not change any parameters or coordinates\n"
    base_query += "\nImportant:\n"
    base_query += "- Base your reasoning on what you actually see in the image\n"
    base_query += "- Vary your phrasing naturally - use different sentence structures, synonyms, and expressions to avoid repetitive patterns\n"
    base_query += "- The tool_call must match exactly (same action, coordinates, and parameters)\n"
    base_query += "- Your explanation should flow naturally and be consistent with the action\n"
    base_query += "- Format: Thought: [your thought] then Action: [your action] then <tool_call>...</tool_call>\n"
    
    if action_history:
        base_query += f"\nTask progress (You have done the following operation on the current device): {action_history}\n"
    
    return base_query


def build_trap_user_query(instruction: str, label: Dict[str, Any], image_width: int, image_height: int,
                          trap_mode: str, action_history: Optional[str] = None) -> str:
    label_desc = format_label_description(label, image_width, image_height)
    
    base_query = f"The user query: {instruction}.\n"
    
    if action_history:
        base_query += f"Task progress (You have done the following operation on the current device): {action_history}\n"
    
    if trap_mode in ["mask", "inpaint"]:
        required_tool_call_args = convert_gt_action_to_tool_call(label, image_width, image_height)
        required_tool_call_json = json.dumps({"name": "mobile_use", "arguments": required_tool_call_args}, ensure_ascii=False)
        
        base_query += f"\n[Image Comparison Context]\n"
        base_query += f"- The first image shows the ORIGINAL screen before any modification.\n"
        base_query += f"- The second image shows the CURRENT screen which has been MODIFIED using {trap_mode} technique - UI elements have been removed, masked, or inpainted.\n"
        base_query += f"\nCRITICAL: Compare the two images carefully. Look for masked regions, inpainted areas, or altered sections in the modified screen.\n"
        base_query += f"The current modified screen is missing UI elements that were present in the original.\n"
        base_query += f"\nIn your Thought, you MUST:\n"
        base_query += f"1. Identify that the image has been modified with {trap_mode} (mention \"mask\", \"inpaint\", \"masked region\", \"inpainted area\", etc.)\n"
        base_query += f"2. Describe the SPECIFIC LOCATION of the mask/inpaint (e.g., \"top-left corner\", \"center region\", \"bottom area\", \"right side\", \"middle section\", \"upper portion\", \"lower-left area\", etc.)\n"
        base_query += f"3. Explain that the target element in that location is missing or obscured\n"
        base_query += f"\nYou MUST NOT attempt to click on any coordinates in the modified screen - the target element is gone.\n"
        base_query += f"You MUST press the Back button to return to the previous screen.\n"
        base_query += f"\nREQUIRED tool_call (you MUST use this exact tool_call):\n"
        base_query += f"<tool_call>\n{required_tool_call_json}\n</tool_call>\n"
        base_query += f"\nYour response should:\n"
        base_query += f"1. Thought: Identify the {trap_mode} modification, describe its specific location (top/bottom/left/right/center, etc.), and explain that the target element is missing, so you need to go back.\n"
        base_query += f"2. Action: Describe pressing the Back button.\n"
        base_query += f"3. tool_call: Use the EXACT tool_call provided above (do not modify it)\n"
        base_query += f"\nOnly use terminate with failure if pressing Back is truly impossible.\n"
    else:
        base_query += f"\n[Task Verification Context]\n"
        base_query += f"- IMPORTANT: The instruction you received has been MODIFIED and does NOT match the current screen or task history.\n"
        base_query += f"- Look at the current screen carefully and compare it with the instruction.\n"
        base_query += f"- Based on the action history, you have been performing operations for a DIFFERENT task.\n"
        base_query += f"\nCRITICAL: The instruction does NOT align with what you see on the screen.\n"
        base_query += f"Examples of mismatches:\n"
        base_query += f"- The instruction asks about App A but you're in App B\n"
        base_query += f"- The instruction asks for a feature that doesn't exist on the current screen\n"
        base_query += f"- The instruction mentions UI elements that are not visible\n"
        base_query += f"- The instruction refers to actions that don't make sense given the current screen state\n"
        base_query += f"\nIn your Thought, you MUST:\n"
        base_query += f"1. Identify that the instruction has been modified (mention \"instruction mismatch\", \"instruction doesn't match\", \"modified instruction\", etc.)\n"
        base_query += f"2. Describe the SPECIFIC mismatch (e.g., \"instruction asks for X but screen shows Y\", \"instruction mentions App A but I'm in App B\", \"instruction refers to feature Z that doesn't exist here\", etc.)\n"
        base_query += f"3. Explain that you need to navigate back because the instruction doesn't align with the current screen\n"
        base_query += f"\nYou MUST use a recovery action. In most cases, press Back:\n"
        base_query += f"- Press Back to navigate to a more appropriate screen where the instruction might make sense.\n"
        base_query += f"- Press Home if you're in a completely wrong app.\n"
        base_query += f"- Only terminate with failure if the mismatch is fundamental and unrecoverable.\n"
        base_query += f"\nYour response should:\n"
        base_query += f"1. Thought: Identify the instruction modification, describe the specific mismatch (what the instruction asks vs. what you see), and explain that you need to navigate back.\n"
        base_query += f"2. Action: Describe pressing the Back button (or Home if completely wrong app).\n"
        base_query += f"3. tool_call: Use action=\"system_button\" with button=\"Back\" (or \"Home\" if appropriate)\n"
    
    return base_query


def load_image_as_pil(image_path: Optional[str]) -> Optional[Image.Image]:
    if not image_path or not os.path.exists(image_path):
        return None
    
    try:
        image = Image.open(image_path)
        if image.mode != "RGB":
            image = image.convert("RGB")
        return image
    except Exception as e:
        print(f"Warning: Cannot load image {image_path}: {e}")
        return None


class NormalDataAnnotator:
    
    def __init__(self, model_path: str, batch_size: int = 512, device_ids: str = "[0]", 
                 tensor_parallel_size: Optional[int] = None, seed: int = 42, is_trap_mode: bool = False):
        self.model_path = model_path
        self.batch_size = batch_size
        self.seed = seed
        self.is_trap_mode = is_trap_mode
        
        try:
            device_ids_list = eval(device_ids)
        except:
            device_ids_list = [0]
        
        if tensor_parallel_size is None:
            tensor_parallel_size = len(device_ids_list)
        
        if device_ids_list:
            os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, device_ids_list))
        
        mode_str = "TRAP" if is_trap_mode else "NORMAL"
        print(f"[Annotator] Mode: {mode_str}")
        print(f"[Annotator] Model path: {self.model_path}")
        print(f"[Annotator] Devices: {device_ids_list}, tensor_parallel_size: {tensor_parallel_size}")
        print(f"[Annotator] Batch size: {self.batch_size}")
        
        self._init_vllm_engine(tensor_parallel_size)
        
        print(f"[Annotator] Loading processor...")
        self.processor = AutoProcessor.from_pretrained(self.model_path, trust_remote_code=True)
        print(f"[Annotator] Processor loaded")
        
        self.system_prompt = build_trap_annotation_system_prompt()
        print(f"[Annotator] Default system prompt: TRAP (will adjust per sample based on trap_mode)")
    
    def _init_vllm_engine(self, tensor_parallel_size: int):
        os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
        
        print(f"[Annotator] Initializing vLLM engine...")
        
        max_images = 2 if self.is_trap_mode else 1
        
        engine_args = {
            "model": self.model_path,
            "tensor_parallel_size": tensor_parallel_size,
            "seed": self.seed,
            "max_model_len": 16384,
            "mm_processor_kwargs": {
                "min_pixels": 32 * 32,
                "max_pixels": 8192 * 8192,
                "fps": 1,
            },
            "limit_mm_per_prompt": {"image": max_images, "video": 0, "audio": 0},
        }
        
        self.llm = LLM(**engine_args)
        self.sampling_params = SamplingParams(
            temperature=0,
            max_tokens=1024
        )
        
        print(f"[Annotator] vLLM engine initialized")
    
    def annotate_batch(self, samples: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        vllm_inputs = []
        sample_metadata = []
        
        print(f"[Annotator] Building vLLM inputs for {len(samples)} samples...")
        
        for sample in tqdm(samples, desc="Building inputs"):
            image_path = sample.get("image_path", None)
            if not image_path:
                images = sample.get("images", [])
                if images and len(images) > 0:
                    image_path = images[0]
            
            instruction = sample.get("instruction", "")
            gt_action = sample.get("gt_action", {})
            action_history = sample.get("action_history", "")
            
            image_width = sample.get("image_width", None)
            image_height = sample.get("image_height", None)
            
            if (image_width is None or image_height is None) and image_path and os.path.exists(image_path):
                try:
                    pil_image = load_image_as_pil(image_path)
                    if pil_image is not None:
                        if image_width is None:
                            image_width = pil_image.width
                        if image_height is None:
                            image_height = pil_image.height
                except:
                    pass
            
            if image_width is None:
                image_width = 1080
            if image_height is None:
                image_height = 2400
            
            metadata = {
                "sample": sample,
                "image_width": image_width,
                "image_height": image_height
            }
            
            try:
                if not instruction:
                    vllm_inputs.append({"prompt": ""})
                    metadata["error"] = "Missing instruction field"
                    sample_metadata.append(metadata)
                    continue
                
                if not gt_action:
                    vllm_inputs.append({"prompt": ""})
                    metadata["error"] = "Missing gt_action field"
                    sample_metadata.append(metadata)
                    continue
                
                trap_mode = sample.get("trap_mode", "")
                data_type = sample.get("data_type", 0)
                
                use_trap_annotation = False
                if trap_mode in ["mask", "instruction_modify"]:
                    use_trap_annotation = True
                elif data_type == 2:
                    use_trap_annotation = True
                elif self.is_trap_mode and trap_mode != "unrelated_mask":
                    use_trap_annotation = True
                
                if use_trap_annotation:
                    current_system_prompt = build_trap_annotation_system_prompt()
                    if not trap_mode:
                        trap_mode = "mask"
                    user_query = build_trap_user_query(instruction, gt_action, image_width, image_height, trap_mode, action_history)
                else:
                    current_system_prompt = build_annotation_system_prompt()
                    user_query = build_user_query(instruction, gt_action, image_width, image_height, action_history)
                
                if not image_path:
                    vllm_input = {"prompt": f"{current_system_prompt}\n\n{user_query}"}
                    metadata["error"] = "Image path is empty"
                    vllm_inputs.append(vllm_input)
                    sample_metadata.append(metadata)
                    continue
                
                if not os.path.exists(image_path):
                    vllm_input = {"prompt": f"{current_system_prompt}\n\n{user_query}"}
                    metadata["error"] = f"Image not found: {image_path}"
                    vllm_inputs.append(vllm_input)
                    sample_metadata.append(metadata)
                    continue
                
                if image_path and os.path.exists(image_path):
                    pil_image = load_image_as_pil(image_path)
                    if pil_image is not None:
                        trap_mode = sample.get("trap_mode", "")
                        data_type = sample.get("data_type", 0)
                        old_images = sample.get("old_images", [])
                        
                        use_trap_annotation = False
                        if trap_mode in ["mask", "instruction_modify"]:
                            use_trap_annotation = True
                        elif data_type == 2:
                            use_trap_annotation = True
                        elif self.is_trap_mode and trap_mode != "unrelated_mask":
                            use_trap_annotation = True
                        
                        if use_trap_annotation and trap_mode == "mask" and old_images and len(old_images) > 0:
                            old_image_path = old_images[0]
                            old_pil_image = load_image_as_pil(old_image_path) if os.path.exists(old_image_path) else None
                            
                            if old_pil_image is not None:
                                messages = [
                                    {
                                        "role": "system",
                                        "content": [{"type": "text", "text": current_system_prompt}]
                                    },
                                    {
                                        "role": "user",
                                        "content": [
                                            {"type": "image", "image": "placeholder_0"},
                                            {"type": "image", "image": "placeholder_1"},
                                            {"type": "text", "text": user_query}
                                        ]
                                    }
                                ]
                                
                                prompt_text = self.processor.apply_chat_template(
                                    messages, tokenize=False, add_generation_prompt=True
                                )
                                
                                image_uuid_0 = f"uuid_{uuid.uuid4().hex[:8]}"
                                image_uuid_1 = f"uuid_{uuid.uuid4().hex[:8]}"
                                vllm_input = {
                                    "prompt": prompt_text,
                                    "multi_modal_data": {"image": [old_pil_image, pil_image]},
                                    "multi_modal_uuids": {"image": [image_uuid_0, image_uuid_1]},
                                }
                            else:
                                messages = [
                                    {
                                        "role": "system",
                                        "content": [{"type": "text", "text": current_system_prompt}]
                                    },
                                    {
                                        "role": "user",
                                        "content": [
                                            {"type": "image", "image": "placeholder_0"},
                                            {"type": "text", "text": user_query}
                                        ]
                                    }
                                ]
                                
                                prompt_text = self.processor.apply_chat_template(
                                    messages, tokenize=False, add_generation_prompt=True
                                )
                                
                                image_uuid = f"uuid_{uuid.uuid4().hex[:8]}"
                                vllm_input = {
                                    "prompt": prompt_text,
                                    "multi_modal_data": {"image": [pil_image]},
                                    "multi_modal_uuids": {"image": [image_uuid]},
                                }
                        else:
                            messages = [
                                {
                                    "role": "system",
                                    "content": [{"type": "text", "text": current_system_prompt}]
                                },
                                {
                                    "role": "user",
                                    "content": [
                                        {"type": "image", "image": "placeholder_0"},
                                        {"type": "text", "text": user_query}
                                    ]
                                }
                            ]
                            
                            prompt_text = self.processor.apply_chat_template(
                                messages, tokenize=False, add_generation_prompt=True
                            )
                            
                            image_uuid = f"uuid_{uuid.uuid4().hex[:8]}"
                            vllm_input = {
                                "prompt": prompt_text,
                                "multi_modal_data": {"image": [pil_image]},
                                "multi_modal_uuids": {"image": [image_uuid]},
                            }
                    else:
                        full_prompt = f"{current_system_prompt}\n\n{user_query}"
                        vllm_input = {"prompt": full_prompt}
                        metadata["error"] = "Failed to load image"
                else:
                    full_prompt = f"{current_system_prompt}\n\n{user_query}"
                    vllm_input = {"prompt": full_prompt}
                    metadata["error"] = "Image not found"
                
                vllm_inputs.append(vllm_input)
                sample_metadata.append(metadata)
                
            except Exception as e:
                print(f"[Annotator] Failed to build input: {e}")
                traceback.print_exc()
                vllm_inputs.append({"prompt": ""})
                metadata["error"] = str(e)
                sample_metadata.append(metadata)
        
        print(f"[Annotator] Starting vLLM batch inference for {len(vllm_inputs)} samples, batch size: {self.batch_size}...")
        all_outputs = []
        
        try:
            total_batches = (len(vllm_inputs) + self.batch_size - 1) // self.batch_size
            for batch_idx in tqdm(range(0, len(vllm_inputs), self.batch_size), desc="vLLM inference batches"):
                batch_end = min(batch_idx + self.batch_size, len(vllm_inputs))
                batch_inputs = vllm_inputs[batch_idx:batch_end]
                
                batch_num = batch_idx // self.batch_size + 1
                print(f"[Annotator] Processing batch {batch_num}/{total_batches}: samples {batch_idx}-{batch_end-1}")
                
                try:
                    batch_outputs = self.llm.generate(batch_inputs, self.sampling_params)
                    all_outputs.extend(batch_outputs)
                    print(f"[Annotator] Batch {batch_num} completed, processed {batch_end}/{len(vllm_inputs)} samples")
                except Exception as e:
                    print(f"[Annotator] Batch {batch_num} failed: {e}")
                    traceback.print_exc()
                    class EmptyOutputItem:
                        def __init__(self):
                            self.text = ""
                    class EmptyRequestOutput:
                        def __init__(self):
                            self.outputs = [EmptyOutputItem()]
                    
                    for _ in range(len(batch_inputs)):
                        all_outputs.append(EmptyRequestOutput())
                    print(f"[Annotator] Added placeholders for failed batch, continuing...")
            
            print(f"[Annotator] vLLM inference completed, parsing results...")
        except Exception as e:
            print(f"[Annotator] vLLM inference failed: {e}")
            traceback.print_exc()
            results = []
            for metadata in sample_metadata:
                error = metadata.get("error", f"vLLM inference failed: {str(e)}")
                results.append({
                    "sample": metadata["sample"],
                    "error": error,
                    "output_text": ""
                })
            return results
        
        results = []
        for output, metadata in zip(all_outputs, sample_metadata):
            original_sample = metadata["sample"]
            result = original_sample.copy()
            
            if "error" in metadata:
                result["annotation_error"] = metadata["error"]
                result["output_text"] = ""
                results.append(result)
                continue
            
            try:
                output_text = output.outputs[0].text
                result["output_text"] = output_text
                
                thought = ""
                action = ""
                tool_call = None
                
                if "Thought:" in output_text:
                    thought_part = output_text.split("Thought:")[1].split("Action:")[0].strip() if "Action:" in output_text else output_text.split("Thought:")[1].strip()
                    thought = thought_part.split("\n")[0].strip()
                else:
                    if "Action:" in output_text:
                        thought = output_text.split("Action:")[0].strip()
                        thought = "\n".join([line.strip() for line in thought.split("\n")[:2] if line.strip()])
                    elif "<tool_call>" in output_text:
                        thought = output_text.split("<tool_call>")[0].strip()
                        thought = "\n".join([line.strip() for line in thought.split("\n")[:2] if line.strip()])
                
                if "Action:" in output_text:
                    action_part = output_text.split("Action:")[1].split("<tool_call>")[0].strip() if "<tool_call>" in output_text else output_text.split("Action:")[1].strip()
                    action = action_part.split("\n")[0].strip()
                else:
                    if "<tool_call>" in output_text:
                        before_toolcall = output_text.split("<tool_call>")[0].strip()
                        if "Thought:" in before_toolcall:
                            action_part = before_toolcall.split("Thought:")[-1].split("\n")
                            action_lines = [line.strip() for line in action_part[1:] if line.strip()]
                            if action_lines:
                                action = action_lines[0]
                        else:
                            lines = [line.strip() for line in before_toolcall.split("\n") if line.strip()]
                            if len(lines) > 1:
                                action = lines[-1]
                            elif lines:
                                action = lines[0]
                
                if "<tool_call>" in output_text and "</tool_call>" in output_text:
                    try:
                        tool_call_str = output_text.split("<tool_call>")[1].split("</tool_call>")[0].strip()
                        tool_call = json.loads(tool_call_str)
                    except json.JSONDecodeError as e:
                        print(f"[Annotator] Failed to parse tool_call JSON: {e}")
                        print(f"[Annotator] Raw tool_call string: {tool_call_str[:200] if tool_call_str else 'empty'}")
                        tool_call = None
                else:
                    if len(output_text) > 0:
                        print(f"[Annotator] No <tool_call> tag found. Output preview: {output_text[:300]}")
                
                trap_mode = original_sample.get("trap_mode", "")
                data_type = original_sample.get("data_type", 0)
                use_trap_annotation = False
                if trap_mode in ["mask", "instruction_modify"]:
                    use_trap_annotation = True
                elif data_type == 2:
                    use_trap_annotation = True
                elif self.is_trap_mode and trap_mode != "unrelated_mask":
                    use_trap_annotation = True
                
                if not use_trap_annotation:
                    gt_action = original_sample.get("gt_action", {})
                    if gt_action:
                        image_width = metadata.get("image_width", 1080)
                        image_height = metadata.get("image_height", 2400)
                        correct_tool_call_args = convert_gt_action_to_tool_call(gt_action, image_width, image_height)
                        correct_tool_call = {"name": "mobile_use", "arguments": correct_tool_call_args}
                        
                        if tool_call and "arguments" in tool_call:
                            parsed_args = tool_call["arguments"]
                            if (parsed_args.get("action") != correct_tool_call_args.get("action") or
                                parsed_args != correct_tool_call_args):
                                print(f"[Annotator] Warning: tool_call doesn't match gt_action, using correct one")
                                tool_call = correct_tool_call
                        else:
                            tool_call = correct_tool_call
                    else:
                        pass
                else:
                    gt_action = original_sample.get("gt_action", {})
                    
                    if trap_mode == "mask" and gt_action:
                        image_width = metadata.get("image_width", 1080)
                        image_height = metadata.get("image_height", 2400)
                        correct_tool_call_args = convert_gt_action_to_tool_call(gt_action, image_width, image_height)
                        correct_tool_call = {"name": "mobile_use", "arguments": correct_tool_call_args}
                        
                        if tool_call and "arguments" in tool_call:
                            parsed_args = tool_call["arguments"]
                            if (parsed_args.get("action") != correct_tool_call_args.get("action") or
                                parsed_args != correct_tool_call_args):
                                print(f"[Annotator] Warning: tool_call doesn't match gt_action (Back) in trap mode, using correct one")
                                tool_call = correct_tool_call
                        else:
                            tool_call = correct_tool_call
                    else:
                        if tool_call and "arguments" in tool_call:
                            parsed_args = tool_call["arguments"]
                            action_type = parsed_args.get("action", "")
                            
                            valid_actions = ["terminate", "system_button", "click", "type", "swipe", "long_press", "wait", "answer"]
                            if action_type not in valid_actions:
                                print(f"[Annotator] Warning: Invalid action type '{action_type}' in trap mode")
                            
                            if action_type == "system_button":
                                button = parsed_args.get("button", "")
                                if button not in ["Back", "Home", "Menu", "Enter"]:
                                    print(f"[Annotator] Warning: Invalid button '{button}' for system_button action")
                            
                            if action_type == "terminate":
                                status = parsed_args.get("status", "")
                                if status not in ["success", "failure"]:
                                    print(f"[Annotator] Warning: Invalid status '{status}' for terminate action (should be 'success' or 'failure')")
                        else:
                            print(f"[Annotator] Warning: Failed to parse tool_call in trap mode")
                
                result["thought"] = thought
                result["action"] = action
                result["tool_call"] = tool_call
                
                if tool_call and "arguments" in tool_call:
                    result["predicted_qwen3"] = tool_call["arguments"]
                
            except Exception as e:
                print(f"[Annotator] Failed to parse result: {e}")
                result["annotation_error"] = f"Parse failed: {str(e)}"
                result["output_text"] = output.outputs[0].text if hasattr(output, 'outputs') and output.outputs else ""
            
            results.append(result)
        
        return results


def fix_answer_tool_call(answer_text: str, gt_action: Dict[str, Any], image_width: int, image_height: int, 
                        is_trap_mode: bool = False) -> str:
    if not answer_text or not gt_action:
        return answer_text
    
    if is_trap_mode:
        return answer_text
    
    correct_tool_call_args = convert_gt_action_to_tool_call(gt_action, image_width, image_height)
    correct_tool_call = {"name": "mobile_use", "arguments": correct_tool_call_args}
    correct_tool_call_json = json.dumps(correct_tool_call, ensure_ascii=False)
    
    if "<tool_call>" in answer_text and "</tool_call>" in answer_text:
        try:
            before_toolcall = answer_text.split("<tool_call>")[0]
            after_toolcall = answer_text.split("</tool_call>")[1]
            
            existing_toolcall_str = answer_text.split("<tool_call>")[1].split("</tool_call>")[0].strip()
            existing_toolcall = json.loads(existing_toolcall_str)
            
            if existing_toolcall.get("name") == correct_tool_call.get("name"):
                existing_args = existing_toolcall.get("arguments", {})
                correct_args = correct_tool_call.get("arguments", {})
                
                if (existing_args.get("action") == correct_args.get("action") and
                    existing_args == correct_args):
                    return answer_text
                else:
                    print(f"[Fix] Tool_call mismatch detected, fixing...")
                    print(f"  Existing: {existing_toolcall_str[:100]}...")
                    print(f"  Correct: {correct_tool_call_json[:100]}...")
            else:
                print(f"[Fix] Tool_call name mismatch, fixing...")
            
            fixed_answer = before_toolcall + "<tool_call>\n" + correct_tool_call_json + "\n</tool_call>" + after_toolcall
            return fixed_answer
            
        except json.JSONDecodeError as e:
            print(f"[Fix] Failed to parse existing tool_call, replacing...")
            before_toolcall = answer_text.split("<tool_call>")[0]
            after_toolcall = answer_text.split("</tool_call>")[1]
            fixed_answer = before_toolcall + "<tool_call>\n" + correct_tool_call_json + "\n</tool_call>" + after_toolcall
            return fixed_answer
    else:
        print(f"[Fix] No tool_call found in answer, appending...")
        fixed_answer = answer_text.rstrip() + "\n<tool_call>\n" + correct_tool_call_json + "\n</tool_call>"
        return fixed_answer
    
    return answer_text


def main():
    parser = argparse.ArgumentParser(description='Annotate normal or trap data using vLLM batch processing')
    parser.add_argument('--input_file', type=str, required=True,
                       help='Input file path')
    parser.add_argument('--model_path', type=str, default="/MODEL_PATH",
                       help='Model path')
    parser.add_argument('--batch_size', type=int, default=512,
                       help='Batch size')
    parser.add_argument('--device_ids', type=str, default="[0,1,2,3,4,5,6,7]",
                       help='CUDA device IDs (string format, e.g., "[0]" or "[0,1]")')
    parser.add_argument('--tensor_parallel_size', type=int, default=8,
                       help='Tensor parallel size (if not specified, calculated from device_ids length)')
    parser.add_argument('--seed', type=int, default=42,
                       help='Random seed')
    parser.add_argument('--max_samples', type=int, default=None,
                       help='Maximum number of samples to process (for testing, None means all samples)')
    parser.add_argument('--trap_mode', action='store_true',
                       help='Enable trap mode for annotating trap data (default: False, annotate normal data)')
    
    args = parser.parse_args()
    
    print(f"[Load] Loading data: {args.input_file}")
    with open(args.input_file, 'r', encoding='utf-8') as f:
        data = json.load(f)
    print(f"[Load] Number of samples: {len(data)}")
    
    original_data = []
    for sample in data:
        original_sample = {}
        for key in sample.keys():
            original_sample[key] = sample[key]
        original_data.append(original_sample)
    
    valid_indices = []
    filtered_data = []
    for idx, sample in enumerate(data):
        if sample.get("gt_action") is not None:
            valid_indices.append(idx)
            filtered_data.append(sample)
    
    original_count = len(data)
    filtered_count = len(filtered_data)
    print(f"[Load] Filtered {original_count - filtered_count} samples with null gt_action, remaining: {filtered_count}")
    
    if args.max_samples:
        valid_indices = valid_indices[:args.max_samples]
        filtered_data = filtered_data[:args.max_samples]
        print(f"[Load] Limited to first {len(filtered_data)} samples")
    
    if len(filtered_data) == 0:
        print(f"[Warning] No valid samples to process")
        return
    
    print(f"[Init] Initializing annotator...")
    annotator = NormalDataAnnotator(
        model_path=args.model_path,
        batch_size=args.batch_size,
        device_ids=args.device_ids,
        tensor_parallel_size=args.tensor_parallel_size,
        seed=args.seed,
        is_trap_mode=args.trap_mode
    )
    
    print(f"[Annotate] Starting batch annotation...")
    results = annotator.annotate_batch(filtered_data)
    
    updated_count = 0
    fixed_count = 0
    for idx, result in enumerate(results):
        original_idx = valid_indices[idx]
        if original_idx < len(original_data):
            if "output_text" in result and result["output_text"]:
                answer_text = result["output_text"]
                gt_action = original_data[original_idx].get("gt_action")
                
                image_width = None
                image_height = None
                if "image_width" in original_data[original_idx]:
                    image_width = original_data[original_idx]["image_width"]
                if "image_height" in original_data[original_idx]:
                    image_height = original_data[original_idx]["image_height"]
                
                if (image_width is None or image_height is None):
                    images = original_data[original_idx].get("images", [])
                    image_path = original_data[original_idx].get("image_path")
                    if not image_path and images and len(images) > 0:
                        image_path = images[0]
                    
                    if image_path and os.path.exists(image_path):
                        try:
                            pil_image = load_image_as_pil(image_path)
                            if pil_image is not None:
                                if image_width is None:
                                    image_width = pil_image.width
                                if image_height is None:
                                    image_height = pil_image.height
                        except:
                            pass
                
                if image_width is None:
                    image_width = 1080
                if image_height is None:
                    image_height = 2400
                
                trap_mode = original_data[original_idx].get("trap_mode", "")
                data_type = original_data[original_idx].get("data_type", 0)
                use_trap_annotation = False
                if trap_mode in ["mask", "instruction_modify"]:
                    use_trap_annotation = True
                elif data_type == 2:
                    use_trap_annotation = True
                elif args.trap_mode and trap_mode != "unrelated_mask":
                    use_trap_annotation = True
                
                if gt_action:
                    fixed_answer = fix_answer_tool_call(answer_text, gt_action, image_width, image_height, 
                                                        is_trap_mode=use_trap_annotation)
                    if fixed_answer != answer_text:
                        fixed_count += 1
                    original_data[original_idx]["answer"] = fixed_answer
                else:
                    original_data[original_idx]["answer"] = answer_text
                
                updated_count += 1
            elif "annotation_error" in result:
                print(f"[Warning] Sample {original_idx} has annotation error: {result['annotation_error']}")
    
    if args.trap_mode:
        if args.input_file.endswith('.json'):
            output_file = args.input_file.replace('.json', '_trap_annotated.json')
        else:
            output_file = args.input_file + '_trap_annotated'
    else:
        if args.input_file.endswith('.json'):
            output_file = args.input_file.replace('.json', '_annotated.json')
        else:
            output_file = args.input_file + '_annotated'
    
    print(f"[Save] Saving results to: {output_file}")
    os.makedirs(os.path.dirname(output_file) if os.path.dirname(output_file) else ".", exist_ok=True)
    with open(output_file, 'w', encoding='utf-8') as f:
        json.dump(original_data, f, ensure_ascii=False, indent=2)
    
    total = len(results)
    success_count = sum(1 for r in results if "annotation_error" not in r and r.get("output_text"))
    error_count = sum(1 for r in results if "annotation_error" in r)
    
    print(f"\n{'='*60}")
    print("Annotation completed!")
    print(f"{'='*60}")
    print(f"Total samples: {total}")
    print(f"Successfully annotated: {success_count} ({success_count/total*100:.2f}%)")
    print(f"Failed: {error_count} ({error_count/total*100:.2f}%)")
    print(f"Updated answer fields: {updated_count}")
    print(f"Fixed tool_call alignment: {fixed_count}")
    print(f"Output file: {output_file}")


if __name__ == "__main__":
    main()

