from open_r1.vlm_modules.vlm_module import VLMBaseModule
from typing import Dict, Any, Union
from transformers import AutoModel, AutoProcessor, AutoConfig
import torch
import torchvision.transforms as T
from PIL import Image
from torchvision.transforms.functional import InterpolationMode
from transformers.feature_extraction_sequence_utils import BatchFeature

IMG_START_TOKEN='<img>'
IMG_END_TOKEN='</img>'
IMG_CONTEXT_TOKEN='<IMG_CONTEXT>'

IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)

class InvernVLModule(VLMBaseModule):
    def __init__(self):
        super().__init__()
        self.conv_template = None
        self.num_image_token = None

    def get_vlm_key(self):
        return "internvl"
        
    def get_model_class(self, model_id: str, model_init_kwargs: dict):
        assert "InternVL" in model_id, f"model_id must contain 'InternVL', but got {model_id}"
        self.model_config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
        # The model class of InternVL when being mapped has been determined by its config
        model_cls = AutoModel
        # InternVL should be inputted with "trust_remote_code=True"
        model_init_kwargs["trust_remote_code"] = True
        # "use_cache" should be removed
        model_init_kwargs.pop("use_cache", None)
        # "flash_attention_2" should be modified to "use_flash_attn" in InternVL
        if "flash_attention_2" in model_init_kwargs.get("attn_implementation", ""):
            model_init_kwargs["use_flash_attn"] = True
            model_init_kwargs.pop("attn_implementation")
        return model_cls

    def post_model_init(self, model, processing_class):
        self.conv_template = model.conv_template if self.conv_template is None else self.conv_template
        self.num_image_token = model.num_image_token if self.num_image_token is None else self.num_image_token
        img_context_token_id = processing_class.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
        model.img_context_token_id = img_context_token_id
    
    def is_embeds_input(self):
        return True

    def get_processing_class(self):
        return AutoProcessor
    
    def get_eos_token_id(self, processing_class):
        eos_token_id = processing_class.convert_tokens_to_ids(self.conv_template.sep.strip())
        return eos_token_id
        
    def get_vision_modules_keywords(self):
        return ['vision_model']

    def get_custom_multimodal_keywords(self):
        return ['pixel_values', 'image_flags']
    
    def get_non_generate_params(self):
        return ['image_flags']

    def get_custom_processing_keywords(self):
        return [('None', 'max_anyres_num')]

    def prepare_prompt(self, processing_class, inputs: dict[str, Union[torch.Tensor, Any]]):
        prompts_text = []
        for example in inputs:
            template = self.conv_template.copy()
            conversation_list = example["prompt"]
            system_message = extract_system_message(conversation_list)
            if system_message is not None:
                template.system_message = system_message
            
            processed_list = process_conversation_list(conversation_list, system_message)
            for i, processed_item in enumerate(processed_list):
                if i % 2 == 0:
                    template.append_message(template.roles[0], processed_item)
                else:
                    template.append_message(template.roles[1], processed_item)
            if len(processed_list) % 2 == 1:
                template.append_message(template.roles[1], None)
            query = template.get_prompt()
            prompts_text.append(query)
        return prompts_text
    
    def prepare_model_inputs(self, processing_class, prompts_text, images, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False):
        # Process images
        full_pixel_values = []
        num_patches_list = []
        for img in images:
            pixel_values = self._load_image(img, input_size=self.model_config.vision_config.image_size, max_num=processing_class.max_anyres_num)
            full_pixel_values.append(pixel_values)
            num_patches_list.append(pixel_values.shape[0])
        full_pixel_values = torch.cat(full_pixel_values, dim=0)
        
        # Process prompts
        queries = []
        image_idx = 0
        for query in prompts_text:
            while "<image>" in query:
                num_patches = num_patches_list[image_idx]
                image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN
                query = query.replace("<image>", image_tokens, 1)
                image_idx += 1
            queries.append(query)
        assert image_idx == len(num_patches_list)
        
        model_inputs = processing_class(
            queries,
            return_tensors=return_tensors,
            padding=padding,
            padding_side=padding_side,
            add_special_tokens=add_special_tokens,
        )
        model_inputs["pixel_values"] = full_pixel_values
        # Only support pure-image data currently (each sample should contain the image)
        model_inputs['image_flags'] = torch.ones(full_pixel_values.shape[0], dtype=torch.long)
        
        model_inputs = BatchFeature(data=model_inputs)

        return model_inputs, None

    def _load_image(self, image: Image.Image, input_size: int=448, max_num:int=12):
        transform = build_transform(input_size=input_size)
        images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
        pixel_values = [transform(image) for image in images]
        pixel_values = torch.stack(pixel_values)
        return pixel_values
    
    @staticmethod
    def get_question_template(task_type: str):
        if task_type == "rec":
            return "{Question} First output the thinking process in <think> </think> tags and then output the final answer in <answer> </answer> tags."
        elif task_type == "android":
            SYSTEM_PROMPT = (
                """
                You are a helpful assistant. You need to operate the mobile phone based on the instruction.
                You will receive an instruction and a screenshot of the mobile phone screen.
                You should first think about the reasoning process in your mind, and then output the final answer in the form of a tool call. 
                The reasoning process and answer are enclosed within <think> </think> and <tool_call> </tool_call> tags, respectively, i.e., <think> reasoning process here </think><tool_call> JSON format answer here </tool_call>
                The output MUST follow the rules below:
                # Output examples(You need to output the right answer according to the instruction):
                - When the instruction is related to "long press" such as "long press the backspace key" , the output should be like:
                <think> reasoning process here </think><tool_call>{{"action": "long_press", "coordinate": [xxx, xxx]}}</tool_call>.
                - When the instruction is related to "click" or "tap" such as "press the backspace key" , the output should be like:
                <think> reasoning process here </think><tool_call>{{"action": "click", "coordinate": [xxx, xxx]}}</tool_call>.
                - When the instruction is related to "swipe" , the output should be like:
                <think> reasoning process here </think><tool_call>{{"action": "swipe", "coordinate": [xxx, xx], "coordinate2": [xxx, xxx]}}</tool_call>
                - When the instruction is related to "Back button" "Home button" , the output should be like:
                <think> reasoning process here </think><tool_call>{{"action": "system_button", "button": "Back"}}</tool_call>
                - When the instruction is related to "type" such as "Type the text '307.1' in the field", the output should be like:
                <think> reasoning process here </think><tool_call>{{"action": "type", "text": "307.1"}}</tool_call>
                - When the task is completed, the output should be like:
                <think> reasoning process here </think><tool_call>{{"action": "terminate", "status": "success"}}</tool_call>
                
                """
            )
            return SYSTEM_PROMPT + '\n' + "{Instruction}"
        elif task_type == "android_hierarchical":
            SYSTEM_PROMPT = (
                """
                You are a helpful assistant that can operate mobile phones to complete tasks.
                You will receive:
                1. A high-level task description
                2. A screenshot of the current mobile phone screen  
                3. A history of previous actions taken (if any)
                
                Based on the current situation, you should analyze what needs to be done next to progress toward completing the overall task.
                You should first think about the reasoning process in your mind, considering the task goal, current screen state, and what actions have been taken so far.
                Then output the next specific action in the form of a tool call.
                
                The reasoning process and answer are enclosed within <think> </think> and <tool_call> </tool_call> tags, respectively, i.e., <think> reasoning process here </think><tool_call> JSON format answer here </tool_call>
                
                The output MUST follow the rules below:
                # Output examples(You need to output the right answer according to the current situation):
                - For click/tap actions, the output should be like:
                <think> I need to analyze the current screen and determine where to click based on the task goal and previous actions. </think><tool_call>{{"action": "click", "coordinate": [xxx, xxx]}}</tool_call>
                - For long press actions, the output should be like:
                <think> Based on the task requirements, I need to long press on this element. </think><tool_call>{{"action": "long_press", "coordinate": [xxx, xxx]}}</tool_call>
                - For swipe actions, the output should be like:
                <think> I need to swipe to navigate or scroll to find the required element. </think><tool_call>{{"action": "swipe", "coordinate": [xxx, xxx], "coordinate2": [xxx, xxx]}}</tool_call>
                - For system button actions, the output should be like:
                <think> I need to use a system button to navigate. </think><tool_call>{{"action": "system_button", "button": "Back"}}</tool_call>
                - For text input actions, the output should be like:
                <think> I need to type specific text to complete this step of the task. </think><tool_call>{{"action": "type", "text": "specific text"}}</tool_call>
                - When the task is completed, the output should be like:
                <think> The task has been successfully completed based on all the actions taken. </think><tool_call>{{"action": "terminate", "status": "success"}}</tool_call>
                
            
                """
            )
            return SYSTEM_PROMPT + '\n' + "Task: {Task}\nAction History: {ActionHistory}\nCurrent Instruction: {Instruction}"
        elif task_type == "android_hierarchical_no_thinking":
            SYSTEM_PROMPT = (
                """
                You are a helpful assistant that can operate mobile phones to complete tasks.
                You will receive:
                1. A high-level task description
                2. A screenshot of the current mobile phone screen  
                3. A history of previous actions taken (if any)
                
                Based on the current situation, you should analyze what needs to be done next to progress toward completing the overall task.
                You should directly output the next specific action in the form of a tool call without any thinking process.
                
                The answer should be enclosed within <tool_call> </tool_call> tags, i.e., <tool_call> JSON format answer here </tool_call>
                
                The output MUST follow the rules below:
                # Output examples(You need to output the right answer according to the current situation):
                - For click/tap actions, the output should be like:
                <tool_call>{{"action": "click", "coordinate": [xxx, xxx]}}</tool_call>
                - For long press actions, the output should be like:
                <tool_call>{{"action": "long_press", "coordinate": [xxx, xxx]}}</tool_call>
                - For swipe actions, the output should be like:
                <tool_call>{{"action": "swipe", "coordinate": [xxx, xxx], "coordinate2": [xxx, xxx]}}</tool_call>
                - For system button actions, the output should be like:
                <tool_call>{{"action": "system_button", "button": "Back"}}</tool_call>
                - For text input actions, the output should be like:
                <tool_call>{{"action": "type", "text": "specific text"}}</tool_call>
                - When the task is completed, the output should be like:
                <tool_call>{{"action": "terminate", "status": "success"}}</tool_call>
                
                """
            )
            return SYSTEM_PROMPT + '\n' + "Task: {Task}\nAction History: {ActionHistory}\nCurrent Instruction: {Instruction}"
        elif task_type == "android_high_level":
            SYSTEM_PROMPT = (
                """
                You are a mobile operation Agent that performs precise screen interactions and generates next step action instructions.
                You will receive:
                1. A high-level task description
                2. A screenshot of the current mobile phone screen
                3. A history of previous actions taken (if any)
                Based on the current situation, you should analyze what needs to be done next to progress toward completing the overall task.
                You should first think about the reasoning process, then give the next single-step instruction that:
                - Represents ONE actionable step (e.g., "click Chrome" not "click Chrome then search")
                - Uses action types and targets without specific coordinates
                STRICTLY follow this output structure:
                <reasoning> reasoning process here </reasoning> <instruction>Instruction: ...</instruction>
                """
            )
            return SYSTEM_PROMPT + '\n' + "{Instruction}"
        else:
            return "{Question} First output the thinking process in <think> </think> tags and then output the final answer in <answer> </answer> tags."
    
    @staticmethod
    def format_reward_rec(completions, **kwargs):
        """Check if the InternVL model output matches a specific format."""
        import re
        import os
        from datetime import datetime
        pattern = r"<think>.*?</think>\s*<answer>.*?\[\d+,\s*\d+,\s*\d+,\s*\d+\].*?</answer>"
        completion_contents = [completion[0]["content"] for completion in completions]
        matches = [re.search(pattern, content, re.DOTALL) is not None for content in completion_contents]
        current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
        if os.getenv("DEBUG_MODE") == "true":
            log_path = os.getenv("LOG_PATH")
            with open(log_path.replace(".txt", "_format.txt"), "a", encoding='utf-8') as f:
                f.write(f"------------- {current_time} Format reward -------------\n")
                for content, match in zip(completion_contents, matches):
                    f.write(f"Content: {content}\n")
                    f.write(f"Has format: {bool(match)}\n")
        return [1.0 if match else 0.0 for match in matches]
        
    @staticmethod
    def iou_reward(completions, solution, **kwargs):
        """Calculate IoU reward between predicted bounding box from InternVL model and ground truth bounding box."""
        """Adopt soft iou reward here"""
        import re
        import os
        import json
        from datetime import datetime
        def iou(box1, box2):
            inter_x1 = max(box1[0], box2[0])
            inter_y1 = max(box1[1], box2[1])
            inter_x2 = min(box1[2]-1, box2[2]-1)
            inter_y2 = min(box1[3]-1, box2[3]-1)
            if inter_x1 < inter_x2 and inter_y1 < inter_y2:
                inter = (inter_x2-inter_x1+1)*(inter_y2-inter_y1+1)
            else:
                inter = 0
            union = (box1[2]-box1[0])*(box1[3]-box1[1]) + (box2[2]-box2[0])*(box2[3]-box2[1]) - inter
            return float(inter)/union
        contents = [completion[0]["content"] for completion in completions]
        rewards = []
        current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
        answer_tag_pattern = r'<answer>(.*?)</answer>'
        bbox_pattern = r'\[(\d+),\s*(\d+),\s*(\d+),\s*(\d+)]'
        for i, (content, sol) in enumerate(zip(contents, solution)):
            sol = re.findall(answer_tag_pattern, sol, re.DOTALL)[-1]
            sol = json.loads(sol.strip())
            reward = 0.0
            # Try symbolic verification first
            try:
                content_answer_match = re.search(answer_tag_pattern, content, re.DOTALL)
                if content_answer_match:
                    content_answer = content_answer_match.group(1).strip()
                    bbox_match = re.search(bbox_pattern, content_answer)
                    if bbox_match:
                        bbox = [int(bbox_match.group(1)), int(bbox_match.group(2)), int(bbox_match.group(3)), int(bbox_match.group(4))]
                        reward = iou(bbox, sol)
            except Exception:
                pass  # Continue to next verification method if this fails
                    
            rewards.append(reward)
            if os.getenv("DEBUG_MODE") == "true":
                log_path = os.getenv("LOG_PATH")
                current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
                image_path = kwargs.get("image_path")[i] if "image_path" in kwargs else None
                problem = kwargs.get("problem")[i]
                if reward <= 1.0:  # this condition can be changed for debug
                    with open(log_path, "a", encoding='utf-8') as f:
                        f.write(f"------------- {current_time} Accuracy reward: {reward} -------------\n")
                        f.write(f"image_path: {image_path}\n")
                        f.write(f"problem: {problem}\n")
                        f.write(f"Content: {content}\n")
                        f.write(f"Solution: {sol}\n") 
        return rewards
    
    @staticmethod
    def format_reward_android(completions, **kwargs):
        """Check if the InternVL model output matches the required format for Android tasks. 
        Supports three modes:
        - android: <think>...</think><tool_call>...</tool_call>
        - android_tool_call_only: <tool_call>...</tool_call>
        - android_hierarchical: <think>...</think><tool_call>...</tool_call>
        """
        import re
        import json
        import os
        from datetime import datetime
        
        completion_contents = [completion[0]["content"] for completion in completions]
        rewards = []
        
        # Get task type from kwargs to determine the expected format
        task_type = kwargs.get("task_type", "android")
        
        if task_type in ["android_tool_call_only", "android_hierarchical_no_thinking"]:
            # Pattern for tool_call only modes: <tool_call>...</tool_call>
            pattern = r"<tool_call>(.*?)</tool_call>"
        elif task_type == "android_high_level":
            # Pattern for high-level mode: <reasoning>...</reasoning><instruction>Instruction: ...</instruction>
            pattern = r"<reasoning>.*?</reasoning>\s*<instruction>\s*Instruction:\s*(.*?)\s*</instruction>"
        elif task_type in ["android", "android_hierarchical"]:
            # Pattern for android and android_hierarchical modes: <think>...</think><tool_call>...</tool_call>
            pattern = r"<think>.*?</think>\s*<tool_call>(.*?)</tool_call>"
        else:
            # Default pattern for android mode: <think>...</think><tool_call>...</tool_call>
            pattern = r"<think>.*?</think>\s*<tool_call>(.*?)</tool_call>"
        
        for i, content in enumerate(completion_contents):
            reward = 0.0
            
            try:
                if task_type == "android_high_level":
                    # For high-level mode, use graded reward system (no strict pattern matching)
                    reward = 0.0
                    
                    # reasoning 部分奖励
                    if "<reasoning>" in content:
                        reward += 0.2
                    elif "(reasoning)" in content:
                        reward += 0.1
                    elif "reasoning" in content:
                        reward += 0.05
                        
                    if "</reasoning>" in content:
                        reward += 0.2
                    elif "(/reasoning)" in content:
                        reward += 0.1
                    elif "/reasoning" in content:
                        reward += 0.05

                    # instruction 部分奖励
                    if "<instruction>" in content:
                        reward += 0.2
                    elif "(instruction)" in content:
                        reward += 0.1
                    elif "instruction" in content:
                        reward += 0.05
                        
                    if "</instruction>" in content:
                        reward += 0.2
                    elif "(/instruction)" in content:
                        reward += 0.1
                    elif "/instruction" in content:
                        reward += 0.05
                        
                    # 额外检查是否有 "Instruction:" 前缀，给予bonus
                    if "Instruction:" in content:
                        reward += 0.2
                else:
                    # For other modes, use strict pattern matching
                    match = re.search(pattern, content, re.DOTALL)
                    if match:
                        content_inside_tags = match.group(1).strip()
                        
                        # Try to parse the tool_call content as JSON for other Android modes
                        try:
                            tool_data_raw = json.loads(content_inside_tags)
                            
                            # Handle potential "arguments" wrapper
                            if "arguments" in tool_data_raw:
                                tool_data = tool_data_raw["arguments"]
                            else:
                                tool_data = tool_data_raw
                            
                            # Check if it has the required "action" field
                            if "action" in tool_data:
                                action = tool_data["action"]
                                
                                # Validate based on action type
                                if action in ["click", "long_press"]:
                                    # Should have coordinate field with [x, y] format
                                    if "coordinate" in tool_data and isinstance(tool_data["coordinate"], list) and len(tool_data["coordinate"]) == 2:
                                        reward = 1.0
                                elif action == "swipe":
                                    # Should have coordinate and coordinate2 fields
                                    if ("coordinate" in tool_data and isinstance(tool_data["coordinate"], list) and len(tool_data["coordinate"]) == 2 and
                                        "coordinate2" in tool_data and isinstance(tool_data["coordinate2"], list) and len(tool_data["coordinate2"]) == 2):
                                        reward = 1.0
                                elif action == "system_button":
                                    # Should have button field
                                    if "button" in tool_data and isinstance(tool_data["button"], str):
                                        reward = 1.0
                                elif action == "type":
                                    # Should have text field
                                    if "text" in tool_data and isinstance(tool_data["text"], str):
                                        reward = 1.0
                                elif action == "terminate":
                                    # Should have status field
                                    if "status" in tool_data and isinstance(tool_data["status"], str):
                                        reward = 1.0
                                # Add more action types as needed
                                
                        except json.JSONDecodeError:
                            # Invalid JSON in tool_call
                            reward = 0.0
                        
            except Exception as e:
                reward = 0.0
            
            rewards.append(reward)
            
        # Debug logging if enabled
        if os.getenv("DEBUG_MODE") == "true":
            log_path = os.getenv("LOG_PATH")
            if log_path is not None:
                current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
                format_log_path = log_path.replace(".txt", "_format.txt")
                
                with open(format_log_path, "a", encoding='utf-8') as f:
                    f.write(f"------------- {current_time} InternVL Android Format Reward -------------\n")
                    for i, (content, reward) in enumerate(zip(completion_contents, rewards)):
                        f.write(f"Sample {i}: Reward = {reward}\n")
                        f.write(f"Content: {content}\n")
                        f.write("=" * 80 + "\n")
                        
        return rewards

    @staticmethod
    def android_reward(completions, solution, **kwargs):
        """Calculate reward for Android task considering both action type and action precision."""
        import re
        import json
        import os
        from datetime import datetime
        import math
        
        def calculate_coordinate_similarity(pred_coord, true_coord, image_width=1080, image_height=1920):
            """Calculate coordinate similarity using normalized euclidean distance."""
            if not (isinstance(pred_coord, list) and isinstance(true_coord, list) and 
                   len(pred_coord) == 2 and len(true_coord) == 2):
                return 0.0
            
            # Normalize coordinates by image dimensions
            pred_x_norm = pred_coord[0] / image_width
            pred_y_norm = pred_coord[1] / image_height
            true_x_norm = true_coord[0] / image_width
            true_y_norm = true_coord[1] / image_height
            
            # Calculate euclidean distance
            distance = math.sqrt((pred_x_norm - true_x_norm)**2 + (pred_y_norm - true_y_norm)**2)
            
            # Convert distance to similarity (closer = higher reward)
            # Use exponential decay for smoother reward curve
            similarity = math.exp(-distance * 5)  # 5 is a scaling factor
            return similarity
        
        def calculate_text_similarity(pred_text, true_text):
            """Calculate text similarity using exact match and partial match."""
            # Exact match gets full reward
            if pred_text.strip() == true_text.strip():
                return 1.0
            else:
                return 0.0
        
        contents = [completion[0]["content"] for completion in completions]
        rewards = []
        tool_call_pattern = r'<tool_call>(.*?)</tool_call>'
        
        for i, (content, sol) in enumerate(zip(contents, solution)):
            reward = 0.0
            
            try:
                # Parse ground truth solution
                sol_match = re.search(tool_call_pattern, sol, re.DOTALL)
                if not sol_match:
                    rewards.append(0.0)
                    continue
                
                # Ground truth format: {"name": "mobile_use", "arguments": {"action": "click", "coordinate": [x, y]}}
                true_tool_call_raw = json.loads(sol_match.group(1).strip())
                if "arguments" in true_tool_call_raw:
                    true_tool_call = true_tool_call_raw["arguments"]
                else:
                    true_tool_call = true_tool_call_raw
                true_action = true_tool_call.get("action", "")
                
                # Parse predicted tool call
                pred_match = re.search(tool_call_pattern, content, re.DOTALL)
                if not pred_match:
                    rewards.append(0.0)
                    continue
                
                # Predicted content might have the "arguments" wrapper or direct format
                pred_tool_call_raw = json.loads(pred_match.group(1).strip())
                if "arguments" in pred_tool_call_raw:
                    pred_tool_call = pred_tool_call_raw["arguments"]
                else:
                    pred_tool_call = pred_tool_call_raw
                pred_action = pred_tool_call.get("action", "")
                
                # Action type matching (base reward)
                if pred_action != true_action:
                    reward = 0.0
                else:
                    # Base reward for correct action type
                    action_reward = 0.3
                    
                    # Calculate precision reward based on action type
                    precision_reward = 0.0
                    
                    if pred_action in ["click", "long_press"]:
                        # Coordinate-based actions
                        pred_coord = pred_tool_call.get("coordinate", [])
                        true_coord = true_tool_call.get("coordinate", [])
                        precision_reward = calculate_coordinate_similarity(pred_coord, true_coord) * 0.7
                        
                    elif pred_action == "swipe":
                        # Swipe action with two coordinates
                        pred_coord1 = pred_tool_call.get("coordinate", [])
                        true_coord1 = true_tool_call.get("coordinate", [])
                        pred_coord2 = pred_tool_call.get("coordinate2", [])
                        true_coord2 = true_tool_call.get("coordinate2", [])
                        
                        coord1_sim = calculate_coordinate_similarity(pred_coord1, true_coord1)
                        coord2_sim = calculate_coordinate_similarity(pred_coord2, true_coord2)
                        precision_reward = (coord1_sim + coord2_sim) / 2 * 0.7
                        
                    elif pred_action == "type":
                        # Text input action
                        pred_text = pred_tool_call.get("text", "")
                        true_text = true_tool_call.get("text", "")
                        precision_reward = calculate_text_similarity(pred_text, true_text) * 0.7
                        
                    elif pred_action == "system_button":
                        # System button action (exact match required)
                        pred_button = pred_tool_call.get("button", "")
                        true_button = true_tool_call.get("button", "")
                        precision_reward = 0.7 if pred_button == true_button else 0.0
                        
                    elif pred_action == "terminate":
                        # Terminate action
                        pred_status = pred_tool_call.get("status", "")
                        true_status = true_tool_call.get("status", "")
                        precision_reward = 0.7 if pred_status == true_status else 0.0
                        
                    # Total reward = action_reward + precision_reward
                    reward = action_reward + precision_reward
                    
            except (json.JSONDecodeError, KeyError, ValueError) as e:
                reward = 0.0
            except Exception as e:
                reward = 0.0
            
            rewards.append(reward)
            
            # Debug logging
            if os.getenv("DEBUG_MODE") == "true":
                log_path = os.getenv("LOG_PATH")
                if log_path is not None:
                    current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
                    android_log_path = log_path.replace(".txt", "_android.txt")
                    
                    with open(android_log_path, "a", encoding='utf-8') as f:
                        f.write(f"------------- {current_time} InternVL Android Accuracy Reward: {reward:.3f} -------------\n")
                        f.write(f"Sample {i}:\n")
                        f.write(f"Content: {content}\n")
                        f.write(f"Solution: {sol}\n")
                        f.write("=" * 80 + "\n")
        
        return rewards

    @staticmethod
    def high_level_content_reward(completions, solution, **kwargs):
        """Calculate content reward for high-level instruction task using 72B API."""
        import dashscope
        import json
        import re
        import os
        from datetime import datetime
        
        # API configuration - get from environment or use default
        api_key = os.getenv("DASHSCOPE_API_KEY", "sk-851c66d9f4d146c4a2e73cfe5d0baae8")
        api_model = os.getenv("HIGH_LEVEL_API_MODEL", "qwen2.5-vl-72b-instruct")
        
        dashscope.api_key = api_key
        
        system_prompt = """
You are a mobile operation instruction validator. Strictly evaluate if the generated instruction is valid for the current step.

# Evaluation Criteria (ALL must be met for reward=1):
1. Atomic Action: Must represent ONE actionable step (e.g., "click X" not "click X then do Y")
2. Context Match: Must logically follow from the current screen state
3. Keyboard State: For text input instructions, keyboard MUST be visible/activated
4. Target Existence: Referenced UI element must be present in current screen

# Evaluation Rules:
- Reward=1 ONLY when ALL criteria are satisfied
- Reward=0 for ANY violation

# Examples:

[Valid Example 1]
Task: Search for hotels in Washington DC
Screen: Home screen with Chrome icon
Instruction: "click Chrome"
→ Reward=1

[Invalid Example 1]
Task: Search for hotels in Washington DC  
Screen: Chrome search page (no keyboard)
Instruction: "type hotels"
→ Violates Rule 3 → Reward=0

[Valid Example 2]
Task: Search for hotels in Washington DC
Screen: Chrome search page (keyboard visible)
Instruction: "type hotels in Washington DC"
→ Reward=1

[Invalid Example 2]
Task: Search for hotels in Washington DC
Screen: Chrome search page (no keyboard)
Instruction: "click search bar"
→ Valid action → Reward=1

[Invalid Example 3]
Task: Search for hotels in Washington DC
Screen: Home screen
Instruction: "open Chrome and search"
→ Violates Rule 1 → Reward=0

[Edge Case]
Task: Search for hotels in Washington DC  
Screen: Search results page
Instruction: "click back button"
→ Valid but unrelated to task → Still Reward=1

# Output Format:
{"reward": 1} or {"reward": 0}
NO explanations. Strict JSON format only.
"""
        
        contents = [completion[0]["content"] for completion in completions]
        rewards = []
        
        # Extract task and image information from kwargs
        tasks = kwargs.get("problem", [])  # Current instruction/problem
        image_paths = kwargs.get("image_path", [])
        
        for i, (content, task, image_path) in enumerate(zip(contents, tasks, image_paths)):
            reward = 0.0
            
            try:
                # Extract instruction from content using regex
                instruction_pattern = r"<instruction>\s*Instruction:\s*(.+?)\s*</instruction>"
                instruction_match = re.search(instruction_pattern, content, re.DOTALL)
                
                # Debug logging for instruction extraction
                if os.getenv("DEBUG_MODE") == "true":
                    print(f"Content: {content}")
                    print(f"Instruction pattern: {instruction_pattern}")
                    print(f"Instruction match found: {instruction_match is not None}")
                
                if instruction_match:
                    generated_instruction = instruction_match.group(1).strip()
                    if os.getenv("DEBUG_MODE") == "true":
                        print(f"Extracted instruction: {generated_instruction}")
                    
                    # Prepare messages for 72B API
                    messages = [
                        {
                            "role": "system",
                            "content": [{"type": "text", "text": system_prompt}]
                        },
                        {
                            "role": "user", 
                            "content": [
                                {"type": "text", "text": f"Task: {task}\nGenerated Instruction: {generated_instruction}"},
                                {"type": "image", "image": image_path}
                            ]
                        }
                    ]
                    
                    # Call 72B API
                    response = dashscope.MultiModalConversation.call(
                        model=api_model,
                        messages=messages,
                        max_tokens=50
                    )
                    
                    if response.status_code == 200:
                        # Correct API response parsing for DashScope
                        api_response = response.output["choices"][0]["message"]["content"]
                        
                        # Parse JSON response
                        try:
                            result = json.loads(api_response.strip())
                            reward = float(result.get("reward", 0))
                        except json.JSONDecodeError:
                            # Fallback: look for reward in text
                            if '"reward": 1' in api_response or '"reward":1' in api_response:
                                reward = 1.0
                            else:
                                reward = 0.0
                        
                        # Debug logging for API response
                        if os.getenv("DEBUG_MODE") == "true":
                            print(f"API Response: {api_response}")
                            print(f"Parsed reward: {reward}")
                    else:
                        # API call failed, default to 0
                        reward = 0.0
                        if os.getenv("DEBUG_MODE") == "true":
                            print(f"API call failed with status: {response.status_code}")
                            print(f"Response: {response}")
                        
                else:
                    # No instruction found in content
                    reward = 0.0
                    
            except Exception as e:
                # Any error results in 0 reward
                reward = 0.0
                if os.getenv("DEBUG_MODE") == "true":
                    print(f"High-level content reward error: {e}")
                    print(f"Exception type: {type(e)}")
                    import traceback
                    traceback.print_exc()
            
            rewards.append(reward)
            
            # Debug logging
            if os.getenv("DEBUG_MODE") == "true":
                log_path = os.getenv("LOG_PATH")
                if log_path is not None:
                    current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
                    high_level_log_path = log_path.replace(".txt", "_high_level.txt")
                    
                    with open(high_level_log_path, "a", encoding='utf-8') as f:
                        f.write(f"------------- {current_time} High-Level Content Reward: {reward:.3f} -------------\n")
                        f.write(f"Sample {i}:\n")
                        f.write(f"Task: {task}\n")
                        f.write(f"Content: {content}\n")
                        f.write(f"Image: {image_path}\n")
                        f.write("=" * 80 + "\n")
        
        return rewards

    @staticmethod
    def select_reward_func(func: str, task_type: str):
        if func == "accuracy":
            if task_type == "rec":
                return InvernVLModule.iou_reward
            elif task_type == "android_high_level":
                return InvernVLModule.high_level_content_reward
            elif task_type in ["android", "android_tool_call_only", "android_hierarchical", "android_hierarchical_no_thinking"]:
                return InvernVLModule.android_reward
            else:
                raise ValueError(f"Unsupported reward function: {func}")
        elif func == "format":
            if task_type == "rec":
                return InvernVLModule.format_reward_rec
            elif task_type in ["android", "android_tool_call_only", "android_hierarchical", "android_hierarchical_no_thinking", "android_high_level"]:
                return InvernVLModule.format_reward_android
            else:
                raise ValueError(f"Unsupported reward function: {func}")
        else:
            raise ValueError(f"Unsupported reward function: {func}")


def process_conversation_list(conversation_list, system_message=None, image_newline=True):
    if system_message is not None:
        conversation_list = conversation_list[1:]
    processed_list = []
    
    for item in conversation_list:
        role = item["role"]
        content = item["content"]
        
        if isinstance(content, list):
            overall_str = ""
            for content_item in content:
                if content_item.get("type") == "image":
                    overall_str += "<image>" if not image_newline else "<image>\n"
                elif content_item.get("type") == "text":
                    overall_str += content_item.get("text")
                else:
                    raise ValueError(f"Unsupported content type: {type(content_item)}")
            processed_list.append(overall_str)
        elif isinstance(content, str):
            processed_list.append(content)
        else:
            raise ValueError(f"Unsupported content type: {type(content)}")
    
    return processed_list

def extract_system_message(conversation_list):
    if conversation_list[0]["role"] == "system":
        if isinstance(conversation_list[0]["content"], list):
            return conversation_list[0]["content"][0]["text"]
        else:
            return conversation_list[0]["content"]
    return None


def build_transform(input_size):
    MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
    transform = T.Compose([
        T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
        T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
        T.ToTensor(),
        T.Normalize(mean=MEAN, std=STD)
    ])
    return transform

def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
    best_ratio_diff = float('inf')
    best_ratio = (1, 1)
    area = width * height
    for ratio in target_ratios:
        target_aspect_ratio = ratio[0] / ratio[1]
        ratio_diff = abs(aspect_ratio - target_aspect_ratio)
        if ratio_diff < best_ratio_diff:
            best_ratio_diff = ratio_diff
            best_ratio = ratio
        elif ratio_diff == best_ratio_diff:
            if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
                best_ratio = ratio
    return best_ratio

def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
    orig_width, orig_height = image.size
    aspect_ratio = orig_width / orig_height

    # calculate the existing image aspect ratio
    target_ratios = set(
        (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
        i * j <= max_num and i * j >= min_num)
    target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])

    # find the closest aspect ratio to the target
    target_aspect_ratio = find_closest_aspect_ratio(
        aspect_ratio, target_ratios, orig_width, orig_height, image_size)

    # calculate the target width and height
    target_width = image_size * target_aspect_ratio[0]
    target_height = image_size * target_aspect_ratio[1]
    blocks = target_aspect_ratio[0] * target_aspect_ratio[1]

    # resize the image
    resized_img = image.resize((target_width, target_height))
    processed_images = []
    for i in range(blocks):
        box = (
            (i % (target_width // image_size)) * image_size,
            (i // (target_width // image_size)) * image_size,
            ((i % (target_width // image_size)) + 1) * image_size,
            ((i // (target_width // image_size)) + 1) * image_size
        )
        # split the image
        split_img = resized_img.crop(box)
        processed_images.append(split_img)
    assert len(processed_images) == blocks
    if use_thumbnail and len(processed_images) != 1:
        thumbnail_img = image.resize((image_size, image_size))
        processed_images.append(thumbnail_img)
    return processed_images