# Prediction functions

get_leading_object_future_trajectory_info = {
    "name": "get_leading_object_future_trajectory",
    "description": "Get the predicted future trajectory of the leading object, the function will return a trajectory containing a series of waypoints. If there is no leading vehicle, return None",
    "parameters": {
      "type": "object",
      "properties": {},
      "required": [],
    },
}

def get_leading_object_future_trajectory(data_dict, short=False):
    objects = data_dict["objects"]
    prompts = "Leading object future trajectory:\n"
    detected_objs = []
    for obj in objects:
        # search for the leading object (at the same lane and in front of the ego vehicle in 10m)
        obj_x, obj_y = obj["bbox"][:2]
        if abs(obj_x) < 3.0 and obj_y >= 0.0 and obj_y < 10.0:
            if short:
                prompts += f"Leading object found, object type: {obj['name']}, object id: {obj['id']}, moving to: ({obj['traj'][5, 0]:.2f}, {obj['traj'][5, 1]:.2f})\n"
            else:
                trajectory_points = ', '.join(f"({x:.2f}, {y:.2f})" for x, y in obj['traj'][:6])
                prompts += f"Leading object found, object type: {obj['name']}, object id: {obj['id']}, future waypoint coordinates in 6s: [{trajectory_points}]\n"
            detected_objs.append(obj)
    if len(detected_objs) == 0:
        prompts = None
    return prompts, detected_objs

get_future_trajectories_for_specific_objects_info = {
    "name": "get_future_trajectories_for_specific_objects",
    "description": "Get the future trajectories of specific objects (specified by a List of object ids), the function will return trajectories for each object. If there is no object, return None",
    "parameters": {
        "type": "object",
        "properties": {
            "object_ids": {
                "type": "array",
                "items": {
                    "type": "integer",
                    "minimum": 0,
                },
                "description": "a list of integer object ids",
            },
        },
        "required": ["object_ids"],
    },
}

def get_future_trajectories_for_specific_objects(object_ids, data_dict, short=False):
    objects = data_dict["objects"]
    prompts = "Future trajectories for specific objects:\n"
    detected_objs = []
    for obj in objects:
        if obj["id"] in object_ids:
            if short:
                prompts += f"Object type: {obj['name']}, object id: {obj['id']}, moving to: ({obj['traj'][5, 0]:.2f}, {obj['traj'][5, 1]:.2f})\n"
            else:
                trajectory_points = ', '.join(f"({x:.2f}, {y:.2f})" for x, y in obj['traj'][:6])
                prompts += f"Object type: {obj['name']}, object id: {obj['id']}, future waypoint coordinates in 3s: [{trajectory_points}]\n"
            detected_objs.append(obj)
    if len(detected_objs) == 0:
        prompts = None
    return prompts, detected_objs

get_future_trajectories_in_range_info = {
    "name": "get_future_trajectories_in_range",
    "description": "Get the future trajectories where any waypoint in this trajectory falls into a given range (x_start, x_end)*(y_start, y_end)m^2, the function will return each trajectory that satisfies the condition. If there is no trajectory satisfied, return None",
    "parameters": {
        "type": "object",
        "properties": {
            "x_start": {
                "type": "number",
                "minimum": -50,
                "maximum": 50,
                "multipleOf" : 0.01,
                "description": "start range of x axis",
            },
            "x_end": {
                "type": "number",
                "minimum": -50,
                "maximum": 50,
                "multipleOf" : 0.01,
                "description": "end range of x axis",
            },
            "y_start": {
                "type": "number",
                "minimum": -50,
                "maximum": 50,
                "multipleOf" : 0.01,
                "description": "start range of y axis",
            },
            "y_end": {
                "type": "number",
                "minimum": -50,
                "maximum": 50,
                "multipleOf" : 0.01,
                "description": "end range of y axis",
            },
        },
        "required": ["x_start", "x_end", "y_start", "y_end"],
    },
}

def get_future_trajectories_in_range(x_start, x_end, y_start, y_end, data_dict, short=False):
    objects = data_dict["objects"]
    prompts = f"Future trajectories in X range {x_start:.2f}-{x_end:.2f} and Y range {y_start:.2f}-{y_end:.2f}:\n"
    detected_objs = []
    for obj in objects:
        # search for the objects in range
        obj_x, obj_y = obj["bbox"][:2]
        if obj_x >= x_start and obj_x <= x_end and obj_y >= y_start and obj_y <= y_end:
            if short:
                prompts += f"Object found, object type: {obj['name']}, object id: {obj['id']}, moving to: ({obj['traj'][5, 0]:.2f}, {obj['traj'][5, 1]:.2f})\n"
            else:
                trajectory_points = ', '.join(f"({x:.2f}, {y:.2f})" for x, y in obj['traj'][:6])
                prompts += f"Object found, object type: {obj['name']}, object id: {obj['id']}, future waypoint coordinates in 3s: [{trajectory_points}]\n"
            detected_objs.append(obj)
    if len(detected_objs) == 0:
        prompts = None
    return prompts, detected_objs

get_future_waypoint_of_specific_objects_at_timestep_info = {
    "name": "get_future_waypoint_of_specific_objects_at_timestep",
    "description": "Get the future waypoints of specific objects at a specific timestep, the function will return a list of waypoints. If there is no object or the object does not have a waypoint at the given timestep, return None",
    "parameters": {
        "type": "object",
        "properties": {
            "object_ids": {
                "type": "array",
                "items": {
                    "type": "integer",
                    "minimum": 0,
                },
                "description": "a list of object ids",
            },
            "timestep": {
                "type": "integer",
                "minimum": 1,
                "maximum": 6,
                "multipleOf" : 1,
                "description": "the selected timestep of the future trajectory, integer value range [1-6]",
            },
        },
        "required": ["object_ids", "timestep"],
    },
}

def get_future_waypoint_of_specific_objects_at_timestep(object_ids, timestep, data_dict):
    objects = data_dict["objects"]
    prompts = f"Future waypoints of specific objects at time {timestep/2 + 0.5}s:\n"
    detected_objs = []
    for object_id in object_ids:
        obj = objects[object_id]
        if len(obj["traj"]) > timestep:
            prompts += f"object type: {obj['name']}, object id: {obj['id']}, waypoint: ({obj['traj'][timestep, 0]:.2f}, {obj['traj'][timestep, 1]:.2f}) at timestep {timestep}\n"
        else:
            prompts = None
        detected_objs.append(obj)
    if len(prompts) == 0:
        prompts = None
    return prompts, detected_objs

get_all_future_trajectories_info = {
    "name": "get_all_future_trajectories",
    "description": "Get the predicted future trajectories of all objects in the whole scene, the function will return a list of object ids and their future trajectories. Always avoid using this function if there are other choices.",
    "parameters": {
      "type": "object",
      "properties": {},
      "required": [],
    },
}

def get_all_future_trajectories(data_dict, short=False):
    objects = data_dict["objects"]
    prompts = "All future trajectories:\n"
    for obj in objects:
        if short:
            prompts += f"Object type: {obj['name']}, object id: {obj['id']}, moving to: ({obj['traj'][5, 0]:.2f}, {obj['traj'][5, 1]:.2f})\n"
        else:
            trajectory_points = ', '.join(f"({x:.2f}, {y:.2f})" for x, y in obj['traj'][:6])
            prompts += f"Object type: {obj['name']}, object id: {obj['id']}, future waypoint coordinates in 3s: [{trajectory_points}]\n"
    if len(objects) == 0:
        prompts = None
    return prompts, objects