{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Topological optimization of 2 clusters"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this notebook, we show how a topological loss can be use to optimize a data set for two clusters.\n",
    "\n",
    "We start by setting the working directory and importing the necessary libraries."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set working directory\n",
    "import os\n",
    "os.chdir(\"..\")\n",
    "\n",
    "# Handling arrays and data.frames\n",
    "import numpy as np\n",
    "\n",
    "# Random sampling\n",
    "import random\n",
    "\n",
    "# Functions for deep learning (Pytorch)\n",
    "import torch\n",
    "from torch import nn\n",
    "\n",
    "# Pytorch compatible topology layer\n",
    "from topologylayer.nn import AlphaLayer\n",
    "from Code.losses import DiagramLoss\n",
    "\n",
    "# Plotting\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load and view data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We generate the data as follows."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYkAAAD8CAYAAACCRVh7AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAbjElEQVR4nO3df6zd9X3f8efLHpY7GwLY18bYuMaZVeS0CUG3gJpQBQUqc1vtkqll0JayjslDHWNZFC3eIlWZpkosWmlqjYYZikK2JhStYViZCwXSClcxzNepCzY/jeWWG198L4bEmMrhx33vj/M9zsn1+f39nvP9cV4P6eieH5/P+b7vtc95fz8/v4oIzMzMmlmUdwBmZlZcThJmZtaSk4SZmbXkJGFmZi05SZiZWUtOEmZm1lImSULSFkkvSTokaVuT139D0rPJ7buSPtZtXTMzy4/SrpOQtBh4GbgWmAb2AjdFxPMNZX4BeCEi3pJ0HfCliLiim7pmZpafLFoSlwOHIuJwRLwLPAhMNhaIiO9GxFvJw6eBdd3WNTOz/PyjDN5jLfBaw+Np4Io25W8F/rzPugCsXLkyNmzY0FuUZmYjbt++fW9ExFgvdbJIEmryXNM+LElXU0sSn+yj7lZgK8D69euZmprqPVIzsxEm6e96rZNFd9M0cFHD43XA0YWFJH0UuA+YjIjjvdQFiIgdETEeEeNjYz0lQjMz61MWSWIvsEnSxZKWADcCOxsLSFoPfAu4OSJe7qWumZnlJ3V3U0S8L+l24DFgMXB/RByUdFvy+j3A7wIrgD+SBPB+0ipoWjdtTGZmlo3UU2DzMD4+Hh6TMDPrjaR9ETHeSx2vuDYzs5aymN1kZiNofj44cvwdjp04xepzlrJhxTIWLWo2YdHKzEnCzHo2Px88evB1PvfQfk69N8/SsxZx1w2XsuUjFzhRVIy7m8ysZ0eOv3M6QQCcem+ezz20nyPH38k5Msuak4SZ9ezYiVOnE0TdqffmmX37VE4R2aA4SZhZz1afs5SlZ/3k18fSsxax6uylOUVkg+IkYWY927BiGXfdcOnpRFEfk9iwYlnOkVnWPHBtleYZOIOxaJHY8pELuOSOq5h9+xSrzvbftqqcJKyyPANnsBYtEhvHlrNxbHneodgAubvJKsszcMzSc5KwyvIMHLP0nCSssjwDxyw9JwmrLM/AMUvPA9dWWZ6BY5aek0SFeLrnmTwDxywdJ4mK8HRPMxuETMYkJG2R9JKkQ5K2NXn9Ekl7JP1I0ucXvHZE0nOS9kvylYT65OmeZjYIqVsSkhYDdwPXAtPAXkk7I+L5hmJvAncA17d4m6sj4o20sYyydtM93dViZv3KoiVxOXAoIg5HxLvAg8BkY4GImI2IvcB7GRzPmvB0TzMbhCySxFrgtYbH08lz3QrgLyTtk7S1VSFJWyVNSZqam5vrM9Tq8nRPMxuELAaum42KRg/1PxERRyWtAh6X9GJEPHXGG0bsAHYAjI+P9/L+I8HTPc1sELJIEtPARQ2P1wFHu60cEUeTn7OSHqbWfXVGkrDOPN3TzLKWRXfTXmCTpIslLQFuBHZ2U1HSMkln1+8DvwQcyCAmMzPLQOqWRES8L+l24DFgMXB/RByUdFvy+j2SLgCmgHOAeUmfBTYDK4GHJdVj+UZEPJo2JjMzy0Ymi+kiYhewa8Fz9zTcf51aN9RCJ4CPZRGDFYtXf5tVg1dcW+a8+tusOrwLrGXOq7/NqsNJwjLni/2YVYeThGXOq7/NqsNJwjLn1d9m1eGBa8ucV3+bVYeThA2EV3+bVYOTxBB4zYCZlZWTxIB5zYCZlZkHrjMwPx8cnjvJnlff4PDcSebnf7xJrdcMmFmZuSWRUqeWgq8YZ+24K9KKzi2JlDq1FLxmwFqpn2BMbN/NTfc+w8T23Tx68PWfaIma5c1JIqVOq4u9ZsBacVeklYG7m1KqtxQaE0VjS8FrBqwVd0VaGbglkVI3LYX6moErN65k49hyJwgD3BVp5ZBJkpC0RdJLkg5J2tbk9Usk7ZH0I0mf76Vu0dVbCrvuuIoHt17Brjuu8vRW64q7Iq0MFJFukEzSYuBl4Fpq17veC9wUEc83lFkF/DRwPfBWRPy3bus2Mz4+HlNTU6niNiuC+uwmd0XaMEjaFxHjvdTJYkzicuBQRBxOgngQmAROf9FHxCwwK+mXe61rVmXevsSKLovuprXAaw2Pp5PnMq0raaukKUlTc3NzfQVq2Wq3iNDMqiGLlkSztnG33xZd142IHcAOqHU3dfn+NiDebsRsNGTRkpgGLmp4vA44OoS6liPP8TcbDVkkib3AJkkXS1oC3AjsHEJdy5EvUWo2GlJ3N0XE+5JuBx4DFgP3R8RBSbclr98j6QJgCjgHmJf0WWBzRJxoVjdtTDZ4nRYRmlk1pJ4CmwdPgc1fWcYkvIGe2Y/lNQXWRlAZthspSyIzKzJvy2F9K/p2Ix5cz5+nSZefWxJWWd5AL19uyVWDWxJWWd5AL19uyVWDk4RVljfQG7x23UmeJl0N7m6yyirD4HqZdepO8jTpanBLooCanZ15ALA/RR9cz9Kw/4906k4qQ0vOn6vO3JIomGZnZ//91z/Ou++HBwCtpTwGiTtNDCh6S84D691xS6Jgmp2dPTv9Qw8AWlt5DBJ3MzGgyC05D6x3x0miYJqdnc0HHgC0tvIYJC5Dd1I7HljvjrubCqbZYN9i4QFAayuPQeKidyd14oH17rglUTDNzs5+bt2HSn3GZoOX11l9kbuTOilbSyivQXZv8FdAza57DPhayNaWr5fdu7L8zbIaZO9ngz8nCTOzgjs8d5KJ7bvP6BrbdcdVPW0x00+ScHeTmVnB5TnInkmSkLRF0kuSDkna1uR1SdqevP6spMsaXjsi6TlJ+yW5eWBmtkCe+5ClThKSFgN3A9cBm4GbJG1eUOw6YFNy2wp8dcHrV0fEpb02g8zMRkGeg+xZTIG9HDgUEYcBJD0ITALPN5SZBL4etQGQpyWdK2lNRMxkcHwzs0rLc7pxFkliLfBaw+Np4IouyqwFZoAA/kJSAP8jInZkEJOZWaXUpxsP+1ooWSSJZqls4ZSpdmU+ERFHJa0CHpf0YkQ8dcZBpK3UuqpYv359mnjNzKxLWQxcTwMXNTxeBxzttkxE1H/OAg9T6746Q0TsiIjxiBgfGxvLIGwzM+skiySxF9gk6WJJS4AbgZ0LyuwEfiuZ5XQl8MOImJG0TNLZAJKWAb8EHMggJjMzy0Dq7qaIeF/S7cBjwGLg/og4KOm25PV7gF3ABHAI+Afgt5Pqq4GHJdVj+UZEPJo2piKpr+g8duIUq88p7orOQRjl392sKrzieoBGeb/6Uf7dzYrKK64LZpT3qx/l392sSpwkBmgU9qtvtTPlKPzuZqPA15MYoKrvV9+uS6nqv7vZqHBLYoDKtl99r9p1KVX9dzcbFW5JDFDZr9zVSbsupY1jyyv9u5uNCieJActrKf0wdOpSqvLvbjYq3N1kfXOXkln1uSVhfat6d5qZOUlYSu5SsrLwDgD9cZIws8rzDgD985iEmVWedwDon5OEmVWedwDon5OEmVVefbp2I+8A0B0nCTOrPE/X7p8Hrs2s8jxdu39OEgXmKXtm2SnadO2yfL4zSRKStgB/SO3KdPdFxJ0LXlfy+gS1K9P9i4j4Xjd1R5Wn7JlVV5k+36nHJCQtBu4GrgM2AzdJ2ryg2HXApuS2FfhqD3VHkqfsmVVXmT7fWQxcXw4ciojDEfEu8CAwuaDMJPD1qHkaOFfSmi7rjiRP2TOrrjJ9vrNIEmuB1xoeTyfPdVOmm7oASNoqaUrS1NzcXOqgi85T9syqq0yf7yySRLMOtOiyTDd1a09G7IiI8YgYHxsb6zHE8vGUPbPqKtPnO4uB62ngoobH64CjXZZZ0kXdkeQpezYqij7LZxDxlenznUWS2AtsknQx8H3gRuDXF5TZCdwu6UHgCuCHETEjaa6LuiOraFP2zLJW9Fk+g4yvLJ/v1N1NEfE+cDvwGPAC8FBEHJR0m6TbkmK7gMPAIeBe4Hfa1U0bk5mVQxFm+czPB4fnTrLn1Tc4PHeS+fkf93gXIb68ZbJOIiJ2UUsEjc/d03A/gH/TbV0zGw2drpM+aJ1aCnnHVwTeu8nMcpP3LJ9OLYW84ysCJ4mSaNckNiurvGf5dFqvkHd8ReC9m0qg6IN7Zv3Ke5ZPvaXQmCgaWwp5x1cEbkmUgAfPrMrqs3yu3LiSjWPLh/oF3E1LIc/4isAtiRLw4JmlUfR1CHlyS6EzJ4kS6NQkNmvFXZWdlWW9Ql7c3VQCHjyzfrmr0tJyS6IEGpvEx06c4h8vWcy7H8xz5Pg7bhpbW+6qtLScJEpi0SKxYcUyXnz9bX77a3vddWBdcVelpeXuphJx14H1yl2VlpZbEiXirgPrlWfvWFpOEiXirgPrh2fvWBrubioRdx2Y2bC5JVEi7jows2FzkigZdx2Y2TCl6m6SdL6kxyW9kvw8r0W5LZJeknRI0raG578k6fuS9ie3iTTxmJlZttKOSWwDnoyITcCTyeOfIGkxcDdwHbAZuEnS5oYifxARlya3gV18yFttmw2OP1/Vlba7aRL4VHL/AeCvgC8sKHM5cCgiDgMk17meBJ5Peeyuef8as8Hx56va0rYkVkfEDEDyc1WTMmuB1xoeTyfP1d0u6VlJ97fqrkrLi9DMBsefr2rrmCQkPSHpQJPbZJfHaHYqUW+LfhX4MHApMAP8fps4tkqakjQ1NzfX5aFrOl19ysz6589XtXXsboqIa1q9JumYpDURMSNpDTDbpNg0cFHD43XA0eS9jzW8173At9vEsQPYATA+Pt5Th2eVFqH52gBWNFX6fNmZ0nY37QRuSe7fAjzSpMxeYJOkiyUtAW5M6pEklrrPAAdSxtNUVRah1ft+J7bv5qZ7n2Fi+24ePfi6BwktV1X5fFlziuj/C0bSCuAhYD3w98CvRcSbki4E7ouIiaTcBPAVYDFwf0T8XvL8/6TW1RTAEeBf18c42hkfH4+pqameYq2fgZd5EdrhuZNMbN99xhnbrjuu8roJy1UVPl+jQNK+iBjvpU6q2U0RcRz4dJPnjwITDY93AWdMb42Im9McvxdVWITmDf6sqKrw+bLmvHdTidT7fhu579fMBslJokTc92tmw+a9m0rEG/xZFXnGXrE5SZRMHn2//hDboHi1dvG5u8na8rRbG5T5+eC57/+AF18/wb+6aiNrPrTUq7ULyEmiIIq6QZq3XLBBqJ98/PMdT7P9yUPct/swN1/506cThVdrF4eTRAEU+WzdWy7YIDQ7+dj+nVf4Z5et84y9gnGSKIAin6172q0NQquTj8WL8Iy9gnGSKIAin6172q0NQquTj09fssqD1gXj2U0FUOQN0jztdnQMcxZb/eRj4aymn1t7rv9vFUyqvZvy0s/eTUXmaYCWt2H9H2xMRGs+tJQP5mHupE8+hqWfvZucJArCG6RZnoaxeaRPhvLXT5LwmERB1BfJXblxJRvHlvtDY0M1jHGxIk/QsNacJMxsKLPYijxBw1pzkjCzocxi83TqcvLsJjMbyiy2VjOaPJ262NJeme584E+BDdSuLHdDRLzVpNz9wK8AsxHxs73WX6iKA9dmo8ATNPKVx8D1NuDJiNgEPJk8buZrwJYU9c2sAjxBo3zSJolJ4IHk/gPA9c0KRcRTwJv91jezMw1jU8h2xyjqppSWrbRjEqsjYgYgImYkrRpUfUlbga0A69ev7zdes0oYxpqDdscAvOZhRHRsSUh6QtKBJrfJYQRYFxE7ImI8IsbHxsaGeWizwhnGmoN2x/Cah9HRsSUREde0ek3SMUlrklbAGmC2x+OnrW82ktqtOchqhXS7Y0Qw8ONbMaQdk9gJ3JLcvwV4ZMj1zUbSMNYctDtGkdY8eGxksNImiTuBayW9AlybPEbShZJ21QtJ+iawB/gZSdOSbm1X38zaG8bit3bHKMoW8kW+YFdVeIM/s5IaxpqDdscowpqHdhsTblixbGhbn5dFP+skvOLarKTqaw4GOQbQ7hjDOH4nrcZN3nznR7z4+tuefZUB791kZqXVamzkrMWLPPsqI04SZlZarcZG/uHdD7zjbEbc3WRmpdVqY8Ijx98p7CWBy8YtCTMrtWb7QRVl9lUVuCVhA9V4TWPPMLFhGcbW56PCScIGxtc0tjwVYfZVFbi7yQbG+/uYlZ+ThA2Mr2lsVn5OEjYwRdrfx8z64yRhA+MZJmbl54FrGxjPMDErPycJGyjPMDErN3c3mZlZS04SZmbWUqokIel8SY9LeiX5eV6LcvdLmpV0YMHzX5L0fUn7k9tEmnjMzCxbaVsS24AnI2IT8GTyuJmvAVtavPYHEXFpctvVooyZmeUgbZKYBB5I7j8AXN+sUEQ8BbyZ8lhmZjZkaZPE6oiYAUh+rurjPW6X9GzSJdW0u8rMzPLRMUlIekLSgSa3yQyO/1Xgw8ClwAzw+23i2CppStLU3NxcBoc2M7NOOq6TiIhrWr0m6ZikNRExI2kNMNvLwSPiWMN73Qt8u03ZHcAOgPHx8ejlOGZm1p+03U07gVuS+7cAj/RSOUksdZ8BDrQqa2Zmw5c2SdwJXCvpFeDa5DGSLpR0eqaSpG8Ce4CfkTQt6dbkpS9Lek7Ss8DVwL9PGY+ZmWUo1bYcEXEc+HST548CEw2Pb2pR/+Y0xzczs8HyimszM2vJScLMzFryLrCW2vx8cOT4Oxw7cYrV53g7cLMqcZKwVObng0cPvn76Wtb1Cwtt+cgFThRmFeDuJkvlyPF3TicIqF3D+nMP7efI8Xdyjsys3Obng8NzJ9nz6hscnjvJ/Hw+y8PckrBUjp04dTpB1J16b57Zt0/5QkNmfSpSC90tCUtl9TlLT1/Dum7pWYtYdfbSnCIyK78itdCdJCyVDSuWcdcNl55OFPUzng0rluUcmVl5tWuhD5u7myyVRYvElo9cwCV3XMXs26dYdbZnN5mlVW+hNyaKvFrobklYaosWiY1jy7ly40o2ji13gjBLqUgtdLckzErK61Oqq0gtdCeJkvIXxGgr0uwXG4x6Cz3vWYLubiqh+hfExPbd3HTvM0xs382jB1/PbR61DV+RZr9YtTlJlJC/IKxIs1+s2pwkSshfEOb1KTYsThIl5C8IK9LsF6u2VAPXks4H/hTYABwBboiItxaUuQj4OnABMA/siIg/7La+nan+BbFw0LIMXxAecM9GkWa/WLUpov/BTklfBt6MiDslbQPOi4gvLCizBlgTEd+TdDawD7g+Ip7vpn4z4+PjMTU11XfcVVD/si3TF4Rn5JjlS9K+iBjvpU7a7qZJ4IHk/gPA9QsLRMRMRHwvuf828AKwttv61lwZF7B5wN2sfNImidURMQO1ZACsaldY0gbg48AzvdaXtFXSlKSpubm5lGFbHjzgblY+HcckJD1BbTxhoS/2ciBJy4E/Az4bESd6qQsQETuAHVDrbuq1vuWvSPvRmFl3OiaJiLim1WuSjklaExEzydjDbItyZ1FLEH8SEd9qeKmr+lYNZR5wNxtVabfl2AncAtyZ/HxkYQFJAv4YeCEi7uq1vlWHZ+SYlU/a2U0rgIeA9cDfA78WEW9KuhC4LyImJH0S2A08R20KLMB/iohdrep3Oq5nN5mZ9a6f2U2pWhIRcRz4dJPnjwITyf2/BpqeKraqb93xmgMzGzTvAltSXnNgZsPgbTlKymsOzGwYnCRKymsOzGwY3N1UUl5zYKPCY2/5ckuipLwLqI0CX2Arf6mmwObFU2BryrjJn1kvDs+dZGL77jNazLvuuCr3y3qW0dCnwFq+inINXLNBaTf25v/3w+HuJjMrLF9gK39OEmZWWB57y5+7m8yssLzfV/6cJMys0Dz2li93N5mZWUtOEmZm1pKThJmZteQkYWZmLTlJmJlZS6XclkPSHPB3DU+tBN7IKZxelCFOx5gNx5iNMsQI5YhzJbAsIsZ6qVTKJLGQpKle9yPJQxnidIzZcIzZKEOMUI44+43R3U1mZtaSk4SZmbVUlSSxI+8AulSGOB1jNhxjNsoQI5Qjzr5irMSYhJmZDUZVWhJmZjYApUwSks6X9LikV5Kf57Upu1jS30j6dtFilLRU0v+T9LeSDkr6z8OMsYc4L5L0l5JeSOL8d0WLMSl3v6RZSQeGGNsWSS9JOiRpW5PXJWl78vqzki4bVmw9xHiJpD2SfiTp88OOr8sYfyP5+z0r6buSPlbAGCeT+PZLmpL0yaLF2FDu5yV9IOlXO75pRJTuBnwZ2Jbc3wb81zZlPwd8A/h20WIEBCxP7p8FPANcWcA41wCXJffPBl4GNhcpxuS1XwQuAw4MKa7FwKvARmAJ8LcL/y7ABPDnyb/1lcAzQ/737SbGVcDPA78HfH6Y8fUQ4y8A5yX3ryvo33E5P+7C/yjwYtFibCj3HWAX8Kud3reULQlgEngguf8AcH2zQpLWAb8M3DecsH5Cxxij5mTy8KzkNuxBom7inImI7yX33wZeANYOK0C6/PeOiKeAN4cUE8DlwKGIOBwR7wIPUou10STw9eTf+mngXElrihRjRMxGxF7gvSHG1aibGL8bEW8lD58G1hUwxpORfAsDyxj+Z7mb/48A/xb4M2C2mzcta5JYHREzUPsCo3Ym1MxXgP8AzLd4fZC6ijHpDttP7R/s8Yh4ZnghAt3/LQGQtAH4OLVWz7D0FOMQrQVea3g8zZnJs5syg5T38bvRa4y3UmudDVNXMUr6jKQXgf8L/MshxVbXMUZJa4HPAPd0+6aFveiQpCeAC5q89MUu6/8KMBsR+yR9KsPQGo+RKkaAiPgAuFTSucDDkn42IjLtU88izuR9llM7A/lsRJzIIraG984kxiFrdnm0hWeP3ZQZpLyP342uY5R0NbUkMez+/q5ijIiHqX2OfxH4L8A1gw6sQTcxfgX4QkR8IHV3db/CJomIaPnHlXRM0pqImEma7s2aTZ8A/qmkCWApcI6k/xURv1mgGBvf6weS/grYAmSaJLKIU9JZ1BLEn0TEt7KML6sYczANXNTweB1wtI8yg5T38bvRVYySPkqt6/i6iDg+pNjqevo7RsRTkj4saWVEDGtPp25iHAceTBLESmBC0vsR8X9avWlZu5t2Arck928BHllYICL+Y0Ssi4gNwI3Ad7JMEF3oGKOksaQFgaSfonbW8eKwAkx0E6eAPwZeiIi7hhhbXccYc7IX2CTpYklLqP0/27mgzE7gt5JZTlcCP6x3nRUoxrx1jFHSeuBbwM0R8XJBY/wnyWeFZBbbEmCYyaxjjBFxcURsSL4X/zfwO+0SRL1S6W7ACuBJ4JXk5/nJ8xcCu5qU/xTDn93UMUZqMyD+BniWWuvhd4v4t6TWtI8kzv3JbaJIMSaPvwnMUBuAnQZuHUJsE9Rme70KfDF57jbgtuS+gLuT158DxnP4N+4U4wXJ3+sE8IPk/jkFi/E+4K2G/39TBfw7fgE4mMS3B/hk0WJcUPZrdDG7ySuuzcyspbJ2N5mZ2RA4SZiZWUtOEmZm1pKThJmZteQkYWZmLTlJmJlZS04SZmbWkpOEmZm19P8BPIXrEF+SQV0AAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Data parameters\n",
    "points_per_cluster = 25\n",
    "mu1 = np.array([-0.2, 0])\n",
    "mu2 = np.array([0.2, 0])\n",
    "sigma = 0.1\n",
    "\n",
    "# Generate the data and labels\n",
    "np.random.seed(420)\n",
    "data = np.concatenate([mu1 + np.random.normal(size=(points_per_cluster, 2), scale=sigma),\n",
    "                       mu2 + np.random.normal(size=(points_per_cluster, 2), scale=sigma)])\n",
    "\n",
    "# Plot the data\n",
    "fig, ax = plt.subplots()\n",
    "sns.scatterplot(x=data[:,0], y=data[:,1])\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Apply topological optimization to the embedding\n",
    "\n",
    "We now show how we can use topological optimization to encourage the model underlying the data to become connected. As a topological loss, we will use the persistence of the second most prominent gap."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define topological loss\n",
    "def g(p): return p[1] - p[0] # function that returns the persistence d - b of a point (b, d)\n",
    "TopLayer = AlphaLayer(maxdim=0) # alpha complex layer\n",
    "Component2Persistence = DiagramLoss(dim=0, i=2, j=2, g=g) # compute persistence of second most prominent gap\n",
    "\n",
    "# Construct topological loss function\n",
    "def top_loss(output):\n",
    "    dgminfo = TopLayer(output)            \n",
    "    loss = - Component2Persistence(dgminfo)\n",
    "    \n",
    "    return loss"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can now conduct the topological optimization as follows."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[epoch 1] [topological loss: -0.131397]\n",
      "[epoch 10] [topological loss: -0.282094]\n",
      "[epoch 20] [topological loss: -0.427987]\n",
      "[epoch 30] [topological loss: -0.579585]\n",
      "[epoch 40] [topological loss: -0.736238]\n",
      "[epoch 50] [topological loss: -0.896679]\n",
      "[epoch 60] [topological loss: -1.062835]\n",
      "[epoch 70] [topological loss: -1.235956]\n",
      "[epoch 80] [topological loss: -1.402753]\n",
      "[epoch 90] [topological loss: -1.571399]\n",
      "[epoch 100] [topological loss: -1.741448]\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAD4CAYAAADvsV2wAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAakklEQVR4nO3df5RU5Z3n8fe3u8G2+WUD3XQLYttrGxWDyqlj1IjrgOQgMxlcs3FIZhIza5bjnkGcuLsju8nOZs/kzGhm1x05cSaHIZmDJhvjRqNsgiSKesZZg2ujoCCaJoxgSwMdgoBoB5r67h99uy2KW91Vfau6qvv5vM7p03Xvfaqeh6L6c5967r3PNXdHRETGvqpyN0BEREaGAl9EJBAKfBGRQCjwRUQCocAXEQlETbkbMJjp06d7S0tLuZshIjJqbNmy5dfu3hC3raIDv6Wlhfb29nI3Q0Rk1DCzPbm2aUhHRCQQCnwRkUAo8EVEAqHAFxEJhAJfRCQQFX2WjoysdNp5+9BxDhztYcbkWlqmTaCqysrdLBEpEgW+AH1hv3HHfu5+dCs9J9PUjqvi/luvYPGcJoW+yBihIR0B4O1DxwfCHqDnZJq7H93K24eOD/q8dNrZ3f0+v/jVr9nd/T7ptKbbFqlU6uELAAeO9gyEfb+ek2kOHuuhtWFi7HP0rUBkdFEPXwCYMbmW2nGnfxxqx1XROKk2Zy9+uN8KRKQ8FPgCQMu0Cdx/6xUDod/fW59dX8fGHftZsvoFPvf3L7Fk9Qts3LGfdNoH/VYgIpVHQzoCQFWVsXhOExevnM/BYz00Tuo7SydXL/7ilfMHvhVkhn7/twIRqTzq4cuAqiqjtWEiV7dOp7VhIlVVNmgvPte3gpZpE8rRfBEZQlF6+Ga2GHgAqAbWuvu9Wdv/ELgnWnwf+Hfuvq0YdUvppNNO3fianL34XN8KdMBWpDIlDnwzqwYeBBYBncDLZrbe3d/IKPbPwL9098NmdhOwBvhE0rolueyLrWbX17H38AccONpD7ynngU1vsXJBG6uf7TjtTJz+Xnz/t4JcZ/KISOUoRg//KmCXu+8GMLNHgKXAQOC7+4sZ5TcDs4pQrySUfVrl+dPO5s4FbXztie0D4b5yQRsbt3dx+3WtVFfBwosb+fjMc9SLFxmFijGGPxN4J2O5M1qXy+3AU7k2mtlyM2s3s/bu7u4iNE9yyT4g+3tzZw6EPfSN1a9+toP5FzXy4HO7WL1pFx+ePKWwFxmlihH4cX/9sZdbmtnv0Bf498RtB3D3Ne6ecvdUQ0PsXbqkSLIPyJoRe4D2rJqPDsrqDByR0asYQzqdwHkZy7OAfdmFzGwusBa4yd0PFaHeoJRiYrNcp1VmL1/YOJHzp53NPYsv0Rk4IqNYMXr4LwNtZnaBmY0HlgHrMwuY2WzgceAL7v7LItQZlP6x9riLn5LIPq3y/2x7l//2+3NOO81y5YI27tu4k9XLrtSUCSKjXOIevrv3mtkK4Gf0nZb5XXffYWZ3RNu/Dfw5MA34WzMD6HX3VNK6y2WkpxEe7OKnwea5iWtj9vpPXTKDDSvns+fQcV595z1+23uK269rxQzc4eHNe+g60qOxe5ExoCjn4bv7BmBD1rpvZzz+MvDlYtRVbuWYMKzQic1ytfFTl8zg5zsPxK5/Y99RVm/axYoFF/Kdf9qtq2dFxiBdaVugckwYNtjEZoW0cUfXkZzrOw4eo3ZcFY9t6WTlgjZdPSsyBinwC5R0wrDhzB9f6BQGcW2srxvP0Z5evjy/lRULLqR5Su1A27uO9PBoe1/QH/7gBA9v3sPy61tZvewKfnrnfI3di4wRmjytQEkmDBvucFChUxhkt7F5Si1fvOZ8/u1D7addUPXw5j0c/uAEzVNqB4K+f/weYO6sKbRM1xW0ImOFevgU1utOMmFYkuGguInN8m3jZ1OzeGBTxxkXVH02NYv7b72COc1TuP/WKzj8wQkefG4Xa1/YzcVNk5k9VcM4ImNJsD38/rNVDh3/Lfve6+Gex17Lq9edZMKw4dxVajiy2/jhiVOx9c6dOYUbLmqkpqZKk6CJBCDIHn7mee3Pv/XrgbCH/HrdhfS2MxV68LUY3KF2XHVsvSdPOc93HCSd9mH/m0Rk9Agy8DOHVnJNJ1CKuzaVav747CGp3t70aRdqrXr8Nb7+6dMvqPqvn57DvRt38lrnEd2SUCQQQQ7pZA+tlPquTZkXO13aPImf3jmf7veLM3QSdyD4vs/M5f6n3xr4N53odU70nmL59a2kHaoMxtcYJ3qdtFP0ISURqUxBBn7mWSz9553nmu89qWJeqBV39WzcgeB7HnuN269r5cHndgFwy7xZ/OVTb56xU1t+fSuALqoSCUSQgd8/tHL3o1vpOtLDD9v3suYLKcZVW9GnShjOtAhx4nYc3/r8lVRhsUNS1RmDdbmGrWZPreOcunG6qEokEEEGfj5n2mT2ppun1HIq3Tf0UegOoVhn5mTvOOrrxtNx4H16Tp6KHZJKnT91YH21xQ9bXdQ4kct0MxORYAQZ+DD4rfkye9P1deP54jXnD5zHXuiQTJILtTJl7zhumdd3bn193Xi+cuNF/M9nfjnQvm9+Zi6Taqv54fKr+c3xE3QcOHZGmfs+M1dhLxKYYAN/MJm96f5gHe6QTObwUZJjBNk7jsxhmtqaqtMOyL73wQn+6qk3OfzBCb71+SuZNXUC923cOXCbwtT5U7m2dZrCXiQwCvwYmb3pwU7bzCfwk1yolSl7x9E/THPLvFn81cYzD8j2H7Rd8b9eZeNd8/mHL12li6pEAqfAj5Hdm046JDPY8FEhr5G542iaXMvHmibz5v6jsTuk/vlwek6m2X+0Z+CCKhEJV5AXXmXLvnBpdn3dwAVSj23p5K6FxZsueDizZfbLvBq2ZfpEFs9pYuHFM2KvonX/6LFOuxQRUA9/0JuFbMjoTX/q0qZhXyyVecZP7ynna0++zp5DHxZ8ADjuPPyPz5xyxjGCuxa28dAv9mguexE5jbknuy9qKaVSKW9vby9pHbu732fJ6hfOGLLZUOB58rnE7VD6pybuOtKTd12DXcAFfQeaDx7roWFiLdVVsP+oxutFQmRmW3LdQjb4IZ2kNzQZStyFV6uf7eCWebMKqmuwqZUzh3r+ReNEWqZrEjQROVPwgV/qGSxz7VD6D6rmW1epd0wiMvYFH/ilmsGyX64dSt+0xfnXVY6plUVkbCnKGL6ZLQYeAKqBte5+b9Z2i7YvAT4AvuTurwz1uiMxhg8fHQzNPk897iDpcCY8i5vNcuY5tUydcFber1nMSdhEZOwabAw/ceCbWTXwS2AR0Am8DHzO3d/IKLMEuJO+wP8E8IC7f2Ko1x6pwI9Tilku873wKdeOptDXEZHwDBb4xTgt8ypgl7vvjip7BFgKvJFRZinwkPftXTab2Tlm1uzuXUWovySKNcslFHbh1VA7mqQXcIlIuIoxhj8TeCdjuTNaV2gZAMxsuZm1m1l7d3d3EZo3PLkOku45dLygi6UKleRG5yIigylG4MeNKWQnYj5l+la6r3H3lLunGhoaEjduuHIdJH31nffYuGN/yUJfZ+OISKkUI/A7gfMylmcB+4ZRpqLEnb2zckEb/7u9s6Q9bp2NIyKlUozAfxloM7MLzGw8sAxYn1VmPfBF63M1cKSSx+/ho8nK1v3xVaxYcCG3X9c6cHVsz8k0B46Wpsdd6tNERSRciQ/aunuvma0AfkbfaZnfdfcdZnZHtP3bwAb6ztDZRd9pmX+ctN6RUFVl1I6rZu0Lu8+YeqFufHXJ6izGdMoiItmKMnmau2+gL9Qz130747EDf1KMuoZSjHPnM504deqMm5yvXNDGyVPpoZ88TDobR0RKYUzNllmKi5OmTTiLZ9/czzf/9eV8+Nte6s6qYd2Lu1l8WVORWy8iUlpjKvCLee58v9n1dSy76nz+7EfbBnYi37j5MmbX1xWz6SIiJTem5tIp1imNmTcp2dF1hK89sf20ncjXntjO3sMf5P0ahd7oRESkFMZUDz/71oRQ+CmN2cNCKxdeWPA9bTXvjYhUojHVwy/GKY3Zw0LpaFbLTEPtRHS1rIhUojHVwy/GKY3Zw0KPbek84yydoXYiuYaWDhzN/a1ARKTUxlTgQ/JTGrOHhbqO9PDD9r38cPnVfHjyVF47kVxDSydPOem0a1hHRMpiTA3pFEPcsNA9iy/h4zPPyfu2gS3TJnDfZ+aeMS3Df3nydfb+5rgO5opIWYy5Hn5SxRgWqqoyzj2nltuva8UM3OHhzXsAeGXve/znH7+ug7kiMuIU+DEKHRbq7U2zo+sIXUd6aJ5yNnOaJzNtwll8559On5Jh5cILB8IeinOdgIhIvjSkk1Bvb5ontr3LH6zZzB3fe4U/WPMLntj2LrOmnH3G0NBFjZM09bGIlI16+AnlujCrrXHiGUND/TcuT3KdgIjIcKmHn1D/dMmZek6m2X+kZ2BoqP9g7wXTNfWxiJSPevgJNU85O7bX3jTlzF67pj4WkXJSDz+hOc2T+cbNl53Wa//GzZcxp3lKbPnsXr/CXkRGinr4CdXUVHHz5TNpa5zI/iM9NE2pZU7zFGpqtC8VkcqiwC+CmpoqLj+vnsvPG7qsiEi5qBsqIhIIBb6ISCA0pJOnYt8rV0RkpCnw86AbmojIWJBoSMfMpprZ02bWEf2ujylznpk9Z2Y7zWyHmd2VpM5y0A1NRGQsSDqGvwrY5O5twKZoOVsv8O/d/RLgauBPzOzShPWOqGLdK1dEpJySDuksBW6IHq8DngfuySzg7l1AV/T4mJntBGYCbySse8T039Ckvm48t8ybhRlUGzRN1hw4IjJ6JA38GVGg4+5dZtY4WGEzawGuBF4apMxyYDnA7NmzEzavOFqmTeBbn7+SjgPv88Cmj251+LGmycyeqoO3IjI6DDmkY2bPmNn2mJ+lhVRkZhOBx4A/dfejucq5+xp3T7l7qqGhoZAqSqaqyrhg2sSBsAeN44vI6DNkD9/db8y1zcwOmFlz1LtvBg7mKDeOvrD/vrs/PuzWltHBY7nH8XXzEhEZDZIetF0P3BY9vg14MruAmRnwHWCnu9+fsL6y6R/Hz6S57EVkNEka+PcCi8ysA1gULWNm55rZhqjMJ4EvAAvMbGv0syRhvSMu7ubmmsteREYTc/dytyGnVCrl7e3t5W7GgP6rbTWXvYhUKjPb4u6puG260rYAhd7cXESkkmjyNBGRQCjwRUQCocAXEQmEAl9EJBAKfBGRQCjwRUQCocAXEQmEAl9EJBAKfBGRQCjwRUQCocAXEQmEAl9EJBAKfBGRQCjwRUQCocAXEQmEAl9EJBAKfBGRQCjwRUQCocAXEQlEosA3s6lm9rSZdUS/6wcpW21mr5rZT5LUKSIiw5O0h78K2OTubcCmaDmXu4CdCesTEZFhShr4S4F10eN1wM1xhcxsFvC7wNqE9YmIyDAlDfwZ7t4FEP1uzFHub4A/A9IJ6xMRkWGqGaqAmT0DNMVs+mo+FZjZ7wEH3X2Lmd2QR/nlwHKA2bNn51OFiIjkYcjAd/cbc20zswNm1uzuXWbWDByMKfZJ4PfNbAlQC0w2s++5+x/lqG8NsAYglUp5Pv8IEREZWtIhnfXAbdHj24Answu4+39y91nu3gIsA57NFfYiIlI6SQP/XmCRmXUAi6JlzOxcM9uQtHEiIlI8Qw7pDMbdDwELY9bvA5bErH8eeD5JnSIiMjy60lZEJBAKfBGRQCjwRUQCocAXEQmEAl9EJBAKfBGRQCjwRUQCocAXEQmEAl9EJBAKfBGRQCjwRUQCocAXEQmEAl9EJBAKfBGRQCjwRUQCocAXEQmEAl9EJBAKfBGRQCjwRUQCocAXEQmEAl9EJBCJAt/MpprZ02bWEf2uz1HuHDP7kZm9aWY7zeyaJPWKiEjhkvbwVwGb3L0N2BQtx3kA2OjuFwOXAzsT1isiIgVKGvhLgXXR43XAzdkFzGwycD3wHQB3P+Hu7yWsV0RECpQ08Ge4exdA9Lsxpkwr0A38g5m9amZrzWxCrhc0s+Vm1m5m7d3d3QmbJyIi/YYMfDN7xsy2x/wszbOOGmAe8HfufiVwnNxDP7j7GndPuXuqoaEhzypERGQoNUMVcPcbc20zswNm1uzuXWbWDByMKdYJdLr7S9Hyjxgk8EVEpDSSDumsB26LHt8GPJldwN33A++Y2ceiVQuBNxLWKyIiBUoa+PcCi8ysA1gULWNm55rZhoxydwLfN7PXgCuAv0xYr4iIFGjIIZ3BuPsh+nrs2ev3AUsylrcCqSR1iYhIMrrSVkQkEAp8EZFAKPBFRAKhwBcRCYQCX0QkEAp8EZFAKPBFRAKhwBcRCYQCX0QkEAp8EZFAKPBFRAKhwBcRCYQCX0QkEAp8EZFAKPBFRAKhwBcRCYQCX0QkEAp8EZFAKPBFRAKhwBcRCYQCX0QkEIkC38ymmtnTZtYR/a7PUe4rZrbDzLab2Q/MrDZJvSIiUrikPfxVwCZ3bwM2RcunMbOZwEog5e6XAdXAsoT1iohIgZIG/lJgXfR4HXBzjnI1wNlmVgPUAfsS1isiIgVKGvgz3L0LIPrdmF3A3d8F/juwF+gCjrj7z3O9oJktN7N2M2vv7u5O2DwREek3ZOCb2TPR2Hv2z9J8KojG9ZcCFwDnAhPM7I9ylXf3Ne6ecvdUQ0NDvv8OEREZQs1QBdz9xlzbzOyAmTW7e5eZNQMHY4rdCPyzu3dHz3kcuBb43jDbLCIiw5B0SGc9cFv0+DbgyZgye4GrzazOzAxYCOxMWK+IiBQoaeDfCywysw5gUbSMmZ1rZhsA3P0l4EfAK8DrUZ1rEtYrIiIFMncvdxtySqVS3t7eXu5miIiMGma2xd1Tcdt0pa2ISCAU+CIigVDgi4gEQoEvIhIIBb6ISCAU+CIigVDgi4gEQoEvIhIIBb6ISCAU+CIigVDgi4gEQoEvIhIIBb6ISCAU+CIigVDgi4gEQoEvIhIIBb6ISCAU+CIigVDgi4gEQoEvIhIIBb6ISCASBb6ZfdbMdphZ2sxi75IelVtsZm+Z2S4zW5WkThERGZ6kPfztwC3AP+YqYGbVwIPATcClwOfM7NKE9YqISIFqkjzZ3XcCmNlgxa4Cdrn77qjsI8BS4I0kdYuISGFGYgx/JvBOxnJntC6WmS03s3Yza+/u7i5540REQjFkD9/MngGaYjZ91d2fzKOOuO6/5yrs7muANQCpVCpnORERKcyQge/uNyasoxM4L2N5FrAv4WuKiEiBRmJI52WgzcwuMLPxwDJg/QjUKyIiGZKelvmvzKwTuAb4qZn9LFp/rpltAHD3XmAF8DNgJ/Cou+9I1mwRESlU0rN0fgz8OGb9PmBJxvIGYEOSukREJBldaSsiEggFvohIIBIN6YiISPGk087bh45z4GgPMybX0jJtAlVVg17YWhAFvohIBUinnY079nP3o1vpOZmmdlwV9996BYvnNBUt9DWkIyJSAd4+dHwg7AF6Tqa5+9GtvH3oeNHqUOCLiFSAA0d7BsK+X8/JNAeP9RStDgW+iEgFmDG5ltpxp0dy7bgqGifVFq0OBb6ISAVomTaB+2+9YiD0+8fwW6ZNKFodOmgrIlIBqqqMxXOauHjlfA4e66Fxks7SEREZs6qqjNaGibQ2TCzN65fkVUVEpOIo8EVEAqHAFxEJhAJfRCQQCnwRkUCYe+XeNtbMuoE95W5HlunAr8vdiAKNtjaPtvaC2jxSRluby9He8929IW5DRQd+JTKzdndPlbsdhRhtbR5t7QW1eaSMtjZXWns1pCMiEggFvohIIBT4hVtT7gYMw2hr82hrL6jNI2W0tbmi2qsxfBGRQKiHLyISCAW+iEggFPgxzGyqmT1tZh3R7/qYMueZ2XNmttPMdpjZXRnbvm5m75rZ1uhnSYnaudjM3jKzXWa2Kma7mdnqaPtrZjYv3+eWSh5t/sOora+Z2YtmdnnGtrfN7PXoPW2voDbfYGZHMv6//zzf55apvf8xo63bzeyUmU2NtpXrPf6umR00s+05tlfUZzmP9lbc5xgAd9dP1g/wTWBV9HgVcF9MmWZgXvR4EvBL4NJo+evAfyhxG6uBXwGtwHhgW3/9GWWWAE8BBlwNvJTvc8vY5muB+ujxTf1tjpbfBqaP8GchnzbfAPxkOM8tR3uzyn8aeLac73FU7/XAPGB7ju2V9lkeqr0V9Tnu/1EPP95SYF30eB1wc3YBd+9y91eix8eAncDMkWogcBWwy913u/sJ4BH62p1pKfCQ99kMnGNmzXk+tyxtdvcX3f1wtLgZmDUC7RpMkveqHO9zoXV+DvhBids0JHf/R+A3gxSpqM/yUO2twM8xoCGdXGa4exf0BTvQOFhhM2sBrgReyli9Ivo69924IaEimAm8k7HcyZk7nFxl8nluKRRa7+309er6OfBzM9tiZstL0L44+bb5GjPbZmZPmdmcAp9bTHnXaWZ1wGLgsYzV5XiP81Fpn+VCVMLnGAj4jldm9gzQFLPpqwW+zkT6/mD+1N2PRqv/DvgL+v5j/wL4H8C/GX5r46uOWZd9jm2uMvk8txTyrtfMfoe+P5TrMlZ/0t33mVkj8LSZvRn1tEopnza/Qt/8Je9Hx2ueANryfG6xFVLnp4H/6+6ZPdVyvMf5qLTPcl4q6HMMBBz47n5jrm1mdsDMmt29K/raeDBHuXH0hf333f3xjNc+kFHm74GfFK/lAzqB8zKWZwH78iwzPo/nlkI+bcbM5gJrgZvc/VD/enffF/0+aGY/pu/rfKn/UIZsc8aOHnffYGZ/a2bT83luCRRS5zKyhnPK9B7no9I+y0OqsM/xQOX6OfOAy19z+kHbb8aUMeAh4G9itjVnPP4K8EgJ2lgD7AYu4KODVXOyyvwupx/o+n/5PrdE72s+bZ4N7AKuzVo/AZiU8fhFYHGFtLmJjy5ivArYG73nI/4+51snMIW+MegJ5X6PM+pvIfdB0Ir6LOfR3or6HA/UP1IVjaYfYBqwCeiIfk+N1p8LbIgeX0ffV8fXgK3Rz5Jo28PA69G29WTsAIrcziX0nR30K+Cr0bo7gDuixwY8GG1/HUgN9twRem+HavNa4HDGe9oerW+N/pi3ATsqrM0rojZto+8A3bWDPbfc7Y2Wv0RWR6TM7/EPgC7gJH29+dsr+bOcR3sr7nPs7ppaQUQkFDpLR0QkEAp8EZFAKPBFRAKhwBcRCYQCX0QkEAp8EZFAKPBFRALx/wE4TQuGCOxCyQAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Learning hyperparameters\n",
    "num_epochs = 100\n",
    "learning_rate = 1e-2\n",
    "\n",
    "# Conduct topological optimization\n",
    "Y = torch.autograd.Variable(torch.tensor(data).type(torch.float), requires_grad=True)\n",
    "optimizer = torch.optim.Adam([Y], lr=learning_rate)\n",
    "\n",
    "for epoch in range(num_epochs):\n",
    "    optimizer.zero_grad()\n",
    "    loss = top_loss(Y)\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "\n",
    "    if epoch == 0 or (epoch + 1) % (int(num_epochs / 10)) == 0:\n",
    "        print (\"[epoch %d] [topological loss: %f]\" % (epoch + 1, loss.item()))\n",
    "        \n",
    "Y = Y.detach().numpy().copy()\n",
    "\n",
    "# View topologically optimized embedding\n",
    "fig, ax = plt.subplots()\n",
    "sns.scatterplot(x=Y[:,0], y=Y[:,1])\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We see that the topological optimization served its purpose, i.e., it results in a point cloud consisting of two clusters. However, one of these clusters consists of merely one point. To accomodate for this, we can compute the topological loss from a random sample of our data as to represent larger clusters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[epoch 1] [topological loss: -0.308334]\n",
      "[epoch 10] [topological loss: -0.125648]\n",
      "[epoch 20] [topological loss: -0.298769]\n",
      "[epoch 30] [topological loss: -0.441468]\n",
      "[epoch 40] [topological loss: -0.331700]\n",
      "[epoch 50] [topological loss: -0.495583]\n",
      "[epoch 60] [topological loss: -0.605348]\n",
      "[epoch 70] [topological loss: -0.605039]\n",
      "[epoch 80] [topological loss: -0.156322]\n",
      "[epoch 90] [topological loss: -0.719416]\n",
      "[epoch 100] [topological loss: -0.644189]\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAD7CAYAAABpJS8eAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAWRklEQVR4nO3df4xdZZ3H8c93uiUT+kNLZ/rDTodp46xYIgtmRBJkVRDTjsZqooSqLBrNhFVsDSaCa9Z//Af/YWmzKCnoitGkISsrDRnLatFsXaFLq112a4MtY5FCmZbqUloyQrnf/ePeW6aX++Pce36f834lTe+dezrnmTOdzzz3eb7Pc8zdBQAovr60GwAASAaBDwAlQeADQEkQ+ABQEgQ+AJQEgQ8AJRFJ4JvZWjN70swOmdltbY57l5m9ZmYfj+K8AIDgQge+mc2RdJekdZLWSNpgZmtaHPctSQ+HPScAoHt/FcHnuFzSIXefkiQz2yZpvaTfNRz3JUk/lvSuoJ94YGDAR0ZGImgiAJTD3r17X3D3wWavRRH4KyQ9M+v5EUnvnn2Ama2Q9DFJV6tD4JvZhKQJSRoeHtaePXsiaCIAlIOZPd3qtSjG8K3Jxxr3a7hT0q3u/lqnT+buW919zN3HBgeb/pICAPQgih7+EUkrZz0fkvRcwzFjkraZmSQNSBo3szPu/pMIzg8ACCCKwH9c0qiZrZL0rKTrJX1y9gHuvqr+2My+L+khwh4AkhU68N39jJndrGr1zRxJ33P3/WZ2U+31u8OeAwAQXhQ9fLn7pKTJho81DXp3/0wU5wQAdCeSwAdQVam4Dp84remTM1q6sF8ji+epr69ZXQOQPAI/YwiM/KpUXDv2P69b7t+nmVcr6p/bpzuuu1RrL17G9xCZwF46GVIPjPEtu7Thnt0a37JLO/Y/r0qFu5LlweETp8+GvSTNvFrRLffv0+ETp1NuGVBF4GcIgZFv0ydnzn7v6mZerejYSzMptQg4F4GfIQRGvi1d2K/+uef+SPXP7dOSBf0ptQg4F4GfIQRGvo0snqc7rrv07PewPoY/snheyi0Dqpi0zZB6YDRO+hEY+dDXZ1p78TJdtPEqHXtpRksWMOmObDH37E4Ijo2Nedk2T6tX6RAYAHphZnvdfazZa/TwM6avz7R6cL5WD85PuykACoYxfAAoCQIfAEqCwAeAkiDwAaAkCHwAKAkCHwBKgsAHgJIg8AGgJAh8ACgJVtoWEDdRAdAMgV8w3HUJWUVHJH0M6RQMN1FBFnE3t2wg8AuGm6ggi+iIZAOBnzGVimvq+Ck9+tQLmjp+quseEDdRQRbREckGAr8HYUO53ecN+7aXuy4hi+iIZAM3QOlSnJOiU8dPaXzLrnN6Qv1z+zS58aqu9sfnJirIGooJksMNUCLUaizyoi5DuZl2b3u7+dzcRAVZw+0fs4HA71JUodxM/W1vYw+ft70oAjoi6WMMv0txjkUy/g4gTvTwu1QP5caxyChCmbe9AOLEpG0PmBQFkFVM2kYsirFIlpkDSBqBnwJK1ACkgUnbNuJaYMUycwBpoIffQpy98DhLOwGgFXr4LcTZC29X2hnXuwoAIPBbiHOzp1b19sOLzmcLWQCxYUinhThXvbaqt49z2wYgj6hmixaB30KcC6yk5qWdjO0Dr6OaLXqRBL6ZrZW0WdIcSfe6++0Nr39K0q21p6ck/b27/3cU545LGqte2UsHeB3veKMXegzfzOZIukvSOklrJG0wszUNh/1B0nvd/RJJ35S0Nex5k1DvhV+xekCrB+fH3qtgLx3gddw0JXpR9PAvl3TI3ackycy2SVov6Xf1A9z917OOf0zSUATnLRz20gFexzve6EVRpbNC0jOznh+pfayVz0n6aQTnLZR6OebuP5yQJF0+sjiRdxVAVvGON3pR9PCbJVLTOkIze7+qgf+elp/MbELShCQNDw9H0LzmsjT7z+QU8Ea8441eFIF/RNLKWc+HJD3XeJCZXSLpXknr3P1Eq0/m7ltVG+MfGxuLpQA9awHL5BTQHDdNiVYUQzqPSxo1s1Vmdp6k6yVtn32AmQ1LekDSDe7++wjOGUrW9rJhcgpAEkIHvrufkXSzpIclHZB0v7vvN7ObzOym2mHfkLRY0rfNbJ+ZpbrJfdYCNs67aAFAXSR1+O4+KWmy4WN3z3r8eUmfj+JcUcja7H/ci7wAQCrpStusBSyTUwCSUNpbHHKbQgBFxC0Om2D2H0AvslTS3a3SBn5Yef6mA2UUxc9s1kq6u0Xg9yDv33SgbKL6mc37mhlugNKDrNXxA2gvqp/ZrJV0d4vA70Hev+lA2UT1M5v3NTMEfg/y/k0Hyiaqn9m8b+hW2rLMMBjDB/Ilyp/ZrJd0tyvLJPB7lPVvOoBzleVnljr8GFDHD+QLP7OM4QNAadDDzxgWdAGIC4GfIUwGA4gTQzoZwoIuAHEi8DOEBV0A4kTgZwgLugDEicDPkLyv4gOQbUzaxqhScf3xT6c1ffIvOv3KGV14wTytGmhddcOdrwDEicCPSaXieuTJaR2cPqXNOw8GrrphcQiAuDCkE5PDJ07riSMvng17iaobAOki8GMyfXJGFRdVNwAyg8CPydKF/ZpjouoGQGYQ+DEZWTxP7xh6kzZdM0rVDYBMYNI2Jn19pqvftlRvHZyvdw4v0suvnNFwhyodAIgTgR+jvj7TyMB8jQxQcQMgfQzpAEBJEPgAUBIEPgCUBIEPACVB4ANASVClEwFuSwggDwod+EkEMbclBJAXhR3SqQfx+JZd2nDPbo1v2aUd+59XpeKRnofbEgLIi8IGflJBzG0JAeRFYQM/qSDmtoQA8qKwgZ9UEHNbQgB5UdhJ23oQN06mRh3E3JYQQF6Ye7STmFEaGxvzPXv29Pzv61U6BDGApKRdpm1me919rNlrkfTwzWytpM2S5ki6191vb3jdaq+PS3pZ0mfc/TdRnLsd7g8LIElZL9MOPYZvZnMk3SVpnaQ1kjaY2ZqGw9ZJGq39mZD0nbDnBYCsyXqZdhSTtpdLOuTuU+7+iqRtktY3HLNe0g+86jFJbzaz5RGcGwAyI+tl2lEE/gpJz8x6fqT2sW6PkSSZ2YSZ7TGzPcePH4+geQCQjKyXaUcR+M0GphpngoMcU/2g+1Z3H3P3scHBwdCNA4CkZL1MO4pJ2yOSVs56PiTpuR6OAYDExFFNk/Uy7SgC/3FJo2a2StKzkq6X9MmGY7ZLutnMtkl6t6QX3f1oBOdGE2mXhQFZF2c1TZarA0MHvrufMbObJT2salnm99x9v5ndVHv9bkmTqpZkHlK1LPOzYc+L5rJeFgZkQatqmos2XpXJoI5KJHX47j6paqjP/tjdsx67pC9GcS60V9b/yEA32lXTFPnnpLB76ZRV1svCgCzIejVNXAj8ginrf2SgG1mvpolLYTdPK6ukNo0D8izr1TRxKfTmaWXFpnFAecW+eVpelKVcMctlYQDSU5rAp1wRQK+K0lkszaRt1nexA5BN9c7i+JZd2nDPbo1v2aUd+59XpZLd4fBWShP4lCsC+VSpuKaOn9KjT72gqeOnEg/aInUWSzOkUy9XnB36lCsC2dbLUGzUwy9FWqRVmh5+WetugTzrtncdx/BLkda2lKaHH6butigTNkDedNu7jmNrkSKtbSlN4Eu9lStWKq5HnpzWE0deVMWlOSa9Y+hNuvptSwl9IGbdDsXGMfxSpEVapQr8XvzxT6d1cPqUtv7H1Nnf7puuGdVbB+drZCBf43dA3nTbu45rrq4oa1sI/A6mT/5Fm3cePOct4uadB/XO4UUEPhCzbnvXYYZfyjB0S+B3cPqVM03fIr78ypmUWgSUSze9616HX8qyMLM0VTq9uvCCeU1n6IcvyN+EDVAG9V8QV6we0OrB+YECu0i19u0Q+B2sGmhezrlqgMAHiqIsCzMZ0umgSDP0QJGFGYMvy8JMAj+AoszQA0UVdgw+K7X2cU8csx8+gNybOn5K41t2vaGHPtnFgqsg95GIM5CjmjhmP3wAhRbFgqtO7+TjruSJY5VwIyZtAeReEvvdxF3Jk8TEMYEPIPeS2Bwx7kBO4pcWQzoAMqfbsfIkquniruRJYuKYSVsAmZLVVa9JtCvIxHEn7SZtCXwAmRJFxU1cogjkuFGlAyA3snyHqbyvyWHSFkCmFOkOU1lTuMBP+4bHAMLhdqTxKdSQTlYnewAEx/5V8SlUD78sW5wCRdfLFsforFCBX5YtTgF0h6HeqkIN6ZRli1MAwbUb6pVU+NsazlaoHj6TPQAatRrqfebPp/WrQy/oJ/ue1X8+dUKf/f5/acf+5wvd+y9UD5/JHgDSuVsz9Jlp0fnn6eiLrw/tLjr/PO19+v/0D//2P2d7/RuvHtW3dhzQRcsW5LbOvpNCBb6U/4URAMJpNoSz6ZpR/eDRp8+G/ifGhs6GvVTt9W955KA+957VgRd4xX2zkjgUakgHAJoN4WzeeVCfGBuSVB3q/eslC5oWeMzpU6A5v/ovlfEtu7Thnt0a37IrF8NBhevhAyi3VtV6l618s7ZNvFtLFvTLXU0LPC4bXqQ+qwZ6u956EjcriUOoHr6ZXWBmPzOzg7W/FzU5ZqWZ/cLMDpjZfjPbFOacANBOq60ZLlw872xd/6qBNxZ4/OOH1+if/v1Jrd3cubee1xLwsEM6t0na6e6jknbWnjc6I+kr7v52SVdI+qKZrQl5XgBoKki1Xr3AY3LjVfqXz4xp4m9X658fOaQnnj0ZaMFmXvf7CTuks17S+2qP75P0S0m3zj7A3Y9KOlp7/JKZHZC0QtLvQp4bAN4gaLVevcBj+uSMtuw8dM5rjbtzNk7QDi86P/ablcQhbOAvrQW63P2omS1pd7CZjUi6TNLuNsdMSJqQpOHh4ZDNA1BG3VTrdVqw2Wrh1gffvlSTOSsB7zikY2Y/N7P/bfJnfTcnMrP5kn4s6cvufrLVce6+1d3H3H1scHCwm1MAQNc6DQG1mqD9459fzt1+Px17+O7+gVavmdm0mS2v9e6XSzrW4ri5qob9j9z9gZ5bC6A0kqpz7zQElOUbsnQr7JDOdkk3Srq99veDjQeYmUn6rqQD7n5HyPMBKIGktzpvNwRUpD26wlbp3C7pWjM7KOna2nOZ2VvMbLJ2zJWSbpB0tZntq/0ZD3leAAWWpa3Oi7RHV6gevrufkHRNk48/J2m89vhXkrI/uAUgM7I0jFKkPbpYaQsgc7I2jFKUPbrYSwdA5hRpGCVL6OEDyJwgwyh53K0ybQQ+gExqN4ySdBVPUTCkAyB3slTFkycEPoDcyetulWkj8AHkTl53q0wbgQ8gd6ji6Q2TtgByp0iLoZJE4APIpaIshkoSQzoAUBIEPgCUBIEPACXBGH6CWAoOIE0EfkJYCg4gbQzpJISl4ADSRuAnhKXgANJG4CeEpeAA0kbgt1GpuKaOn9KjT72gqeOnVKl4z5+LpeAA0sakbQtRT7KyFBzoDlVt0SPwW2g1yXrRxqt6XsrNUnAgGKra4sGQTgtBJlmjHPIB8Dqq2uJBD7+F+iTr7NCfPclKDwSIT7sOF++Qe0cPv4VOk6z0QID4UNUWD3r4LXSaZKUHAsSn3uFqfAdNVVs4BH4b7SZZOw35AOgdVW3xYEinR9TVA/Gqd7iuWD2g1YPzCfsI0MPvET0QIDuo2Q+GwA+BunogfVTMBceQDoBco2IuOAIfQK6xE21wBD6AXKNmPzgCH0CuUTEXHJO2AHKNirngCHwAuUfFXDAM6QBASRD4AFASBD4AlESowDezC8zsZ2Z2sPb3ojbHzjGz35rZQ2HOCQDoTdge/m2Sdrr7qKSdteetbJJ0IOT5AAA9Chv46yXdV3t8n6SPNjvIzIYkfUjSvSHPBwDoUdjAX+ruRyWp9veSFsfdKemrkiotXgcAxKxjHb6Z/VzSsiYvfT3ICczsw5KOufteM3tfgOMnJE1I0vDwcJBTAAAC6Bj47v6BVq+Z2bSZLXf3o2a2XNKxJoddKekjZjYuqV/SQjP7obt/usX5tkraKkljY2Me5IsAAHQWdkhnu6Qba49vlPRg4wHu/jV3H3L3EUnXS3qkVdgDAOITNvBvl3StmR2UdG3tuczsLWY2GbZxAIDohNpLx91PSLqmycefkzTe5OO/lPTLMOcEAPSGlbYAUBIEPgCUBIEPACVB4ANASRD4AFASBD4AlASBDwAlQeADQEkQ+ABQEqFW2qKqUnEdPnFa0ydntHRhv0YWz1Nfn6XdLAA4B4EfUqXi2rH/ed1y/z7NvFpR/9w+3XHdpVp78TJCH0CmMKQT0uETp8+GvSTNvFrRLffv0+ETp1NuGQCci8APafrkzNmwr5t5taJjL82k1CIAaI7AD2npwn71zz33MvbP7dOSBf0ptQgAmiPwQxpZPE93XHfp2dCvj+GPLJ6XcssA4FxM2obU12dae/EyXbTxKh17aUZLFlClAyCbCPwI9PWZVg/O1+rB+Wk3BQBaYkgHAEqCwAeAkiDwAaAkCHwAKAkCHwBKwtw97Ta0ZGbHJT2ddjtmGZD0QtqNSBnXgGtQx3XI5jW40N0Hm72Q6cDPGjPb4+5jabcjTVwDrkEd1yF/14AhHQAoCQIfAEqCwO/O1rQbkAFcA65BHdchZ9eAMXwAKAl6+ABQEgQ+AJQEgd+GmV1gZj8zs4O1vxe1OXaOmf3WzB5Kso1xC3INzGylmf3CzA6Y2X4z25RGW6NmZmvN7EkzO2RmtzV53cxsS+31J8zsnWm0M04BrsGnal/7E2b2azP7mzTaGbdO12HWce8ys9fM7ONJti8oAr+92yTtdPdRSTtrz1vZJOlAIq1KVpBrcEbSV9z97ZKukPRFM1uTYBsjZ2ZzJN0laZ2kNZI2NPma1kkarf2ZkPSdRBsZs4DX4A+S3uvul0j6pnI2iRlEwOtQP+5bkh5OtoXBEfjtrZd0X+3xfZI+2uwgMxuS9CFJ9ybTrER1vAbuftTdf1N7/JKqv/hWJNXAmFwu6ZC7T7n7K5K2qXotZlsv6Qde9ZikN5vZ8qQbGqOO18Ddf+3uf649fUzSUMJtTEKQ/wuS9CVJP5Z0LMnGdYPAb2+pux+VqqEmaUmL4+6U9FVJlRav51nQayBJMrMRSZdJ2h1/02K1QtIzs54f0Rt/iQU5Js+6/fo+J+mnsbYoHR2vg5mtkPQxSXcn2K6ulf6OV2b2c0nLmrz09YD//sOSjrn7XjN7X4RNS0zYazDr88xXtYfzZXc/GUXbUtTsHpWNNcxBjsmzwF+fmb1f1cB/T6wtSkeQ63CnpFvd/TWz7N7etPSB7+4faPWamU2b2XJ3P1p7q97srdqVkj5iZuOS+iUtNLMfuvunY2py5CK4BjKzuaqG/Y/c/YGYmpqkI5JWzno+JOm5Ho7Js0Bfn5ldoupw5jp3P5FQ25IU5DqMSdpWC/sBSeNmdsbdf5JICwNiSKe97ZJurD2+UdKDjQe4+9fcfcjdRyRdL+mRPIV9AB2vgVX/l39X0gF3vyPBtsXpcUmjZrbKzM5T9Xu7veGY7ZL+rlatc4WkF+vDXwXR8RqY2bCkByTd4O6/T6GNSeh4Hdx9lbuP1HLgXyV9IWthLxH4ndwu6VozOyjp2tpzmdlbzGwy1ZYlJ8g1uFLSDZKuNrN9tT/j6TQ3Gu5+RtLNqlZcHJB0v7vvN7ObzOym2mGTkqYkHZJ0j6QvpNLYmAS8Bt+QtFjSt2vf9z0pNTc2Aa9DLrC1AgCUBD18ACgJAh8ASoLAB4CSIPABoCQIfAAoCQIfAEqCwAeAkvh/vKH4xVKuK9cAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Learning hyperparameters\n",
    "top_frac = 0.10\n",
    "num_epochs = 100\n",
    "learning_rate = 1e-2\n",
    "\n",
    "# Conduct topologically optimization\n",
    "Y = torch.autograd.Variable(torch.tensor(data).type(torch.float), requires_grad=True)\n",
    "optimizer = torch.optim.Adam([Y], lr=learning_rate)\n",
    "\n",
    "random.seed(42)\n",
    "for epoch in range(num_epochs):\n",
    "    optimizer.zero_grad()\n",
    "    I = random.sample(range(Y.shape[0]), int(Y.shape[0] * top_frac))\n",
    "    loss = top_loss(Y[I,:])\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "\n",
    "    if epoch == 0 or (epoch + 1) % (int(num_epochs / 10)) == 0:\n",
    "        print (\"[epoch %d] [topological loss: %f]\" % (epoch + 1, loss.item()))\n",
    "        \n",
    "Y = Y.detach().numpy().copy()\n",
    "\n",
    "# View topologically optimized embedding\n",
    "fig, ax = plt.subplots()\n",
    "sns.scatterplot(x=Y[:,0], y=Y[:,1])\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
