from tool.base import BaseToolEnv, BaseTool
from typing import List, Tuple, Any

class AgentMathEnv(BaseToolEnv):
    def __init__(self, tools, max_tool_response_length=512):
        self.tool = tools[0]
        assert self.tool.name == "python"
        self.max_tool_response_length = max_tool_response_length

    def step(self, raw_response: str , ip_address="localhost") -> Tuple[str, List[bool], bool]:
        code = self.extract_tool_calls(raw_response)
        if len(code) == 0:
            return "", [], False
        code = code[0]
        tool_response, tool_success = self.tool.execute({"code": code}, ip_address=ip_address)
        tool_response = self.format_tool_response([tool_response])
        return tool_response, [tool_success], True

    def batch_step(self, raw_responses: List[str], ip_address="localhost") -> Tuple[List[str], List[List[bool]], List[bool]]:
        batch_tool_response = [""] * len(raw_responses)
        batch_tool_successes = [[]] * len(raw_responses)
        batch_active = [True] * len(raw_responses)
        codes = []
        for i, raw_response in enumerate(raw_responses):
            raw_response_temp_split1 = raw_response.split("</answer>\nassistant")[-1]
            raw_response_temp_split2 = raw_response_temp_split1.split("\n</interpreter>\n")[-1]
            code = self.extract_tool_calls(raw_response_temp_split2)
            if len(code) == 0 or "</code>" not in raw_response_temp_split2:
                batch_tool_response[i] = ""
                batch_tool_successes[i] = []
                batch_active[i] = False
                continue
            codes.append({"code": code[0]})
        results = self.tool.batch_execute(codes, ip_address=ip_address)
        i = 0
        for j in range(len(raw_responses)):
            if batch_active[j]:
                result = results[i]
                batch_tool_response[j] = self.format_tool_response([result["content"]])
                batch_tool_successes[j] = [result["success"]]
                i += 1
        return batch_tool_response, batch_tool_successes, batch_active

    def stop(self, raw_response: str) -> bool:
        tool_calls = self.extract_tool_calls(raw_response)
        if len(tool_calls) == 0:
            return True
        else:
            return False
        
    def extract_tool_calls(self, raw_response: str) -> List[Any]:
        """
        extract the code after "<code>", and before "</code>"
        """
        code = ''
        start = False
        for line in raw_response.split('\n'):
            if line.startswith('<code>'):
                code += '\n# ========\n'
                start = True
            elif line.startswith('</code>'):
                start = False
            elif line.startswith('.</code>'):
                start = False
            elif '</code>' in line:
                start = False
            elif start:
                if line.startswith('```'):
                    continue
                code += line + '\n'
        if start or len(code) == 0:
            # the code is incomplete
            return []
        return [code]


    def extract_program(self, result: str, last_only=True):
        """
        extract the program after "```python", and before "```"
        """
        program = ''
        start = False
        for line in result.split('\n'):
            if line.startswith('```python') or line.endswith('```python'):
                if last_only:
                    program = ''  # only extract the last program
                else:
                    program += '\n# ========\n'
                start = True
            elif line.startswith('```'):
                start = False
            elif start:
                program += line + '\n'
        if start:
            # the code is incomplete
            program = ''

        if len(program) == 0:
            return []

        return [program]

    def format_tool_response(self, tool_responses: List[str]) -> str:
        if len(tool_responses) == 0:
            return ""
        tool_responses[0] = tool_responses[0].replace("/miniconda3/envs/sandbox-runtime", '')
        if len(tool_responses[0]) > self.max_tool_response_length:
            tool_responses[0] = tool_responses[0][:self.max_tool_response_length] + "..."
        return tool_responses[0]