{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-08-03T19:45:08.453405Z",
     "start_time": "2020-08-03T19:45:06.920783Z"
    }
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torch.nn.functional as F\n",
    "import argoverse\n",
    "\n",
    "import sys\n",
    "import os\n",
    "sys.path.append(\"../datasets/\")\n",
    "from argoverse_pickle_loader import read_pkl_data\n",
    "\n",
    "from CtsConv import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-08-03T19:45:09.050506Z",
     "start_time": "2020-08-03T19:45:09.021920Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cuda:0\n"
     ]
    }
   ],
   "source": [
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "print(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-08-03T19:45:09.691680Z",
     "start_time": "2020-08-03T19:45:09.662692Z"
    },
    "code_folding": []
   },
   "outputs": [],
   "source": [
    "class ParticlesNetwork(nn.Module):\n",
    "    def __init__(self, \n",
    "                 kernel_sizes = [4, 4, 4],\n",
    "                 radius_scale = 40,\n",
    "                 coordinate_mapping = 'ball_to_cube',\n",
    "                 interpolation = 'linear',\n",
    "                 use_window = True,\n",
    "                 particle_radius = 0.5,\n",
    "                 timestep = 0.1,\n",
    "                 encoder_func = 'lstm',\n",
    "                 encoder_hidden_size = 32\n",
    "                 ):\n",
    "        super(ParticlesNetwork, self).__init__()\n",
    "        \n",
    "        # init parameters\n",
    "        \n",
    "        self.kernel_sizes = kernel_sizes\n",
    "        self.radius_scale = radius_scale\n",
    "        self.coordinate_mapping = coordinate_mapping\n",
    "        self.interpolation = interpolation\n",
    "        self.use_window = use_window\n",
    "        self.particle_radius = particle_radius\n",
    "        self.timestep = timestep\n",
    "        self.encoder_func = encoder_func\n",
    "        self.layer_channels = [32, 64, 64, 64, 3]\n",
    "        self.filter_extent = np.float32(self.radius_scale * 6 *\n",
    "                                        self.particle_radius)\n",
    "        \n",
    "        self.encoder_hidden_size = encoder_hidden_size\n",
    "        \n",
    "        self.in_channel = 1 + 3 + self.encoder_hidden_size\n",
    "        \n",
    "        # create continuous convolution and fully-connected layers\n",
    "        \n",
    "        convs = []\n",
    "        denses = []\n",
    "        \n",
    "        self.conv_fluid = CtsConv(in_channels = self.in_channel, \n",
    "                                  out_channels = self.layer_channels[0],\n",
    "                                  kernel_sizes = self.kernel_sizes,\n",
    "                                  radius = self.particle_radius)\n",
    "        \n",
    "        self.conv_obstacle = CtsConv(in_channels = 3, \n",
    "                                     out_channels = self.layer_channels[0],\n",
    "                                     kernel_sizes = self.kernel_sizes,\n",
    "                                     radius = self.particle_radius)\n",
    "        \n",
    "        self.dense_fluid = nn.Linear(self.in_channel, self.layer_channels[0])\n",
    "        \n",
    "        in_ch = 3 * self.layer_channels[0] # concat conv_obstacle, conv_fluid, dense_fluid\n",
    "        for i in range(1, len(self.layer_channels)):\n",
    "            out_ch = self.layer_channels[i]\n",
    "            dense = nn.Linear(in_ch, out_ch)\n",
    "            denses.append(dense)\n",
    "            conv = CtsConv(in_channels = in_ch, \n",
    "                           out_channels = out_ch,\n",
    "                           kernel_sizes = self.kernel_sizes,\n",
    "                           radius = self.particle_radius)\n",
    "            convs.append(conv)\n",
    "            in_ch = self.layer_channels[i]\n",
    "        \n",
    "        self.convs = nn.ModuleList(convs)\n",
    "        self.denses = nn.ModuleList(denses)\n",
    "        \n",
    "        if self.encoder_func == 'lstm':\n",
    "            self.encoder = nn.LSTMCell(3, self.encoder_hidden_size)\n",
    "        else:\n",
    "            raise Exception('Unknown encoder type ' + encoder_func)\n",
    "            \n",
    "    def update_pos_vel(self, p0, v0, a):\n",
    "        \"\"\"Apply acceleration and integrate position and velocity.\n",
    "        Assume the particle has constant acceleration during timestep.\n",
    "        Return particle's position and velocity after 1 unit timestep.\"\"\"\n",
    "        \n",
    "        dt = self.timestep\n",
    "        v1 = v0 + dt * a\n",
    "        p1 = p0 + dt * (v0 + v1) / 2\n",
    "        return p1, v1\n",
    "\n",
    "    def apply_correction(self, p0, p1, correction):\n",
    "        \"\"\"Apply the position correction\n",
    "        p0, p1: the position of the particle before and after basic integration. \"\"\"\n",
    "        dt = self.timestep\n",
    "        p_corrected = p1 + correction\n",
    "        v_corrected = (p_corrected - p0) / dt\n",
    "        return p_corrected, v_corrected\n",
    "    \n",
    "    def dense_forward(self, in_feats, dense_layer):\n",
    "        flatten_in_feats = in_feats.reshape(in_feats.shape[0] * in_feats.shape[1], in_feats.shape[2])\n",
    "        flatten_output = dense_layer(flatten_in_feats)\n",
    "        return flatten_output.reshape(in_feats.shape[0], in_feats.shape[1], -1)\n",
    "\n",
    "    def compute_correction(self, p, v, other_feats, box, box_feats, fluid_mask, box_mask):\n",
    "        \"\"\"Precondition: p and v were updated with accerlation\"\"\"\n",
    "\n",
    "        # compute the extent of the filters (the diameter) and the fluid features\n",
    "        filter_extent = torch.tensor(self.filter_extent)\n",
    "        fluid_feats = [torch.ones_like(p[:,:, 0:1]), v]\n",
    "        if not other_feats is None:\n",
    "            fluid_feats.append(other_feats)\n",
    "        fluid_feats = torch.cat(fluid_feats, -1)\n",
    "\n",
    "        # compute the correction by accumulating the output through the network layers\n",
    "        self.output_conv_fluid = self.conv_fluid(p, p, fluid_feats, fluid_mask)\n",
    "        self.output_dense_fluid = self.dense_forward(fluid_feats, self.dense_fluid)\n",
    "        self.output_conv_obstacle = self.conv_obstacle(box, p, box_feats, box_mask)\n",
    "        \n",
    "        feats = torch.cat((self.output_conv_obstacle, self.output_conv_fluid, self.output_dense_fluid), -1)\n",
    "        self.outputs = [feats]\n",
    "        \n",
    "        for conv, dense in zip(self.convs, self.denses):\n",
    "            # pass input features to conv and fully-connected layers\n",
    "            in_feats = F.relu(self.outputs[-1])\n",
    "            output_conv = conv(p, p, in_feats, fluid_mask)\n",
    "            output_dense = self.dense_forward(in_feats, dense)\n",
    "            \n",
    "            # if last dim size of output from cur dense layer is same as last dim size of output\n",
    "            # current output should be based off on previous output\n",
    "            if output_dense.shape[-1] == self.outputs[-1].shape[-1]:\n",
    "                output = output_conv + output_dense + self.outputs[-1]\n",
    "            else:\n",
    "                output = output_conv + output_dense\n",
    "            self.outputs.append(output)\n",
    "\n",
    "        # compute the number of fluid particle neighbors.\n",
    "        # this info is used in the loss function during training.\n",
    "        # TODO: test this block of code\n",
    "        self.num_fluid_neighbors = torch.sum(fluid_mask, dim = -1) - 1\n",
    "    \n",
    "        self.last_features = self.outputs[-2]\n",
    "\n",
    "        # scale to better match the scale of the output distribution\n",
    "        self.pos_correction = (1.0 / 128) * self.outputs[-1]\n",
    "        return self.pos_correction\n",
    "    \n",
    "    def encode_forward(self, v0_enc):\n",
    "        v0_enc_reshape = v0_enc.reshape(-1, *v0_enc.shape[-2:])\n",
    "        hx = torch.zeros(v0_enc_reshape.shape[0], self.encoder_hidden_size, device=device)\n",
    "        cx = torch.zeros(v0_enc_reshape.shape[0], self.encoder_hidden_size, device=device)\n",
    "\n",
    "        for i in range(v0_enc.shape[2]):\n",
    "            hx, cx = self.encoder(v0_enc_reshape[:,i,:], (hx, cx))\n",
    "            \n",
    "        hx, cx = hx.reshape(*v0_enc.shape[:2], -1), cx.reshape(*v0_enc.shape[:2], -1)\n",
    "        \n",
    "        return hx, cx\n",
    "    \n",
    "    def forward(self, inputs):\n",
    "        \"\"\" inputs: 8 elems tuple\n",
    "        p0_enc, v0_enc, p0, v0, a, feats, box, box_feats\n",
    "        Computes 1 simulation timestep\"\"\"\n",
    "        p0_enc, v0_enc, p0, v0, a, feats, box, box_feats, fluid_mask, box_mask = inputs\n",
    "        \n",
    "        # TODO: check the following calls\n",
    "        if self.encoder_func == 'lstm':\n",
    "            state_h, state_c = self.encode_forward(v0_enc)\n",
    "            \n",
    "        if feats is None:\n",
    "            feats = state_h\n",
    "        else:\n",
    "            feats = torch.cat((feats, state_h, state_c), -1)\n",
    "\n",
    "        p1, v1 = self.update_pos_vel(p0, v0, a)\n",
    "        pos_correction = self.compute_correction(p1, v1, feats, box, box_feats, fluid_mask, box_mask)\n",
    "        p_corrected, v_corrected = self.apply_correction(p0, p1, pos_correction)\n",
    "\n",
    "        return p_corrected, v_corrected"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-08-03T19:45:12.518207Z",
     "start_time": "2020-08-03T19:45:10.373595Z"
    }
   },
   "outputs": [],
   "source": [
    "model = ParticlesNetwork()\n",
    "model = model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-08-03T19:45:13.411211Z",
     "start_time": "2020-08-03T19:45:13.050643Z"
    }
   },
   "outputs": [],
   "source": [
    "#On Roselab5\n",
    "dataset_path = '~/particle/argoverse/argoverse_forecasting/'\n",
    "#On Maplewalnut\n",
    "\n",
    "\n",
    "val_path = os.path.join(dataset_path, 'val', 'clean_data')\n",
    "train_path = os.path.join(dataset_path, 'train', 'clean_data')\n",
    "dataset = read_pkl_data(train_path, batch_size=16)\n",
    "data_loader = iter(dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-08-03T19:45:14.450412Z",
     "start_time": "2020-08-03T19:45:14.070566Z"
    }
   },
   "outputs": [],
   "source": [
    "data = next(data_loader)\n",
    "data['lane_mask'] = [np.array([0])] * 16"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-08-03T19:45:14.921709Z",
     "start_time": "2020-08-03T19:45:14.904843Z"
    }
   },
   "outputs": [],
   "source": [
    "batch_tensor = {}\n",
    "convert_keys = (['pos' + str(i) for i in range(30)] + \n",
    "                ['vel' + str(i) for i in range(30)] + \n",
    "                ['pos_2s', 'vel_2s', 'lane', 'lane_norm', 'car_mask', 'lane_mask'])\n",
    "\n",
    "for k in convert_keys:\n",
    "    batch_tensor[k] = torch.tensor(np.stack(data[k]), dtype=torch.float32, device=device)\n",
    "\n",
    "for k in ['track_id' + str(i) for i in range(30)] + ['city']:\n",
    "    batch_tensor[k] = data[k]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-08-03T19:45:16.213795Z",
     "start_time": "2020-08-03T19:45:16.198501Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([16, 60]), (16, 1, 60, 1))"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "batch_tensor['car_mask'].squeeze(-1).shape, (16, 1, 60, 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-08-03T19:45:16.730898Z",
     "start_time": "2020-08-03T19:45:16.724892Z"
    }
   },
   "outputs": [],
   "source": [
    "inputs = ([\n",
    "        batch_tensor['pos_2s'], batch_tensor['vel_2s'], \n",
    "        batch_tensor['pos0'], batch_tensor['vel0'], \n",
    "        torch.zeros(batch_tensor['vel0'].shape[0], 1, 3).to(device), None,\n",
    "        batch_tensor['lane'], batch_tensor['lane_norm'], \n",
    "        batch_tensor['car_mask'].squeeze(-1), batch_tensor['lane_mask']\n",
    "    ])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-08-03T19:45:21.595757Z",
     "start_time": "2020-08-03T19:45:21.582839Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "pos_2s: torch.Size([16, 60, 19, 3])\n",
      "vel_2s: torch.Size([16, 60, 19, 3])\n",
      "pos0: torch.Size([16, 60, 3])\n",
      "vel0: torch.Size([16, 60, 3])\n",
      "accel (0s): torch.Size([16, 1, 3])\n",
      "feats (None): None\n",
      "lane: torch.Size([16, 1, 3])\n",
      "lane_norm: torch.Size([16, 1, 3])\n",
      "car_mask: torch.Size([16, 60])\n",
      "lane_mask: torch.Size([16, 1])\n"
     ]
    }
   ],
   "source": [
    "def print_inputs_shape(inputs):\n",
    "    print('pos_2s:',inputs[0].shape)\n",
    "    print('vel_2s:',inputs[1].shape)\n",
    "    print('pos0:',inputs[2].shape)\n",
    "    print('vel0:',inputs[3].shape)\n",
    "    print('accel (0s):',inputs[4].shape)\n",
    "    print('feats (None):',inputs[5])\n",
    "    print('lane:',inputs[6].shape)\n",
    "    print('lane_norm:',inputs[7].shape)\n",
    "    print('car_mask:',inputs[8].shape)\n",
    "    print('lane_mask:',inputs[9].shape)\n",
    "    \n",
    "print_inputs_shape(inputs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-08-03T19:45:22.738432Z",
     "start_time": "2020-08-03T19:45:22.519917Z"
    },
    "scrolled": false
   },
   "outputs": [
    {
     "ename": "AttributeError",
     "evalue": "'CtsConv' object has no attribute 'normalize_attention'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mAttributeError\u001b[0m                            Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-11-cc1c37d1c8fd>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      2\u001b[0m \u001b[0;31m# print(kernel_on_field.shape, unsqueezed_feat.shape, attention.shape)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      3\u001b[0m \u001b[0;31m# out = torch.einsum('bmnoi,bmni->bmo', kernel_on_field, unsqueezed_feat*attention)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/.local/lib/python3.6/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m    548\u001b[0m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    549\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 550\u001b[0;31m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    551\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    552\u001b[0m             \u001b[0mhook_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m<ipython-input-3-d058f4e6974f>\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, inputs)\u001b[0m\n\u001b[1;32m    162\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    163\u001b[0m         \u001b[0mp1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mv1\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate_pos_vel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mp0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mv0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0ma\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 164\u001b[0;31m         \u001b[0mpos_correction\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcompute_correction\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mp1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mv1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfeats\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbox\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbox_feats\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfluid_mask\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbox_mask\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    165\u001b[0m         \u001b[0mp_corrected\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mv_corrected\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mapply_correction\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mp0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mp1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpos_correction\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    166\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m<ipython-input-3-d058f4e6974f>\u001b[0m in \u001b[0;36mcompute_correction\u001b[0;34m(self, p, v, other_feats, box, box_feats, fluid_mask, box_mask)\u001b[0m\n\u001b[1;32m    102\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    103\u001b[0m         \u001b[0;31m# compute the correction by accumulating the output through the network layers\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 104\u001b[0;31m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moutput_conv_fluid\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconv_fluid\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mp\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mp\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfluid_feats\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfluid_mask\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    105\u001b[0m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moutput_dense_fluid\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdense_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfluid_feats\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdense_fluid\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    106\u001b[0m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moutput_conv_obstacle\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconv_obstacle\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbox\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mp\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbox_feats\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbox_mask\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/.local/lib/python3.6/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m    548\u001b[0m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    549\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 550\u001b[0;31m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    551\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    552\u001b[0m             \u001b[0mhook_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/particle/TrafficFluidsPt/TrafficFluids/models/CtsConv.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, field, center, field_feat, field_mask, ctr_feat)\u001b[0m\n\u001b[1;32m    150\u001b[0m     ):\n\u001b[1;32m    151\u001b[0m         out = self.ContinuousConv(\n\u001b[0;32m--> 152\u001b[0;31m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkernel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfield\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcenter\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfield_feat\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfield_mask\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mctr_feat\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    153\u001b[0m         )\n\u001b[1;32m    154\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mout\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/particle/TrafficFluidsPt/TrafficFluids/models/CtsConv.py\u001b[0m in \u001b[0;36mContinuousConv\u001b[0;34m(self, kernel, field, center, field_feat, field_mask, ctr_feat)\u001b[0m\n\u001b[1;32m    129\u001b[0m         \u001b[0;31m# attention: [batch, num_m, num_n, 1]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    130\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 131\u001b[0;31m         \u001b[0mpsi\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mattention\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxis\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnormalize_attention\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    132\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    133\u001b[0m         \u001b[0mscaled_field\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mLambda\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrelative_field\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/.local/lib/python3.6/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__getattr__\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m    592\u001b[0m                 \u001b[0;32mreturn\u001b[0m \u001b[0mmodules\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    593\u001b[0m         raise AttributeError(\"'{}' object has no attribute '{}'\".format(\n\u001b[0;32m--> 594\u001b[0;31m             type(self).__name__, name))\n\u001b[0m\u001b[1;32m    595\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    596\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m__setattr__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mAttributeError\u001b[0m: 'CtsConv' object has no attribute 'normalize_attention'"
     ]
    }
   ],
   "source": [
    "model(inputs)\n",
    "# print(kernel_on_field.shape, unsqueezed_feat.shape, attention.shape)\n",
    "# out = torch.einsum('bmnoi,bmni->bmo', kernel_on_field, unsqueezed_feat*attention)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-08-02T22:30:34.351191Z",
     "start_time": "2020-08-02T22:30:34.329864Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([960, 32]) torch.Size([960, 32]) torch.Size([960, 3])\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(torch.Size([16, 60, 32]), torch.Size([16, 60, 32]))"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "encoder_hidden_size = 32\n",
    "v0_enc = batch_tensor['vel_2s']\n",
    "encoder = nn.LSTMCell(3, encoder_hidden_size).to(device)\n",
    "\n",
    "v0_enc_reshape = v0_enc.reshape(-1, *v0_enc.shape[-2:])\n",
    "hx = torch.zeros(v0_enc_reshape.shape[0], encoder_hidden_size, device=device)\n",
    "cx = torch.zeros(v0_enc_reshape.shape[0], encoder_hidden_size, device=device)\n",
    "print(hx.shape, cx.shape, v0_enc_reshape[:,0,:].shape)\n",
    "for i in range(v0_enc.shape[2]):\n",
    "    hx, cx = encoder(v0_enc_reshape[:,i,:], (hx, cx))\n",
    "hx, cx = hx.reshape(*v0_enc.shape[:2], -1), cx.reshape(*v0_enc.shape[:2], -1)\n",
    "hx.shape, cx.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-08-02T22:30:34.358841Z",
     "start_time": "2020-08-02T22:30:34.353757Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "device(type='cuda', index=0)"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.conv_obstacle.kernel.device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-08-02T22:40:00.165470Z",
     "start_time": "2020-08-02T22:39:59.705110Z"
    }
   },
   "outputs": [
    {
     "ename": "AttributeError",
     "evalue": "'CtsConv' object has no attribute 'device'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mAttributeError\u001b[0m                            Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-14-a57409624050>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconvs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m~/anaconda3/envs/tf-gpu/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__getattr__\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m    574\u001b[0m                 \u001b[0;32mreturn\u001b[0m \u001b[0mmodules\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    575\u001b[0m         raise AttributeError(\"'{}' object has no attribute '{}'\".format(\n\u001b[0;32m--> 576\u001b[0;31m             type(self).__name__, name))\n\u001b[0m\u001b[1;32m    577\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    578\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m__setattr__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mAttributeError\u001b[0m: 'CtsConv' object has no attribute 'device'"
     ]
    }
   ],
   "source": [
    "model.convs[0].kernel.device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
