import logging
import os
from typing import Any, Optional, Union
from uuid import uuid4

from verl.utils.rollout_trace import rollout_trace_op
from verl.tools.base_tool import BaseTool
from verl.tools.schemas import OpenAIFunctionToolSchema, ToolResponse
from recipe.fileagent.agent_loop import ImageToolResponse
from recipe.fileagent.utils.metric_utils import build_tool_metric
from recipe.fileagent.utils.vision_utils import controlled_smart_resize

logger = logging.getLogger(__name__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))


class ImageZoomInTool(BaseTool):
    def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema):
        super().__init__(config, tool_schema)
        self._instance_dict = {}
        self.custom_config = self.config.get("custom", {})
        self.max_crop_pixels = self.custom_config.pop("max_crop_pixels", 1024*8*28*28)

    async def create(self, instance_id: Optional[str] = None, **kwargs) -> tuple[str, ToolResponse]:
        if instance_id is None:
            instance_id = str(uuid4())
        import copy
        images = kwargs["create_kwargs"]["images"]
        model_input_sizes = kwargs["create_kwargs"]["model_input_sizes"]
        if images is None:
            raise ValueError("images must be provided when creating the image zoom in tool.")
        if model_input_sizes is None:
            raise ValueError("model_input_sizes must be provided when creating the image zoom in tool.")
        self._instance_dict[instance_id] = {
            "response": "",
            "images": images,
            "model_input_sizes": model_input_sizes,
        }
        return instance_id, ToolResponse()

    @rollout_trace_op
    async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[Union[ImageToolResponse, ToolResponse], float, dict]:
        index = parameters.get("image_id",1)
        bbox_2d = parameters.get("bbox_2d", None)
        label = parameters.get("label", "")

        # Validate bbox_2d
        is_valid, message = self._validate_bbox(bbox_2d)
        if not is_valid:
            metric = build_tool_metric(tool_name=self.name, succeeded=False)
            return ToolResponse(text=message), 0, metric

        try:
            index=int(index)-1
            image = self._instance_dict[instance_id]["images"][index]
            model_input_size = self._instance_dict[instance_id]["model_input_sizes"][index]
        except:
            metric = build_tool_metric(tool_name=self.name, succeeded=False)
            return ToolResponse(text=f"Invalid image id:{index}"), 0, metric

        # Refine bbox_2d
        #refined_bbox_2d = self._refine_bbox(bbox_2d, image, model_input_size)
        refined_bbox_2d=bbox_2d

        # Crop image
        origin_cropped_image = image.crop(refined_bbox_2d)
        # TODO: append the origin cropped_image

        # Resize image
        #print("origin_cropped_image size:",origin_cropped_image.size)
        resized_height, resized_width = controlled_smart_resize(height=origin_cropped_image.height, width=origin_cropped_image.width, max_pixels=self.max_crop_pixels)
        #print("resized size:",resized_width,resized_height)
        cropped_image = origin_cropped_image.resize((resized_width, resized_height))

        # Build Response
        response_text = f"Zoomed in on the image to the region {bbox_2d}."
        if label:
            response_text = f"Zoomed in on the image to the region {bbox_2d} with label {label}."

        # Tool Metric
        metric = build_tool_metric(tool_name=self.name, succeeded=True)

        return ImageToolResponse(text=response_text, image=[cropped_image], origin_image=[origin_cropped_image]), 0, metric  

    async def release(self, instance_id: str, **kwargs) -> None:
        del self._instance_dict[instance_id]

    def _validate_bbox(self, bbox_2d):
        if not isinstance(bbox_2d, list) or len(bbox_2d) != 4:
            return False, "Invalid 'bbox_2d': expected a list of 4 integers [x1, y1, x2, y2]."

        left, top, right, bottom = bbox_2d

        if not (isinstance(left, int) and isinstance(top, int) and
                isinstance(right, int) and isinstance(bottom, int)):
            return False, "Invalid 'bbox_2d': all elements must be integers."

        if left >= right or bottom <= top:
            return False, f"Invalid 'bbox_2d': Must satisfy left < right and top < bottom."

        return True, ""
    
    def _refine_bbox(self, bbox_2d, image, model_input_size):
        # resized width and height
        Wr, Hr = model_input_size

        # original width and height
        W, H = image.width, image.height

        # map from resized -> original
        x1, y1, x2, y2 = bbox_2d
        sx, sy = (W / Wr), (H / Hr)
        x1_m = x1 * sx
        y1_m = y1 * sy
        x2_m = x2 * sx
        y2_m = y2 * sy

        # clamp to image bounds
        left = max(0, round(x1_m))
        top = max(0, round(y1_m))
        right = min(W, round(x2_m))
        bottom = min(H, round(y2_m))

        return [left, top, right, bottom]
