
import os
import sys
import json
import argparse
import traceback
from typing import Dict, Any, List, Optional
from tqdm import tqdm
from PIL import Image
import uuid

current_dir = os.path.dirname(os.path.abspath(__file__))
preprocess_dir = os.path.dirname(current_dir)
utils_dir = os.path.join(preprocess_dir, "utils")
sys.path.insert(0, utils_dir)

try:
    from vllm import LLM, SamplingParams
    from transformers import AutoProcessor
    VLLM_AVAILABLE = True
except ImportError:
    VLLM_AVAILABLE = False
    print("Error: Cannot import vLLM, please install vLLM: pip install vllm")
    sys.exit(1)

from qwen3_mobile_use import MOBILE_USE_TOOL_SCHEMA


def build_annotation_system_prompt() -> str:
    tool_json = json.dumps(MOBILE_USE_TOOL_SCHEMA, ensure_ascii=False)
    return f"""# Tools

You may call one or more functions to assist with the user query.

You are provided with function signatures within <tools></tools> XML tags:
<tools>
{tool_json}
</tools>

For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
<tool_call>
{{"name": <function-name>, "arguments": <args-json-object>}}
</tool_call>


Response format for every step:
1) Thought: one concise sentence explaining the next move based on the image content and instruction (no multi-step reasoning).
   - The Thought MUST be faithful to the image content: accurately describe what you see in the image.
   - The Thought MUST be faithful to the instruction: correctly interpret the task requirement.
   - Structure your Thought with: (1) Observation of what you see in the image, (2) Reasoning about why the action is needed, (3) Connection to the task goal.
   - Common patterns:
     * "X is visible/active, indicating Y. To Z, I need to..."
     * "The current screen shows X. To Y, I need to..."
     * "To X, I need to Y, which is likely accessed by..."
     * "The user wants to X, and the current screen shows Y. To Z, I need to..."
2) Action: a short imperative describing what to do in the UI.
   - The Action MUST be faithful to the Thought: the action description must match what the Thought describes.
   - CRITICAL: The Action should describe WHAT UI element to interact with, NOT technical details like coordinates.
   - DO NOT include coordinates, pixel values, or technical specifications in the Action description.
   - Use natural language following these patterns:
     * Click: "Click on the heart icon to add the Superstar Shoes to the wishlist", "Click on the \"REVIEWS\" link", "Click on the search button"
     * Type: "Type 'razer deathadder' into the search bar", "Type text: Chunky chelsea Boots"
     * Swipe: "Swipe left to reveal more physics topics", "Swipe up to view the reviews", "Swipe up to scroll down and reveal more information"
     * System button: "Press Back to return to the previous screen", "Press the Home button to return to the home screen"
     * Terminate: "Terminate the task as it cannot be completed on this screen"
3) A single <tool_call>...</tool_call> block containing only the JSON: {{"name": <function-name>, "arguments": <args-json-object>}}.
   - The tool_call arguments MUST be faithful to the Action: the actual action parameters must match the Action description.
   - IMPORTANT: Coordinates in tool_call must be in normalized format (0-1000 range), NOT pixel coordinates. Convert pixel coordinates to 0-1000 range: normalized_x = pixel_x / image_width * 1000, normalized_y = pixel_y / image_height * 1000.


CRITICAL: You must ensure faithfulness at all levels:
1. Thought faithfulness: Your Thought must accurately reflect what you can see in the image and what the instruction says. Base your reasoning on actual observation, not on patterns or shortcuts.
2. Action faithfulness: Your Action description must be consistent with your Thought. The action you describe should be what your Thought suggests based on your observations.
3. Tool call faithfulness: The tool_call MUST be EXACTLY the same as the required tool_call provided in the user query. You MUST NOT modify any parameters, coordinates, or values. Use the exact tool_call JSON provided.
4. Ground Truth alignment: The user query will provide a REQUIRED tool_call that you must use exactly. Your output must use this exact tool_call without any modifications.


The user query will provide a REQUIRED tool_call that you must use exactly. Your task is to:
1. Observe the image carefully and describe what you see
2. Read the instruction carefully and interpret what it says
3. Understand what the required tool_call is asking you to do
4. Generate a Thought that explains WHY this required action makes sense based on your observation of the image and the instruction
5. Generate an Action that describes WHAT you will do (must match the required tool_call)
6. Use the EXACT tool_call provided (do not modify it)

Your Thought should explain WHY the required action is appropriate based on what you observe, not just state what the action is. Follow these patterns:

For click actions:
- Good: "The heart icon is visible next to the product name, indicating the option to add the item to the wishlist. Clicking this icon will fulfill the task of adding the Superstar Shoes to the wishlist."
- Good: "The image shows a product page with a section for ratings and reviews, indicating a high customer satisfaction rate. To read the detailed reviews, I should click on the \"REVIEWS\" link, which is likely to lead to a page with customer feedback."
- Good: "The search bar is active, and the keyboard is visible, indicating that I can type a new search query. To find the desired item, I need to enter the specific product name."
- Bad: "I need to click on the search box." (without explaining why based on observation)
- Bad: "Click on the button." (too vague, no observation or reasoning)

For type actions:
- Good: "The search bar is active, and the keyboard is visible, indicating that I can type a new search query. To find the desired item, I need to enter the specific product name."
- Good: "The user wants to search for 'razer deathadder' on Newegg. The search bar is visible and ready for input, so the next step is to type the search term into it."
- Bad: "I need to type text." (without context or observation)

For swipe actions:
- Good: "To explore more physics topics, I need to swipe left to reveal additional content that might be hidden off-screen."
- Good: "To read the reviews for the Switch FWD Running Shoes, I need to scroll down further as the reviews section is partially visible."
- Good: "To find more tips on solo travel, I need to scroll down the page to view additional content."
- Bad: "I need to swipe." (without explaining why or what you observe)

For system_button actions:
- Good: "The current screen is unrelated to the task of opening Spotify and creating a playlist. The user is in a web browser with Best Buy search results."
- Good: "The 'Back' button is visible and accessible, indicating that the user can navigate back to the previous screen."
- Good: "The task requires opening the Spotify app, but currently, I am on a web browser with a search for 'bestbuy.com'. This indicates that I need to navigate away from the browser and open the Spotify app."
- Bad: "I need to press Back." (without explaining why)

For terminate actions:
- Good: "The task was to open Spotify, create a playlist, add songs, and share it via email. However, the current screen shows a Best Buy cart page, indicating that the user has navigated away from Spotify. Since the task cannot be completed here, and there are no further actions possible without returning to Spotify, the appropriate action is to terminate the process."
- Bad: "I need to terminate." (without explaining why)

Your Action should describe WHAT UI element to interact with in natural language, NOT technical details. Follow these patterns:

For click actions:
- Good: "Click on the heart icon to add the Superstar Shoes to the wishlist"
- Good: "Click on the \"REVIEWS\" link to read customer reviews"
- Good: "Click on the three-dot menu icon in the top right corner to access Chrome settings"
- Good: "Click on the 'Add Time Zone' button to add the user's home time zone"
- Bad: "Click at coordinate [319, 1240]", "Click at coordinate [468, 1209] (pixel coordinates, image resolution 1080×2400)"

For type actions:
- Good: "Type 'razer deathadder' into the search bar"
- Good: "Type text: Chunky chelsea Boots"
- Bad: "Type at coordinate [x, y]"

For swipe actions:
- Good: "Swipe left to reveal more physics topics"
- Good: "Swipe up to view the reviews"
- Good: "Swipe up on the screen to view the Nagpur weather map"
- Good: "Swipe up to scroll down and reveal more information about solo travel tips"
- Bad: "Swipe from coordinate [x1, y1] to [x2, y2]"

For system_button actions:
- Good: "Press Back to return to the previous screen and navigate towards Spotify"
- Good: "Press the Home button to return to the home screen"
- Good: "Press the Back button to return to the previous screen"
- Bad: "Press system button at coordinate [x, y]"

For terminate actions:
- Good: "Terminate the task as it cannot be completed on this screen"
- Good: "Terminate the task with status: success"

IMPORTANT: You must use the exact tool_call provided in the user query. Do not generate your own tool_call or modify the provided one.

Rules:
- Output exactly in the order: Thought, Action, <tool_call>.
- You MUST start with "Thought:" label, then "Action:" label, then "<tool_call>" tag.
- Format: 
  Thought: [your one-sentence thought]
  Action: [your one-sentence action]
  <tool_call>
  {{"name": "mobile_use", "arguments": {{...}}}}
  </tool_call>
- Be brief: one sentence for Thought, one for Action.
- Do not output anything else outside those three parts.
- Do NOT skip the "Thought:" or "Action:" labels.
- If finishing, use action=terminate in the tool call.
- Always ensure faithfulness: Thought → Action → tool_call must be consistent.
- Always base your reasoning on actual observation of the image and instruction."""


def format_label_description(label: Dict[str, Any], image_width: int, image_height: int) -> str:
    action_type = label.get("action", "")
    
    if action_type == "click":
        return "Click on the appropriate UI element"
    
    elif action_type == "type":
        text = label.get("text", "")
        if text:
            return f"Type '{text}' into the input field"
        else:
            return "Type text in the input field"
    
    elif action_type == "swipe":
        return "Swipe in the appropriate direction"
    
    elif action_type == "system_button":
        button = label.get("button", "")
        if button:
            if button == "Back":
                return "Press Back to return to the previous screen"
            elif button == "Home":
                return "Press the Home button"
            elif button == "Menu":
                return "Press the Menu button"
            elif button == "Enter":
                return "Press the Enter button"
            else:
                return f"Press the {button} system button"
        else:
            return "Press a system button"
    
    elif action_type == "long_press":
        time = label.get("time", "")
        if time:
            return f"Long press for {time} seconds"
        else:
            return "Long press on the appropriate UI element"
    
    elif action_type == "wait":
        time = label.get("time", "")
        if time:
            return f"Wait for {time} seconds"
        else:
            return "Wait for the interface to update"
    
    elif action_type == "terminate":
        status = label.get("status", "")
        if status:
            return f"Terminate the task with status: {status}"
        else:
            return "Terminate the task"
    
    elif action_type == "answer":
        text = label.get("text", "")
        if text:
            return f"Provide the answer: {text}"
        else:
            return "Provide an answer"
    
    else:
        return f"Perform {action_type} action"


def convert_gt_action_to_tool_call(gt_action: Dict[str, Any], image_width: int = None, image_height: int = None) -> Dict[str, Any]:
    action_type = gt_action.get("action", "")
    tool_call_args = {"action": action_type}
    
    if action_type == "click":
        coord = gt_action.get("coordinate", [])
        if coord and len(coord) >= 2:
            tool_call_args["coordinate"] = [int(coord[0]), int(coord[1])]
    
    elif action_type == "type":
        text = gt_action.get("text", "")
        if text:
            tool_call_args["text"] = text
    
    elif action_type == "swipe":
        coord1 = gt_action.get("coordinate", [])
        coord2 = gt_action.get("coordinate2", [])
        if coord1 and len(coord1) >= 2:
            tool_call_args["coordinate"] = [int(coord1[0]), int(coord1[1])]
        if coord2 and len(coord2) >= 2:
            tool_call_args["coordinate2"] = [int(coord2[0]), int(coord2[1])]
    
    elif action_type == "system_button":
        button = gt_action.get("button", "")
        if button:
            tool_call_args["button"] = button
    
    elif action_type == "long_press":
        coord = gt_action.get("coordinate", [])
        time = gt_action.get("time", "")
        if coord and len(coord) >= 2:
            tool_call_args["coordinate"] = [int(coord[0]), int(coord[1])]
        if time:
            tool_call_args["time"] = time
    
    elif action_type == "wait":
        time = gt_action.get("time", "")
        if time:
            tool_call_args["time"] = time
    
    elif action_type == "terminate":
        status = gt_action.get("status", "")
        if status:
            tool_call_args["status"] = status
    
    elif action_type == "answer":
        text = gt_action.get("text", "")
        if text:
            tool_call_args["text"] = text
    
    return tool_call_args


def build_user_query(instruction: str, label: Dict[str, Any], image_width: int, image_height: int,
                     action_history: Optional[str] = None) -> str:
    label_desc = format_label_description(label, image_width, image_height)
    
    required_tool_call_args = convert_gt_action_to_tool_call(label, image_width, image_height)
    required_tool_call_json = json.dumps({"name": "mobile_use", "arguments": required_tool_call_args}, ensure_ascii=False)
    
    base_query = f"The user query: {instruction}.\n"
    
    base_query += f"\nRequired action type: {label_desc}\n"
    base_query += f"\nREQUIRED tool_call (you MUST use this exact tool_call):\n"
    base_query += f"<tool_call>\n{required_tool_call_json}\n</tool_call>\n"
    base_query += "\nYour task:\n"
    base_query += "1. Observe the image carefully and understand what UI elements are visible\n"
    base_query += "2. Read the instruction and understand what task needs to be completed\n"
    base_query += "3. Generate a Thought that explains WHY you need to perform the required action based on your observation\n"
    base_query += "4. Generate an Action that describes WHAT UI element you will interact with (use natural language, describe the element, NOT coordinates)\n"
    base_query += "5. Use the EXACT tool_call provided above (do not modify it)\n"
    base_query += "\nCRITICAL REQUIREMENTS:\n"
    base_query += "- Your Thought must be based on genuine observation of the image and interpretation of the instruction\n"
    base_query += "- Your Action must logically follow from your Thought and describe the UI element in natural language\n"
    base_query += "- DO NOT include coordinates, pixel values, or technical specifications in your Action description\n"
    base_query += "- For click actions: Describe what to click (e.g., 'Click on the heart icon', 'Click on the \"REVIEWS\" link', 'Click on the search button')\n"
    base_query += "- For type actions: Describe what to type and where (e.g., 'Type 'razer deathadder' into the search bar', 'Type text: Chunky chelsea Boots')\n"
    base_query += "- For swipe actions: Describe direction and purpose (e.g., 'Swipe left to reveal more content', 'Swipe up to view the reviews')\n"
    base_query += "- For system_button actions: Describe which button (e.g., 'Press Back to return to the previous screen', 'Press the Home button')\n"
    base_query += "- For terminate actions: Describe why (e.g., 'Terminate the task as it cannot be completed on this screen')\n"
    base_query += "- Your tool_call MUST be exactly the same as the required tool_call above (same action, same coordinates, same parameters)\n"
    base_query += "- The Thought → Action → tool_call chain must be logically consistent\n"
    
    if action_history:
        base_query += f"\nTask progress (You have done the following operation on the current device): {action_history}\n"
    
    return base_query


def load_image_as_pil(image_path: Optional[str]) -> Optional[Image.Image]:
    if not image_path or not os.path.exists(image_path):
        return None
    
    try:
        image = Image.open(image_path)
        if image.mode != "RGB":
            image = image.convert("RGB")
        return image
    except Exception as e:
        print(f"Warning: Cannot load image {image_path}: {e}")
        return None


class NormalDataAnnotator:
    
    def __init__(self, model_path: str, batch_size: int = 512, device_ids: str = "[0]", 
                 tensor_parallel_size: Optional[int] = None, seed: int = 42):
        self.model_path = model_path
        self.batch_size = batch_size
        self.seed = seed
        
        try:
            device_ids_list = eval(device_ids)
        except:
            device_ids_list = [0]
        
        if tensor_parallel_size is None:
            tensor_parallel_size = len(device_ids_list)
        
        if device_ids_list:
            os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, device_ids_list))
        
        print(f"[Annotator] Model path: {self.model_path}")
        print(f"[Annotator] Devices: {device_ids_list}, tensor_parallel_size: {tensor_parallel_size}")
        print(f"[Annotator] Batch size: {self.batch_size}")
        
        self._init_vllm_engine(tensor_parallel_size)
        
        print(f"[Annotator] Loading processor...")
        self.processor = AutoProcessor.from_pretrained(self.model_path, trust_remote_code=True)
        print(f"[Annotator] Processor loaded")
        
        self.system_prompt = build_annotation_system_prompt()
    
    def _init_vllm_engine(self, tensor_parallel_size: int):
        os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
        
        print(f"[Annotator] Initializing vLLM engine...")
        
        engine_args = {
            "model": self.model_path,
            "tensor_parallel_size": tensor_parallel_size,
            "seed": self.seed,
            "max_model_len": 16384,
            "mm_processor_kwargs": {
                "min_pixels": 32 * 32,
                "max_pixels": 8192 * 8192,
                "fps": 1,
            },
            "limit_mm_per_prompt": {"image": 1, "video": 0, "audio": 0},
        }
        
        self.llm = LLM(**engine_args)
        self.sampling_params = SamplingParams(
            temperature=0,
            max_tokens=1024
        )
        
        print(f"[Annotator] vLLM engine initialized")
    
    def annotate_batch(self, samples: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        vllm_inputs = []
        sample_metadata = []
        
        print(f"[Annotator] Building vLLM inputs for {len(samples)} samples...")
        
        for sample in tqdm(samples, desc="Building inputs"):
            image_path = sample.get("image_path", None)
            if not image_path:
                images = sample.get("images", [])
                if images and len(images) > 0:
                    image_path = images[0]
            
            instruction = sample.get("instruction", "")
            gt_action = sample.get("gt_action", {})
            action_history = sample.get("action_history", "")
            
            image_width = sample.get("image_width", None)
            image_height = sample.get("image_height", None)
            
            if (image_width is None or image_height is None) and image_path and os.path.exists(image_path):
                try:
                    pil_image = load_image_as_pil(image_path)
                    if pil_image is not None:
                        if image_width is None:
                            image_width = pil_image.width
                        if image_height is None:
                            image_height = pil_image.height
                except:
                    pass
            
            if image_width is None:
                image_width = 1080
            if image_height is None:
                image_height = 2400
            
            metadata = {
                "sample": sample,
                "image_width": image_width,
                "image_height": image_height
            }
            
            try:
                if not instruction:
                    vllm_inputs.append({"prompt": ""})
                    metadata["error"] = "Missing instruction field"
                    sample_metadata.append(metadata)
                    continue
                
                if not gt_action:
                    vllm_inputs.append({"prompt": ""})
                    metadata["error"] = "Missing gt_action field"
                    sample_metadata.append(metadata)
                    continue
                
                user_query = build_user_query(instruction, gt_action, image_width, image_height, action_history)
                
                if not image_path:
                    vllm_input = {"prompt": f"{self.system_prompt}\n\n{user_query}"}
                    metadata["error"] = "Image path is empty"
                    vllm_inputs.append(vllm_input)
                    sample_metadata.append(metadata)
                    continue
                
                if not os.path.exists(image_path):
                    vllm_input = {"prompt": f"{self.system_prompt}\n\n{user_query}"}
                    metadata["error"] = f"Image not found: {image_path}"
                    vllm_inputs.append(vllm_input)
                    sample_metadata.append(metadata)
                    continue
                
                if image_path and os.path.exists(image_path):
                    pil_image = load_image_as_pil(image_path)
                    if pil_image is not None:
                        messages = [
                            {
                                "role": "system",
                                "content": [{"type": "text", "text": self.system_prompt}]
                            },
                            {
                                "role": "user",
                                "content": [
                                    {"type": "image", "image": "placeholder_0"},
                                    {"type": "text", "text": user_query}
                                ]
                            }
                        ]
                        
                        prompt_text = self.processor.apply_chat_template(
                            messages, tokenize=False, add_generation_prompt=True
                        )
                        
                        image_uuid = f"uuid_{uuid.uuid4().hex[:8]}"
                        vllm_input = {
                            "prompt": prompt_text,
                            "multi_modal_data": {"image": [pil_image]},
                            "multi_modal_uuids": {"image": [image_uuid]},
                        }
                    else:
                        full_prompt = f"{self.system_prompt}\n\n{user_query}"
                        vllm_input = {"prompt": full_prompt}
                        metadata["error"] = "Failed to load image"
                else:
                    full_prompt = f"{self.system_prompt}\n\n{user_query}"
                    vllm_input = {"prompt": full_prompt}
                    metadata["error"] = "Image not found"
                
                vllm_inputs.append(vllm_input)
                sample_metadata.append(metadata)
                
            except Exception as e:
                print(f"[Annotator] Failed to build input: {e}")
                traceback.print_exc()
                vllm_inputs.append({"prompt": ""})
                metadata["error"] = str(e)
                sample_metadata.append(metadata)
        
        print(f"[Annotator] Starting vLLM batch inference for {len(vllm_inputs)} samples, batch size: {self.batch_size}...")
        all_outputs = []
        
        try:
            total_batches = (len(vllm_inputs) + self.batch_size - 1) // self.batch_size
            for batch_idx in tqdm(range(0, len(vllm_inputs), self.batch_size), desc="vLLM inference batches"):
                batch_end = min(batch_idx + self.batch_size, len(vllm_inputs))
                batch_inputs = vllm_inputs[batch_idx:batch_end]
                
                batch_num = batch_idx // self.batch_size + 1
                print(f"[Annotator] Processing batch {batch_num}/{total_batches}: samples {batch_idx}-{batch_end-1}")
                
                try:
                    batch_outputs = self.llm.generate(batch_inputs, self.sampling_params)
                    all_outputs.extend(batch_outputs)
                    print(f"[Annotator] Batch {batch_num} completed, processed {batch_end}/{len(vllm_inputs)} samples")
                except Exception as e:
                    print(f"[Annotator] Batch {batch_num} failed: {e}")
                    traceback.print_exc()
                    class EmptyOutputItem:
                        def __init__(self):
                            self.text = ""
                    class EmptyRequestOutput:
                        def __init__(self):
                            self.outputs = [EmptyOutputItem()]
                    
                    for _ in range(len(batch_inputs)):
                        all_outputs.append(EmptyRequestOutput())
                    print(f"[Annotator] Added placeholders for failed batch, continuing...")
            
            print(f"[Annotator] vLLM inference completed, parsing results...")
        except Exception as e:
            print(f"[Annotator] vLLM inference failed: {e}")
            traceback.print_exc()
            results = []
            for metadata in sample_metadata:
                error = metadata.get("error", f"vLLM inference failed: {str(e)}")
                results.append({
                    "sample": metadata["sample"],
                    "error": error,
                    "output_text": ""
                })
            return results
        
        results = []
        for output, metadata in zip(all_outputs, sample_metadata):
            original_sample = metadata["sample"]
            result = original_sample.copy()
            
            if "error" in metadata:
                result["annotation_error"] = metadata["error"]
                result["output_text"] = ""
                results.append(result)
                continue
            
            try:
                output_text = output.outputs[0].text
                result["output_text"] = output_text
                
                thought = ""
                action = ""
                tool_call = None
                
                if "Thought:" in output_text:
                    thought_part = output_text.split("Thought:")[1].split("Action:")[0].strip() if "Action:" in output_text else output_text.split("Thought:")[1].strip()
                    thought = thought_part.split("\n")[0].strip()
                else:
                    if "Action:" in output_text:
                        thought = output_text.split("Action:")[0].strip()
                        thought = "\n".join([line.strip() for line in thought.split("\n")[:2] if line.strip()])
                    elif "<tool_call>" in output_text:
                        thought = output_text.split("<tool_call>")[0].strip()
                        thought = "\n".join([line.strip() for line in thought.split("\n")[:2] if line.strip()])
                
                if "Action:" in output_text:
                    action_part = output_text.split("Action:")[1].split("<tool_call>")[0].strip() if "<tool_call>" in output_text else output_text.split("Action:")[1].strip()
                    action = action_part.split("\n")[0].strip()
                else:
                    if "<tool_call>" in output_text:
                        before_toolcall = output_text.split("<tool_call>")[0].strip()
                        if "Thought:" in before_toolcall:
                            action_part = before_toolcall.split("Thought:")[-1].split("\n")
                            action_lines = [line.strip() for line in action_part[1:] if line.strip()]
                            if action_lines:
                                action = action_lines[0]
                        else:
                            lines = [line.strip() for line in before_toolcall.split("\n") if line.strip()]
                            if len(lines) > 1:
                                action = lines[-1]
                            elif lines:
                                action = lines[0]
                
                if "<tool_call>" in output_text and "</tool_call>" in output_text:
                    try:
                        tool_call_str = output_text.split("<tool_call>")[1].split("</tool_call>")[0].strip()
                        tool_call = json.loads(tool_call_str)
                    except json.JSONDecodeError as e:
                        print(f"[Annotator] Failed to parse tool_call: {e}")
                        tool_call = None
                
                gt_action = original_sample.get("gt_action", {})
                if gt_action:
                    image_width = metadata.get("image_width", 1080)
                    image_height = metadata.get("image_height", 2400)
                    correct_tool_call_args = convert_gt_action_to_tool_call(gt_action, image_width, image_height)
                    correct_tool_call = {"name": "mobile_use", "arguments": correct_tool_call_args}
                    
                    if tool_call and "arguments" in tool_call:
                        parsed_args = tool_call["arguments"]
                        if (parsed_args.get("action") != correct_tool_call_args.get("action") or
                            parsed_args != correct_tool_call_args):
                            print(f"[Annotator] Warning: tool_call doesn't match gt_action, using correct one")
                            tool_call = correct_tool_call
                    else:
                        tool_call = correct_tool_call
                else:
                    pass
                
                result["thought"] = thought
                result["action"] = action
                result["tool_call"] = tool_call
                
                if tool_call and "arguments" in tool_call:
                    result["predicted_qwen3"] = tool_call["arguments"]
                
            except Exception as e:
                print(f"[Annotator] Failed to parse result: {e}")
                result["annotation_error"] = f"Parse failed: {str(e)}"
                result["output_text"] = output.outputs[0].text if hasattr(output, 'outputs') and output.outputs else ""
            
            results.append(result)
        
        return results


def fix_answer_tool_call(answer_text: str, gt_action: Dict[str, Any], image_width: int, image_height: int) -> str:
    if not answer_text or not gt_action:
        return answer_text
    
    correct_tool_call_args = convert_gt_action_to_tool_call(gt_action, image_width, image_height)
    correct_tool_call = {"name": "mobile_use", "arguments": correct_tool_call_args}
    correct_tool_call_json = json.dumps(correct_tool_call, ensure_ascii=False)
    
    if "<tool_call>" in answer_text and "</tool_call>" in answer_text:
        try:
            before_toolcall = answer_text.split("<tool_call>")[0]
            after_toolcall = answer_text.split("</tool_call>")[1]
            
            existing_toolcall_str = answer_text.split("<tool_call>")[1].split("</tool_call>")[0].strip()
            existing_toolcall = json.loads(existing_toolcall_str)
            
            if existing_toolcall.get("name") == correct_tool_call.get("name"):
                existing_args = existing_toolcall.get("arguments", {})
                correct_args = correct_tool_call.get("arguments", {})
                
                if (existing_args.get("action") == correct_args.get("action") and
                    existing_args == correct_args):
                    return answer_text
                else:
                    print(f"[Fix] Tool_call mismatch detected, fixing...")
                    print(f"  Existing: {existing_toolcall_str[:100]}...")
                    print(f"  Correct: {correct_tool_call_json[:100]}...")
            else:
                print(f"[Fix] Tool_call name mismatch, fixing...")
            
            fixed_answer = before_toolcall + "<tool_call>\n" + correct_tool_call_json + "\n</tool_call>" + after_toolcall
            return fixed_answer
            
        except json.JSONDecodeError as e:
            print(f"[Fix] Failed to parse existing tool_call, replacing...")
            before_toolcall = answer_text.split("<tool_call>")[0]
            after_toolcall = answer_text.split("</tool_call>")[1]
            fixed_answer = before_toolcall + "<tool_call>\n" + correct_tool_call_json + "\n</tool_call>" + after_toolcall
            return fixed_answer
    else:
        print(f"[Fix] No tool_call found in answer, appending...")
        fixed_answer = answer_text.rstrip() + "\n<tool_call>\n" + correct_tool_call_json + "\n</tool_call>"
        return fixed_answer
    
    return answer_text


def main():
    parser = argparse.ArgumentParser(description='Annotate normal data using vLLM batch processing')
    parser.add_argument('--input_file', type=str, default="/INPUT_FILE",
                       help='Input file path')
    parser.add_argument('--model_path', type=str, default="/MODEL_PATH",
                       help='Model path')
    parser.add_argument('--batch_size', type=int, default=512,
                       help='Batch size')
    parser.add_argument('--device_ids', type=str, default="[0,1,2,3,4,5,6,7]",
                       help='CUDA device IDs (string format, e.g., "[0]" or "[0,1]")')
    parser.add_argument('--tensor_parallel_size', type=int, default=8,
                       help='Tensor parallel size (if not specified, calculated from device_ids length)')
    parser.add_argument('--seed', type=int, default=42,
                       help='Random seed')
    parser.add_argument('--max_samples', type=int, default=None,
                       help='Maximum number of samples to process (for testing, None means all samples)')
    
    args = parser.parse_args()
    
    print(f"[Load] Loading data: {args.input_file}")
    with open(args.input_file, 'r', encoding='utf-8') as f:
        data = json.load(f)
    print(f"[Load] Number of samples: {len(data)}")
    
    original_data = []
    for sample in data:
        original_sample = {}
        for key in sample.keys():
            original_sample[key] = sample[key]
        original_data.append(original_sample)
    
    valid_indices = []
    filtered_data = []
    for idx, sample in enumerate(data):
        if sample.get("gt_action") is not None:
            valid_indices.append(idx)
            filtered_data.append(sample)
    
    original_count = len(data)
    filtered_count = len(filtered_data)
    print(f"[Load] Filtered {original_count - filtered_count} samples with null gt_action, remaining: {filtered_count}")
    
    if args.max_samples:
        valid_indices = valid_indices[:args.max_samples]
        filtered_data = filtered_data[:args.max_samples]
        print(f"[Load] Limited to first {len(filtered_data)} samples")
    
    if len(filtered_data) == 0:
        print(f"[Warning] No valid samples to process")
        return
    
    print(f"[Init] Initializing annotator...")
    annotator = NormalDataAnnotator(
        model_path=args.model_path,
        batch_size=args.batch_size,
        device_ids=args.device_ids,
        tensor_parallel_size=args.tensor_parallel_size,
        seed=args.seed
    )
    
    print(f"[Annotate] Starting batch annotation...")
    results = annotator.annotate_batch(filtered_data)
    
    updated_count = 0
    fixed_count = 0
    for idx, result in enumerate(results):
        original_idx = valid_indices[idx]
        if original_idx < len(original_data):
            if "output_text" in result and result["output_text"]:
                answer_text = result["output_text"]
                gt_action = original_data[original_idx].get("gt_action")
                
                image_width = None
                image_height = None
                if "image_width" in original_data[original_idx]:
                    image_width = original_data[original_idx]["image_width"]
                if "image_height" in original_data[original_idx]:
                    image_height = original_data[original_idx]["image_height"]
                
                if (image_width is None or image_height is None):
                    images = original_data[original_idx].get("images", [])
                    image_path = original_data[original_idx].get("image_path")
                    if not image_path and images and len(images) > 0:
                        image_path = images[0]
                    
                    if image_path and os.path.exists(image_path):
                        try:
                            pil_image = load_image_as_pil(image_path)
                            if pil_image is not None:
                                if image_width is None:
                                    image_width = pil_image.width
                                if image_height is None:
                                    image_height = pil_image.height
                        except:
                            pass
                
                if image_width is None:
                    image_width = 1080
                if image_height is None:
                    image_height = 2400
                
                if gt_action:
                    fixed_answer = fix_answer_tool_call(answer_text, gt_action, image_width, image_height)
                    if fixed_answer != answer_text:
                        fixed_count += 1
                    original_data[original_idx]["answer"] = fixed_answer
                else:
                    original_data[original_idx]["answer"] = answer_text
                
                updated_count += 1
            elif "annotation_error" in result:
                print(f"[Warning] Sample {original_idx} has annotation error: {result['annotation_error']}")
    
    if args.input_file.endswith('.json'):
        output_file = args.input_file.replace('.json', '_annotated.json')
    else:
        output_file = args.input_file + '_annotated'
    
    print(f"[Save] Saving results to: {output_file}")
    os.makedirs(os.path.dirname(output_file) if os.path.dirname(output_file) else ".", exist_ok=True)
    with open(output_file, 'w', encoding='utf-8') as f:
        json.dump(original_data, f, ensure_ascii=False, indent=2)
    
    total = len(results)
    success_count = sum(1 for r in results if "annotation_error" not in r and r.get("output_text"))
    error_count = sum(1 for r in results if "annotation_error" in r)
    
    print(f"\n{'='*60}")
    print("Annotation completed!")
    print(f"{'='*60}")
    print(f"Total samples: {total}")
    print(f"Successfully annotated: {success_count} ({success_count/total*100:.2f}%)")
    print(f"Failed: {error_count} ({error_count/total*100:.2f}%)")
    print(f"Updated answer fields: {updated_count}")
    print(f"Fixed tool_call alignment: {fixed_count}")
    print(f"Output file: {output_file}")


if __name__ == "__main__":
    main()

