import json
import logging
import os
from typing import Any, Optional
from uuid import uuid4
from fastmcp.exceptions import ClientError
from verl.tools.utils.mcp_clients.McpClientManager import ClientManager
from verl.utils.rollout_trace import rollout_trace_op
from .base_tool import BaseTool
from .schemas import OpenAIFunctionToolSchema, ToolResponse
logger = logging.getLogger(__name__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
class MCPBaseTool(BaseTool):
    def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema):
        super().__init__(config, tool_schema)
        self._instance_dict = {}
        self.timeout = config.get("timeout", 30)
        logger.info(f"Initialized MCPBaseTool with config: {config}")
    def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema:
        return self.tool_schema
    async def create(self, instance_id: Optional[str] = None, **kwargs) -> tuple[str, ToolResponse]:
        if instance_id is None:
            instance_id = str(uuid4())
        self._instance_dict[instance_id] = {
            "response": "",
            "reward": [],
        }
        return instance_id, ToolResponse()
    async def _call_tool(self, instance_id, parameters) -> tuple[str, dict]:
        err_msg = ""
        try:
            call_tool_result = await ClientManager.call_tool(self.name, parameters, self.timeout)
        except ClientError as e:
            err_msg = f"\n Tool call failed: {e}"
        except ConnectionError as e:
            err_msg = f"\n Connection failed: {e}"
        except Exception as e:
            err_msg = f"\n An unexpected error occurred: {e}"
        logger.debug(f"Tool result for instance {instance_id} with tool {self.name}: {call_tool_result.content}")
        result, metadata = self._parse_tool_result(call_tool_result.content)
        metadata["api_request_error"] = None if not err_msg else err_msg
        return result, metadata
    @rollout_trace_op
    async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> tuple[ToolResponse, float, dict]:
        if self.name == "" or self.name is None or parameters is None:
            error_msg = "Error: 'parameters' is missing or empty."
            logger.error(f"[MCPTool] {error_msg} Received tool name: {self.name}, parameters: {parameters}")
            return ToolResponse(text=json.dumps({"result": error_msg})), 0.0, {}
        try:
            result_text, metadata = await self._call_tool(instance_id, parameters)
            self._instance_dict[instance_id]["reward"].append(result_text.strip())
            metrics = {
                "query_count": metadata.get("query_count", 0),
                "status": metadata.get("status", "unknown"),
                "total_results": metadata.get("total_results", 0),
                "api_request_error": metadata.get("api_request_error"),
            }
            return ToolResponse(text=result_text), 0.0, metrics
        except Exception as e:
            error_result = json.dumps({"result": f"Tool execution failed: {e}"})
            logger.error(f"[MCPBaseTool] Execution failed: {e}")
            return ToolResponse(text=error_result), 0.0, {"error": str(e)}
    async def calc_reward(self, instance_id: str, **kwargs) -> str:
        return self._instance_dict[instance_id]["reward"]
    async def release(self, instance_id: str, **kwargs) -> None:
        if instance_id in self._instance_dict:
            del self._instance_dict[instance_id]
    def _parse_tool_result(self, content: list) -> tuple[str, dict]:
        tools_content = [part.text for part in filter(lambda x: x.type == "text", content)]
        return " ".join(tools_content), {}