{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from confet import ConfetStochastic, ConfetFixed, ConfetExact\n",
    "from data import list_datasets, collate, collate_sort, load_dataset\n",
    "from copy import deepcopy\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from scipy.stats import wasserstein_distance\n",
    "from sklearn.neighbors import KDTree\n",
    "import torch\n",
    "torch.set_default_tensor_type('torch.cuda.FloatTensor') # Comment this out if you want to train without GPUs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = 'mixture' # Dataset name, options: mixture, thomas, checkins_ny, checkins_paris, crimes, matern\n",
    "model = ConfetExact # Model class, options: ConfetStochastic, ConfetFixed, ConfetExact\n",
    "batch_size = 64     # Size of the mini-batch\n",
    "learning_rate = 1e-3 # Learning rate\n",
    "weight_decay = 1e-3 # Weight decay\n",
    "epochs = 50         # For how many epochs to tran max\n",
    "display_step = 5    # Every how many steps to show output\n",
    "patience = 10       # How many steps before early stopping\n",
    "# params\n",
    "num_layers = 3      # How many layers\n",
    "hidden_dim = 64     # Size of the hidden layers\n",
    "regularization_param = 0.0 # Unused\n",
    "solver = 'dopri5'   # Which solver to use, options: dopri5, rk4\n",
    "solver_step = None  # Solver step size if solver rk4\n",
    "n_heads = 1         # Number of attention heads\n",
    "induced = False     # Whether to use induced attention\n",
    "n_points = None     # If using induced attention, how many inducing points\n",
    "num_coupling_layers = 1 # How many coupling layers\n",
    "interaction = 'sum' # Which aggregation function to use, options: sum, max, attention\n",
    "conditioner_dim = 8 # Size of the d_h dimension in the decoupled network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch    5, loss_train = -1.0487, loss_val = -1.0257\n",
      "Epoch   10, loss_train = -1.1884, loss_val = -1.1550\n",
      "Epoch   15, loss_train = -1.3368, loss_val = -1.3326\n",
      "Epoch   20, loss_train = -1.8487, loss_val = -1.8596\n",
      "Epoch   25, loss_train = -1.9045, loss_val = -1.9290\n",
      "Epoch   30, loss_train = -1.9481, loss_val = -1.9623\n",
      "Epoch   35, loss_train = -1.9903, loss_val = -1.9773\n",
      "Epoch   40, loss_train = -1.9540, loss_val = -1.9915\n",
      "Epoch   45, loss_train = -2.0066, loss_val = -1.9977\n",
      "Epoch   50, loss_train = -1.9891, loss_val = -2.0025\n"
     ]
    }
   ],
   "source": [
    "## Load data\n",
    "dset = load_dataset(dataset)\n",
    "trainset, valset, testset = dset.split_train_val_test()\n",
    "\n",
    "collate = collate\n",
    "dl_train = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, collate_fn=collate)\n",
    "dl_val = torch.utils.data.DataLoader(valset, batch_size=batch_size, shuffle=False, collate_fn=collate)\n",
    "dl_test = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, collate_fn=collate)\n",
    "\n",
    "## Train model\n",
    "model = model(dim=trainset.dim,\n",
    "              hidden_dim=hidden_dim,\n",
    "              num_layers=num_layers,\n",
    "              regularization_param=regularization_param,\n",
    "              solver=solver,\n",
    "              solver_step=solver_step,\n",
    "              n_heads=n_heads,\n",
    "              induced=induced,\n",
    "              n_points=n_points,\n",
    "              num_coupling_layers=num_coupling_layers,\n",
    "              interaction=interaction,\n",
    "              conditioner_dim=conditioner_dim)\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)\n",
    "\n",
    "impatient = 0\n",
    "best_loss = np.inf\n",
    "best_model = deepcopy(model.state_dict())\n",
    "training_val_losses = []\n",
    "\n",
    "for epoch in range(epochs):\n",
    "    # Optimization\n",
    "    model.train()\n",
    "    for batch in dl_train:\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        x, m, _ = batch\n",
    "        loss = model(x, m)\n",
    "\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "    # Validation\n",
    "    model.eval()\n",
    "    loss_val = 0\n",
    "    for i, batch in enumerate(dl_val):\n",
    "        x, m, _ = batch\n",
    "        loss_val += model(x, m).item() / len(dl_val)\n",
    "\n",
    "    training_val_losses.append(loss_val)\n",
    "\n",
    "    # Early stopping\n",
    "    if (best_loss - loss_val) < 1e-4:\n",
    "        impatient += 1\n",
    "        if loss_val < best_loss:\n",
    "            best_loss = loss_val\n",
    "            best_model = deepcopy(model.state_dict())\n",
    "    else:\n",
    "        best_loss = loss_val\n",
    "        best_model = deepcopy(model.state_dict())\n",
    "        impatient = 0\n",
    "\n",
    "    if impatient >= patience:\n",
    "        print(f'Breaking due to early stopping at epoch {epoch}')\n",
    "        break\n",
    "\n",
    "    if (epoch + 1) % display_step == 0:\n",
    "        print(f\"Epoch {epoch+1:4d}, loss_train = {loss:.4f}, loss_val = {loss_val:.4f}\")\n",
    "\n",
    "\n",
    "## Test model\n",
    "model.load_state_dict(best_model)\n",
    "model.eval()\n",
    "\n",
    "test_loss = 0\n",
    "for i, batch in enumerate(dl_test):\n",
    "    x, m, _ = batch\n",
    "    loss = model(x, m)\n",
    "    test_loss += loss.item() / len(dl_test)\n",
    "\n",
    "\n",
    "## Sampling quality -- Wasserstein score\n",
    "dist_test, dist_samples = [], []\n",
    "for x in testset:\n",
    "    if len(x[0]) > 2:\n",
    "        dist_test.append(KDTree(x[0]).query(x[0], k=2)[0][:,1])\n",
    "        samples = model.sample(len(x[0])).detach().cpu().numpy()\n",
    "        dist_samples.append(KDTree(samples).query(samples, k=2)[0][:,1])\n",
    "\n",
    "dist_test = np.concatenate(dist_test, 0)\n",
    "dist_samples = np.concatenate(dist_samples, 0)\n",
    "\n",
    "wasserstein = float(wasserstein_distance(dist_test, dist_samples))\n",
    "\n",
    "\n",
    "## Save results\n",
    "results = { 'test_loss': test_loss, 'training_val_losses': training_val_losses,\n",
    "            'final_epoch': epoch, 'wasserstein': wasserstein }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test loss -2.0289\n"
     ]
    }
   ],
   "source": [
    "print(f'Test loss {test_loss:.4f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAASAAAAEcCAYAAABnO2lWAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAcdUlEQVR4nO3dfWxUZaIG8GfaKe7Uj1II7UDb7ZKK0DWVopBdL4JrcZzoONjQJtsG2cTYu3FNSO6ykEDECojK4m68+4cxKSR+oDa7Fm4DTFaJrdtGInAhxYHe1l2Uhu6QjuK24pbSj+HcP3TGmfZ8zZmP98yZ55eYWHo6856ZOc+838cmSZIEIiIBckQXgIiyFwOIiIRhABGRMAwgIhKGAUREwjCAiEgYzQDatm0b7r33Xjz66KOyv5ckCbt374bL5YLX60Vvb2/SC0lE1qQZQOvWrcP+/fsVf9/d3Y2BgQEcO3YMzz//PHbs2JHM8hGRhWkG0IoVK1BQUKD4+46ODtTW1sJms6G6uhpXr17Fl19+mdRCEpE12RN9gGAwCKfTGfnZ6XQiGAyiqKhI9e/Onj2Lm266KdGnJyITGB8fR3V1ddx/l3AAya3ksNlsmn930003obKyMtGnJyIT6OvrM/R3CY+COZ1ODA0NRX4eGhrSrP0QEQFJCKCamhq0t7dDkiScPXsWt956KwOIiHTRbIJt2rQJp06dwvDwMFavXo2NGzdiamoKANDY2Ij7778fXV1dcLlccDgcePHFF1NeaCKyBpuo7Tj6+vrYB0RkEUavZ86EJiJhGEBEJAwDiIiEYQARkTAMICIShgFERMIwgIhIGAYQEQnDACIiYRhARCQMA4iIhGEAEZEwDCAiEoYBRETCMICISBgGEBEJwwAiImEYQEQkDAOIiIRhABGRMAwgIhKGAUREwjCAiEgYBhARCcMAIiJhGEBEJAwDiIiEYQARkTAMICIShgFERMIwgIhIGAYQEQnDACIiYRhARCQMA4iIhGEAEZEwDCAiEoYBRETCMICISBgGEBEJwwAiImF0BVB3dzfcbjdcLhdaWlpm/P7bb7/FU089hbVr18Lj8eDgwYNJLygRWY9mAIVCIezatQv79++Hz+fD0aNHceHChZhj3nnnHVRUVODw4cM4cOAAfv/732NiYiJlhSYia9AMIL/fj/LycpSVlWHWrFnweDzo6OiIOcZms2F0dBSSJGF0dBQFBQWw2+0pKzQRWYNmSgSDQTidzsjPxcXF8Pv9McesX78ev/nNb7Bq1SqMjo7ilVdeQU4Ou5esor0ngJc/+AyXR8awYLYDW9yLUbusRHSxyAI0A0iSpBn/ZrPZYn7++OOPUVlZibfeeguXLl3CE088geXLl+OWW25JXklJiPaeALYdOoexyRAAIDAyhm2HzgEAQ4gSpllNcTqdGBoaivwcDAZRVFQUc8yhQ4fw0EMPwWazoby8HKWlpfjiiy+SX1pKu5c/+CwSPmFjkyG8/MFngkpEVqIZQFVVVRgYGMDg4CAmJibg8/lQU1MTc8z8+fPxySefAACuXLmCixcvorS0NDUlprS6PDIW178TxUOzCWa329Hc3IympiaEQiHU1dVh0aJFaG1tBQA0Njbi6aefxrZt2+D1eiFJEjZv3ow5c+akvPCUegtmOxCQCZsFsx1xPQ77kUiOTZLr5EmDvr4+VFZWinhqisP0PiAAcOTl4qV1VboDJBmPQeZm9HrmUBWpql1Wgrp7SpD7/cBDrs2GuntK4goO9iOREk7WIdXmUXtPAAfPBBD6vqIckiQcPBPA8vI5ukOI/UikhDWgLBduHgVGxiDhh2H29p4AAO3aS3tPACv3dGLhVh9W7umM/F00pf6iePuRyHoYQFlOK2DUai9a4RW2xb0YjrzcmH9z5OVii3tx8k6EMhIDKMtpNY/Uai96+3Zql5XgpXVVKJntgA1AyWwHO6AJAPuAsp7WMPsW92LZEawt7sX47Z/Pyj5mOLw49E5aWAPKclrNI7Xai1rtSG/zjLIbAyjLJdI8UgsvDr2THmyCEWqXKc/r0bMYVa6ZpdU8U8JmW3ZhAJEqtZpMOLjkAkLPEo7pYfPAknk4eCbAlfdZhE2wLKBnro4So5MItfqW5PqI3jlxic22LMMakMUlup+P0cWo4cfecbgXI2OTAIAf5f3wfSdXs1JalMhRNetiAFmcVhNKi9owvJzokHDk5eDa5I3I74avTUbCL55lGNGjamyeWQubYBaX6DqseEbJpjerosMnLBx+epdhaI2q7Tjcq+txyJwYQBaXjHVYtctKcHxrDS7u8eD41hrFGodcSMi5PDIm20ckJ7zyXikwR8Ymsb39nObjkDkxgCxOzzqsRDqpo+mtVS2Y7YipWan5qP+ryN8oeefEJU5wzFAMIIvTakIlOmM5Orym3atAlg2ImWV9fGuNagiFQ01t4aoEcKQsQ7ETOguoTTRU61vR6tyd3jGsZ2/N/6iYExN+L3/wmewoW1i45lO7rAQ7j/Ri+Nqk7HHcWygzsQaU5dT6Vn6i0STT2+cTbeDrH4bUwzUvJTYADyyZF/n5Oe+dUKpkcW+hzMQAynJaF25gZAxb2j5F9c5jM/qIjNQ6wn+jJ7wkAAfPBCLPV7usBOt//uMZIcS9hTIXAyjL6blwJ0MSRsYmZ/QRGal1FDjyAOgPr+kzoXfXVuGVX1ZzbyGLYB9QltPqW5ETDgW5SYpaRiemIuGl1vyKNj2s1Pq0KLOwBkR4znunrjk50S6PjM0YYSvMz8NsR16kZnLzrJmPORmSIuGl9znZv2NdrAFRzLYagZEx2KC8LissenRKqTaycKtP9t/D4RV+zssjY8ix2SJ33ogWPWxP1sMAIgCxQaJneFxPKGgtZI1+TqWwksC1XlbGJhjNoDVBsDA/L+EdE6dTamZpzZSmzMYAIkVKAfKc904A6ks4wrWosclQ5K6qaiNWesMqWctGyBzYBCNFaluuqm2PASDmdyFJioSJUs1J7bnCuCWH9dgkSc8E+uQzejN7MoeVezpl+3fCTSal3x3fWpOS50zkcSlxRq9n1oDIECP7DCW6Xov3mLceBhAZojXCZWQb1+mmb8Fa4MiLbO+ayOOSebATmgxR6zROxr3g5bYJGZ2YQl5O7EowrgPLbKwBkSF6Oo0T2UBebrHqZEhCYX4e8mfZuTG9RTCAyDC1WdCJrtdS3Cbk2iR6mh8y/LhkLmyCkSklYy9rMj8GEJlSMvqRyPzYBCNT0tPHRJmPAZRlMunuolr9SJl0LiSPAZRFrLSUwUrnks3YB5RF1G7TnGmsdC7ZjAGURay0lMFK55LNdAVQd3c33G43XC4XWlpaZI85efIkHnvsMXg8Hjz++ONJLSQlh5WGtq10LtlMM4BCoRB27dqF/fv3w+fz4ejRo7hw4ULMMVevXsXOnTvx2muvwefz4U9/+lPKCkzGWWlo20rnks00A8jv96O8vBxlZWWYNWsWPB4POjo6Yo45cuQIXC4XFixYAACYO3duakpLCdG6TXMmsdK5ZDPNUbBgMAin0xn5ubi4GH6/P+aYgYEBTE1NYcOGDRgdHcWvfvUr1NbWJr+0lDAr3dLGSueSrTQDSG6/MpstdkVyKBRCb28v3njjDVy/fh0NDQ1YunQpFi5cmLySEpHlaAaQ0+nE0NBQ5OdgMIiioqIZxxQWFiI/Px/5+flYvnw5+vv7GUAmxkl8ZAaafUBVVVUYGBjA4OAgJiYm4PP5UFMTu/3lmjVrcPr0aUxNTWFsbAx+vx8VFRUpKzQlRm6vnfDtlonSSbMGZLfb0dzcjKamJoRCIdTV1WHRokVobW0FADQ2NqKiogKrVq3C2rVrkZOTg/r6etxxxx0pLzwZozaJj7UgSiduSp+FFm71yd751Abg4h5PuotDFmD0euZM6CzESXxkFgygLMRJfGQWXA2fhbjXDpkFAyhLcRIfmQGbYEQkDAOIiIRhABGRMAwgIhKGAUREwjCAiEgYBhARCcMAIiJhGEBEJAwDiIiEYQARkTAMICIShgFERMIwgIhIGG7HQZbFO3+YHwOILCl854/w5vvhO38AYAiZCJtgZElqd/4g82AAkSVdHhmL699JDAYQWRLv/JEZGEBkSbzzR2ZgJzRZEu/8kRkYQGRZvPOH+bEJRkTCMICISBgGEBEJwwAiImEYQEQkDAOIiIRhABGRMAwgIhKGAUREwnAmdApwIywifRhAScaNsIj0YxMsybgRFpF+DKAk40ZYRPoxgJKMG2ER6ccASjJuhEWkn64A6u7uhtvthsvlQktLi+Jxfr8flZWVeP/995NWwExTu6wEL62rQslsB2wASmY78NK6KnZAE8nQHAULhULYtWsXXn/9dRQXF6O+vh41NTW4/fbbZxz3hz/8Affdd1/KCpsp0r0RFof9KVNpBpDf70d5eTnKysoAAB6PBx0dHTMC6MCBA3C73Th37lxqSmpiIgOAw/7mwC8BYzQDKBgMwul0Rn4uLi6G3++fccyHH36IN99809IBJPchAyA0ANSG/XkBpAe/BIzT7AOSJGnGv9lstpifX3jhBWzevBm5ubkzjrWK8IcsMDIGCT98yHYc7hU674fD/uJx7pdxmjUgp9OJoaGhyM/BYBBFRUUxx5w/fx6bNm0CAAwPD6Orqwt2ux0PPvhgkosrjtKHbPq/haUrABbMdiAg81wc9k8ffgkYpxlAVVVVGBgYwODgIIqLi+Hz+fDHP/4x5pjOzs7I/2/duhW/+MUvLBU+QPwfpnQFwBb34pjqP5DZw/6Z2JfCLwHjNJtgdrsdzc3NaGpqwiOPPIKHH34YixYtQmtrK1pbW9NRRlNQ+jAV5ucJnfdjpWF/pWZue09AdNFUce6XcTZJrpMnDfr6+lBZWSniqQ2Z3tEIfPche2ldFYD03wAvE2sKWlbu6ZStSZTMduD41hoBJdIvVe9HprzPRq9nrobXSetOm8n6UOj5wFl11CWT+1JSMffLqu9zNAZQHFI9wVDvB86qQ+9m7EvZ3n4OrScHEZIk5NpsaPxZGXbXVqXlua36PkfjWjAT0Tucm8k1BTVm60vZ3n4Ob5+4hND3vRQhScLbJy5h/b5P0vL8Vn2fozGATKK9JyD77Q/M/MBZdcV9ujvU23sCWLmnEwu3+rByT+eMzu7Wk4Oyf3f883+lpWPcqu9zNDbBTCDc9FIy/QMncug91Z2i6VpHp6e5G1IZn0lHM8hqUyzksAZkAnJNrzC5D5yoofdMHSaXo6e5mzttxn+0dDSDrDTFQglrQCag9mFW+sCle8U9YK1OUT39K40/K8PbJy7JHpeuZpCI9zmdGEAmoDT6o/YNLEKmdIrqaSbqGXHbXVuFi1/9G8c//1fMMY68XDywZB5W7uk0/fwcs2MTzATkRn+A7/ogzNTEyYROUb3NRL0jbu/85734719WxzSD6u4pwcEzAUs0RUVjAJlAuK0vV+Mx06pqsw2Ty9E7lSGe/pXaZSU4vrUGF/d4cHxrDT7q/4qr35OETbAkMzpKVLusBL/981nZ35mliaM1G9wM4mkmGu1fyZSmaCZgACVRolPnlfolChx5yS1oAszeKaqnbyfRqQRmnLGdqXJ37NixQ8QTX7lyBfPmzRPx1Alr7wngyTdPY/fR/8N7p/+JuTfPwpL5t+HJN0/jX9cmYo6duiHhXOAbPHnfQs3HnXvzLHT0BXFj2vSTkCThx3PysWT+bQmVLxvMvXkWuv7+FaaiXkRHXi6avT/Fkvm3Rb4kwu/Tt9en0PX3rzB0dQzPtvfqes2UnuORu5y6H8NqjF7PDKA4KX2ASwsdaDvzT9m/+ff1KfzXg3doPvaS+bfh9eMXcX3yRsy/35CgO8TUypcNF8OS+behtNCBc4Fv8O/rUyiZ7UCz96eRGo7il8Q/v8HV61MAtF8zued45C4nDp4JZO3rbvR6ZhMsTmqdnMmomo9cm5T9d739C1aaq2OUWjNR6XWcPudZ6zWb/hwr93Rm/etuBEfB4qTWAak1SqS19ghIbKg7nvVk2SqeL4N4XjN2TBvDAIqTWkCoDe0me37KdO09AWxp+zTucmcbuddXabpnPK9ZJsyRMiM2weKktUBQqfqvt2kU/v+dR3ox/H1z7Ca79vfEyx98hsmQ/OJJ2/flJvmpBA8smYeDZwIx748NwE/mOnTPds6GhaOpwACKk9G5MEpVcaUmU3RH9MjYpOZwvlpVX/q+vGp/n06itxlV+pJ458SlSF+QBMQswdCaUpEJc6TMiHtCJyCeC0lpv2MbgFd+WT2jQzPevZGV/iZaeA/r6OdKdxio7a0t8mLV8/oBmbE/tQhGr2f2ARkU79YUW9yLZfsaJAC/+8unMX9npENzi3sx8nLVF69OXy4gYnsNs97ET29nMTuVk4sBNI2ekSog/gupdlnJjKHesOmLTo10aNYuK8HL9UtRmK8+azr6AhIRBmYdLdLbWcxO5eRiAEWJp0Zg5EIqUfnwRl/4RkfCapeVoKf5IQzs8Sg+V/QFJCIMzDpapLQjwXThbTi0vqBIHwZQlHhqBEYuJK0PefjCDw/nz45aA/ajvPjeKj0hJiIMUrWiXq7mqrc2G+4HG5sMae7B9Of/HTTcZNVbnmzCTugoC7f6ZJtJNgAX93hi/s1oZ2p7TwC/+8unsvsNhzs423sCeOZ/zmF0IjYMbfiuz6hEZ2exVgezqA7hZHd8y51HXq4NkIDJaQvrCvPz8Jz3TgDffeEERsYir6tRejqmzdr5nixGr2cGUJR4R5+MXkhad1nd0vap4pye6ccn+uEVPSSeDHpHsMKUwskouS+osPDrq1Q+q4yq8c6oSRDvZDKjW1OozRlZuadTM3yA5K0zMvv2GnrE22el5/WNh1KTVe6LZjrRne+iMYCipHMymdKFn4z1R9lGaRFwOqh9Qand7SRMdOe7aAygaZJVIzDatInnYsr2D2+YXM012c2syOPm2HDLj+wYuTap+b5qfUFwqQYDKCUS2Rlxi3ux7j6gbP/whinVXAFgx+FejIzJb3GiV3hcLN4asdqXid6BBKtjJ3QKGFlKEa29JxCzGHW2Iw+PLp2Pj/q/yujOYlGqdx5LKITUOpnVWH3kKxo7oU0k0Ql+Ss3A6GadmRaXmt2OtXfOCIJ4ht6NNnXD7010LSze+VxWx1cjBVIxwc9Kt0VON7l9mtb//MczJkTm5dhmrKdLRlN3fOqHnQ2Gr03yfYvCGlAKpGJvGG61mhi5WuXy8jmy/UbJHAXl+6aOAZQCqRjON+sizkym1NRNZjDwfVPHAEqRZE/w472oMhPfN3XsA8oQDyyZN2M/IQ7Fx0fEYtBMuJ21SKwBZYD2ngAOngnEjNrYANTdk/nLKNIl0bvWGsWtWtUxgDKAXEemBOCj/q/EFCgDiewMtsJ6u1RhEywDsCMzcXwNzYkBlAHMuotgJuFraE66Aqi7uxtutxsulwstLS0zfn/48GF4vV54vV40NDSgv78/6QXNZuzITBxfQ3PS7AMKhULYtWsXXn/9dRQXF6O+vh41NTW4/fbbI8eUlpbi7bffRkFBAbq6uvDss8/ivffeS2nBswk7MhPH19CcNAPI7/ejvLwcZWVlAACPx4OOjo6YALr77rsj/19dXY2hoaEUFDW7sSMzcXwNzUezCRYMBuF0OiM/FxcXIxgMKh7f1taG1atXJ6d0RGRpmjUgud06bAp3Djhx4gTa2trw7rvvJl4yIrI8zQByOp0xTapgMIiioqIZx/X392P79u3Yt28fCgsLk1tKIrIkzSZYVVUVBgYGMDg4iImJCfh8PtTUxG6qdfnyZWzcuBF79+7FwoULU1ZYIrIWzRqQ3W5Hc3MzmpqaEAqFUFdXh0WLFqG1tRUA0NjYiFdffRUjIyPYuXMnACA3NxeHDh1KbcmJKONxS1YiSpjR65kzoYlIGAYQEQnDACIiYRhARCQMA4iIhGEAEZEwDCAiEoYBRETCMICISBgGEBEJwwAiImEYQEQkDAOIiIRhABGRMAwgIhKGAUREwjCAiEgYBhARCcMAIiJhGEBEJAwDiIiEYQARkTAMICIShgFERMIwgIhIGAYQEQnDACIiYRhARCQMA4iIhGEAEZEwDCAiEoYBRETCMICISBgGEBEJwwAiImEYQEQkDAOIiIRhABGRMAwgIhKGAUREwjCAiEgYXQHU3d0Nt9sNl8uFlpaWGb+XJAm7d++Gy+WC1+tFb29v0gtKRNajGUChUAi7du3C/v374fP5cPToUVy4cCHmmO7ubgwMDODYsWN4/vnnsWPHjlSVl4gsRDOA/H4/ysvLUVZWhlmzZsHj8aCjoyPmmI6ODtTW1sJms6G6uhpXr17Fl19+mbJCE5E12LUOCAaDcDqdkZ+Li4vh9/tVj3E6nQgGgygqKlJ83PHxcfT19RkpMxGZzPj4uKG/0wwgSZJm/JvNZov7mOmqq6u1npqILE6zCeZ0OjE0NBT5Wa5mM/2YoaEh1doPERGgI4CqqqowMDCAwcFBTExMwOfzoaamJuaYmpoatLe3Q5IknD17FrfeeisDiIg0aTbB7HY7mpub0dTUhFAohLq6OixatAitra0AgMbGRtx///3o6uqCy+WCw+HAiy++mPKCE1Hms0lyHThERGnAmdBEJAwDiIiESXkAWWEZh9Y5HD58GF6vF16vFw0NDejv7xdQSmVa5Q/z+/2orKzE+++/n8bS6aPnHE6ePInHHnsMHo8Hjz/+eJpLqE3rHL799ls89dRTWLt2LTweDw4ePCiglMq2bduGe++9F48++qjs7w1dy1IKTU1NSWvWrJEuXbokjY+PS16vV/rHP/4Rc8zf/vY36cknn5Ru3Lgh9fT0SPX19aksUtz0nMOZM2ekkZERSZK+Ox8znYOe8oeP27Bhg9TU1CT99a9/FVBSZXrO4ZtvvpEefvhhKRAISJIkSVeuXBFRVEV6zuG1116T9u7dK0mSJH399dfSihUrpPHxcRHFlXXq1Cnp/Pnzksfjkf29kWs5pTUgKyzj0HMOd999NwoKCgB8N8Eyek6UaHrKDwAHDhyA2+3G3LlzBZRSnZ5zOHLkCFwuFxYsWAAApjsPPedgs9kwOjoKSZIwOjqKgoIC2O2aA9Vps2LFisjnXI6RazmlASS3jCMYDKoeE17GYRZ6ziFaW1sbVq9enY6i6aL3Pfjwww/R0NCQ7uLpouccBgYGcPXqVWzYsAHr1q1De3t7uoupSs85rF+/Hp9//jlWrVqFtWvX4plnnkFOTuZ00xq5llMar1KKlnGkUzzlO3HiBNra2vDuu++muli66Sn/Cy+8gM2bNyM3NzddxYqLnnMIhULo7e3FG2+8gevXr6OhoQFLly7FwoUL01VMVXrO4eOPP0ZlZSXeeustXLp0CU888QSWL1+OW265JV3FTIiRazmlAWSFZRx6zgEA+vv7sX37duzbtw+FhYXpLKIqPeU/f/48Nm3aBAAYHh5GV1cX7HY7HnzwwbSWVYnez1FhYSHy8/ORn5+P5cuXo7+/3zQBpOccDh06hF//+tew2WwoLy9HaWkpvvjiC9x1113pLq4hRq7llNbvrLCMQ885XL58GRs3bsTevXtN84EP01P+zs7OyH9utxvPPfecacIH0HcOa9aswenTpzE1NYWxsTH4/X5UVFQIKvFMes5h/vz5+OSTTwAAV65cwcWLF1FaWiqiuIYYuZZTWgOywjIOPefw6quvYmRkBDt37gQA5Obm4tChQyKLHaGn/Gan5xwqKioifSc5OTmor6/HHXfcIbjkP9BzDk8//TS2bdsGr9cLSZKwefNmzJkzR3DJf7Bp0yacOnUKw8PDWL16NTZu3IipqSkAxq9lLsUgImEyp4udiCyHAUREwjCAiEgYBhARCcMAIiJhGEBEJAwDiIiE+X8PqJ9yuUGtAwAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 288x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "x = model.sample(100).detach().cpu().numpy()\n",
    "plt.figure(figsize=(4, 4))\n",
    "plt.scatter(x[:,0], x[:,1])\n",
    "plt.tight_layout()\n",
    "plt.xlim([0, 1])\n",
    "plt.ylim([0, 1])\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
