from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2VLForConditionalGeneration, AutoProcessor
from typing import Dict, Any, Union
from trl.data_utils import maybe_apply_chat_template
import torch
from copy import deepcopy
from open_r1.vlm_modules.vlm_module import VLMBaseModule
from PIL import Image

class Qwen2VLModule(VLMBaseModule):
    def __init__(self):
        super().__init__()

    def get_vlm_key(self):
        return "qwen"

    def get_model_class(self, model_id: str, model_init_kwargs: dict):
        if "Qwen2-VL" in model_id:
            model_cls = Qwen2VLForConditionalGeneration
        elif "Qwen2.5-VL" in model_id:
            model_cls = Qwen2_5_VLForConditionalGeneration
        else:
            raise ValueError(f"Unsupported model: {model_id}")
        return model_cls
    
    def post_model_init(self, model, processing_class):
        pass
    
    def get_processing_class(self):
        return AutoProcessor
    
    def get_vision_modules_keywords(self):  
        return ['visual']
    
    def get_custom_multimodal_keywords(self):
        return ['pixel_values', 'image_grid_thw']

    def get_non_generate_params(self):
        return []
    
    def get_custom_processing_keywords(self):
        return [('image_processor', 'max_pixels'), ('image_processor', 'min_pixels')]
    
    def prepare_prompt(self, processing_class, inputs: dict[str, Union[torch.Tensor, Any]]):
        prompts_text = [maybe_apply_chat_template(example, processing_class)["prompt"] for example in inputs]
        return prompts_text
    
    def prepare_model_inputs(self, processing_class, prompts_text, images, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False):
        # FIXME
        # This could only process pure-multimodal or pure-text inputs
        additional_output = None
        if len(images) > 0:
            prompt_inputs = processing_class(
                text=prompts_text,
                images=images,
                return_tensors=return_tensors,
                padding=padding,
                padding_side=padding_side,
                add_special_tokens=add_special_tokens)
            additional_output = [{'image_grid_thw': image_grid_thw} for image_grid_thw in prompt_inputs['image_grid_thw']]
        else:
            prompt_inputs = processing_class(
                text=prompts_text,
                return_tensors=return_tensors,
                padding=padding,
                padding_side=padding_side,
                add_special_tokens=add_special_tokens)
        return prompt_inputs, additional_output
    
    @staticmethod
    def get_question_template(task_type: str):
        if task_type == "rec":
            return "{Question} First output the thinking process in <think> </think> tags and then output the final answer in <answer> </answer> tags. Output the final answer in JSON format."
        elif task_type == "ic":
            return "{Question} First thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> json format answer here </answer>"
        elif task_type == "odLength":
            SYSTEM_PROMPT = (
                #"A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
                "First thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
                "process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
                "<think> reasoning process here </think><answer> answer here </answer>"
            )
            return SYSTEM_PROMPT + '\n' + "{Question}"
        elif task_type == "android":
            SYSTEM_PROMPT = (
                """
                You are a helpful assistant. You need to operate the mobile phone based on the instruction.
                You will receive an instruction and a screenshot of the mobile phone screen.
                You should first think about the reasoning process in your mind, and then output the final answer in the form of a tool call. 
                The reasoning process and answer are enclosed within <think> </think> and <tool_call> </tool_call> tags, respectively, i.e., <think> reasoning process here </think><tool_call> JSON format answer here </tool_call>
                The output MUST follow the rules below:
                # Output examples(You need to output the right answer according to the instruction):
                - When the instruction is related to "long press" such as "long press the backspace key" , the output should be like:
                <think> reasoning process here </think><tool_call>{{"action": "long_press", "coordinate": [xxx, xxx]}}</tool_call>.
                - When the instruction is related to "click" or "tap" such as "press the backspace key" , the output should be like:
                <think> reasoning process here </think><tool_call>{{"action": "click", "coordinate": [xxx, xxx]}}</tool_call>.
                - When the instruction is related to "swipe" , the output should be like:
                <think> reasoning process here </think><tool_call>{{"action": "swipe", "coordinate": [xxx, xx], "coordinate2": [xxx, xxx]}}</tool_call>
                - When the instruction is related to "Back button" "Home button" , the output should be like:
                <think> reasoning process here </think><tool_call>{{"action": "system_button", "button": "Back"}}</tool_call>
                - When the instruction is related to "type" such as "Type the text '307.1' in the field", the output should be like:
                <think> reasoning process here </think><tool_call>{{"action": "type", "text": "307.1"}}</tool_call>
                - When the task is completed, the output should be like:
                <think> reasoning process here </think><tool_call>{{"action": "terminate", "status": "success"}}</tool_call>
                
                """
            )
            return SYSTEM_PROMPT + '\n' + "{Instruction}"
        elif task_type == "android_tool_call_only":
            SYSTEM_PROMPT = (
                """
                You are a helpful assistant. You need to operate the mobile phone based on the instruction.
                You will receive an instruction and a screenshot of the mobile phone screen.
                You should directly output the final answer in the form of a tool call without any thinking process.
                The answer should be enclosed within <tool_call> </tool_call> tags, i.e., <tool_call> JSON format answer here </tool_call>
                The output MUST follow the rules below:
                # Output examples(You need to output the right answer according to the instruction):
                - When the instruction is related to "long press" such as "long press the backspace key" , the output should be like:
                <tool_call>{{"action": "long_press", "coordinate": [xxx, xxx]}}</tool_call>.
                - When the instruction is related to "click" or "tap" such as "press the backspace key" , the output should be like:
                <tool_call>{{"action": "click", "coordinate": [xxx, xxx]}}</tool_call>.
                - When the instruction is related to "swipe" , the output should be like:
                <tool_call>{{"action": "swipe", "coordinate": [xxx, xx], "coordinate2": [xxx, xxx]}}</tool_call>
                - When the instruction is related to "Back button" "Home button" , the output should be like:
                <tool_call>{{"action": "system_button", "button": "Back"}}</tool_call>
                - When the instruction is related to "type" such as "Type the text '307.1' in the field", the output should be like:
                <tool_call>{{"action": "type", "text": "307.1"}}</tool_call>
                - When the task is completed, the output should be like:
                <tool_call>{{"action": "terminate", "status": "success"}}</tool_call>
                
                """
            )
            return SYSTEM_PROMPT + '\n' + "{Instruction}"
        elif task_type == "android_hierarchical":
            SYSTEM_PROMPT = (
                """
                You are a helpful assistant that can operate mobile phones to complete tasks.
                You will receive:
                1. A high-level task description
                2. A screenshot of the current mobile phone screen  
                3. A history of previous actions taken (if any)
                
                Based on the current situation, you should analyze what needs to be done next to progress toward completing the overall task.
                You should first think about the reasoning process in your mind, considering the task goal, current screen state, and what actions have been taken so far.
                Then output the next specific action in the form of a tool call.
                
                The reasoning process and answer are enclosed within <think> </think> and <tool_call> </tool_call> tags, respectively, i.e., <think> reasoning process here </think><tool_call> JSON format answer here </tool_call>
                
                The output MUST follow the rules below:
                # Output examples(You need to output the right answer according to the current situation):
                - For click/tap actions, the output should be like:
                <think> I need to analyze the current screen and determine where to click based on the task goal and previous actions. </think><tool_call>{{"action": "click", "coordinate": [xxx, xxx]}}</tool_call>
                - For long press actions, the output should be like:
                <think> Based on the task requirements, I need to long press on this element. </think><tool_call>{{"action": "long_press", "coordinate": [xxx, xxx]}}</tool_call>
                - For swipe actions, the output should be like:
                <think> I need to swipe to navigate or scroll to find the required element. </think><tool_call>{{"action": "swipe", "coordinate": [xxx, xxx], "coordinate2": [xxx, xxx]}}</tool_call>
                - For system button actions, the output should be like:
                <think> I need to use a system button to navigate. </think><tool_call>{{"action": "system_button", "button": "Back"}}</tool_call>
                - For text input actions, the output should be like:
                <think> I need to type specific text to complete this step of the task. </think><tool_call>{{"action": "type", "text": "specific text"}}</tool_call>
                - When the task is completed, the output should be like:
                <think> The task has been successfully completed based on all the actions taken. </think><tool_call>{{"action": "terminate", "status": "success"}}</tool_call>
                
                """
            )
            return SYSTEM_PROMPT + '\n' + "Task: {Task}\nAction History: {ActionHistory}\nCurrent Instruction: {Instruction}"
        elif task_type == "android_hierarchical_no_thinking":
            SYSTEM_PROMPT = (
                """
                You are a helpful assistant that can operate mobile phones to complete tasks.
                You will receive:
                1. A high-level task description
                2. A screenshot of the current mobile phone screen  
                3. A history of previous actions taken (if any)
                
                Based on the current situation, you should analyze what needs to be done next to progress toward completing the overall task.
                You should directly output the next specific action in the form of a tool call without any thinking process.
                
                The answer should be enclosed within <tool_call> </tool_call> tags, i.e., <tool_call> JSON format answer here </tool_call>
                
                The output MUST follow the rules below:
                # Output examples(You need to output the right answer according to the current situation):
                - For click/tap actions, the output should be like:
                <tool_call>{{"action": "click", "coordinate": [xxx, xxx]}}</tool_call>
                - For long press actions, the output should be like:
                <tool_call>{{"action": "long_press", "coordinate": [xxx, xxx]}}</tool_call>
                - For swipe actions, the output should be like:
                <tool_call>{{"action": "swipe", "coordinate": [xxx, xxx], "coordinate2": [xxx, xxx]}}</tool_call>
                - For system button actions, the output should be like:
                <tool_call>{{"action": "system_button", "button": "Back"}}</tool_call>
                - For text input actions, the output should be like:
                <tool_call>{{"action": "type", "text": "specific text"}}</tool_call>
                - When the task is completed, the output should be like:
                <tool_call>{{"action": "terminate", "status": "success"}}</tool_call>
                
                """
            )
            return SYSTEM_PROMPT + '\n' + "Task: {Task}\nAction History: {ActionHistory}\nCurrent Instruction: {Instruction}"
        elif task_type == "android_high_level":
            SYSTEM_PROMPT = (
                """
                You are a mobile operation Agent that performs precise screen interactions. Analyze the input and generate the next action instuction. 
                You should only output concise and clear action instructions, which can include action types and action targets, without specific coordinates.
                The reasoning process and answer are enclosed within <think> </think> and <instruction_step> </instruction_step> tags, respectively, i.e., <think> reasoning process here </think><instruction_step> answer here </instruction_step>
                """
            )
            return SYSTEM_PROMPT + '\n' + "{Instruction}"
        else:
            return "{Question} First output the thinking process in <think> </think> tags and then output the final answer in <answer> </answer> tags."
            
    @staticmethod
    def format_reward_rec(completions, **kwargs):
        """Check if the Qwen model output matches a specific format."""
        import re
        import os
        from datetime import datetime
        pattern = r"<think>.*?</think>\s*<answer>.*?\{.*\[\d+,\s*\d+,\s*\d+,\s*\d+\].*\}.*?</answer>"
        completion_contents = [completion[0]["content"] for completion in completions]
        matches = [re.search(pattern, content, re.DOTALL) is not None for content in completion_contents]

        current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
        if os.getenv("DEBUG_MODE") == "true":
            log_path = os.getenv("LOG_PATH")
            if log_path is not None:
                with open(log_path.replace(".txt", "_format.txt"), "a", encoding='utf-8') as f:
                    f.write(f"------------- {current_time} Format reward -------------\n")
                    for content, match in zip(completion_contents, matches):
                        f.write(f"Content: {content}\n")
                        f.write(f"Has format: {bool(match)}\n")
        return [1.0 if match else 0.0 for match in matches]
    
    @staticmethod
    def format_reward_android(completions, **kwargs):
        """Check if the Android model output matches the required format. 
        Supports three modes:
        - android: <think>...</think><tool_call>...</tool_call>
        - android_tool_call_only: <tool_call>...</tool_call>
        - android_hierarchical: <think>...</think><tool_call>...</tool_call>
        """
        import re
        import json
        import os
        from datetime import datetime
        
        completion_contents = [completion[0]["content"] for completion in completions]
        rewards = []
        
        # Get task type from kwargs to determine the expected format
        task_type = kwargs.get("task_type", "android")
        
        if task_type in ["android_tool_call_only", "android_hierarchical_no_thinking"]:
            # Pattern for tool_call only modes: <tool_call>...</tool_call>
            pattern = r"<tool_call>(.*?)</tool_call>"
        elif task_type == "android_high_level":
            # Pattern for high-level mode: <reasoning>...</reasoning><instruction>Instruction: ...</instruction>
            # pattern = r"<reasoning>.*?</reasoning>\s*<instruction>\s*Instruction:\s*(.*?)\s*</instruction>"
            pattern = r"<think>.*?</think>\s*<instruction_step>(.*?)</instruction_step>"
        elif task_type in ["android", "android_hierarchical"]:
            # Pattern for android and android_hierarchical modes: <think>...</think><tool_call>...</tool_call>
            pattern = r"<think>.*?</think>\s*<tool_call>(.*?)</tool_call>"
        else:
            # Default pattern for android mode: <think>...</think><tool_call>...</tool_call>
            pattern = r"<think>.*?</think>\s*<tool_call>(.*?)</tool_call>"
        
        for i, content in enumerate(completion_contents):
            reward = 0.0
            
            try:
                if task_type == "android_high_level":
                    # For high-level mode, use graded reward system (no strict pattern matching required)
                    reward = 0.0
                    
                    # think 部分奖励  
                    if "<think>" in content:
                        reward += 0.2
                    elif "(think)" in content:
                        reward += 0.1
                    elif "think" in content:
                        reward += 0.05
                        
                    if "</think>" in content:
                        reward += 0.2
                    elif "(/think)" in content:
                        reward += 0.1
                    elif "/think" in content:
                        reward += 0.05

                    # instruction_step 部分奖励
                    if "<instruction_step>" in content:
                        reward += 0.3
                    elif "(instruction_step)" in content:
                        reward += 0.15
                    elif "instruction_step" in content:
                        reward += 0.075
                        
                    if "</instruction_step>" in content:
                        reward += 0.3
                    elif "(/instruction_step)" in content:
                        reward += 0.15
                    elif "/instruction_step" in content:
                        reward += 0.075
                else:
                    # For other modes, use strict pattern matching
                    match = re.search(pattern, content, re.DOTALL)
                    if match:
                        content_inside_tags = match.group(1).strip()
                        
                        # Try to parse the tool_call content as JSON for other Android modes
                        try:
                            tool_data = json.loads(content_inside_tags)
                            
                            # Check if it has the required "action" field
                            if "action" in tool_data:
                                action = tool_data["action"]
                                
                                # Validate based on action type
                                if action in ["click", "long_press"]:
                                    # Should have coordinate field with [x, y] format
                                    if "coordinate" in tool_data and isinstance(tool_data["coordinate"], list) and len(tool_data["coordinate"]) == 2:
                                        reward = 1.0
                                elif action == "swipe":
                                    # Should have coordinate and coordinate2 fields
                                    if ("coordinate" in tool_data and isinstance(tool_data["coordinate"], list) and len(tool_data["coordinate"]) == 2 and
                                        "coordinate2" in tool_data and isinstance(tool_data["coordinate2"], list) and len(tool_data["coordinate2"]) == 2):
                                        reward = 1.0
                                elif action == "system_button":
                                    # Should have button field
                                    if "button" in tool_data and isinstance(tool_data["button"], str):
                                        reward = 1.0
                                elif action == "type":
                                    # Should have text field
                                    if "text" in tool_data and isinstance(tool_data["text"], str):
                                        reward = 1.0
                                elif action == "terminate":
                                    # Should have status field
                                    if "status" in tool_data and isinstance(tool_data["status"], str):
                                        reward = 1.0
                                # Add more action types as needed
                                
                        except json.JSONDecodeError:
                            # Invalid JSON in tool_call
                            reward = 0.0
                        
            except Exception as e:
                reward = 0.0
            
            rewards.append(reward)
            
            # Debug logging if enabled
            # ============================DEBUG MODE====================================================
            if os.getenv("DEBUG_MODE") == "true":
                log_path = os.getenv("LOG_PATH")
                if log_path is not None:
                    current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
                    format_log_path = log_path.replace(".txt", "_format.txt")
                    
                    # Build the complete log entry as a single string for atomic writing
                    log_entry_parts = [
                        f"------------- {current_time} Android Format Reward: {reward} -------------\n",
                        f"Sample {i}: \n"
                    ]
                
                # Check what went wrong if reward is 0
                if reward == 0.0:
                    log_entry_parts.append("❌ Format validation failed:\n")
                    
                    # Check pattern match based on task type
                    if task_type in ["android_tool_call_only", "android_hierarchical_no_thinking"]:
                        expected_format = "<tool_call>...</tool_call>"
                        tool_only_pattern = r"<tool_call>(.*?)</tool_call>"
                        match = re.search(tool_only_pattern, content, re.DOTALL)
                        if not match:
                            log_entry_parts.append(f"   - Missing {expected_format} structure\n")
                        else:
                            tool_call_content = match.group(1).strip()
                            log_entry_parts.append(f"   - Found tool_call content: {tool_call_content[:100]}...\n")
                    elif task_type == "android_high_level":
                        expected_format = "<reasoning>...</reasoning><instruction>...</instruction>"
                        high_level_pattern = r"<reasoning>.*?</reasoning>\s*<instruction>(.*?)</instruction>"
                        match = re.search(high_level_pattern, content, re.DOTALL)
                        if not match:
                            log_entry_parts.append(f"   - Missing {expected_format} structure\n")
                        else:
                            instruction_content = match.group(1).strip()
                            log_entry_parts.append(f"   - Found instruction content: {instruction_content[:100]}...\n")
                    else:
                        expected_format = "<think>...</think><tool_call>...</tool_call>"
                        think_tool_pattern = r"<think>.*?</think>\s*<tool_call>(.*?)</tool_call>"
                        match = re.search(think_tool_pattern, content, re.DOTALL)
                        if not match:
                            log_entry_parts.append(f"   - Missing {expected_format} structure\n")
                        else:
                            tool_call_content = match.group(1).strip()
                            log_entry_parts.append(f"   - Found tool_call content: {tool_call_content[:100]}...\n")
                    
                    if match and task_type != "android_high_level":
                        tool_call_content = match.group(1).strip()
                        log_entry_parts.append(f"   - Found tool_call content: {tool_call_content[:100]}...\n")
                        
                        try:
                            tool_data = json.loads(tool_call_content)
                            log_entry_parts.append(f"   - JSON parsing successful\n")
                            log_entry_parts.append(f"   - Found action: {tool_data.get('action', 'MISSING')}\n")
                            
                            # Check specific action validation
                            action = tool_data.get("action")
                            if action in ["click", "long_press"]:
                                coord = tool_data.get("coordinate", [])
                                log_entry_parts.append(f"   - Click/LongPress validation: coordinate={coord}, valid={isinstance(coord, list) and len(coord) == 2}\n")
                            elif action == "swipe":
                                coord1 = tool_data.get("coordinate", [])
                                coord2 = tool_data.get("coordinate2", [])
                                log_entry_parts.append(f"   - Swipe validation: coord1={coord1}, coord2={coord2}\n")
                            elif action == "type":
                                text = tool_data.get("text", "")
                                log_entry_parts.append(f"   - Type validation: text='{text}', valid={isinstance(text, str)}\n")
                                
                        except json.JSONDecodeError as e:
                            log_entry_parts.append(f"   - JSON parsing failed: {e}\n")
                else:
                    log_entry_parts.append("✅ Format validation passed\n")
                
                if task_type in ["android_tool_call_only", "android_hierarchical_no_thinking"]:
                    log_entry_parts.append(f"Expected format: <tool_call>{{valid_json}}</tool_call>\n")
                elif task_type == "android_high_level":
                    log_entry_parts.append(f"Expected format: <reasoning>...</reasoning><instruction>Instruction: ...</instruction>\n")
                else:
                    log_entry_parts.append(f"Expected format: <think>...</think><tool_call>{{valid_json}}</tool_call>\n")
                # Record complete content without truncation
                log_entry_parts.append("=" * 40 + " COMPLETE CONTENT " + "=" * 40 + "\n")
                log_entry_parts.append(f"Content: {content}\n")  # Record complete content
                log_entry_parts.append("=" * 80 + "\n")
                
                # Atomic write: single write operation to avoid race conditions
                complete_log_entry = "".join(log_entry_parts)
                with open(format_log_path, "a", encoding='utf-8') as f:
                    f.write(complete_log_entry)
            # ============================DEBUG MODE====================================================
        
        return rewards
    
    @staticmethod
    def iou_reward(completions, solution, **kwargs):
        """Calculate IoU reward between predicted bounding box from Qwen model and ground truth bounding box."""
        import re
        import os
        from datetime import datetime
        import json
        def iou(box1, box2):
            inter_x1 = max(box1[0], box2[0])
            inter_y1 = max(box1[1], box2[1])
            inter_x2 = min(box1[2]-1, box2[2]-1)
            inter_y2 = min(box1[3]-1, box2[3]-1)
            if inter_x1 < inter_x2 and inter_y1 < inter_y2:
                inter = (inter_x2-inter_x1+1)*(inter_y2-inter_y1+1)
            else:
                inter = 0
            union = (box1[2]-box1[0])*(box1[3]-box1[1]) + (box2[2]-box2[0])*(box2[3]-box2[1]) - inter
            return float(inter)/union
        def resize_bbox(bbox, input_height, input_width, image_height, image_width):
            bbox[0] = bbox[0] / input_width * image_width
            bbox[1] = bbox[1] / input_height * image_height
            bbox[2] = bbox[2] / input_width * image_width
            bbox[3] = bbox[3] / input_height * image_height
            return bbox
        contents = [completion[0]["content"] for completion in completions]
        rewards = []
        current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
        answer_tag_pattern = r'<answer>(.*?)</answer>'
        bbox_pattern = r'\[(\d+),\s*(\d+),\s*(\d+),\s*(\d+)]'

        for i, (content, sol) in enumerate(zip(contents, solution)):
            image_grid_thw = kwargs.get("image_grid_thw")[i]
            image_path = kwargs.get("image_path")[i][0]
            image = Image.open(image_path)
            image_width, image_height = image.size
            input_height = int(image_grid_thw[1]*14)
            input_width = int(image_grid_thw[2]*14)
            
            sol = re.findall(answer_tag_pattern, sol, re.DOTALL)[-1]
            sol = json.loads(sol.strip())
            reward = 0.0
            # Try symbolic verification first
            try:
                content_answer_match = re.search(answer_tag_pattern, content, re.DOTALL)
                if content_answer_match:
                    content_answer = content_answer_match.group(1).strip()
                    bbox_match = re.search(bbox_pattern, content_answer)
                    if bbox_match:
                        bbox = [int(bbox_match.group(1)), int(bbox_match.group(2)), int(bbox_match.group(3)), int(bbox_match.group(4))]
                        bbox = resize_bbox(bbox, input_height, input_width, image_height, image_width)
                        # if iou(bbox, sol) > 0.5:
                        #     reward = 1.0
                        reward = iou(bbox, sol)
            except Exception:
                pass  # Continue to next verification method if this fails
                    
            rewards.append(reward)
            if os.getenv("DEBUG_MODE") == "true":
                log_path = os.getenv("LOG_PATH")
                current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
                image_path = kwargs.get("image_path")[i] if "image_path" in kwargs else None
                problem = kwargs.get("problem")[i]
                if reward <= 1.0:  # this condition can be changed for debug
                    with open(log_path, "a", encoding='utf-8') as f:
                        f.write(f"------------- {current_time} Accuracy reward: {reward} -------------\n")
                        f.write(f"image_path: {image_path}\n")
                        f.write(f"problem: {problem}\n")
                        f.write(f"Content: {content}\n")
                        f.write(f"Solution: {sol}\n") 
        return rewards
    
    @staticmethod
    def android_reward(completions, solution, **kwargs):
        """Calculate reward for Android task considering both action type and action precision."""
        import re
        import json
        import os
        from datetime import datetime
        import math
        
        def calculate_coordinate_similarity(pred_coord, true_coord, image_width=1080, image_height=1920):
            """Calculate coordinate similarity using normalized euclidean distance."""
            if not (isinstance(pred_coord, list) and isinstance(true_coord, list) and 
                   len(pred_coord) == 2 and len(true_coord) == 2):
                return 0.0
            
            # Normalize coordinates by image dimensions
            pred_x_norm = pred_coord[0] / image_width
            pred_y_norm = pred_coord[1] / image_height
            true_x_norm = true_coord[0] / image_width
            true_y_norm = true_coord[1] / image_height
            
            # Calculate euclidean distance
            distance = math.sqrt((pred_x_norm - true_x_norm)**2 + (pred_y_norm - true_y_norm)**2)
            
            # Convert distance to similarity (closer = higher reward)
            # Use exponential decay for smoother reward curve
            similarity = math.exp(-distance * 5)  # 5 is a scaling factor
            return similarity
        
        def calculate_text_similarity(pred_text, true_text):
            """Calculate text similarity using exact match and partial match."""
            
            # Exact match gets full reward
            if pred_text.strip() == true_text.strip():
                return 1.0
            else:
                return 0.0
        

        
        contents = [completion[0]["content"] for completion in completions]
        rewards = []
        tool_call_pattern = r'<tool_call>(.*?)</tool_call>'
        
        for i, (content, sol) in enumerate(zip(contents, solution)):
            reward = 0.0
            
            try:
                # Parse ground truth solution
                sol_match = re.search(tool_call_pattern, sol, re.DOTALL)
                if not sol_match:
                    rewards.append(0.0)
                    continue
                
                # Ground truth format: {"name": "mobile_use", "arguments": {"action": "click", "coordinate": [x, y]}}
                true_tool_call_raw = json.loads(sol_match.group(1).strip())
                if "arguments" in true_tool_call_raw:
                    true_tool_call = true_tool_call_raw["arguments"]
                else:
                    true_tool_call = true_tool_call_raw
                true_action = true_tool_call.get("action", "")
                
                # Parse predicted tool call
                pred_match = re.search(tool_call_pattern, content, re.DOTALL)
                if not pred_match:
                    rewards.append(0.0)
                    continue
                
                # Predicted content might have the "arguments" wrapper or direct format
                pred_tool_call_raw = json.loads(pred_match.group(1).strip())
                if "arguments" in pred_tool_call_raw:
                    pred_tool_call = pred_tool_call_raw["arguments"]
                else:
                    pred_tool_call = pred_tool_call_raw
                pred_action = pred_tool_call.get("action", "")
                
                # Action type matching (base reward)
                if pred_action != true_action:
                    reward = 0.0
                else:
                    # Base reward for correct action type
                    action_reward = 0.3
                    
                    # Calculate precision reward based on action type
                    precision_reward = 0.0
                    
                    if pred_action in ["click", "long_press"]:
                        # Coordinate-based actions
                        pred_coord = pred_tool_call.get("coordinate", [])
                        true_coord = true_tool_call.get("coordinate", [])
                        precision_reward = calculate_coordinate_similarity(pred_coord, true_coord) * 0.7
                        
                    elif pred_action == "swipe":
                        # Swipe action with two coordinates
                        pred_coord1 = pred_tool_call.get("coordinate", [])
                        true_coord1 = true_tool_call.get("coordinate", [])
                        pred_coord2 = pred_tool_call.get("coordinate2", [])
                        true_coord2 = true_tool_call.get("coordinate2", [])
                        
                        coord1_sim = calculate_coordinate_similarity(pred_coord1, true_coord1)
                        coord2_sim = calculate_coordinate_similarity(pred_coord2, true_coord2)
                        precision_reward = (coord1_sim + coord2_sim) / 2 * 0.7
                        
                    elif pred_action == "type":
                        # Text input action
                        pred_text = pred_tool_call.get("text", "")
                        true_text = true_tool_call.get("text", "")
                        precision_reward = calculate_text_similarity(pred_text, true_text) * 0.7
                        
                    elif pred_action == "system_button":
                        # System button action (exact match required)
                        pred_button = pred_tool_call.get("button", "")
                        true_button = true_tool_call.get("button", "")
                        precision_reward = 0.7 if pred_button == true_button else 0.0
                        
                    elif pred_action == "terminate":
                        # Terminate action
                        pred_status = pred_tool_call.get("status", "")
                        true_status = true_tool_call.get("status", "")
                        precision_reward = 0.7 if pred_status == true_status else 0.0
                        
                    # Total reward = action_reward + precision_reward
                    reward = action_reward + precision_reward
                    
            except (json.JSONDecodeError, KeyError, ValueError) as e:
                reward = 0.0
            except Exception as e:
                reward = 0.0
            
            rewards.append(reward)
            
            # Debug logging
            # ============================DEBUG MODE====================================================
            if os.getenv("DEBUG_MODE") == "true":
                log_path = os.getenv("LOG_PATH")
                current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
                android_log_path = log_path.replace(".txt", "_android.txt")
                
                # Build the complete log entry as a single string for atomic writing
                log_entry_parts = [
                    f"------------- {current_time} Android Accuracy Reward: {reward:.3f} -------------\n",
                    f"Sample {i}: \n"
                ]
                
                # Initialize default values to avoid reference errors
                pred_action = "UNKNOWN"
                true_action = "UNKNOWN"
                pred_tool_call = {}
                true_tool_call = {}
                
                # Try to extract more debugging info
                try:
                    # Parse ground truth solution
                    sol_match = re.search(tool_call_pattern, sol, re.DOTALL)
                    if sol_match:
                        true_tool_call_raw = json.loads(sol_match.group(1).strip())
                        if "arguments" in true_tool_call_raw:
                            true_tool_call = true_tool_call_raw["arguments"]
                        else:
                            true_tool_call = true_tool_call_raw
                        true_action = true_tool_call.get("action", "UNKNOWN")
                    
                    # Parse predicted tool call
                    pred_match = re.search(tool_call_pattern, content, re.DOTALL)
                    if pred_match:
                        pred_tool_call_raw = json.loads(pred_match.group(1).strip())
                        if "arguments" in pred_tool_call_raw:
                            pred_tool_call = pred_tool_call_raw["arguments"]
                        else:
                            pred_tool_call = pred_tool_call_raw
                        pred_action = pred_tool_call.get("action", "UNKNOWN")
                        
                    log_entry_parts.append(f"Predicted action: {pred_action}\n")
                    log_entry_parts.append(f"True action: {true_action}\n")
                    
                    if pred_action in ["click", "long_press"]:
                        pred_coord = pred_tool_call.get("coordinate", [])
                        true_coord = true_tool_call.get("coordinate", []) if true_tool_call else []
                        log_entry_parts.append(f"Predicted coordinate: {pred_coord}\n")
                        log_entry_parts.append(f"True coordinate: {true_coord}\n")
                        if len(pred_coord) == 2 and len(true_coord) == 2:
                            coord_sim = calculate_coordinate_similarity(pred_coord, true_coord)
                            log_entry_parts.append(f"Coordinate similarity: {coord_sim:.3f}\n")
                            
                    elif pred_action == "swipe":
                        pred_coord1 = pred_tool_call.get("coordinate", [])
                        true_coord1 = true_tool_call.get("coordinate", []) if true_tool_call else []
                        pred_coord2 = pred_tool_call.get("coordinate2", [])
                        true_coord2 = true_tool_call.get("coordinate2", []) if true_tool_call else []
                        log_entry_parts.append(f"Predicted start: {pred_coord1}, end: {pred_coord2}\n")
                        log_entry_parts.append(f"True start: {true_coord1}, end: {true_coord2}\n")
                        
                    elif pred_action == "type":
                        pred_text = pred_tool_call.get("text", "")
                        true_text = true_tool_call.get("text", "") if true_tool_call else ""
                        log_entry_parts.append(f"Predicted text: '{pred_text}'\n")
                        log_entry_parts.append(f"True text: '{true_text}'\n")
                        
                except Exception as debug_e:
                    log_entry_parts.append(f"Debug info extraction failed: {debug_e}\n")
                    log_entry_parts.append(f"Raw solution: {sol}\n")
                    log_entry_parts.append(f"Raw content preview: {content[:500]}...\n")
                
                # Record complete content without truncation
                log_entry_parts.append("=" * 40 + " COMPLETE CONTENT " + "=" * 40 + "\n")
                log_entry_parts.append(f"Content: {content}\n")
                log_entry_parts.append("=" * 40 + " COMPLETE SOLUTION " + "=" * 40 + "\n")
                log_entry_parts.append(f"Solution: {sol}\n")
                log_entry_parts.append("=" * 80 + "\n")
                
                # Atomic write: single write operation to avoid race conditions
                complete_log_entry = "".join(log_entry_parts)
                with open(android_log_path, "a", encoding='utf-8') as f:
                    f.write(complete_log_entry)
            # ============================DEBUG MODE====================================================
        
        return rewards

    @staticmethod
    def high_level_content_reward(completions, solution, **kwargs):
        """Calculate content reward for high-level instruction task using 72B API."""
        import dashscope
        import json
        import re
        import os
        from datetime import datetime
        
        # API configuration - get from environment or use default
        api_key = os.getenv("DASHSCOPE_API_KEY", "sk-895ab67e50c7493ca4b63c78a6f106e1")
        api_model = os.getenv("HIGH_LEVEL_API_MODEL", "qwen2.5-vl-72b-instruct")
        
        dashscope.api_key = api_key
        
        system_prompt = """
You are a mobile operation instruction validator. Strictly evaluate if the generated instruction is valid for the current step.

# Evaluation Criteria (ALL must be met for reward=1):
1. Atomic Action: Must represent ONE actionable step (e.g., "click X" not "click X then do Y")
2. Context Match: Must logically follow from the current screen state
3. Keyboard State: For text input instructions, keyboard MUST be visible/activated
4. Target Existence: Referenced UI element must be present in current screen

# Evaluation Rules:
- Reward=1 ONLY when ALL criteria are satisfied
- Reward=0 for ANY violation

# Examples:

[Valid Example 1]
Task: Search for hotels in Washington DC
Screen: Home screen with Chrome icon
Instruction: "click Chrome"
→ Reward=1

[Invalid Example 1]
Task: Search for hotels in Washington DC  
Screen: Chrome search page (no keyboard)
Instruction: "type hotels"
→ Violates Rule 3 → Reward=0

[Valid Example 2]
Task: Search for hotels in Washington DC
Screen: Chrome search page (keyboard visible)
Instruction: "type hotels in Washington DC"
→ Reward=1

[Invalid Example 2]
Task: Search for hotels in Washington DC
Screen: Chrome search page (no keyboard)
Instruction: "click search bar"
→ Valid action → Reward=1

[Invalid Example 3]
Task: Search for hotels in Washington DC
Screen: Home screen
Instruction: "open Chrome and search"
→ Violates Rule 1 → Reward=0

[Edge Case]
Task: Search for hotels in Washington DC  
Screen: Search results page
Instruction: "click back button"
→ Valid but unrelated to task → Still Reward=1

# Output Format:
{"reward": 1} or {"reward": 0}
NO explanations. Strict JSON format only.
"""
        
        contents = [completion[0]["content"] for completion in completions]
        rewards = []
        
        # Extract task and image information from kwargs
        tasks = kwargs.get("problem", [])  # Current instruction/problem
        raw_image_paths = kwargs.get("image_path", [])
        
        # Handle different image_path formats
        image_paths = []
        for path_item in raw_image_paths:
            if isinstance(path_item, list):
                # Handle nested list format: [["/path/to/image.png"]] -> "/path/to/image.png"
                if len(path_item) > 0:
                    image_paths.append(path_item[0])
                else:
                    image_paths.append(None)
            else:
                # Handle direct string format: "/path/to/image.png"
                image_paths.append(path_item)
        
        
        for i, (content, task, image_path) in enumerate(zip(contents, tasks, image_paths)):
            reward = 0.0
            
            try:
                # Extract instruction from content using regex
                instruction_pattern = r"<think>.*?</think>\s*<instruction_step>(.*?)</instruction_step>"
                instruction_match = re.search(instruction_pattern, content, re.DOTALL)
                
                if instruction_match:
                    generated_instruction = instruction_match.group(1).strip()
                    
                    # Prepare messages for 72B API
                    messages = [
                        {
                            "role": "system",
                            "content": [{"type": "text", "text": system_prompt}]
                        },
                        {
                            "role": "user", 
                            "content": [
                                {"type": "image", "image": image_path},
                                {"type": "text", "text": f"Task: {task}\nGenerated Instruction: {generated_instruction}"}
                            ]
                        }
                    ]
                    
                    # Call 72B API
                    response = dashscope.MultiModalConversation.call(
                        model=api_model,
                        messages=messages
                    )
                    if response.status_code == 200:
                        # Correct API response parsing for DashScope
                        api_response = response.output["choices"][0]["message"]["content"][0]["text"]
                        
                        # Parse JSON response
                        try:
                            result = json.loads(api_response.strip())
                            reward = float(result.get("reward", 0))
                        except json.JSONDecodeError:
                            # Fallback: look for reward in text
                            if '"reward": 1' in api_response or '"reward":1' in api_response:
                                reward = 1.0
                            else:
                                reward = 0.0
                        
                    else:
                        # API call failed, default to 0
                        reward = 0.0
                        if os.getenv("DEBUG_MODE") == "true":
                            print(f"API call failed with status: {response.status_code}")
                            print(f"Response: {response}")
                        
                else:
                    # No instruction found in content
                    reward = 0.0
                    
            except Exception as e:
                # Any error results in 0 reward
                reward = 0.0
                if os.getenv("DEBUG_MODE") == "true":
                    print(f"High-level content reward error: {e}")
            
            rewards.append(reward)
            
            # Debug logging
            if os.getenv("DEBUG_MODE") == "true":
                log_path = os.getenv("LOG_PATH")
                if log_path is not None:
                    current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
                    high_level_log_path = log_path.replace(".txt", "_high_level.txt")
                    
                    with open(high_level_log_path, "a", encoding='utf-8') as f:
                        f.write(f"------------- {current_time} High-Level Content Reward: {reward:.3f} -------------\n")
                        f.write(f"Sample {i}:\n")
                        f.write(f"Task: {task}\n")
                        f.write(f"Content: {content}\n")
                        f.write(f"Image: {image_path}\n")
                        f.write("=" * 80 + "\n")
        
        return rewards

    @staticmethod
    def select_reward_func(func: str, task_type: str):
        if func == "accuracy":
            if task_type == "rec":
                return Qwen2VLModule.iou_reward
            elif task_type == "android_high_level":
                return Qwen2VLModule.high_level_content_reward
            elif task_type in ["android", "android_tool_call_only", "android_hierarchical", "android_hierarchical_no_thinking"]:
                return Qwen2VLModule.android_reward
            else:
                raise ValueError(f"Unsupported reward function: {func}")
        elif func == "format":
            if task_type == "rec":
                return Qwen2VLModule.format_reward_rec
            elif task_type in ["android", "android_tool_call_only", "android_hierarchical", "android_hierarchical_no_thinking", "android_high_level"]:
                return Qwen2VLModule.format_reward_android
            else:
                raise ValueError(f"Unsupported reward function: {func}")
        else:
            raise ValueError(f"Unsupported reward function: {func}")
        