import os
import io
import regex
import pickle
import traceback
import copy
import datetime
import dateutil.relativedelta
import multiprocess
from multiprocess import Pool
from typing import Any, Dict, Optional
from pebble import ProcessPool
from tqdm import tqdm
from concurrent.futures import TimeoutError
from functools import partial
from timeout_decorator import timeout
from contextlib import redirect_stdout
import resource
import ast
import rdkit
from adme_py import ADME
import builtins
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field, validator
from typing import List, Optional

# 程序资源限制
max_memory = 100 * 1024 * 1024 * 1024       # 4G
resource.setrlimit(resource.RLIMIT_AS, (max_memory, max_memory * 2))


class DangerousVisitor(ast.NodeVisitor):
    def visit_Call(self, node):
        if isinstance(node.func, ast.Attribute):
            if node.func.attr in ["system", "remove", "rmtree"]:
                raise RuntimeError(
                    f"语法树检查发现危险函数调用: {node.func.attr}"
                )
        self.generic_visit(node)


class GenericRuntime:
    GLOBAL_DICT = {}
    LOCAL_DICT = None
    HEADERS = []

    def __init__(self):
        self._global_vars = copy.copy(self.GLOBAL_DICT)
        self._local_vars = copy.copy(self.LOCAL_DICT) if self.LOCAL_DICT else None

        for c in self.HEADERS:
            self.exec_code(c)

    def exec_code(self, code_piece: str) -> None:
        if regex.search(r"(\s|^)?input\(", code_piece) or regex.search(
            r"(\s|^)?os.system\(", code_piece
        ):
            raise RuntimeError()
        exec(code_piece, self._global_vars)

    def eval_code(self, expr: str) -> Any:
        return eval(expr, self._global_vars)

    def inject(self, var_dict: Dict[str, Any]) -> None:
        for k, v in var_dict.items():
            self._global_vars[k] = v

    @property
    def answer(self):
        return self._global_vars["answer"]


# 定义安全的 __import__ 函数


class SecureRuntime(GenericRuntime):
    def secure_import(name, *args, **kwargs):
        BLACKLIST = {# 系统和操作系统访问
        "os", "shutil", "subprocess", "sys", "pty", "platform", "posix", "winreg", "msvcrt",
        
        # 文件系统操作
        "pathlib", "glob", "tempfile", "fileinput", "fnmatch", "mmap",
        
        # 网络和通信
        "socket", "http", "urllib", "requests", "ftplib", "smtplib", "telnetlib", 
        "paramiko", "socketserver", "asyncore", "asynchat", "ssl",
        
        # 进程和线程控制
        "multiprocessing", "threading", "concurrent", "asyncio", "signal",
        
        # 系统命令执行
        "commands", "popen2", "pty", "pipes",
        
        # 代码执行和内省
        "code", "codeop", "dis", "compileall", "py_compile", "inspect",
        
        # 数据库访问
        "sqlite3", "pymysql", "psycopg2",
        
        # 其他危险模块
        "pickle", "marshal", "shelve",  # 不安全的序列化
        "ctypes", "cffi",  # 底层系统访问
        }
        if name in BLACKLIST:
            raise RuntimeError(f"禁止导入危险模块: {name}")
        return __import__(name, *args, **kwargs)
    
    GLOBAL_DICT = {
        "__builtins__": {
            "__import__": secure_import,  # 覆盖默认导入方法
            "__build_class__": builtins.__build_class__,  
            "__name__": "__main__",
            "__file__": "<string>",
            "__doc__": None,
            "__package__": None,
            "print": print,               # 保留安全内置函数
            "range": range,
            "enumerate": enumerate,
            "len": len,
            # 数学运算函数
            "abs": abs, "min": min, "max": max, "sum": sum, 
            "round": round, "int": int, "float": float,
            "pow": pow, "divmod": divmod,
            
            # 字符串处理
            "str": str, "repr": repr, "format": format,
            "ord": ord, "chr": chr,
            
            # 容器和序列
            "list": list, "tuple": tuple, "dict": dict,
            "set": set, "frozenset": frozenset,
            "sorted": sorted, "zip": zip,
            
            # 类型判断
            "isinstance": isinstance, "type": type,
            "bool": bool,
            
            # 迭代器
            "iter": iter, "next": next,
            
            # 其他有用函数
            "hash": hash, "bin": bin, "hex": hex, "oct": oct,
            "all": all, "any": any,
            
            # 添加内置异常类
            "Exception": builtins.Exception,
            "ValueError": builtins.ValueError,
            "TypeError": builtins.TypeError,
            "NameError": builtins.NameError,
            "RuntimeError": builtins.RuntimeError,
            "IndexError": builtins.IndexError,
            "KeyError": builtins.KeyError,
            "ZeroDivisionError": builtins.ZeroDivisionError,
            "AttributeError": builtins.AttributeError,
            "ImportError": builtins.ImportError,
            "SyntaxError": builtins.SyntaxError,
            
        }
    }
    
class DateRuntime(GenericRuntime):
    GLOBAL_DICT = {
        "datetime": datetime.datetime,
        "timedelta": dateutil.relativedelta.relativedelta,
        "relativedelta": dateutil.relativedelta.relativedelta,
    }


class CustomDict(dict):
    def __iter__(self):
        return list(super().__iter__()).__iter__()


class ColorObjectRuntime(GenericRuntime):
    GLOBAL_DICT = {"dict": CustomDict}


class PythonExecutor:
    def __init__(
        self,
        runtime: Optional[Any] = None,
        get_answer_symbol: Optional[str] = None,
        get_answer_expr: Optional[str] = None,
        get_answer_from_stdout: bool = True,
        timeout_length: int = 5,
    ) -> None:
        self.runtime = runtime if runtime else GenericRuntime()
        self.answer_symbol = get_answer_symbol
        self.answer_expr = get_answer_expr
        self.get_answer_from_stdout = get_answer_from_stdout
        self.pool = Pool(multiprocess.cpu_count())
        self.timeout_length = timeout_length

    def process_generation_to_code(self, gens: str):
        return [g.split("\n") for g in gens]

    @staticmethod
    def execute(
        code,
        get_answer_from_stdout=True,
        runtime=None,
        answer_symbol=None,
        answer_expr=None,
        timeout_length=10,
    ):
        DV = DangerousVisitor()
        try:
            # 在执行任何代码前进行语法树安全分析
            if isinstance(code, list):
                code_str = "\n".join(code)
            else:
                code_str = code
            
            tree = ast.parse(code_str)
            DV.visit(tree)  # 发现危险调用立即终止
            
            # 代码执行
            if get_answer_from_stdout:
                program_io = io.StringIO()
                with redirect_stdout(program_io):
                    timeout(timeout_length)(runtime.exec_code)("\n".join(code))
                program_io.seek(0)
                result = program_io.read()
            if answer_symbol:
                timeout(timeout_length)(runtime.exec_code)(code)
                result = runtime._global_vars[answer_symbol]
            if answer_expr:
                timeout(timeout_length)(runtime.exec_code)(code)
                result = timeout(timeout_length)(runtime.eval_code)(answer_expr)
            if result == "":
                timeout(timeout_length)(runtime.exec_code)(code[:-1])
                result = timeout(timeout_length)(runtime.eval_code)(code[-1])
            report = "Done"
            str(result)
            pickle.dumps(result)  # serialization check
        except:
            result = ""
            report = traceback.format_exc().split("\n")[-2]
        return result, report

    def apply(self, code):
        return self.batch_apply([code])[0]

    @staticmethod
    def truncate(s, max_length=1000):
        half = max_length // 2
        if len(s) > max_length:
            s = s[:half] + "..." + s[-half:]
            return s, True
        return s, False

    def batch_apply(self, 
            batch_code,
            max_workers=min(4, os.cpu_count()),
            max_length=1000
        ):
        all_code_snippets = self.process_generation_to_code(batch_code)

        timeout_cnt = 0
        all_exec_results = []
        with ProcessPool(
            max_workers=max_workers
        ) as pool:
            executor = partial(
                self.execute,
                get_answer_from_stdout=self.get_answer_from_stdout,
                runtime=self.runtime,
                answer_symbol=self.answer_symbol,
                answer_expr=self.answer_expr,
                timeout_length=self.timeout_length,  # this timeout not work
            )
            future = pool.map(executor, all_code_snippets, timeout=self.timeout_length)
            iterator = future.result()

            if len(all_code_snippets) > 100:
                progress_bar = tqdm(total=len(all_code_snippets), desc="Execute")
            else:
                progress_bar = None

            while True:
                try:
                    result = next(iterator)
                    all_exec_results.append(result)
                except StopIteration:
                    break
                except TimeoutError as error:
                    print(error)
                    all_exec_results.append(("", "Timeout Error"))
                    timeout_cnt += 1
                except Exception as error:
                    print(error)
                    exit()
                if progress_bar is not None:
                    progress_bar.update(1)

            if progress_bar is not None:
                progress_bar.close()

        batch_results = []
        for code, (res, report) in zip(all_code_snippets, all_exec_results):
            # post processing
            res, report = str(res).strip(), str(report).strip()
            (res, res_truncate), (report, report_truncate) = self.truncate(res,max_length), self.truncate(report,max_length)
            batch_results.append((res, report, res_truncate))   # 程序执行结果，是否被截断，报告
        return batch_results
    
app = FastAPI()

class CodeRequest(BaseModel):
    batch_code: List[str] = Field(..., 
        min_items=1, 
        max_items=1000,
        example=["print('hello')", "1+1"],
        description="需要执行的Python代码列表（1-1000条）")
    
    max_length: Optional[int] = Field(1000,
        ge=10, 
        le=5000,
        description="结果截断长度（10-5000字符）")
    
    max_workers: Optional[int] = Field(
        default_factory=lambda: min(4, os.cpu_count()),
        ge=1,
        le=32,
        description="最大并行工作进程数（1-32）")

    @validator('max_workers')
    def validate_workers(cls, v):
        available_cores = os.cpu_count()
        return min(v, available_cores)  # 不能超过实际CPU核心数


class ResultItem(BaseModel):
    result: str = Field(..., example="2", description="执行结果")
    report: str = Field(..., example="Done", description="执行报告")
    truncated: bool = Field(..., example=False, description="是否被截断")


@app.post("/execute", response_model=List[ResultItem])
async def execute_code(request: CodeRequest):
    try:
        # 初始化安全运行时环境
        secure_runtime = SecureRuntime()
        
        # 创建 Python 执行器
        executor = PythonExecutor(runtime=secure_runtime)
        
        # 执行代码并获取结果
        raw_results = executor.batch_apply(
            batch_code=request.batch_code,   # 要执行的一批程序
            max_length=request.max_length,   # 结果截断长度
            max_workers=request.max_workers  # 最大并行工作进程数
        )
        
        # 转换结果格式
        return [
            {
                "result": res[0],    # 执行结果
                "report": res[1],    # 执行报告，是否报错
                "truncated": res[2]  # 程序结果是否截断
            }
            for res in raw_results
        ]
    
    except Exception as e:
        raise HTTPException(
            status_code=500,
            detail=f"执行过程中发生错误: {str(e)}"
        )

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=3999)