# Copyright 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
基于 fileagent_libs.fileagent_mcp_host 的沙盒工具
支持 ExecuteCode 和 ExecuteShell 两个工具
与 fileagent_stateful_tool.py 接口对齐
"""

import base64
import json
import logging
import os
import subprocess
from typing import Any, Optional
from uuid import uuid4

from verl.tools.base_tool import BaseTool
from verl.tools.schemas import OpenAIFunctionToolSchema, ToolResponse
from verl.utils.rollout_trace import rollout_trace_op

# 导入沙盒模块
try:
    # from fileagent_libs.fileagent_mcp_host import (
    #     # create_sandbox,
    #     # sandbox_write_file,
    #     # sandbox_execute_command,
    #     # destroy_sandbox,
    #     # run_code,   
    #     async_run_code
    #     async_sandbox_write_file,
    #     async_sandbox_execute_command,
    #     async_sandbox_create_sandbox,
    #     async_sandbox_destroy_sandbox,
    # )
    FILEAGENT_AVAILABLE = True
    from fileagent_libs.fileagent_mcp_host import (async_sandbox_write_file, async_sandbox_execute_command, async_sandbox_create_sandbox, async_sandbox_destroy_sandbox, async_run_code)

except ImportError:
    FILEAGENT_AVAILABLE = False    
    logging.warning("fileagent_mcp_host not available, FileagentSandboxTool will not work")

logger = logging.getLogger(__name__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))

FILE = tuple[str, bytes]

# 全局沙盒管理器 - 让不同工具共享同一个沙盒
_SHARED_SANDBOX_REGISTRY = {}


class FileagentExecuteCodeTool(BaseTool):
    """
    使用 fileagent_mcp_host.run_code 执行代码的工具
    接口与 FileAgentSandboxTool 对齐
    """

    def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema):
        super().__init__(config, tool_schema)
        if not FILEAGENT_AVAILABLE:
            raise ImportError("fileagent_mcp_host module is required for FileagentExecuteCodeTool")
        
        self._session_ids: dict[str, str] = {}
        self.default_language = config.get("default_language", "python")
        # Timeout可以从两个地方读取
        self.default_timeout = config.get("default_timeout") or config.get("custom", {}).get("max_execution_time", 60)
        
        # print(f"⚙️  [ExecuteCode] Init with default_timeout={self.default_timeout}s")
        logger.info(f"Init FileagentExecuteCodeTool with config: {config}")

    async def create(
        self, 
        instance_id: Optional[str] = None,
        files: list[FILE] = [],
        **kwargs
    ) -> tuple[str, ToolResponse]:
        """
        创建沙盒实例（支持复用）
        
        如果该 instance_id 已经有共享的sandbox，直接复用；
        否则创建新的sandbox并注册到全局注册表。
        
        Args:
            instance_id: 实例ID（如果为None则自动生成）
            files: 预写入文件列表 [(filename, content), ...]
            **kwargs: 其他参数
        
        Returns:
            (instance_id, 创建响应)
        """
        created_instance_id, tool_response = await super().create(instance_id, **kwargs)
        
        # 调试：打印收到的参数
        create_kwargs = kwargs.get("create_kwargs", {})
        # print(f"🔍 [ExecuteCode] DEBUG - create_kwargs keys: {list(create_kwargs.keys())}")
        # print(f"🔍 [ExecuteCode] DEBUG - prewrites: {create_kwargs.get('prewrites', 'NOT FOUND')}")
        
        # 检查是否已经有共享的sandbox
        if created_instance_id in _SHARED_SANDBOX_REGISTRY:
            session_id = _SHARED_SANDBOX_REGISTRY[created_instance_id]
            self._session_ids[created_instance_id] = session_id
            # print(f"🔄 [ExecuteCode] Reusing shared sandbox {session_id[:8]} for instance {created_instance_id[:8]}")
            logger.info(f"[ExecuteCode] Reusing shared sandbox {session_id[:8]} for instance {created_instance_id[:8]}")
            return created_instance_id, ToolResponse(text=f"Reused sandbox: {session_id[:8]}")
        
        # 创建新沙盒
        try:
            response = await async_sandbox_create_sandbox()
            session_id = json.loads(response.result)["session_id"]
            
            # 同时保存到本地和全局注册表
            self._session_ids[created_instance_id] = session_id
            _SHARED_SANDBOX_REGISTRY[created_instance_id] = session_id
            
            # print(f"✨ [ExecuteCode] Created NEW sandbox {session_id[:8]} for instance {created_instance_id[:8]}")
            # logger.info(f"[ExecuteCode] Created NEW sandbox {session_id[:8]} for instance {created_instance_id[:8]}")
            
            # 写入预置文件（支持两种格式）
            # 格式1: files 参数 [(filename, content), ...]
            for filename, content in files:
                await async_sandbox_write_file(
                    session_id, 
                    filename, 
                    base64.b64encode(content).decode("ascii"), 
                    binary=True
                )
                # print(f"📥 [ExecuteCode] Wrote file from files param: {filename}")
                # logger.info(f"Wrote file {filename} to sandbox {session_id[:8]}")
            
            # 格式2: create_kwargs 中的 prewrites [{"filename": "...", "fpath": "...", "binary": true}, ...]
            create_kwargs = kwargs.get("create_kwargs", {})
            prewrites = create_kwargs.get("prewrites", [])
            if prewrites:
                # print(f"📦 [ExecuteCode] Processing {len(prewrites)} prewrites from create_kwargs...")
                for pw in prewrites:
                    filename = pw.get("filename")
                    fpath = pw.get("fpath")
                    
                    if not filename:
                        # print(f"⚠️  [ExecuteCode] Skipping prewrite: missing filename")
                        continue
                    
                    if fpath and os.path.exists(fpath):
                        try:
                            with open(fpath, 'rb') as f:
                                content = f.read()
                            await async_sandbox_write_file(
                                session_id,
                                filename,
                                base64.b64encode(content).decode("ascii"),
                                binary=True
                            )
                            # print(f"📥 [ExecuteCode] Wrote file: {filename} (from {fpath}, size={len(content)} bytes)")
                        except Exception as e:
                            print(f"⚠️  [ExecuteCode] Failed to write {filename}: {e}")
                    else:
                        print(f"⚠️  [ExecuteCode] Skipping {filename}: fpath not found or not exists")
            
            return created_instance_id, tool_response
            
        except Exception as e:
            fd_count = "unknown"
            try:
                fd_count = subprocess.check_output([
                    "bash",
                    "-lc",
                    "ls /proc/$(pgrep -n python)/fd | wc -l"
                ]).decode().strip()
            except Exception as fd_err:
                try:
                    fd_count = str(len(os.listdir('/proc/self/fd')))
                except Exception as inner_err:
                    fd_count = f"unavailable ({inner_err})"
                logger.warning(f"Failed to run fd count command: {fd_err}")
            logger.error(f"Failed to create sandbox: {e}; open_fds={fd_count}")
            raise

    @rollout_trace_op
    async def execute(
        self, 
        instance_id: str, 
        parameters: dict[str, Any], 
        **kwargs
    ) -> tuple[ToolResponse, float, dict]:
        """
        执行代码
        
        Args:
            instance_id: 实例ID
            parameters: 工具参数，包含 "code", "language", "timeout" 等
        
        Returns:
            (工具响应, 步骤奖励, 指标字典)
        """
        session_id = self._session_ids[instance_id]
        code = parameters.get("code", "")
        language = parameters.get("language", self.default_language)
        timeout = parameters.get("timeout", self.default_timeout)
        
        # print(f"\n{'='*80}")
        # print(f"⚡ [ExecuteCode] Starting execution in session {session_id[:8]}")
        # print(f"{'='*80}")
        # print(f"📝 Code to execute (length={len(code)}):")
        # print(f"{'─'*80}")
        # # 打印代码（最多前500个字符）
        # code_preview = code if len(code) <= 500 else code[:500] + "\n... (truncated)"
        # print(code_preview)
        # print(f"{'─'*80}")
        logger.info(f"[ExecuteCode] Executing code in session {session_id[:8]}")
        
        try:
            # 调用 run_code
            response = await async_run_code(code, language=language, session_id=session_id)
            
            # 规范化返回（参考 tool_utils.py 的 _normalize_run_code）
            result_dict = None
            
            # 方式1: 尝试 resp.result -> JSON
            if hasattr(response, "result"):
                try:
                    j = json.loads(response.result)
                    if isinstance(j, dict):
                        result_dict = {
                            "code": int(j.get("code", 1)),
                            "stdout": str(j.get("stdout", "")),
                            "stderr": str(j.get("stderr", ""))
                        }
                except Exception:
                    pass
            
            # 方式2: 对象属性
            if not result_dict and hasattr(response, "__dict__"):
                code = int(getattr(response, "code", getattr(response, "returncode", 1)))
                stdout = str(getattr(response, "stdout", ""))
                stderr = str(getattr(response, "stderr", ""))
                result_dict = {"code": code, "stdout": stdout, "stderr": stderr}
            
            # 方式3: 字典
            if not result_dict and isinstance(response, dict):
                result_dict = {
                    "code": int(response.get("code", 1)),
                    "stdout": str(response.get("stdout", "")),
                    "stderr": str(response.get("stderr", ""))
                }
            
            # 默认
            if not result_dict:
                result_dict = {"code": 1, "stdout": "", "stderr": f"Unexpected response type: {type(response).__name__}"}
            
            # 打印执行结果
            returncode = result_dict["code"]
            stdout_full = result_dict["stdout"]
            stderr_full = result_dict["stderr"]
            
            # 打印详细状态
            # print(f"{'='*80}")
            # status_emoji = "✅" if returncode == 0 else "❌"
            # status_text = "SUCCESS" if returncode == 0 else "FAILED"
            # print(f"{status_emoji} [ExecuteCode] Execution {status_text}")
            # print(f"{'='*80}")
            # print(f"📊 Status: returncode={returncode}")
            # print(f"📏 Output sizes: stdout={len(stdout_full)} chars, stderr={len(stderr_full)} chars")
            
            # if stdout_full:
            #     stdout_preview = stdout_full if len(stdout_full) <= 500 else stdout_full[:500] + "\n... (truncated, total " + str(len(stdout_full)) + " chars)"
            #     print(stdout_preview)
            
            # if stderr_full:
            #     stderr_preview = stderr_full if len(stderr_full) <= 500 else stderr_full[:500] + "\n... (truncated, total " + str(len(stderr_full)) + " chars)"
            #     print(stderr_preview)
            
            # pass
            # print(f"{'='*80}\n")
            
            # 统一返回格式
            succeeded = (returncode == 0)
            tool_metric = {"name": "ExecuteCode", "succeeded": succeeded}
            return ToolResponse(text=json.dumps(result_dict, ensure_ascii=False)), 0, tool_metric
            
        except Exception as e:
            error_msg = f"Code execution error: {type(e).__name__}: {e}"
            # print(f"{'='*80}")
            # print(f"❌ [ExecuteCode] EXCEPTION OCCURRED")
            # print(f"{'='*80}")
            # print(f"⚠️  Error: {error_msg}")
            # print(f"{'='*80}\n")
            logger.error(error_msg)
            tool_metric = {"name": "ExecuteCode", "succeeded": False}
            return ToolResponse(text=json.dumps({
                "success": False,
                "error": error_msg
            })), 0, tool_metric

    async def release(self, instance_id: str, **kwargs) -> None:
        """
        释放沙盒实例（延迟销毁策略）
        
        只从本地字典中移除，不真正销毁sandbox。
        这样同一个trajectory的后续轮次可以继续使用。
        
        如果需要强制销毁，调用时传入 force_destroy=True
        """
        # 只从本地字典移除
        session_id = self._session_ids.pop(instance_id, None)
        instance_short = instance_id[:8]
        session_short = session_id[:8] if session_id else "None"
        
        # 检查是否强制销毁
        force_destroy = kwargs.get("force_destroy", False)

        if force_destroy:
            shared_session = _SHARED_SANDBOX_REGISTRY.pop(instance_id, None)
            if not session_id:
                session_id = shared_session
            session_short = session_id[:8] if session_id else "None"

            if session_id:
                try:
                    await async_sandbox_destroy_sandbox(session_id)
                    logger.info(f"[ExecuteCode] Force destroyed sandbox session {session_short} for instance {instance_short}")
                except Exception as e:
                    logger.error(f"[ExecuteCode] Failed to destroy sandbox session {session_short} for instance {instance_short}: {e}")
            else:
                logger.info(f"[ExecuteCode] No sandbox to destroy for instance {instance_short}")
        else:
            # 非强制销毁时保留共享注册表中的 session 以便后续复用
            logger.info(f"[ExecuteCode] Released (but not destroyed) sandbox for instance {instance_short} (session {session_short})")


class FileagentExecuteShellTool(BaseTool):
    """
    使用 fileagent_mcp_host.sandbox_execute_command 执行 Shell 命令的工具
    接口与 FileAgentSandboxTool 对齐
    """

    def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema):
        super().__init__(config, tool_schema)
        if not FILEAGENT_AVAILABLE:
            raise ImportError("fileagent_mcp_host module is required for FileagentExecuteShellTool")
        
        self._session_ids: dict[str, str] = {}
        logger.info(f"Init FileagentExecuteShellTool with config: {config}")

    async def create(
        self, 
        instance_id: Optional[str] = None,
        files: list[FILE] = [],
        **kwargs
    ) -> tuple[str, ToolResponse]:
        """
        创建沙盒实例（支持复用）
        
        如果该 instance_id 已经有共享的sandbox，直接复用；
        否则创建新的sandbox并注册到全局注册表。
        
        Args:
            instance_id: 实例ID（如果为None则自动生成）
            files: 预写入文件列表 [(filename, content), ...]
            **kwargs: 其他参数
        
        Returns:
            (instance_id, 创建响应)
        """
        created_instance_id, tool_response = await super().create(instance_id, **kwargs)
        
        # 调试：打印收到的参数
        create_kwargs = kwargs.get("create_kwargs", {})
        # print(f"🔍 [ExecuteShell] DEBUG - create_kwargs keys: {list(create_kwargs.keys())}")
        # print(f"🔍 [ExecuteShell] DEBUG - prewrites: {create_kwargs.get('prewrites', 'NOT FOUND')}")
        
        # 检查是否已经有共享的sandbox（可能是ExecuteCode创建的）
        if created_instance_id in _SHARED_SANDBOX_REGISTRY:
            session_id = _SHARED_SANDBOX_REGISTRY[created_instance_id]
            self._session_ids[created_instance_id] = session_id
            # print(f"🔄 [ExecuteShell] Reusing shared sandbox {session_id[:8]} for instance {created_instance_id[:8]}")
            logger.info(f"[ExecuteShell] Reusing shared sandbox {session_id[:8]} for instance {created_instance_id[:8]}")
            return created_instance_id, ToolResponse(text=f"Reused sandbox: {session_id[:8]}")
        
        # 创建新沙盒
        try:
            response = await async_sandbox_create_sandbox()
            session_id = json.loads(response.result)["session_id"]
            
            # 同时保存到本地和全局注册表
            self._session_ids[created_instance_id] = session_id
            _SHARED_SANDBOX_REGISTRY[created_instance_id] = session_id
            
            # print(f"✨ [ExecuteShell] Created NEW sandbox {session_id[:8]} for instance {created_instance_id[:8]}")
            logger.info(f"[ExecuteShell] Created NEW sandbox {session_id[:8]} for instance {created_instance_id[:8]}")
            
            # 写入预置文件（支持两种格式）
            # 格式1: files 参数 [(filename, content), ...]
            for filename, content in files:
                await async_sandbox_write_file(
                    session_id, 
                    filename, 
                    base64.b64encode(content).decode("ascii"), 
                    binary=True
                )
                # print(f"📥 [ExecuteShell] Wrote file from files param: {filename}")
                # logger.info(f"Wrote file {filename} to sandbox {session_id[:8]}")
            
            # 格式2: create_kwargs 中的 prewrites [{"filename": "...", "fpath": "...", "binary": true}, ...]
            create_kwargs = kwargs.get("create_kwargs", {})
            prewrites = create_kwargs.get("prewrites", [])
            if prewrites:
                # print(f"📦 [ExecuteS  hell] Processing {len(prewrites)} prewrites from create_kwargs...")
                for pw in prewrites:
                    filename = pw.get("filename")
                    fpath = pw.get("fpath")
                    
                    if not filename:
                        # print(f"⚠️  [ExecuteShell] Skipping prewrite: missing filename")
                        continue
                    
                    if fpath and os.path.exists(fpath):
                        try:
                            with open(fpath, 'rb') as f:
                                content = f.read()
                            await async_sandbox_write_file(
                                session_id,
                                filename,
                                base64.b64encode(content).decode("ascii"),
                                binary=True
                            )
                            # print(f"📥 [ExecuteShell] Wrote file: {filename} (from {fpath}, size={len(content)} bytes)")
                        except Exception as e:
                            print(f"⚠️  [ExecuteShell] Failed to write {filename}: {e}")
                    else:
                        print(f"⚠️  [ExecuteShell] Skipping {filename}: fpath not found or not exists")
            
            return created_instance_id, tool_response
            
        except Exception as e:
            fd_count = "unknown"
            try:
                fd_count = subprocess.check_output([
                    "bash",
                    "-lc",
                    "ls /proc/$(pgrep -n python)/fd | wc -l"
                ]).decode().strip()
            except Exception as fd_err:
                try:
                    fd_count = str(len(os.listdir('/proc/self/fd')))
                except Exception as inner_err:
                    fd_count = f"unavailable ({inner_err})"
                logger.warning(f"Failed to run fd count command: {fd_err}")
            logger.error(f"Failed to create sandbox: {e}; open_fds={fd_count}")
            raise

    @rollout_trace_op
    async def execute(
        self, 
        instance_id: str, 
        parameters: dict[str, Any], 
        **kwargs
    ) -> tuple[ToolResponse, float, dict]:
        """
        执行 Shell 命令
        
        Args:
            instance_id: 实例ID
            parameters: 工具参数，包含 "command"
        
        Returns:
            (工具响应, 步骤奖励, 指标字典)
        """
        session_id = self._session_ids[instance_id]
        command = parameters.get("command", "")
        
        # print(f"\n{'='*80}")
        # print(f"⚡ [ExecuteShell] Starting execution in session {session_id[:8]}")
        # print(f"{'='*80}")
        # print(f"📝 Command to execute:")
        # print(f"{'─'*80}")
        # print(command)
        # print(f"{'─'*80}")
        logger.info(f"[ExecuteShell] Executing command in session {session_id[:8]}: {command[:50]}...")
        
        try:
            # 调用 sandbox_execute_command
            response = await async_sandbox_execute_command(session_id=session_id, command=command)
            
            # 规范化返回（参考 tool_utils.py 的 _normalize_shell）
            result_dict = None
            
            # 方式1: as_dict()['return']
            if hasattr(response, "as_dict"):
                try:
                    val = response.as_dict().get("return")
                    if isinstance(val, str):
                        try:
                            val = json.loads(val)
                        except Exception:
                            # 不是json字符串
                            result_dict = {"code": 1, "stdout": "", "stderr": str(val)}
                    if isinstance(val, dict):
                        result_dict = {
                            "code": int(val.get("code", 1)),
                            "stdout": str(val.get("stdout", "")),
                            "stderr": str(val.get("stderr", ""))
                        }
                except Exception:
                    pass
            
            # 方式2: 直接 dict
            if not result_dict and isinstance(response, dict):
                result_dict = {
                    "code": int(response.get("code", 1)),
                    "stdout": str(response.get("stdout", "")),
                    "stderr": str(response.get("stderr", ""))
                }
            
            # 默认
            if not result_dict:
                result_dict = {"code": 1, "stdout": "", "stderr": f"Unexpected response type: {type(response).__name__}"}
            
            # 打印执行结果
            returncode = result_dict["code"]
            stdout_full = result_dict["stdout"]
            stderr_full = result_dict["stderr"]
            
            # 打印详细状态
            # print(f"{'='*80}")
            # status_emoji = "✅" if returncode == 0 else "❌"
            # status_text = "SUCCESS" if returncode == 0 else "FAILED"
            # print(f"{status_emoji} [ExecuteShell] Execution {status_text}")
            # print(f"{'='*80}")
            # print(f"📊 Status: returncode={returncode}")
            # print(f"📏 Output sizes: stdout={len(stdout_full)} chars, stderr={len(stderr_full)} chars")
            
            # if stdout_full:
            #     stdout_preview = stdout_full if len(stdout_full) <= 500 else stdout_full[:500] + "\n... (truncated, total " + str(len(stdout_full)) + " chars)"
            #     print(stdout_preview)
            
            # if stderr_full:
            #     stderr_preview = stderr_full if len(stderr_full) <= 500 else stderr_full[:500] + "\n... (truncated, total " + str(len(stderr_full)) + " chars)"
            #     print(stderr_preview)
            
            # 统一返回格式
            succeeded = (returncode == 0)
            tool_metric = {"name": "ExecuteShell", "succeeded": succeeded}
            return ToolResponse(text=json.dumps(result_dict, ensure_ascii=False)), 0, tool_metric
            
        except Exception as e:
            error_msg = f"Command execution error: {type(e).__name__}: {e}"
            # print(f"{'='*80}")
            # print(f"❌ [ExecuteShell] EXCEPTION OCCURRED")
            # print(f"{'='*80}")
            # print(f"⚠️  Error: {error_msg}")
            # print(f"{'='*80}\n")
            logger.error(error_msg)
            tool_metric = {"name": "ExecuteShell", "succeeded": False}
            return ToolResponse(text=json.dumps({
                "success": False,
                "error": error_msg
            })), 0, tool_metric

    async def release(self, instance_id: str, **kwargs) -> None:
        """
        释放沙盒实例（延迟销毁策略）
        
        只从本地字典中移除，不真正销毁sandbox。
        这样同一个trajectory的后续轮次可以继续使用。
        
        如果需要强制销毁，调用时传入 force_destroy=True
        """
        # 只从本地字典移除
        session_id = self._session_ids.pop(instance_id, None)
        instance_short = instance_id[:8]
        session_short = session_id[:8] if session_id else "None"
        
        # 检查是否强制销毁
        force_destroy = kwargs.get("force_destroy", False)

        if force_destroy:
            shared_session = _SHARED_SANDBOX_REGISTRY.pop(instance_id, None)
            if not session_id:
                session_id = shared_session
            session_short = session_id[:8] if session_id else "None"

            if session_id:
                try:
                    await async_sandbox_destroy_sandbox(session_id)
                    logger.info(f"[ExecuteShell] Force destroyed sandbox session {session_short} for instance {instance_short}")
                except Exception as e:
                    logger.error(f"[ExecuteShell] Failed to destroy sandbox session {session_short} for instance {instance_short}: {e}")
            else:
                logger.info(f"[ExecuteShell] No sandbox to destroy for instance {instance_short}")
        else:
            logger.info(f"[ExecuteShell] Released (but not destroyed) sandbox for instance {instance_short} (session {session_short})")

