{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-10-01T07:53:20.855574Z",
     "start_time": "2020-10-01T07:53:18.996423Z"
    }
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "sys.path.append('/home/leo/particle/TrafficFluidsPt/TrafficFluids')\n",
    "sys.path.append('../datasets')\n",
    "import numpy as np\n",
    "from glob import glob\n",
    "from tensorpack import dataflow\n",
    "import pandas as pd\n",
    "import helper\n",
    "import tqdm\n",
    "import pickle\n",
    "import time\n",
    "from collections import defaultdict\n",
    "from typing import Any, Dict, List, Tuple, Union\n",
    "from datasets.argoverse_lane_loader import read_pkl_data\n",
    "from argoverse.data_loading.argoverse_forecasting_loader import ArgoverseForecastingLoader\n",
    "from argoverse.map_representation.map_api import ArgoverseMap\n",
    "from train_utils import *\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torch.nn.functional as F\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.lines as mlines\n",
    "%matplotlib inline\n",
    "\n",
    "device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-10-01T07:53:22.702179Z",
     "start_time": "2020-10-01T07:53:22.224885Z"
    }
   },
   "outputs": [],
   "source": [
    "dataset_path = '/home/leo/particle/argoverse/argoverse_forecasting/'\n",
    "afl = ArgoverseForecastingLoader(os.path.join(dataset_path, 'test_obs', 'data'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-10-01T02:34:05.945723Z",
     "start_time": "2020-10-01T02:34:05.938625Z"
    }
   },
   "outputs": [],
   "source": [
    "def _filter_imcomplete_data(data, tstmps, window=20):\n",
    "        complete_id = list()\n",
    "        for idx, subdf in data[data['TIMESTAMP'].isin(tstmps[:window])].groupby('TRACK_ID'):\n",
    "            if len(subdf) == window:\n",
    "                complete_id.append(idx)\n",
    "        return data[data['TRACK_ID'].isin(complete_id)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-30T04:07:54.384634Z",
     "start_time": "2020-09-30T03:57:56.274568Z"
    },
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 78143/78143 [09:58<00:00, 130.65it/s]\n"
     ]
    }
   ],
   "source": [
    "num_list = []\n",
    "for i in tqdm.tqdm(range(len(afl))):\n",
    "    data = afl[i].seq_df\n",
    "    tstmps = data.TIMESTAMP.unique()\n",
    "    tstmps.sort()\n",
    "\n",
    "    data = _filter_imcomplete_data(data, tstmps, 20)\n",
    "    num_list.append(len(data.TRACK_ID.unique()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-30T15:26:03.003780Z",
     "start_time": "2020-09-30T15:26:02.998411Z"
    },
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "78143"
      ]
     },
     "execution_count": 63,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(num_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 116,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-10-01T01:43:19.447156Z",
     "start_time": "2020-10-01T01:43:19.236449Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "78143"
      ]
     },
     "execution_count": 116,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "file_list = !ls /home/leo/particle/argoverse/argoverse_forecasting/test_obs/lane_data/\n",
    "len(file_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 112,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-10-01T00:54:51.778869Z",
     "start_time": "2020-10-01T00:54:51.742465Z"
    }
   },
   "outputs": [],
   "source": [
    "lack_files = list(set(range(78143)).difference([int(f.split('.')[0]) for f in file_list]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 95,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-10-01T00:42:43.683622Z",
     "start_time": "2020-10-01T00:42:37.025511Z"
    }
   },
   "outputs": [],
   "source": [
    "am = ArgoverseMap()\n",
    "putil = process_utils(lane_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 101,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-10-01T00:49:22.843676Z",
     "start_time": "2020-10-01T00:49:22.811417Z"
    },
    "code_folding": [
     0
    ]
   },
   "outputs": [],
   "source": [
    "class ArgoverseTest(object):\n",
    "    \"\"\"\n",
    "    Data flow for argoverse dataset\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, file_path: str, shuffle: bool = True, random_rotation: bool = False,\n",
    "                 max_car_num: int = 50, freq: int = 10, use_interpolate: bool = False, \n",
    "                 lane_path: str = \"/home/leo/particle/TrafficFluids/datasets\", \n",
    "                 use_lane: bool = False, use_mask: bool = True):\n",
    "        if not os.path.exists(file_path):\n",
    "            raise Exception(\"Path does not exist.\")\n",
    "\n",
    "        self.afl = ArgoverseForecastingLoader(file_path)\n",
    "        self.shuffle = shuffle\n",
    "        self.random_rotation = random_rotation\n",
    "        self.max_car_num = max_car_num\n",
    "        self.freq = freq\n",
    "        self.use_interpolate = use_interpolate\n",
    "        self.am = ArgoverseMap()\n",
    "        self.use_mask = use_mask\n",
    "        self.file_path = file_path\n",
    "        \n",
    "\n",
    "    def get_feat(self, scene):\n",
    "\n",
    "        data, city = self.afl[scene].seq_df, self.afl[scene].city\n",
    "\n",
    "        lane = np.array([[0., 0., 0.]], dtype=np.float32)\n",
    "        lane_drct = np.array([[0., 0., 0.]], dtype=np.float32)\n",
    "\n",
    "\n",
    "        tstmps = data.TIMESTAMP.unique()\n",
    "        tstmps.sort()\n",
    "\n",
    "        data = self._filter_imcomplete_data(data, tstmps, 20)\n",
    "\n",
    "        data = self._calc_vel(data, self.freq)\n",
    "\n",
    "        agent = data[data['OBJECT_TYPE'] == 'AGENT']['TRACK_ID'].values[0]\n",
    "\n",
    "        car_mask = np.zeros((self.max_car_num, 1), dtype=np.float32)\n",
    "        car_mask[:len(data.TRACK_ID.unique())] = 1.0\n",
    "\n",
    "        feat_dict = {'city': city, \n",
    "                     'lane': lane, \n",
    "                     'lane_norm': lane_drct, \n",
    "                     'scene_idx': scene,  \n",
    "                     'agent_id': agent, \n",
    "                     'car_mask': car_mask}\n",
    "\n",
    "        pos_enc = [subdf[['X', 'Y']].values[np.newaxis,:] \n",
    "                   for _, subdf in data[data['TIMESTAMP'].isin(tstmps[:19])].groupby('TRACK_ID')]\n",
    "        pos_enc = np.concatenate(pos_enc, axis=0)\n",
    "        pos_enc = np.insert(pos_enc, 0, axis=1, values=pos_enc[:,0])\n",
    "        pos_enc = self._expand_dim(pos_enc)\n",
    "        feat_dict['pos_2s'] = self._expand_particle(pos_enc, self.max_car_num, 0)\n",
    "\n",
    "        vel_enc = [subdf[['vel_x', 'vel_y']].values[np.newaxis,:] \n",
    "                   for _, subdf in data[data['TIMESTAMP'].isin(tstmps[:19])].groupby('TRACK_ID')]\n",
    "        vel_enc = np.concatenate(vel_enc, axis=0)\n",
    "        vel_enc = np.insert(vel_enc, 0, axis=1, values=vel_enc[:,0])\n",
    "        vel_enc = self._expand_dim(vel_enc)\n",
    "        feat_dict['vel_2s'] = self._expand_particle(vel_enc, self.max_car_num, 0)\n",
    "\n",
    "        pos = data[data['TIMESTAMP'] == tstmps[19]][['X', 'Y']].values\n",
    "        pos = self._expand_dim(pos)\n",
    "        feat_dict['pos0'] = self._expand_particle(pos, self.max_car_num, 0)\n",
    "        vel = data[data['TIMESTAMP'] == tstmps[19]][['vel_x', 'vel_y']].values\n",
    "        vel = self._expand_dim(vel)\n",
    "        feat_dict['vel0'] = self._expand_particle(vel, self.max_car_num, 0)\n",
    "        track_id =  data[data['TIMESTAMP'] == tstmps[19]]['TRACK_ID'].values\n",
    "        feat_dict['track_id0'] = self._expand_particle(track_id, self.max_car_num, 0, 'str')\n",
    "        feat_dict['frame_id0'] = 0\n",
    "\n",
    "        return feat_dict\n",
    "    \n",
    "    def __len__(self):\n",
    "        return len(glob.glob(os.path.join(self.file_path, '*')))\n",
    "            \n",
    "    @classmethod\n",
    "    def __expand_df_generator(cls, df, city_name):\n",
    "        ids = df.TRACK_ID.unique()\n",
    "        tstmps = df.TIMESTAMP.unique()\n",
    "        for tstmp, sub_df in df.groupby('TIMESTAMP'):\n",
    "            for idx in ids:\n",
    "                if not idx in sub_df.TRACK_ID.values:\n",
    "                    yield pd.DataFrame(dict(TIMESTAMP = [tstmp], TRACK_ID = [idx], X = [np.nan], Y = [np.nan], \n",
    "                                       CITY_NAME = [city_name], \n",
    "                                            OBJECT_TYPE = [df[df['TRACK_ID'] == idx]['OBJECT_TYPE'].iloc[0]]))\n",
    "                else:\n",
    "                    yield df[(df['TIMESTAMP'] == tstmp) & (df['TRACK_ID'] == idx)]\n",
    "\n",
    "    @classmethod\n",
    "    def _expand_df(cls, df, city_name):\n",
    "        return pd.concat(cls.__expand_df_generator(df, city_name), axis=0)\n",
    "\n",
    "\n",
    "    @classmethod\n",
    "    def __calc_vel_generator(cls, df, freq=10):\n",
    "        for idx, subdf in df.groupby('TRACK_ID'):\n",
    "            sub_df = subdf.copy()\n",
    "            sub_df[['vel_x', 'vel_y']] = sub_df[['X', 'Y']].diff() * freq\n",
    "            yield sub_df.iloc[1:, :]\n",
    "\n",
    "    @classmethod\n",
    "    def _calc_vel(cls, df, freq=10):\n",
    "        return pd.concat(cls.__calc_vel_generator(df, freq=freq), axis=0)\n",
    "    \n",
    "    @classmethod\n",
    "    def _expand_dim(cls, ndarr, dtype=np.float32):\n",
    "        return np.insert(ndarr, 2, values=0, axis=-1).astype(dtype)\n",
    "    \n",
    "    @classmethod\n",
    "    def _linear_interpolate_generator(cls, data, col=['X', 'Y']):\n",
    "        for idx, df in data.groupby('TRACK_ID'):\n",
    "            sub_df = df.copy()\n",
    "            sub_df[col] = sub_df[col].interpolate(limit_direction='both')\n",
    "            yield sub_df\n",
    "    \n",
    "    @classmethod\n",
    "    def _linear_interpolate(cls, data, col=['X', 'Y']):\n",
    "        return pd.concat(cls._linear_interpolate_generator(data, col), axis=0)\n",
    "    \n",
    "    @classmethod\n",
    "    def _filter_imcomplete_data(cls, data, tstmps, window=20):\n",
    "        complete_id = list()\n",
    "        for idx, subdf in data[data['TIMESTAMP'].isin(tstmps[:window])].groupby('TRACK_ID'):\n",
    "            if len(subdf) == window:\n",
    "                complete_id.append(idx)\n",
    "        return data[data['TRACK_ID'].isin(complete_id)]\n",
    "    \n",
    "    @classmethod\n",
    "    def _expand_particle(cls, arr, max_num, axis, value_type='int'):\n",
    "        dummy_shape = list(arr.shape)\n",
    "        dummy_shape[axis] = max_num - arr.shape[axis]\n",
    "        dummy = np.zeros(dummy_shape)\n",
    "        if value_type == 'str':\n",
    "            dummy = np.array(['dummy' + str(i) for i in range(np.product(dummy_shape))]).reshape(dummy_shape)\n",
    "        return np.concatenate([arr, dummy], axis=axis)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 102,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-10-01T00:49:35.491677Z",
     "start_time": "2020-10-01T00:49:28.506051Z"
    }
   },
   "outputs": [],
   "source": [
    "at = ArgoverseTest(os.path.join(dataset_path, 'test_obs', 'data'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 115,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-10-01T01:23:46.627444Z",
     "start_time": "2020-10-01T00:56:09.635281Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "SAVED ============= 0 / 11269 ....... 0.00034737586975097656\n",
      "SAVED ============= 1000 / 11269 ....... 146.92148852348328\n",
      "SAVED ============= 2000 / 11269 ....... 147.34624314308167\n",
      "SAVED ============= 3000 / 11269 ....... 147.59493041038513\n",
      "SAVED ============= 4000 / 11269 ....... 145.740553855896\n",
      "SAVED ============= 5000 / 11269 ....... 147.39939260482788\n",
      "SAVED ============= 6000 / 11269 ....... 147.2357213497162\n",
      "SAVED ============= 7000 / 11269 ....... 147.13746738433838\n",
      "SAVED ============= 8000 / 11269 ....... 145.92743849754333\n",
      "SAVED ============= 9000 / 11269 ....... 146.34984636306763\n",
      "SAVED ============= 10000 / 11269 ....... 148.0656123161316\n",
      "SAVED ============= 11000 / 11269 ....... 148.17516469955444\n"
     ]
    }
   ],
   "source": [
    "batch_start = time.time()\n",
    "for i, scene in enumerate(lack_files):\n",
    "    if i % 1000 == 0:\n",
    "        batch_end = time.time()\n",
    "        print(\"SAVED ============= {} / {} ....... {}\".format(i, len(lack_files), batch_end - batch_start))\n",
    "        batch_start = time.time()\n",
    "\n",
    "    data = {k:[v] for k, v in at.get_feat(scene).items()}\n",
    "    datas = process_func(putil, data, am)\n",
    "    with open(os.path.join(dataset_path, 'test_obs/lane_data', str(datas['scene_idx'][0])+'.pkl'), 'wb') as f:\n",
    "        pickle.dump(datas, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 111,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-10-01T00:53:39.363416Z",
     "start_time": "2020-10-01T00:53:39.167477Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'city': ['PIT'],\n",
       " 'lane': [array([[1505.4978 ,  231.5036 ,    0.     ],\n",
       "         [1502.8154 ,  230.56395,    0.     ],\n",
       "         [1500.1332 ,  229.62431,    0.     ],\n",
       "         ...,\n",
       "         [1587.8226 ,  247.47318,    0.     ],\n",
       "         [1588.9114 ,  247.42273,    0.     ],\n",
       "         [1590.     ,  247.37227,    0.     ]], dtype=float32)],\n",
       " 'lane_norm': [array([[-2.682357  , -0.93964   ,  0.        ],\n",
       "         [-2.682357  , -0.93964   ,  0.        ],\n",
       "         [-2.682357  , -0.93964   ,  0.        ],\n",
       "         ...,\n",
       "         [ 1.0896193 , -0.02142323,  0.        ],\n",
       "         [ 1.0886799 , -0.05045327,  0.        ],\n",
       "         [ 1.0886799 , -0.05045327,  0.        ]], dtype=float32)],\n",
       " 'scene_idx': [4],\n",
       " 'agent_id': ['00000000-0000-0000-0000-000000062167'],\n",
       " 'car_mask': [array([[1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [0.],\n",
       "         [0.],\n",
       "         [0.],\n",
       "         [0.],\n",
       "         [0.],\n",
       "         [0.],\n",
       "         [0.],\n",
       "         [0.],\n",
       "         [0.],\n",
       "         [0.],\n",
       "         [0.],\n",
       "         [0.],\n",
       "         [0.],\n",
       "         [0.],\n",
       "         [0.],\n",
       "         [0.],\n",
       "         [0.],\n",
       "         [0.],\n",
       "         [0.],\n",
       "         [0.],\n",
       "         [0.],\n",
       "         [0.],\n",
       "         [0.],\n",
       "         [0.],\n",
       "         [0.],\n",
       "         [0.],\n",
       "         [0.],\n",
       "         [0.]], dtype=float32)],\n",
       " 'pos_2s': [array([[[1471.72851562,  212.58920288,    0.        ],\n",
       "          [1471.72851562,  212.58920288,    0.        ],\n",
       "          [1472.8984375 ,  212.99705505,    0.        ],\n",
       "          ...,\n",
       "          [1490.70373535,  219.22186279,    0.        ],\n",
       "          [1492.03833008,  219.68772888,    0.        ],\n",
       "          [1493.36877441,  220.15060425,    0.        ]],\n",
       "  \n",
       "         [[1524.34484863,  241.84013367,    0.        ],\n",
       "          [1524.34484863,  241.84013367,    0.        ],\n",
       "          [1524.27844238,  241.84616089,    0.        ],\n",
       "          ...,\n",
       "          [1522.70812988,  240.5657196 ,    0.        ],\n",
       "          [1521.79455566,  240.24189758,    0.        ],\n",
       "          [1522.04101562,  240.24653625,    0.        ]],\n",
       "  \n",
       "         [[1490.2557373 ,  226.48826599,    0.        ],\n",
       "          [1490.2557373 ,  226.48826599,    0.        ],\n",
       "          [1490.34289551,  226.66937256,    0.        ],\n",
       "          ...,\n",
       "          [1492.09729004,  229.39562988,    0.        ],\n",
       "          [1492.14892578,  229.58010864,    0.        ],\n",
       "          [1492.34936523,  229.79858398,    0.        ]],\n",
       "  \n",
       "         ...,\n",
       "  \n",
       "         [[   0.        ,    0.        ,    0.        ],\n",
       "          [   0.        ,    0.        ,    0.        ],\n",
       "          [   0.        ,    0.        ,    0.        ],\n",
       "          ...,\n",
       "          [   0.        ,    0.        ,    0.        ],\n",
       "          [   0.        ,    0.        ,    0.        ],\n",
       "          [   0.        ,    0.        ,    0.        ]],\n",
       "  \n",
       "         [[   0.        ,    0.        ,    0.        ],\n",
       "          [   0.        ,    0.        ,    0.        ],\n",
       "          [   0.        ,    0.        ,    0.        ],\n",
       "          ...,\n",
       "          [   0.        ,    0.        ,    0.        ],\n",
       "          [   0.        ,    0.        ,    0.        ],\n",
       "          [   0.        ,    0.        ,    0.        ]],\n",
       "  \n",
       "         [[   0.        ,    0.        ,    0.        ],\n",
       "          [   0.        ,    0.        ,    0.        ],\n",
       "          [   0.        ,    0.        ,    0.        ],\n",
       "          ...,\n",
       "          [   0.        ,    0.        ,    0.        ],\n",
       "          [   0.        ,    0.        ,    0.        ],\n",
       "          [   0.        ,    0.        ,    0.        ]]])],\n",
       " 'vel_2s': [array([[[12.98771858,  4.63325071,  0.        ],\n",
       "          [12.98771858,  4.63325071,  0.        ],\n",
       "          [11.69968128,  4.07853413,  0.        ],\n",
       "          ...,\n",
       "          [14.97850323,  5.21672106,  0.        ],\n",
       "          [13.34574223,  4.65865231,  0.        ],\n",
       "          [13.30472374,  4.62867308,  0.        ]],\n",
       "  \n",
       "         [[-1.34566402, -1.41019392,  0.        ],\n",
       "          [-1.34566402, -1.41019392,  0.        ],\n",
       "          [-0.66434371,  0.06033424,  0.        ],\n",
       "          ...,\n",
       "          [13.8757143 ,  3.68315363,  0.        ],\n",
       "          [-9.13572025, -3.23834133,  0.        ],\n",
       "          [ 2.46516275,  0.04652267,  0.        ]],\n",
       "  \n",
       "         [[ 1.28162742,  2.38161683,  0.        ],\n",
       "          [ 1.28162742,  2.38161683,  0.        ],\n",
       "          [ 0.87070954,  1.811064  ,  0.        ],\n",
       "          ...,\n",
       "          [ 1.6546371 ,  2.88599658,  0.        ],\n",
       "          [ 0.51643306,  1.84479105,  0.        ],\n",
       "          [ 2.00441027,  2.1846776 ,  0.        ]],\n",
       "  \n",
       "         ...,\n",
       "  \n",
       "         [[ 0.        ,  0.        ,  0.        ],\n",
       "          [ 0.        ,  0.        ,  0.        ],\n",
       "          [ 0.        ,  0.        ,  0.        ],\n",
       "          ...,\n",
       "          [ 0.        ,  0.        ,  0.        ],\n",
       "          [ 0.        ,  0.        ,  0.        ],\n",
       "          [ 0.        ,  0.        ,  0.        ]],\n",
       "  \n",
       "         [[ 0.        ,  0.        ,  0.        ],\n",
       "          [ 0.        ,  0.        ,  0.        ],\n",
       "          [ 0.        ,  0.        ,  0.        ],\n",
       "          ...,\n",
       "          [ 0.        ,  0.        ,  0.        ],\n",
       "          [ 0.        ,  0.        ,  0.        ],\n",
       "          [ 0.        ,  0.        ,  0.        ]],\n",
       "  \n",
       "         [[ 0.        ,  0.        ,  0.        ],\n",
       "          [ 0.        ,  0.        ,  0.        ],\n",
       "          [ 0.        ,  0.        ,  0.        ],\n",
       "          ...,\n",
       "          [ 0.        ,  0.        ,  0.        ],\n",
       "          [ 0.        ,  0.        ,  0.        ],\n",
       "          [ 0.        ,  0.        ,  0.        ]]])],\n",
       " 'pos0': [array([[1494.70178223,  220.6166687 ,    0.        ],\n",
       "         [1521.75195312,  240.07633972,    0.        ],\n",
       "         [1492.48864746,  230.03465271,    0.        ],\n",
       "         [1499.71618652,  236.0511322 ,    0.        ],\n",
       "         [1479.53796387,  208.68539429,    0.        ],\n",
       "         [1467.40209961,  221.82922363,    0.        ],\n",
       "         [1455.68310547,  217.74343872,    0.        ],\n",
       "         [1444.32128906,  213.85362244,    0.        ],\n",
       "         [1450.7286377 ,  215.9201355 ,    0.        ],\n",
       "         [1481.61572266,  209.02430725,    0.        ],\n",
       "         [1462.42700195,  219.96299744,    0.        ],\n",
       "         [1509.45092773,  218.25204468,    0.        ],\n",
       "         [1490.86022949,  235.40797424,    0.        ],\n",
       "         [1483.25817871,  229.38735962,    0.        ],\n",
       "         [1504.39404297,  206.51811218,    0.        ],\n",
       "         [1497.4954834 ,  214.12487793,    0.        ],\n",
       "         [1568.10461426,  257.8213501 ,    0.        ],\n",
       "         [1581.61022949,  253.20635986,    0.        ],\n",
       "         [1492.76086426,  212.43614197,    0.        ],\n",
       "         [1504.41796875,  214.23857117,    0.        ],\n",
       "         [1488.63354492,  242.02023315,    0.        ],\n",
       "         [1507.0760498 ,  208.19821167,    0.        ],\n",
       "         [   0.        ,    0.        ,    0.        ],\n",
       "         [   0.        ,    0.        ,    0.        ],\n",
       "         [   0.        ,    0.        ,    0.        ],\n",
       "         [   0.        ,    0.        ,    0.        ],\n",
       "         [   0.        ,    0.        ,    0.        ],\n",
       "         [   0.        ,    0.        ,    0.        ],\n",
       "         [   0.        ,    0.        ,    0.        ],\n",
       "         [   0.        ,    0.        ,    0.        ],\n",
       "         [   0.        ,    0.        ,    0.        ],\n",
       "         [   0.        ,    0.        ,    0.        ],\n",
       "         [   0.        ,    0.        ,    0.        ],\n",
       "         [   0.        ,    0.        ,    0.        ],\n",
       "         [   0.        ,    0.        ,    0.        ],\n",
       "         [   0.        ,    0.        ,    0.        ],\n",
       "         [   0.        ,    0.        ,    0.        ],\n",
       "         [   0.        ,    0.        ,    0.        ],\n",
       "         [   0.        ,    0.        ,    0.        ],\n",
       "         [   0.        ,    0.        ,    0.        ],\n",
       "         [   0.        ,    0.        ,    0.        ],\n",
       "         [   0.        ,    0.        ,    0.        ],\n",
       "         [   0.        ,    0.        ,    0.        ],\n",
       "         [   0.        ,    0.        ,    0.        ],\n",
       "         [   0.        ,    0.        ,    0.        ],\n",
       "         [   0.        ,    0.        ,    0.        ],\n",
       "         [   0.        ,    0.        ,    0.        ],\n",
       "         [   0.        ,    0.        ,    0.        ],\n",
       "         [   0.        ,    0.        ,    0.        ],\n",
       "         [   0.        ,    0.        ,    0.        ]])],\n",
       " 'vel0': [array([[ 1.33305054e+01,  4.66071701e+00,  0.00000000e+00],\n",
       "         [-2.89108872e+00, -1.70206892e+00,  0.00000000e+00],\n",
       "         [ 1.39334333e+00,  2.36074400e+00,  0.00000000e+00],\n",
       "         [ 3.08830309e+00,  1.12145448e+00,  0.00000000e+00],\n",
       "         [ 4.88446765e-02,  9.23772436e-03,  0.00000000e+00],\n",
       "         [ 8.22565854e-01, -1.79458205e-02,  0.00000000e+00],\n",
       "         [-4.74196959e+00, -3.97785157e-01,  0.00000000e+00],\n",
       "         [-4.86471429e-02,  2.12336969e+00,  0.00000000e+00],\n",
       "         [-5.81180096e-01, -9.81987894e-01,  0.00000000e+00],\n",
       "         [ 1.62082314e-01,  5.33448569e-02,  0.00000000e+00],\n",
       "         [-3.92243886e+00, -3.19223547e+00,  0.00000000e+00],\n",
       "         [-2.68084317e-01, -5.27247190e-01,  0.00000000e+00],\n",
       "         [-2.94170070e+00, -1.65730059e+00,  0.00000000e+00],\n",
       "         [-8.40914026e-02,  3.89038771e-02,  0.00000000e+00],\n",
       "         [ 8.65505874e-01, -9.13275063e-01,  0.00000000e+00],\n",
       "         [ 2.13372684e+00,  1.87010932e+00,  0.00000000e+00],\n",
       "         [ 1.48126495e+00, -1.38603523e-01,  0.00000000e+00],\n",
       "         [-8.65328312e-01, -3.50441790e+00,  0.00000000e+00],\n",
       "         [ 7.97225654e-01,  1.02660310e+00,  0.00000000e+00],\n",
       "         [ 3.23006582e+00,  9.60197568e-01,  0.00000000e+00],\n",
       "         [-4.85400319e-01, -8.21378171e-01,  0.00000000e+00],\n",
       "         [ 1.52900757e-03,  9.04129446e-02,  0.00000000e+00],\n",
       "         [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00],\n",
       "         [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00],\n",
       "         [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00],\n",
       "         [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00],\n",
       "         [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00],\n",
       "         [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00],\n",
       "         [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00],\n",
       "         [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00],\n",
       "         [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00],\n",
       "         [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00],\n",
       "         [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00],\n",
       "         [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00],\n",
       "         [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00],\n",
       "         [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00],\n",
       "         [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00],\n",
       "         [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00],\n",
       "         [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00],\n",
       "         [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00],\n",
       "         [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00],\n",
       "         [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00],\n",
       "         [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00],\n",
       "         [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00],\n",
       "         [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00],\n",
       "         [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00],\n",
       "         [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00],\n",
       "         [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00],\n",
       "         [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00],\n",
       "         [ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00]])],\n",
       " 'track_id0': [array(['00000000-0000-0000-0000-000000000000',\n",
       "         '00000000-0000-0000-0000-000000062167',\n",
       "         '00000000-0000-0000-0000-000000062188',\n",
       "         '00000000-0000-0000-0000-000000062223',\n",
       "         '00000000-0000-0000-0000-000000062226',\n",
       "         '00000000-0000-0000-0000-000000062233',\n",
       "         '00000000-0000-0000-0000-000000062241',\n",
       "         '00000000-0000-0000-0000-000000062247',\n",
       "         '00000000-0000-0000-0000-000000062249',\n",
       "         '00000000-0000-0000-0000-000000062280',\n",
       "         '00000000-0000-0000-0000-000000062315',\n",
       "         '00000000-0000-0000-0000-000000062328',\n",
       "         '00000000-0000-0000-0000-000000062361',\n",
       "         '00000000-0000-0000-0000-000000062368',\n",
       "         '00000000-0000-0000-0000-000000062374',\n",
       "         '00000000-0000-0000-0000-000000062376',\n",
       "         '00000000-0000-0000-0000-000000062382',\n",
       "         '00000000-0000-0000-0000-000000062388',\n",
       "         '00000000-0000-0000-0000-000000062412',\n",
       "         '00000000-0000-0000-0000-000000062413',\n",
       "         '00000000-0000-0000-0000-000000062414',\n",
       "         '00000000-0000-0000-0000-000000062438', 'dummy0', 'dummy1',\n",
       "         'dummy2', 'dummy3', 'dummy4', 'dummy5', 'dummy6', 'dummy7',\n",
       "         'dummy8', 'dummy9', 'dummy10', 'dummy11', 'dummy12', 'dummy13',\n",
       "         'dummy14', 'dummy15', 'dummy16', 'dummy17', 'dummy18', 'dummy19',\n",
       "         'dummy20', 'dummy21', 'dummy22', 'dummy23', 'dummy24', 'dummy25',\n",
       "         'dummy26', 'dummy27'], dtype=object)],\n",
       " 'frame_id0': [0]}"
      ]
     },
     "execution_count": 111,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data = {k:[v] for k, v in at.get_feat(4).items()}\n",
    "datas = process_func(putil, data, am)\n",
    "datas"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### baseline params"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-30T21:21:44.863877Z",
     "start_time": "2020-09-30T21:21:44.849939Z"
    },
    "code_folding": []
   },
   "outputs": [],
   "source": [
    "class EncoderRNN(nn.Module):\n",
    "    \"\"\"Encoder Network.\"\"\"\n",
    "    def __init__(self,\n",
    "                 input_size: int = 2,\n",
    "                 embedding_size: int = 8,\n",
    "                 hidden_size: int = 16):\n",
    "        \"\"\"Initialize the encoder network.\n",
    "        Args:\n",
    "            input_size: number of features in the input\n",
    "            embedding_size: Embedding size\n",
    "            hidden_size: Hidden size of LSTM\n",
    "        \"\"\"\n",
    "        super(EncoderRNN, self).__init__()\n",
    "        self.hidden_size = hidden_size\n",
    "\n",
    "        self.linear1 = nn.Linear(input_size, embedding_size)\n",
    "        self.lstm1 = nn.LSTMCell(embedding_size, hidden_size)\n",
    "\n",
    "    def forward(self, x: torch.FloatTensor, hidden: Any) -> Any:\n",
    "        \"\"\"Run forward propagation.\n",
    "        Args:\n",
    "            x: input to the network\n",
    "            hidden: initial hidden state\n",
    "        Returns:\n",
    "            hidden: final hidden \n",
    "        \"\"\"\n",
    "        embedded = F.relu(self.linear1(x))\n",
    "        hidden = self.lstm1(embedded, hidden)\n",
    "        return hidden\n",
    "\n",
    "\n",
    "class DecoderRNN(nn.Module):\n",
    "    \"\"\"Decoder Network.\"\"\"\n",
    "    def __init__(self, embedding_size=8, hidden_size=16, output_size=2):\n",
    "        \"\"\"Initialize the decoder network.\n",
    "        Args:\n",
    "            embedding_size: Embedding size\n",
    "            hidden_size: Hidden size of LSTM\n",
    "            output_size: number of features in the output\n",
    "        \"\"\"\n",
    "        super(DecoderRNN, self).__init__()\n",
    "        self.hidden_size = hidden_size\n",
    "\n",
    "        self.linear1 = nn.Linear(output_size, embedding_size)\n",
    "        self.lstm1 = nn.LSTMCell(embedding_size, hidden_size)\n",
    "        self.linear2 = nn.Linear(hidden_size, output_size)\n",
    "\n",
    "    def forward(self, x, hidden):\n",
    "        \"\"\"Run forward propagation.\n",
    "        Args:\n",
    "            x: input to the network\n",
    "            hidden: initial hidden state\n",
    "        Returns:\n",
    "            output: output from lstm\n",
    "            hidden: final hidden state\n",
    "        \"\"\"\n",
    "        embedded = F.relu(self.linear1(x))\n",
    "        hidden = self.lstm1(embedded, hidden)\n",
    "        output = self.linear2(hidden[0])\n",
    "        return output, hidden"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-30T21:24:51.677290Z",
     "start_time": "2020-09-30T21:24:51.671028Z"
    }
   },
   "outputs": [],
   "source": [
    "encoder = EncoderRNN(input_size=5)\n",
    "decoder = DecoderRNN(output_size=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-30T21:26:19.604998Z",
     "start_time": "2020-09-30T21:26:19.598558Z"
    }
   },
   "outputs": [],
   "source": [
    "param = []\n",
    "for params in encoder.parameters():\n",
    "    param.append(np.prod(params.shape))\n",
    "for params in decoder.parameters():\n",
    "    param.append(np.prod(params.shape))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-09-30T21:26:29.121112Z",
     "start_time": "2020-09-30T21:26:29.115317Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "3434"
      ]
     },
     "execution_count": 78,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.sum(param)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-10-01T07:53:38.365492Z",
     "start_time": "2020-10-01T07:53:38.348077Z"
    },
    "code_folding": [
     8
    ]
   },
   "outputs": [],
   "source": [
    "\"Functions loading the .pkl version preprocessed data\"\n",
    "from tensorpack import dataflow\n",
    "from glob import glob\n",
    "import pickle\n",
    "import os\n",
    "import numpy as np\n",
    "\n",
    "\n",
    "class ArgoversePklLoader(dataflow.RNGDataFlow):\n",
    "    def __init__(self, data_path: str, shuffle: bool=True, max_lane_nodes=650, min_lane_nodes=0):\n",
    "        super(ArgoversePklLoader, self).__init__()\n",
    "        self.data_path = data_path\n",
    "        self.shuffle = shuffle\n",
    "        self.max_lane_nodes = max_lane_nodes\n",
    "        self.min_lane_nodes = min_lane_nodes\n",
    "        \n",
    "    def __iter__(self):\n",
    "        pkl_list = glob(os.path.join(self.data_path, '*'))\n",
    "        pkl_list.sort()\n",
    "        if self.shuffle:\n",
    "            self.rng.shuffle(pkl_list)\n",
    "            \n",
    "        for pkl_path in pkl_list:\n",
    "            with open(pkl_path, 'rb') as f:\n",
    "                data = pickle.load(f)\n",
    "                \n",
    "            if data['lane'][0].shape[0] < self.min_lane_nodes:\n",
    "                continue\n",
    "                \n",
    "            if data['lane'][0].shape[0] > self.max_lane_nodes:\n",
    "                continue\n",
    "                \n",
    "            data = {k:v[0] for k, v in data.items()}\n",
    "            lane_mask = np.zeros(self.max_lane_nodes, dtype=np.float32)\n",
    "            lane_mask[:len(data['lane'])] = 1.0\n",
    "            \n",
    "            data['lane'] = self.expand_particle(data['lane'], self.max_lane_nodes, 0)\n",
    "            data['lane_norm'] = self.expand_particle(data['lane_norm'], self.max_lane_nodes, 0)\n",
    "            # data['lane_norm'] /= np.linalg.norm(data['lane_norm'], axis=-1)[...,np.newaxis]\n",
    "            data['lane_mask'] = lane_mask\n",
    "            \n",
    "            yield data\n",
    "            \n",
    "    def __len__(self):\n",
    "        return len(glob(os.path.join(self.data_path, '*')))\n",
    "    \n",
    "    @classmethod\n",
    "    def expand_particle(cls, arr, max_num, axis, value_type='int'):\n",
    "        dummy_shape = list(arr.shape)\n",
    "        dummy_shape[axis] = max_num - arr.shape[axis]\n",
    "        dummy = np.zeros(dummy_shape)\n",
    "        if value_type == 'str':\n",
    "            dummy = np.array(['dummy' + str(i) for i in range(np.product(dummy_shape))]).reshape(dummy_shape)\n",
    "        return np.concatenate([arr, dummy], axis=axis)\n",
    "    \n",
    "\n",
    "def read_pkl_data(data_path: str, batch_size: int, \n",
    "                  shuffle: bool=False, repeat: bool=False, **kwargs):\n",
    "    df = ArgoversePklLoader(data_path=data_path, shuffle=shuffle, **kwargs)\n",
    "    if repeat:\n",
    "        df = dataflow.RepeatedData(df, -1)\n",
    "    df = dataflow.BatchData(df, batch_size=batch_size, use_list=True)\n",
    "    df.reset_state()\n",
    "    return df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-10-01T07:53:45.540928Z",
     "start_time": "2020-10-01T07:53:40.376294Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/leo/.local/lib/python3.6/site-packages/torch/serialization.py:657: SourceChangeWarning: source code of class 'EquiCtsConv.EquiCtsConv2d' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n",
      "  warnings.warn(msg, SourceChangeWarning)\n"
     ]
    }
   ],
   "source": [
    "model_unequi = torch.load('./weights/ctsconv_map2.pth')\n",
    "model_equi = torch.load('./weights/rho1_ctsconv_map.pth')\n",
    "model_reg = torch.load('./weights/reg__ctsconv_map.pth')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-10-01T07:53:46.945999Z",
     "start_time": "2020-10-01T07:53:46.634477Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "9767"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_path = os.path.join(dataset_path, 'test_obs', 'lane_data')\n",
    "dataset = read_pkl_data(test_path, batch_size=8, repeat=False, shuffle=True, max_lane_nodes=1590, min_lane_nodes=0)\n",
    "len(dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-10-01T07:53:49.426535Z",
     "start_time": "2020-10-01T07:53:49.405382Z"
    },
    "code_folding": [
     10
    ]
   },
   "outputs": [],
   "source": [
    "def get_agent(pr, pr_id, agent_id, device='cpu'):\n",
    "        \n",
    "    pr_agent = pr[pr_id == agent_id,:]\n",
    "    \n",
    "    return pr_agent\n",
    "\n",
    "def get_history(scene, agent_id):\n",
    "    df = afl[scene].seq_df\n",
    "    return df[df.TRACK_ID == agent_id][['X', 'Y']].values\n",
    "\n",
    "def unequi_evaluate(model, val_dataset,\n",
    "             train_window=3, device='cpu', \n",
    "             batch_size=32, use_normalize_input=False, normalize_scale=3):\n",
    "    \n",
    "    print('evaluating.. ', end='', flush=True)\n",
    "        \n",
    "    count = 0\n",
    "    prediction_gt = {}\n",
    "    val_iter = iter(val_dataset)\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        for i, sample in enumerate(val_dataset):\n",
    "\n",
    "            pred = []\n",
    "\n",
    "            if count % 10 == 0:\n",
    "                print('{}'.format(count + 1), end=' ', flush=True)\n",
    "\n",
    "            count += 1\n",
    "\n",
    "            data = {}\n",
    "            convert_keys = (['pos0', 'vel0', 'pos_2s', 'vel_2s', 'lane', 'lane_norm', 'car_mask', 'lane_mask'])\n",
    "\n",
    "            for k in convert_keys:\n",
    "                data[k] = torch.tensor(np.stack(sample[k]), dtype=torch.float32, device=device)\n",
    "\n",
    "            for k in ['track_id0', 'city', 'agent_id', 'scene_idx']:\n",
    "                data[k] = np.stack(sample[k])\n",
    "\n",
    "            scenes = data['scene_idx'].tolist()\n",
    "\n",
    "            data['agent_id'] = data['agent_id'][:,np.newaxis]\n",
    "\n",
    "            data['car_mask'] = data['car_mask'].squeeze(-1)\n",
    "            accel = torch.zeros(1, 1, 3).to(device)\n",
    "            data['accel'] = accel\n",
    "\n",
    "            lane = data['lane']\n",
    "            lane_normals = data['lane_norm']\n",
    "            agent_id = data['agent_id']\n",
    "            city = data['city']\n",
    "\n",
    "            inputs = ([\n",
    "                data['pos_2s'], data['vel_2s'], \n",
    "                data['pos0'], data['vel0'], \n",
    "                data['accel'], None,\n",
    "                data['lane'], data['lane_norm'], \n",
    "                data['car_mask'], data['lane_mask']\n",
    "            ])\n",
    "\n",
    "            pr_pos1, pr_vel1, states = model(inputs)\n",
    "\n",
    "            pr_agent = get_agent(pr_pos1, data['track_id0'], agent_id, device)\n",
    "\n",
    "            pred.append(pr_agent.unsqueeze(1).detach().cpu().numpy())\n",
    "            del pr_agent\n",
    "\n",
    "            # pr_direction = get_lane_direction(\n",
    "            #     pr_pos1, batch['city'][batch_i], am\n",
    "            # )\n",
    "            pos_2s = data['pos_2s']\n",
    "            vel_2s = data['vel_2s']\n",
    "            pos0 = data['pos0']\n",
    "            vel0 = data['vel0']\n",
    "            for i in range(29):\n",
    "                pos_enc = torch.unsqueeze(pos0, 2)\n",
    "                # pos_2s = torch.cat([pos_2s[:,:,1:,:], pos_enc], axis=2)\n",
    "                vel_enc = torch.unsqueeze(vel0, 2)\n",
    "                # vel_2s = torch.cat([vel_2s[:,:,1:,:], vel_enc], axis=2)\n",
    "                inputs = (pos_enc, vel_enc, pr_pos1, pr_vel1, data['accel'], None, \n",
    "                          data['lane'], data['lane_norm'], data['car_mask'], data['lane_mask'])\n",
    "                pos0, vel0 = pr_pos1, pr_vel1\n",
    "                pr_pos1, pr_vel1, states = model(inputs, states)\n",
    "\n",
    "                pr_agent = get_agent(pr_pos1, data['track_id0'], agent_id, device)\n",
    "\n",
    "                pred.append(pr_agent.unsqueeze(1).detach().cpu().numpy())\n",
    "\n",
    "                clean_cache(device)\n",
    "\n",
    "            predict_result = np.concatenate(pred, axis=1)\n",
    "            for idx, scene_id in enumerate(scenes):\n",
    "                prediction_gt[afl[scene_id].current_seq.name[:-4]] = np.array([predict_result[idx]] * 6)\n",
    "\n",
    "    return prediction_gt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-10-01T11:42:53.213694Z",
     "start_time": "2020-10-01T07:53:54.332230Z"
    },
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "evaluating.. 1 11 21 31 41 51 61 71 81 91 101 111 121 131 141 151 161 171 181 191 201 211 221 231 241 251 261 271 281 291 301 311 321 331 341 351 361 371 381 391 401 411 421 431 441 451 461 471 481 491 501 511 521 531 541 551 561 571 581 591 601 611 621 631 641 651 661 671 681 691 701 711 721 731 741 751 761 771 781 791 801 811 821 831 841 851 861 871 881 891 901 911 921 931 941 951 961 971 981 991 1001 1011 1021 1031 1041 1051 1061 1071 1081 1091 1101 1111 1121 1131 1141 1151 1161 1171 1181 1191 1201 1211 1221 1231 1241 1251 1261 1271 1281 1291 1301 1311 1321 1331 1341 1351 1361 1371 1381 1391 1401 1411 1421 1431 1441 1451 1461 1471 1481 1491 1501 1511 1521 1531 1541 1551 1561 1571 1581 1591 1601 1611 1621 1631 1641 1651 1661 1671 1681 1691 1701 1711 1721 1731 1741 1751 1761 1771 1781 1791 1801 1811 1821 1831 1841 1851 1861 1871 1881 1891 1901 1911 1921 1931 1941 1951 1961 1971 1981 1991 2001 2011 2021 2031 2041 2051 2061 2071 2081 2091 2101 2111 2121 2131 2141 2151 2161 2171 2181 2191 2201 2211 2221 2231 2241 2251 2261 2271 2281 2291 2301 2311 2321 2331 2341 2351 2361 2371 2381 2391 2401 2411 2421 2431 2441 2451 2461 2471 2481 2491 2501 2511 2521 2531 2541 2551 2561 2571 2581 2591 2601 2611 2621 2631 2641 2651 2661 2671 2681 2691 2701 2711 2721 2731 2741 2751 2761 2771 2781 2791 2801 2811 2821 2831 2841 2851 2861 2871 2881 2891 2901 2911 2921 2931 2941 2951 2961 2971 2981 2991 3001 3011 3021 3031 3041 3051 3061 3071 3081 3091 3101 3111 3121 3131 3141 3151 3161 3171 3181 3191 3201 3211 3221 3231 3241 3251 3261 3271 3281 3291 3301 3311 3321 3331 3341 3351 3361 3371 3381 3391 3401 3411 3421 3431 3441 3451 3461 3471 3481 3491 3501 3511 3521 3531 3541 3551 3561 3571 3581 3591 3601 3611 3621 3631 3641 3651 3661 3671 3681 3691 3701 3711 3721 3731 3741 3751 3761 3771 3781 3791 3801 3811 3821 3831 3841 3851 3861 3871 3881 3891 3901 3911 3921 3931 3941 3951 3961 3971 3981 3991 4001 4011 4021 4031 4041 4051 4061 4071 4081 4091 4101 4111 4121 4131 4141 4151 4161 4171 4181 4191 4201 4211 4221 4231 4241 4251 4261 4271 4281 4291 4301 4311 4321 4331 4341 4351 4361 4371 4381 4391 4401 4411 4421 4431 4441 4451 4461 4471 4481 4491 4501 4511 4521 4531 4541 4551 4561 4571 4581 4591 4601 4611 4621 4631 4641 4651 4661 4671 4681 4691 4701 4711 4721 4731 4741 4751 4761 4771 4781 4791 4801 4811 4821 4831 4841 4851 4861 4871 4881 4891 4901 4911 4921 4931 4941 4951 4961 4971 4981 4991 5001 5011 5021 5031 5041 5051 5061 5071 5081 5091 5101 5111 5121 5131 5141 5151 5161 5171 5181 5191 5201 5211 5221 5231 5241 5251 5261 5271 5281 5291 5301 5311 5321 5331 5341 5351 5361 5371 5381 5391 5401 5411 5421 5431 5441 5451 5461 5471 5481 5491 5501 5511 5521 5531 5541 5551 5561 5571 5581 5591 5601 5611 5621 5631 5641 5651 5661 5671 5681 5691 5701 5711 5721 5731 5741 5751 5761 5771 5781 5791 5801 5811 5821 5831 5841 5851 5861 5871 5881 5891 5901 5911 5921 5931 5941 5951 5961 5971 5981 5991 6001 6011 6021 6031 6041 6051 6061 6071 6081 6091 6101 6111 6121 6131 6141 6151 6161 6171 6181 6191 6201 6211 6221 6231 6241 6251 6261 6271 6281 6291 6301 6311 6321 6331 6341 6351 6361 6371 6381 6391 6401 6411 6421 6431 6441 6451 6461 6471 6481 6491 6501 6511 6521 6531 6541 6551 6561 6571 6581 6591 6601 6611 6621 6631 6641 6651 6661 6671 6681 6691 6701 6711 6721 6731 6741 6751 6761 6771 6781 6791 6801 6811 6821 6831 6841 6851 6861 6871 6881 6891 6901 6911 6921 6931 6941 6951 6961 6971 6981 6991 7001 7011 7021 7031 7041 7051 7061 7071 7081 7091 7101 7111 7121 7131 7141 7151 7161 7171 7181 7191 7201 7211 7221 7231 7241 7251 7261 7271 7281 7291 7301 7311 7321 7331 7341 7351 7361 7371 7381 7391 7401 7411 7421 7431 7441 7451 7461 7471 7481 7491 7501 7511 7521 7531 7541 7551 7561 7571 7581 7591 7601 7611 7621 7631 7641 7651 7661 7671 7681 7691 7701 7711 7721 7731 7741 7751 7761 7771 7781 7791 7801 7811 7821 7831 7841 7851 7861 7871 7881 7891 7901 7911 7921 7931 7941 7951 7961 7971 7981 7991 8001 8011 8021 8031 8041 8051 8061 8071 8081 8091 8101 8111 8121 8131 8141 8151 8161 8171 8181 8191 8201 8211 8221 8231 8241 8251 8261 8271 8281 8291 8301 8311 8321 8331 8341 8351 8361 8371 8381 8391 8401 8411 8421 8431 8441 8451 8461 8471 8481 8491 8501 8511 8521 8531 8541 8551 8561 8571 8581 8591 8601 8611 8621 8631 8641 8651 8661 8671 8681 8691 8701 8711 8721 8731 8741 8751 8761 8771 8781 8791 8801 8811 8821 8831 8841 8851 8861 8871 8881 8891 8901 8911 8921 8931 8941 8951 8961 8971 8981 8991 9001 9011 9021 9031 9041 9051 9061 9071 9081 9091 9101 9111 9121 9131 9141 9151 9161 9171 9181 9191 9201 9211 9221 9231 9241 9251 9261 9271 9281 9291 9301 9311 9321 9331 9341 9351 9361 9371 9381 9391 9401 9411 9421 9431 9441 9451 9461 9471 9481 9491 9501 9511 9521 9531 9541 9551 9561 9571 9581 9591 9601 9611 9621 9631 9641 9651 9661 9671 9681 9691 9701 9711 9721 9731 9741 9751 9761 "
     ]
    }
   ],
   "source": [
    "model_unequi = model_unequi.to(device)\n",
    "model_unequi.eval()\n",
    "unequi_result = unequi_evaluate(model_unequi, dataset, device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-10-01T06:27:58.010375Z",
     "start_time": "2020-10-01T06:27:57.996770Z"
    },
    "code_folding": [
     0
    ]
   },
   "outputs": [],
   "source": [
    "class temp_loader(object):\n",
    "    def __init__(self):\n",
    "        self.max_lane_nodes = 1590\n",
    "        \n",
    "    def __iter__(self):\n",
    "        for scene in set(range(78143)).difference(unequi_all_result.keys()):\n",
    "            pkl_path = os.path.join(test_path, str(scene) + '.pkl')\n",
    "            with open(pkl_path, 'rb') as f:\n",
    "                data = pickle.load(f)\n",
    "\n",
    "            lane_mask = np.zeros(self.max_lane_nodes, dtype=np.float32)\n",
    "            lane_mask[:len(data['lane'][0])] = 1.0\n",
    "\n",
    "            data['lane'] = [ArgoversePklLoader.expand_particle(data['lane'][0], 1590, 0)]\n",
    "            data['lane_norm'] = [ArgoversePklLoader.expand_particle(data['lane_norm'][0], 1590, 0)]\n",
    "            # data['lane_norm'] /= np.linalg.norm(data['lane_norm'], axis=-1)[...,np.newaxis]\n",
    "            data['lane_mask'] = [lane_mask]\n",
    "            yield data\n",
    "            \n",
    "    def __len__(self):\n",
    "        return len(set(range(78143)).difference(unequi_all_result.keys()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-10-01T07:43:03.892436Z",
     "start_time": "2020-10-01T07:42:56.771872Z"
    }
   },
   "outputs": [],
   "source": [
    "lane_num = []\n",
    "for data in dataset:\n",
    "    lane_num.append(data['lane_mask'][0].sum())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-10-01T07:43:03.896892Z",
     "start_time": "2020-10-01T07:43:03.893919Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "78143"
      ]
     },
     "execution_count": 47,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(lane_num)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-10-01T15:11:53.241775Z",
     "start_time": "2020-10-01T15:11:53.186133Z"
    }
   },
   "outputs": [],
   "source": [
    "unequi_all_result = {k: v[...,:2] for k, v in unequi_result.items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "start_time": "2020-10-01T15:11:54.473Z"
    },
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "27879/78136"
     ]
    }
   ],
   "source": [
    "from argoverse.evaluation.competition_util import generate_forecasting_h5\n",
    "\n",
    "output_path = 'competition_file/'\n",
    "\n",
    "generate_forecasting_h5(unequi_all_result, output_path, 'unequi_result')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-10-01T11:42:53.715301Z",
     "start_time": "2020-10-01T07:54:13.048Z"
    },
    "code_folding": [
     0
    ]
   },
   "outputs": [],
   "source": [
    "def equi_evaluate(model, val_dataset,\n",
    "             train_window=3, device='cpu', \n",
    "             batch_size=32, use_normalize_input=False, normalize_scale=3):\n",
    "    \n",
    "    print('evaluating.. ', end='', flush=True)\n",
    "        \n",
    "    count = 0\n",
    "    prediction_gt = {}\n",
    "    val_iter = iter(val_dataset)\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        for i, sample in enumerate(val_dataset):\n",
    "\n",
    "            pred = []\n",
    "\n",
    "            if count % 10 == 0:\n",
    "                print('{}'.format(count + 1), end=' ', flush=True)\n",
    "\n",
    "            count += 1\n",
    "\n",
    "            data = {}\n",
    "            convert_keys = (['pos0', 'vel0', 'pos_2s', 'vel_2s', 'lane', 'lane_norm'])\n",
    "\n",
    "            for k in convert_keys:\n",
    "                data[k] = torch.tensor(np.stack(sample[k])[...,:2], dtype=torch.float32, device=device)\n",
    "\n",
    "            for k in ['track_id0', 'city', 'agent_id', 'scene_idx']:\n",
    "                data[k] = np.stack(sample[k])\n",
    "                \n",
    "            for k in ['car_mask', 'lane_mask']:\n",
    "                data[k] = torch.tensor(np.stack(sample[k]), dtype=torch.float32, device=device).unsqueeze(-1)\n",
    "\n",
    "            scenes = data['scene_idx'].tolist()\n",
    "\n",
    "            data['agent_id'] = data['agent_id'][:,np.newaxis]\n",
    "\n",
    "            data['car_mask'] = data['car_mask'].squeeze(-1)\n",
    "            accel = torch.zeros(1, 1, 2).to(device)\n",
    "            data['accel'] = accel\n",
    "\n",
    "            lane = data['lane']\n",
    "            lane_normals = data['lane_norm']\n",
    "            agent_id = data['agent_id']\n",
    "            city = data['city']\n",
    "\n",
    "            inputs = ([\n",
    "                data['pos_2s'], data['vel_2s'], \n",
    "                data['pos0'], data['vel0'], \n",
    "                data['accel'], None,\n",
    "                data['lane'], data['lane_norm'], \n",
    "                data['car_mask'], data['lane_mask']\n",
    "            ])\n",
    "\n",
    "            pr_pos1, pr_vel1, states = model(inputs)\n",
    "\n",
    "            pr_agent = get_agent(pr_pos1, data['track_id0'], agent_id, device)\n",
    "\n",
    "            pred.append(pr_agent.unsqueeze(1).detach().cpu().numpy())\n",
    "            del pr_agent\n",
    "\n",
    "            # pr_direction = get_lane_direction(\n",
    "            #     pr_pos1, batch['city'][batch_i], am\n",
    "            # )\n",
    "            pos_2s = data['pos_2s']\n",
    "            vel_2s = data['vel_2s']\n",
    "            pos0 = data['pos0']\n",
    "            vel0 = data['vel0']\n",
    "            for i in range(29):\n",
    "                pos_enc = torch.unsqueeze(pos0, 2)\n",
    "                # pos_2s = torch.cat([pos_2s[:,:,1:,:], pos_enc], axis=2)\n",
    "                vel_enc = torch.unsqueeze(vel0, 2)\n",
    "                # vel_2s = torch.cat([vel_2s[:,:,1:,:], vel_enc], axis=2)\n",
    "                inputs = (pos_enc, vel_enc, pr_pos1, pr_vel1, data['accel'], None, \n",
    "                          data['lane'], data['lane_norm'], data['car_mask'], data['lane_mask'])\n",
    "                pos0, vel0 = pr_pos1, pr_vel1\n",
    "                pr_pos1, pr_vel1, states = model(inputs, states)\n",
    "\n",
    "                pr_agent = get_agent(pr_pos1, data['track_id0'], agent_id, device)\n",
    "\n",
    "                pred.append(pr_agent.unsqueeze(1).detach().cpu().numpy())\n",
    "\n",
    "                clean_cache(device)\n",
    "\n",
    "            predict_result = np.concatenate(pred, axis=1)\n",
    "            for idx, scene_id in enumerate(scenes):\n",
    "                prediction_gt[afl[scene_id].current_seq.name[:-4]] = np.array([predict_result[idx]] * 6)\n",
    "\n",
    "    return prediction_gt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-10-01T11:42:53.715914Z",
     "start_time": "2020-10-01T07:54:13.667Z"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "model_equi = model_equi.to(device)\n",
    "model_equi.eval()\n",
    "equi_result = equi_evaluate(model_equi, dataset, device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-10-01T11:42:53.716405Z",
     "start_time": "2020-10-01T07:54:29.766Z"
    }
   },
   "outputs": [],
   "source": [
    "model_reg = model_reg.to(device)\n",
    "model_reg.eval()\n",
    "reg_result = equi_evaluate(model_reg, dataset, device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-10-01T11:42:53.716999Z",
     "start_time": "2020-10-01T07:54:34.217Z"
    }
   },
   "outputs": [],
   "source": [
    "generate_forecasting_h5(equi_result, output_path, 'rho1_result')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-10-01T11:42:53.717446Z",
     "start_time": "2020-10-01T07:54:34.681Z"
    }
   },
   "outputs": [],
   "source": [
    "generate_forecasting_h5(reg_result, output_path, 'reg_result')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
