{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-08-16T16:40:01.171039Z",
     "start_time": "2020-08-16T16:40:00.816672Z"
    }
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "from CtsConv import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-08-16T16:40:01.234606Z",
     "start_time": "2020-08-16T16:40:01.172214Z"
    }
   },
   "outputs": [],
   "source": [
    "nn.Parameter?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-08-16T16:40:01.500612Z",
     "start_time": "2020-08-16T16:40:01.485245Z"
    },
    "code_folding": [
     0
    ]
   },
   "outputs": [],
   "source": [
    "class CtsConvLSTMCellOld(nn.Module):\n",
    "    \n",
    "    def __init__(\n",
    "        self,\n",
    "        in_channels, \n",
    "        out_channels, \n",
    "        kernel_sizes, \n",
    "        radius, \n",
    "        normalize_attention=False, \n",
    "        bias=True\n",
    "    ):\n",
    "        super(CtsConvLSTMCell, self).__init__()\n",
    "        \n",
    "        self.in_channels = in_channels\n",
    "        self.out_channels = out_channels\n",
    "        self.kernel_sizes = kernel_sizes\n",
    "        self.radius = radius\n",
    "        self.normalize_attention = normalize_attention\n",
    "        self.bias = bias\n",
    "        \n",
    "        self.cts_conv_f = self.__get_ctsconv_layer('cts_conv_f')\n",
    "        self.cts_conv_i = self.__get_ctsconv_layer('cts_conv_i')\n",
    "        self.cts_conv_c = self.__get_ctsconv_layer('cts_conv_c')\n",
    "        self.cts_conv_o = self.__get_ctsconv_layer('cts_conv_o')\n",
    "        \n",
    "        self.weight_ci = self.__get_weight('weight_ci')\n",
    "        self.weight_cf = self.__get_weight('weight_cf')\n",
    "        self.weight_co = self.__get_weight('weight_co')\n",
    "        \n",
    "        if self.bias:\n",
    "            self.bias_i = self.__get_weight('bias_i')\n",
    "            self.bias_f = self.__get_weight('bias_f')\n",
    "            self.bias_o = self.__get_weight('bias_o')\n",
    "            self.bias_c = self.__get_weight('bias_c')\n",
    "        \n",
    "    def forward(self, inputs, states):\n",
    "        field, center, field_feat, field_mask = inputs\n",
    "        h, c = states\n",
    "        \n",
    "        feats = torch.cat([field_feat, h], axis=-1)\n",
    "        \n",
    "        # ========= UPDATE states C\n",
    "        # i gate\n",
    "        ans_i = self.cts_conv_i(field, center, feats, field_mask) + self.weight_ci * c\n",
    "        \n",
    "        # f gate\n",
    "        ans_f = self.cts_conv_f(field, center, feats, field_mask) + self.weight_cf * c\n",
    "        \n",
    "        # g gate\n",
    "        ans_g = self.cts_conv_c(field, center, feats, field_mask)\n",
    "        \n",
    "        if self.bias:\n",
    "            ans_i = ans_i + self.bias_i\n",
    "            ans_f = ans_f + self.bias_f\n",
    "            ans_g = ans_g + self.bias_c\n",
    "        \n",
    "        # activation\n",
    "        ans_i = F.relu(ans_i)\n",
    "        ans_f = F.relu(ans_f)\n",
    "        ans_g = torch.tanh(ans_g)\n",
    "        \n",
    "        c_out = c * ans_f + ans_i * ans_g\n",
    "        \n",
    "        # ========= UPDATE states H\n",
    "        # o gate\n",
    "        ans_o = self.cts_conv_o(field, center, feats, field_mask) + self.weight_co * c_out\n",
    "        \n",
    "        if self.bias:\n",
    "            ans_o = ans_o + self.bias_o\n",
    "        ans_o = F.relu(ans_o)\n",
    "        h_out = ans_o * torch.tanh(c_out)\n",
    "        \n",
    "        return (h_out, c_out)\n",
    "        \n",
    "    def __get_ctsconv_layer(self, layer_name):\n",
    "        return CtsConv(in_channels=self.in_channels + self.out_channels, \n",
    "                       out_channels=self.out_channels, \n",
    "                       kernel_sizes=self.kernel_sizes,\n",
    "                       radius=self.radius,\n",
    "                       normalize_attention=self.normalize_attention, \n",
    "                       layer_name=layer_name)\n",
    "    \n",
    "    def __get_weight(self, weight_name):\n",
    "        nn_param = torch.rand(self.out_channels) - 0.5\n",
    "        nn_param = nn_param / (self.out_channels)\n",
    "        nn_param = nn.Parameter(nn_param)\n",
    "        self.register_parameter(weight_name, nn_param)\n",
    "        return nn_param\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-08-16T16:40:01.892571Z",
     "start_time": "2020-08-16T16:40:01.873440Z"
    }
   },
   "outputs": [],
   "source": [
    "class CtsConvLSTMCell(nn.Module):\n",
    "    \n",
    "    def __init__(\n",
    "        self,\n",
    "        in_channels, \n",
    "        out_channels, \n",
    "        kernel_sizes, \n",
    "        radius, \n",
    "        normalize_attention=False, \n",
    "        bias=True\n",
    "    ):\n",
    "        super(CtsConvLSTMCell, self).__init__()\n",
    "        \n",
    "        self.in_channels = in_channels\n",
    "        self.out_channels = out_channels\n",
    "        self.kernel_sizes = kernel_sizes\n",
    "        self.radius = radius\n",
    "        self.normalize_attention = normalize_attention\n",
    "        self.bias = bias\n",
    "        \n",
    "        self.cts_conv_f = self.__get_ctsconv_layer('cts_conv_f')\n",
    "        self.cts_conv_i = self.__get_ctsconv_layer('cts_conv_i')\n",
    "        self.cts_conv_g = self.__get_ctsconv_layer('cts_conv_g')\n",
    "        self.cts_conv_o = self.__get_ctsconv_layer('cts_conv_o')\n",
    "        \n",
    "        weight_shape = (self.in_channels + self.out_channels, self.out_channels)\n",
    "        self.dense_i = nn.Linear(*weight_shape, bias=self.bias)\n",
    "        self.dense_f = nn.Linear(*weight_shape, bias=self.bias)\n",
    "        self.dense_g = nn.Linear(*weight_shape, bias=self.bias)\n",
    "        self.dense_o = nn.Linear(*weight_shape, bias=self.bias)\n",
    "        \n",
    "        if self.bias:\n",
    "            self.bias_i = self.__get_bias('bias_i')\n",
    "            self.bias_f = self.__get_bias('bias_f')\n",
    "            self.bias_o = self.__get_bias('bias_o')\n",
    "            self.bias_c = self.__get_bias('bias_c')\n",
    "        \n",
    "    def forward(self, inputs, states):\n",
    "        field, center, field_feat, field_mask = inputs\n",
    "        h, c = states\n",
    "        \n",
    "        feats = torch.cat([field_feat, h], axis=-1)\n",
    "        \n",
    "        # ========= UPDATE states C\n",
    "        # i gate\n",
    "        ans_i = self.cts_conv_i(field, center, feats, field_mask)\n",
    "        ans_i = ans_i + self.dense_i(feats)\n",
    "        \n",
    "        # f gate\n",
    "        ans_f = self.cts_conv_f(field, center, feats, field_mask)\n",
    "        ans_f = ans_f + self.dense_f(feats)\n",
    "        \n",
    "        # g gate\n",
    "        ans_g = self.cts_conv_g(field, center, feats, field_mask)\n",
    "        ans_g = ans_g + self.dense_g(feats)\n",
    "        \n",
    "        if self.bias:\n",
    "            ans_i = ans_i + self.bias_i\n",
    "            ans_f = ans_f + self.bias_f\n",
    "            ans_g = ans_g + self.bias_c\n",
    "        \n",
    "        # activation\n",
    "        ans_i = F.relu(ans_i)\n",
    "        ans_f = F.relu(ans_f)\n",
    "        ans_g = torch.tanh(ans_g)\n",
    "        \n",
    "        c_out = c * ans_f + ans_i * ans_g\n",
    "        \n",
    "        # ========= UPDATE states H\n",
    "        # o gate\n",
    "        ans_o = self.cts_conv_o(field, center, feats, field_mask)\n",
    "        ans_o = ans_o + self.dense_o(feats)\n",
    "        \n",
    "        if self.bias:\n",
    "            ans_o = ans_o + self.bias_o\n",
    "        ans_o = F.relu(ans_o)\n",
    "        h_out = ans_o * torch.tanh(c_out)\n",
    "        \n",
    "        return (h_out, c_out)\n",
    "        \n",
    "    def __get_ctsconv_layer(self, layer_name):\n",
    "        return CtsConv(in_channels=self.in_channels + self.out_channels, \n",
    "                       out_channels=self.out_channels, \n",
    "                       kernel_sizes=self.kernel_sizes,\n",
    "                       radius=self.radius,\n",
    "                       normalize_attention=self.normalize_attention, \n",
    "                       layer_name=layer_name)\n",
    "\n",
    "    def __get_bias(self, bias_name):\n",
    "        nn_param = torch.rand(self.out_channels) - 0.5\n",
    "        nn_param = nn_param / (self.out_channels)\n",
    "        nn_param = nn.Parameter(nn_param)\n",
    "        self.register_parameter(bias_name, nn_param)\n",
    "        return nn_param\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-08-16T16:40:02.703342Z",
     "start_time": "2020-08-16T16:40:02.679604Z"
    }
   },
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '0, 1, 2, 3'\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-08-16T16:40:03.166855Z",
     "start_time": "2020-08-16T16:40:03.158863Z"
    }
   },
   "outputs": [],
   "source": [
    "pos = torch.rand(8,60,3)\n",
    "vel = torch.rand(8,60,3)\n",
    "h = torch.zeros(8,60,32, dtype=torch.float32)\n",
    "c = torch.zeros(8,60,32, dtype=torch.float32)\n",
    "mask = torch.tensor([[1.]*50 + [0.]*10] * 8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-08-16T16:40:04.284065Z",
     "start_time": "2020-08-16T16:40:04.267432Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([8, 60, 35])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "feat = torch.cat([vel, h], axis=-1)\n",
    "feat.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-08-16T16:40:10.754803Z",
     "start_time": "2020-08-16T16:40:08.565429Z"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "model = nn.DataParallel(CtsConvLSTMCell(3, 32, [4,4,4], 1, bias=False)).to(device)\n",
    "inputs = (pos, pos, vel, mask)\n",
    "states = (h, c)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-08-16T16:40:18.121180Z",
     "start_time": "2020-08-16T16:40:11.334870Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(0.0495, device='cuda:0', grad_fn=<MaxBackward1>) tensor(0.3063, device='cuda:0', grad_fn=<MaxBackward1>)\n",
      "tensor(0.0485, device='cuda:0', grad_fn=<MaxBackward1>) tensor(0.4205, device='cuda:0', grad_fn=<MaxBackward1>)\n",
      "tensor(0.0496, device='cuda:0', grad_fn=<MaxBackward1>) tensor(0.4676, device='cuda:0', grad_fn=<MaxBackward1>)\n",
      "tensor(0.0499, device='cuda:0', grad_fn=<MaxBackward1>) tensor(0.4857, device='cuda:0', grad_fn=<MaxBackward1>)\n",
      "tensor(0.0500, device='cuda:0', grad_fn=<MaxBackward1>) tensor(0.4927, device='cuda:0', grad_fn=<MaxBackward1>)\n",
      "tensor(0.0500, device='cuda:0', grad_fn=<MaxBackward1>) tensor(0.4954, device='cuda:0', grad_fn=<MaxBackward1>)\n",
      "tensor(0.0500, device='cuda:0', grad_fn=<MaxBackward1>) tensor(0.4964, device='cuda:0', grad_fn=<MaxBackward1>)\n",
      "tensor(0.0500, device='cuda:0', grad_fn=<MaxBackward1>) tensor(0.4968, device='cuda:0', grad_fn=<MaxBackward1>)\n",
      "tensor(0.0500, device='cuda:0', grad_fn=<MaxBackward1>) tensor(0.4970, device='cuda:0', grad_fn=<MaxBackward1>)\n",
      "tensor(0.0500, device='cuda:0', grad_fn=<MaxBackward1>) tensor(0.4970, device='cuda:0', grad_fn=<MaxBackward1>)\n",
      "tensor(0.0500, device='cuda:0', grad_fn=<MaxBackward1>) tensor(0.4970, device='cuda:0', grad_fn=<MaxBackward1>)\n",
      "tensor(0.0500, device='cuda:0', grad_fn=<MaxBackward1>) tensor(0.4970, device='cuda:0', grad_fn=<MaxBackward1>)\n",
      "tensor(0.0500, device='cuda:0', grad_fn=<MaxBackward1>) tensor(0.4971, device='cuda:0', grad_fn=<MaxBackward1>)\n",
      "tensor(0.0500, device='cuda:0', grad_fn=<MaxBackward1>) tensor(0.4971, device='cuda:0', grad_fn=<MaxBackward1>)\n",
      "tensor(0.0500, device='cuda:0', grad_fn=<MaxBackward1>) tensor(0.4971, device='cuda:0', grad_fn=<MaxBackward1>)\n",
      "tensor(0.0500, device='cuda:0', grad_fn=<MaxBackward1>) tensor(0.4971, device='cuda:0', grad_fn=<MaxBackward1>)\n",
      "tensor(0.0500, device='cuda:0', grad_fn=<MaxBackward1>) tensor(0.4971, device='cuda:0', grad_fn=<MaxBackward1>)\n",
      "tensor(0.0500, device='cuda:0', grad_fn=<MaxBackward1>) tensor(0.4971, device='cuda:0', grad_fn=<MaxBackward1>)\n",
      "tensor(0.0500, device='cuda:0', grad_fn=<MaxBackward1>) tensor(0.4971, device='cuda:0', grad_fn=<MaxBackward1>)\n",
      "tensor(0.0500, device='cuda:0', grad_fn=<MaxBackward1>) tensor(0.4971, device='cuda:0', grad_fn=<MaxBackward1>)\n"
     ]
    }
   ],
   "source": [
    "for i in range(20):\n",
    "    h_new, c_new = model(inputs, states)\n",
    "    states = (h_new, c_new)\n",
    "    print(torch.max(states[0]), torch.max(states[1]))\n",
    "# states"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-08-15T18:53:16.830944Z",
     "start_time": "2020-08-15T18:53:16.822708Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor(0, device='cuda:0'),\n",
       " tensor(0, device='cuda:0'),\n",
       " tensor(0.1937, device='cuda:0', grad_fn=<MaxBackward1>),\n",
       " tensor(0.8584, device='cuda:0', grad_fn=<MaxBackward1>))"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(states[0] > 1).sum(), (states[1] > 1).sum(), torch.max(states[0]), torch.max(states[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-08-13T19:35:31.611659Z",
     "start_time": "2020-08-13T19:35:31.560554Z"
    }
   },
   "outputs": [],
   "source": [
    "torch.nn.LSTMCell?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-08-12T22:00:14.624445Z",
     "start_time": "2020-08-12T22:00:14.620597Z"
    }
   },
   "outputs": [],
   "source": [
    "# self, in_channels, out_channels, kernel_sizes, radius, normalize_attention=False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-08-12T22:00:36.860826Z",
     "start_time": "2020-08-12T22:00:36.853007Z"
    }
   },
   "outputs": [],
   "source": [
    "nn.init.xavier_normal_?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-08-12T22:08:47.052401Z",
     "start_time": "2020-08-12T22:08:47.044550Z"
    }
   },
   "outputs": [],
   "source": [
    "nn.Module.register_parameter?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-08-12T22:11:31.568400Z",
     "start_time": "2020-08-12T22:11:31.553582Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "False"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "None is not None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-08-13T04:24:12.226757Z",
     "start_time": "2020-08-13T04:24:12.221239Z"
    }
   },
   "outputs": [],
   "source": [
    "torch.rand?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-08-13T19:58:21.233865Z",
     "start_time": "2020-08-13T19:58:21.227787Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<function _VariableFunctions.stack>"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.stack"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
