import pickle
import json
from pathlib import Path

from agentdriver.llm_core.chat import run_one_round_conversation, run_one_round_conversation_with_functional_call
from agentdriver.llm_core.timeout import timeout
from agentdriver.functional_tools.functional_agent import FuncAgent
from agentdriver.perception.perception_prompts import (
    init_system_message,
    detection_prompt,
    prediction_prompt,
    occupancy_prompt,
    map_prompt,
)

class PerceptionAgent:
    def __init__(self, token, split, data_path, model_name = "gpt-3.5-turbo-0613", verbose=True) -> None:
        self.token = token
        folder_name = Path("val") if "val" in split else Path("train")
        self.file_name = data_path / folder_name / Path(f"{self.token}.pkl")
        with open(self.file_name, "rb") as f:
            self.data_dict = pickle.load(f)
        self.func_agent = FuncAgent(self.data_dict)
        self.model_name = model_name
        self.verbose = verbose

        self.num_call_detection_times = 1
        self.num_call_prediction_times = 1
        self.num_call_occupancy_times = 1
        self.num_call_map_times = 1

    def functional_call(self, response_message):
        # Call the function from GPT response
        function_name = response_message["function_call"]["name"]
        function_args = json.loads(response_message["function_call"]["arguments"])
        function_to_call = getattr(self.func_agent, function_name)
        if not callable(function_to_call):
            print(f"Function {function_name} is not callable!")
            return None
        else:
            function_returns = function_to_call(
                **function_args
            )
            function_prmopt, function_ret_data = function_returns
            if function_prmopt is None:
                function_prmopt = ""
            function_response = {
                "name": function_name,
                "args": function_args,
                "prompt": function_prmopt,
                "data": function_ret_data,
            }
        if self.verbose:
            print(function_name)
            print(function_args)
            print(function_prmopt)
        return function_response

    def generate_detection_func_prompt(self):
        detection_func_prompt = "You can execute one of the following functions to get object detection results (don't execute functions that have been used before):\n"
        for info in self.func_agent.detection_func_infos:
            detection_func_prompt += ("- " + info["name"] + "(")
            if info["parameters"].get("required"):
                param_ind = 0
                for param in info["parameters"]["required"]:
                    if param_ind < len(info["parameters"]["required"]) - 1:
                        detection_func_prompt += (param + ", ")
                    else:
                        detection_func_prompt += param
                    param_ind += 1
            detection_func_prompt += ") #"
            detection_func_prompt += info["description"] + "\n"
        # detection_func_prompt += "You can also anwser NO if you do not need to use these functions for object detection.\n"
        if self.verbose:
            print(detection_func_prompt)
        return detection_func_prompt

    def generate_prediction_func_prompt(self):
        prediction_func_prompt = "You can execute one of the following functions to get object future trajectory predictions (don't execute functions that have been used before):\n"
        for info in self.func_agent.prediction_func_infos:
            prediction_func_prompt += ("- " + info["name"] + "(")
            if info["parameters"].get("required"):
                param_ind = 0
                for param in info["parameters"]["required"]:
                    if param_ind < len(info["parameters"]["required"]) - 1:
                        prediction_func_prompt += (param + ", ")
                    else:
                        prediction_func_prompt += param
                    param_ind += 1
            prediction_func_prompt += ") #"
            prediction_func_prompt += info["description"] + "\n"
        # prediction_func_prompt += "You can also anwser NO if you do not need to use these functions for future trajectory prediction.\n"
        if self.verbose:
            print(prediction_func_prompt)
        return prediction_func_prompt

    def generate_occupancy_func_prompt(self):
        occupancy_func_prompt = "You can execute one of the following functions to get occupancy information (don't execute functions that have been used before):\n"
        for info in self.func_agent.occupancy_func_infos:
            occupancy_func_prompt += ("- " + info["name"] + "(")
            if info["parameters"].get("required"):
                param_ind = 0
                for param in info["parameters"]["required"]:
                    if param_ind < len(info["parameters"]["required"]) - 1:
                        occupancy_func_prompt += (param + ", ")
                    else:
                        occupancy_func_prompt += param
                    param_ind += 1
            occupancy_func_prompt += ") #"
            occupancy_func_prompt += info["description"] + "\n"
        # occupancy_func_prompt += "You can also anwser NO if you do not need to use these functions for occupancy.\n"
        if self.verbose:
            print(occupancy_func_prompt)
        return occupancy_func_prompt

    def generate_map_func_prompt(self):
        map_func_prompt = "You can execute one of the following functions to get map information (don't execute functions that have been used before):\n"
        for info in self.func_agent.map_func_infos:
            map_func_prompt += ("- " + info["name"] + "(")
            if info["parameters"].get("required"):
                param_ind = 0
                for param in info["parameters"]["required"]:
                    if param_ind < len(info["parameters"]["required"]) - 1:
                        map_func_prompt += (param + ", ")
                    else:
                        map_func_prompt += param
                    param_ind += 1
            map_func_prompt += ") #"
            map_func_prompt += info["description"] + "\n"
        # map_func_prompt += "You can also anwser NO if you do not need to use these functions for map.\n"
        if self.verbose:
            print(map_func_prompt)
        return map_func_prompt

    def get_perception_results(self, ego_prompts):
        """
        Collecting necessary information from driving scenarios by chain-of-thought reasoning with function call
        """
        full_messages = []
        func_responses = []
        system_message = init_system_message + "\n" + ego_prompts + "\n" # general context and scenario-specific context

        if self.verbose:
            print(system_message)
            print(detection_prompt)

        # Detection information
        full_messages, detection_response = run_one_round_conversation(
            full_messages=full_messages, 
            system_message=system_message, 
            user_message=detection_prompt,
            model_name=self.model_name,
        )

        if self.verbose:
            print(detection_response["content"])

        if detection_response["content"] == "YES":
            for _ in range(self.num_call_detection_times):
                full_messages, detection_func_response = run_one_round_conversation_with_functional_call(
                    full_messages=full_messages, 
                    system_message=None, 
                    user_message=self.generate_detection_func_prompt(), 
                    functional_calls_info=self.func_agent.detection_func_infos,
                    model_name=self.model_name,
                )

                if self.verbose:
                    print(detection_func_response["content"])

                # Check if functional call is triggered
                if detection_func_response.get("function_call"):
                    function_response = self.functional_call(detection_func_response)
                    # Append function response to conversation
                    if function_response is not None:
                        full_messages.append(
                            {
                                "role": "function",
                                "name": function_response["name"],
                                "content": function_response["prompt"],
                            }
                        )
                else:
                    function_response = None
                func_responses.append(function_response)
        else:
            pass

        if self.verbose:
            print(prediction_prompt)

        # Prediction information
        full_messages, prediction_response = run_one_round_conversation(
            full_messages=full_messages, 
            system_message=None, 
            user_message=prediction_prompt, 
            model_name=self.model_name,
        )

        if self.verbose:
            print(prediction_response["content"])

        if prediction_response["content"] == "YES":
            for _ in range(self.num_call_prediction_times):
                full_messages, prediction_func_response = run_one_round_conversation_with_functional_call(
                    full_messages=full_messages, 
                    system_message=None, 
                    user_message=self.generate_prediction_func_prompt(), 
                    functional_calls_info=self.func_agent.prediction_func_infos,
                    model_name=self.model_name,
                )

                if self.verbose:
                    print(prediction_func_response["content"])

                # Check if functional call is triggered
                if prediction_func_response.get("function_call"):
                    function_response = self.functional_call(prediction_func_response)
                    # Append function response to conversation
                    if function_response is not None:
                        full_messages.append(
                            {
                                "role": "function",
                                "name": function_response["name"],
                                "content": function_response["prompt"],
                            }
                        )
                else:
                    function_response = None
                func_responses.append(function_response)
        else:
            pass

        if self.verbose:
            print(occupancy_prompt)

        # Occupancy information
        full_messages, occupancy_response = run_one_round_conversation(
            full_messages=full_messages, 
            system_message=None, 
            user_message=occupancy_prompt,
            model_name=self.model_name,
        )

        if self.verbose:
            print(occupancy_response["content"])

        if occupancy_response["content"] == "YES":
            for _ in range(self.num_call_occupancy_times):
                full_messages, occupancy_func_response = run_one_round_conversation_with_functional_call(
                    full_messages=full_messages, 
                    system_message=None, 
                    user_message=self.generate_occupancy_func_prompt(), 
                    functional_calls_info=self.func_agent.occupancy_func_infos,
                    model_name=self.model_name,
                )

                if self.verbose:
                    print(occupancy_func_response["content"])

                # Check if functional call is triggered
                if occupancy_func_response.get("function_call"):
                    function_response = self.functional_call(occupancy_func_response)
                    # Append function response to conversation
                    if function_response is not None:
                        full_messages.append(
                            {
                                "role": "function",
                                "name": function_response["name"],
                                "content": function_response["prompt"],
                            }
                        )
                else:
                    function_response = None
                func_responses.append(function_response)
        else:
            pass

        if self.verbose:
            print(map_prompt)

        # Map information
        full_messages, map_response = run_one_round_conversation(
            full_messages=full_messages, 
            system_message=None, 
            user_message=map_prompt, 
            model_name=self.model_name,
        )

        if self.verbose:
            print(map_response["content"])

        if map_response["content"] == "YES":
            for _ in range(self.num_call_map_times):
                full_messages, map_func_response = run_one_round_conversation_with_functional_call(
                    full_messages=full_messages, 
                    system_message=None, 
                    user_message=self.generate_map_func_prompt(), 
                    functional_calls_info=self.func_agent.map_func_infos,
                    model_name=self.model_name,
                )

                if self.verbose:
                    print(map_func_response["content"])

                # Check if functional call is triggered
                if map_func_response.get("function_call"):
                    function_response = self.functional_call(map_func_response)
                    # Append function response to conversation
                    if function_response is not None:
                        full_messages.append(
                            {
                                "role": "function",
                                "name": function_response["name"],
                                "content": function_response["prompt"],
                            }
                        )
                else:
                    function_response = None
                func_responses.append(function_response)
        else:
            pass

        return full_messages, func_responses

    def process_perception_results(self, ego_prompts, ego_data, full_messages, func_responses):
        """
        Process the results from perception
        """
        perception_prompts = "*"*5 + "Perception Results:" + "*"*5 + "\n"
        working_memory = {}
        working_memory["token"] = self.token
        working_memory["ego_data"] = ego_data
        working_memory["functions"] = {}
        for func_response in func_responses:
            if func_response is not None:
                perception_prompts += func_response["prompt"] + "\n"
                working_memory["functions"][func_response["name"]] = {
                    "data": func_response["data"],
                    "args": func_response["args"],
                } 
        if self.verbose:
            print(perception_prompts)
        working_memory.update({"perception_prompts": perception_prompts})
        working_memory.update({"ego_prompts": ego_prompts})
        return perception_prompts, working_memory

    @timeout(15)
    def run(self):
        ego_prompts, ego_data = self.func_agent.get_ego_states()
        full_messages, func_responses = self.get_perception_results(ego_prompts)
        perception_prompts, working_memory = self.process_perception_results(ego_prompts, ego_data, full_messages, func_responses)
        return ego_prompts, perception_prompts, working_memory