import asyncio
import copy
import json
import logging
import os
from typing import Any
from uuid import uuid4
from typing import Union
import numpy as np

from verl.experimental.agent_loop.tool_agent_loop import ToolAgentLoop, AgentData, AgentState
from verl.experimental.agent_loop.tool_agent_loop import AgentData as BaseAgentData
from verl.experimental.agent_loop.agent_loop import AgentLoopOutput
from verl.experimental.agent_loop.tool_parser import FunctionCall
from verl.tools.schemas import ToolResponse
from verl.utils.profiler import simple_timer
from verl.utils.rollout_trace import rollout_trace_op
from recipe.fileagent.utils.metric_utils import (
    build_tool_metric,
    init_trajectory_tool_metrics,
    update_trajectory_tool_metrics,
    FILEAGENT_TOOL_METRICS_KEY,
)

logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))

import re
def shift_picture_indices(text: str, offset: int) -> str:
    # 使用正则匹配 "Picture <数字>:"
    pattern = re.compile(r"Picture\s+(\d+):")

    # 替换函数：把匹配到的数字加上偏移量
    def replacer(match):
        num = int(match.group(1))
        return f"Picture {num + offset}:"
    # 逐个替换
    return pattern.sub(replacer, text)

# To do append origin images
class AgentData(BaseAgentData):
    def __init__(self, origin_image_data:Any,tool_metrics: dict[str, dict[str, Any]], **kwargs):
        super().__init__(**kwargs)
        self.tool_metrics = tool_metrics
        self.origin_image_data = origin_image_data

class ImageToolResponse(ToolResponse):
    origin_image: list[Any] | None = None

class FileAgentAgentLoop(ToolAgentLoop):
    @rollout_trace_op
    async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutput:
        messages = list(kwargs["raw_prompt"])
        image_data = copy.deepcopy(kwargs.get("multi_modal_data", {}).get("image", None))
        metrics = {}
        request_id = uuid4().hex
        tools_kwargs = kwargs.get("tools_kwargs", {})

        # Initialize interaction if needed
        interaction = None
        interaction_kwargs = {}
        if self.interaction_config_file:
            interaction_kwargs = kwargs["extra_info"]["interaction_kwargs"]
            if "name" not in interaction_kwargs:
                raise ValueError("'name' key is required in interaction_kwargs")
            interaction_name = interaction_kwargs["name"]
            if interaction_name not in self.interaction_map:
                raise ValueError(
                    f"Interaction '{interaction_name}' not found in interaction_map. Available interactions: "
                    f"{list(self.interaction_map.keys())}"
                )
            interaction = self.interaction_map[interaction_name]
            await interaction.start_interaction(request_id, **interaction_kwargs)

        # Create AgentData instance to encapsulate all state
        agent_data = AgentData(
            messages=messages,
            image_data=image_data,
            metrics=metrics,
            request_id=request_id,
            tools_kwargs=tools_kwargs,
            interaction=interaction,
            interaction_kwargs=interaction_kwargs,
            tool_metrics=init_trajectory_tool_metrics(self.tools.keys()),
            origin_image_data=copy.deepcopy(tools_kwargs.get("global_tool",{}).get("create_kwargs", {}).get("images", None))
        )

        # State machine loop
        state = AgentState.PENDING
        while state != AgentState.TERMINATED:
            if state == AgentState.PENDING:
                state = await self._handle_pending_state(agent_data, sampling_params)
            elif state == AgentState.GENERATING:
                state = await self._handle_generating_state(agent_data, sampling_params)
                agent_data.assistant_turns += 1
            elif state == AgentState.PROCESSING_TOOLS:
                state = await self._handle_processing_tools_state(agent_data)
            elif state == AgentState.INTERACTING:
                state = await self._handle_interacting_state(agent_data)
                agent_data.user_turns += 1
            else:
                logger.error(f"Invalid state: {state}")
                state = AgentState.TERMINATED

        # Finalize output
        response_ids = agent_data.prompt_ids[-len(agent_data.response_mask) :]
        prompt_ids = agent_data.prompt_ids[: len(agent_data.prompt_ids) - len(agent_data.response_mask)]
        multi_modal_data = {"image": agent_data.image_data} if agent_data.image_data is not None else {}
        output = AgentLoopOutput(
            prompt_ids=prompt_ids,
            response_ids=response_ids[: self.response_length],
            response_mask=agent_data.response_mask[: self.response_length],
            multi_modal_data=multi_modal_data,
            response_logprobs=agent_data.response_logprobs[: self.response_length]
            if agent_data.response_logprobs
            else None,
            num_turns=agent_data.user_turns + agent_data.assistant_turns + 1,
            metrics=agent_data.metrics,
            extra_fields={},
        )
        output.extra_fields.update({"turn_scores": agent_data.turn_scores})
        output.extra_fields.update({FILEAGENT_TOOL_METRICS_KEY: agent_data.tool_metrics})
        
        # Cleanup: Force destroy sandbox for this request after trajectory completes
        for tool_name, tool in self.tools.items():
            try:
                await tool.release(agent_data.request_id, force_destroy=True)
                logging.info(f"Cleaned up {tool_name} sandbox for request {agent_data.request_id[:8]}")
            except Exception as e:
                logging.error(f"Failed to cleanup {tool_name} sandbox: {e}")
        
        return output

    async def _handle_processing_tools_state(self, agent_data: AgentData) -> AgentState:
        """Handle the processing tools state: execute tool calls and prepare tool responses."""
        add_messages: list[dict[str, Any]] = []
        before_messages: list[dict[str, Any]] = copy.deepcopy(agent_data.messages)
        before_image_data = copy.deepcopy(agent_data.image_data)
        new_images_this_turn: list[Any] = []  # Local variable instead of agent_data attribute
        tasks = []
        tool_names=[]
        for tool_call in agent_data.tool_calls[: self.max_parallel_calls]:
            tool_names.append(tool_call.name)
            tasks.append(self._call_tool(tool_call, agent_data))

        with simple_timer("tool_calls", agent_data.metrics):
            tool_results = await asyncio.gather(*tasks)

        # Handle responses for interaction if needed
        if self.interaction_config_file:
            for response, _ in tool_results:
                if response.text:
                    agent_data.messages.append({"role": "tool", "content": response.text})

        # Process tool responses and update multi_modal_data
        # Removed: agent_data.new_images_this_turn = []
        for (tool_response, tool_metric), tool_name in zip(tool_results, tool_names):
            # Handle tool metrics
            update_trajectory_tool_metrics(agent_data.tool_metrics, tool_metric)

            # Create message from tool response
            if tool_response.image or tool_response.video:
                # Multi-modal content with structured format
                content = []
                if tool_response.image:
                    content.append({"type": "image"})
                if tool_response.video:
                    content.append({"type": "video"})
                if tool_response.text:
                    content.append({"type": "text", "text": tool_response.text})
                message = {"role": "tool", "content": content}
            else:
                # Text-only content
                message = {"role": "tool", "content": tool_response.text or ""}

            add_messages.append(message)
            agent_data.messages.extend(add_messages)
            # Handle image data
            if isinstance(tool_response, ImageToolResponse):

                if agent_data.image_data is None:
                    agent_data.image_data = []
                elif not isinstance(agent_data.image_data, list):
                    agent_data.image_data = [agent_data.image_data]

                if agent_data.origin_image_data is None:
                    agent_data.origin_image_data = []
                elif not isinstance(agent_data.origin_image_data, list):
                    agent_data.origin_image_data = [agent_data.origin_image_data]
                
                assert len(agent_data.origin_image_data) == len(agent_data.image_data), "origin_image_data and image_data should have the same length!"

                # Add new image data
                if isinstance(tool_response.image, list):
                    # Ensure all elements in the list are valid image objects
                    for img in tool_response.image:
                        if img is not None:  # Add a check to ensure the image is not None
                            agent_data.image_data.append(img)
                            new_images_this_turn.append(img)  # Using local variable
                else:
                    # Ensure the image is not None
                    if tool_response.image is not None:
                        agent_data.image_data.append(tool_response.image)
                        new_images_this_turn.append(tool_response.image)  # Using local variable

                if isinstance(tool_response.origin_image, list):
                    # Ensure all elements in the list are valid image objects
                    for img in tool_response.origin_image:
                        if img is not None:  # Add a check to ensure the image is not None
                            agent_data.origin_image_data.append(img)
                else:
                    # Ensure the image is not None
                    if tool_response.origin_image is not None:
                        agent_data.origin_image_data.append(tool_response.origin_image)
            # Handle video data
            if tool_response.video:
                # Currently not supported, raise informative error
                logger.warning("Multimedia type 'video' is not currently supported. Only 'image' is supported.")
                raise NotImplementedError(
                    "Multimedia type 'video' is not currently supported. Only 'image' is supported."
                )
        model_input_sizes=agent_data.tools_kwargs.get("global_tool", {}).get("create_kwargs", {}).get("model_input_sizes", None)
        if agent_data.image_data and len(agent_data.image_data) > 0:
            model_input_sizes = [(img.width, img.height) for img in agent_data.image_data]
        agent_data.tools_kwargs.get("global_tool", {}).get("create_kwargs", {})['images']=copy.deepcopy(agent_data.origin_image_data)
        agent_data.tools_kwargs.get("global_tool", {}).get("create_kwargs", {})['model_input_sizes']=copy.deepcopy(model_input_sizes)

        # Update prompt with tool responses
        if self.processor is not None:
            raw_tool_response = await self.loop.run_in_executor(
                None,
                lambda: self.processor.apply_chat_template(
                    add_messages,
                    add_generation_prompt=True,
                    tokenize=False,
                    **self.apply_chat_template_kwargs,
                ),
            )
            if before_image_data and self.apply_chat_template_kwargs.get("add_vision_id", False):
                raw_tool_response=shift_picture_indices(raw_tool_response,len(before_image_data))
            # Use only the new images from this turn for processing tool responses
            current_images = new_images_this_turn if new_images_this_turn else None  # Using local variable
            model_inputs = self.processor(text=[raw_tool_response], images=current_images, return_tensors="pt")
            response_ids = model_inputs.pop("input_ids").squeeze(0).tolist()
        else:
            response_ids = await self.loop.run_in_executor(
                None,
                lambda: self.tokenizer.apply_chat_template(add_messages, add_generation_prompt=True, tokenize=True),
            )
        
        response_ids = response_ids[len(self.system_prompt) :]
        if len(agent_data.response_mask) + len(response_ids) >= self.response_length:
            return AgentState.TERMINATED
        # Update prompt_ids and response_mask
        agent_data.prompt_ids += response_ids
        agent_data.response_mask += [0] * len(response_ids)
        if agent_data.response_logprobs:
            agent_data.response_logprobs += [0.0] * len(response_ids)
        agent_data.user_turns += 1
        return AgentState.GENERATING

    async def _call_tool(self, tool_call: FunctionCall, agent_data: AgentData) -> tuple[Union[ToolResponse, ImageToolResponse], dict]:
        """Call tool and return tool response."""
        tool, instance_id = None, None
        try:
            # TODO: append malformed tool_call to the prompt: invalid function name or arguments
            tool_name = tool_call.name
            tool_args = json.loads(tool_call.arguments)
            tool = self.tools[tool_name]
            #kwargs = agent_data.tools_kwargs.get(tool_name, {})
            global_kwargs=copy.deepcopy(agent_data.tools_kwargs.get("global_tool",{}))
            # 使用 request_id 作为 instance_id，这样同一个rollout的所有轮次都会复用同一个sandbox
            instance_id, _ = await tool.create(instance_id=agent_data.request_id, create_kwargs=global_kwargs.get("create_kwargs", {}))
            tool_execution_response, _, tool_metric = await tool.execute(instance_id, tool_args)
        except Exception as e:
            logger.warning(f"Error when executing tool: {e}")
            tool_metric = build_tool_metric(tool_name=tool_call.name, succeeded=False)
            return ToolResponse(
                text=f"Error when executing tool: {e}",
            ), tool_metric
        finally:
            if tool and instance_id:
                await tool.release(instance_id)

        tool_response_text = tool_execution_response.text
        if tool_response_text and len(tool_response_text) > self.max_tool_response_length:
            if self.tool_response_truncate_side == "left":
                tool_response_text = tool_response_text[: self.max_tool_response_length] + "...(truncated)"
            elif self.tool_response_truncate_side == "right":
                tool_response_text = "(truncated)..." + tool_response_text[-self.max_tool_response_length :]
            else:
                length = self.max_tool_response_length // 2
                tool_response_text = tool_response_text[:length] + "...(truncated)..." + tool_response_text[-length:]

        # Create ToolResponse from tool execution result
        tool_response_kwargs = {"text": tool_response_text}

        # Add multimedia data if present
        for attr_name in ["image", "video",'origin_image']:
            if hasattr(tool_execution_response, attr_name):
                attr_value = getattr(tool_execution_response, attr_name)
                if attr_value is not None:
                    tool_response_kwargs[attr_name] = attr_value
        
        if "origin_image" in tool_response_kwargs:
            response = ImageToolResponse(**tool_response_kwargs)
        else:
            response = ToolResponse(**tool_response_kwargs)
        
        return response, tool_metric


def apply_worker_patch(worker_cls):
    _orig_postprocess = worker_cls._postprocess

    def _patched_postprocess(self, inputs):
        output = _orig_postprocess(self, inputs)
        tool_metrics = np.array([input.extra_fields[FILEAGENT_TOOL_METRICS_KEY] for input in inputs])
        output.non_tensor_batch[FILEAGENT_TOOL_METRICS_KEY] = tool_metrics
        return output

    worker_cls._postprocess = _patched_postprocess
