{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "from distributions.networks import DecoderCubText\n",
    "from distributions.ResidualBlocks import ResidualBlock1dConv, ResidualBlock1dTransposeConv\n",
    "\n",
    "from torchsummary import summary"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 96,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "----------------------------------------------------------------\n",
      "        Layer (type)               Output Shape         Param #\n",
      "================================================================\n",
      "            Linear-1                  [-1, 128]           8,320\n",
      "   ConvTranspose1d-2               [-1, 128, 4]          65,664\n",
      "              ReLU-3               [-1, 128, 4]               0\n",
      "   ConvTranspose1d-4               [-1, 128, 8]          65,664\n",
      "              ReLU-5               [-1, 128, 8]               0\n",
      "            Conv1d-6              [-1, 1590, 8]         205,110\n",
      "           Softmax-7              [-1, 1590, 8]               0\n",
      "================================================================\n",
      "Total params: 344,758\n",
      "Trainable params: 344,758\n",
      "Non-trainable params: 0\n",
      "----------------------------------------------------------------\n",
      "Input size (MB): 0.00\n",
      "Forward/backward pass size (MB): 0.22\n",
      "Params size (MB): 1.32\n",
      "Estimated Total Size (MB): 1.53\n",
      "----------------------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "net = DecoderText(64, 1590)\n",
    "summary(net, input_size=(64,), device='cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 142,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch import nn\n",
    "\n",
    "class EncoderCubText2(nn.Module):\n",
    "    def __init__(self, dim_text=32, vocab_size=1590):\n",
    "        super().__init__()\n",
    "        dim_text = dim_text\n",
    "        num_features = vocab_size\n",
    "\n",
    "        self.conv1 = nn.Conv1d(num_features, 4*dim_text,\n",
    "                               kernel_size=4, stride=2, padding=1, dilation=1)\n",
    "        self.resblock_1 = self.make_res_block_encoder_feature_extractor(4 * dim_text,\n",
    "                                                                                 5 * dim_text,\n",
    "                                                                                 kernelsize=4, stride=1, padding=1,\n",
    "                                                                                 dilation=1)\n",
    "        self.resblock_2 = self.make_res_block_encoder_feature_extractor(5 * dim_text,\n",
    "                                                                                 5 * dim_text,\n",
    "                                                                                 kernelsize=4, stride=2, padding=1,\n",
    "                                                                                 dilation=1)\n",
    "        self.resblock_3 = self.make_res_block_encoder_feature_extractor(5 * dim_text,\n",
    "                                                                                 5 * dim_text,\n",
    "                                                                                 kernelsize=4, stride=2, padding=1,\n",
    "                                                                                 dilation=1)\n",
    "\n",
    "        self.relu = nn.ReLU(inplace=True)\n",
    "\n",
    "    def forward(self, x):\n",
    "        batch_size = x.size()[0]\n",
    "        x = x.transpose(-2, -1)\n",
    "        out = self.conv1(x)\n",
    "        out = self.resblock_1(out)\n",
    "        out = self.resblock_2(out)\n",
    "        out = self.resblock_3(out)\n",
    "        out = out.view(batch_size, -1)\n",
    "        out = self.relu(out)\n",
    "        return out\n",
    "    \n",
    "    @staticmethod\n",
    "    def make_res_block_encoder_feature_extractor(in_channels, out_channels, kernelsize, stride, padding, dilation,\n",
    "                                                 a_val=2.0, b_val=0.3):\n",
    "        downsample = None\n",
    "        if (stride != 1) or (in_channels != out_channels) or dilation != 1:\n",
    "            downsample = nn.Sequential(nn.Conv1d(in_channels, out_channels,\n",
    "                                                 kernel_size=kernelsize,\n",
    "                                                 stride=stride,\n",
    "                                                 padding=padding,\n",
    "                                                 dilation=dilation),\n",
    "                                       nn.BatchNorm1d(out_channels))\n",
    "        layers = []\n",
    "        layers.append(\n",
    "            ResidualBlock1dConv(in_channels, out_channels, kernelsize, stride, padding, dilation, downsample, a=a_val,\n",
    "                                b=b_val))\n",
    "        return nn.Sequential(*layers)           "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 143,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "----------------------------------------------------------------\n",
      "        Layer (type)               Output Shape         Param #\n",
      "================================================================\n",
      "            Conv1d-1              [-1, 128, 16]         814,208\n",
      "       BatchNorm1d-2              [-1, 128, 16]             256\n",
      "              ReLU-3              [-1, 128, 16]               0\n",
      "            Conv1d-4              [-1, 128, 16]          16,512\n",
      "           Dropout-5              [-1, 128, 16]               0\n",
      "       BatchNorm1d-6              [-1, 128, 16]             256\n",
      "              ReLU-7              [-1, 128, 16]               0\n",
      "            Conv1d-8              [-1, 160, 15]          82,080\n",
      "           Dropout-9              [-1, 160, 15]               0\n",
      "           Conv1d-10              [-1, 160, 15]          82,080\n",
      "      BatchNorm1d-11              [-1, 160, 15]             320\n",
      "ResidualBlock1dConv-12              [-1, 160, 15]               0\n",
      "      BatchNorm1d-13              [-1, 160, 15]             320\n",
      "             ReLU-14              [-1, 160, 15]               0\n",
      "           Conv1d-15              [-1, 160, 15]          25,760\n",
      "          Dropout-16              [-1, 160, 15]               0\n",
      "      BatchNorm1d-17              [-1, 160, 15]             320\n",
      "             ReLU-18              [-1, 160, 15]               0\n",
      "           Conv1d-19               [-1, 160, 7]         102,560\n",
      "          Dropout-20               [-1, 160, 7]               0\n",
      "           Conv1d-21               [-1, 160, 7]         102,560\n",
      "      BatchNorm1d-22               [-1, 160, 7]             320\n",
      "ResidualBlock1dConv-23               [-1, 160, 7]               0\n",
      "      BatchNorm1d-24               [-1, 160, 7]             320\n",
      "             ReLU-25               [-1, 160, 7]               0\n",
      "           Conv1d-26               [-1, 160, 7]          25,760\n",
      "          Dropout-27               [-1, 160, 7]               0\n",
      "      BatchNorm1d-28               [-1, 160, 7]             320\n",
      "             ReLU-29               [-1, 160, 7]               0\n",
      "           Conv1d-30               [-1, 160, 3]         102,560\n",
      "          Dropout-31               [-1, 160, 3]               0\n",
      "           Conv1d-32               [-1, 160, 3]         102,560\n",
      "      BatchNorm1d-33               [-1, 160, 3]             320\n",
      "ResidualBlock1dConv-34               [-1, 160, 3]               0\n",
      "             ReLU-35                  [-1, 480]               0\n",
      "================================================================\n",
      "Total params: 1,459,392\n",
      "Trainable params: 1,459,392\n",
      "Non-trainable params: 0\n",
      "----------------------------------------------------------------\n",
      "Input size (MB): 0.19\n",
      "Forward/backward pass size (MB): 0.43\n",
      "Params size (MB): 5.57\n",
      "Estimated Total Size (MB): 6.19\n",
      "----------------------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "net = EncoderCubText2()\n",
    "summary(net, input_size=(32, 1590), device='cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\u001b[0;31mSignature:\u001b[0m \u001b[0msummary\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput_size\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'cuda'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
       "\u001b[0;31mDocstring:\u001b[0m <no docstring>\n",
       "\u001b[0;31mFile:\u001b[0m      /opt/conda/lib/python3.6/site-packages/torchsummary/torchsummary.py\n",
       "\u001b[0;31mType:\u001b[0m      function\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "1,480,502"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
