import numpy as np
import copy
from verl.workers.agent.tool_envs import ToolBase
from typing import Optional, List, Dict, Any
from PIL import Image
import re
import json
from verl.workers.agent.envs.mm_process_engine.prompt import PROMPT
from math import ceil, floor
# 临时修复
# ToolBase.registry = {}

class VisualToolBoxV3(ToolBase):
    name = "visual_toolbox_v3"
    user_prompt = '<image>\nHere is the zoomed in image for your grounding region {}' + PROMPT.USER_PROMPT_V3


    def __init__(self, _name, _desc, _params, **kwargs):
        super().__init__(
            name=self.name,
        )
        self.chatml_history = []
        self.multi_modal_data = None  # To store the current image being processed
        print(f"ENV: {self.name} initialized!")


    def extract_answer(self, action_string: str) -> Dict[str, any]:
        answer = re.findall(r'<answer>(.*?)</answer>', action_string, re.DOTALL)
        return answer[-1] if answer else None
        
    def extract_action(self, action_string: str) -> Dict[str, Any]:
        """
        Extracts the tool call from the action string.
        
        Args:
            action_string: The string containing the tool call in XML tags.
            
        Returns:
            A dictionary with the tool name and arguments.
            
        Raises:
            ValueError: If no tool call is found or JSON is invalid.
        """
        tool_call_match = re.findall(r'```json(.*?)```', action_string, re.DOTALL)
        return tool_call_match[-1] if tool_call_match else None


    def execute(self, action_string: str, **kwargs) -> tuple:
        """
        Execute the tool functionality based on the action string.
        
        Args:
            action_string: The string containing the tool call in XML tags.
            
        Returns:
            observation: The structured observation with the processed image.
            reward: 0.1 if tool call is successful with correct JSON format, 0 otherwise.
            done: Whether the episode is terminated.
            info: Additional info.
        """

        answer = self.extract_answer(action_string)
        if answer:
            return "", 0.0, True, {}
        action = self.extract_action(action_string)
        if not action:
            return "", 0.0, True, {}
        
        try:
            tool_call = json.loads(action.strip())  # 或使用 literal_eval
        except Exception as e:
            error_msg = f"Invalid tool call format: {action.strip()}. Error: {e}"
            return [{"role": "user", "content": error_msg}], 0.0, False, {}
        
        print(f"DEBUG{tool_call=}")
        cropped_bbox = self.get_bbox_2d(tool_call)
        if not cropped_bbox:
            user_msg = [{"role": "user", "content": "ZOOM IN ARGUMENTS ARE INVALID"}]
            return user_msg, 0.0, False, {}

        # TODO: modify here and process the final output
        try:
            pil_img = self.multi_modal_data['image'][0]
            cropped_image = pil_img.crop(cropped_bbox)
        except Exception as err:
            user_msg = [{"role": "user", "content": "ZOOM IN AREA IS INVALID"}]
            return user_msg, 0.0, False, {}

        user_msg = self.user_prompt.format(cropped_bbox)
        chat_msg = [{"role": "user", "content": user_msg}]
        obs_dict = {"chat": chat_msg, "multi_modal_data": {"image": [cropped_image]}}
        return obs_dict, 0.0, False, {}


    def reset(self, raw_prompt, multi_modal_data, origin_multi_modal_data, **kwargs):
        self.chatml_history = raw_prompt
        self.multi_modal_data = origin_multi_modal_data
        assert 'image' in self.multi_modal_data.keys(), f'[ERROR] {origin_multi_modal_data=}'
        assert len(self.multi_modal_data['image']) > 0, f'[ERROR] {self.multi_modal_data["image"]=}'
        
        self.height = self.multi_modal_data['image'][0].height
        self.width = self.multi_modal_data['image'][0].width
        
    def get_bbox_2d(self, action_list):
            if not action_list:
                return None

            for action in action_list:
                if not action:
                    continue
                try:
                    bbox_info = eval(action)
                    if isinstance(bbox_info, list):
                        bbox_2d = bbox_info[0]['bbox_2d']
                    else:
                        bbox_2d = bbox_info['bbox_2d']
                    assert isinstance(bbox_2d, list), f"[ERROR] invalid bbox_2d type: {bbox_2d=}"
                    assert len(bbox_2d) == 4, f"[ERROR] invalid size for {bbox_2d=}"
                    bbox_result = self.maybe_resize_bbox(*bbox_2d)
                    if not bbox_result:
                        continue
                    return bbox_result
                except Exception as err:
                    print(f' [ERROR] unexpected {err=}')
                    continue
            return None


    def validate_bbox(self, left, top, right, bottom):
        try:
            assert left < right and bottom > top, f'invalid shape for {left=}, {top=}, {right=}, {bottom=}'
            height = bottom - top
            width = right - left
            assert max(height, width) / min(height, width) <= 100, f"aspect ratio error: {left=}, {top=}, {right=}, {bottom=}"
            return True
        except Exception as err:
            print(f' [ERROR vl_agent #2] {err=}')
            return False


    def maybe_resize_bbox(self, left, top, right, bottom):
        left = max(0, left)
        top = max(0, top)
        right = min(self.width, right)
        bottom = min(self.height, bottom)
        if not self.validate_bbox(left, top, right, bottom):
            return None

        height = bottom - top
        width = right - left
        if height < 28 or width < 28:
            center_x = (left + right) / 2.0
            center_y = (top + bottom) / 2.0
            ratio = 28 / min(height, width)
            new_half_height = ceil(height * ratio * 0.5)
            new_half_width = ceil(width * ratio * 0.5)
            new_left = floor(center_x - new_half_width)
            new_right = ceil(center_x + new_half_width)
            new_top = floor(center_y - new_half_height)
            new_bottom = ceil(center_y + new_half_height)
            if not self.validate_bbox(new_left, new_top, new_right, new_bottom):
                return None
            return [new_left, new_top, new_right, new_bottom]
        return [left, top, right, bottom]

if __name__ == "__main__":
    # Example usage (for testing)
    tool = VisualToolBox("visual_toolbox", "Tool for image processing", {})
    
    # Test zoom in tool (should return reward=0.1)
    zoom_in_action = """
    <tool_call>
    {"name": "image_zoom_in_tool", "arguments": {"image_path": "test.jpg", "bbox": [10, 10, 100, 100]}}
    </tool_call>
    """
    obs, reward, done, info = tool.execute(zoom_in_action)
    print(f"Zoom in result - Reward: {reward}, Info: {info}")
    
    # Test rotate tool (should return reward=0.1)
    rotate_action = """
    <tool_call>
    {"name": "image_rotate_tool", "arguments": {"image_path": "test.jpg", "angle": 90}}
    </tool_call>
    """
    obs, reward, done, info = tool.execute(rotate_action)
    print(f"Rotate result - Reward: {reward}, Info: {info}")
    
    # Test invalid JSON (should return reward=0.0)
    invalid_action = """
    <tool_call>
    {"name": "image_rotate_tool", "arguments": {"image_path": "test.jpg", "angle": 90}
    </tool_call>
    """
    obs, reward, done, info = tool.execute(invalid_action)
    print(f"Invalid JSON result - Reward: {reward}, Info: {info}")
    
    # Test unknown tool (should return reward=0.0)
    unknown_tool_action = """
    <tool_call>
    {"name": "unknown_tool", "arguments": {"param": "value"}}
    </tool_call>
    """
    obs, reward, done, info = tool.execute(unknown_tool_action)
    print(f"Unknown tool result - Reward: {reward}, Info: {info}")