{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from iwae import IWAE\n",
    "from data import list_datasets, collate, collate_sort, load_dataset\n",
    "from copy import deepcopy\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from scipy.stats import wasserstein_distance\n",
    "from sklearn.neighbors import KDTree\n",
    "import torch\n",
    "torch.set_default_tensor_type('torch.cuda.FloatTensor') # Comment this out if you want to train without GPUs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = 'mixture'\n",
    "batch_size = 64\n",
    "learning_rate = 1e-3\n",
    "weight_decay = 0\n",
    "epochs = 50\n",
    "display_step = 5\n",
    "patience = 10\n",
    "# params\n",
    "hidden_dim = 64\n",
    "num_layers = 2\n",
    "n_bins = 5\n",
    "att_layers = 5\n",
    "n_heads = 8"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch    5, loss_train = -1.6041, loss_val = -1.6286\n",
      "Epoch   10, loss_train = -1.8546, loss_val = -1.8500\n",
      "Epoch   15, loss_train = -1.9103, loss_val = -1.9036\n",
      "Epoch   20, loss_train = -1.9112, loss_val = -1.9214\n",
      "Epoch   25, loss_train = -1.9286, loss_val = -1.9334\n",
      "Epoch   30, loss_train = -1.8373, loss_val = -1.8864\n",
      "Breaking due to early stopping at epoch 33\n"
     ]
    }
   ],
   "source": [
    "## Load data\n",
    "dset = load_dataset(dataset)\n",
    "trainset, valset, testset = dset.split_train_val_test()\n",
    "\n",
    "collate = collate\n",
    "dl_train = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, collate_fn=collate)\n",
    "dl_val = torch.utils.data.DataLoader(valset, batch_size=batch_size, shuffle=False, collate_fn=collate)\n",
    "dl_test = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=False, collate_fn=collate)\n",
    "\n",
    "## Train model\n",
    "model = IWAE(dim=trainset.dim,\n",
    "             hidden_dim=hidden_dim,\n",
    "             num_layers=num_layers,\n",
    "             n_bins=n_bins,\n",
    "             att_layers=att_layers,\n",
    "             n_heads=n_heads)\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)\n",
    "\n",
    "impatient = 0\n",
    "best_loss = np.inf\n",
    "best_model = deepcopy(model.state_dict())\n",
    "training_val_losses = []\n",
    "\n",
    "for epoch in range(epochs):\n",
    "    # Optimization\n",
    "    model.train()\n",
    "    for batch in dl_train:\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        x, m, _ = batch\n",
    "        loss = model(x, m)\n",
    "\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "    # Validation\n",
    "    model.eval()\n",
    "    loss_val = 0\n",
    "    for i, batch in enumerate(dl_val):\n",
    "        x, m, _ = batch\n",
    "        loss_val += model(x, m).item() / len(dl_val)\n",
    "\n",
    "    training_val_losses.append(loss_val)\n",
    "\n",
    "    # Early stopping\n",
    "    if (best_loss - loss_val) < 1e-4:\n",
    "        impatient += 1\n",
    "        if loss_val < best_loss:\n",
    "            best_loss = loss_val\n",
    "            best_model = deepcopy(model.state_dict())\n",
    "    else:\n",
    "        best_loss = loss_val\n",
    "        best_model = deepcopy(model.state_dict())\n",
    "        impatient = 0\n",
    "\n",
    "    if impatient >= patience:\n",
    "        print(f'Breaking due to early stopping at epoch {epoch}')\n",
    "        break\n",
    "\n",
    "    if (epoch + 1) % display_step == 0:\n",
    "        print(f\"Epoch {epoch+1:4d}, loss_train = {loss:.4f}, loss_val = {loss_val:.4f}\")\n",
    "\n",
    "\n",
    "## Test model\n",
    "model.load_state_dict(best_model)\n",
    "model.eval()\n",
    "\n",
    "test_loss = 0\n",
    "for i, batch in enumerate(dl_test):\n",
    "    x, m, _ = batch\n",
    "    loss = model(x, m, num_samples=5000)\n",
    "    test_loss += loss.item() / len(dl_test)\n",
    "\n",
    "\n",
    "## Sampling quality -- Wasserstein score\n",
    "dist_test, dist_samples = [], []\n",
    "for x in testset:\n",
    "    if len(x[0]) > 2:\n",
    "        dist_test.append(KDTree(x[0]).query(x[0], k=2)[0][:,1])\n",
    "        samples = model.sample(len(x[0])).detach().cpu().numpy()\n",
    "        dist_samples.append(KDTree(samples).query(samples, k=2)[0][:,1])\n",
    "\n",
    "dist_test = np.concatenate(dist_test, 0)\n",
    "dist_samples = np.concatenate(dist_samples, 0)\n",
    "\n",
    "wasserstein = float(wasserstein_distance(dist_test, dist_samples))\n",
    "\n",
    "\n",
    "## Save results\n",
    "results = { 'test_loss': test_loss, 'training_val_losses': training_val_losses,\n",
    "            'final_epoch': epoch, 'wasserstein': wasserstein }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test loss -1.9830\n"
     ]
    }
   ],
   "source": [
    "print(f'Test loss {test_loss:.4f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAASAAAAEcCAYAAABnO2lWAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAbPElEQVR4nO3df0xb190G8OeCSV7oKqBRscuPsSihCZpYSJdIq7qkKylDkwtFwN6B0kibhqqqUqS3GZFA7WhCmzZLF+2valISqdXSDnVNEEpira1KOlirplErUjcIqqUNC3KE23RQWkoIOPf9I7PB5l7fH772se99PlKk2FzDMfg+Pud7zzmWZFmWQUQkQJboBhCRczGAiEgYBhARCcMAIiJhGEBEJAwDiIiE0Qygrq4u3HvvvXjooYcUvy7LMp599lnU1taivr4eIyMjljeSiOxJM4Camppw7Ngx1a8PDQ1hfHwcb731Fp555hns27fPyvYRkY1pBtDWrVuRn5+v+vWBgQE0NjZCkiRUV1djZmYGX3zxhaWNJCJ7ciX6DYLBIDweT+S2x+NBMBhEUVFR3MdduHABq1evTvTHE1EamJ+fR3V1teHHJRxASis5JEnSfNzq1atRWVmZ6I8nojQwOjpq6nEJXwXzeDyYnJyM3J6cnNTs/RARARYEUE1NDfr7+yHLMi5cuIDbb7+dAUREumgOwfbs2YPz589jamoK27dvx+7du7G4uAgAaGtrw/3334/BwUHU1tYiNzcXzz33XNIbTUT2IInajmN0dJQ1ICKbMHs+cyY0EQnDACIiYRhARCQMA4iIhGEAEZEwDCAiEoYBRETCMICISBgGEBEJwwAiImEYQEQkDAOIiIRhABGRMAwgIhKGAUREwjCAiEgYBhARCcMAIiJhGEBEJAwDiIiEYQARkTAMICIShgFERMIwgIhIGAYQEQnDACIiYRhARCQMA4iIhGEAEZEwDCAiEoYBRETCMICISBgGEBEJwwAiImEYQEQkDAOIiIRhABGRMAwgIhKGAUREwjCAiEgYBhARCaMrgIaGhlBXV4fa2locOXJkxde/+eYbPPbYY2hoaIDX68XJkyctbygR2Y9mAIVCIfT09ODYsWPw+Xw4c+YMLl26FHXMq6++inXr1uHUqVM4fvw4/vCHP+DGjRtJazQR2YNL6wC/34/y8nKUlZUBALxeLwYGBrB+/frIMZIkYXZ2FrIsY3Z2Fvn5+XC5NL81Zaj+4QBeePNTXJ2eQ3FBLvbWbUDj5hLRzaIMpNkDCgaD8Hg8kdtutxvBYDDqmJ07d+Kzzz7Dtm3b0NDQgCeffBJZWSwv2VH/cABdfZ8gMD0HGUBgeg5dfZ+gfzggummUgTRTQpblFfdJkhR1+91330VlZSX++c9/or+/Hz09Pfj222+tayWljRfe/BRzC6Go++YWQnjhzU8FtYgymWYAeTweTE5ORm4Hg0EUFRVFHdPX14ef//znkCQJ5eXlKC0txeeff259a0m4q9Nzhu43o384gPsOnsXaTh/uO3iWvSsb0wygqqoqjI+PY2JiAjdu3IDP50NNTU3UMXfddRfef/99AMC1a9dw+fJllJaWJqfFJFRxQa6h+43iEM9ZNCvFLpcL3d3daG9vRygUQnNzMyoqKtDb2wsAaGtrw+OPP46uri7U19dDlmV0dHTgjjvuSHrjKXXChefA9BwkAMsH5rk52dhbt8GSnxNviMdCt/1IslKRJwVGR0dRWVkp4keTQeFeyfJgCIdQicVXwdZ2+qD0gpQAXD7oteRnkPXMns+8Vk6alHol4fB5r7NG+UEmFRfkIqBQT7JqiEfphdfKSVMqCs9he+s2IDcnO+o+K4d4lF4YQKQp2YXn5Ro3l+D5piqUFORCwq1e1vNNVaz/2BSHYKRpb92GFTWgZPZKGjeXMHAcggFEmsJhwOUXZDUGEOliVa+E68hoOQYQpUzs5fzwJEMADCGHYgCRJfT0bDjJkGIxgChhSj2bJ167gA///R8821gVOc7o5XwO1+yPl+EpYWoTFV89dyVqDZeRy/lcE+YMDCBKmFoPRgaitukwMsmQ2344AwOIEhZvQuLycDIyyTCVs69JHNaAHMjq2soDG+/EK+euKH4tPzcn6rbey/lcE+YM7AE5TDJqK++Mfan6tZnrC6Y2FuOaMGdgD8jGlHo6ybgUHm9YdPO/e2sYnfPD2dfOwACyKbVJf7HhE5ZIbUVtuBTLaNBxTZj9MYBsSq2nky1JCCnsQae3thLbq3pg45347sai7naxiEzLMYBsSu1ED8kycnOyTa1sV+pVqRWfY7dtDWMRmZZjEdqm1E708KVvo/vt9A8H8Lu/faw6hIv9GTt/8n1IMfeziEyx2AOyqXh7+BitrYR7PkpDNyWB6Tmc/CgQ1QOSADT/mDUdisYAsikrryIp1ZPiyZYkxaUZ8S7XkzMxgGxMradjdCKikcJxbH3J7PchZ2ANyGHMTESMVzjOWfYKKszLidSXjH6fcNv4iajOwgByGDOLPPfWbVhRUA5buLn0/+v/vfHAxjsVj1W7H+Dqd6diADmMmUWejZtLFC+pxwoHmVqtJ14NiKvfnYk1IIcxu8izROds53jHhENOqQbF1e/OxB6Qw5hd5Kn0ODVqw7XiglzVoVbsqvnljyH7Yg/IYcxeno99XN6qbMzeUL7aJWPlTOhwyKkNtf4nJ8v0DG3KXJIs65xdZjGzH2ZP6eG+g2c1h2QlBbkrQm5tp0+xniQB+NOvqrn6PUOZPZ/ZAyJTtGozJQW5eK+zZsX98WpQXP3uPKwBkSnxajPxhk7caIyWYwCRKWpF6YLcnLiLW43sC032xyEYmZLIWjMOtSiMPSAyhR8aSFZgD4gM42e8k1UYQBRXqja2J2diAJGqVG5sT87EGhCpirexvRIumyCj2AMyyEnF12RsbE+0HHtABjhtzxqrN7YnisUekAFOK75aubG9Fif1LGmJrh7Q0NAQ6urqUFtbiyNHjige88EHH+Dhhx+G1+vFI488Ymkj04XT9qxJ1axlp/UsaYlmDygUCqGnpwcvvfQS3G43WlpaUFNTg/Xr10eOmZmZwf79+3Hs2DEUFxfjq6++SmqjRTG7mVcmS8WsZaf1LGmJZg/I7/ejvLwcZWVlWLVqFbxeLwYGBqKOOX36NGpra1FcXAwAWLNmTXJaKxgXUiaH03qWtEQzgILBIDweT+S22+1GMBiMOmZ8fBwzMzPYtWsXmpqa0N/fb31L0wAXUiaHWg/Szj1LukVzCKa0X5kUMw8kFAphZGQEL7/8Mq5fv47W1lZs2rQJa9euta6laYILKa0Xr9hN9qYZQB6PB5OTk5HbwWAQRUVFK44pLCxEXl4e8vLysGXLFoyNjdkygMh6Vn6KK2UWzQCqqqrC+Pg4JiYm4Ha74fP5cPjw4ahjduzYgZ6eHiwuLmJhYQF+vx+//vWvk9VmsiH2LJ1JM4BcLhe6u7vR3t6OUCiE5uZmVFRUoLe3FwDQ1taGdevWYdu2bWhoaEBWVhZaWlpw9913J73xRJTZuCk9ESXM7PnMpRhEJAwDiIiEYQARkTAMICIShgFERMIwgIhIGAYQEQnDACIiYRhARCQMA4iIhGEAEZEwDCAiEoYBRETCMICISBgGEBEJwwAiImEYQEQkDAOIiIRhABGRMAwgIhKGAUREwjCAiEgYzc8FI3KK/uEAP501xRhARLgVPss/nz4wPYeuvk8AgCGURByCEeHW59KHwydsbiGEF978VFCLnIEBRATg6vScofvJGgwgIgDFBbmG7idrMICIAOyt24DcnOyo+3JzsrG3boOgFjkDi9BEWCo08ypYajGAiP6rcXMJAyfFOAQjImEYQEQkDAOIiIRhABGRMCxCZwiuUyI7YgBlAK5TIrviECwDcJ0S2RUDKANwnRLZFQMoA3CdEtkVa0AZYG/dhqgaEJD8dUpGi94skpMZDCCdRJ5gqV6nZLTozSI5maVrCDY0NIS6ujrU1tbiyJEjqsf5/X5UVlbijTfesKyB6SB8ggWm5yBj6QTrHw6krA2Nm0vwXmcNLh/04r3OmqSe2EaL3iySk1maARQKhdDT04Njx47B5/PhzJkzuHTpkuJxf/zjH/HTn/40KQ0VyWknmNGiN4vkZJZmAPn9fpSXl6OsrAyrVq2C1+vFwMDAiuOOHz+Ouro6rFmzJikNFclpJ5jRojeL5GSWZgAFg0F4PJ7IbbfbjWAwuOKYt99+G62trda3MA047QQzujkXN/MiszSL0LIsr7hPkqSo2wcOHEBHRweys7NXHGsHIq5CiWS06J0pm3nxSl360Qwgj8eDycnJyO1gMIiioqKoYy5evIg9e/YAAKampjA4OAiXy4UHH3zQ4uaKkSknmJWMbs6V7pt58UpdetIMoKqqKoyPj2NiYgJutxs+nw+HDx+OOubs2bOR/3d2duJnP/uZbcInLN1PsFSL7U08sPFOvDP2ZdoGdLwLCenUTqfRDCCXy4Xu7m60t7cjFAqhubkZFRUV6O3tBQC0tbUlvZGUXpR6E6+cuxL5uujehdJQy2kXEjKFJCsVeVJgdHQUlZWVIn40Jei+g2cR0HHilhTk4r3OmhS0aElsOAK36nWrXVmYnltQfExJGvbYMo3Z85lrwcgwvb0GEb0LtaGWJGHFlbowERNL6RYGEBmmd/qBiGkKaqE3/d0Cnm+qQolKm+w8sTSdMYDIMKV5P7FETVOIN2crvJxFUjyC9SARGEBkWOPmkkhvQsKtGsojP/l+1O3nm6qE1FT0TIp02sTSdMbV8GRKuk5L0DNny2kTS9MZA4hsRyscnTixNF0xgMiRktGD41IP4xhARBbgUg9zWIQmsoDT9oyyCntAFuEeys7GpR7mMIAswD2UqbggV3F5Ci/tx8chmAW4hzJxUzZzGEAWsGqv5MD0HO47eJZrkjKQ0uRMUZMxMwmHYBYw2v1WOx6wdjjGOtMSq34X8b5Puk7OTGfsAVnAij2Ul7NiOJYOHyUkWv9wAPcdPIsfdPrwxGsXEv5d8HdqPe4HZBGzV8Hi7atTUpBr+h1bbc8eEXv0iKC0L1AsPb+L5X/XLElCSOV0cfqeQmbPZw7BLGJ2D2W1oJCAyP1mhmVOvyysVOiPpfW7iA0xtfABeCXTLA7BBFMajkkAYl/qRodldlzxHR5Sre30aRbr9QRtliSpfo/+4QB+97ePNUNsOV7JNI4BJJjS1RO191kjvRe7XRY2Wn/RE7QhWVb8Hv3DAex9/eO4PR414SuZekKSOARLCa360PLhW/idV+nFn5+bo/tn2m3Ft55PtVj+ey7Iy0FOloSFm/FDROmTMfadGtF8nJpEh85OwwAywMylXCOznsPHqr3zzt5YRP9wQPeL2U6XhbVqWrG/56nvbm1AHx7OlsSZ+hD7vdU2rwdu9SLjDcvUhs52+TtYjUMwncxeglV7595/emRFV12rcLoQkvG7v33syO69Vk1L7XcnY2noqbYftJG6WLx9pdU4pfBvBgNIJ7PLJ9RefFPfLUSF2d7XP9b1UTchWXbkHBStmla8kzz8d9JTF+sfDiBLZdPowrwcNG4uUfw+OdmS6l7TmVz4TzYGkE5mL2vrffGZqTk46aqL1lIHrd/z1ek5ze8R7uUq/SlysiU8Xf9DxbYU5uUA8srhF5DZhf9U4EREncxO7NMzIS4REoDLB71J+d7pSqkWByDu71nP30mt+J8tSTj8v5tU6zhqrw2tx9kJJyImmdmNzMMvvv977UJS2mX37r3SZ9Cf/Ciwoqj/fFMVnm+qwv7TI5ECdJjW30mr+H9TltG4uSSqLfm5OZCkW583pvYOHn4cqWMA6WTmsvbyF2x2nGn8ifjO4JWxTKJ0BfHVc1dUrzS911mzIij0/J20iv8ygM09b+Hb64uRoXK8K2Vhdn9zsAIDyAAjl7WNTOM3InaW9NR3C7ada6IUDHomaRqdfqDnKlVsr0qLBLD2owOL0EmiZy2SGUonoF2L0UYuXyfS20hGT0WG/d4QkoEBlCSJzv1Qu6Sr5+cZWTOVztSCIfZ3k+iVJj0fNW2U0blCTsUAShKz76olBbkYP+jFn35VjRy1CSlxfp6d9qxRm7ez06KPgQ4H9ROvXcBqVxYK83Ii37MwT/+yl1i89K4fa0BJ0D8cwOz84or7c7IkQLo1o1lNuCfTuLlE8YoOsLIOtPwFr2fNVKZI5nq22Brd9NwCcnOy8adfVUcK2bFXPXOyJdy2yoWv5xairoIV5OVAloGv5xYyfs1dqjGALBZeSR07sbAwLycykS3eRmTLe07TKoXP8NompZPSbvsAJWs9m1ZQK4XfAxvvxDtjX+LruQXcttqVUNBwu9xbGEAWU1tJLctL7+hq77CxXXe1vaPjTarjx8PoE++DAdZ2+iKhEP49W/lRSvxYpiWsAVlMbX5I7P16PkVB79ql5QXnBzbeaat9gJIlXiAr1c601gIaKfzzY5mWsAeUQrETBrWGF1o1EKV30pMfBdD84xK8M/al47v38SjNbI+1fEgWr8f0g05fVF1Oq0djt2FyIhhAFivMy1GdtGammx0vpNTeSd8Z+9IRG88D5mspjZtL8OG//4PeDyY093ruHw7E/SglwNg+QBwmL+EQzGJP1/8QOdnKl8/1drP1dued/k6ayJSD/uEATn4U0DVDvavvE8WhrRa1v4PdtstNBAPIYo2bS/BCyybVr+v9JAY9J5UdN543IpFaipGZ6uFe5fKanR5qfwd+iuoSDsGSoHFzieqldq1wMDKPx+wKfbtIpAdotJcY3k8o/DdQ24IjTOvvYKftchPBHlCSmO1mGzmpnP5OmkgPUO2YbEm5f1NckBs1NJ6dX1wx1A7fctrfIRG6ekBDQ0M4cOAAbt68iV/+8pd49NFHo75+6tQpHD16FABw2223Yd++fdi4caP1rc0gZmfxGi1QOvmdNJEeoNpjm39cErXfUPj+BzbeuWLmdE6WhMK8HEx/xxnQZmkGUCgUQk9PD1566SW43W60tLSgpqYG69evjxxTWlqKV155Bfn5+RgcHMTvf/97vP7660lteCYwEw5OH1YZkchSjXiP3VJ+x4r7lYbGCzdl5K1yYbj759Y/OYfQDCC/34/y8nKUlZUBALxeLwYGBqIC6J577on8v7q6GpOTk0loqjPY7fO8ki2RHqDaY5Xuf0JlR0unXHFMFs0ACgaD8Hg8kdtutxt+v1/1+BMnTmD79u3WtM6hnDysSlecu5McmkVopT3rJZVC3blz53DixAl0dHQk3jKiNMK5O8mh2QPyeDxRQ6pgMIiioqIVx42NjeGpp57C0aNHUVhYaG0riQTj0Dg5NAOoqqoK4+PjmJiYgNvths/nw+HDh6OOuXr1Knbv3o1Dhw5h7dq1SWsskUgcGltPM4BcLhe6u7vR3t6OUCiE5uZmVFRUoLe3FwDQ1taGF198EdPT09i/fz8AIDs7G319fcltORFlPH4wIRElzOz5zJnQRCQMA4iIhGEAEZEwDCAiEoYBRETCMICISBgGEBEJwwAiImEYQEQkDAOIiIRhABGRMAwgIhKGAUREwjCAiEgYBhARCcMAIiJhGEBEJAwDiIiEYQARkTAMICIShgFERMIwgIhIGAYQEQnDACIiYRhARCQMA4iIhGEAEZEwDCAiEoYBRETCMICISBgGEBEJwwAiImEYQEQkDAOIiIRhABGRMAwgIhKGAUREwjCAiEgYBhARCcMAIiJhGEBEJIyuABoaGkJdXR1qa2tx5MiRFV+XZRnPPvssamtrUV9fj5GREcsbSkT2oxlAoVAIPT09OHbsGHw+H86cOYNLly5FHTM0NITx8XG89dZbeOaZZ7Bv375ktZeIbEQzgPx+P8rLy1FWVoZVq1bB6/ViYGAg6piBgQE0NjZCkiRUV1djZmYGX3zxRdIaTUT24NI6IBgMwuPxRG673W74/f64x3g8HgSDQRQVFal+3/n5eYyOjpppMxGlmfn5eVOP0wwgWZZX3CdJkuFjYlVXV2v9aCKyOc0hmMfjweTkZOS2Us8m9pjJycm4vR8iIkBHAFVVVWF8fBwTExO4ceMGfD4fampqoo6pqalBf38/ZFnGhQsXcPvttzOAiEiT5hDM5XKhu7sb7e3tCIVCaG5uRkVFBXp7ewEAbW1tuP/++zE4OIja2lrk5ubiueeeS3rDiSjzSbJSAYeIKAU4E5qIhGEAEZEwSQ8gOyzj0HoOp06dQn19Perr69Ha2oqxsTEBrVSn1f4wv9+PyspKvPHGGylsnT56nsMHH3yAhx9+GF6vF4888kiKW6hN6zl88803eOyxx9DQ0ACv14uTJ08KaKW6rq4u3HvvvXjooYcUv27qXJaTaHFxUd6xY4d85coVeX5+Xq6vr5f/9a9/RR3zj3/8Q/7tb38r37x5Ux4eHpZbWlqS2STD9DyHjz76SJ6enpZl+dbzSafnoKf94eN27dolt7e3y3//+98FtFSdnufw9ddfy7/4xS/kQCAgy7IsX7t2TURTVel5Dn/+85/lQ4cOybIsy1999ZW8detWeX5+XkRzFZ0/f16+ePGi7PV6Fb9u5lxOag/IDss49DyHe+65B/n5+QBuTbBcPidKND3tB4Djx4+jrq4Oa9asEdDK+PQ8h9OnT6O2thbFxcUAkHbPQ89zkCQJs7OzkGUZs7OzyM/Ph8uleaE6ZbZu3Rp5nSsxcy4nNYCUlnEEg8G4x4SXcaQLPc9huRMnTmD79u2paJouev8Gb7/9NlpbW1PdPF30PIfx8XHMzMxg165daGpqQn9/f6qbGZee57Bz50589tln2LZtGxoaGvDkk08iKytzyrRmzuWkxqucpGUcqWSkfefOncOJEyfw17/+NdnN0k1P+w8cOICOjg5kZ2enqlmG6HkOoVAIIyMjePnll3H9+nW0trZi06ZNWLt2baqaGZee5/Duu++isrISf/nLX3DlyhX85je/wZYtW/C9730vVc1MiJlzOakBZIdlHHqeAwCMjY3hqaeewtGjR1FYWJjKJsalp/0XL17Enj17AABTU1MYHByEy+XCgw8+mNK2qtH7OiosLEReXh7y8vKwZcsWjI2NpU0A6XkOfX19ePTRRyFJEsrLy1FaWorPP/8cP/rRj1LdXFPMnMtJ7d/ZYRmHnudw9epV7N69G4cOHUqbF3yYnvafPXs28q+urg5PP/102oQPoO857NixAx9++CEWFxcxNzcHv9+PdevWCWrxSnqew1133YX3338fAHDt2jVcvnwZpaWlIppriplzOak9IDss49DzHF588UVMT09j//79AIDs7Gz09fWJbHaEnvanOz3PYd26dZHaSVZWFlpaWnD33XcLbvkSPc/h8ccfR1dXF+rr6yHLMjo6OnDHHXcIbvmSPXv24Pz585iamsL27duxe/duLC4uAjB/LnMpBhEJkzkldiKyHQYQEQnDACIiYRhARCQMA4iIhGEAEZEwDCAiEub/AYsUuF7whssFAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 288x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "x = model.sample(100).detach().cpu().numpy()\n",
    "plt.figure(figsize=(4, 4))\n",
    "plt.scatter(x[:,0], x[:,1])\n",
    "plt.tight_layout()\n",
    "plt.xlim([0, 1])\n",
    "plt.ylim([0, 1])\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
