from meta_researcher.tool.base import BaseToolEnv, BaseTool
from typing import List, Dict, Tuple, Any
import re
import os
import json
from datetime import datetime
from meta_researcher.tool.prompts.prompt import SYSTEM_PROMPT

class NousToolEnv(BaseToolEnv):
    def __init__(self, tools: List[BaseTool], max_tool_response_length: int, max_step_num: int):
        self.tools = tools
        self.tool_map = {tool.name: tool for tool in self.tools}
        self.tool_call_start = "<tool_call>"
        self.tool_call_end = "</tool_call>"
        self.tool_response_start = "<tool_response>"
        self.tool_response_end = "</tool_response>"
        self.eos_token = "<|im_end|>"
        self.parallel_tool_calls = False
        self.max_tool_response_length = max_tool_response_length
        self.max_step_num = max_step_num
    
    def tools_format_func(self) -> str:
        CURRENT_TIME = datetime.now().strftime("%a %b %d %Y %H:%M:%S %z")
        locale = os.getenv('LOCALE', "zh-CN")
        template = SYSTEM_PROMPT
        tool_call_format = "{'name': <function-name>, 'arguments': <args-json-object>}"
        tools = "\n".join([f"{json.dumps(tool.tool_description, ensure_ascii=False)}" for tool in self.tools])
        return template.format(CURRENT_TIME=CURRENT_TIME)

    def step(self, raw_response: str) -> Tuple[str, List[bool], bool]:
        tool_calls = self.extract_tool_calls(raw_response)
        if len(tool_calls) == 0:
            return "", [], False
        if not self.parallel_tool_calls:
            tool_calls = [tool_calls[0]]
        tool_responses = []
        tool_successes = []
        for tool_call in tool_calls:
            if tool_call is None:
                tool_responses.append("Error: JSONDecodeError")
                tool_successes.append(False)
            else:
                if "name" not in tool_call:
                    tool_responses.append("Error: No tool name")
                    tool_successes.append(False)
                else:
                    tool_name = tool_call["name"]
                    if tool_name not in self.tool_map:
                        tool_responses.append("Error: ToolNotFoundError")
                        tool_successes.append(False)
                    else:
                        tool = self.tool_map[tool_name]
                        if not tool.validate_args(tool_call["arguments"]):
                            tool_responses.append("Error: Invalid tool arguments")
                            tool_successes.append(False)
                        else:
                            tool_result = tool.execute(tool_call["arguments"])
                            tool_responses.append(tool_result["content"])
                            tool_successes.append(tool_result["success"])
        tool_response = self.format_tool_response(tool_responses)
        return tool_response, tool_successes, True

    def batch_step(self, raw_responses: List[str]) -> Tuple[List[str], List[List[bool]], List[bool]]:
        batch_tool_responses = [[]] * len(raw_responses)
        batch_tool_successes = [[]] * len(raw_responses)
        batch_active = [True] * len(raw_responses)
        success_tool_calls_arguments = {}
        success_tool_calls_index = {}
        for i, raw_response in enumerate(raw_responses):
            tool_calls = self.extract_tool_calls(raw_response)
            if len(tool_calls) == 0:
                batch_tool_successes[i] = []
                batch_active[i] = False
                batch_tool_responses[i] = []
                continue

            if not self.parallel_tool_calls:
                tool_calls = [tool_calls[0]]
            tool_responses = []
            tool_successes = []
            for j, tool_call in enumerate(tool_calls):
                if tool_call is None:
                    tool_responses.append("Error: JSONDecodeError")
                    tool_successes.append(False)
                else:
                    if "name" not in tool_call:
                        tool_responses.append("Error: No tool name")
                        tool_successes.append(False)
                    elif "arguments" not in tool_call:
                        tool_responses.append("Error: No tool arguments")
                        tool_successes.append(False)
                    else:
                        tool_name = tool_call["name"]
                        if tool_name not in self.tool_map:
                            tool_responses.append("Error: ToolNotFoundError")
                            tool_successes.append(False)
                        else:
                            tool = self.tool_map[tool_name]
                            if not tool.validate_args(tool_call["arguments"]):
                                tool_responses.append("Error: Invalid tool arguments")
                                tool_successes.append(False)
                            else:
                                if tool_name not in success_tool_calls_arguments:
                                    success_tool_calls_arguments[tool_name] = []
                                    success_tool_calls_index[tool_name] = []
                                tool_responses.append("Executing...")
                                tool_successes.append(False)
                                success_tool_calls_arguments[tool_name].append(tool_call["arguments"])
                                success_tool_calls_index[tool_name].append((i,j))
            batch_tool_responses[i] = tool_responses
            batch_tool_successes[i] = tool_successes
        
        # batch excute
        for tool_name, args_list in success_tool_calls_arguments.items():
            if tool_name == "plan":
                batch_results = [{"content": "<plan>\n" + args["content"] + "\n</plan>", "success": True} for args in args_list]
            elif tool_name == "reflect":
                batch_results = [{"content": "<reflect>\n" + args["content"] + "\n</reflect>", "success": True} for args in args_list]
            else:
                tool = self.tool_map[tool_name] 
                batch_results = tool.batch_execute(args_list)
            for query, batch_result, (i,j) in zip(args_list, batch_results, success_tool_calls_index[tool_name]):
                assert batch_tool_responses[i][j] == "Executing..."
                batch_tool_responses[i][j] = batch_result["content"]
                batch_tool_successes[i][j] = batch_result["success"]
        
        batch_tool_responses_ = []
        for i, tool_responses in enumerate(batch_tool_responses):
            if batch_active[i]:
                assert len(batch_tool_responses[i]) > 0
                batch_tool_responses_.append(self.format_tool_response(tool_responses))
            else:
                batch_tool_responses_.append("")
        
        return batch_tool_responses_, batch_tool_successes, batch_active

    def stop(self, raw_response: str) -> bool:
        tool_calls = self.extract_tool_calls(raw_response)
        if len(tool_calls) == 0:
            return True
        else:
            return False
        
    def extract_tool_calls(self, raw_response: str) -> List[Any]:
        tool_calls = []
        # plan_pattern = re.compile(r"<plan>(.*?)</plan>", re.DOTALL)
        # for plan in re.findall(plan_pattern, raw_response):
        #     try:
        #         tool_calls.append({"name": "plan", "arguments": {"content": plan}})
        #     except:
        #         tool_calls.append(None)
        
        # Reflect_pattern = re.compile(r"<reflect>(.*?)</reflect>", re.DOTALL)
        # for reflect in re.findall(Reflect_pattern, raw_response):
        #     try:
        #         tool_calls.append({"name": "reflect", "arguments": {"content": reflect}})
        #     except:
        #         tool_calls.append(None)
        
        # pattern = re.compile(f"{re.escape(self.tool_call_start)}(.*?){re.escape(self.tool_call_end)}", re.DOTALL)
        pattern = re.compile(
            f"{re.escape(self.tool_call_start)}((?:(?!{re.escape(self.tool_call_start)}).)*?){re.escape(self.tool_call_end)}", re.DOTALL
        )
        for tool_call in re.findall(pattern, raw_response):
            try:
                tool_call = json.loads(tool_call)
                tool_calls.append(tool_call)
            except json.JSONDecodeError:
                tool_calls.append(None)
        
        return tool_calls
        
    def format_tool_response(self, tool_responses: List[str]) -> str:
        tool_message = "<|im_end|>\n<|im_start|>user\n"
        for i, tool_response in enumerate(tool_responses):
            if len(tool_response) > self.max_tool_response_length:
                tool_response = tool_response[:self.max_tool_response_length] + "..."
            tool_message += f"<tool_response>\n{tool_response}\n</tool_response>"
            if i < len(tool_responses) - 1:
                tool_message += "\n"
        tool_message += "<|im_end|>\n<|im_start|>assistant\n<think>\n"
        return tool_message