from typing import Dict, List, Any
import os
from sandbox_fusion import set_sandbox_endpoint, run_concurrent, run_code, RunCodeRequest, RunStatus
from tool.base import BaseTool


class ProxyManager:
    def __enter__(self):
        self._original_http_proxy = os.environ.pop("http_proxy", None)
        self._original_https_proxy = os.environ.pop("https_proxy", None)

    def __exit__(self, exc_type, exc_val, exc_tb):
        if self._original_http_proxy is not None:
            os.environ["http_proxy"] = self._original_http_proxy
        if self._original_https_proxy is not None:
            os.environ["https_proxy"] = self._original_https_proxy

class PythonTool(BaseTool):
    name = "python"
    description = "Python code sandbox, which can be used to execute Python code."
    parameters = {
        "type": "object",
        "properties": {
            "code": {
                "type": "string",
                "description": "The Python code to execute. The Python code should be complete scripts, including necessary imports. IMPORTANT: Use print() statements to output any results you want to see, otherwise they won't be visible.",
            }
        },
        "required": ["code"],
    }

    def __init__(self):
        super().__init__()
        self.run_timeout = 10
        self.concurrency = 64 #32
        self.max_attempts = 5

    def batch_execute(self, args_list: List[Dict], ip_address="localhost") -> List[Dict[str, Any]]:
        set_sandbox_endpoint(f'http://{ip_address}:8080')
        print('batch_execute() ip address =====', f'http://{ip_address}:8080')
        with ProxyManager():
            batch_code = [args.get("code", "") for args in args_list]
            batch_results = []
            results = run_concurrent(run_code, kwargs=[{"request": RunCodeRequest(run_timeout=self.run_timeout, code=c, language='python'), 'max_attempts': self.max_attempts} for c in batch_code], concurrency=self.concurrency)
            for result in results:
                if result.status == RunStatus.Success:
                    if result.run_result and result.run_result.stdout and len(result.run_result.stdout) > 0:
                        batch_results.append({"content": result.run_result.stdout, "success": True})
                    else:
                        batch_results.append({"content": "Execution successful but no output", "success": True})
                else:
                    error_message = result.message or "Unknown error"
                    if result.run_result and result.run_result.stderr:
                        error_message = result.run_result.stderr
                    elif result.compile_result and result.compile_result.stderr:
                        error_message = result.compile_result.stderr
                    batch_results.append({"content": error_message, "success": False})
            for result in batch_results:
                result['content'] = result['content'].strip()
            return batch_results
    
    def execute(self, args: Dict, ip_address="localhost", **kwargs) -> Dict[str, Any]:
        return self.batch_execute([args], ip_address=ip_address)[0]