{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from autoregressive import Autoregressive\n",
    "from data import list_datasets, collate, collate_sort, load_dataset\n",
    "from copy import deepcopy\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from scipy.stats import wasserstein_distance\n",
    "from sklearn.neighbors import KDTree\n",
    "import torch\n",
    "torch.set_default_tensor_type('torch.cuda.FloatTensor') # Comment this out if you want to train without GPUs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = 'mixture'\n",
    "batch_size = 64\n",
    "learning_rate = 1e-3\n",
    "weight_decay = 0\n",
    "epochs = 50\n",
    "display_step = 5\n",
    "patience = 10\n",
    "hidden_dim = 64\n",
    "# params\n",
    "hidden_dim = 64\n",
    "num_autoreg_layers = 2\n",
    "num_coupling_layers = 2\n",
    "n_bins = 5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch    5, loss_train = -0.5602, loss_val = -0.5453\n",
      "Epoch   10, loss_train = -0.9456, loss_val = -1.0025\n",
      "Epoch   15, loss_train = -1.1682, loss_val = -1.1885\n",
      "Epoch   20, loss_train = -1.2599, loss_val = -1.3411\n",
      "Epoch   25, loss_train = -1.3967, loss_val = -1.3857\n",
      "Epoch   30, loss_train = -1.4352, loss_val = -1.4551\n",
      "Epoch   35, loss_train = -1.4904, loss_val = -1.4812\n",
      "Epoch   40, loss_train = -1.4969, loss_val = -1.4866\n",
      "Epoch   45, loss_train = -1.5431, loss_val = -1.4576\n",
      "Epoch   50, loss_train = -1.5603, loss_val = -1.5178\n"
     ]
    }
   ],
   "source": [
    "## Load data\n",
    "dset = load_dataset(dataset)\n",
    "trainset, valset, testset = dset.split_train_val_test()\n",
    "\n",
    "collate = collate_sort\n",
    "dl_train = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, collate_fn=collate)\n",
    "dl_val = torch.utils.data.DataLoader(valset, batch_size=batch_size, shuffle=False, collate_fn=collate)\n",
    "dl_test = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, collate_fn=collate)\n",
    "\n",
    "## Train model\n",
    "model = Autoregressive(dim=trainset.dim,\n",
    "                       hidden_dim=hidden_dim,\n",
    "                       num_autoreg_layers=num_autoreg_layers,\n",
    "                       num_coupling_layers=num_coupling_layers,\n",
    "                       n_bins=n_bins)\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)\n",
    "\n",
    "impatient = 0\n",
    "best_loss = np.inf\n",
    "best_model = deepcopy(model.state_dict())\n",
    "training_val_losses = []\n",
    "\n",
    "for epoch in range(epochs):\n",
    "    # Optimization\n",
    "    model.train()\n",
    "    for batch in dl_train:\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        x, m, _ = batch\n",
    "        loss = model(x, m)\n",
    "\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "    # Validation\n",
    "    model.eval()\n",
    "    loss_val = 0\n",
    "    for i, batch in enumerate(dl_val):\n",
    "        x, m, _ = batch\n",
    "        loss_val += model(x, m).item() / len(dl_val)\n",
    "\n",
    "    training_val_losses.append(loss_val)\n",
    "\n",
    "    # Early stopping\n",
    "    if (best_loss - loss_val) < 1e-4:\n",
    "        impatient += 1\n",
    "        if loss_val < best_loss:\n",
    "            best_loss = loss_val\n",
    "            best_model = deepcopy(model.state_dict())\n",
    "    else:\n",
    "        best_loss = loss_val\n",
    "        best_model = deepcopy(model.state_dict())\n",
    "        impatient = 0\n",
    "\n",
    "    if impatient >= patience:\n",
    "        print(f'Breaking due to early stopping at epoch {epoch}')\n",
    "        break\n",
    "\n",
    "    if (epoch + 1) % display_step == 0:\n",
    "        print(f\"Epoch {epoch+1:4d}, loss_train = {loss:.4f}, loss_val = {loss_val:.4f}\")\n",
    "\n",
    "\n",
    "## Test model\n",
    "model.load_state_dict(best_model)\n",
    "model.eval()\n",
    "\n",
    "test_loss = 0\n",
    "for i, batch in enumerate(dl_test):\n",
    "    x, m, _ = batch\n",
    "    loss = model(x, m)\n",
    "    test_loss += loss.item() / len(dl_test)\n",
    "\n",
    "\n",
    "## Sampling quality -- Wasserstein score\n",
    "dist_test, dist_samples = [], []\n",
    "for x in testset:\n",
    "    if len(x[0]) > 2:\n",
    "        dist_test.append(KDTree(x[0]).query(x[0], k=2)[0][:,1])\n",
    "        samples = model.sample(len(x[0])).detach().cpu().numpy()\n",
    "        dist_samples.append(KDTree(samples).query(samples, k=2)[0][:,1])\n",
    "\n",
    "dist_test = np.concatenate(dist_test, 0)\n",
    "dist_samples = np.concatenate(dist_samples, 0)\n",
    "\n",
    "wasserstein = float(wasserstein_distance(dist_test, dist_samples))\n",
    "\n",
    "\n",
    "## Save results\n",
    "results = { 'test_loss': test_loss, 'training_val_losses': training_val_losses,\n",
    "            'final_epoch': epoch, 'wasserstein': wasserstein }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test loss -1.5305\n"
     ]
    }
   ],
   "source": [
    "print(f'Test loss {test_loss:.4f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAASAAAAEcCAYAAABnO2lWAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAcWklEQVR4nO3dfWxb5b0H8K8TJ+AObtJWS0xeFFUhdNFV1hRaXVWsZUsx1mSS5jbRXSLopGnRhJAqjd5WagQLbXjrWtC9/6BJJRJoBSIubRSVRqOI9C4ZiLYqSmcaJdUKzZqlirllCYWQ5sWc+0drN3bOm4+P/TzH/n4kJGKf2o/tc77+nefl2KUoigIiIgFyRDeAiLIXA4iIhGEAEZEwDCAiEoYBRETCMICISBjDAGpvb8emTZvw6KOPqt6vKAqef/55+Hw+1NfXY3h42PZGElFmMgyg7du3o6urS/P+wcFBjI2N4YMPPsBzzz2Hffv22dk+IspghgG0ceNGFBQUaN7f39+PxsZGuFwu1NbW4vr16/jyyy9tbSQRZSZ3sg8QCoXg9Xqjf3u9XoRCIRQVFen+u/Pnz+OOO+5I9umJSAJzc3Oora1N+N8lHUBqKzlcLpfhv7vjjjtQXV2d7NMTkQRGRkYs/bukR8G8Xi8mJyejf09OThpWP0REgA0BVFdXh97eXiiKgvPnz+Puu+9mABGRKYanYLt27cLZs2cxNTWFLVu2YOfOnVhcXAQAtLa24qGHHsLAwAB8Ph88Hg9efPHFlDeaiDKDS9TlOEZGRtgHRJQhrB7PnAlNRMIwgIhIGAYQEQnDACIiYRhARCQMA4iIhGEAEZEwDCAiEoYBRETCMICISBgGEBEJwwAiImEYQEQkDAOIiIRhABGRMAwgIhKGAUREwjCAiEgYBhARCZP074IRJat3aAKHTl7E1elZlBR6sMe/Fo3rS0U3i9KAAURC9Q5NoL3nM8wuhAEAE9OzaO/5DAAYQlmAp2Ak1KGTF6PhEzG7EMahkxcFtYjSiQFEQl2dnk3odsosDCASqqTQk9DtlFkYQCTUHv9aePJyY27z5OVij3+toBZROrETmoSKdDRzFCw7MYBIuMb1pQycLMVTMCIShgFERMIwgIhIGAYQEQnDACIiYRhARCQMA4iIhGEAEZEwDCAiEoYBRETCMICISBgGEBEJwwAiImEYQEQkjKkAGhwchN/vh8/nw+HDh5fd/8033+CJJ55AQ0MDAoEAjh07ZntDiSjzGAZQOBxGZ2cnurq60NfXhxMnTuDSpUsx27z11luorKzE8ePHceTIEfz+97/H/Px8yhpNRJnBMICCwSAqKipQXl6O/Px8BAIB9Pf3x2zjcrkwMzMDRVEwMzODgoICuN281hkR6TMMoFAoBK/XG/27uLgYoVAoZpvHHnsMn3/+OTZv3oyGhgY8/fTTyMlh9xIR6TNMCUVRlt3mcrli/v7oo49QXV2Nv/zlL+jt7UVnZye+/fZb+1pJRBnJMIC8Xi8mJyejf4dCIRQVFcVs09PTg0ceeQQulwsVFRUoKyvDF198YX9riSijGAZQTU0NxsbGMD4+jvn5efT19aGuri5mm3vuuQeffPIJAODatWu4fPkyysrKUtNiIsoYhj3FbrcbHR0daGtrQzgcRlNTE6qqqtDd3Q0AaG1txZNPPon29nbU19dDURTs3r0bq1atSnnjicjZXIpaJ08ajIyMoLq6WsRTE5HNrB7PHKoiImE4WScD9A5N8JdFyZEYQA7XOzSB9p7PMLsQBgBMTM+iveczAGAIkfR4CuZwh05ejIZPxOxCGIdOXhTUIiLzGEAOd3V6NqHbiWTCAHK4kkJPQrcTyYQB5HB7/GvhycuNuc2Tl4s9/rWCWkRkHjuhHS7S0cxRMHIiBlAGaFxfysAhR+IpGBEJwwqIbKU3KZITJikeA4hsozcpEgAnTNIyDKAsZ2dVYjQpUus+BlD2YgBlsUSWcWgF1dLbtS6roDcpkhMmsxsDKIvpVSxLA0grqM79/Z849unEsseIF5kUOaESNpwwmd04CpbFzC7j0Aqq7jPjhuETmRTJCZOkhgGUxcwu49AKqrDOtexcAEoLPXhpew2A2yGWe+sHDSL3sf8nuzGAspjZqkQrqHLjfh0lorTQg8sHAvh4781rh7f3fBY9/QorSvQ5tMKnd2gCDx44hTV7+/DggVPoHZpI6HWRczCAsljj+lK8tL0GpYUeuAAUevJwZ14OnnrnfMyBrxVUrf9WbhhgiV4uJNLfNHGrUzvS38QQykzshM5ykWUcZkbE1EbBNlSs0h3GT/RyIWY7xikzMIAIgPGBvzSoDp28iKfeOY9DJy9ij39t9FRLTUmhJ6HRL17fKLswgAiAuQPfqEpSmyu0x7825t8A+qNfiQYWORv7gAiAuRExvSpJq+8GgKl+pggO12cX/i4YAVhe3QBAXo4Ld93pxvR3C5qVCXBzyF3v/tJb1RAA1Woofjiei1adx+rxzACiqKUHfoEnDzPzi1gI3949XIDqcovSQo/uUgzgZtDcmZeDqe8WVO8vZdA4mtXjmX1ADpSqCmHphc0ePHAK07OxYaFgeQhFTo8OnbyoWQEBN0/V9GZNc3V8dmIfkMMYzZOxaxKfVqe0AkT7c5bOZlbru0kUf04o+7ACchijiX12XXNHq0+ntNCjOuy+dK6QViVU6MnD3OL3upUQh9uzCysgh9EbLtcKp99qjDjpsTIa1bi+FB/vrcN//6JW9d/ua/jX6IiYFg63ZxdWQA6jN09Gr3pItBpK5tc2jP6t2sxrgMPt2YijYA6jdeC+tL3GsCMY0D6FstKOZDvCOdyeOTgKliWMqov4cIpnRx9LIldS1MOfEyIGkANpHbhmOoLt6GORacEoqyhnYyd0hjHqCLajj0WWBaO8dIfzMYAyVPy1fuy8AqHZKymmWqLXGiL58BQsg6WqjyXRFe6pIkslRtaxAqKEpbK6SoQslRhZxwooS9jdWSvDCJYslRhZxwDKAnYNm8smmcmSJAcGUBaQadjcbjJUYmQd+4CyADtrSVamAmhwcBB+vx8+nw+HDx9W3ebMmTPYtm0bAoEAHn/8cVsbSclxQmctfwssOxmegoXDYXR2duL1119HcXExmpubUVdXh3vvvTe6zfXr17F//350dXWhpKQEX331VUobTYmRvbM2U/uoyJhhBRQMBlFRUYHy8nLk5+cjEAigv78/Zpv33nsPPp8PJSUlAIDVq1enprVkiahhc7NVDScUZi/DCigUCsHr9Ub/Li4uRjAYjNlmbGwMi4uL2LFjB2ZmZvDLX/4SjY2N9reWLEt3Z20iVQ37qLKXYQCpXa3DFfeb4OFwGMPDw3jjjTdw48YNtLS0YN26dVizZo19LSVHSWTkjb8Flr0MT8G8Xi8mJyejf4dCIRQVFS3bZvPmzVixYgVWrVqFDRs2YHR01P7WkmMkUtXwt8Cyl2EA1dTUYGxsDOPj45ifn0dfXx/q6mIvaLV161acO3cOi4uLmJ2dRTAYRGVlZcoaTfJLZORNlqUdlH6Gp2ButxsdHR1oa2tDOBxGU1MTqqqq0N3dDQBobW1FZWUlNm/ejIaGBuTk5KC5uRn33XdfyhtP8kp05I0TCrMTL8lKKcOLhWUPXpKVpMOqhoxwKQYRCcMAIiJhGEBEJAwDiIiEYQARkTAMICIShgFERMIwgIhIGAYQEQnDACIiYRhARCQMA4iIhGEAEZEwDCAiEoYBRETCMICISBgGEBEJwwAiImEYQEQkDAOIiIRhABGRMAwgIhKGAUREwjCAiEgYBhARCcMAIiJhGEBEJAwDiIiEcYtuAJFZvUMTOHTyIq5Oz6Kk0IM9/rVoXF8qulmUBAYQOULv0ATaez7D7EIYADAxPYv2ns8AgCHkYDwFI0c4dPJiNHwiZhfCOHTyoqAWkR0YQOQIV6dnE7qdnIEBRI5QUuhJ6HZyBgYQOcIe/1p48nJjbvPk5WKPf62gFpEd2AlNjhDpaOYoWGZhAJFjNK4vZeBkGJ6CEZEwDCAiEoanYA7FWcGUCRhADsRZwZQpTJ2CDQ4Owu/3w+fz4fDhw5rbBYNBVFdX4/3337etgbScbLOCe4cm8OCBU1iztw8PHjiF3qEJIe0g5zEMoHA4jM7OTnR1daGvrw8nTpzApUuXVLd7+eWX8ZOf/CQlDaXbZJoVHKnGJqZnoeB2NcYQIjMMAygYDKKiogLl5eXIz89HIBBAf3//su2OHDkCv9+P1atXp6ShdJtMs4Jlq8bIWQwDKBQKwev1Rv8uLi5GKBRats2HH36IlpYW+1tIy8g0K1imaoycxzCAFEVZdpvL5Yr5+4UXXsDu3buRm5u7bFuyX+P6Ury0vQalhR64AJQWevDS9hohHdAyVWPkPIajYF6vF5OTk9G/Q6EQioqKYra5cOECdu3aBQCYmprCwMAA3G43Hn74YZub61x2D5vLMit4j39tzIgcwDVaZJ5hANXU1GBsbAzj4+MoLi5GX18fXnnllZhtTp06Ff3/vXv34qc//SnDBzdDZ9/xYUzPLsTcnknD5lyjJRenzQ8zDCC3242Ojg60tbUhHA6jqakJVVVV6O7uBgC0tramvJFO1Ds0gT3v/hUL3y8/hQVud9TKvHOYJUs1li6yHuROnB/mUtQ6edJgZGQE1dXVutvI+kGb8eCBU5gw6Ih1Abh8IJCeBpEt4g9y4OYpp6g+uKW09rnSQg8+3luX0uc2czyrkXYmtFGayx5OZkaB2FHrPHrTDkTvf04ckZQ2gIzml8heapYUenQrIHbUOpPMB7nWPifzF520q+H1PmgnTH772Y9+qHnfyhV5UpTslLhkph2keslKquaHpbLd0lZAemku87cQcPMDO/bp8g/pB/m5eOHfGTxOZnXaQTo6iFMxIpnqdktbAemlueyT39QqNAAoXJHP8HE4q5NA01W1N64vxcd763D5QAAf761Len9LdbulrYCM0lzmyW+yV2iUuGQHPZy6T6S63dIGEKA9v0T2yW9O7AwkbXachjh1n0h1u6U9BQP0O7/sLjXtJNNiUUqeHachTt0nUt1uaSsgWWd1minF7arQZJ/rlMki773eVIpETkNkr9q1pLrd0s2ENvrg0zGrU0s6Z8HKPOM206m992pE7ouixX85vuxbjU0PrEv4caQ6BVt6dT0tE9Ozwi79mc75R06Y65SptEYxl5J9fk0qqV0F8x9T1jqlpToFM/PBA4i59CeQvlOydI5kOHXURFaJnM4areEDkHQlKmsXgxlqx6nGmmtDUgVQogeXHWtw4nfMn/3oh/jf0f9T3VHtGBHQOhDiby9ckYep7xaW/XvZR01klMi6wgJPnuHjlRZ6Ujq/RsYAWvoe2dlnI1UAGa2fUpNMRaC2Y755+kr0/vgdNdmLb2kdCOf+/k8c+3Qi5va8HBfycl1YCN/+uJ0waiKjRNYVxl+7KZ5dn4HMFa7al/LS/dNOUvUBqQ35GUmmIjBzyrd0R032UqhaB0L3mfFlty98r+AH+e6Y52p6oBSHTl50XJ+BaImuK9Si93kn2p8j62x+tf6dt05fMXyPcly6d2uSqgKKH/LLcbkQ1hmkS/bbyOy3zdLtzFx8S+s0S6u603qNX88u4Pyzj0Qf06l9BqJZWVcYT2/Ey8pnI+ulbNUCWe+Uy4Wb72PZSmvBKVUFBMROMHzlP9Ytq4giQWvHhdjNftsk2sej9TtZua7EviaWPi9Hxayzsq5QbVstVj4bmX5YYKlETgFLCz3RicCFK/ItPZ9UFVC8RCdBJTpxT+1bKF78zmf0HHo7o141Z/S8MvcZyC7RdYV5OS7cdacb098tmNqPrH426b6UrZnjQ6tadCG2ErKrWpM6gADzH5KVMlhtx4wfBfvZj36IQycv4ql3zqPAk4eZ+cVox7Dac2jtdBPTs8g1OKVcKv7b0KlriWSRynWFTvhszB4fWqeGTQ+Uao4OJ0P6ADLL7LCm2reA2XN7tRGS+OfQ+wYxGz5qw7yy9hlkAquVyNJZ+/EVQl6OC9/NL2LN3j4pll2YPT7SvWREygCysgZKr/KI7ATxw4lGVZLZEZKlz60WFPE7px7XrXY9eOBUzOt26lqiTBX/5aTg9udceKtSjszjkmHAIJHTxHSeGkoXQGql4p53/4r97w3rnpPrzSFaOpwYHwR6k7/M9q8sLbXVgiKRuU2R9qnttNn28zcy0xotKr21L8RXy6InGcp6mijdKJjaB7vwvYKp7xaWjSotZWYOkVYVohU0VkdI4i8VUmrxQ146kuLUdUMyS+Y91asoZBwwkPVyINJVQGY+pKXfJvHT6O/My8H0rbAySyto1E6nEh0h0Xocs65Oz3IOUAok+54aVRR2VxvJXprF6BTe7OPbfYkY6QLI7CmL2oE5PbsAT14u/usXtZqX9EhkONGufpf4x0k0HJ22bsgJkn1PjQYF7BwwsOsLSOsU3uzj62239k4LLwySXg/I7LVYAPVvmtJbQZHO4cREmPnVVMC485q/rGrdmr19qu9tIu+pXjVgZ6WQ6l88Nfv4ett1bbsnM34ZNb5aiJ97A9z+NnnqnfOqj3F1elbqUaNI242S3+h+0R2ITmZHp6zeoICdAwap7lMy+/h6I81WSRdAwPIPT+vbROs0K7ITyTpq1Li+FL/VCE+zZOhAdDInzatK9QiW2cfXm+M2/d28peeWbhRMjdYF6GXt2TfD6shYhAzrhpxM1rVYalK9n5t9/D3+tVBbzagAmLx+w9JzS1kBmSXzaZaRZEbG7LggFslbIcdL9X5u9vH1Kvf5RWtdyY4OIMA5O1G8SJv3HR82vAjWUk6p8Mheqd7PzT6+1npGi5cDcn4AOVnkQzfzEzDA7dE9JwauE/BnkIxprWe0OpQuPID4od8OIq1hTheA//pFbda9L+nEyZ7mrNS4Vnm+21oNJLQTWu/iXdlIr5PvP//nr1yGkUK84Jux3qEJfHtjcdntebkueP/F2kxEoQHEDz1W4/pSzVI2rCgM6RSScf2WbA6dvIgFld/f+UG+2/IVEYUGUCZ96HYtFjVz2dZsDulUkfUi8TLROi6/TmAQJZ7QAMqUD93OU0mzFy1zYkjLzMlzytIlFcer0ADKlA/dzlNJsxMUnRbSsnPSxERRUnG8Ch0Fc/JEwqXsPJW0cqF8sodT55Sli97xOjJy3dJjSrca3onsXq2cyM9FU+I49SN58e/hy77V2PTAuoQfR/g8oExg98JGfhOnDuf7JE/tPfzHlLU+SVN9QIODg/D7/fD5fDh8+PCy+48fP476+nrU19ejpaUFo6OjlhrjVOw/kI/WqCSnfiRP7T1UGZ03xbACCofD6OzsxOuvv47i4mI0Nzejrq4O9957b3SbsrIyvPnmmygoKMDAwAB+97vf4d1337XWIoeJL0U5Y1k8vSonk6Z+iGLne2VYAQWDQVRUVKC8vBz5+fkIBALo7++P2eb+++9HQUEBAKC2thaTk5O2NVBmnMktJ70qJ1Omfohk53tlGEChUAherzf6d3FxMUKhkOb2R48exZYtW+xpneRYzstJr8rJlKkfIqm9hzkWl8MbnoKpDZK5NGbrnj59GkePHsXbb79trTUOw3JeTnpX+MuUqR8iqb2HZSutVUWGAeT1emNOqUKhEIqKipZtNzo6imeeeQavvfYaVq5caakxTiPrj71lO6NRSY4yJi/+PRwZGbH0OIanYDU1NRgbG8P4+Djm5+fR19eHurrYuS1Xr17Fzp07cfDgQaxZs8ZSQ5yI5bycOCrpHIYVkNvtRkdHB9ra2hAOh9HU1ISqqip0d3cDAFpbW/Hqq69ienoa+/fvBwDk5uaip6cntS2XAMt5ebHKcQbOhCaipFk9nh3xqxhElJkYQEQkDAOIiIRhABGRMAwgIhKGAUREwjCAiEgYBhARCcMAIiJhGEBEJAwDiIiEYQARkTAMICIShgFERMIwgIhIGAYQEQnDACIiYRhARCQMA4iIhGEAEZEwDCAiEoYBRETCMICISBgGEBEJwwAiImEYQEQkDAOIiIRhABGRMAwgIhKGAUREwjCAiEgYBhARCcMAIiJhGEBEJAwDiIiEYQARkTAMICIShgFERMIwgIhIGAYQEQnDACIiYUwF0ODgIPx+P3w+Hw4fPrzsfkVR8Pzzz8Pn86G+vh7Dw8O2N5SIMo9hAIXDYXR2dqKrqwt9fX04ceIELl26FLPN4OAgxsbG8MEHH+C5557Dvn37UtVeIsoghgEUDAZRUVGB8vJy5OfnIxAIoL+/P2ab/v5+NDY2wuVyoba2FtevX8eXX36ZskYTUWZwG20QCoXg9XqjfxcXFyMYDOpu4/V6EQqFUFRUpPm4c3NzGBkZsdJmIpLM3NycpX9nGECKoiy7zeVyJbxNvNraWqOnJqIMZ3gK5vV6MTk5Gf1brbKJ32ZyclK3+iEiAkwEUE1NDcbGxjA+Po75+Xn09fWhrq4uZpu6ujr09vZCURScP38ed999NwOIiAwZnoK53W50dHSgra0N4XAYTU1NqKqqQnd3NwCgtbUVDz30EAYGBuDz+eDxePDiiy+mvOFE5HwuRa0Dh4goDTgTmoiEYQARkTApD6BMWMZh9BqOHz+O+vp61NfXo6WlBaOjowJaqc2o/RHBYBDV1dV4//3309g6c8y8hjNnzmDbtm0IBAJ4/PHH09xCY0av4ZtvvsETTzyBhoYGBAIBHDt2TEArtbW3t2PTpk149NFHVe+3dCwrKbS4uKhs3bpVuXLlijI3N6fU19crf/vb32K2+fOf/6z8+te/Vr7//ntlaGhIaW5uTmWTEmbmNXz66afK9PS0oig3X49Mr8FM+yPb7dixQ2lra1P+9Kc/CWipNjOv4euvv1Z+/vOfKxMTE4qiKMq1a9dENFWTmdfwhz/8QTl48KCiKIry1VdfKRs3blTm5uZENFfV2bNnlQsXLiiBQED1fivHckoroExYxmHmNdx///0oKCgAcHOC5dI5UaKZaT8AHDlyBH6/H6tXrxbQSn1mXsN7770Hn8+HkpISAJDudZh5DS6XCzMzM1AUBTMzMygoKIDbbThQnTYbN26M7udqrBzLKQ0gtWUcoVBId5vIMg5ZmHkNSx09ehRbtmxJR9NMMfsZfPjhh2hpaUl380wx8xrGxsZw/fp17NixA9u3b0dvb2+6m6nLzGt47LHH8Pnnn2Pz5s1oaGjA008/jZwc53TTWjmWUxqvSoqWcaRTIu07ffo0jh49irfffjvVzTLNTPtfeOEF7N69G7m5uelqVkLMvIZwOIzh4WG88cYbuHHjBlpaWrBu3TqsWbMmXc3UZeY1fPTRR6iursYf//hHXLlyBb/61a+wYcMG3HXXXelqZlKsHMspDaBMWMZh5jUAwOjoKJ555hm89tprWLlyZTqbqMtM+y9cuIBdu3YBAKampjAwMAC3242HH344rW3VYnY/WrlyJVasWIEVK1Zgw4YNGB0dlSaAzLyGnp4e/OY3v4HL5UJFRQXKysrwxRdf4Mc//nG6m2uJlWM5pfVdJizjMPMarl69ip07d+LgwYPS7PARZtp/6tSp6H9+vx/PPvusNOEDmHsNW7duxblz57C4uIjZ2VkEg0FUVlYKavFyZl7DPffcg08++QQAcO3aNVy+fBllZWUimmuJlWM5pRVQJizjMPMaXn31VUxPT2P//v0AgNzcXPT09IhsdpSZ9svOzGuorKyM9p3k5OSgubkZ9913n+CW32bmNTz55JNob29HfX09FEXB7t27sWrVKsEtv23Xrl04e/YspqamsGXLFuzcuROLi4sArB/LXIpBRMI4p4udiDIOA4iIhGEAEZEwDCAiEoYBRETCMICISBgGEBEJ8/9pBmcxQG2U2gAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 288x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "x = model.sample(100).detach().cpu().numpy()\n",
    "plt.figure(figsize=(4, 4))\n",
    "plt.scatter(x[:,0], x[:,1])\n",
    "plt.tight_layout()\n",
    "plt.xlim([0, 1])\n",
    "plt.ylim([0, 1])\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
