import json
import uuid
import ast
import requests

from .basic_handle  import SimulateMultiTurnMessages
from wildtoolbench.bench_test.utils import get_keywords

        
def parse_python_function_call(call_str):
    tree = ast.parse(call_str)
    expr = tree.body[0]

    call_node = expr.value
    function_name = (
        call_node.func.id
        if isinstance(call_node.func, ast.Name)
        else str(call_node.func)
    )

    parameters = {}
    noNameParam = []

    # Process positional arguments
    for arg in call_node.args:
        noNameParam.append(get_keywords(arg))

    # Process keyword arguments
    for kw in call_node.keywords:
        parameters[kw.arg] = get_keywords(kw.value)

    if noNameParam:
        parameters["None"] = noNameParam
        
    function_dict = {"name": function_name, "arguments": parameters}
    return function_dict


FN_CALL_DELIMITER = "<<function>>"


def strip_function_calls(content):
    """
    Split the content by the function call delimiter and remove empty strings
    """
    return [element.strip() for element in content.split(FN_CALL_DELIMITER)[1:] if element.strip()]


def parse_function_call(call):
    """
    This is temporary. The long term solution is to union all the 
    types of the parameters from the user's input function definition,
    and check which language is a proper super set of the union type.
    """
    try:
        return parse_python_function_call(call)
    except Exception as e:
        print(f"error: {e}")
        return None


def format_response(response):
    """
    Formats the response from the OpenFunctions model.

    Parameters:
    - response (str): The response generated by the LLM.

    Returns:
    - str: The formatted response.
    - dict: The function call(s) extracted from the response.

    """
    function_call_dicts = None
    try:
        response = strip_function_calls(response)
        # Parallel function calls returned as a str, list[dict]
        if len(response) > 1:
            function_call_dicts = []
            for function_call in response:
                parse_function_call_dict = parse_function_call(function_call)
                if parse_function_call_dict is not None:
                    function_call_dicts.append(parse_function_call_dict)
            response = ", ".join(response)
        # Single function call returned as a str, dict
        else:
            function_call_dicts = [parse_function_call(response[0])]
            response = response[0]
    except Exception as e:
        # Just faithfully return the generated response str to the user
        print(f"error: {e}")
        pass

    return response, function_call_dicts


class GorillaMultiTurnMessages(SimulateMultiTurnMessages):
    def __init__(self, model_url, is_english=False):
        super().__init__(model_url, is_english)
        self.model_messages = []
    
    def get_prompt(self, user_query: str, history, functions: list = [], env_info=None) -> str:
        """
        Generates a conversation prompt based on the user's query and a list of functions.

        Parameters:
        - user_query (str): The user's query.
        - functions (list): A list of functions to include in the prompt.

        Returns:
        - str: The formatted conversation prompt.
        """
        system = "You are an AI programming assistant, utilizing the Gorilla LLM model, developed by Gorilla LLM, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer."
        if env_info is not None:
            env_info = self.add_date_to_message([{"role":"", "content":""}], env_info)[0]["content"]
            assert env_info.startswith("当前日期") or env_info.startswith("Current Date")
            system = "\n\n" + env_info
        history = "\n".join([f"<<{m['role']}>>" + m["content"] + f"<<{m['role']}>>" for m in history])
        if len(functions) == 0:
            return f"{system}\n### Instruction: <<history>> {history}\n\n<<question>> {user_query}\n### Response: "
        functions_string = json.dumps(functions)

        return f"{system}\n### Instruction:  <<history>> {history}\n\n<<function>>{functions_string}\n<<question>>{user_query}\n### Response: "

    def request_funcall(self, messages, tools, env_info=None):
        """
        Request the function call(s) from the model.

        Parameters:
        - messages (list): The messages in the conversation.
        - tools (list): The list of tools to include in the prompt.

        Returns:
        - str: The formatted response.
        - list: The function call(s) extracted from the response.
        """
        url = self.model_url
        headers = {'Content-Type': 'application/json'}
        history = [_ for _ in messages if _["role"] != "system"]
        query = [_ for _ in messages if _["role"] == "user"][-1]["content"]
        data = {
            'messages': [{
                "content": self.get_prompt(query, history, tools, env_info), "role": "user"
            }],
        }

        text = None
        tool_calls = None
        try:
            response = requests.post(url, headers=headers, json=data, timeout=90)
            if response.status_code == 200:
                result = response.json()
                text = result["answer"].strip()
                _, tool_calls = format_response(result["answer"])
                if (
                    tool_calls is not None 
                    and len(tool_calls) > 0 
                    and len([_ for _ in tool_calls if _ is not None]) > 0
                    and type(tool_calls[0]["name"]) == str
                ):
                    tool_calls = [{"id":str(uuid.uuid4()), "function":_} for _ in tool_calls if _ is not None]
                else:
                    tool_calls = None
        except Exception as e:
            print(f"error: {e}")

        return text, tool_calls


def main():
    handle = GorillaMultiTurnMessages("http://111.111.111.111:12345")
    tools = [
        {
            "type": "function",
            "function": {
                "name": "get_current_weather",
                "description": "Get the current weather in a given location",
                "parameters": {
                    "type": "object",
                    "properties": {
                        "location": {
                            "type": "string",
                            "description": "The city and state, e.g. San Francisco, CA"
                        },
                        "unit": {
                            "type": "string",
                            "enum": [
                                "celsius",
                                "fahrenheit"
                            ]
                        }
                    },
                    "required": [
                        "location"
                    ]
                }
            }
        }
    ]
    messages = [
        {
            "role": "user",
            "content": "What's the weather like in the two cities of Boston and San Francisco?"
        }
    ]
    content, tool_calls = handle.request_funcall(messages, tools)
    print(content)
    print(json.dumps(tool_calls, ensure_ascii=False, indent=4))
    print("==="*10)

    tools = [
        {
            "name": "thermodynamics.calculate_boiling_point",
            "description": "Calculate the boiling point of a given substance at a specific pressure.",
            "parameters": {
                "type": "object",
                "properties": {
                    "substance": {
                        "type": "string",
                        "description": "The substance for which to calculate the boiling point."
                    },
                    "pressure": {
                        "type": "number",
                        "description": "The pressure at which to calculate the boiling point."
                    },
                    "unit": {
                        "type": "string",
                        "description": "The unit of the pressure. Default is 'kPa'."
                    }
                },
                "required": [
                    "substance",
                    "pressure"
                ]
            }
        }
    ]
    messages = [
        {
            "role": "user",
            "content": "What is the freezing point of water at a pressure of 10 kPa?"
        }
    ]
    content, tool_calls = handle.request_funcall(messages, tools)
    print(content)
    print(json.dumps(tool_calls, ensure_ascii=False, indent=4))
    print("==="*10)

    res = "<<function>>getActivityReport(user_id=456, include_details=true, date_range={'start_date': '2023-04-01', 'end_date': '2023-04-30'})"
    _, tool_calls = format_response(res)
    print(_)
    print(tool_calls)
    print(tool_calls is not None)
    print(len(tool_calls) > 0)
    print(len([_ for _ in tool_calls if _ is not None]) > 0)
    print(type(tool_calls[0]["name"]) == str)


if __name__ == "__main__":
    main()
