# Copyright 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
基于 fileagent_libs.fileagent_mcp_host 的沙盒工具
支持 ExecuteCode 和 ExecuteShell 两个工具
"""

import json
import logging
import os
from typing import Any, Optional
from uuid import uuid4

from verl.tools.base_tool import BaseTool
from verl.tools.schemas import OpenAIFunctionToolSchema, ToolResponse
from verl.utils.rollout_trace import rollout_trace_op

# 导入你的沙盒模块
try:
    from fileagent_libs.fileagent_mcp_host import (
        create_sandbox,
        sandbox_write_file,
        sandbox_execute_command,
        destroy_sandbox,
        run_code,
    )
    FILEAGENT_AVAILABLE = True
except ImportError:
    FILEAGENT_AVAILABLE = False
    logging.warning("fileagent_mcp_host not available, FileagentSandboxTool will not work")

logger = logging.getLogger(__name__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))

# 全局沙盒管理器 - 用于在 ExecuteCode 和 ExecuteShell 之间共享沙盒
_GLOBAL_SANDBOX_REGISTRY = {}


def _as_dict_return(obj: Any) -> Any:
    """
    兼容 .as_dict()['return'] 模式：
      - JSON字符串：'{"session_id":"..."}'
      - dict：{"code":0, "stdout":"...", "stderr":"..."}
      - 普通字符串
    """
    if hasattr(obj, "as_dict"):
        d = obj.as_dict()
        val = d.get("return")
        if isinstance(val, str):
            try:
                return json.loads(val)
            except Exception:
                return val
        return val
    return obj


class FileagentExecuteCodeTool(BaseTool):
    """
    使用 fileagent_mcp_host.run_code 执行代码的工具
    对应你原来的 ExecuteCode 工具
    """

    def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema):
        super().__init__(config, tool_schema)
        if not FILEAGENT_AVAILABLE:
            raise ImportError("fileagent_mcp_host module is required for FileagentExecuteCodeTool")
        
        self._instance_dict = {}  # 存储每个实例的 session_id 和状态
        self.default_language = config.get("default_language", "python")
        self.default_timeout = config.get("default_timeout", 30)
        
        logger.info(f"Init FileagentExecuteCodeTool with config: {config}")

    async def create(
        self, 
        instance_id: Optional[str] = None, 
        **kwargs
    ) -> tuple[str, ToolResponse]:
        """
        为每个 trajectory 创建一个沙盒实例（支持复用）
        
        Args:
            instance_id: 实例ID
            kwargs: 可选参数，比如 prewrites 文件列表
        
        Returns:
            (instance_id, 创建响应)
        """
        if instance_id is None:
            instance_id = str(uuid4())
        
        # 检查是否已经创建过沙盒（复用机制）
        if instance_id in _GLOBAL_SANDBOX_REGISTRY:
            # 增加引用计数
            _GLOBAL_SANDBOX_REGISTRY[instance_id]["ref_count"] += 1
            ref_count = _GLOBAL_SANDBOX_REGISTRY[instance_id]["ref_count"]
            session_id = _GLOBAL_SANDBOX_REGISTRY[instance_id]["session_id"]
            logger.info(f"Reusing existing sandbox for instance {instance_id[:8]}, ref_count={ref_count}")
            # 将全局注册表的引用复制到本地实例字典
            if instance_id not in self._instance_dict:
                self._instance_dict[instance_id] = _GLOBAL_SANDBOX_REGISTRY[instance_id]
            return instance_id, ToolResponse(text=f"Sandbox reused: {session_id[:8]}")
        
        # 创建沙盒
        try:
            resp = create_sandbox()
            val = _as_dict_return(resp)
            
            # 解析 session_id
            if isinstance(val, dict) and "session_id" in val:
                session_id = val["session_id"]
            else:
                # 尝试解析JSON字符串
                try:
                    j = json.loads(val)
                    if "session_id" in j:
                        session_id = j["session_id"]
                    else:
                        raise RuntimeError(f"CreateSandBox returned unexpected payload: {val!r}")
                except Exception:
                    raise RuntimeError(f"CreateSandBox returned unexpected payload: {val!r}")
            
            logger.info(f"Created sandbox for instance {instance_id[:8]}: session_id={session_id[:8]}")
            
            # 从 kwargs 中获取 ground_truth（如果有）
            ground_truth = kwargs.get("create_kwargs", {}).get("ground_truth")
            
            # 存储实例信息（同时存到全局注册表和本地字典）
            instance_data = {
                "session_id": session_id,
                "execution_history": [],
                "ground_truth": ground_truth,  # 保存 ground_truth 用于奖励计算
                "ref_count": 1,  # 引用计数
            }
            self._instance_dict[instance_id] = instance_data
            _GLOBAL_SANDBOX_REGISTRY[instance_id] = instance_data
            
            # 处理预写入文件（如果有）
            prewrites = kwargs.get("create_kwargs", {}).get("prewrites", [])
            if isinstance(prewrites, list):
                import base64
                for pw in prewrites:
                    filename = pw.get("filename")
                    fpath = pw.get("fpath", "")
                    is_binary = pw.get("binary", False)
                    
                    if filename and fpath and os.path.exists(fpath):
                        try:
                            with open(fpath, 'rb') as f:
                                content = f.read()
                            sandbox_write_file(
                                session_id, 
                                filename, 
                                base64.b64encode(content).decode("ascii"), 
                                binary=True
                            )
                            logger.info(f"Wrote file {filename} to sandbox {session_id[:8]}")
                        except Exception as we:
                            logger.warning(f"Failed to write file {filename}: {we}")
            
            return instance_id, ToolResponse(text=f"Sandbox created: {session_id[:8]}")
            
        except Exception as e:
            logger.error(f"Failed to create sandbox: {e}")
            raise

    @rollout_trace_op
    async def execute(
        self, 
        instance_id: str, 
        parameters: dict[str, Any], 
        **kwargs
    ) -> tuple[ToolResponse, float, dict]:
        """
        执行代码
        
        Args:
            instance_id: 实例ID
            parameters: 工具参数，包含 "code", "language", "timeout" 等
        
        Returns:
            (工具响应, 步骤奖励, 指标字典)
        """
        logger.info(f"🔧 [ExecuteCode] Starting execution for instance {instance_id[:8]}...")
        
        if instance_id not in self._instance_dict:
            logger.error(f"❌ [ExecuteCode] Sandbox not initialized for {instance_id[:8]}")
            return ToolResponse(text="Error: Sandbox not initialized"), -1.0, {}
        
        session_id = self._instance_dict[instance_id]["session_id"]
        code = parameters.get("code", "")
        language = parameters.get("language", self.default_language)
        timeout = parameters.get("timeout", self.default_timeout)
        
        logger.info(f"📝 [ExecuteCode] Session: {session_id[:8]}, Code length: {len(code)} chars")
        
        if not isinstance(code, str):
            code = str(code)
        
        try:
            # 调用 run_code
            logger.info(f"▶️  [ExecuteCode] Calling run_code...")
            resp = run_code(code, language=language, session_id=session_id)
            logger.info(f"✅ [ExecuteCode] run_code completed")
            
            # 解析响应
            try:
                resp_dict = json.loads(resp.result) if hasattr(resp, 'result') else resp
            except Exception:
                resp_dict = resp
            
            # 提取执行结果
            if hasattr(resp_dict, "__dict__"):
                rc = int(getattr(resp_dict, "code", getattr(resp_dict, "returncode", 0)))
                stdout = str(getattr(resp_dict, "stdout", ""))
                stderr = str(getattr(resp_dict, "stderr", ""))
            elif isinstance(resp_dict, dict):
                rc = int(resp_dict.get("code", 0))
                stdout = str(resp_dict.get("stdout", ""))
                stderr = str(resp_dict.get("stderr", ""))
            else:
                rc, stdout, stderr = 1, "", f"Unexpected response type: {type(resp_dict).__name__}"
            
            success = (rc == 0)
            output = stdout if stdout else stderr
            
            logger.info(f"{'✅' if success else '❌'} [ExecuteCode] Execution {'succeeded' if success else 'failed'} (rc={rc}), Output length: {len(output)} chars")
            
            # 记录执行历史
            self._instance_dict[instance_id]["execution_history"].append({
                "code": code,
                "success": success,
                "output": output,
                "returncode": rc
            })
            
            # 构建响应
            result_text = json.dumps({
                "success": success,
                "code": rc,
                "output": output,
                "stderr": stderr
            }, ensure_ascii=False)
            
            logger.info(f"📤 [ExecuteCode] Returning result: success={success}, output preview: {output[:100]}...")
            
            # 步骤奖励：成功执行给小奖励，失败给小惩罚（可选）
            step_reward = 0.0  # 默认不给步骤奖励，由 calc_reward 统一计算
            
            # 指标
            metrics = {
                "success": success,
                "output_length": len(output)
            }
            
            return ToolResponse(text=result_text), step_reward, metrics
            
        except Exception as e:
            error_msg = f"Code execution error: {type(e).__name__}: {e}"
            logger.error(error_msg)
            return ToolResponse(text=json.dumps({
                "success": False,
                "code": 1,
                "output": "",
                "stderr": error_msg
            })), -0.1, {"success": False}

    async def calc_reward(self, instance_id: str, **kwargs) -> float:
        """
        计算最终奖励
        
        奖励计算策略：
        1. 基础奖励：执行成功率 (0.0 - 0.3)
        2. 输出奖励：如果有 ground_truth，检查输出是否包含正确答案 (0.0 - 0.7)
        
        Args:
            instance_id: 实例ID
            kwargs: 可选参数，包含 ground_truth, model_output 等
        
        Returns:
            最终奖励分数 (0.0 - 1.0)
        """
        if instance_id not in self._instance_dict:
            return 0.0
        
        history = self._instance_dict[instance_id]["execution_history"]
        ground_truth = self._instance_dict[instance_id].get("ground_truth")
        
        # 如果没有执行历史，返回 0
        if not history:
            return 0.0
        
        # 1. 基础奖励：执行成功率 (权重 0.3)
        success_count = sum(1 for h in history if h["success"])
        success_rate = success_count / len(history)
        base_reward = 0.3 * success_rate
        
        # 2. 输出奖励：检查是否得到正确答案 (权重 0.7)
        output_reward = 0.0
        
        # 从 kwargs 中获取模型的最终输出（包含 **Answer**: 部分）
        model_output = kwargs.get("model_output", "")
        
        if ground_truth and model_output:
            # 提取 **Answer**: 后面的内容
            import re
            answer_match = re.search(r"\*\*Answer\*\*:\s*(.*?)(?:\n|$)", model_output, re.DOTALL)
            if answer_match:
                answer_text = answer_match.group(1).strip()
                
                # 检查答案是否包含 ground_truth
                # 支持多种匹配方式
                ground_truth_str = str(ground_truth).strip()
                
                # 1. 精确匹配
                if ground_truth_str in answer_text:
                    output_reward = 0.7
                # 2. 数值匹配（去除逗号、空格等）
                elif ground_truth_str.replace(",", "").replace(" ", "") in answer_text.replace(",", "").replace(" ", ""):
                    output_reward = 0.7
                # 3. 部分匹配（对于长答案）
                elif len(ground_truth_str) > 10 and any(
                    word in answer_text for word in ground_truth_str.split() if len(word) > 3
                ):
                    output_reward = 0.5
        
        total_reward = base_reward + output_reward
        
        logger.debug(
            f"Reward calculation for {instance_id[:8]}: "
            f"base={base_reward:.2f}, output={output_reward:.2f}, total={total_reward:.2f}"
        )
        
        return total_reward

    async def release(self, instance_id: str, **kwargs) -> None:
        """
        释放沙盒实例（延迟销毁策略）
        
        verl 会在每次工具调用后立即调用 release，但我们希望沙盒在整个 episode 中保持活跃。
        因此，release 只是减少引用计数，但不真正销毁沙盒，除非明确指定 force_destroy=True。
        
        Args:
            instance_id: 实例ID
            kwargs: 可选参数，force_destroy=True 强制销毁
        """
        if instance_id not in _GLOBAL_SANDBOX_REGISTRY:
            # 从本地字典也删除
            if instance_id in self._instance_dict:
                del self._instance_dict[instance_id]
            return
        
        force_destroy = kwargs.get("force_destroy", False)
        
        # 减少引用计数
        _GLOBAL_SANDBOX_REGISTRY[instance_id]["ref_count"] -= 1
        ref_count = _GLOBAL_SANDBOX_REGISTRY[instance_id]["ref_count"]
        session_id = _GLOBAL_SANDBOX_REGISTRY[instance_id]["session_id"]
        
        logger.debug(f"Release called for {instance_id[:8]}, ref_count={ref_count}, force_destroy={force_destroy}")
        
        # 只从本地字典中删除，但保留在全局注册表中（供后续复用）
        if instance_id in self._instance_dict:
            del self._instance_dict[instance_id]
        
        # 只有在强制销毁或引用计数为负数时才真正销毁沙盒
        if force_destroy or ref_count < 0:
            try:
                destroy_sandbox(session_id=session_id)
                logger.info(f"Destroyed sandbox {session_id[:8]} for instance {instance_id[:8]}")
            except Exception as e:
                logger.warning(f"Failed to destroy sandbox {session_id[:8]}: {e}")
            finally:
                # 从全局注册表中删除
                if instance_id in _GLOBAL_SANDBOX_REGISTRY:
                    del _GLOBAL_SANDBOX_REGISTRY[instance_id]


class FileagentExecuteShellTool(BaseTool):
    """
    使用 fileagent_mcp_host.sandbox_execute_command 执行 Shell 命令的工具
    对应你原来的 ExecuteShell 工具
    """

    def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema):
        super().__init__(config, tool_schema)
        if not FILEAGENT_AVAILABLE:
            raise ImportError("fileagent_mcp_host module is required for FileagentExecuteShellTool")
        
        self._instance_dict = {}
        logger.info(f"Init FileagentExecuteShellTool with config: {config}")

    async def create(
        self, 
        instance_id: Optional[str] = None, 
        **kwargs
    ) -> tuple[str, ToolResponse]:
        """创建沙盒实例（与 ExecuteCode 共享逻辑和沙盒）"""
        if instance_id is None:
            instance_id = str(uuid4())
        
        # 检查是否已经创建过沙盒（复用机制 - 与 ExecuteCode 共享）
        if instance_id in _GLOBAL_SANDBOX_REGISTRY:
            # 增加引用计数
            _GLOBAL_SANDBOX_REGISTRY[instance_id]["ref_count"] += 1
            ref_count = _GLOBAL_SANDBOX_REGISTRY[instance_id]["ref_count"]
            session_id = _GLOBAL_SANDBOX_REGISTRY[instance_id]["session_id"]
            logger.info(f"[ExecuteShell] Reusing existing sandbox for instance {instance_id[:8]}, ref_count={ref_count}")
            # 将全局注册表的引用复制到本地实例字典
            if instance_id not in self._instance_dict:
                self._instance_dict[instance_id] = _GLOBAL_SANDBOX_REGISTRY[instance_id]
            return instance_id, ToolResponse(text=f"Sandbox reused: {session_id[:8]}")
        
        try:
            resp = create_sandbox()
            val = _as_dict_return(resp)
            
            if isinstance(val, dict) and "session_id" in val:
                session_id = val["session_id"]
            else:
                try:
                    j = json.loads(val)
                    session_id = j["session_id"]
                except Exception:
                    raise RuntimeError(f"CreateSandBox returned unexpected payload: {val!r}")
            
            logger.info(f"[ExecuteShell] Created sandbox for instance {instance_id[:8]}: session_id={session_id[:8]}")
            
            # 从 kwargs 中获取 ground_truth（如果有）
            ground_truth = kwargs.get("create_kwargs", {}).get("ground_truth")
            
            # 存储实例信息（同时存到全局注册表和本地字典）
            instance_data = {
                "session_id": session_id,
                "execution_history": [],
                "ground_truth": ground_truth,  # 保存 ground_truth 用于奖励计算
                "ref_count": 1,  # 引用计数
            }
            self._instance_dict[instance_id] = instance_data
            _GLOBAL_SANDBOX_REGISTRY[instance_id] = instance_data
            
            return instance_id, ToolResponse(text=f"Sandbox created: {session_id[:8]}")
            
        except Exception as e:
            logger.error(f"Failed to create sandbox: {e}")
            raise

    @rollout_trace_op
    async def execute(
        self, 
        instance_id: str, 
        parameters: dict[str, Any], 
        **kwargs
    ) -> tuple[ToolResponse, float, dict]:
        """
        执行 Shell 命令
        
        Args:
            instance_id: 实例ID
            parameters: 工具参数，包含 "command"
        
        Returns:
            (工具响应, 步骤奖励, 指标字典)
        """
        logger.info(f"🔧 [ExecuteShell] Starting execution for instance {instance_id[:8]}...")
        
        if instance_id not in self._instance_dict:
            logger.error(f"❌ [ExecuteShell] Sandbox not initialized for {instance_id[:8]}")
            return ToolResponse(text="Error: Sandbox not initialized"), -1.0, {}
        
        session_id = self._instance_dict[instance_id]["session_id"]
        command = parameters.get("command", "")
        
        logger.info(f"📝 [ExecuteShell] Session: {session_id[:8]}, Command: {command[:100]}...")
        
        if not isinstance(command, str):
            command = str(command)
        
        try:
            # 调用 sandbox_execute_command
            logger.info(f"▶️  [ExecuteShell] Calling sandbox_execute_command...")
            resp = sandbox_execute_command(session_id=session_id, command=command)
            logger.info(f"✅ [ExecuteShell] Command completed")
            val = _as_dict_return(resp)
            
            if isinstance(val, dict):
                rc = int(val.get("code", 1))
                stdout = str(val.get("stdout", ""))
                stderr = str(val.get("stderr", ""))
            else:
                rc, stdout, stderr = 1, "", str(val)
            
            success = (rc == 0)
            output = stdout if stdout else stderr
            
            logger.info(f"{'✅' if success else '❌'} [ExecuteShell] Execution {'succeeded' if success else 'failed'} (rc={rc}), Output length: {len(output)} chars")
            
            # 记录执行历史
            self._instance_dict[instance_id]["execution_history"].append({
                "command": command,
                "success": success,
                "output": output,
                "returncode": rc
            })
            
            # 构建响应
            result_text = json.dumps({
                "success": success,
                "code": rc,
                "output": output,
                "stderr": stderr
            }, ensure_ascii=False)
            
            logger.info(f"📤 [ExecuteShell] Returning result: success={success}, output preview: {output[:100]}...")
            
            step_reward = 0.0
            metrics = {
                "success": success,
                "output_length": len(output)
            }
            
            return ToolResponse(text=result_text), step_reward, metrics
            
        except Exception as e:
            error_msg = f"Command execution error: {type(e).__name__}: {e}"
            logger.error(error_msg)
            return ToolResponse(text=json.dumps({
                "success": False,
                "code": 1,
                "output": "",
                "stderr": error_msg
            })), -0.1, {"success": False}

    async def calc_reward(self, instance_id: str, **kwargs) -> float:
        """
        计算最终奖励（与 ExecuteCode 相同的策略）
        
        奖励计算策略：
        1. 基础奖励：执行成功率 (0.0 - 0.3)
        2. 输出奖励：如果有 ground_truth，检查输出是否包含正确答案 (0.0 - 0.7)
        """
        if instance_id not in self._instance_dict:
            return 0.0
        
        history = self._instance_dict[instance_id]["execution_history"]
        ground_truth = self._instance_dict[instance_id].get("ground_truth")
        
        if not history:
            return 0.0
        
        # 1. 基础奖励：执行成功率
        success_count = sum(1 for h in history if h["success"])
        success_rate = success_count / len(history)
        base_reward = 0.3 * success_rate
        
        # 2. 输出奖励：检查答案
        output_reward = 0.0
        model_output = kwargs.get("model_output", "")
        
        if ground_truth and model_output:
            import re
            answer_match = re.search(r"\*\*Answer\*\*:\s*(.*?)(?:\n|$)", model_output, re.DOTALL)
            if answer_match:
                answer_text = answer_match.group(1).strip()
                ground_truth_str = str(ground_truth).strip()
                
                if ground_truth_str in answer_text:
                    output_reward = 0.7
                elif ground_truth_str.replace(",", "").replace(" ", "") in answer_text.replace(",", "").replace(" ", ""):
                    output_reward = 0.7
                elif len(ground_truth_str) > 10 and any(
                    word in answer_text for word in ground_truth_str.split() if len(word) > 3
                ):
                    output_reward = 0.5
        
        total_reward = base_reward + output_reward
        
        logger.debug(
            f"Reward calculation for {instance_id[:8]}: "
            f"base={base_reward:.2f}, output={output_reward:.2f}, total={total_reward:.2f}"
        )
        
        return total_reward

    async def release(self, instance_id: str, **kwargs) -> None:
        """
        释放沙盒实例（延迟销毁策略 - 与 ExecuteCode 共享）
        
        verl 会在每次工具调用后立即调用 release，但我们希望沙盒在整个 episode 中保持活跃。
        因此，release 只是减少引用计数，但不真正销毁沙盒，除非明确指定 force_destroy=True。
        
        Args:
            instance_id: 实例ID
            kwargs: 可选参数，force_destroy=True 强制销毁
        """
        if instance_id not in _GLOBAL_SANDBOX_REGISTRY:
            # 从本地字典也删除
            if instance_id in self._instance_dict:
                del self._instance_dict[instance_id]
            return
        
        force_destroy = kwargs.get("force_destroy", False)
        
        # 减少引用计数
        _GLOBAL_SANDBOX_REGISTRY[instance_id]["ref_count"] -= 1
        ref_count = _GLOBAL_SANDBOX_REGISTRY[instance_id]["ref_count"]
        session_id = _GLOBAL_SANDBOX_REGISTRY[instance_id]["session_id"]
        
        logger.debug(f"[ExecuteShell] Release called for {instance_id[:8]}, ref_count={ref_count}, force_destroy={force_destroy}")
        
        # 只从本地字典中删除，但保留在全局注册表中（供后续复用）
        if instance_id in self._instance_dict:
            del self._instance_dict[instance_id]
        
        # 只有在强制销毁或引用计数为负数时才真正销毁沙盒
        if force_destroy or ref_count < 0:
            try:
                destroy_sandbox(session_id=session_id)
                logger.info(f"[ExecuteShell] Destroyed sandbox {session_id[:8]} for instance {instance_id[:8]}")
            except Exception as e:
                logger.warning(f"[ExecuteShell] Failed to destroy sandbox {session_id[:8]}: {e}")
            finally:
                # 从全局注册表中删除
                if instance_id in _GLOBAL_SANDBOX_REGISTRY:
                    del _GLOBAL_SANDBOX_REGISTRY[instance_id]

