# Detection functions

import numpy as np
import math
from agentdriver.utils.geometry import CAR_LENGTH, CAR_WIDTH, GRID_SIZE, rotate_bbox
from skimage.draw import polygon
from agentdriver.utils.box_distance import polygons_overlap, polygon_distance
import matplotlib.pyplot as plt

debug = False

get_leading_object_detection_info ={
    "name": "get_leading_object_detection",
    "description": "Get the detection of the leading object, the function will return the leading object id and its position and size. If there is no leading object, return None",
    "parameters": {
      "type": "object",
      "properties": {},
      "required": [],
    },
}

def get_leading_object_detection(data_dict):
    objects = data_dict["objects"]
    prompts = "Leading object detections:\n"
    detected_objs = []
    for obj in objects:
        # search for the leading object (at the same lane and in front of the ego vehicle in 10m)
        obj_x, obj_y = obj["bbox"][:2]
        if abs(obj_x) < 3.0 and obj_y >= 0.0 and obj_y < 10.0:
            prompts += f"Leading object detected, object type: {obj['name']}, object id: {obj['id']}, position: ({obj_x:.2f}, {obj_y:.2f}), size: ({obj['bbox'][3]:.2f}, {obj['bbox'][4]:.2f})\n"
            detected_objs.append(obj)
    if len(detected_objs) == 0:
        prompts = None
    return prompts, detected_objs

get_surrounding_object_detections_info = {
    "name": "get_surrounding_object_detections",
    "description": "Get the detections of the surrounding objects in a 20m*20m range, the function will return a list of surroundind object ids and their positions and sizes. If there is no surrounding object, return None",
    "parameters": {
      "type": "object",
      "properties": {},
      "required": [],
    },
}

def get_surrounding_object_detections(data_dict):
    objects = data_dict["objects"]
    prompts = "Surrounding object detections:\n"
    detected_objs = []
    for obj in objects:
        # search for the surrounding objects (20m*20m range)
        obj_x, obj_y = obj["bbox"][:2]
        if abs(obj_x) < 20.0 and abs(obj_y) < 20.0:
            prompts += f"Surrounding object detected, object type: {obj['name']}, object id: {obj['id']}, position: ({obj_x:.2f}, {obj_y:.2f}), size: ({obj['bbox'][3]:.2f}, {obj['bbox'][4]:.2f})\n"
            detected_objs.append(obj)
    if len(detected_objs) == 0:
        prompts = None
    return prompts, detected_objs

get_front_object_detections_info = {
    "name": "get_front_object_detections",
    "description": "Get the detections of the objects in front of you in a 10m*20m range, the function will return a list of front object ids and their positions and sizes. If there is no front object, return None",
    "parameters": {
      "type": "object",
      "properties": {},
      "required": [],
    },
}

def get_front_object_detections(data_dict):
    objects = data_dict["objects"]
    prompts = "Front object detections:\n"
    detected_objs = []
    for obj in objects:
        # search for the front objects (10m*20m range)
        obj_x, obj_y = obj["bbox"][:2]
        if abs(obj_x) < 5.0 and obj_y >= 0.0 and obj_y < 20.0:
            prompts += f"Front object detected, object type: {obj['name']}, object id: {obj['id']}, position: ({obj_x:.2f}, {obj_y:.2f}), size: ({obj['bbox'][3]:.2f}, {obj['bbox'][4]:.2f})\n"
            detected_objs.append(obj)
    if len(detected_objs) == 0:
        prompts = None
    return prompts, detected_objs

get_object_detections_in_range_info = {
    "name": "get_object_detections_in_range",
    "description": "Get the detections of the objects in a given range (x_start, x_end)*(y_start, y_end)m^2, the function will return a list of object ids and their positions and sizes. If there is no object, return None",
    "parameters": {
        "type": "object",
        "properties": {
            "x_start": {
                "type": "number",
                "minimum": -50,
                "maximum": 50,
                "multipleOf" : 0.01,
                "description": "start range of x axis",
            },
            "x_end": {
                "type": "number",
                "minimum": -50,
                "maximum": 50,
                "multipleOf" : 0.01,
                "description": "end range of x axis",
            },
            "y_start": {
                "type": "number",
                "minimum": -50,
                "maximum": 50,
                "multipleOf" : 0.01,
                "description": "start range of y axis",
            },
            "y_end": {
                "type": "number",
                "minimum": -50,
                "maximum": 50,
                "multipleOf" : 0.01,
                "description": "end range of y axis",
            },
        },
        "required": ["x_start", "x_end", "y_start", "y_end"],
    },
}

def get_object_detections_in_range(x_start, x_end, y_start, y_end, data_dict):
    x_start, x_end, y_start, y_end = float(x_start), float(x_end), float(y_start), float(y_end)
    objects = data_dict["objects"]
    prompts = f"Object detections in X range {x_start:.2f}-{x_end:.2f} and Y range {y_start:.2f}-{y_end:.2f}:\n"
    detected_objs = []
    for obj in objects:
        # search for the objects in range
        obj_x, obj_y = obj["bbox"][:2]
        if obj_x >= x_start and obj_x <= x_end and obj_y >= y_start and obj_y <= y_end:
            prompts += f"Object detected, object type: {obj['name']}, object id: {obj['id']}, position: ({obj_x:.2f}, {obj_y:.2f}), size: ({obj['bbox'][3]:.2f}, {obj['bbox'][4]:.2f})\n"
            detected_objs.append(obj)
    if len(detected_objs) == 0:
        prompts = None
    return prompts, detected_objs

get_all_object_detections_info ={
    "name": "get_all_object_detections",
    "description": "Get the detections of all objects in the whole scene, the function will return a list of object ids and their positions and sizes. Always avoid using this function if there are other choices.",
    "parameters": {
      "type": "object",
      "properties": {},
      "required": [],
    },
}

def get_all_object_detections(data_dict):
    objects = data_dict["objects"]
    prompts = f"Full object detections:\n"
    detected_objs = []
    for obj in objects:
        obj_x, obj_y = obj["bbox"][:2]
        prompts += f"Object detected, object type: {obj['name']}, object id: {obj['id']}, position: ({obj_x:.2f}, {obj_y:.2f}), size: ({obj['bbox'][3]:.2f}, {obj['bbox'][4]:.2f})\n"
        detected_objs.append(obj)
    if len(detected_objs) == 0:
        prompts = None
    return prompts, detected_objs

def check_rotate_object_collision_for_planned_trajectory(trajectory, data_dict, safe_margin=1.):
    objects = data_dict["objects"]

    if debug:
        plt.figure()

    agents_final_corners = []
    for obj in objects:
        x, y, z, dx, dy, dz, rotation_z, rotation_y, rotation_x = obj["bbox"]
        cx, cy = x, y

        rotated_corners = rotate_bbox(0, 0, dx, dy, rotation_z)

        # Get the box corners
        if debug:
            plt.scatter(x+dx/2, y+dy/2, c='b', s=50)
            plt.scatter(x+dx/2, y-dy/2, c='b', s=50)
            plt.scatter(x-dx/2, y-dy/2, c='b', s=50)
            plt.scatter(x-dx/2, y+dy/2, c='b', s=50)
            plt.show()

            final_corners_0 = [(cx + x_prime, cy + y_prime) for x_prime, y_prime in rotated_corners]

            for pt in final_corners_0:
                plt.scatter(pt[0], pt[1], c='g', s=50)

            for pt in obj["traj"]: ## NOTE: traj consists of the center of the bbox
                plt.scatter(pt[0], pt[1], c='r', s=100)
        
        agent_final_corners = []    
        for pt in obj["traj"][:6]:
            cx, cy = pt[0], pt[1]
            agent_final_corners.append([(cx + x_prime, cy + y_prime) for x_prime, y_prime in rotated_corners]) # only save future 6 timesteps

            if debug:
                for pt in agent_final_corners[-1]:
                    plt.scatter(pt[0], pt[1], c='g', s=50)
        
        agents_final_corners.append(agent_final_corners)

    ego_final_corners = []
    for pt in trajectory:
        ego_cx, ego_cy = pt[0], pt[1]
        ego_rotated_corners = rotate_bbox(ego_cx, ego_cy, CAR_WIDTH, CAR_LENGTH, 0) # NOTE ego vehicle is always facing front in evaluation, we can consider to rotate it in practice
        ego_final_corners.append(ego_rotated_corners) # only save future 6 timesteps
        
        if debug:
            for pt in ego_final_corners[-1]:
                plt.scatter(pt[0], pt[1], c='r', s=50)

    collision = np.full(len(trajectory), False)
    for ts in range(len(trajectory)):
        for obj in agents_final_corners:
            if polygons_overlap(ego_final_corners[ts], obj[ts]):
                collision[ts] = True
                if debug:
                    print("Collision detected")
                    plt.figure()
                    plt.scatter(ego_final_corners[ts][:, 0], ego_final_corners[ts][:, 1], c='b', s=50)
                    plt.scatter(np.array(obj[ts])[:,0], np.array(obj[ts])[:,1], c='y', s=50)

            elif polygon_distance(ego_final_corners[ts], obj[ts]) < safe_margin:
                collision[ts] = True
                if debug:
                    print("Collision detected")
                    plt.figure()
                    plt.scatter(ego_final_corners[ts][:, 0], ego_final_corners[ts][:, 1], c='b', s=50)
                    plt.scatter(np.array(obj[ts])[:,0], np.array(obj[ts])[:,1], c='y', s=50)

    return collision