import asyncio
import ast
import atexit
import base64
# import concurrent.futures
import glob
import io
import json
import os
import queue
import re
import shutil
import signal  # noqa
import stat
import subprocess
import sys
import threading
import time
import uuid
import random
import psutil
from pathlib import Path
from typing import Dict, List, Optional, Union
from ray.util import pdb

import json5

from qwen_agent.tools.base import BaseToolWithFileAccess, register_tool

if os.getenv("SANDBOX_FUSION_ENDPOINT") is None:
    os.environ["SANDBOX_FUSION_ENDPOINT"] = "http://172.22.11.236:8080"

from sandbox_fusion import run_code, RunCodeRequest, RunStatus
from requests.exceptions import Timeout


@register_tool('PythonInterpreter')
class PythonInterpreter(BaseToolWithFileAccess):
    name = "PythonInterpreter"
    description = 'Execute Python code in a sandboxed environment. Use this to run Python code and get the execution results.\n**Make sure to use print() for any output you want to see in the results.**\nFor code parameters, use placeholders first, and then put the code within <code></code> XML tags, such as:\n<tool_call>\n{"purpose": <detailed-purpose-of-this-tool-call>, "name": <tool-name>, "arguments": {"code": ""}}\n<code>\nHere is the code.\n</code>\n</tool_call>\n'

    parameters = {
        "type": "object",
        "properties": {
            "code": {
                "type": "string",
                "description": "The Python code to execute. Must be provided within <code></code> XML tags. Remember to use print() statements for any output you want to see.",
            }
        },
        "required": ["code"],
    }

    # description = 'Execute Python code in a sandboxed environment. Use this to run Python code and get the execution results.\nWhen call this tool, you must provide your python code as a code block:\n```python\n[your python code]\n```'

    def __init__(self, cfg: Optional[Dict] = None):
        super().__init__(cfg)
        # self.summary_mapping = SummaryMapping()
    
    @property
    def args_format(self) -> str:
        fmt = self.cfg.get('args_format')
        if fmt is None:
            if has_chinese_chars([self.name_for_human, self.name, self.description, self.parameters]):
                fmt = '此工具的输入应为Markdown代码块。'
            else:
                fmt = 'Enclose the code within triple backticks (`) at the beginning and end of the code.'
        return fmt

    def observation(self, tool: dict, tool_dict: dict, tool_results, empty_mode: bool=False, readpage: bool=False, max_observation_length: int=None, tokenizer=None):
        assert isinstance(tool_results, str), f"result of python code should be str, instead of {type(tool_results)}. {tool_results}"
        return tool_results
    
    @property
    def function(self) -> dict:  # Bad naming. It should be `function_info`.
        return {
            # 'name_for_human': self.name_for_human,
            'name': self.name,
            'description': self.description,
            'parameters': self.parameters,
            # 'args_format': self.args_format
        }

    def call(self, params: Union[str, dict], files: List[str] = None, timeout: Optional[int] = 50, **kwargs) -> str:
        try:
            super().call(params=params, files=files)  # copy remote files to work_dir
            try:
                if type(params) is str:
                    params = json5.loads(params)
                code = params.get('code', '')
                triple_match = re.search(r'```[^\n]*\n(.+?)```', code, re.DOTALL)
                if triple_match:
                    code = triple_match.group(1)
            except Exception:
                code = extract_code(params)

            if not code.strip():
                return {
                    "success": False,
                    "error_message": '[Python Interpreter Error]: Empty code.'
                }

            try:
                code_result = run_code(RunCodeRequest(code=code, language='python'), max_attempts=1, client_timeout=timeout)

                result = []
                if code_result.run_result.stdout:
                    result.append(f"stdout:\n{code_result.run_result.stdout}")
                if code_result.run_result.stderr:
                    result.append(f"stderr:\n{code_result.run_result.stderr}")
                
                
                result = '\n'.join(result)

            except Timeout:
                result = '[Python Interpreter Error] TimeoutError: Execution timed out.'
            
            except Exception as e:
                result = f'[Python Interpreter Error]: {str(e)}'

            return {
                "success": True,
                "results": result if result.strip() else 'Finished execution.',
                "params": params
            }

        except Exception as e:

            return {
                "success": False,
                "error_message": f"[Python Interpreter Error]: {str(e)}"
            }


def _test():
#     params = {"code": """print("Hello world!")"""}
#     params = {"code": """"
# ```python
# import pandas as pd
# import matplotlib.pyplot as plt

# df = pd.read_csv('sample.csv')

# plt.figure(figsize=(12, 6))
# plt.plot(df['col1'], df['col2'], label='label_name')
# plt.xlabel('x_name')
# plt.ylabel('y_name')
# plt.title('x vs. y')
# plt.legend()
# plt.show()
# ```
#               """}

#     params = {"code": """"
# ```import matplotlib.pyplot as plt
# import numpy as np

# # 生成随机数据
# x = np.random.rand(50)
# y = np.random.rand(50)

# plt.scatter(x, y)
# plt.title('散点图')
# plt.xlabel('X轴')
# plt.ylabel('Y轴')
# plt.show()
# ```
#               """}
#     params = {"code": """"
# ```
# def fibonacci(n):
#     if n <= 0:
#         return "请输入正整数"
#     elif n == 1:
#         return [0]
#     elif n == 2:
#         return [0, 1]
    
#     fib = [0, 1]
#     for i in range(2, n):
#         fib.append(fib[i-1] + fib[i-2])
#     return fib

# # 测试
# num = 7
# # print(f"斐波那契数列：{fibonacci(num)}")
# print(fibonacci(num))
# ```
#               """}
#     params = {"code": """
# ```
# print(1+1+2*8)
# ```"""}
    # params = {"code": "import numpy as np\nfrom scipy.stats import weibull_min\nfrom scipy.optimize import minimize\n\n# Given dataset\ndata = np.array([45, 60, 36, 70, 50, 55])\n\n# Define the negative log-likelihood function for the Weibull distribution\ndef neg_log_likelihood(params, data):\n    beta, lam = params\n    return -np.sum(np.log(weibull_min.pdf(data, beta, scale=lam)))\n\n# Initial guess for the parameters\ninitial_guess = [1.5, 100]\n\n# Perform the minimization to find the MLE estimates\nresult = minimize(neg_log_likelihood, initial_guess, args=(data,))\n\n# Extract the MLE estimates for beta and lambda\nbeta_mle, lam_mle = result.x\n\nbeta_mle, lam_mle"}
    # params = {"code": "# This code is conceptual and serves to test the conceptual understanding\\n# We will create a mock function to demonstrate how bypass notches work\\n# Let's create a mock scenario to see if bypass notches aid in material flow around complex geometries\\n\\n# Define a mock function to simulate material flow around a complex geometry\\nimport numpy as np\\n\\n# Assume a complex geometry is represented by a 2D array of material densities\\ngeometry = np.array([[0, 0, 1, 1, 1, 0], [1, 1, 1, 1, 1, 1], [1, 0, 0, 0, 1, 1], [1, 1, 1, 1, 1, 1], [0, 0, 1, 1, 1, 0]])\\n# The bypass notches are represented by 0s in this case\\n\\n# Simulate material flow without bypass notches\\nmaterial_flow_without_notches = np.copy(geometry)\\n# Convert all values to 1 as material flow is obstructed\\nmaterial_flow_without_notches[:] = 1\\n\\n# Simulate material flow with bypass notches\\nmaterial_flow_with_notches = np.copy(geometry)\\n# Add bypass notches and simulate unobstructed flow\\nfor i in range(1, len(material_flow_with_notches) - 1):\\n    for j in range(1, len(material_flow_with_notches[0]) - 1):\\n        if (material_flow_with_notches[i, j] == 1 and material_flow_with_notches[i-1, j] == 1 and\\n            material_flow_with_notches[i+1, j] == 1 and material_flow_with_notches[i, j-1] == 1 and\\n            material_flow_with_notches[i, j+1] == 1):\\n            material_flow_with_notches[i, j] = 0\\n\\n# Display the two scenarios to see the difference\\nprint('Without Bypass Notches:', material_flow_without_notches)\\nprint('With Bypass Notches:', material_flow_with_notches)\\n\\n# The bypass notches (0s in the geometry) should simulate an unobstructed path for material flow around complex geometries.\\n# This code is a conceptual representation and is not executable in the current context as it requires further simulation tools for practical use.\\n\\n# Visualizing the results:\\nimport matplotlib.pyplot as plt\\nplt.figure(figsize=(10, 5))\\nplt.subplot(1, 2, 1)\\nplt.imshow(material_flow_without_notches, cmap='gray', extent=[0, len(geometry[0]), 0, len(geometry)], aspect='auto')\\nplt.title('Without Bypass Notches')\\nplt.subplot(1, 2, 2)\\nplt.imshow(material_flow_with_notches, cmap='gray', extent=[0, len(geometry[0]), 0, len(geometry)], aspect='auto')\\nplt.title('With Bypass Notches')\\nplt.show()"}
    # params = {"code": "import numpy as np\nA = np.array([[0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1], [-2, -4, -3, -5]])\nB = np.array([[0, 0, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]])\ncontrollability_matrix = np.hstack((B, np.dot(A, B), np.dot(A**2, B)))\nprint(np.linalg.matrix_rank(controllability_matrix))"}
    # params = {"code": "import json\n\ndef verify_diagnosis(proband):\n    features = proband[\"features\"]\n    if all(feature in features for feature in [\"hepatomegaly\", \"hepatic calcification\", \"cholestasis\", \"cardiomegaly\", \"splenomegaly\", \"increased T cell count\", \"increased interleukin-6 concentration\"]):\n        return \"Alagille syndrome\"\n    else:\n        return \"Other condition\"\n\nproband = {\n    \"features\": [\"hepatomegaly\", \"hepatic calcification\", \"cholestasis\", \"cardiomegaly\", \"splenomegaly\", \"increased T cell count\", \"increased interleukin-6 concentration\", \"abdominal pain\", \"diarrhea\", \"urticaria\", \"atopic dermatitis\", \"fever\", \"asthma\", \"arthritis\", \"elevated C-reactive protein\", \"frontal bossing\", \"increased T cell count\", \"increased interleukin-6 concentration\"]\n}\n\ndiagnosis = verify_diagnosis(proband)\nprint(diagnosis)"}
    params = {"code": "print(\"Hello, World!\")"}
    # '![fig-001](workspace/tools/CodeInterpreter/ac1b42e5-19fb-460e-b3a3-5f1029658efd.png)'
    executor = PythonInterpreter()
    out = executor.call(params)
    from pprint import pprint
    pprint(out)


if __name__ == '__main__':
    os.environ["SANDBOX_FUSION_ENDPOINT"] = "http://172.22.11.236:8080"
    _test()