{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "import sys\n",
    "sys.path.append(\"../..\")\n",
    "\n",
    "from agentdriver.functional_tools.prediction import get_leading_object_future_trajectory, get_future_trajectories_for_specific_objects, get_future_trajectories_in_range, get_future_waypoint_of_specific_objects_at_timestep, get_all_future_trajectories\n",
    "\n",
    "data_dict = pickle.load(open('../../data/val/0a0d6b8c2e884134a3b48df43d54c36a.pkl', 'rb'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Leading object future trajectory:\n",
      "Leading object found, object type: car, object id: 2, future waypoint coordinates in 6s: [(4.36, 9.56), (4.36, 9.56), (4.36, 9.57), (4.36, 9.57), (4.36, 9.56), (4.36, 9.56)]\n",
      "Leading object found, object type: car, object id: 3, future waypoint coordinates in 6s: [(-2.66, 13.82), (-1.69, 14.79), (-0.99, 16.13), (-0.25, 17.73), (0.19, 19.42), (0.57, 21.35)]\n",
      "\n",
      "Leading object future trajectory:\n",
      "Leading object found, object type: car, object id: 2, moving to: (4.36, 9.56)\n",
      "Leading object found, object type: car, object id: 3, moving to: (0.57, 21.35)\n",
      "\n",
      "Future trajectories for specific objects:\n",
      "Object type: car, object id: 0, future waypoint coordinates in 3s: [(-7.60, -5.55), (-7.60, -5.55), (-7.60, -5.55), (-7.60, -5.55), (-7.60, -5.55), (-7.60, -5.55)]\n",
      "Object type: car, object id: 2, future waypoint coordinates in 3s: [(4.36, 9.56), (4.36, 9.56), (4.36, 9.57), (4.36, 9.57), (4.36, 9.56), (4.36, 9.56)]\n",
      "Object type: pedestrian, object id: 5, future waypoint coordinates in 3s: [(8.98, 10.39), (9.01, 11.08), (9.01, 11.76), (9.01, 12.45), (9.01, 13.15), (9.00, 13.83)]\n",
      "\n",
      "Future trajectories for specific objects:\n",
      "Object type: car, object id: 0, moving to: (-7.60, -5.55)\n",
      "Object type: car, object id: 2, moving to: (4.36, 9.56)\n",
      "Object type: pedestrian, object id: 5, moving to: (9.00, 13.83)\n",
      "\n",
      "Future trajectories in X range -20.00-20.00 and Y range 0.00-10.00:\n",
      "Object found, object type: car, object id: 1, future waypoint coordinates in 3s: [(-14.26, 8.84), (-14.25, 8.83), (-14.23, 8.82), (-14.24, 8.83), (-14.24, 8.83), (-14.22, 8.84)]\n",
      "Object found, object type: car, object id: 2, future waypoint coordinates in 3s: [(4.36, 9.56), (4.36, 9.56), (4.36, 9.57), (4.36, 9.57), (4.36, 9.56), (4.36, 9.56)]\n",
      "Object found, object type: pedestrian, object id: 5, future waypoint coordinates in 3s: [(8.98, 10.39), (9.01, 11.08), (9.01, 11.76), (9.01, 12.45), (9.01, 13.15), (9.00, 13.83)]\n",
      "Object found, object type: pedestrian, object id: 17, future waypoint coordinates in 3s: [(-12.47, 5.30), (-12.48, 5.35), (-12.47, 5.38), (-12.52, 5.39), (-12.51, 5.49), (-12.52, 5.54)]\n",
      "Object found, object type: pedestrian, object id: 22, future waypoint coordinates in 3s: [(-12.47, 5.30), (-12.49, 5.36), (-12.48, 5.39), (-12.53, 5.40), (-12.51, 5.49), (-12.52, 5.54)]\n",
      "Object found, object type: pedestrian, object id: 32, future waypoint coordinates in 3s: [(-12.54, 5.31), (-12.55, 5.37), (-12.54, 5.39), (-12.59, 5.40), (-12.57, 5.49), (-12.58, 5.53)]\n",
      "\n",
      "Future trajectories in X range -20.00-20.00 and Y range 0.00-10.00:\n",
      "Object found, object type: car, object id: 1, moving to: (-14.22, 8.84)\n",
      "Object found, object type: car, object id: 2, moving to: (4.36, 9.56)\n",
      "Object found, object type: pedestrian, object id: 5, moving to: (9.00, 13.83)\n",
      "Object found, object type: pedestrian, object id: 17, moving to: (-12.52, 5.54)\n",
      "Object found, object type: pedestrian, object id: 22, moving to: (-12.52, 5.54)\n",
      "Object found, object type: pedestrian, object id: 32, moving to: (-12.58, 5.53)\n",
      "\n",
      "Future waypoints of specific objects at time 3.0s:\n",
      "object type: car, object id: 0, waypoint: (-7.60, -5.55) at timestep 5\n",
      "object type: car, object id: 2, waypoint: (4.36, 9.56) at timestep 5\n",
      "object type: pedestrian, object id: 5, waypoint: (9.00, 13.83) at timestep 5\n",
      "\n",
      "All future trajectories:\n",
      "Object type: car, object id: 0, future waypoint coordinates in 3s: [(-7.60, -5.55), (-7.60, -5.55), (-7.60, -5.55), (-7.60, -5.55), (-7.60, -5.55), (-7.60, -5.55)]\n",
      "Object type: car, object id: 1, future waypoint coordinates in 3s: [(-14.26, 8.84), (-14.25, 8.83), (-14.23, 8.82), (-14.24, 8.83), (-14.24, 8.83), (-14.22, 8.84)]\n",
      "Object type: car, object id: 2, future waypoint coordinates in 3s: [(4.36, 9.56), (4.36, 9.56), (4.36, 9.57), (4.36, 9.57), (4.36, 9.56), (4.36, 9.56)]\n",
      "Object type: car, object id: 3, future waypoint coordinates in 3s: [(-2.66, 13.82), (-1.69, 14.79), (-0.99, 16.13), (-0.25, 17.73), (0.19, 19.42), (0.57, 21.35)]\n",
      "Object type: car, object id: 4, future waypoint coordinates in 3s: [(-20.89, 9.39), (-20.89, 9.40), (-20.89, 9.40), (-20.89, 9.40), (-20.89, 9.40), (-20.89, 9.41)]\n",
      "Object type: pedestrian, object id: 5, future waypoint coordinates in 3s: [(8.98, 10.39), (9.01, 11.08), (9.01, 11.76), (9.01, 12.45), (9.01, 13.15), (9.00, 13.83)]\n",
      "Object type: car, object id: 6, future waypoint coordinates in 3s: [(-21.40, 14.16), (-21.40, 14.16), (-21.40, 14.17), (-21.41, 14.16), (-21.40, 14.16), (-21.40, 14.17)]\n",
      "Object type: pedestrian, object id: 7, future waypoint coordinates in 3s: [(-12.93, 17.35), (-12.93, 16.59), (-12.95, 15.76), (-12.93, 14.95), (-12.97, 14.15), (-12.98, 13.45)]\n",
      "Object type: pedestrian, object id: 8, future waypoint coordinates in 3s: [(8.98, 13.17), (8.98, 12.42), (8.98, 11.66), (8.97, 10.89), (8.92, 10.14), (8.93, 9.40)]\n",
      "Object type: car, object id: 9, future waypoint coordinates in 3s: [(6.24, 42.27), (6.24, 42.27), (6.24, 42.27), (6.24, 42.27), (6.24, 42.27), (6.25, 42.28)]\n",
      "Object type: car, object id: 10, future waypoint coordinates in 3s: [(-0.09, -33.56), (-0.20, -31.64), (-0.23, -29.71), (-0.25, -27.70), (-0.32, -25.70), (-0.35, -23.63)]\n",
      "Object type: pedestrian, object id: 11, future waypoint coordinates in 3s: [(9.00, 13.15), (9.01, 12.47), (9.01, 11.80), (9.01, 11.12), (8.99, 10.45), (8.99, 9.78)]\n",
      "Object type: pedestrian, object id: 12, future waypoint coordinates in 3s: [(-6.25, -27.81), (-5.55, -28.13), (-4.80, -28.45), (-4.01, -28.82), (-3.27, -29.15), (-2.47, -29.55)]\n",
      "Object type: pedestrian, object id: 13, future waypoint coordinates in 3s: [(6.94, -23.25), (7.01, -23.93), (7.08, -24.63), (7.14, -25.32), (7.20, -26.01), (7.27, -26.69)]\n",
      "Object type: pedestrian, object id: 14, future waypoint coordinates in 3s: [(8.83, 27.74), (8.83, 28.44), (8.82, 29.17), (8.86, 29.92), (8.88, 30.62), (8.87, 31.33)]\n",
      "Object type: pedestrian, object id: 15, future waypoint coordinates in 3s: [(8.97, 13.09), (8.97, 12.34), (8.96, 11.58), (8.96, 10.81), (8.90, 10.04), (8.91, 9.31)]\n",
      "Object type: car, object id: 16, future waypoint coordinates in 3s: [(-4.14, 20.86), (-4.14, 20.85), (-4.14, 20.87), (-4.14, 20.87), (-4.13, 20.88), (-4.13, 20.86)]\n",
      "Object type: pedestrian, object id: 17, future waypoint coordinates in 3s: [(-12.47, 5.30), (-12.48, 5.35), (-12.47, 5.38), (-12.52, 5.39), (-12.51, 5.49), (-12.52, 5.54)]\n",
      "Object type: car, object id: 18, future waypoint coordinates in 3s: [(-4.14, 20.88), (-4.14, 20.88), (-4.14, 20.89), (-4.14, 20.90), (-4.13, 20.91), (-4.13, 20.89)]\n",
      "Object type: pedestrian, object id: 19, future waypoint coordinates in 3s: [(8.06, -14.43), (8.04, -13.84), (8.05, -13.24), (7.98, -12.64), (7.97, -12.04), (7.94, -11.43)]\n",
      "Object type: pedestrian, object id: 20, future waypoint coordinates in 3s: [(-12.64, 17.64), (-12.61, 16.86), (-12.61, 16.02), (-12.56, 15.20), (-12.58, 14.38), (-12.57, 13.66)]\n",
      "Object type: pedestrian, object id: 21, future waypoint coordinates in 3s: [(8.05, -14.42), (8.03, -13.83), (8.04, -13.23), (7.97, -12.63), (7.96, -12.03), (7.93, -11.43)]\n",
      "Object type: pedestrian, object id: 22, future waypoint coordinates in 3s: [(-12.47, 5.30), (-12.49, 5.36), (-12.48, 5.39), (-12.53, 5.40), (-12.51, 5.49), (-12.52, 5.54)]\n",
      "Object type: pedestrian, object id: 23, future waypoint coordinates in 3s: [(8.69, 27.36), (8.70, 28.08), (8.69, 28.82), (8.72, 29.58), (8.74, 30.30), (8.73, 31.01)]\n",
      "Object type: pedestrian, object id: 24, future waypoint coordinates in 3s: [(7.98, -14.37), (7.96, -13.78), (7.96, -13.17), (7.90, -12.58), (7.89, -11.97), (7.85, -11.36)]\n",
      "Object type: pedestrian, object id: 25, future waypoint coordinates in 3s: [(6.30, -27.25), (6.29, -27.88), (6.28, -28.54), (6.29, -29.18), (6.29, -29.84), (6.32, -30.47)]\n",
      "Object type: pedestrian, object id: 26, future waypoint coordinates in 3s: [(8.05, -14.42), (8.04, -13.82), (8.04, -13.23), (7.98, -12.63), (7.97, -12.02), (7.93, -11.42)]\n",
      "Object type: car, object id: 27, future waypoint coordinates in 3s: [(13.50, 28.56), (13.50, 28.56), (13.50, 28.56), (13.50, 28.57), (13.50, 28.57), (13.51, 28.57)]\n",
      "Object type: pedestrian, object id: 28, future waypoint coordinates in 3s: [(-22.10, -26.42), (-22.30, -26.94), (-22.55, -27.53), (-22.71, -28.17), (-22.95, -28.78), (-23.06, -29.29)]\n",
      "Object type: pedestrian, object id: 29, future waypoint coordinates in 3s: [(-19.94, -28.98), (-19.94, -28.98), (-19.93, -28.97), (-19.93, -28.97), (-19.93, -28.97), (-19.94, -28.97)]\n",
      "Object type: pedestrian, object id: 30, future waypoint coordinates in 3s: [(7.98, -14.37), (7.96, -13.77), (7.97, -13.17), (7.91, -12.57), (7.89, -11.97), (7.86, -11.36)]\n",
      "Object type: pedestrian, object id: 31, future waypoint coordinates in 3s: [(-12.24, 17.90), (-12.20, 17.13), (-12.20, 16.28), (-12.15, 15.45), (-12.16, 14.63), (-12.14, 13.90)]\n",
      "Object type: pedestrian, object id: 32, future waypoint coordinates in 3s: [(-12.54, 5.31), (-12.55, 5.37), (-12.54, 5.39), (-12.59, 5.40), (-12.57, 5.49), (-12.58, 5.53)]\n",
      "Object type: pedestrian, object id: 33, future waypoint coordinates in 3s: [(6.26, -23.75), (6.28, -24.48), (6.30, -25.23), (6.31, -25.99), (6.30, -26.76), (6.29, -27.47)]\n",
      "Object type: pedestrian, object id: 34, future waypoint coordinates in 3s: [(8.07, -14.44), (8.06, -13.85), (8.06, -13.25), (8.00, -12.65), (7.98, -12.04), (7.95, -11.44)]\n",
      "Object type: pedestrian, object id: 35, future waypoint coordinates in 3s: [(6.28, -26.72), (6.28, -26.75), (6.29, -26.76), (6.26, -26.78), (6.25, -26.78), (6.27, -26.83)]\n",
      "Object type: pedestrian, object id: 36, future waypoint coordinates in 3s: [(6.27, -25.64), (6.32, -26.31), (6.38, -27.00), (6.44, -27.69), (6.51, -28.37), (6.58, -29.03)]\n",
      "Object type: pedestrian, object id: 37, future waypoint coordinates in 3s: [(-22.24, -25.98), (-22.35, -26.63), (-22.50, -27.34), (-22.57, -28.09), (-22.72, -28.82), (-22.77, -29.43)]\n",
      "Object type: pedestrian, object id: 38, future waypoint coordinates in 3s: [(-19.67, -28.37), (-19.67, -28.36), (-19.66, -28.36), (-19.67, -28.36), (-19.66, -28.36), (-19.67, -28.36)]\n",
      "Object type: car, object id: 39, future waypoint coordinates in 3s: [(15.71, 44.77), (15.71, 44.77), (15.71, 44.77), (15.71, 44.77), (15.71, 44.77), (15.71, 44.78)]\n",
      "Object type: pedestrian, object id: 40, future waypoint coordinates in 3s: [(-22.26, -26.04), (-22.36, -26.70), (-22.51, -27.42), (-22.58, -28.17), (-22.72, -28.91), (-22.78, -29.53)]\n",
      "Object type: pedestrian, object id: 41, future waypoint coordinates in 3s: [(-13.56, -33.86), (-13.56, -34.47), (-13.58, -35.06), (-13.58, -35.64), (-13.59, -36.24), (-13.59, -36.84)]\n",
      "Object type: pedestrian, object id: 42, future waypoint coordinates in 3s: [(8.04, -14.41), (8.03, -13.82), (8.03, -13.21), (7.97, -12.61), (7.96, -12.00), (7.92, -11.39)]\n",
      "Object type: pedestrian, object id: 43, future waypoint coordinates in 3s: [(-19.75, -31.20), (-19.75, -31.20), (-19.75, -31.20), (-19.75, -31.20), (-19.75, -31.20), (-19.75, -31.19)]\n",
      "Object type: pedestrian, object id: 44, future waypoint coordinates in 3s: [(6.44, -27.14), (6.43, -27.15), (6.45, -27.15), (6.43, -27.15), (6.43, -27.12), (6.45, -27.15)]\n",
      "Object type: pedestrian, object id: 45, future waypoint coordinates in 3s: [(8.07, -14.43), (8.05, -13.84), (8.05, -13.24), (7.99, -12.64), (7.98, -12.03), (7.95, -11.43)]\n",
      "Object type: pedestrian, object id: 46, future waypoint coordinates in 3s: [(6.35, -27.18), (6.34, -27.19), (6.36, -27.19), (6.34, -27.19), (6.34, -27.16), (6.36, -27.19)]\n",
      "Object type: pedestrian, object id: 47, future waypoint coordinates in 3s: [(6.44, -25.97), (6.46, -26.61), (6.47, -27.27), (6.51, -27.93), (6.54, -28.59), (6.58, -29.22)]\n",
      "\n",
      "All future trajectories:\n",
      "Object type: car, object id: 0, moving to: (-7.60, -5.55)\n",
      "Object type: car, object id: 1, moving to: (-14.22, 8.84)\n",
      "Object type: car, object id: 2, moving to: (4.36, 9.56)\n",
      "Object type: car, object id: 3, moving to: (0.57, 21.35)\n",
      "Object type: car, object id: 4, moving to: (-20.89, 9.41)\n",
      "Object type: pedestrian, object id: 5, moving to: (9.00, 13.83)\n",
      "Object type: car, object id: 6, moving to: (-21.40, 14.17)\n",
      "Object type: pedestrian, object id: 7, moving to: (-12.98, 13.45)\n",
      "Object type: pedestrian, object id: 8, moving to: (8.93, 9.40)\n",
      "Object type: car, object id: 9, moving to: (6.25, 42.28)\n",
      "Object type: car, object id: 10, moving to: (-0.35, -23.63)\n",
      "Object type: pedestrian, object id: 11, moving to: (8.99, 9.78)\n",
      "Object type: pedestrian, object id: 12, moving to: (-2.47, -29.55)\n",
      "Object type: pedestrian, object id: 13, moving to: (7.27, -26.69)\n",
      "Object type: pedestrian, object id: 14, moving to: (8.87, 31.33)\n",
      "Object type: pedestrian, object id: 15, moving to: (8.91, 9.31)\n",
      "Object type: car, object id: 16, moving to: (-4.13, 20.86)\n",
      "Object type: pedestrian, object id: 17, moving to: (-12.52, 5.54)\n",
      "Object type: car, object id: 18, moving to: (-4.13, 20.89)\n",
      "Object type: pedestrian, object id: 19, moving to: (7.94, -11.43)\n",
      "Object type: pedestrian, object id: 20, moving to: (-12.57, 13.66)\n",
      "Object type: pedestrian, object id: 21, moving to: (7.93, -11.43)\n",
      "Object type: pedestrian, object id: 22, moving to: (-12.52, 5.54)\n",
      "Object type: pedestrian, object id: 23, moving to: (8.73, 31.01)\n",
      "Object type: pedestrian, object id: 24, moving to: (7.85, -11.36)\n",
      "Object type: pedestrian, object id: 25, moving to: (6.32, -30.47)\n",
      "Object type: pedestrian, object id: 26, moving to: (7.93, -11.42)\n",
      "Object type: car, object id: 27, moving to: (13.51, 28.57)\n",
      "Object type: pedestrian, object id: 28, moving to: (-23.06, -29.29)\n",
      "Object type: pedestrian, object id: 29, moving to: (-19.94, -28.97)\n",
      "Object type: pedestrian, object id: 30, moving to: (7.86, -11.36)\n",
      "Object type: pedestrian, object id: 31, moving to: (-12.14, 13.90)\n",
      "Object type: pedestrian, object id: 32, moving to: (-12.58, 5.53)\n",
      "Object type: pedestrian, object id: 33, moving to: (6.29, -27.47)\n",
      "Object type: pedestrian, object id: 34, moving to: (7.95, -11.44)\n",
      "Object type: pedestrian, object id: 35, moving to: (6.27, -26.83)\n",
      "Object type: pedestrian, object id: 36, moving to: (6.58, -29.03)\n",
      "Object type: pedestrian, object id: 37, moving to: (-22.77, -29.43)\n",
      "Object type: pedestrian, object id: 38, moving to: (-19.67, -28.36)\n",
      "Object type: car, object id: 39, moving to: (15.71, 44.78)\n",
      "Object type: pedestrian, object id: 40, moving to: (-22.78, -29.53)\n",
      "Object type: pedestrian, object id: 41, moving to: (-13.59, -36.84)\n",
      "Object type: pedestrian, object id: 42, moving to: (7.92, -11.39)\n",
      "Object type: pedestrian, object id: 43, moving to: (-19.75, -31.19)\n",
      "Object type: pedestrian, object id: 44, moving to: (6.45, -27.15)\n",
      "Object type: pedestrian, object id: 45, moving to: (7.95, -11.43)\n",
      "Object type: pedestrian, object id: 46, moving to: (6.36, -27.19)\n",
      "Object type: pedestrian, object id: 47, moving to: (6.58, -29.22)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "## test prediction functions\n",
    "\n",
    "prompts, detected_objs = get_leading_object_future_trajectory(data_dict)\n",
    "print(prompts)\n",
    "prompts, detected_objs = get_leading_object_future_trajectory(data_dict, short=True)\n",
    "print(prompts)\n",
    "\n",
    "prompts, detected_objs = get_future_trajectories_for_specific_objects([0,2,5], data_dict, short=False)\n",
    "print(prompts)\n",
    "prompts, detected_objs = get_future_trajectories_for_specific_objects([0,2,5], data_dict, short=True)\n",
    "print(prompts)\n",
    "\n",
    "prompts, detected_objs = get_future_trajectories_in_range(-20, 20, 0, 10, data_dict, short=False)\n",
    "print(prompts)\n",
    "prompts, detected_objs = get_future_trajectories_in_range(-20, 20, 0, 10, data_dict, short=True)\n",
    "print(prompts)\n",
    "\n",
    "prompts, detected_objs = get_future_waypoint_of_specific_objects_at_timestep([0,2,5], 5, data_dict)\n",
    "print(prompts)\n",
    "\n",
    "prompts, detected_objs = get_all_future_trajectories(data_dict, short=False)\n",
    "print(prompts)\n",
    "prompts, detected_objs = get_all_future_trajectories(data_dict, short=True)\n",
    "print(prompts)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "get_leading_object_detection(data_dict)/n\n",
      "Leading object detections:\n",
      "Leading object detected, object type: car, object id: 2, position: (4.36, 9.56), size: (1.86, 4.72)\n",
      "Leading object detected, object type: car, object id: 3, position: (-3.70, 13.08), size: (2.01, 4.92)\n",
      "\n",
      "get_surrounding_object_detections(data_dict)/n\n",
      "Surrounding object detections:\n",
      "Surrounding object detected, object type: car, object id: 0, position: (-7.60, -5.55), size: (1.87, 4.83)\n",
      "Surrounding object detected, object type: car, object id: 1, position: (-14.26, 8.85), size: (1.84, 4.40)\n",
      "Surrounding object detected, object type: car, object id: 2, position: (4.36, 9.56), size: (1.86, 4.72)\n",
      "Surrounding object detected, object type: car, object id: 3, position: (-3.70, 13.08), size: (2.01, 4.92)\n",
      "Surrounding object detected, object type: pedestrian, object id: 5, position: (8.97, 9.75), size: (0.69, 0.84)\n",
      "Surrounding object detected, object type: pedestrian, object id: 7, position: (-12.91, 18.02), size: (0.60, 0.63)\n",
      "Surrounding object detected, object type: pedestrian, object id: 8, position: (8.99, 13.87), size: (0.59, 0.63)\n",
      "Surrounding object detected, object type: pedestrian, object id: 11, position: (8.99, 13.81), size: (0.59, 0.63)\n",
      "Surrounding object detected, object type: pedestrian, object id: 15, position: (8.98, 13.78), size: (0.59, 0.63)\n",
      "Surrounding object detected, object type: pedestrian, object id: 17, position: (-12.47, 5.22), size: (0.54, 0.62)\n",
      "Surrounding object detected, object type: pedestrian, object id: 19, position: (8.08, -15.03), size: (0.68, 0.79)\n",
      "Surrounding object detected, object type: pedestrian, object id: 20, position: (-12.64, 18.32), size: (0.61, 0.65)\n",
      "Surrounding object detected, object type: pedestrian, object id: 21, position: (8.07, -15.02), size: (0.68, 0.79)\n",
      "Surrounding object detected, object type: pedestrian, object id: 22, position: (-12.48, 5.22), size: (0.54, 0.63)\n",
      "Surrounding object detected, object type: pedestrian, object id: 24, position: (8.00, -14.97), size: (0.68, 0.79)\n",
      "Surrounding object detected, object type: pedestrian, object id: 26, position: (8.08, -15.02), size: (0.68, 0.79)\n",
      "Surrounding object detected, object type: pedestrian, object id: 30, position: (8.00, -14.97), size: (0.68, 0.79)\n",
      "Surrounding object detected, object type: pedestrian, object id: 31, position: (-12.24, 18.59), size: (0.62, 0.65)\n",
      "Surrounding object detected, object type: pedestrian, object id: 32, position: (-12.54, 5.24), size: (0.54, 0.63)\n",
      "Surrounding object detected, object type: pedestrian, object id: 34, position: (8.09, -15.04), size: (0.68, 0.79)\n",
      "Surrounding object detected, object type: pedestrian, object id: 42, position: (8.06, -15.02), size: (0.68, 0.79)\n",
      "Surrounding object detected, object type: pedestrian, object id: 45, position: (8.09, -15.04), size: (0.68, 0.79)\n",
      "\n",
      "get_front_object_detections(data_dict)/n\n",
      "Front object detections:\n",
      "Front object detected, object type: car, object id: 2, position: (4.36, 9.56), size: (1.86, 4.72)\n",
      "Front object detected, object type: car, object id: 3, position: (-3.70, 13.08), size: (2.01, 4.92)\n",
      "Front object detected, object type: pedestrian, object id: 5, position: (8.97, 9.75), size: (0.69, 0.84)\n",
      "Front object detected, object type: pedestrian, object id: 8, position: (8.99, 13.87), size: (0.59, 0.63)\n",
      "Front object detected, object type: pedestrian, object id: 11, position: (8.99, 13.81), size: (0.59, 0.63)\n",
      "Front object detected, object type: pedestrian, object id: 14, position: (8.81, 27.05), size: (0.71, 0.80)\n",
      "Front object detected, object type: pedestrian, object id: 15, position: (8.98, 13.78), size: (0.59, 0.63)\n",
      "Front object detected, object type: car, object id: 16, position: (-4.13, 20.86), size: (1.89, 4.70)\n",
      "Front object detected, object type: car, object id: 18, position: (-4.13, 20.88), size: (1.89, 4.70)\n",
      "Front object detected, object type: pedestrian, object id: 23, position: (8.68, 26.66), size: (0.71, 0.81)\n",
      "\n",
      "get_object_detections_in_range(-20, 20, 0, 10, data_dict)/n\n",
      "Object detections in X range -20.00-20.00 and Y range 0.00-10.00:\n",
      "Object detected, object type: car, object id: 1, position: (-14.26, 8.85), size: (1.84, 4.40)\n",
      "Object detected, object type: car, object id: 2, position: (4.36, 9.56), size: (1.86, 4.72)\n",
      "Object detected, object type: pedestrian, object id: 5, position: (8.97, 9.75), size: (0.69, 0.84)\n",
      "Object detected, object type: pedestrian, object id: 17, position: (-12.47, 5.22), size: (0.54, 0.62)\n",
      "Object detected, object type: pedestrian, object id: 22, position: (-12.48, 5.22), size: (0.54, 0.63)\n",
      "Object detected, object type: pedestrian, object id: 32, position: (-12.54, 5.24), size: (0.54, 0.63)\n",
      "\n",
      "get_all_object_detections(data_dict)/n\n",
      "Full object detections:\n",
      "Object detected, object type: car, object id: 0, position: (-7.60, -5.55), size: (1.87, 4.83)\n",
      "Object detected, object type: car, object id: 1, position: (-14.26, 8.85), size: (1.84, 4.40)\n",
      "Object detected, object type: car, object id: 2, position: (4.36, 9.56), size: (1.86, 4.72)\n",
      "Object detected, object type: car, object id: 3, position: (-3.70, 13.08), size: (2.01, 4.92)\n",
      "Object detected, object type: car, object id: 4, position: (-20.89, 9.39), size: (1.96, 4.63)\n",
      "Object detected, object type: pedestrian, object id: 5, position: (8.97, 9.75), size: (0.69, 0.84)\n",
      "Object detected, object type: car, object id: 6, position: (-21.40, 14.16), size: (1.84, 4.30)\n",
      "Object detected, object type: pedestrian, object id: 7, position: (-12.91, 18.02), size: (0.60, 0.63)\n",
      "Object detected, object type: pedestrian, object id: 8, position: (8.99, 13.87), size: (0.59, 0.63)\n",
      "Object detected, object type: car, object id: 9, position: (6.23, 42.27), size: (1.84, 4.44)\n",
      "Object detected, object type: car, object id: 10, position: (-0.03, -35.43), size: (1.96, 4.58)\n",
      "Object detected, object type: pedestrian, object id: 11, position: (8.99, 13.81), size: (0.59, 0.63)\n",
      "Object detected, object type: pedestrian, object id: 12, position: (-6.91, -27.53), size: (0.79, 0.93)\n",
      "Object detected, object type: pedestrian, object id: 13, position: (6.88, -22.58), size: (0.65, 0.74)\n",
      "Object detected, object type: pedestrian, object id: 14, position: (8.81, 27.05), size: (0.71, 0.80)\n",
      "Object detected, object type: pedestrian, object id: 15, position: (8.98, 13.78), size: (0.59, 0.63)\n",
      "Object detected, object type: car, object id: 16, position: (-4.13, 20.86), size: (1.89, 4.70)\n",
      "Object detected, object type: pedestrian, object id: 17, position: (-12.47, 5.22), size: (0.54, 0.62)\n",
      "Object detected, object type: car, object id: 18, position: (-4.13, 20.88), size: (1.89, 4.70)\n",
      "Object detected, object type: pedestrian, object id: 19, position: (8.08, -15.03), size: (0.68, 0.79)\n",
      "Object detected, object type: pedestrian, object id: 20, position: (-12.64, 18.32), size: (0.61, 0.65)\n",
      "Object detected, object type: pedestrian, object id: 21, position: (8.07, -15.02), size: (0.68, 0.79)\n",
      "Object detected, object type: pedestrian, object id: 22, position: (-12.48, 5.22), size: (0.54, 0.63)\n",
      "Object detected, object type: pedestrian, object id: 23, position: (8.68, 26.66), size: (0.71, 0.81)\n",
      "Object detected, object type: pedestrian, object id: 24, position: (8.00, -14.97), size: (0.68, 0.79)\n",
      "Object detected, object type: pedestrian, object id: 25, position: (6.31, -26.68), size: (0.62, 0.70)\n",
      "Object detected, object type: pedestrian, object id: 26, position: (8.08, -15.02), size: (0.68, 0.79)\n",
      "Object detected, object type: car, object id: 27, position: (13.50, 28.55), size: (1.85, 4.53)\n",
      "Object detected, object type: pedestrian, object id: 28, position: (-21.89, -25.99), size: (0.54, 0.58)\n",
      "Object detected, object type: pedestrian, object id: 29, position: (-19.94, -28.99), size: (0.62, 0.70)\n",
      "Object detected, object type: pedestrian, object id: 30, position: (8.00, -14.97), size: (0.68, 0.79)\n",
      "Object detected, object type: pedestrian, object id: 31, position: (-12.24, 18.59), size: (0.62, 0.65)\n",
      "Object detected, object type: pedestrian, object id: 32, position: (-12.54, 5.24), size: (0.54, 0.63)\n",
      "Object detected, object type: pedestrian, object id: 33, position: (6.25, -23.05), size: (0.62, 0.71)\n",
      "Object detected, object type: pedestrian, object id: 34, position: (8.09, -15.04), size: (0.68, 0.79)\n",
      "Object detected, object type: pedestrian, object id: 35, position: (6.29, -26.68), size: (0.62, 0.70)\n",
      "Object detected, object type: pedestrian, object id: 36, position: (6.22, -25.00), size: (0.63, 0.71)\n",
      "Object detected, object type: pedestrian, object id: 37, position: (-22.12, -25.44), size: (0.53, 0.59)\n",
      "Object detected, object type: pedestrian, object id: 38, position: (-19.67, -28.37), size: (0.62, 0.71)\n",
      "Object detected, object type: car, object id: 39, position: (15.71, 44.77), size: (1.85, 4.35)\n",
      "Object detected, object type: pedestrian, object id: 40, position: (-22.14, -25.50), size: (0.53, 0.59)\n",
      "Object detected, object type: pedestrian, object id: 41, position: (-13.56, -33.22), size: (0.68, 0.80)\n",
      "Object detected, object type: pedestrian, object id: 42, position: (8.06, -15.02), size: (0.68, 0.79)\n",
      "Object detected, object type: pedestrian, object id: 43, position: (-19.75, -31.21), size: (0.63, 0.70)\n",
      "Object detected, object type: pedestrian, object id: 44, position: (6.44, -27.13), size: (0.63, 0.69)\n",
      "Object detected, object type: pedestrian, object id: 45, position: (8.09, -15.04), size: (0.68, 0.79)\n",
      "Object detected, object type: pedestrian, object id: 46, position: (6.35, -27.17), size: (0.63, 0.69)\n",
      "Object detected, object type: pedestrian, object id: 47, position: (6.43, -25.37), size: (0.63, 0.71)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "## test detection functions\n",
    "\n",
    "from agentdriver.functional_tools.detection import get_leading_object_detection, get_surrounding_object_detections, get_front_object_detections, get_object_detections_in_range, get_all_object_detections\n",
    "\n",
    "prompts, detected_objs = get_leading_object_detection(data_dict)\n",
    "print(\"get_leading_object_detection(data_dict)/n\")\n",
    "print(prompts)\n",
    "\n",
    "prompts, detected_objs = get_surrounding_object_detections(data_dict)\n",
    "print(\"get_surrounding_object_detections(data_dict)/n\")\n",
    "print(prompts)\n",
    "\n",
    "prompts, detected_objs = get_front_object_detections(data_dict)\n",
    "print(\"get_front_object_detections(data_dict)/n\")\n",
    "print(prompts)\n",
    "\n",
    "prompts, detected_objs = get_object_detections_in_range(-20, 20, 0, 10, data_dict)\n",
    "print(\"get_object_detections_in_range(-20, 20, 0, 10, data_dict)/n\")\n",
    "print(prompts)\n",
    "\n",
    "prompts, detected_objs = get_all_object_detections(data_dict)\n",
    "print(\"get_all_object_detections(data_dict)/n\")\n",
    "print(prompts)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "get_occupancy_at_locations_for_timestep([array([-7.5959854, -5.554348 ], dtype=float32), array([-14.235255,   8.831818], dtype=float32), array([4.3573565, 9.565505 ], dtype=float32), array([-0.24887705, 17.727306  ], dtype=float32), array([-20.891598,   9.395823], dtype=float32), array([-21.40746 ,  14.163794], dtype=float32), array([ 6.242978, 42.270267], dtype=float32), array([ -0.24772091, -27.702385  ], dtype=float32), array([-4.136358, 20.871729], dtype=float32), array([-4.1377354, 20.896942 ], dtype=float32), array([13.502827, 28.56776 ], dtype=float32), array([15.709743, 44.774147], dtype=float32)], 3, data_dict)/n\n",
      "Occupancy information:\n",
      "Location (-7.60, -5.55) is occupied at timestep 3\n",
      "Location (-14.24, 8.83) is occupied at timestep 3\n",
      "Location (4.36, 9.57) is occupied at timestep 3\n",
      "Location (-0.25, 17.73) is occupied at timestep 3\n",
      "Location (-20.89, 9.40) is occupied at timestep 3\n",
      "Location (-21.41, 14.16) is occupied at timestep 3\n",
      "Location (6.24, 42.27) is occupied at timestep 3\n",
      "Location (-0.25, -27.70) is occupied at timestep 3\n",
      "Location (-4.14, 20.87) is occupied at timestep 3\n",
      "Location (-4.14, 20.90) is occupied at timestep 3\n",
      "Location (13.50, 28.57) is occupied at timestep 3\n",
      "Location (15.71, 44.77) is not occupied at timestep 3\n",
      "\n",
      "check_occupancy_for_planned_trajectory([[25,25],[25,26],[26,27]], data_dict)/n\n",
      "Check collision of the planned trajectory:\n",
      "The planned trajectory does not collide with any other objects.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "\n",
    "## test occupancy functions\n",
    "from agentdriver.functional_tools.occupancy import get_occupancy_at_locations_for_timestep, check_occupancy_for_planned_trajectory\n",
    "from agentdriver.functional_tools.detection import get_all_object_detections\n",
    "from agentdriver.utils.geometry import location_to_pixel_coordinate\n",
    "\n",
    "prompts, detected_objs = get_all_object_detections(data_dict)\n",
    "\n",
    "loc = [detected_objs[i][\"traj\"][3, :] for i in range(len(detected_objs)) if detected_objs[i][\"name\"] == \"car\"]\n",
    "\n",
    "prompts, occ_list = get_occupancy_at_locations_for_timestep(loc, 3, data_dict)\n",
    "print(f\"get_occupancy_at_locations_for_timestep({loc}, 3, data_dict)/n\")\n",
    "print(prompts)\n",
    "\n",
    "prompts, occ_list = check_occupancy_for_planned_trajectory([[2,3],[3,4],[4,5]], data_dict)\n",
    "print(\"check_occupancy_for_planned_trajectory([[25,25],[25,26],[26,27]], data_dict)/n\")\n",
    "print(prompts)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "get_drivable_at_locations([(25, 25)], data_dict)/n\n",
      "Drivability of selected locations:\n",
      "Location (25.00, 25.00) is not drivable\n",
      "\n",
      "check_drivable_of_planned_trajectory([[2,3],[3,4],[4,5]], data_dict)/n\n",
      "Drivability of the planned trajectory:\n",
      "All waypoints of the planned trajectory are in drivable regions\n",
      "\n",
      "get_lane_category_at_locations([(25, 25)], data_dict)/n\n",
      "Lane category of selected locations:\n",
      "Location (25.00, 25.00) has no lane category\n",
      "\n",
      "get_distance_to_shoulder_at_locations([(25, 25)], data_dict)/n\n",
      "Distance to both sides of road shoulders of selected locations:\n",
      "Location (25.00, 25.00) distance to left shoulder is 18.5m and distance to right shoulder is uncertain\n",
      "\n",
      "get_current_shoulder(data_dict)/n\n",
      "Distance to both sides of road shoulders of current ego-vehicle location:\n",
      "Current ego-vehicle's distance to left shoulder is 7.5m and right shoulder is 4.0m\n",
      "\n",
      "get_distance_to_lane_divider_at_locations([(25, 25)], data_dict)/n\n",
      "Get distance to both sides of road lane_dividers of selected locations:\n",
      "Location (25.00, 25.00) distance to left lane_divider is 22.0m and distance to right lane_divider is uncertain\n",
      "\n",
      "get_current_lane_divider(data_dict)/n\n",
      "Get distance to both sides of road lane_dividers of current ego-vehicle location:\n",
      "Current ego-vehicle's distance to left lane_divider is 1.0m and distance to right lane_divider is 1.5m\n",
      "\n",
      "get_nearest_pedestrian_crossing(data_dict)/n\n",
      "Get the nearest pedestrian crossing location:\n",
      "The nearest pedestrian crossing is at (-9.25, 9.75)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "## test map functions\n",
    "from agentdriver.functional_tools.map import get_drivable_at_locations, check_drivable_of_planned_trajectory, get_lane_category_at_locations, get_distance_to_shoulder_at_locations, get_current_shoulder, get_distance_to_lane_divider_at_locations, get_current_lane_divider, get_nearest_pedestrian_crossing\n",
    "\n",
    "prompts, drivable_list = get_drivable_at_locations([(25, 25)], data_dict)\n",
    "print(\"get_drivable_at_locations([(25, 25)], data_dict)/n\")\n",
    "print(prompts)\n",
    "\n",
    "prompts, drivable_list = check_drivable_of_planned_trajectory([[2,3],[3,4],[4,5]], data_dict)\n",
    "print(\"check_drivable_of_planned_trajectory([[2,3],[3,4],[4,5]], data_dict)/n\")\n",
    "print(prompts)\n",
    "\n",
    "prompts, lane_category_list = get_lane_category_at_locations([(25, 25)], data_dict)\n",
    "print(\"get_lane_category_at_locations([(25, 25)], data_dict)/n\")\n",
    "print(prompts)\n",
    "\n",
    "prompts, distance_to_shoulder_list = get_distance_to_shoulder_at_locations([(25, 25)], data_dict)\n",
    "print(\"get_distance_to_shoulder_at_locations([(25, 25)], data_dict)/n\")\n",
    "print(prompts)\n",
    "\n",
    "prompts, current_shoulder = get_current_shoulder(data_dict)\n",
    "print(\"get_current_shoulder(data_dict)/n\")\n",
    "print(prompts)\n",
    "\n",
    "prompts, distance_to_lane_divider_list = get_distance_to_lane_divider_at_locations([(25, 25)], data_dict)\n",
    "print(\"get_distance_to_lane_divider_at_locations([(25, 25)], data_dict)/n\")\n",
    "print(prompts)\n",
    "\n",
    "prompts, current_lane_divider = get_current_lane_divider(data_dict)\n",
    "print(\"get_current_lane_divider(data_dict)/n\")\n",
    "print(prompts)\n",
    "\n",
    "prompts, nearest_pedestrian_crossing = get_nearest_pedestrian_crossing(data_dict)\n",
    "print(\"get_nearest_pedestrian_crossing(data_dict)/n\")\n",
    "print(prompts)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llmagent",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
