from futuremind.tool.base import BaseToolEnv, BaseTool
from typing import List, Dict, Tuple, Any
import re
import json
import ast


def string_to_list_with_double_quotes(string_list):
    # Handle non-string inputs
    if not isinstance(string_list, str):
        return string_list

    # Remove any leading/trailing whitespace
    string_list = string_list.strip()

    # Method 1: Try ast.literal_eval first (safest)
    try:
        result = ast.literal_eval(string_list)
        if isinstance(result, list):
            return result
    except (ValueError, SyntaxError):
        pass

    # Method 2: Try json.loads (works with double quotes)
    try:
        result = json.loads(string_list)
        if isinstance(result, list):
            return result
    except (json.JSONDecodeError, ValueError):
        pass

    # Method 3: Handle mixed quote situations
    try:
        # Replace single quotes with double quotes, but be careful with apostrophes
        # This regex finds single quotes that are likely list delimiters
        fixed_string = re.sub(r"(?<!\w)'(?!\w)|(?<!\w)'(?=\s*[,\]])", '"', string_list)
        result = json.loads(fixed_string)
        if isinstance(result, list):
            return result
    except (json.JSONDecodeError, ValueError):
        pass

    # Method 4: Manual parsing for complex cases
    try:
        # Remove outer brackets
        inner = string_list.strip('[]')

        # Split by comma, but be smart about it
        items = []
        current_item = ""
        in_quotes = False
        quote_char = None

        for char in inner:
            if char in ['"', "'"] and (not in_quotes or char == quote_char):
                if not in_quotes:
                    in_quotes = True
                    quote_char = char
                elif char == quote_char:
                    in_quotes = False
                    quote_char = None
                current_item += char
            elif char == ',' and not in_quotes:
                items.append(current_item.strip().strip('"\''))
                current_item = ""
            else:
                current_item += char

        # Add the last item
        if current_item.strip():
            items.append(current_item.strip().strip('"\''))

        return items
    except:
        pass

    print(f"Warning: Could not convert '{string_list}' to list")
    return string_list


class NousToolEnv(BaseToolEnv):
    def __init__(self, tools: List[BaseTool], max_tool_response_length: int):
        self.tools = tools
        self.tool_map = {tool.name: tool for tool in self.tools}
        self.tool_call_start = "<tool_call>"
        self.tool_call_end = "</tool_call>"
        self.tool_response_start = "<tool_response>"
        self.tool_response_end = "</tool_response>"
        self.eos_token = "<|im_end|>"
        self.parallel_tool_calls = False
        self.max_tool_response_length = max_tool_response_length
    
    def step(self, raw_response: str, step_inference: bool = False, arguments_key="arguments") -> Tuple[str, List[bool], bool]:
        tool_calls = self.extract_tool_calls(raw_response)
        if len(tool_calls) == 0:
            return "", [], False
        if not self.parallel_tool_calls:
            tool_calls = [tool_calls[0]]
        tool_responses = []
        tool_successes = []
        for tool_call in tool_calls:
            if tool_call is None:
                tool_responses.append("Error: JSONDecodeError")
                tool_successes.append(False)
            else:
                if "name" not in tool_call:
                    tool_responses.append("Error: No tool name")
                    tool_successes.append(False)
                else:
                    tool_name = tool_call["name"]
                    if tool_name not in self.tool_map:
                        tool_responses.append("Error: ToolNotFoundError")
                        tool_successes.append(False)
                    else:
                        tool = self.tool_map[tool_name]

                        if 'queries' in tool_call[arguments_key]:
                            if isinstance(tool_call[arguments_key]['queries'], str):
                                tool_call[arguments_key]['queries'] = string_to_list_with_double_quotes(tool_call[arguments_key]['queries'])
                        
                        if 'limit' in tool_call[arguments_key]:
                            if isinstance(tool_call[arguments_key]['limit'], str):
                                tool_call[arguments_key]['limit'] = int(tool_call[arguments_key]['limit'])

                        if not tool.validate_args(tool_call[arguments_key]):
                            tool_responses.append("Error: Invalid tool arguments")
                            tool_successes.append(False)
                        else:
                            tool_result = tool.execute(tool_call[arguments_key])
                            tool_responses.append(tool_result["content"])
                            tool_successes.append(tool_result["success"])
        
        if not step_inference:
            tool_responses = self.format_tool_response(tool_responses)

        return tool_responses, tool_successes, True

    def batch_step(self, raw_responses: List[str]) -> Tuple[List[str], List[List[bool]], List[bool]]:
        batch_tool_responses = [[]] * len(raw_responses)
        batch_tool_successes = [[]] * len(raw_responses)
        batch_active = [True] * len(raw_responses)
        success_tool_calls_arguments = {} # batch 内成功的工具调用。key: tool_name，value: [arguments]
        success_tool_calls_index = {} # batch 内成功的工具调用。key: tool_name，value: [(i,j)]
        for i, raw_response in enumerate(raw_responses):
            tool_calls = self.extract_tool_calls(raw_response)
            if len(tool_calls) == 0:
                batch_tool_successes[i] = []
                batch_active[i] = False
                batch_tool_responses[i] = []
                continue

            if not self.parallel_tool_calls:
                tool_calls = [tool_calls[0]]
            tool_responses = []
            tool_successes = []
            for j, tool_call in enumerate(tool_calls):
                if tool_call is None:
                    tool_responses.append("Error: JSONDecodeError")
                    tool_successes.append(False)
                else:
                    if "name" not in tool_call:
                        tool_responses.append("Error: No tool name")
                        tool_successes.append(False)
                    elif "arguments" not in tool_call:
                        tool_responses.append("Error: No tool arguments")
                        tool_successes.append(False)
                    else:
                        tool_name = tool_call["name"]
                        if tool_name not in self.tool_map:
                            tool_responses.append("Error: ToolNotFoundError")
                            tool_successes.append(False)
                        else:
                            tool = self.tool_map[tool_name]
                            if not tool.validate_args(tool_call["arguments"]):
                                tool_responses.append("Error: Invalid tool arguments")
                                tool_successes.append(False)
                            else:
                                # 默认success_tool_calls[tool_name]
                                if tool_name not in success_tool_calls_arguments:
                                    success_tool_calls_arguments[tool_name] = []
                                    success_tool_calls_index[tool_name] = []
                                tool_responses.append("Executing...")
                                tool_successes.append(False)
                                success_tool_calls_arguments[tool_name].append(tool_call["arguments"])
                                success_tool_calls_index[tool_name].append((i,j))
            batch_tool_responses[i] = tool_responses
            batch_tool_successes[i] = tool_successes
        
        # batch excute
        for tool_name, args_list in success_tool_calls_arguments.items():
            tool = self.tool_map[tool_name] 
            print("开始批量搜索")
            batch_results = tool.batch_execute(args_list)
            print("结束批量搜索")
            for query, batch_result, (i,j) in zip(args_list, batch_results, success_tool_calls_index[tool_name]):
                # print("#"*100)
                # print(f"----》  第{i}个问题   《----")
                # print(f"query:{query}\n\n")
                # print(f"batch_result:{batch_result}\n\n")
                # print("#"*100)
                # print("\n\n")
                assert batch_tool_responses[i][j] == "Executing..."
                batch_tool_responses[i][j] = batch_result["content"]
                batch_tool_successes[i][j] = batch_result["success"]
        
        batch_tool_responses_ = []
        for i, tool_responses in enumerate(batch_tool_responses):
            if batch_active[i]:
                assert len(batch_tool_responses[i]) > 0
                batch_tool_responses_.append(self.format_tool_response(tool_responses))
            else:
                batch_tool_responses_.append("")
        
        return batch_tool_responses_, batch_tool_successes, batch_active

    def stop(self, raw_response: str) -> bool:
        # print(f"raw_response:{raw_response}")
        # print("\n\n")
        tool_calls = self.extract_tool_calls(raw_response)
        if len(tool_calls) == 0:
            return True
        else:
            return False
        
    # def extract_tool_calls(self, raw_response: str) -> List[Any]:
    #     tool_calls = []
    #     pattern = re.compile(f"{re.escape(self.tool_call_start)}(.*?){re.escape(self.tool_call_end)}", re.DOTALL)
    #     for tool_call in re.findall(pattern, raw_response):
    #         try:
    #             tool_call = json.loads(tool_call)
    #             tool_calls.append(tool_call)
    #         except json.JSONDecodeError:
    #             tool_calls.append(None)
        
    #     return tool_calls

    def extract_tool_calls(self, raw_response: str) -> List[Any]:
        # print(f"raw_response:{raw_response}\n\n")
        tool_calls = []
        # 假设 self.tool_call_start/end 已定义，例如 "<tool_call>" 和 "</tool_call>"
        pattern = re.compile(
            f"{re.escape(self.tool_call_start)}(.*?){re.escape(self.tool_call_end)}",
            re.DOTALL
        )
        for snippet in re.findall(pattern, raw_response):
            tc_obj = None
            text = snippet.strip()
            # 1. 直接尝试 json.loads
            try:
                tc_obj = json.loads(text)
            except json.JSONDecodeError:
                # 2. 失败后，寻找包含 {"name" 的位置并手动配对大括号
                idx = text.find('{"name"')
                if idx != -1:
                    start = idx
                    brace_count = 0
                    end = None
                    for i in range(start, len(text)):
                        if text[i] == '{':
                            brace_count += 1
                        elif text[i] == '}':
                            brace_count -= 1
                            if brace_count == 0:
                                end = i
                                break
                    if end is not None:
                        candidate = text[start:end+1]
                        try:
                            tc_obj = json.loads(candidate)
                        except json.JSONDecodeError:
                            tc_obj = None
                # 若找不到或解析失败，则 tc_obj 保持 None
            tool_calls.append(tc_obj)
        return tool_calls

        
    def format_tool_response(self, tool_responses: List[str]) -> str:
        tool_message = "<|im_end|>\n<|im_start|>user\n"
        for i, tool_response in enumerate(tool_responses):
            if len(tool_response) > self.max_tool_response_length:
                tool_response = tool_response[:self.max_tool_response_length] + "..."
            tool_message += f"<tool_response>\n{tool_response}\n</tool_response>"
            if i < len(tool_responses) - 1:
                tool_message += "\n"
        tool_message += "<|im_end|>\n<|im_start|>assistant\n<think>\n"
        return tool_message
        