{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torch.nn.functional as F\n",
    "# import open3d.ml.tf as o3dml"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cuda:0\n"
     ]
    }
   ],
   "source": [
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "print(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Particles(nn.Module):\n",
    "    def __init__(self, \n",
    "                 kernel_size = [4, 4, 4],\n",
    "                 radius_scale = 40,\n",
    "                 coordinate_mapping = 'ball_to_cube',\n",
    "                 interpolation = 'linear',\n",
    "                 use_window = True,\n",
    "                 particle_radius = 0.5,\n",
    "                 timestep = 0.1\n",
    "                 ):\n",
    "        super(Particles, self).__init__()\n",
    "        self.kernel_size = kernel_size\n",
    "        self.radius_scale = radius_scale\n",
    "        self.coordinate_mapping = coordinate_mapping\n",
    "        self.interpolation = interpolation\n",
    "        self.use_window = use_window\n",
    "        self.particle_radius = particle_radius\n",
    "        self.timestep = timestep\n",
    "        \n",
    "        self.layer_channels = [32, 64, 64, 3]\n",
    "        self.filter_extent = np.float32(self.radius_scale * 6 *\n",
    "                                        self.particle_radius)\n",
    "        self.all_convs = []\n",
    "        \n",
    "        def window_poly6(r_sqr):\n",
    "            return torch.clamp((1 - r_sqr)**3, min = 0, max = 1)\n",
    "        \n",
    "        def conv_layer(name, activation = None, **kwargs):\n",
    "            \"\"\"Construct a continous convolution layer from Open3d\n",
    "            TODO: need to have open3d in pytorch\"\"\"\n",
    "            conv_fn = o3dml.layers.ContinuousConv\n",
    "\n",
    "            window_fn = None\n",
    "            if self.use_window == True:\n",
    "                window_fn = window_poly6\n",
    "\n",
    "            conv = conv_fn(name=name,\n",
    "                           kernel_size = self.kernel_size,\n",
    "                           activation = activation,\n",
    "                           align_corners = True,\n",
    "                           interpolation = self.interpolation,\n",
    "                           coordinate_mapping = self.coordinate_mapping,\n",
    "                           normalize = False,\n",
    "                           window_function = window_fn,\n",
    "                           radius_search_ignore_query_points = True,\n",
    "                           **kwargs)\n",
    "\n",
    "            self.all_convs.append((name, conv))\n",
    "            return conv\n",
    "        \n",
    "        self.conv_fluid = conv_layer(name = \"conv_fluid\",\n",
    "                                     filters = self.layer_channels[0],\n",
    "                                     activation = None)\n",
    "        self.conv_obstacle = conv_layer(name = \"conv_obstacle\",\n",
    "                                        filters = self.layer_channels[0],\n",
    "                                        activation = None)\n",
    "        # TODO: find input_dim value\n",
    "        input_dim = 1\n",
    "        self.dense_fluid = nn.Linear(input_dim, self.layer_channels[0], name = \"dense_fluid\")\n",
    "        \n",
    "        self.convs = []\n",
    "        self.denses = []\n",
    "        for i in range(1, len(self.layer_channels)):\n",
    "            ch = self.layer_channels[i]\n",
    "            # TODO: find input_dim value\n",
    "            dense = nn.Linear(input_dim, ch, name = \"dense{0}\".format(i))\n",
    "            conv = conv_layer(name = 'conv{0}'.format(i), filters = ch, activation = None)\n",
    "            self.denses.append(dense)\n",
    "            self.convs.append(conv)\n",
    "            \n",
    "    def update_pos_vel(self, p0, v0, a):\n",
    "        \"\"\"Apply acceleration and integrate position and velocity.\n",
    "        Assume the particle has constant acceleration during timestep.\n",
    "        Return particle's position and velocity after 1 unit timestep.\"\"\"\n",
    "        \n",
    "        dt = self.timestep\n",
    "        v1 = v0 + dt * a\n",
    "        p1 = p0 + dt * (v0 + v1) / 2\n",
    "        return p1, v1\n",
    "\n",
    "    def apply_correction(self, p0, p1, correction):\n",
    "        \"\"\"Apply the position correction\n",
    "        p0, p1: the position of the particle before and after basic integration. \"\"\"\n",
    "        dt = self.timestep\n",
    "        p_corrected = p1 + correction\n",
    "        v_corrected = (p_corrected - p0) / dt\n",
    "        return p_corrected, v_corrected\n",
    "\n",
    "    def compute_correction(self, p, v, other_feats, box, box_feats,\n",
    "                           fixed_radius_search_hash_table=None):\n",
    "        \"\"\"Precondition: p and v were updated with accerlation\"\"\"\n",
    "\n",
    "        # compute the extent of the filters (the diameter)\n",
    "        filter_extent = torch.tensor(self.filter_extent)\n",
    "\n",
    "        fluid_feats = [torch.ones_like(p[:, 0:1]), v]\n",
    "        if not other_feats is None:\n",
    "            fluid_feats.append(other_feats)\n",
    "        fluid_feats = torch.cat(fluid_feats, -1)\n",
    "\n",
    "        self.ans_conv_fluid = self.conv_fluid(fluid_feats, p, p, filter_extent)\n",
    "        self.ans_dense_fluid = self.dense_fluid(fluid_feats)\n",
    "        self.ans_conv_obstacle = self.conv_obstacle(box_feats, box, p, filter_extent)\n",
    "        self.ans_convs = [torch.cat([self.ans_conv_obstacle, self.ans_conv_fluid, self.ans_dense_fluid], -1)]\n",
    "        \n",
    "        for conv, dense in zip(self.convs, self.denses):\n",
    "            # pass input features to conv and fully-connected layers\n",
    "            inp_feats = F.relu(self.ans_convs[-1])\n",
    "            ans_conv = conv(inp_feats, pos, pos, filter_extent)\n",
    "            ans_dense = dense(inp_feats)\n",
    "            if ans_dense.shape[-1] == self.ans_convs[-1].shape[-1]:\n",
    "                ans = ans_conv + ans_dense + self.ans_convs[-1]\n",
    "            else:\n",
    "                ans = ans_conv + ans_dense\n",
    "            self.ans_convs.append(ans)\n",
    "\n",
    "        # compute the number of fluid particle neighbors.\n",
    "        # this info is used in the loss function during training.\n",
    "        # TODO: need to have open3d in pytorch\n",
    "        self.num_fluid_neighbors = o3dml.ops.reduce_subarrays_sum(\n",
    "            pytorch.ones_like(self.conv_fluid.nns.neighbors_index,\n",
    "                              dtype = np.float32),\n",
    "            self.conv_fluid.nns.neighbors_row_splits)\n",
    "\n",
    "        self.last_features = self.ans_convs[-2]\n",
    "\n",
    "        # scale to better match the scale of the output distribution\n",
    "        self.pos_correction = (1.0 / 128) * self.ans_convs[-1]\n",
    "        return self.pos_correction\n",
    "    \n",
    "    def forward(self, inputs, fixed_radius_search_hash_table = None):\n",
    "        \"\"\"Computes 1 simulation timestep\"\"\"\n",
    "        p0, v0, feats, box, box_feats = inputs\n",
    "\n",
    "        p1, v1 = self.update_pos_vel(p0, v0)\n",
    "        pos_correction = self.compute_correction(p1, v1, feats, box, box_feats, fixed_radius_search_hash_table)\n",
    "        p_corrected, v_corrected = self.apply_correction(p0, p1, pos_correction)\n",
    "\n",
    "        return p_corrected, v_corrected\n",
    "    \n",
    "    def init(self, feats_shape = None):\n",
    "        \"\"\"Runs the network with dummy data to initialize the shape of all variables\"\"\"\n",
    "        p0 = np.zeros(shape = (1, 3), dtype = np.float32)\n",
    "        v0 = np.zeros(shape = (1, 3), dtype = np.float32)\n",
    "        feats = None    \n",
    "        if not feats_shape is None:\n",
    "            feats = np.zeros(shape = feats_shape, dtype = np.float32)\n",
    "        box = np.zeros(shape = (1, 3), dtype = np.float32)\n",
    "        box_feats = np.zeros(shape = (1, 3), dtype = np.float32)\n",
    "        \n",
    "        _ = self.__forward__((p0, v0, feats, box, box_feats))\n",
    "            "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'o3dml' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-11-517a48139963>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mParticles\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m<ipython-input-10-2867e9cb4662>\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, kernel_size, radius_scale, coordinate_mapping, interpolation, use_window, particle_radius, timestep)\u001b[0m\n\u001b[1;32m     51\u001b[0m         self.conv_fluid = conv_layer(name = \"conv_fluid\",\n\u001b[1;32m     52\u001b[0m                                      \u001b[0mfilters\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlayer_channels\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 53\u001b[0;31m                                      activation = None)\n\u001b[0m\u001b[1;32m     54\u001b[0m         self.conv_obstacle = conv_layer(name = \"conv_obstacle\",\n\u001b[1;32m     55\u001b[0m                                         \u001b[0mfilters\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlayer_channels\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m<ipython-input-10-2867e9cb4662>\u001b[0m in \u001b[0;36mconv_layer\u001b[0;34m(name, activation, **kwargs)\u001b[0m\n\u001b[1;32m     29\u001b[0m             \"\"\"Construct a continous convolution layer from Open3d\n\u001b[1;32m     30\u001b[0m             TODO: need to have open3d in pytorch\"\"\"\n\u001b[0;32m---> 31\u001b[0;31m             \u001b[0mconv_fn\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mo3dml\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlayers\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mContinuousConv\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     32\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     33\u001b[0m             \u001b[0mwindow_fn\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mNameError\u001b[0m: name 'o3dml' is not defined"
     ]
    }
   ],
   "source": [
    "model = Particles()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
