{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n"
     ]
    }
   ],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# External imports \n",
    "import torch\n",
    "from torch.utils.data import DataLoader\n",
    "import random\n",
    "import numpy as np\n",
    "from tqdm import trange\n",
    "import matplotlib.pyplot as plt\n",
    "from IPython.display import display, clear_output\n",
    "\n",
    "# Internal imports\n",
    "import sys; sys.path.insert(0, '..')\n",
    "from src import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "N_DIMS = 2\n",
    "MI = 1000\n",
    "NUM_SAMPLES = 100000\n",
    "BS = 500\n",
    "NUM_EPOCHS = 150\n",
    "SEED = 21\n",
    "LR = 5e-3\n",
    "DROPOUT = 0.20\n",
    "DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Setting seed for reproducibility\n",
    "random.seed(SEED)\n",
    "torch.manual_seed(SEED)\n",
    "np.random.seed(SEED)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.0\n",
      "rho 1.0\n",
      "scale_p [[1. 1.]\n",
      " [1. 1.]]\n",
      "torch.Size([2])\n"
     ]
    },
    {
     "ename": "ValueError",
     "evalue": "The parameter covariance_matrix has invalid values",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mValueError\u001b[0m                                Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-29-bd16c0f85313>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      8\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      9\u001b[0m \u001b[0;31m# Define distributions\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 10\u001b[0;31m \u001b[0mp\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mq\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mm\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mget_dists_from_mi\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mMI\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mN_DIMS\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     11\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/disk_c/han/cob_pytorch/src/distributions.py\u001b[0m in \u001b[0;36mget_dists_from_mi\u001b[0;34m(mi, n_dims)\u001b[0m\n\u001b[1;32m     22\u001b[0m     \u001b[0mscale_m\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0meye\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mn_dims\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfloat32\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     23\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 24\u001b[0;31m     \u001b[0;32mreturn\u001b[0m \u001b[0mget_dists\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmu1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmu2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmu3\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mscale_p\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mscale_q\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mscale_m\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     25\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     26\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mget_rho_from_mi\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_dims\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/disk_c/han/cob_pytorch/src/distributions.py\u001b[0m in \u001b[0;36mget_dists\u001b[0;34m(mu1, mu2, mu3, scale_p, scale_q, scale_m)\u001b[0m\n\u001b[1;32m     31\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mget_dists\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmu1\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0.\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmu2\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m2.\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmu3\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m2.\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mscale_p\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0.1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mscale_q\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0.1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mscale_m\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1.\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     32\u001b[0m     \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmu1\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 33\u001b[0;31m     p = MultivariateNormal(\n\u001b[0m\u001b[1;32m     34\u001b[0m         \u001b[0mloc\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmu1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     35\u001b[0m         \u001b[0mcovariance_matrix\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mscale_p\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/disk_c/han/anaconda3/envs/torch1.8/lib/python3.8/site-packages/torch/distributions/multivariate_normal.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, loc, covariance_matrix, precision_matrix, scale_tril, validate_args)\u001b[0m\n\u001b[1;32m    144\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    145\u001b[0m         \u001b[0mevent_shape\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mloc\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 146\u001b[0;31m         \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mMultivariateNormal\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch_shape\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mevent_shape\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalidate_args\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mvalidate_args\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    147\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    148\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mscale_tril\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/disk_c/han/anaconda3/envs/torch1.8/lib/python3.8/site-packages/torch/distributions/distribution.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, batch_shape, event_shape, validate_args)\u001b[0m\n\u001b[1;32m     51\u001b[0m                     \u001b[0;32mcontinue\u001b[0m  \u001b[0;31m# skip checking lazily-constructed args\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     52\u001b[0m                 \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mconstraint\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcheck\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgetattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparam\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mall\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 53\u001b[0;31m                     \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"The parameter {} has invalid values\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparam\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     54\u001b[0m         \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mDistribution\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     55\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mValueError\u001b[0m: The parameter covariance_matrix has invalid values"
     ]
    }
   ],
   "source": [
    "# Define model\n",
    "model = RatioCritic(dim_input=N_DIMS, dim_output=3, dropout=DROPOUT)\n",
    "model.apply(weights_init)\n",
    "\n",
    "# Define optimizer\n",
    "optim = torch.optim.Adam(model.parameters(), lr=LR)\n",
    "# optim = torch.optim.SGD(model.parameters(), lr=LR)\n",
    "\n",
    "# Define distributions\n",
    "p, q, m = get_dists_from_mi(MI, N_DIMS)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define dataset & dataloader\n",
    "train_ds = DistDataset(p, q, m, num_samples=NUM_SAMPLES)\n",
    "test_ds = DistDataset(p, q, m, num_samples=NUM_SAMPLES) # Test dataset is only of size batch \n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define dataloader\n",
    "train_dl = DataLoader(train_ds, batch_size=BS, shuffle=True)\n",
    "test_dl = DataLoader(test_ds, batch_size=BS, shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set up viz\n",
    "fig, [[ax1,ax2,ax3,ax4],[ax5,ax6,ax7,ax8]] = plt.subplots(2, 4,figsize=(20,10))\n",
    "\n",
    "line, = ax1.plot([0,1],[0,1])\n",
    "test_line, = ax4.plot([0,1],[0,1],label='Test Loss')\n",
    "test_line1, = ax4.plot([0,1],[0,1],label='Train Loss')\n",
    "\n",
    "test_line_p, = ax5.plot([0,1],[0,1],label='Test Loss P')\n",
    "test_line_q, = ax6.plot([0,1],[0,1],label='Test Loss Q')\n",
    "test_line_m, = ax7.plot([0,1],[0,1],label='Test Loss M')\n",
    "\n",
    "kld_line, = ax3.plot([0],[0],label='GT KLD: '+str(MI))\n",
    "x, y = np.random.random((2, 500))\n",
    "\n",
    "scat = ax2.scatter(x,y,label='GT LR vs CoB ',alpha=0.9,s=10.,c='r')\n",
    "\n",
    "ax1.set_xlabel(\"Iteration\")\n",
    "ax1.set_ylabel(\"Train Loss\")\n",
    "ax1.set_xlim([0,NUM_EPOCHS*(NUM_SAMPLES//BS)])\n",
    "ax1.set_ylim([0,10])\n",
    "\n",
    "ax2.set_xlabel(\"Log Ratio\")\n",
    "ax2.set_ylabel(\"Log Ratio CoB\")\n",
    "ax2.legend(loc='best')\n",
    "\n",
    "ax3.set_ylabel(\"KLD\")\n",
    "ax3.legend(loc='best')\n",
    "\n",
    "ax4.set_xlabel(\"Iteration\")\n",
    "ax4.set_ylabel(\"Test Loss\")\n",
    "ax4.set_xlim([0,NUM_EPOCHS*(NUM_SAMPLES//BS)])\n",
    "ax4.set_ylim([0,10])\n",
    "ax4.legend(loc='best')\n",
    "\n",
    "ax5.set_xlabel(\"Iteration\")\n",
    "ax5.set_ylabel(\"Test Loss p\")\n",
    "ax5.set_xlim([0,NUM_EPOCHS*(NUM_SAMPLES//BS)])\n",
    "ax5.set_ylim([0,5])\n",
    "ax5.legend(loc='best')\n",
    "\n",
    "ax6.set_xlabel(\"Iteration\")\n",
    "ax6.set_ylabel(\"Test Loss q\")\n",
    "ax6.set_xlim([0,NUM_EPOCHS*(NUM_SAMPLES//BS)])\n",
    "ax6.set_ylim([0,5])\n",
    "ax6.legend(loc='best')\n",
    "\n",
    "ax7.set_xlabel(\"Iteration\")\n",
    "ax7.set_ylabel(\"Test Loss m\")\n",
    "ax7.set_xlim([0,NUM_EPOCHS*(NUM_SAMPLES//BS)])\n",
    "ax7.set_ylim([0,5])\n",
    "ax7.legend(loc='best')\n",
    "\n",
    "plt.tight_layout()\n",
    "\n",
    "loss_store = []\n",
    "kld_store = []\n",
    "test_loss_store = []\n",
    "test_loss_store1 = []\n",
    "test_loss_store_p = []\n",
    "test_loss_store_q = []\n",
    "test_loss_store_m = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.train()\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    model = model.to(DEVICE)\n",
    "    \n",
    "i = 0\n",
    "loss_crit = torch.nn.CrossEntropyLoss()\n",
    "\n",
    "\n",
    "for epoch in trange(NUM_EPOCHS):\n",
    "    for p_batch, q_batch, m_batch in iter(train_dl):\n",
    "        model.train()\n",
    "        i += 1\n",
    "        \n",
    "        optim.zero_grad()\n",
    "        \n",
    "        # CUDA\n",
    "        if torch.cuda.is_available():\n",
    "            p_batch, q_batch, m_batch = p_batch.to(DEVICE), q_batch.to(DEVICE), m_batch.to(DEVICE)\n",
    "            \n",
    "        logP = model(p_batch)\n",
    "        logQ = model(q_batch)\n",
    "        logM = model(m_batch)\n",
    "        \n",
    "        p_label = torch.empty(p_batch.shape[0], dtype=torch.long, device=DEVICE).fill_(0)\n",
    "        q_label = torch.empty(q_batch.shape[0], dtype=torch.long, device=DEVICE).fill_(1)\n",
    "        m_label = torch.empty(m_batch.shape[0], dtype=torch.long, device=DEVICE).fill_(2)\n",
    "        \n",
    "        loss = loss_crit(logP, p_label) + loss_crit(logQ, q_label) + loss_crit(logM, m_label)\n",
    "        loss.backward()\n",
    "        optim.step()\n",
    "        loss_store.append(loss.item())\n",
    "\n",
    "        # Validation/Test\n",
    "        if i % 100 == 0:\n",
    "            model.eval()\n",
    "            \n",
    "            with torch.no_grad():\n",
    "                for p_batch, q_batch, m_batch in iter(test_dl):\n",
    "                    log_ratio_p_q, _ = get_gt_ratio_kl(p, q, m_batch)\n",
    "                    _, kl_from_p_q = get_gt_ratio_kl(p, q, p_batch)\n",
    "\n",
    "                    if torch.cuda.is_available():\n",
    "                        p_batch, q_batch, m_batch = p_batch.to(DEVICE), q_batch.to(DEVICE), m_batch.to(DEVICE)\n",
    "                        \n",
    "                    logP = model(p_batch)\n",
    "                    logQ = model(q_batch)\n",
    "                    logM = model(m_batch)\n",
    "\n",
    "                    log_ratio_p_q_from_cob = logP[:, 0] - logP[:, 1]\n",
    "                    kl_from_cob = torch.mean(log_ratio_p_q_from_cob)\n",
    "                    \n",
    "                    log_ratio_p_q_from_cob = logM[:, 0] - logM[:, 1]\n",
    "\n",
    "                    p_label = torch.empty(p_batch.shape[0], dtype=torch.long, device=DEVICE).fill_(0)\n",
    "                    q_label = torch.empty(q_batch.shape[0], dtype=torch.long, device=DEVICE).fill_(1)\n",
    "                    m_label = torch.empty(m_batch.shape[0], dtype=torch.long, device=DEVICE).fill_(2)\n",
    "                    \n",
    "                    test_loss = loss_crit(logP, p_label) + loss_crit(logQ, q_label) + loss_crit(logM, m_label)\n",
    "\n",
    "                    # Visualize\n",
    "                    # First plot of loss\n",
    "                    line.set_data(range(len(loss_store)), loss_store)\n",
    "                    ax1.set_xlim( 0, len(loss_store))\n",
    "\n",
    "                    kld_store.append(kl_from_cob.cpu().detach().numpy())\n",
    "                    test_loss_store.append(test_loss.item())\n",
    "                    test_loss_store1.append(loss.item())\n",
    "                    \n",
    "                    test_loss_store_p.append(loss_crit(logP, p_label).item())\n",
    "                    test_loss_store_q.append(loss_crit(logQ, q_label).item())\n",
    "                    test_loss_store_m.append(loss_crit(logM, m_label).item())\n",
    "\n",
    "                    scat.set_offsets(np.vstack([log_ratio_p_q.cpu().detach().numpy(),log_ratio_p_q_from_cob.cpu().detach().numpy()]).T)\n",
    "                    ax2.set_xlim( log_ratio_p_q.min(), log_ratio_p_q.max() )\n",
    "                    ax2.set_ylim( log_ratio_p_q.min(), log_ratio_p_q.max() )\n",
    "\n",
    "                    kld_line.set_data(range(len(kld_store)),kld_store)\n",
    "                    ax3.set_xlim( 0, len(kld_store))\n",
    "                    ax3.set_ylim( min(kld_store), max(kld_store) )\n",
    "\n",
    "                    test_line.set_data(range(len(test_loss_store)), test_loss_store)\n",
    "                    test_line1.set_data(range(len(test_loss_store1)), test_loss_store1)\n",
    "                    ax4.set_xlim( 0, len(test_loss_store) )\n",
    "                    \n",
    "\n",
    "                    test_line_p.set_data(range(len(test_loss_store_p)), test_loss_store_p)\n",
    "                    ax5.set_xlim( 0, len(test_loss_store_p) )\n",
    "                    \n",
    "                    test_line_q.set_data(range(len(test_loss_store_q)), test_loss_store_q)\n",
    "                    ax6.set_xlim( 0, len(test_loss_store_q) )\n",
    "                    test_line_m.set_data(range(len(test_loss_store_m)), test_loss_store_m)\n",
    "                    ax7.set_xlim( 0, len(test_loss_store_m) )\n",
    "                    \n",
    "                    clear_output(wait=True)\n",
    "                    display(fig)\n",
    "                    break\n",
    "\n",
    "            model.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "torch1.8",
   "language": "python",
   "name": "torch1.8"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
