{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-08-30T22:37:04.877045Z",
     "start_time": "2020-08-30T22:37:04.867494Z"
    }
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import sys\n",
    "sys.path.append('..')\n",
    "sys.path.append('../scripts')\n",
    "from datasets.argoverse_pickle_loader import read_pkl_data\n",
    "from datasets.helper import get_lane_direction\n",
    "from collections import namedtuple\n",
    "from glob import glob\n",
    "import time\n",
    "import gc\n",
    "import pickle\n",
    "from utils.deeplearningutilities.tf import Trainer, MyCheckpointManager\n",
    "from argoverse.map_representation.map_api import ArgoverseMap\n",
    "from train_utils import *\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torch.nn.functional as F\n",
    "from torch.autograd import Variable\n",
    "\n",
    "from EquiCtsConv import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-08-30T22:37:05.258721Z",
     "start_time": "2020-08-30T22:37:05.251084Z"
    }
   },
   "outputs": [],
   "source": [
    "dataset_path = '~/particle/argoverse/argoverse_forecasting/'\n",
    "lane_path = '~/particle/TrafficFluids/datasets/'\n",
    "\n",
    "val_path = os.path.join(dataset_path, 'val', 'clean_data')\n",
    "train_path = os.path.join(dataset_path, 'train', 'clean_data')\n",
    "\n",
    "TrainParams = namedtuple('TrainParams', ['epochs', 'batches_per_epoch', 'base_lr', 'batch_size'])\n",
    "train_params = TrainParams(50, 300, 0.001, 16)\n",
    "\n",
    "train_window = 3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-08-30T22:37:05.602252Z",
     "start_time": "2020-08-30T22:37:05.598223Z"
    }
   },
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "device = torch.device('cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-08-30T22:37:07.014067Z",
     "start_time": "2020-08-30T22:37:06.879025Z"
    }
   },
   "outputs": [],
   "source": [
    "def create_model():\n",
    "    from models.equivariant_model import ParticlesNetwork\n",
    "    \"\"\"Returns an instance of the network for training and evaluation\"\"\"\n",
    "    model = ParticlesNetwork(radius_scale = 40, num_theta=32)\n",
    "    return model\n",
    "model = create_model().to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-08-30T22:38:44.122741Z",
     "start_time": "2020-08-30T22:38:37.355236Z"
    }
   },
   "outputs": [],
   "source": [
    "am = ArgoverseMap()\n",
    "\n",
    "val_dataset = read_pkl_data(val_path, batch_size=32, shuffle=False, repeat=False)\n",
    "\n",
    "dataset = read_pkl_data(train_path, batch_size=1, repeat=True, shuffle=True)\n",
    "\n",
    "data_iter = iter(dataset)   "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-08-30T05:05:38.261761Z",
     "start_time": "2020-08-30T05:05:37.805134Z"
    }
   },
   "outputs": [],
   "source": [
    "batch = next(data_iter)\n",
    "\n",
    "batch_size = len(batch['pos0'])\n",
    "\n",
    "batch['lane_mask'] = [np.array([0])] * batch_size\n",
    "\n",
    "batch_tensor = {}\n",
    "convert_keys = (['pos' + str(i) for i in range(train_window + 1)] + \n",
    "                ['vel' + str(i) for i in range(train_window + 1)] + \n",
    "                ['pos_2s', 'vel_2s', 'lane', 'lane_norm'])\n",
    "\n",
    "for k in convert_keys:\n",
    "    batch_tensor[k] = torch.tensor(np.stack(batch[k])[...,:2], dtype=torch.float32, device=device)\n",
    "\n",
    "for k in ['car_mask', 'lane_mask']:\n",
    "    batch_tensor[k] = torch.tensor(np.stack(batch[k]), dtype=torch.float32, device=device).unsqueeze(-1)\n",
    "\n",
    "for k in ['track_id' + str(i) for i in range(30)] + ['city']:\n",
    "    batch_tensor[k] = batch[k]\n",
    "\n",
    "batch_tensor['car_mask'] = batch_tensor['car_mask'].squeeze(-1)\n",
    "accel = torch.zeros(batch_size, 1, 2).to(device)\n",
    "batch_tensor['accel'] = accel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-08-30T05:05:38.707761Z",
     "start_time": "2020-08-30T05:05:38.703691Z"
    }
   },
   "outputs": [],
   "source": [
    "inputs = ([\n",
    "            batch_tensor['pos_2s'], batch_tensor['vel_2s'], \n",
    "            batch_tensor['pos0'], batch_tensor['vel0'], \n",
    "            batch_tensor['accel'], None,\n",
    "            batch_tensor['lane'], batch_tensor['lane_norm'], \n",
    "            batch_tensor['car_mask'], batch_tensor['lane_mask']\n",
    "        ])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-08-30T05:05:39.084135Z",
     "start_time": "2020-08-30T05:05:39.077970Z"
    }
   },
   "outputs": [],
   "source": [
    "def rotate_field(theta, field):\n",
    "    rot_mat = EquiCtsConv2d.RotMat(torch.tensor(theta)).to(field.device)\n",
    "    rotated_field = torch.einsum('ij,...j->...i', rot_mat, field)\n",
    "    return rotated_field\n",
    "\n",
    "def r90(field):\n",
    "    return rotate_field(np.pi / 2, field)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-08-30T05:05:39.603430Z",
     "start_time": "2020-08-30T05:05:39.595049Z"
    }
   },
   "outputs": [],
   "source": [
    "rotated_inputs = ([\n",
    "            r90(batch_tensor['pos_2s']), r90(batch_tensor['vel_2s']), \n",
    "            r90(batch_tensor['pos0']), r90(batch_tensor['vel0']), \n",
    "            r90(batch_tensor['accel']), None,\n",
    "            r90(batch_tensor['lane']), r90(batch_tensor['lane_norm']), \n",
    "            batch_tensor['car_mask'], batch_tensor['lane_mask']\n",
    "        ])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-08-30T05:05:40.225943Z",
     "start_time": "2020-08-30T05:05:40.213935Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-1.,  0.]])"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "r90(rotate_field(-np.pi / 2, torch.tensor([[-1.0, 0.0]])))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "heading_collapsed": true
   },
   "source": [
    "### new activation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-08-27T05:49:46.069883Z",
     "start_time": "2020-08-27T05:49:42.566469Z"
    },
    "hidden": true
   },
   "outputs": [],
   "source": [
    "model.eval()\n",
    "with torch.no_grad():\n",
    "    pr_pos1, pr_vel1, (states, _) = model(inputs)\n",
    "    rpr_pos1, rpr_vel1, (rstates, _) = model(rotated_inputs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-08-27T05:49:48.553746Z",
     "start_time": "2020-08-27T05:49:48.546034Z"
    },
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.1920)"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.norm(rotate_field(-np.pi / 2, rpr_pos1) - pr_pos1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-08-27T05:49:52.157068Z",
     "start_time": "2020-08-27T05:49:52.152548Z"
    },
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor(1.9200), tensor(94.0037))"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.norm(rotate_field(-np.pi / 2, rpr_vel1) - pr_vel1), torch.norm(pr_vel1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-08-27T05:49:59.382109Z",
     "start_time": "2020-08-27T05:49:59.375032Z"
    },
    "hidden": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.)"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.norm(rotate_field(-np.pi / 2, rstates) - states)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### F.relu"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-08-30T05:06:43.743508Z",
     "start_time": "2020-08-30T05:06:40.452263Z"
    }
   },
   "outputs": [],
   "source": [
    "model.eval()\n",
    "with torch.no_grad():\n",
    "    pr_pos1, pr_vel1, (states, _) = model(inputs)\n",
    "    rpr_pos1, rpr_vel1, (rstates, _) = model(rotated_inputs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-08-30T05:06:45.102474Z",
     "start_time": "2020-08-30T05:06:45.094985Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.1919)"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.norm(rotate_field(-np.pi / 2, rpr_pos1) - pr_pos1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-08-30T05:06:47.730940Z",
     "start_time": "2020-08-30T05:06:47.721567Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor(1.9189), tensor(84.5314))"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.norm(rotate_field(-np.pi / 2, rpr_vel1) - pr_vel1), torch.norm(pr_vel1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-08-30T05:06:49.997844Z",
     "start_time": "2020-08-30T05:06:49.989458Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.)"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.norm(rotate_field(-np.pi / 2, rstates) - states)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### No activation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-08-26T22:59:46.653713Z",
     "start_time": "2020-08-26T22:59:46.546692Z"
    }
   },
   "outputs": [],
   "source": [
    "model.eval()\n",
    "with torch.no_grad():\n",
    "    pr_pos1, pr_vel1, (states, _) = model(inputs)\n",
    "    rpr_pos1, rpr_vel1, (rstates, _) = model(rotated_inputs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-08-26T22:59:47.419223Z",
     "start_time": "2020-08-26T22:59:47.409958Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.1829, device='cuda:0')"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.norm(rotate_field(-np.pi / 2, rpr_pos1) - pr_pos1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-08-26T22:59:48.400197Z",
     "start_time": "2020-08-26T22:59:48.388887Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor(1.8293, device='cuda:0'), tensor(82.0080, device='cuda:0'))"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.norm(rotate_field(-np.pi / 2, rpr_vel1) - pr_vel1), torch.norm(pr_vel1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-08-26T22:59:49.103811Z",
     "start_time": "2020-08-26T22:59:49.095624Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0., device='cuda:0')"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.norm(rotate_field(-np.pi / 2, rstates) - states)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Step-by-step comparison"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-08-30T05:05:49.241554Z",
     "start_time": "2020-08-30T05:05:49.236411Z"
    }
   },
   "outputs": [],
   "source": [
    "def compare_equi(rfeat, feat):\n",
    "    return torch.norm(rotate_field(-np.pi / 2, rfeat) - feat)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-08-30T05:13:40.803789Z",
     "start_time": "2020-08-30T05:13:37.281121Z"
    },
    "code_folding": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1 tensor(0.)\n",
      "2 tensor(0.)\n",
      "3 tensor(0.)\n",
      "4 tensor(0.0001)\n",
      "5 tensor(1.1335e-05)\n",
      "6 tensor(0.)\n",
      "7 tensor(0.0001)\n",
      "8 tensor(0.0001)\n",
      "9 tensor(7.7481e-05)\n",
      "10 tensor(6.8446e-05)\n",
      "11 tensor(9.6663e-05)\n",
      "12 tensor(7.4671e-05)\n",
      "13 tensor(5.8527e-05)\n",
      "14 tensor(0.0001)\n",
      "15 tensor(0.0001)\n",
      "16 tensor(9.1270e-05)\n",
      "17 tensor(0.0002)\n",
      "18 tensor(3.7308e-05)\n",
      "19 tensor(2.0197e-05)\n",
      "20 tensor(4.5504e-05)\n"
     ]
    }
   ],
   "source": [
    "with torch.no_grad():\n",
    "    i = 0\n",
    "    _, other_feats, p, v, __, ___, box, box_feats, fluid_mask, box_mask = inputs\n",
    "    _, rother_feats, rp, rv, __, ___, rbox, rbox_feats, rfluid_mask, rbox_mask = rotated_inputs\n",
    "    \"\"\"Precondition: p and v were updated with accerlation\"\"\"\n",
    "\n",
    "    i+=1;print(i, compare_equi(rv, v))\n",
    "    fluid_feats = [v.unsqueeze(-2)]\n",
    "    rfluid_feats = [rv.unsqueeze(-2)]\n",
    "    i+=1;print(i, compare_equi(rv.unsqueeze(-2), v.unsqueeze(-2)))\n",
    "    if not other_feats is None:\n",
    "        fluid_feats.append(other_feats)\n",
    "        rfluid_feats.append(rother_feats)\n",
    "    fluid_feats = torch.cat(fluid_feats, -2)\n",
    "    rfluid_feats = torch.cat(rfluid_feats, -2)\n",
    "    i+=1;print(i, compare_equi(rfluid_feats, fluid_feats))\n",
    "\n",
    "    # compute the correction by accumulating the output through the network layers\n",
    "    output_conv_fluid = model.conv_fluid(p, p, fluid_feats, fluid_mask)\n",
    "    routput_conv_fluid = model.conv_fluid(rp, rp, rfluid_feats, rfluid_mask)\n",
    "    i+=1;print(i, compare_equi(routput_conv_fluid, output_conv_fluid))\n",
    "\n",
    "    output_dense_fluid = model.dense_fluid(fluid_feats)\n",
    "    routput_dense_fluid = model.dense_fluid(rfluid_feats)\n",
    "    i+=1;print(i, compare_equi(routput_dense_fluid, output_dense_fluid))\n",
    "\n",
    "    output_conv_obstacle = model.conv_obstacle(box, p, box_feats.unsqueeze(-2), box_mask)\n",
    "    routput_conv_obstacle = model.conv_obstacle(rbox, rp, rbox_feats.unsqueeze(-2), rbox_mask)\n",
    "    i+=1;print(i, compare_equi(routput_conv_obstacle, output_conv_obstacle))\n",
    "\n",
    "    feats = torch.cat((output_conv_obstacle, output_conv_fluid, output_dense_fluid), -2)\n",
    "    rfeats = torch.cat((routput_conv_obstacle, routput_conv_fluid, routput_dense_fluid), -2)\n",
    "    i+=1;print(i, compare_equi(rfeats, feats))\n",
    "    # self.outputs = [feats]\n",
    "    output = feats\n",
    "    routput = rfeats\n",
    "\n",
    "    for conv, dense in zip(model.convs, model.denses):\n",
    "        # pass input features to conv and fully-connected layers\n",
    "        mags = (torch.sum(output**2,axis=-1) + 1e-6).unsqueeze(-1)\n",
    "        in_feats = output/mags * model.activation(mags - model.relu_shift)\n",
    "        rmags = (torch.sum(routput**2,axis=-1) + 1e-6).unsqueeze(-1)\n",
    "        rin_feats = routput/rmags * model.activation(rmags - model.relu_shift)\n",
    "\n",
    "        # in_feats = model.activation(output)\n",
    "        # rin_feats = model.activation(routput)\n",
    "\n",
    "        # in_feats = output\n",
    "        # rin_feats = routput\n",
    "\n",
    "        i+=1;print(i, compare_equi(rin_feats, in_feats))\n",
    "\n",
    "        output_conv = conv(p, p, in_feats, fluid_mask)\n",
    "        routput_conv = conv(rp, rp, rin_feats, rfluid_mask)\n",
    "        i+=1;print(i, compare_equi(routput_conv, output_conv))\n",
    "\n",
    "        output_dense = dense(in_feats)\n",
    "        routput_dense = dense(rin_feats)\n",
    "        i+=1;print(i, compare_equi(routput_dense, output_dense))\n",
    "\n",
    "        # if last dim size of output from cur dense layer is same as last dim size of output\n",
    "        # current output should be based off on previous output\n",
    "        if output_dense.shape[-2] == output.shape[-2]:\n",
    "            output = output_conv + output_dense + output\n",
    "            routput = routput_conv + routput_dense + routput\n",
    "        else:\n",
    "            output = output_conv + output_dense\n",
    "            routput = routput_conv + routput_dense\n",
    "        # self.outputs.append(output)\n",
    "\n",
    "    i+=1;print(i, compare_equi(routput, output))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-08-27T06:39:36.809105Z",
     "start_time": "2020-08-27T06:39:36.801003Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([-0.2366,  0.2920]), tensor([-0.2366,  0.2920]))"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "EquiCtsConv2d()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
