# Generate chain of thoughts reasoning and prompting by simple rules

import numpy as np
from agentdriver.utils.geometry import CAR_LENGTH, CAR_WIDTH

cot_system_message = """
**A Language Agent for Autonomous Driving**
Role: You are the brain of an autonomous vehicle (a.k.a. ego-vehicle). In this step, you need to first determine notable objects and identify their potential effects on your driving route, and then derive a high-level driving plan.

Input
- You will receive your current ego-states.
- You will also receive current perception results.

Task
- You need to determine the notable objects based on perception results and ego-states. Notable objects are the objects that will have potential effects on your driving route.
- You need to describe the potential effects of those notable objects on your driving route.
- You need to derive a high-level driving plan based on the former information and reasoning results. The driving plan should be a combination of a meta action from ["STOP", "MOVE FORWARD", "TURN LEFT", "CHANGE LANE TO LEFT", "TURN RIGHT", "CHANE LANE TO RIGHT"], and a speed description from ["A CONSTANT SPEED", "A DECELERATION", "A QUICK DECELERATION", "AN ACCELERATION", "A QUICK ACCELERATION"] if the meta action is not "STOP".
"""

def generate_chain_of_thoughts(data_dict, working_memory):
    """
    Generate chain of thoughts reasoning and prompting by simple rules
    """
    ego_fut_trajs = data_dict['ego_fut_traj']
    ego_his_trajs = data_dict['ego_hist_traj']
    ego_fut_diff = data_dict['ego_fut_traj_diff']
    ego_his_diff = data_dict['ego_hist_traj_diff']
    vx = data_dict['ego_states'][0]*0.5
    vy = data_dict['ego_states'][1]*0.5
    ax = data_dict['ego_hist_traj_diff'][-1, 0] - data_dict['ego_hist_traj_diff'][-2, 0]
    ay = data_dict['ego_hist_traj_diff'][-1, 1] - data_dict['ego_hist_traj_diff'][-2, 1]
    ego_estimate_velos = [
        [0, 0],
        [vx, vy],
        [vx+ax, max(vy+ay, 0)],
        [vx+2*ax, max(vy+2*ay, 0)],
        [vx+3*ax, max(vy+3*ay, 0)],
        [vx+4*ax, max(vy+4*ay, 0)],
        [vx+5*ax, max(vy+5*ay, 0)],
    ]
    ego_estimate_trajs = np.cumsum(ego_estimate_velos, axis=0) # [7, 2]
    
    # aggregate detected objects from perception working memory

    detected_objects = []
    for function_name in working_memory["functions"].keys():
        if "detection" in function_name:
            detected_objects.extend(working_memory["functions"][function_name]["data"])

    num_objects = len(detected_objects)
    num_future_horizon = 7 # include current
    object_collisons = np.zeros((num_objects, num_future_horizon))
    for i in range(num_objects):
        obj_traj = np.concatenate([detected_objects[i]["bbox"][:2][None, :], detected_objects[i]["traj"][:6, :]], axis=0)
        if (obj_traj[:7, 1]<=0).all(): # negative Y, meaning the object is always behind us, we don't care 
            continue
        for t in range(num_future_horizon):
            ego_x, ego_y = ego_estimate_trajs[t]
            object_x, object_y = obj_traj[t]
            size_x, size_y = detected_objects[i]["bbox"][3]*0.5, detected_objects[i]["bbox"][4]*0.5  #   object_boxes[i, 3:5] * 0.5 # half size
            collision = collision_detection(ego_x, ego_y, CAR_WIDTH/2., CAR_LENGTH/2., object_x, object_y, size_x, size_y)
            if collision:
                object_collisons[i, t] = 1
                break
    cot_message = "*"*5 + "Chain of Thoughts Reasoning:" + "*"*5 + "\n"
    cot_message += f"Thoughts:\n"
    if (object_collisons==0).all(): # nothing to care about
        cot_message += f" - Notable Objects: None\n"
        cot_message += f"   Potential Effects: None\n"
        # cot_message += f"   Nothing to care.\n"
    else:
        for i in range(num_objects):
            obj_traj = np.concatenate([detected_objects[i]["bbox"][:2][None, :], detected_objects[i]["traj"][:6, :]], axis=0)
            for t in range(num_future_horizon):
                if object_collisons[i, t] > 0:
                    object_name = detected_objects[i]["name"]
                    ox, oy = detected_objects[i]["bbox"][:2] 
                    tx, ty = obj_traj[t]
                    time = t*0.5
                    # cot_message += f" ################################################################################\n"
                    cot_message += f" - Notable Objects: {object_name} at ({ox:.2f},{oy:.2f}), moving to ({tx:.2f},{ty:.2f}) at {time} second\n"
                    cot_message += f"   Potential Effects: within the safe zone of the ego-vehicle at {time} second\n"

    meta_action = generate_meta_action(
        ego_fut_diff=ego_fut_diff, 
        ego_fut_trajs=ego_fut_trajs, 
        ego_his_diff=ego_his_diff, 
        ego_his_trajs=ego_his_trajs
    )
    cot_message += ("Driving Plan: " + meta_action)
    return cot_message

def generate_chain_of_thoughts_new(data_dict, working_memory):
    """
    Generate chain of thoughts reasoning and prompting by simple rules
    """
    ego_fut_trajs = data_dict['ego_fut_traj']
    ego_his_trajs = data_dict['ego_hist_traj']
    ego_fut_diff = data_dict['ego_fut_traj_diff']
    ego_his_diff = data_dict['ego_hist_traj_diff']
    vx = data_dict['ego_states'][0]*0.5
    vy = data_dict['ego_states'][1]*0.5
    ax = data_dict['ego_hist_traj_diff'][-1, 0] - data_dict['ego_hist_traj_diff'][-2, 0]
    ay = data_dict['ego_hist_traj_diff'][-1, 1] - data_dict['ego_hist_traj_diff'][-2, 1]
    ego_estimate_velos = [
        [0, 0],
        [vx, vy],
        [vx+ax, max(vy+ay, 0)],
        [vx+2*ax, max(vy+2*ay, 0)],
        [vx+3*ax, max(vy+3*ay, 0)],
        [vx+4*ax, max(vy+4*ay, 0)],
        [vx+5*ax, max(vy+5*ay, 0)],
    ]
    ego_estimate_trajs = np.cumsum(ego_estimate_velos, axis=0) # [7, 2]
    
    # aggregate detected objects from perception working memory

    detected_objects = []
    for function_name in working_memory["functions"].keys():
        if "detection" in function_name:
            detected_objects.extend(working_memory["functions"][function_name]["data"])

    num_objects = len(detected_objects)
    num_future_horizon = 7 # include current
    object_collisons = np.zeros((num_objects, num_future_horizon))
    for i in range(num_objects):
        obj_traj = np.concatenate([detected_objects[i]["bbox"][:2][None, :], detected_objects[i]["traj"][:6, :]], axis=0)
        if (obj_traj[:7, 1]<=0).all(): # negative Y, meaning the object is always behind us, we don't care 
            continue
        for t in range(num_future_horizon):
            ego_x, ego_y = ego_estimate_trajs[t]
            object_x, object_y = obj_traj[t]
            size_x, size_y = detected_objects[i]["bbox"][3]*0.5, detected_objects[i]["bbox"][4]*0.5  #   object_boxes[i, 3:5] * 0.5 # half size
            collision = collision_detection(ego_x, ego_y, 0.925, 2.04, object_x, object_y, size_x, size_y)
            if collision:
                object_collisons[i, t] = 1
                break
    cot_message = "*"*5 + "Chain of Thoughts Reasoning:" + "*"*5 + "\n"
    cot_message += f"Thoughts:\n"
    if (object_collisons==0).all(): # nothing to care about
        cot_message += f" - Notable Objects: None\n"
        cot_message += f"   Potential Effects: None\n"
        # cot_message += f"   Nothing to care.\n"
    else:
        for i in range(num_objects):
            obj_traj = np.concatenate([detected_objects[i]["bbox"][:2][None, :], detected_objects[i]["traj"][:6, :]], axis=0)
            for t in range(num_future_horizon):
                if object_collisons[i, t] > 0:
                    object_name = detected_objects[i]["name"]
                    ox, oy = detected_objects[i]["bbox"][:2] 
                    tx, ty = obj_traj[t]
                    time = t*0.5
                    # cot_message += f" ################################################################################\n"
                    cot_message += f" - Notable Objects: {object_name} at ({ox:.2f},{oy:.2f}), moving to ({tx:.2f},{ty:.2f}) at {time} second\n"
                    cot_message += f"   Potential Effects: within the safe zone of the ego-vehicle at {time} second\n"

    meta_action = generate_meta_action(
        ego_fut_diff=ego_fut_diff, 
        ego_fut_trajs=ego_fut_trajs, 
        ego_his_diff=ego_his_diff, 
        ego_his_trajs=ego_his_trajs
    )
    cot_message += ("Driving Plan: " + meta_action)
    return cot_message

def collision_detection(x1, y1, sx1, sy1, x2, y2, sx2, sy2, x_space=2.0, y_space=6.0): # safe distance
    if (np.abs(x1-x2) < sx1+sx2+x_space) and (y2 > y1) and (y2 - y1 < sy1+sy2+y_space): # in front of you
        return True
    else:
        return False

def generate_meta_action( 
    ego_fut_diff,
    ego_fut_trajs,
    ego_his_diff,
    ego_his_trajs,
    ):
    meta_action = ""

    # speed meta
    constant_eps = 0.5
    his_velos = np.linalg.norm(ego_his_diff, axis=1)
    fut_velos = np.linalg.norm(ego_fut_diff, axis=1)
    cur_velo = his_velos[-1]
    end_velo = fut_velos[-1]

    if cur_velo < constant_eps and end_velo < constant_eps:
        speed_meta = "stop"
    elif end_velo < constant_eps:
        speed_meta = "a deceleration to zero"
    elif np.abs(end_velo - cur_velo) < constant_eps:
        speed_meta = "a constant speed"
    else:
        if cur_velo > end_velo:
            if cur_velo > 2 * end_velo:
                speed_meta = "a quick deceleration"
            else:
                speed_meta = "a deceleration"
        else:
            if end_velo > 2 * cur_velo:
                speed_meta = "a quick acceleration"
            else:
                speed_meta = "an acceleration"
    
    # behavior_meta
    if speed_meta == "stop":
        meta_action += (speed_meta + "\n")
        return meta_action.upper()
    else:
        forward_th = 2.0
        lane_changing_th = 4.0
        if (np.abs(ego_fut_trajs[:, 0]) < forward_th).all():
            behavior_meta = "move forward"
        else:
            if ego_fut_trajs[-1, 0] < 0: # left
                if np.abs(ego_fut_trajs[-1, 0]) > lane_changing_th:
                    behavior_meta = "turn left"
                else:
                    behavior_meta = "change lane to left"
            elif ego_fut_trajs[-1, 0] > 0: # right
                if np.abs(ego_fut_trajs[-1, 0]) > lane_changing_th:
                    behavior_meta = "turn right"
                else:
                    behavior_meta = "change lane to right"
            else:
                raise ValueError(f"Undefined behaviors: {ego_fut_trajs}")
        
        # heading-based rules
        # ego_fut_headings = np.arctan(ego_fut_diff[:,0]/(ego_fut_diff[:,1]+1e-4))*180/np.pi # in degree
        # ego_his_headings = np.arctan(ego_his_diff[:,0]/(ego_his_diff[:,1]+1e-4))*180/np.pi # in degree
        
        # forward_heading_th = 5 # forward heading is always near 0
        # turn_heading_th = 45

        # if (np.abs(ego_fut_headings) < forward_heading_th).all():
        #     behavior_meta = "move forward"
        # else:
        #     # we extract a 5-s curve, if the largest heading change is above 45 degrees, we view it as turn
        #     curve_headings = np.concatenate([ego_his_headings, ego_fut_headings])
        #     min_heading, max_heading = curve_headings.min(), curve_headings.max()
        #     if ego_fut_trajs[-1, 0] < 0: # left
        #         if np.abs(max_heading - min_heading) > turn_heading_th:
        #             behavior_meta = "turn left"
        #         else:
        #             behavior_meta = "chane lane to left"
        #     elif ego_fut_trajs[-1, 0] > 0: # right
        #         if np.abs(max_heading - min_heading) > turn_heading_th:
        #             behavior_meta = "turn right"
        #         else:
        #             behavior_meta = "chane lane to right"
        #     else:
        #         raise ValueError(f"Undefined behaviors: {ego_fut_trajs}")
        
        meta_action += (behavior_meta + " with " + speed_meta + "\n")
        return meta_action.upper()