{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import pickle\n",
    "from wdl.bregman import barycenter\n",
    "import matplotlib.pyplot as plt\n",
    "from utilities.simpleDistributions import simplexSample"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "height, width = 28, 28\n",
    "maxsinkiters = 50\n",
    "reg = 0.003\n",
    "barySolver = barycenter(None, \"conv\", reg=reg, height=height, width=width, maxsinkiter=maxsinkiters)\n",
    "\n",
    "with open(\n",
    "        \"noise0.0005.pkl\",\n",
    "        'rb') as f:\n",
    "    data = pickle.load(f)\n",
    "\n",
    "D = data[\"True D\"][:, :3]\n",
    "w = simplexSample(3)\n",
    "\n",
    "X = barySolver(D, w)\n",
    "\n",
    "noise = 5e-4\n",
    "X += noise * torch.randn_like(X)\n",
    "# make any negative entries small to avoid numerical errors (unsure of what the exact bug is\n",
    "# happens during backward pass\n",
    "X = torch.clip(X, min=1e-4)\n",
    "X /= X.sum(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAOcAAADnCAYAAADl9EEgAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAP5klEQVR4nO3dS1DX9ffH8TehIKAiqHlBvGAaCYgJ2VUtUQmdJmuV0+TkNDZdp2nRok3TskVTm1ZtmnFaaJktKLEpsQyFlIvmJUVTEBC5o8hFRP0t/vP7r3y/zowufseZ52PZaw6KfM98Zjid80m4fft2AODPA//rvwCAO6M5AadoTsApmhNwiuYEnBqnwqNHj8pf5Q4NDckvnpycHM2WL18uaxMSEmReV1cn86ysrGjW2dkpa9PS0mS+cOFCmR87dkzmmZmZ0ay7u1vWjo2NyTw7O1vmo6OjMp87d2406+3tlbWJiYky37dvn8xffvnlaHb27FlZe+rUKZm/8MILMu/r65N5SkpKNEtNTZW1g4ODMk9LS7vjh50nJ+AUzQk4RXMCTtGcgFM0J+AUzQk4RXMCTiUYWykybG9vl19czeRqa2tl7fr162Wu5k4h6Bns+PHj77o2hBAyMjJkfuDAAZlfvXo1mj3//POy9uDBgzJfvXq1zH///XeZz5w5M5q1trbedW0IIfT09MhczVirqqpkrZqRhhBCW1ubzNPT02Xe1dUVzazPy7hx8n8nCMuXL2fOCdxPaE7AKZoTcIrmBJyiOQGnaE7AKZoTcErOOY8cOSLnnAUFBfKLl5eXR7PS0lJZa80Kp0yZIvOJEydGM2tX9Pr16zK3ZmInT56UeX5+fjSzriHm5OTIvKKiQuYlJSUyV/NAa173+eefy/yjjz6SeVNTUzSzZqjq5x2CPYtUf3YIeu5eXV0ta5OSkmReXFzMnBO4n9CcgFM0J+AUzQk4RXMCTtGcgFP3tDJm/dq+rKzsLv5K/6ehoUHm1q/Wp0+fHs2sMc2NGzfuKbd+Lb9q1apoZq1VWetq1sjg+PHjMs/Ly4tm586dk7W5ubkyt85yLliwIJpZpzGtc6UtLS0ynz9/vszV58laZ1OnUEMI4dFHH2WUAtxPaE7AKZoTcIrmBJyiOQGnaE7AKZoTcEruADU2NsriqVOnynz37t3RrLCwUNZa8z5rLUudSrRWxh5//HGZnzhxQubvvfeezM+cORPNBgYGZO2kSZNkXl9fL/PTp0/LXL3G715PQE6ePFnm+/fvj2bWyc+LFy/KfOXKlTK3Tmf+8ccf0cw64/rrr7/KPIYnJ+AUzQk4RXMCTtGcgFM0J+AUzQk4RXMCTsnB1JUrV2TxY489JvPu7u5oZs3MrH1Na96n9veefvppWWu92vDSpUsyP3XqlMzPnz8fzayznNu3b5e5tZd45MgRmc+ZMyeaWf8uNTU1Mt+2bZvM1asRrfOT1gzVmv9a+6Dq89bR0SFrFy1aJPMYnpyAUzQn4BTNCThFcwJO0ZyAUzQn4BTNCTgl55zqjmgI9l6kmlWOHz9e1lqv+FOv0QtB3xJ9+OGHZe2FCxdkPmvWLJnfvHlT5mqWOXfuXFn7888/y3zevHkyt/ZF1c80LS1N1j711FMyV3usIeifmfV9bdy4UeZjY2My7+vru+t66yaudc83hicn4BTNCThFcwJO0ZyAUzQn4BTNCTglRynWuEOdMgxB/wpZrZOFYI8zbt26JfPe3t5oVl5eLms3bNgg866uLpn/9ddfMp8xY0Y0+/LLL2VtamqqzK2TosuWLZP5s88+G80++eQTWWutyll/N7W2Za10Wf8u1unMCRMmyDwlJSWaFRQUyNq7xZMTcIrmBJyiOQGnaE7AKZoTcIrmBJyiOQGn5JzTWoWx1njUycDFixfLWmvGap3WfO6556KZdVbz+PHjMrfOV3Z2dsrceg2fYq0+3b59W+ZZWVkyP3z4cDSzXqO3c+dOmWdnZ8v82LFj0Uyd7AxBnxsNIYSPP/5Y5tYMNicnJ5pZP5MbN27IPIYnJ+AUzQk4RXMCTtGcgFM0J+AUzQk4RXMCTsk5p/Wqu6KiIpmrU4gTJ06UtYODgzK3znKqE5Nnz56VtdYMdffu3TK3dk3b2tqi2d69e2XtI488InNrlmjNaNXP1Pq+rHOm1k6l+rw0NjbKWmtu3tzcLPOlS5fKXM1BrVcAJiYmyjx2qpUnJ+AUzQk4RXMCTtGcgFM0J+AUzQk4RXMCTt3TKwAPHTok8yVLlkSzadOmyVpr57G4uFjm48bFv7Xh4WFZ29DQIPM9e/bI3NpFzcjIiGbWa/SsXVTrdXPXrl2TeXt7ezRbsWLFXdeGYN+1Va9WTEpKkrXW3Nyasf72228yX7t2bTSz5r93e9eWJyfgFM0JOEVzAk7RnIBTNCfgFM0JOEVzAk7JOeeiRYtkcWtrq8zPnTsXzaz3Lar3RIYQQlNTk8zV7dnq6mpZOzIyInNrbvXggw/KXM0L1XtFQ7DnoGq+G0II/f39MlfvobRuw1qzRPVe0hD0nqu1h2rdjrV2KletWiXz0dHRaGbNObdv3y7zrVu33vG/8+QEnKI5AadoTsApmhNwiuYEnKI5Aaf0790NmZmZMu/q6opm1trWvn37ZG69Eu6bb76JZj/++KOsfeKJJ2RujXHy8vJkrv7drFW49PR0mVdWVso8dobxv2pqaqLZli1bZO2VK1dkbp39VK8vPHr0qKy1Vsas7/vIkSMyX7NmTTSzzrhu3rxZ5jE8OQGnaE7AKZoTcIrmBJyiOQGnaE7AKZoTcErOOdXcKQT7dXQtLS3R7OTJk7J26tSpMk9NTZX5k08+Gc2s85DWmcR3331X5mpVLgR9cjQtLU3W3rx5U+bbtm2T+a5du2S+cePGaKZmoCHY62ozZ86UeWFhYTSzVsas9UbrnKk6fRmCPvNqzUgnTJgg8xienIBTNCfgFM0JOEVzAk7RnIBTNCfgFM0JOCUHU//++68stuaFOTk50UzNjUIIoaqqSuYdHR0yr62tjWbWLuimTZtkbu0tfvbZZzJvbGyMZtau6NDQkMytfxfrbKeas1qvNrT2NS3Hjh2LZtacMzk5WealpaUyt2b2X331VTR78803ZW1zc7PM582bd8f/zpMTcIrmBJyiOQGnaE7AKZoTcIrmBJyiOQGnEtTOZmVlpVzoVLc8Qwihvb09mlm7orNnz5b5mTNnZP7TTz9FM2vf0rqRmp+fL/OEhASZv//++9HMuktrzZazsrJkXl9fL3P16kRrL3H//v0yv3TpksyLioqi2eXLl2Wteq1iCPrVhiHYnzd1i1jdZw7B3nMtKiq64weGJyfgFM0JOEVzAk7RnIBTNCfgFM0JOEVzAk7JAYy6rxpCCIcOHZL5lClTollSUpKsraiokHlZWZnMDx8+HM2+++47WWuxbsu+8cYbMu/u7o5m1s6kNadMTEyUuTUPVHNOa9fU2iXNyMiQuXpv6rJly+7pz163bp3MYzuV/6V2TefPny9rrf3hGJ6cgFM0J+AUzQk4RXMCTtGcgFM0J+CUHKVYZxSrq6tlvn79+mg2bdo0WWuNWn744QeZFxQURDPrDGJvb6/MJ0+eLPOzZ8/KXL0K76GHHpK1FmsUY42/1GqWtVZljRTUudIQ9PnLsbExWWudvrRGJdYKohqH9PT0yFrrdZYxPDkBp2hOwCmaE3CK5gScojkBp2hOwCmaE3BKnsY8cOCAvF85adIk+cXVzM06F9jS0iJz68yimsnt2LFD1qqTniGEUFxcLPPp06fLPDs7O5pZM9KSkhKZ9/X1yfzPP/+UuVqNsmaNmZmZMn/gAf0sUGc/X3nlFVlrrRDW1NTI3PqZqRmsmluHYK+zcRoTuM/QnIBTNCfgFM0JOEVzAk7RnIBTNCfglJxzXrlyRc451ewnBL0XOTAwIGutV91Zu4X9/f3R7K233pK1ycnJMl++fLnMc3NzZa7mfeqcaAghXLhwQeaVlZUyHxoakrma91l7rNYO7jvvvCPzgwcPRrMXX3xR1lo7uBbrtY9qLq/m1iHYs+kQAnNO4H5CcwJO0ZyAUzQn4BTNCThFcwJO0ZyAU3LOWVdXJ+ecw8PD8ounpKREM2t/Tu0VhmDP3NQuaWNjo6z94osvZJ6TkyPz1NRUmRcWFkazPXv2yFprb/H777+XuTXnVDubH374oaydNWuWzK27tmq2vWTJEllr7QdbuTV3//bbb6PZli1bZG1nZ6fM8/PzmXMC9xOaE3CK5gScojkBp2hOwCmaE3BKjlJCCPc0SlEji7a2Nlk7MjIi8wULFshcndZsaGiQtRMnTpR5c3OzzHft2iVz9WpF61f61rjCOhmal5cnc7X+ZI2QrFfdbd26VebqbKf1edmwYYPMW1tbZW593tQoZvHixbLWeu3imjVrGKUA9xOaE3CK5gScojkBp2hOwCmaE3CK5gSckns01sxMrWWFoF8ReOvWLVm7bt06mVuzIzUPtF5daK2r/f333zJPS0u766+/ceNGWXvixAmZW2c7rbWtTz/9NJpVVFTIWrUiGEIITU1NMl+5cmU0s+ac1mnM7u5umVtrfmqGa/3Zc+bMkXkMT07AKZoTcIrmBJyiOQGnaE7AKZoTcIrmBJySc87Tp0/L4oSEO66h/b/Vq1dHM2sWaJ0TLC0tlfnVq1ej2S+//CJr16xZI3Nrn1PN60IIISMjI5pZ8zjrNXsffPCBzEdHR2Xe1dUVzaxdUmtX1JqLq51La1ZozX+tU6zt7e0ynzt3bjTr6emRtZcvX5Z5bB+UJyfgFM0JOEVzAk7RnIBTNCfgFM0JOEVzAk7JOadx0zZ0dHTIXM0a9+7dK2s3bdokc+s2rHplXH5+vqy1XpOndh5DCOHrr7+WubqxmpiYKGsHBwdlbu0lWjdW1Q7v0qVLZa01K7RuDc+YMSOaXb9+XdZevHhR5tb+sPWKwfPnz0cz6w5yQUGBzGN4cgJO0ZyAUzQn4BTNCThFcwJO0ZyAUzQn4JScc6anp8viefPmyVzNppKTk2WtdVu2v79f5mpeqPYpQ7DnnNeuXZP5q6++KvPZs2dHs/Lycllrfd+vvfaazOvq6mReWFgYzaz9XmsH19p7fOaZZ6KZNXO3dnCrqqpkbs1R1Wfi7bfflrXW/m8MT07AKZoTcIrmBJyiOQGnaE7AKZoTcCpB/Yr69OnT8vfXubm58ourkcO4cXKKY/7qfMeOHTLfvHlzNKutrZW16jxkCPaZRWukkJOTE81u3Lgha631owsXLsjcWilT609lZWWy9p9//pG59QrAmTNnRrOFCxfKWuu1jGpEFIK9cqY+ryMjI7J2xYoVMg8h3PHGLE9OwCmaE3CK5gScojkBp2hOwCmaE3CK5gScksNG60zj8PCwzNXr5qyZl3Vm8fXXX5d5fX19NFu2bJms3blzp8xfeuklmY+Njclc/d0yMzNlbUVFhcyteeDAwIDMp06dGs2stSprdr127VqZq1flWatyJSUlMrfW1VJSUmSuZrDWZ/Vu8eQEnKI5AadoTsApmhNwiuYEnKI5AadoTsApuc8J4H+HJyfgFM0JOEVzAk7RnIBTNCfgFM0JOPUfPM7RL8rR32QAAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "fig, ax = plt.subplots()\n",
    "ax.imshow(X.view(28, 28), cmap=\"gray_r\")\n",
    "ax.axis(\"off\")\n",
    "fig.savefig(\"noisymnist.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
