{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "7f7568c3-eafe-413e-b864-28209880d853",
   "metadata": {},
   "source": [
    "# Quantization and Barycenter Computation for GMMs\n",
    "\n",
    "As our target data, we load MNIST images and fit Gaussian mixtures to the pixel intensities."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 287,
   "id": "0aa485f7-1282-4d0c-9f5b-6b758e014924",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The history saving thread hit an unexpected error (OperationalError('attempt to write a readonly database')).History will not be written to the database.\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAeQAAAHiCAYAAAA597/kAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAA9ZElEQVR4nO3d23YdxbWA4SljIsjGhyAzcLLhNvf7/R8i97klIXgEEWSTgMHx2heiRalUhzmrZnVX9/q/MTwS7KWDbdCvWVXdfXE6nU4CAAA29WjrTwAAABBkAACmQJABAJgAQQYAYAIEGQCACRBkAAAmQJABAJgAQQYAYAKPNS96//69fP311/LkyRO5uLgY/TkBAHAYp9NJ3rx5I3/605/k0aP8HKwK8tdffy1ffvml2ycHAMC5+eqrr+SLL77I/roqyE+ePPn1/33ChAwAgMHtHap/CFqapgryEuGLiwuCDACA0ekk1X5yqAsAgAkQZAAAJkCQAQCYAEEGAGACBBkAgAkQZAAAJkCQAQCYAEEGAGACBBkAgAkQZAAAJkCQAQCYAEEGAGACBBkAgAmonvYEePqf3/3v8I/x75//PvxjAIAngowh1oiu58cn4AC2RpDRbev4eqj9Hgg2gNEIMpocIcIWud8voQbghSDD5NxCXJP68yDSAFoQZFQRYZvwz4s4A9AiyMgixP3iP0MCDSCHICNpzRj/4fEX5rf517u/DfhMxlv+XAkzgBhBxj0jQ9wS3t73NWu4WdYGECPIEBH/EHvGt8cewk2cAYgQ5LPnEeJZ4ttD83tYI9rEGThfBPmM9cb4CCG2yP1+R4Wa/WbgvBDkM9UTY88Qv3j/x+a3/fbRP9w+jx6jQ02YgfNAkM9Qa4xbQ9wTXY/3u1W4wz8vjzgTZuDYCPKZscZ4tgi3qH0uawQ7/nPsCTRhBo6JIJ+JlqnYGuOZImyxRbCXP1vCDGBBkM/AyKl4rxG2yP0ePULtFWaiDOwfQT64UTFuCfFnF5+Y36bFP08/rPJxPEPdu9/MtAzsH0E+sBExtoR4rQC3ftxR4U79GVki3RNnwgzsF0E+KEuMvUK8VYBbaT5fr2jHf37aQLcuaRNmYH8I8gF5xrgWYmuEP7sc+6/cP9++c31/o6Id/rlq4twTZqIM7ANBPmNrxHh0gC0fzzvWdx8z8+egDbUlzi1hJsrAPhDkg9FOxz0xroV47Qhr1T6v0dO1JtDLn7t3mIkyML85v3KiyZYx1kb46lL1sibXb/veXvt7aA13+GdXi7MlzJYoi7CvDMzq4nQ6nWovev36tTx79kwuLp7IxcXFGp8XjDxiPCLEIwOs1RtqrdZQayZnzT6zZRmbKAPrOZ1Ocjq9kZubG3n69Gn2dQT5AEbGuLQ8XYrxDCHWGB1rS6Q9wkyUgflog8yS9ZnwjHEuxJoIX12+r7+o0fXbR+a30X7j0Bru1J9VLtKaJe3aUrZlb5l9ZWAuTMg7p5mOvWLcEuKRAbZoiXXbx7G9vjZB16Zmr4mZMAPjMCGfgZ5nGov0xzgXYk2Er373S/U1Ndc/f6h+belz8ox16s+kFOnwzzUV5+XvojQxl6KsPfTFtAxsjwl5x3qmY22MLVNxLnoe8bWyxNr8vh0CXpukc5NzaWIuhZlJGdgOh7oObsRSdWuMe0J89dFP1deErn/6yPT6e287MNIPPpYh2qU4W8NMlIH5EOSD856ONTHuDbE1vhY9oRZZJ9a1SM8UZqIM+CHIBzZrjFMhHhnhmt5IV99/Z8RLgc7FORVmpmVgbgT5oGaMsTXEn370n+yvWXz30+/NbzM60g8+niLa1jB7TctEGVgHQT6o1iBrYizyMMhhjFtD7BVgrZZQi6wX61Kkc3HWhpkoA/MhyAdVC7LndGyJcUuIP/3kTfHXc7774Ynt9Y2BbmUJuzXOPWEmysA2CPIBjZyOPWOcCnFrfLUskV470ItaqHNx1oSZKAPzIsgHNGo6HhnjWoifPbWF+ua1LrzWKfru7VaMdSnQqTjHYdZMy5YlbKIMjEGQD2aG6bgUY22IrQGuGR1oD5rI5+LcEuY1pmWiDOgR5IPxmo69Y6xZnq5F+Mnz19lfe/N9/l/elD0E+u5zyIQ6Fec4zF7TMlEGxiPIB9MSZK+l6jDGlqk4F+JSgLU0odbGeZTefe04zB7TMlEG1keQD2SN6XhkjD0CXLOHQIdysW4JM1EG5kaQD8RjOvZYqk4tU9em4lyMP776PvnzGj9ePy/+unaZe5ZAa+NsCTNRBuZBkA9kdJB7puMwyJoY94Q45hXmUBxpz0Nomm8AUnEuhZkoA/MjyAcxYrnaOh33xrgW4csXN8mff/vts+LbhUbEeQ25SMdhtkzLpSVsogysjyAfxOjDXK3TcW+McxHWKIV6r2EWScfZEmaiDMyJIB/E1svVtenYEuOeCKf0hHnRGujRJ8XjOJfC3LKETZSB9RDkg/AOcm65uvUgV+40dRjjUogfvyw8BPhX7765LP66R5gXb75/usqp8NTHTSmFuTfKIvfDTJSBMQjyQViD7LlcbVmqtsZYE+KcUqBzcbaGuZX20Frp84njrJ2WtfvKHlHmNpuAHkE+iFKQRyxXa6fjlhj3RDhl6zB7nRjPfU6WMGumZaIMbIMgH8DI5eqe6dga41KIT59fZX8tdPHqOvtruTBrTmlrA+15uVZJ6vMphZkoA/PTBjn9NHScpfh+1SL1ZxovtHvGi9PnV+oY116fC77m8/j46vtsbJdfa4nx5Yub6g/tx3zy/PW9b3zCb4jCLYTw7yr8hureqkf4zVdhaz5eWRFJb4fkLruLaR6OApw7goxmmsNPqVhaQqx9254oi9z/hsIjwh6vjz+H0VGOD/wRZWBdBPlAUl8sW8TPOBapP9dYqyfGNb171K0hFum/pCsX59K0/Ozpm7swE2Vg/wjymQu/MGukbiWp3Tv2irH1/Xhf/9z7/mvfOOTCHEpNy0QZ2DeCjKzS/rH1Wt1SRP/7+Zf3fvS8v96lawvt8vTjl2/v/cj9XO39p6blRRjlJczWKIc0UU4hykA7gnxgmhPWvTTTcS3GqZ/TBNoaZU+1m53UYmt5m9K0XNtX1kT57ueify9qUc5tkViiTJiB3xDkndJ+0dNKfYFeeO0fx7TTcCnMluVrrym5985jGnGYU9PyojXKdz9XOHk9OsoiTMvAgiBDTfsoQs10rI1x69uMmJJLS9SaaXi5dCv+UVKalnujrL0ciigD6yDIuCc1QcVS+8eW6bMlxqW3XWNK7p2uS59jLdCpKC+fjybKC88ox4gy0I8gw2zLRxh+8OqrTT6u5dnMMeup8FSYSxO45qDXInVQb+SNQ0SIMqBFkHFP/IACLUuwvKOauq1mz+00c3retkVtWtYsXy9SUV7rciiR2yhzAhsoI8hQix9ukBPGsHQP6rWtHVQPvVHO3dFrUTt5LeIXZRFOYAMlBHmntDf1n1XLlJx6G8t07GGLqLdGedF78jq2VpRFmJZxXggyROThQ+1D8fN31/bBq6+6l7lHh3TkNwEi+X3lRSrKlkNeC6+T1yJEGbAiyAeWeoTeTDSRLb3Gezr+8fq56nGMWy59Ww+IaQ95tTwdiigDvgjymQifdRs/B7dHGLBcqFr3kb0Of9UCGodY+4xkby13KPM45LUgysC2CDKywgfeW2kn1VR0NUvU2ulYE2PLz2vfr1YpwCOi3Hp7zZmiTJhxVAT5zF2/bftXwHotsnZKXus6Y83ytHVSHr2PrGU55LWoHfIS8Ymyx2VRIkzLOCaCfCDfPvqHy/spXYucuvQpt2wdB+ri1XUyzMtErIlx6n28++ZSPR1r94nj16feJvX+U59H6ZsRr29ANLcK1V6fHCrd47wlyiIsYQM5BHli//7576t+vNRJ62XZOnXSOpyStVEWyYe5JPc2uRDHsbSGOCX1PryinFu6Lykd8Gq5FCqkfVwjUQb8EOQds16LnDvYlVq2tk7JJbml3Fxkl58Pf6Tep2UqLnnz/dMHP0pSUY4/bsvyde+0bJ2SLUvXRBkY7+J0Op1qL3r9+rU8e/ZMLi6eyMXFxRqfF35V+2KT+sIVf6HTPBd5+SJ77wvvr1+Qwy/SyzQVfjEPv8jnno+88HgKU2+IrfvfqaVekfTkGf+e49+v9bKlmtTyfSj8cwn/PJY/g/Cbq3AVJDzQF35zFq6ixN/Ixaf3w28A734ucSlebatF+43n2itKgNbpdJLT6Y3c3NzI06f5rz9MyLintGwdCr+Qa5auF7nJVqN3KtZMvym5yVmzr5zaRx/J+g1Pbul6xKQsYj+BLcLtNnE+CPKZ6V22Dqeo3NK15tpkTZiX15RCrN0r1oT45vWT6nK8V5TXuse39fnJGqOizBI2zh1BxgPaKTlUClXpmt04uqUAh7ym4iXCYYhrYU79XjWHvWIeUfZeAtdMyTGvKIuwr4zzRpAn17IvFu/J9dxC0zIll8K3TLO9N9TIvQ/tVBwGuDYNl16Xet+lKJcOtnkrLVtrpmTr0rUIUQY8cKhrB7wPdonUD3eJ/PZFN/xCnDrgJZI/5CWSPvy0SB38CrXcaSsXYi/xEm/t91s76LVomXZb7lgW/5mFf149h7xE/A56iZQPe1muMOCwF7amPdRFkHegJcgi+tPWIr9FORVkkf4oi5TDbGE5PV0LcelJVqnrcxepfdfcaXORcVFuvYVo6sS1SD7IIn5RFvE5gU2UsRcE+WC2npJF0lGOo1WbHhctcfYIcc+jJFOBHjUtt8otjZeCLDJ3lEV8wkyUsRWCfDBrTsm3/799UhapT8u9cvvVI0Ics4a5FmWRcddnLyxBFmmLskj5OuXbn7v/z0QZ54AgH8yIIIvYp2SRtiiHegLtHeLa6fHU/Z3vfs24OjAyzB6n0nNTskh7lEXm21cmylgbQT4gj2Vrkf4pWSQfZRF9mBdxuLQ377BGuOdxkova71XEdsDNK8reQRYpRzn+s9xTlAky1kaQD2iNKVlkTJRF7DegyNkixDHrNyGlJWyR/GlzTZy1dz6bPcoiD8NMlHEEBPmANNdVtkzJIvml69t/ru8ni6SXd60nlWOay5VaQ1x6gMaidEOM2u/XYwk7lIqz5TakrddvE2WgD0E+qC2m5Nt/1kVZRDct9yrtDedCrAlwTcvv13o5WC3MrXLXdNemZJGxe8q3P3f/n4kyjoQgH1TrlCwyLsoiulDd+7WGSLdEWEQf4iUc8e8txRpm67Qs4hvm0g1WWm+uUoqyyHr7yrkoc8gLsyDIB+Y1JYv4RllEt4ztpSfCqft155QCXQtz70nsRU+cW+52JkKUAS8E+cA8p2SRviiL6MK86A10676wJcA1Ld+IWKZlkb7bjYa09w63PCHriFEmyBiJIB/cbFEWsYU5FMbLeiLacxJewpD6vaVYw9xyVzOv243WWKZkkXqURba5VpkoY0YE+Qy0Ll2L1PeTRepRvv05vzBreUQ4dbCopBTp2n56zyVSizXCPDrKIus8mIIoYzYE+QyMnpJF0o/N00zLIu2HoxaWU9G1CFsDXKL9/fYsY4v43gdcIxdkEdsd0jyXsIkyjoAgn4lZonz7c+1hbuEd4eWLf+r3ltIS5pa7mq0V5lKQF9oHeew5ygQZ3gjymegJsoh/lG9/vr4H2xJpr6Xo1M0oNCy/X+9lbJHyPcA94qwJssg2h71arlNmSsYsCPIZWSvKIv5h9jAywjmt++mWQ18i9jCLtMdZG2SRcUvYRBlHRJDPzNZRFikv9XrGWbsU7R3hFE2Ye6dlkbYwi+jjbIlxaMS0PEuUCTK8EOQzowmyiF+URdrC/PC1+VC3HMSyRDi1F5mS+30utCsEI5axRfSPs8zFuTXGi71GmSkZayHIZ2hUlEXs0/LCEucWIwJcY/lGZK1pWcT2nOklzqUYa14T8r6zl2eUWbrGlgjymRoZZZH2MIv4xXmLCKccIcw5PdO058MpZogyQUYvgnzGPKIs0hZmEV2cR+kJ8PIFvPR7S9GGea3T2IvWMGv2nWthHhXl0s1DWLrGrAjymds6ynevGRznlgDnHuWn4bGn3jsti/jvL4uMOZ1tebayyLxRJsjoQZDPnDbIIn1RFvGbKDXWDnCONcx7WMbuuZZ5yyizdI3ZEWS4RlmkHmYRe5xHGBHgnN4wzzQtj7y5iCXK2oNe2iizdI2tEWSIiH+UReYM85oRTum9kcoMl0h53oqzJcyaKGsPebF0jZkQZNyzxbS8GBFnjwDnli5TtL/fnjAT5XyUR+wns3SNtRBkPLDVtFxSirXn1GuJb03rCsEeojziSVKpMI+K8hpL1wQZVgQZSSOiLNIf5hE8I5zjdc12T5hr+8qWKK/5aMc1osyUjBkQZGRZoixiC/Nii0CvEeAca5hniLLHlHz54kbefvtM9VrvKI9cuuaAFzwRZBRZoyzSFmaRMXHeMr4lvWEeuYStjbImyJcvbpI/X4tzLcqa09dMydgbgoyqliiLtId5RrUvut7fhIyYlr2jXAtyLsaxUpzjMOeizNI1joAgQ6U1you9xFmzBKnlceDNe1ruibJ1StYGeZEKs2VSbl26thzwIsgYSRtk+/PtcCj//vnvXV9U/vXub66x87R8bt6fn/Z9fvvoH9kv9KkTv/EUFwflwT5pGKJgYhR5eB1vKF4Wzj0+0Usq4LUpPPymIXVHstC9b0yCb1pKDzOJvyFq3Vbp/YYWCBFkiEj/d/qj4jfz52EJc8qaUY5vTxlHOVa6Fab2EFdIE2XN7T3DyT9cFYj31lNabtm6lxUgHANBxp3eaXmxVhTDj7PlNwOaj71VlENxlEOjp2QRe5RzU3LqRHlo7SkZ8MIeMpLWWIrTTh9bT91WrQ/riAPhtafcs5+cW1q27iOHanvKuUuhLAe82EvGTNhDRhevabkkNeHOMvX28JqWvSZly37yDHJL10zJODqCjKI1wnxEtW8meqIchjmO8r3XZqJsWbou7SW38lq6XuT2kuMVhMWIvWQOd8EDQYbKEuajxjn8/Xn+HkdEOXZveTZ+GpIiyltMyS0nr2O5A14pTMnYA/aQ0WVPk4FXaEfcUCUVgdq1yqU9Zc1+culBFLW95J495FC8n2y9YYjlZiG9e8ncThOt2EPGKmaenEdNva3vzzot105gl/aUNfvJpaXrmpZLnzR6puSa3imZS6AwGkGGm63jPCrApY9n5RHlkGeUwwl0jb1kq9xTqxZrXJcMjESQMcToMMbx3fqbAIveKNdOX3upXZvsMSXXlr7XPHEda5mS97SFg/nwLSKGqwUr/CI249K3xvJ5a78g/+vd37Jf3L999I8HMfjn6Yd7y6r/fPvu3oR3/fa32Fy/fXQXoeufP7ybFq9/+uhuivzup9/fhey7H57cBe7m9ZPsJPrj9fMHS8pvv33mtp/s4eqjn4o3Ryn57OKT6ooEMBITMja39ZTryfJ7GDkpr7l03Tspx0HvuQQqlFu2Dqfk2rI1UzLWRJABZ1tGOWSNcs4at9XsZVm2LkmdbAfWQpCBASwTv/VuZKVl1Z79ZO21ySOmZE8zPHSCKRktCDIwUG+UczcPCW2xdJ3iGWWvZesQt9PE7AgyMNiIKI/cT84Jo5y7DGrLSdly5y4PTMnwRpCBFWwR5ValpevZo5wy6nAX4I0gAytZI8qhEUvXsRmi3LpsXaI93MWUDE8EGVjR6D3lWS6FErmNsvaHVe5a6bWXrQFPBBlYWe/11nGUvS+F0tji1pq5u3Zp5B7LWFq29jrcxZQMLYIMbEATZcvlUKUoWy+F0i5dz3i/6y3w0Al4IcjARnqiPMt+8ixRzt3b2rpszZSMLRFkYEPeUQ6tsZ8ssl2Ua09/immWrVsxJcMDQQY25rl8PepSqNgsUR6BS6CwFYIM7Jj1+mSvKTnFM8otb6tZtm6ZkrkECmshyMAEjrCfnPLj9XNzXC2vty5bazElYwsEGZjEmvvJoZH7yQttZEctda81JdcwJaOEIAMTWWs/ueWpUB5RLgVXE+Patci5ZesWI56VDJQQZGAyrTcOGb2fXKKNskg6zD2TseeyNXvJ2BJBBnao53GNIc9DXpYoi/wW5pEnsq2Hu2LsJWNNBBmYkOd+svetNa1R1jxL2ZPlYRMpTMnYCkEGJtV7z+tQ636yR5RF6tNyL+2ydcuUXMOUDC8EGdix0UvXJVtFWfOQid7DXZaHTuQwJcOKIAMTW3PpuveQ19aTcmlKZi8Ze0CQgcltdWtN636ySP1uXiI+UbZOyS3YS8baCDKwA1tcClVSinJIczevVqkoMyVjzwgycBBrXgoVm+2Q10I7JXOPa8yAIAM7MfPSdWytKGumZK9nJYswJWMsggzsSGuUPZeutfvJa9HsJ+fwJCjMhCADO+O5n1yivRRqxuuTtVNyi9plUEzJaEWQgQPyWrpu2U8O7W3pumVK1mJKRg1BBnZorVPXJXtYum65g1f2NYabhTAlowVBBnaqFuUt7+IV2vouXmGUNUvXXrfUBKwIMnBgI05dh2abkkXan5lsXbpumZJZtkYJQQZ2bLYDXrOIo9xywxCN2mVQgMX8/2UB6DLjYxotPr76/u6HRSnKmhuGeFwG1TIl43wRZGDnRj6mMaS9raZG6y01wzjHP1K0UdYsXed4T8ksW58vggwcwKgDXtopWaNlSrZMxdoohyzXJzMlYzSCDJyJ1gNeodyUPMvhLk2UNZdC9R7w6sWUfJ4IMnAQow54aU9czyK3hJ2Lsue9rhdMyWhBkIEz4j0lt9y9S8t6iEvz9poop2xxbTJT8vkhyMCBzDIll54E5e3yxY1cvrhJ/pol6kuULVNy77I1UzJCBBk4mNYDXqOM3EcOQ1yKchhmzX5y6oCXdUrmdpqwIsgAkkZcAtVzPXIoNxVrp2XL0vXoKbmEZevzQpCBA2qZkluXrUfcuSu31FwKbvy6mtqkPGJKTmHZGguCDMCN98GukCay8evjt9HsKdem5NTNQrRTcsuyNVPy+SDIwJnSTMk9y9aWg12lm3dop+LS24csS9eWm4XkMCVDiyADB+V5S83FmsvWWo9fvpXHL8vfHbREOea1l8zhLuQQZAC7FYa4FubWKbu2l7wGlq3PA0EGcI9l2VrLculTOL2WIpqLbynM4fvLTckL7Ylrr2VrgCADB7bFNckjD3YtakvUy2s0r1ssUW45cX33cyxbowNBBnBocZRzU3JKy3XJrTjYBYIMwCQ82OX5jGQty9Sbe5va0rV2Sl6sdX9rHBtBBg6u5bR17SYhW2mJce/blh48YVm27r2VJge7jo8gA2dOs4/cerCr5yET8YGuXFBPn1/d/agJ30dt6brnEiigBUEG4GKLa5HjCGvDXFK6ScnoZWv2kc8bQQawiZvXPg+aSNFGWXvAq/a85Fi4bA1oEWTgDHjvI1ufj7yFXJQte8mlZetF7SYhXP4ELYIMwN0a1yJr9C5fl7TsI3ODEJQQZABdNwixXvrk9Uxkrdq+cupuYK37yKNx0vrYCDKAXeqZfnPL1r37yFyPjB4EGYCKxz2t90izj5zSerCLk9bniyADZ2LE4xgtLA+YiPXcEGRhXba2WPvpTzgmggzAzZbPRfZW2kdeaA92ed6xC8d1nP96AKxKe+mT5m5db75/Wvz1d99wYS+OjyADyNryntZvv3222ccGtkCQAYiI/7ORR1+LfPHqeuj7B9ZGkAF0a30MY+r2mT9eP0++tnfZmoBjdgQZgJrXpU9r3xykRW1fG/BGkIEzsvWlT7NivxozIMgAXJUufSpdi5yaSONQxsvW2mVolquxB/nHkACAk+ufPmp6GEOrWoC5jAozIsgANnXz+smD21P+eP28eF/pENMvjoIlawDDeFz6VFu2Bo6CIANnxPvxfdq7daWkTlprTzb3RDl+2zD4qUuuUpdmLTR3IQO0CDKA6Y06Be31fkffBAXngSADWFXqpLXlBiGLlilZ+zapSb3n2ulwJeFcH2OJOoIMYBd6p9nSUrWIfrla+xjJ1ruX4XwRZAAuvAJU2ke2HvB6983l3Y81tDx+cssHeGAuXPYEYBWpa5G/++GJfPrJm8xb6C5/aoltaTrWHizjQBe8MSEDcNcyKZZOM4e2uM3lHu69jf0jyMCZqF3y9IfHX5jf52eXvy2yXQWD6tXle/P7CoVTampvtyfKmr1jEf03CJywhheCDOBs1GJcO12tPdDVw/u51NgPggycgdYbgrx4/8d7//zZxScen46Zx5RseX1tOk7tH4fL9OEBNy55ghZBBtC0XL22lii//fbZ3Y/a+6sd5lpjOsZ545Q1gEOwTMClGIfTce0wV+/+MZc8IcSEDByc9/2r15KaWHNTck+MrWqXO428Ici/f/77uHeOzRFk4Mzllqtr+8fhCWsv2pPNrVIx1kzHteXq3GVePQ/fwPkhyAC6aS55im8K0qp3wg1pbwKyCKdj63K15kAXJ6zPG0EGDszrdHWrq9/94vJ+Yi1Rrh3iap2OW7F/jBhBBg5KE2Pt6erScvWV422inz3N30azh+dULXJ/Os5d7gRYccoaOFPaveNYae84XK4Op+NwufrTj/7z2/8v3MfaQy7ELdOx9d7V7B/DigkZOCDPk9Wlm4F4TsdWtalXG2Mvpek43j9muRopBBk4QzNMx7F4ufrJ89fFz6XEskRtnY5zy9Whlum4dqCLS56OjyADB+N5kGvkdOyxXJ0KbynGpaVqYGvsIQMH4nmQK9Y7HY/meXDLOh2zXA0PBBk4iJ4YrzEdlw5z9Zyu1oa4Nh17PfOYw1xoxZI1cOY0MV5z71ikb/84pfrgCKcYt2L/GCIEGTiEta45FsnflUsb43A69ojxx1ffm14vUt479lyujrFcjRKWrIGd816qDrXEuKQ3xrn4fnz1vfvNP1rEy9U8/xgWBBnYsRn3jTU3AbHEuGUCtth6uRpYEGQAIjJ237h0idNWIU4Zebo6h/1jLNhDBnbKczoeuW9cOlG9RozXuva4drqa/WPUEGRgh2bbN54hxltM1IAnggzsTM99qnv2jTWHuHqvNV47qj23ymS5Gt7YQwZ2RBvj1HTce71xjuZuXLX7VO99umW5Gh4IMrATPTFOWWvfuLRMXQvx5Yubu///9ttnxdcu72+Gy5+AFgQZ2IHeGG+1b9wa4zDE8c9pwuxF82SnGMvVaEWQgcmNiHHvvvGIGKcinDI6zOH+cfY1hrtzAVoEGZjYGjEOp+PRMe4JcertvKIcHuiyYv8YXjhlDUxqyxjneMb48sVNc4xL70N7QGzEHbpYrkYPJmRgQmvHOFa7xMkjxiWPXz5cE373Tf47htK0/OT56+rTnmpaHyYBWBBkYDJbxNiyVD0ixqkA515TCvNsWK6GBUvWwES8L20SKR/gErHvGy88Yvz45VtVjOO3qfG4rjk8YZ1T2z9OYbkaOUzIwAQsd9/qPU3tcYjLK8atHr98+2BS9jjkpTlhncOjFtGLIAMbst4Gc40Yh3KHuBaaGHuGeGvsH2MkggxsxGMqFumPccxyiGvhFePT51fJn794df3g51JT8gjaG4LEUvvHLFejhCADK/OaikV8Ytx7iKsnxrkAp16XivKWWvaPgRIOdQErsk7FW8c4lNs3XoyKcen1HsvfPTcFATwxIQMrGblELTImxpZDXDkeMfZy81p3MxDNCesQB7rggSADg41eohYZG+OQ9RBXHOPeENeWrj1vpxmzHOhi/xgtCDIwkNdULLJdjFsPcWli/N/Pv7z3zx+8+urBa2JxlL0Od/Vc8gR4IMjAADNMxSK+MbYc4mqJcernNIHeAge6MAJBBpzNMBWLjIlxaK1l6i1pLnmK94+5XSZaEWTAySwhFinf9MPzucY1uRinpuPc60ZMyaUnPVkPdGmwfwwNggw4GH2CWqR/KhYZ81zj0nScoo1xi9SBrtKTnrSXPIUHuliuxigEGejk9UAIz6n49p/7YhwauW/ca6YDXSxXowdBBhptcV2xSH0qvv25/POMtTGuHeIqWSvGFtprkBct+8dAD4IMNNjjVCxij3GJdanai+YWmj9eP8/+GvvHmBVBBoxGxthrKhbxifHMS9VrYf8YayHIgNLWt75c9EzFIr4x1tgyxh4HurTYP0YvggworD0Vt4RYRD8Vi5SvM259gtPa1xyHB7pKt8ws7R+nDnSttX/McjVCBBkomGEq7l2eFmmbikW2j3F8DfKW+8cty9W1/WMgRJCBjJEx3iLEIuNinNOzXF27IUhtOtYsV9emY8sDJYBeBBmIbB1iEd99YpH6KereGHtPx6kYt07Hy3K15+lqze0yOV0NK4IMBLaO8UwhFtEd4NriXtUe03ENp6uxNoIM/Mojxtq94tmn4pw1rjeuTce5O3OVpuMUDnNhNgQZZ4+p+PsHn2NqMp5xqVo7HS/L1anpuPdmIC3L1UAKQcZZWzPG1qnYO8Qi68bYeqBLE+PcUrX3dMxyNbZAkHG2tlyibn0QhMi6IRaZczLOKU3HHliuxkgEGWfHEmIRW4x7l6hbb+whUl+aFukPsYg+xh+8+qo4JZcua0rFuHc6Tl3qtCxXt17qxHI1PBFknJUt94u9puLZQ5xSu6a4RhPjZToOY9y7dxwuVzMdYzSCjLOx5X6xdiqeMcQi7TFuCXHLvrF2qdpz75jpGN4IMg5v5BK1iC3G2qm49b7TIrYIi4wJcatSjDVKS9Wh3huBACMQZBza2vei1sa45dGIniHWXE88W4z3Ph2zXI0agozD2kOMe59RLNJ36VJoiztuLTxinLpFJtMx9oQg45Bm2S9uWaLOTcUtt7mcOcIi9dPUIrYYh8IYj56OAQ8EGYfitV8s4hvj3BJ1y/J0b4i3jvDCGuOc1KnqUOoyJ28sV8MDQcZh7CHGnlOxJcQ9Ea7dccvjJLWGZd+4tlStnY41T3UCvBBkHILXErXInDFuCbE1wj3PLvbgtVQdSi1Ve+NSJ3ghyNi9o8V4jWcTe8b3v59/aZqSNdOxZqk6NPN0zHI1tAgydu0IJ6k1U7HI/RhbQ7z19FtTut7Y8yCXlvZkNdMxPBFk7NZeY9yzRG0J8ZoR1k7Jo6fjlJbpOMZ0jDUQZOzO6MNbItvF2DIVp0I8+yScwnQM3CLI2JVzifGIEP/y8s/V18Q+/Oav6tda95JF1r3MieuOMTuCjN04WoytS9SWELfEt/R+LGHOsV7qVLvMKZQ6yBUKY1ziOR2zXA0rgoxdGH2SWmT9yXjREuPRIU69X48oh7SXOYU8bgLCdIxZ6b5tBDZ0xBhblqm3jrHFiD3s2nTshXtWY2sEGYcxW4xDR4jxDMFPsSxXX5dv7a3CcjVGYckaU9NOxzPGuHad8aInxrNGEoAdEzKmNWuM771G+fhEkfQhLmLcxnLtcQ77x5gNQcaUZo5x/OQmEdvtMFPPLy6ZKcYzfxPAtcfYO5asMR3r5U05o2OsfYRiz+VNLTG+uPq/4q/ffZzrv6heZ9VyPfIaPPaPgZGYkDGVWW+HKZLfN154X940MsbLay2v134Oo9TuzpWivf44xnI1tkCQMY2ZYhzb4vKmmGeMZ5B7dnMvj/3jXl6rPDgvBBm7s0aMWy9vCqVOVGu1XM/bsgTd8jatNwixPp85llp1SP0dpPb4t0CUYcUeMqbgcYhrqxiPui1m7MNv/lqdkk/Xf0lOyl77xdoYt+wjX764MT/paZQX7//osmy9/HvNdcnQYELGIcwU41AuxiOdrv/y4IcHj1tnxt945Jatwz+3npWG3OVq8b8HWrVT/TlMy9AgyNic1yVOsVqMY7kY52iea5wzav90lJYYj3wUZOo679RBu5n8z+/+9+4HkMKSNXbBulStibHmxh8ibbfFFGl/nOJszzT2fqhE7PHLt8VnIs/mD4+/6L4eORVllrVBkLEpzbTQG+OUlsubanfiOqLeGKf2kk+fX2UfxRjuI3989f2Dpz49e/omefmTxtXl3Ncia/5bINrHRpCxmd6lO22MPa411uwbe0zHMxg9Ebd48vx19qlPn370nwfXI1/97hfV5U97k/tvhlAfA0HG1HLTce4QV2ymGGtssVw9OsCaE9cty9affvKmek9rrc8uPnlw+8zSSWuPZWtPLIEfA4e6sIkRB1tG3vhDpO8Ql4jPQS6PeH74zV/v/dhCaWUg/CYmdd/v1Gn2RepgV89J69I3fn94/EXzqes1hIfIOEy2D0zImJZlOtbsG4c0lze1sEzHqSh98Oqr6pQ845JyjXUvOcW6bJ1S2kdOTckay7+nM03MOXGUmaLnwoSMQ2p5lGKsZ6k6trfLnNYSflMS/hnVpuRF6XCd5tK13ik5tEzMM0/NMabnuTAhY3U9J6tbLnHaYqm65yYgmil5jzyfAlU6bX310U8PHsV4dfn+7kETLVOy9c5dmijPNlFzV7HtEWQc2qg7ccUxLi1Vt0zHYbi84xxHcev4h0vX4eGu1K00PZata7yiXNMzSY+MOWHeDkHGqtaejkMeD4xIKS2perDEuWUCLb2Nd6xbp+TUNcmL0mnr3OVP4ZT82eXjB09/KkVZZPvHM6b+G/GONGFeH0HGbtUOcrXsG4da9o21S9Wt1x97LflaPt4aE7RlSl70LFv3miXMoTjSXoH+n9/9L1FeCYe6MJWe647D6XiWfeMjHOb64NVXrt8ItAZ++ean9HdgeRxj+O9IamVFc3L/xfs/qg99rc3zgBkHv9bBhIzVeP4HPXKpumXfOGXWGC+TqHVK95yWLUvXmik5tWxtnZJzS9ciUr0cyhLltafqMMoe9+BmWh6HIONwStOxl9alapHxt8vUXtsbvk77OY1cws4tW6f0HO4qnbhORVmk/RrlFI+JujXqHnFmb3kclqwxvZ4vYCOmY02Mc9PxqBhfvLq++9H69mtrCbv1muTl77l0TXL8Ddxnl4+zS9jhjy0tS+XhD6veJW2WsP0RZExD+8XB+pxjrb3tG/dGOPf+akbuJ2tuFBJL3UqztpesOeRX+/cqDvRskdYiyvNgyRqHkjtZ7XWZU+3WmKUYe03Ho6fZ1j3mVq2XQqWWrS0PnNDcLGSJcmoZO2VUlFuWyy0nwXtu/8m+sh8mZEyt9p1+z3TseWtMkfEx9pyGtR9va6k/U+t13+E3YPHS9b1v2i7z5w+8VmFa9Uzjlom5dRmbSdkHEzJW4fUfbOkLkXU6Xoy83tiLNY7aRxm2Lq+vdbgrZL0mOXe4q/as5Nq0HNNOz97C/xY0E7R1YrZOy0zK/QgyptDyXbl1aum9I9dizaVqS4itzxNe3qb0uV+8ul5l6Xr0snXqEqi7X0tcCrV8c5e753Vo1PRsCb0lztowE+X1sWSNabWers5NxynW6XitGFuWp999c9kU4/Dta5/LlrTL1rXDXdql6/s/v85ldCnLae/4R/XtlEvamv++9vTkqiMgyDgEyxfNnuk4Vrq8aY3JsifEI95P7JeXf5ZfXv7Z/Ha5P7vUN0SpU/C5xzLWolwLc+rH2rRx1oR5RJTZT27HkjV2o+cEa2rvONQyHadi7BXh3D7qFlp/Ty0htkg9cEKzlxwuX6f2k8Moa+59PTrKpWXzMMq5Je7a3cY0T7FqWb6GHRMyDsuyXH3uRl873Rvnls8v/PuNV0XiSTl/v+v3xcl5DdqpvDY5l76h9Z6UmZLbMCFjl1oP0tSWq2eYjvdgxAlr7cGu2nOSwym5+GjG6KBX/fT1elGuTea1Q2elW4CKpKdl7+c9w44JGYfS8kUzdRio5NxjrJGaiK1TsubP1HJNcvLuXdFWxjItl6bmNYSTeenfac3EnPy1zLRcm5SZksdiQsbu1fbwatceLzTT8ZoxnmkfeQ80U/IS5dyecmyNKJem8rvPI4qy5TIt6wMzapdFsZ88DhMyptT7RBztF1LrdHxEa9zuc9Gzl+y5z13aU15bPJVrJvTc5JybmEsPzBiJKdmGIONsaC93GjEd//fzLx/8OKpadEu/XnrYRCj8ewn/vsJVjvCbrdSKSCrK8Q8vre+zFueWMD/4uYbla5aux2DJGpDyE51aaIKrOcRUW7YOv0louZbYOh3Xfl+jL3XqsUQ5POSVWsIOjZicNe8ztYQeRjl1qVbqIFjqNqCpJezS8jVL1+thQgYa5YJlmX41r9XeZOTxy7fVZd3lNaXX5j6eZ4xHhjs3JS800/LWatN0amLe+vIs9GNCBja2hK53Wl707LXmwj/LZPz45du7lYDw8qf4JiG5A16L0rS8leSDMIIox1Nz7jKteFoeOSXDFxMyzo7HDUFaJsga7bQ8QmkKL31erbfGXN7WU3wJVG1SFrn9d2GWG8R8+tF/7v2Ipfahc/vL8aSs3U+24l7XvpiQgYzaga4RPPaVLWqBr8V4DV6/3yXK8bQscv+btNyNRHKv75X7eKX9bc1NTTSTciw3JffiKVA6BBmYjGYJOwxpa6xmiPEvL/8sH37z1wcfV/soxviuXaWl60UpzCLr31I19fE0B89aohzLXaMcY9l6HQQZU/r20T+6r0Vuobnz04jl6hRtmOLPJxVoy1L3LPvFOeE+skYqyiL1MGu1Xste+rjaE+GpKIvkbzai2U+2TsmctvZDkIGJWabFRes+81YhTk3JFrUpuWarm8OkPq7l8NkS5mVPOTcttyxdpzAlj0eQgU7W6dgaIO0p7Faaz3/rqdgqtXQtIslJeYTSde2lzyE3tacekhE/VrKktnQdG7WXjDKCDHT64NVXw6Ms4h/mPYW45WBXalIOQ9ka596byGi+OchdqlU6cFa6J/eD10ZTsnYvGWMRZKzi3z//fZpb6H33w5PswZ0fr5+bniBkFQZu+f8tYW6JsuWbBkuIL67+r/jrp+u/qN+XVryPnHokY00c1iWQ3ndt03z83B63Ncqh2uMkMR+CjN27fvvwOsvrnz/sflLP22+fPXgO8gg903LKEuuWqV2rFuH4tSOirGHZT+4NcfiNnGUPO/zYqRPhpShblq0xP4KMKfzr3d9MNxn459t35hsbfPfT77N3Y3rz/VP1F+SLV9dDnoLUc7ApNCrElghvITclWw951WhWUGqvyX0+vXvdxUdJGveRsT7+djCt+ESn5pBJ6guOdl+tpOXBDVY9d70a+fEurv6vO8bWt2+9jCy3ovHx1ffmrYjlbeIfHmrvy3I/7q1v+Qk/TMg4azevnzz4gteyj6w52KWNree03Po5zD4Ni+SvRy7tJ486HxB/I6Ddz14+n9TEXLsftwb7yPtCkHGWSge7NEYsW4fCcPbE2TpxjwzxmnvJLYe8LO+79zXx51YK8yKOcm4v2XLaGnMhyFjNyJPW2oNdpX3klLUOdpVoT2P3LHfvYSLewqi/e+1+d3yXsdZJOaX1BiEYhz1kTK10ZyDv6yZrB2nW2EfewlFj3BvTrb8RG40Yz4cgYxqa++F6HuxalvtSE0c4peSWPuMbVdSuDf7wm7+al5+Xtxm1pzxrjEfdlUzr6DH2pPnvlic96bBkjV0LL3/yWLauXf707pvLe49ijPeSNYe7RsR1ljtqleT2j0d9s9Fi9hCH3zxqbxDSIvWNL/exHo8JGatq+U555OVPLVNyvHSdmpS3nvA0Zp2OU1ofMWkJ7OwxLglvDhL+ex6esC5dg8xtM+dAkDGVEY9xS132Ubq7UbiXnDv1WouyyPbLriUzx1j75+a1p3/54qY7xo9fvr374WnUwzC89o9ZrvZFkLF74Xf3tS802ik59YUw3kvWBGHmKK9py+XqXGw9Qxz/3NGwXL0OgozVbbFsrZ2SWw94ieSXVUcvYc90rfEIrcvVOaNC7KF0HbLX5U6xeLna87GLTMc2BBnTaVm29pqSQ7nlQst+cohp+b7UdDzyz6g3xJZl6dFTcvjva+/+casR20vnjiDj0EZMybFUlLealoEa6/4xy9XruTidTqfai16/fi3Pnj2Ti4sncnFxscbnhTNQu2tX6ulPL97/8d4/f3bxyf1//vUSqPDyp6vL97/9/+ASqKuPfhKR+zfnD2+nudzjOrwMKr4Xcjhx5aYi7S02Wx+oEBr1CMVWlr3j1DcquW9stAe6clsMnieq4793zeeW+rzib/rCFZrc5U4t03EYZM1ydS7ItQmZ5erfnE4nOZ3eyM3NjTx9mj+ox3XIOKTwmuTrt4/uoly7Ljl1j+vw2uT4wRPhrTWXL8TxF+glKrUwp4JkjXQYulqcw1h6xrl0v+rcIa74994aYu39q8PX9cbZetq7J8a9vKZjlqvHYELGpkZOySL1SXmZkkXSk3L4JKj4hiFhmFNf1Ev7iL0Pphj1zGORtjhrHhoxcir2epDEyGuRNSEWKce4ZzqOY8x0vB7thEyQsSmPIIvcj7IlyCLppWsRe5RF7GFerBXoLe7opZ2KRdIxXiPEKR5xLn1+a8b49p9/+/8jYyxCkGMEGbux9ZQsot9PFmmLsoj95G1rpGeJsyXEIrYYjwxxiiXOms+ttEQt0h9jkb7pmBj7IsjYjdFTskj/0rXI/SiLtIdZpP2yGGuk145z7UYfHkvUa8fYU20qFmmLscgcS9UiBDmFIGNXZpqSRfyivKhNWKMDPXLPWXO3rS2m4tylarm/o5Fyn0trjEVYqt4TgoxdGTEli4yJskh5CVukPcyLlkCveXmVliXEIu0xLl0nXjMy0NoQi4yPsQhB3gpBxu7MsHQt0hZlEVuYReyHhkbtQXsHunTjk5lCXNIT6drnVJuKRdpiLMJS9awIMnanFmSR/qVrkb4oi7SFWUT3Rb7lZK8l1L2nuVt4RnhUgEfK3YK1FmKRtj3j23++/35Hx1iEIJcQZOxSy5QsMibKIu3T8qInzqGR07RnpGsPgmi5hEkb4ZZHFeb+fnqVPhdriEWI8d4RZOyS15QsMj7KIv1hDrUsk45e9vYyYhoe9azgEXJ320o91EQ7FYuMjbEIS9VeCDJ2y2tKFhkTZZF6mEX64hzaU6hb7+O8KEV4TwEWKd/yshZiEVuMU7fEJMbzIMjYtVFRjoMsko+ySH5fWeRhlEXSYRbJx3nRunQ6eunby+gIe93vufb31PNxNREWqYdYhBjvDUHGrq25dC3SHmURW5hFbF/01wq1yPhbRS48lqI9H7gwSirAd7+mCLGIfYlahBjPiCBj97Zaur7953yURfRhFinHeWGdzETaYr3FTTFE+iO89wDfvcYpxLc/d/+f4xCLjImxCEG2IsjYvdYpWWSdKIvYwiyii/OiJdKLrWLtdSK6dyl4FqkALzQhFmmbikWI8UwIMg5h7SiL+IVZpBznu9cYIh1qDfaoS31KPAI8c3gXpQAvWkN8+3MP359miVqEGG+JIOMwZoyyiC3MIro433t9Y6gXPRO2iD7c1tPP2um3NcCaKK4pFeC7X3MMsQgxnhVBxqG07ieL6KIs4htmkXKcF9ZI33vbjYOtMSK+swU3VgqwSDrCIvoQi+inYhFiPAOCjEPpmZJF0lEWaZuWb39OH+a7X1cEetET6nvvpzPaI2jjO2t4a8F98PpMgEXSEb79+fTrvaZiEWK8JoKMw5ktyrc//zDMIvU4i9gCnTNjuFuWmlviaw3jWkoBvnuNQ4hFxi5RixBjLwQZh7RWlEV8wnz364pA373WIdQxr3D3soZ3z9G99/pMgG9/Lf92niEWIcZbIcg4rDWjLOIbZhFbnKvva0C8F9qIeywt94TXGsc1lAL822vyv2YNsQhT8cwIMg5LE2SR8VFe5OJ8+2v1QD94G8dgJ9//wIiXtEZ3xuCK6KJ7//XlX89FWKQ9xCLEeAYEGYc2Msoi/mH+7TX2QBff3+B4r2VkdK3h9Pu49deUIizSF2IRYjwLgozD84iyiH1aFukP8/3X+0a6+LE2DLhHdLeKq4YmwIueEIv4TsUixHg0goyzMDrKIu1hXlgDnX4f60V7a7NG1xLcWC3AIvUIixDivSLIOBtrRFmkP8wiPnG2minmo2LbE8sRNAG+e23nNLwgxvMiyDgr2iiLjJuW716jjHPKFsGe2WyhjVnCe+/tHKbhhTXEIsR4bQQZZ8czyiJ9E/O913UEusUeoj5zaFsjW3yfigAvCPHxEGScJUuURXzCLKKP893rV4700YyIphdLfBfaCIsQ4j3SBpmvCjiUf//8d1OU//Xub9UoL18sS2Fevghrw9wSlCNHfObAxlqCG7MEWKQtwiKEeG+O+184zlZLlEXq07IlzCHr9Jx93zuK1sw8gmplDfCCEJ8XgoxDWr4geU/LIrowh+IAeAX6XG0RVKvWAC9aQyxCjPeMIOPQRk3LIvYwL0pBOddY7yGyJb0BFumLsAghPgKCjMOzRlmkLcwi9jjH9h6mI/OIbqw3wiKE+EgIMs5CS5RF9MvYC884w9eIoLbwiLAIIT4igoyz0bKvLGKblkOpABBpm1ki2sMrwAtCfFwEGWenZ1peWOO8aAnMXiJ+hHj28o7vggifB4KMs9Q6LS884qxF6OYxKrgpRPj8EGSctd4wi6wbZ9itGdEeBBgEGZD2ZexY6os/kW6zl5D2IsRYEGTgVx7TckocliMG+lzi6YUII4UgA5FRYV4Qr/NEhFFDkIGM0WHGsRFgWBFkoIIwo4TwwgtBBpQI8/kiulgDQQaM4i/OBHp+BBV7QJCBTgT6FtED+hBkwFkqTFtHmlgC8yPIwApKQeyNNbEFjoEgAxsjqABERB5t/QkAAACCDADAFAgyAAATIMgAAEyAIAMAMAGCDADABAgyAAATIMgAAEyAIAMAMAGCDADABAgyAAATIMgAAEyAIAMAMAGCDADABAgyAAATIMgAAEyAIAMAMAGCDADABAgyAAATIMgAAEyAIAMAMAGCDADABAgyAAATIMgAAEyAIAMAMAGCDADABAgyAAATIMgAAEyAIAMAMAGCDADABAgyAAATIMgAAEyAIAMAMAGCDADABAgyAAATIMgAAEyAIAMAMAGCDADABAgyAAATIMgAAEyAIAMAMAGCDADABAgyAAATIMgAAEyAIAMAMAGCDADABB5rXnQ6ne79LwAA0NE2VBXkN2/e/Pr/fhCaDACA3Zs3b+TZs2fZX784Kcbe9+/fy9dffy1PnjyRi4sL108QAIAjO51O8ubNG/nTn/4kjx7ld4pVQQYAAGNxqAsAgAkQZAAAJkCQAQCYAEEGAGACBBkAgAkQZAAAJkCQAQCYwP8DzSjJCaZGMEQAAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 600x600 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import ImageGMM as ImgGMM\n",
    "import GMM_utils as GMM\n",
    "import sliced_mw as smw \n",
    "\n",
    "idx = 20\n",
    "idx2 = 30\n",
    "gmm1, label1 = ImgGMM.get_mnist_gmm(idx, num_components=10, train=False, target_label=2, var=1., nsample=5000, type_cov=\"spherical\",\n",
    "                                   n_init=1)\n",
    "ImgGMM.plot_gmm_contours(gmm1, label=label1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9855bbb9-b2ba-418b-ae81-9e0ecfd59fb6",
   "metadata": {},
   "source": [
    "## DSMW Barycenter\n",
    "\n",
    "Next, we load images/GMMs for each class and estimate free-support barycenters based on our sliced DSMW (or SMSW in the code) metric."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 156,
   "id": "1a2f9786-d933-4dea-bf4d-e9e0afc90094",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Perform 10 optimization schemes for class 0...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 174.02it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 167.49it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 179.14it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 173.08it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 177.75it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 176.50it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 172.34it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 181.31it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 177.54it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 179.01it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Done for 0\n",
      "\n",
      "Perform 10 optimization schemes for class 1...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 174.63it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 177.39it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 180.85it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 177.41it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 179.51it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 174.24it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 172.62it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 179.48it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 180.73it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 183.11it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Done for 1\n",
      "\n",
      "Perform 10 optimization schemes for class 2...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 181.60it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 181.83it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 174.32it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 174.00it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 175.50it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 173.35it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 178.64it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 171.38it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 175.32it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 171.78it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Done for 2\n",
      "\n",
      "Perform 10 optimization schemes for class 3...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 183.82it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 166.22it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 170.80it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 164.21it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 169.72it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 166.09it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 166.39it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 182.00it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 180.24it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 181.46it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Done for 3\n",
      "\n",
      "Perform 10 optimization schemes for class 4...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 170.50it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 177.32it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 170.74it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 177.68it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 179.34it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 166.84it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 177.01it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 177.28it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 172.00it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 174.01it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Done for 4\n",
      "\n",
      "Perform 10 optimization schemes for class 5...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 173.93it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 173.71it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 166.80it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 165.71it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 166.10it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 164.15it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 164.06it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 175.19it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 172.80it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 183.24it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Done for 5\n",
      "\n",
      "Perform 10 optimization schemes for class 6...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 182.02it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 168.55it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 177.89it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 178.71it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 181.23it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 177.69it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 179.19it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 177.27it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 183.20it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 179.98it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Done for 6\n",
      "\n",
      "Perform 10 optimization schemes for class 7...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 168.59it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 167.27it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 171.40it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 177.58it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 187.12it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 181.04it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 174.28it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 170.94it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 173.93it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 177.20it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Done for 7\n",
      "\n",
      "Perform 10 optimization schemes for class 8...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 177.31it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 184.01it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 185.89it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 178.88it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 174.84it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 177.29it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 180.23it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 168.34it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 184.76it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 177.92it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Done for 8\n",
      "\n",
      "Perform 10 optimization schemes for class 9...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:02<00:00, 78.22it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 166.28it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 162.66it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 156.06it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 166.85it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 161.24it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 171.62it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 175.52it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 166.20it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:01<00:00, 163.15it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Done for 9\n",
      "\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from tqdm import tqdm\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import copy\n",
    "np.random.seed(42)\n",
    "torch.manual_seed(42)\n",
    "\n",
    "restart_num = 10\n",
    "for tl in range( 10):\n",
    "    #gmm1 = GMM.DiagonalGMM(torch.randn(50, 2), optimize=\"diag\")\n",
    "    torch.manual_seed(tl)\n",
    "    np.random.seed(tl)\n",
    "    print(f\"Perform {restart_num} optimization schemes for class {tl}...\")  \n",
    "    target_gmms = [ImgGMM.get_mnist_gmm(idx, num_components=10, train=False, requires_grad=False, target_label=tl)[0] for idx in range(5)]\n",
    "    final_loss = 1e8\n",
    "    \n",
    "    # iterative over multiple initializations\n",
    "    for seed in range(restart_num):\n",
    "        torch.manual_seed(seed)\n",
    "        np.random.seed(seed)\n",
    "        gmm1 = GMM.RandomGaussianMixtureModel(100, 2, optimize=True, normal_means=True)\n",
    "        with torch.no_grad():\n",
    "            gmm1.means *= 5.\n",
    "            gmm1.means += 14.\n",
    "        l_ls = []\n",
    "        gmm1 = gmm1.cuda()\n",
    "        gmm2 = gmm2.cuda()\n",
    "        min_loss = 1e8\n",
    "        opt = torch.optim.Adam(gmm1.parameters(), lr=.03)\n",
    "        for i in tqdm(range(201)):\n",
    "            opt.zero_grad()\n",
    "            gmm1.set_conditions()\n",
    "            diag = torch.eye(gmm1.covariances.shape[-1]).unsqueeze(0).cuda()\n",
    "            s = gmm1.covariances.shape[0]\n",
    "            gmm1.covariances += .3 * diag.repeat(s, 1, 1)\n",
    "            multi_smsw = 0.\n",
    "            for gmm2 in target_gmms:\n",
    "                multi_smsw += smw.calc_parallel_SMSW(gmm1, gmm2, pnum=100)\n",
    "            multi_smsw.backward()\n",
    "            l_ls.append(multi_smsw.item())\n",
    "            opt.step()\n",
    "\n",
    "        # Keep the one with the lowest sliced distance\n",
    "        if multi_smsw < final_loss:\n",
    "            gmm_final = gmm1\n",
    "            final_loss = multi_smsw\n",
    "    print(f\"Done for {tl}\\n\")               \n",
    "    ImgGMM.plot_gmm_contours(gmm_final, label=tl, savename=\"Barycenter\")\n",
    "    for num, gmm2 in enumerate(target_gmms):\n",
    "        ImgGMM.plot_gmm_contours(gmm2, label=tl, savename=str(num))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f19b8d7a-141d-4195-90dd-de864dd6f18a",
   "metadata": {},
   "source": [
    "## Sliced Barycenter\n",
    "\n",
    "Alternatively, we may use particles and the sliced Wasserstein"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 159,
   "id": "6d92f295-b604-4fca-9598-a230b0a885f6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Perform 10 optimization schemes for class 0...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:06<00:00, 30.38it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.61it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.57it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.67it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.72it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.71it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.72it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.69it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.64it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.67it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Done for 0\n",
      "\n",
      "Perform 10 optimization schemes for class 1...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.88it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.91it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.90it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.84it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.67it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:06<00:00, 33.32it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:06<00:00, 33.35it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.67it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.72it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.85it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Done for 1\n",
      "\n",
      "Perform 10 optimization schemes for class 2...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.56it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.64it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.65it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.70it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.66it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.59it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.62it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.65it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.60it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.71it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Done for 2\n",
      "\n",
      "Perform 10 optimization schemes for class 3...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.72it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.76it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.74it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.79it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.75it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.77it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.72it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.78it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.76it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.76it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Done for 3\n",
      "\n",
      "Perform 10 optimization schemes for class 4...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.68it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.75it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.77it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.82it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.79it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.78it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.77it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.79it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.76it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.72it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Done for 4\n",
      "\n",
      "Perform 10 optimization schemes for class 5...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.57it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.69it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.73it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.64it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.69it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.72it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.76it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.63it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.68it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.65it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Done for 5\n",
      "\n",
      "Perform 10 optimization schemes for class 6...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.64it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.60it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.74it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.65it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.66it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.74it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.73it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.67it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.62it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.70it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Done for 6\n",
      "\n",
      "Perform 10 optimization schemes for class 7...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.65it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.78it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.69it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.71it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.71it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.75it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.62it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.67it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.76it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.71it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Done for 7\n",
      "\n",
      "Perform 10 optimization schemes for class 8...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.71it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.72it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.72it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.70it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.74it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.71it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.73it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.69it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.64it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.65it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Done for 8\n",
      "\n",
      "Perform 10 optimization schemes for class 9...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.76it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:06<00:00, 33.47it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:06<00:00, 33.47it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.62it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.69it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.61it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.69it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.62it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.71it/s]\n",
      "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:05<00:00, 33.86it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Done for 9\n",
      "\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from tqdm import tqdm\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import copy\n",
    "import ot\n",
    "import os\n",
    "\n",
    "np.random.seed(42)\n",
    "torch.manual_seed(42)\n",
    "\n",
    "\n",
    "def save_sw_bary(bary_in, label, savename=None):\n",
    "    plt.figure(figsize=(6, 6))\n",
    "    \n",
    "    plt.scatter(bary_in.detach().cpu()[:, 0], bary_in.detach().cpu()[:, 1], c=\"darkred\")\n",
    "    plt.xlim(0, 28)\n",
    "    plt.ylim(28, 0)\n",
    "    plt.xticks([])\n",
    "    plt.yticks([])\n",
    "    if savename:\n",
    "        save_dir = \"output_GMM/\" + str(label)\n",
    "        os.makedirs(save_dir, exist_ok=True)\n",
    "        plt.savefig(save_dir + \"/\" + str(savename), bbox_inches='tight', pad_inches=0)\n",
    "        plt.close()\n",
    "    else:\n",
    "        plt.show()\n",
    "        \n",
    "restart_num = 10\n",
    "for tl in range(10):\n",
    "    #gmm1 = GMM.DiagonalGMM(torch.randn(50, 2), optimize=\"diag\")\n",
    "    torch.manual_seed(tl)\n",
    "    np.random.seed(tl)\n",
    "    print(f\"Perform {restart_num} optimization schemes for class {tl}...\")  \n",
    "    target_gmms = [ImgGMM.get_mnist_gmm(idx, num_components=10, train=False, requires_grad=False, target_label=tl)[0] for idx in range(5)]\n",
    "    target_samples = [gmm.sample(10_000).cuda() for gmm in target_gmms]\n",
    "    final_loss = 1e8\n",
    "    \n",
    "    # iterative over multiple initializations\n",
    "    for seed in range(restart_num):\n",
    "        torch.manual_seed(seed)\n",
    "        np.random.seed(seed)\n",
    "        bary_sample = torch.randn(100, 2).cuda() * 5 + 14\n",
    "        bary_sample = bary_sample.requires_grad_(True)\n",
    "        min_loss = 1e8\n",
    "        opt = torch.optim.Adam([bary_sample], lr=.03)\n",
    "        for i in tqdm(range(201)):\n",
    "            opt.zero_grad()\n",
    "            gmm1.set_conditions()\n",
    "            multi_smsw = 0.\n",
    "            for tsample in target_samples:\n",
    "                multi_smsw += ot.sliced.sliced_wasserstein_distance(bary_sample, tsample, n_projections=100)**2\n",
    "            multi_smsw.backward()\n",
    "            l_ls.append(multi_smsw.item())\n",
    "            opt.step()\n",
    "\n",
    "        # Keep the one with the lowest sliced distance\n",
    "        if multi_smsw < final_loss:\n",
    "            bary_final = bary_sample\n",
    "            final_loss = multi_smsw\n",
    "\n",
    "    save_sw_bary(bary_final, label=tl, savename=\"SW_Barycenter\")\n",
    "    print(f\"Done for {tl}\\n\")               \n",
    "    #ImgGMM.plot_gmm_contours(gmm_final, label=tl, savename=\"Barycenter\")\n",
    "    #for num, gmm2 in enumerate(target_gmms):\n",
    "    #    ImgGMM.plot_gmm_contours(gmm2, label=tl, savename=str(num))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b07c4368-2f79-4f27-b4e4-36fb88245ab1",
   "metadata": {},
   "source": [
    "## Quantization\n",
    "\n",
    "We can also try to simply reduce the number of components."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 286,
   "id": "6e3d1d95-975e-4734-b8c4-8ba0ac6b9d12",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Perform 20 optimization schemes for class 0...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 565.61it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 611.04it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 548.54it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 517.02it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 652.64it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 573.45it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 659.01it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 628.57it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 611.53it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 643.98it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 639.11it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 602.51it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 619.46it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 624.68it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 652.02it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 589.41it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 581.41it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 673.52it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 569.77it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 605.70it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Done for 0\n",
      "\n",
      "Perform 20 optimization schemes for class 1...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 638.51it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 641.73it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 672.08it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 630.31it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 550.36it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 558.97it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 649.34it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 661.95it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 624.74it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 577.31it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 643.03it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 604.01it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 629.00it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 643.73it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 617.40it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 634.43it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 661.02it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 640.14it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 655.35it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 653.76it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Done for 1\n",
      "\n",
      "Perform 20 optimization schemes for class 2...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 518.41it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 615.49it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 639.28it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 636.10it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 625.04it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 601.75it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 638.63it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 626.96it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 625.17it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 630.27it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 622.16it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 684.32it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 658.14it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 625.67it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 612.91it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 603.39it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 649.69it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 669.23it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 647.72it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 624.99it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Done for 2\n",
      "\n",
      "Perform 20 optimization schemes for class 3...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 600.93it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 612.76it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 608.08it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 625.72it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 675.45it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 659.04it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 626.13it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 644.54it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 644.10it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 634.89it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 653.88it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 631.64it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 659.07it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 566.25it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 578.72it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 667.00it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 634.15it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 642.67it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 626.83it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 645.42it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Done for 3\n",
      "\n",
      "Perform 20 optimization schemes for class 4...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 629.09it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 604.84it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 595.21it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 614.93it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 597.57it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 611.64it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 641.39it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 648.38it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 639.54it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 583.68it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 627.40it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 654.37it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 669.47it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 674.46it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 685.66it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 661.56it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 661.27it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 650.24it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 659.54it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 664.84it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Done for 4\n",
      "\n",
      "Perform 20 optimization schemes for class 5...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 623.11it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 599.41it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 674.03it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 579.26it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 623.23it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 658.99it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 585.70it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 659.30it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 602.23it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 644.81it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 677.45it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 646.93it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 621.19it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 618.32it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 655.68it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 580.44it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 610.02it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 619.07it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 596.61it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 650.02it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Done for 5\n",
      "\n",
      "Perform 20 optimization schemes for class 6...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 618.42it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 632.22it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 652.34it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 658.50it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 682.54it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 677.20it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 670.87it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 616.03it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 654.96it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 648.22it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 609.88it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 632.11it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 638.73it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 630.29it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 601.30it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 602.83it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 635.08it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 606.49it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 617.00it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 629.19it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Done for 6\n",
      "\n",
      "Perform 20 optimization schemes for class 7...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 549.44it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 654.39it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 681.47it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 624.64it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 636.93it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 668.53it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 648.73it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 664.55it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 655.02it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 610.72it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 576.74it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 668.70it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 643.21it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 647.86it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 613.63it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 541.95it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 634.13it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 624.08it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 644.84it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 629.55it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Done for 7\n",
      "\n",
      "Perform 20 optimization schemes for class 8...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 571.28it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 621.01it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 633.59it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 598.70it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 618.07it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 638.40it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 607.56it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 603.64it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 631.75it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 624.54it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 671.64it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 701.34it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 679.73it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 550.81it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 611.47it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 658.84it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 665.09it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 591.68it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 611.69it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 640.86it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Done for 8\n",
      "\n",
      "Perform 20 optimization schemes for class 9...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 595.74it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 683.22it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 600.86it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 632.32it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 660.35it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 639.74it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 666.02it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 621.72it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 619.47it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 651.42it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 556.90it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 612.80it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 632.92it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 643.40it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 647.32it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 590.38it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 592.95it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 656.04it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 618.10it/s]\n",
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [00:00<00:00, 642.75it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Done for 9\n",
      "\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from tqdm import tqdm\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import copy\n",
    "\n",
    "\n",
    "\n",
    "torch.manual_seed(42)\n",
    "np.random.seed(42)\n",
    "restart_num = 20\n",
    "c = 50 # components\n",
    "\n",
    "for tl in range(10):\n",
    "    smsw_final = 1e8\n",
    "    n = 10_000\n",
    "    gmm2, _ =  ImgGMM.get_mnist_gmm(0, num_components=100, train=False, \n",
    "                                 target_label=tl, var=1., nsample=n, \n",
    "                                 type_cov=\"full\", n_init=1)\n",
    "    #gmm2.means = gmm2.means.round()\n",
    "    #gmm2.covariances += 1. * diag.repeat(gmm2.covariances.shape[0], 1, 1)\n",
    "    print(f\"Perform {restart_num} optimization schemes for class {tl}...\")  \n",
    "    # iterative over multiple initializations\n",
    "    for seed in range(restart_num):\n",
    "        torch.manual_seed(seed)\n",
    "        np.random.seed(seed)\n",
    "\n",
    "        gmm1 = GMM.RandomGaussianMixtureModel(c, 2, optimize=True, normal_means=True)\n",
    "        with torch.no_grad():\n",
    "            gmm1.means.data = gmm2.sample(c)\n",
    "        l_ls = []\n",
    "        gmm1 = gmm1.cuda()\n",
    "        gmm2 = gmm2.cuda()\n",
    "        min_loss = 1e8\n",
    "        opt = torch.optim.Adam(gmm1.parameters(), lr=.03)\n",
    "        for i in tqdm(range(201)):\n",
    "            opt.zero_grad()\n",
    "            gmm1.set_conditions()\n",
    "            diag = torch.eye(gmm1.covariances.shape[-1]).unsqueeze(0).cuda()\n",
    "            s = gmm1.covariances.shape[0]\n",
    "            gmm1.covariances += 1. * diag.repeat(s, 1, 1)\n",
    "            smsw = smw.calc_parallel_SMSW(gmm1, gmm2, pnum=100)\n",
    "            smsw.backward()\n",
    "            l_ls.append(smsw.item())\n",
    "            opt.step()\n",
    "            #if i%100 == 0:\n",
    "            #    ImgGMM.plot_gmm_contours(gmm1, label=str(tl) + f\"_quant_{i}\")\n",
    "            #    plt.show()\n",
    "        # Keep the one with the lowest sliced distance\n",
    "        smsw_test = smw.calc_parallel_SMSW(gmm1, gmm2, pnum=2000)\n",
    "        if smsw_test < smsw_final:\n",
    "            smsw_final = smsw_test\n",
    "            gmm_final = gmm1\n",
    " \n",
    "    print(f\"Done for {tl}\\n\")   \n",
    "    ImgGMM.plot_gmm_contours(gmm_final, label=str(tl) + \"_quant\", savename=\"Quantization\")\n",
    "    ImgGMM.plot_gmm_contours(gmm2, label=str(tl) + \"_quant\", savename=\"0\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d8c0b5a-e6d9-4725-bb6c-127dfe4f2791",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
