# Maintain a long-term memory to retrieve historical driving experiences.
import pickle
import numpy as np
from pathlib import Path

from agentdriver.memory.memory_prompts import memory_system_prompt
from agentdriver.llm_core.chat import run_one_round_conversation
from agentdriver.functional_tools.ego_state import extract_ego_inputs
from agentdriver.functional_tools.detection import (
    get_leading_object_detection,
    get_surrounding_object_detections,
    get_front_object_detections,
    get_object_detections_in_range,
    get_all_object_detections,
)

from agentdriver.functional_tools.prediction import (
    get_leading_object_future_trajectory,
    get_future_trajectories_for_specific_objects,
    get_future_trajectories_in_range,
    get_future_waypoint_of_specific_objects_at_timestep,
    get_all_future_trajectories,
)

class ExperienceMemory:
    r"""Memory of Past Driving Experiences."""
    def __init__(self, data_path, model_name = "gpt-3.5-turbo-0613", verbose=False, compare_perception=False) -> None:
        self.data_path = data_path / Path("memory") / Path("database.pkl")
        self.num_keys = 3
        self.keys = []
        self.values = []
        self.tokens = []
        self.load_db()
        self.key_coefs = [1.0, 10.0, 1.0]
        self.k = 3
        self.model_name = model_name
        self.verbose = verbose
        self.compare_perception = compare_perception

    def gen_vector_keys(self, data_dict):
        vx = data_dict['ego_states'][0]*0.5
        vy = data_dict['ego_states'][1]*0.5
        v_yaw = data_dict['ego_states'][4]
        ax = data_dict['ego_hist_traj_diff'][-1, 0] - data_dict['ego_hist_traj_diff'][-2, 0]
        ay = data_dict['ego_hist_traj_diff'][-1, 1] - data_dict['ego_hist_traj_diff'][-2, 1]
        cx = data_dict['ego_states'][2]
        cy = data_dict['ego_states'][3]
        vhead = data_dict['ego_states'][7]*0.5
        steeling = data_dict['ego_states'][8]

        return [
            np.array([vx, vy, v_yaw, ax, ay, cx, cy, vhead, steeling]),
            data_dict['goal'],
            data_dict['ego_hist_traj'].flatten(),
        ]

    def load_db(self):
        r"""Load the memory from a file."""
        data = pickle.load(open(self.data_path, 'rb'))
        temp_keys = []
        for token in data:
            key_arrays = self.gen_vector_keys(data[token])
            if temp_keys == []:
                temp_keys = [[] for _ in range(len(key_arrays))]
            for i, key_array in enumerate(key_arrays):
                temp_keys[i].append(key_array)
            temp_value = data[token].copy()
            temp_value.update({"token": token})
            self.values.append(temp_value)      
            self.tokens.append(token)
        for temp_key in temp_keys:
            temp_key = np.stack(temp_key, axis=0)
            self.keys.append(temp_key)
            
    def compute_similarity(self, queries, token):
        """Compute the similarity between the current experience and the past experiences in the memory."""        
        diffs = []
        for query, key, key_coef in zip(queries, self.keys, self.key_coefs):
            squared_diffs = np.sum((query - key)**2, axis=1)
            diffs.append(squared_diffs * key_coef)
        diffs = sum(diffs)

        confidence = np.exp(-diffs)

        if token in self.tokens:
            self_index = self.tokens.index(token)
            confidence[self_index] = 0.0

        sorted_indices = np.argsort(-confidence, kind="mergesort")

        top_k_indices = sorted_indices[:self.k]

        return top_k_indices, confidence[top_k_indices]

    def vector_retrieve(self, working_memory):
        """ Step-1 Vectorized Retrieval """        
        querys = self.gen_vector_keys(working_memory['ego_data'])
        top_k_indices, confidence = self.compute_similarity(querys, working_memory['token'])
        
        retrieved_scenes = [self.values[i] for i in top_k_indices].copy()
        return retrieved_scenes, confidence

    def gpt_retrieve(self, working_memory, retrieved_scenes, confidence):
        """ Step-2 GPT Retrieval """        
        mem_system_message = memory_system_prompt
        mem_prompts = "** Current Scenario: **:\n"
        mem_prompts += working_memory["ego_prompts"]
        if self.compare_perception:
            mem_prompts += working_memory["perception_prompts"]
            

        mem_prompts += f"Found {len(retrieved_scenes)} relevant experiences:\n"

        retrieve_ego_prompts = []
        for i in range(len(retrieved_scenes)):
            retrieve_prompt = f"** Past Driving Experience {i+1}: **\n"

            retrieve_ego_prompt, _ = extract_ego_inputs(retrieved_scenes[i])
            retrieve_ego_prompt = retrieve_ego_prompt.replace("Ego States:", "Past Ego States:")
            retrieve_prompt += retrieve_ego_prompt
            retrieve_ego_prompts.append(retrieve_ego_prompt)

            if self.compare_perception:
                # get the perception information
                for function_name in working_memory["functions"].keys():
                    function_args = working_memory["functions"][function_name]["args"]
                    try:
                        function_to_call = globals()[function_name]
                        function_prompt, _ = function_to_call(**function_args, data_dict = retrieved_scenes[i])
                        if function_prompt is None:
                            function_prompt = ""
                    except:
                        function_prompt = ""
                    retrieve_prompt += function_prompt

            mem_prompts += retrieve_prompt
        
        mem_prompts += f"Please return the index 1-{self.k} of the most similar experience: "

        if self.verbose:
            print(mem_system_message)
            print(mem_prompts)

        # run the conversation
        _, response_message = run_one_round_conversation(
            full_messages=[],
            system_message=mem_system_message,
            user_message=mem_prompts,
            model_name=self.model_name,
        )
 
        if self.verbose:
            print(f"Memory-GPT response: {response_message['content']}")

        try:
            ret_index = int(response_message["content"])-1
            assert ret_index >= 0 and ret_index < len(retrieved_scenes) - 1
        except:
            return None 

        retrieved_fut_traj = retrieved_scenes[ret_index]["ego_fut_traj"] 

        retrieved_mem_prompt = "*"*5 + "Past Driving Experience for Reference:" + "*"*5 + "\n"
        retrieved_mem_prompt += f"Most similar driving experience from memory with confidence score: {confidence[ret_index]:.2f}:\n"
        # retrieved_mem_prompt += retrieve_ego_prompts[ret_index]
        retrieved_mem_prompt += f"The planned trajectory in this experience for your reference:\n"

        fut_waypoints = [f"({point[0]:.2f},{point[1]:.2f})" for point in retrieved_fut_traj[1:]]
        traj_prompts = "[" + ", ".join(fut_waypoints) + "]\n"

        retrieved_mem_prompt += traj_prompts
        return retrieved_mem_prompt
        
    def retrieve(self, working_memory):  
        r"""Retrieve the most similar past driving experiences with current working memory as input."""

        retrieved_scenes, confidence = self.vector_retrieve(working_memory)
        retrieved_mem_prompt = self.gpt_retrieve(working_memory, retrieved_scenes, confidence)

        return retrieved_mem_prompt
        