{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 21247,
     "status": "ok",
     "timestamp": 1612291061721,
     "user": {
      "displayName": "YUBO WANG",
      "photoUrl": "",
      "userId": "04037309235043616004"
     },
     "user_tz": 300
    },
    "id": "wOkWBE5om6yD",
    "outputId": "b913ceb7-1a39-4d63-f01c-9dd779af9b3e"
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "from autograd.numpy import log, sqrt, sin, cos, exp, pi, prod\n",
    "from autograd.numpy.random import normal, uniform\n",
    "import os\n",
    "from scipy import stats\n",
    "from google.colab import drive\n",
    "drive.mount('/content/drive')\n",
    "\n",
    "path = \"/content/drive/My Drive\"\n",
    "\n",
    "os.chdir(path)\n",
    "os.listdir(path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "qYqi0H4-nFtr"
   },
   "outputs": [],
   "source": [
    "# import data\n",
    "import pickle\n",
    "PATH = 'anonymous/Contour-Stochastic-Gradient-Langevin-Dynamics/'\n",
    "\n",
    "f = open(PATH + \"result/popCSGLD_samples.txt\", \"rb\")\n",
    "pop_csgld_epoch, pop_csgld_x = pickle.load(f)\n",
    "f.close()\n",
    "f = open(PATH + \"SGLD_result/pop_SGLD_samples.txt\", \"rb\")\n",
    "pop_sgld_epoch, pop_sgld_x = pickle.load(f)\n",
    "f.close()\n",
    "f = open(PATH + \"reSGLD_result/pop_reSGLD_samples_var30.txt\", \"rb\")\n",
    "resgld_epoch, resgld_x = pickle.load(f)\n",
    "f.close()\n",
    "f = open(PATH + \"cyclicalSGLD_result/cyclicalSGLD_samples.txt\", \"rb\")\n",
    "cyclicalsgld_epoch, cyclicalsgld_x = pickle.load(f)\n",
    "f.close()\n",
    "f = open(PATH + \"result/CSGLD/CSGLD_samples.txt\", \"rb\")\n",
    "csgld_epoch, csgld_x = pickle.load(f)\n",
    "f.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "s1jzs0-39mIV"
   },
   "outputs": [],
   "source": [
    "import pickle\n",
    "PATH = 'anonymous/Contour-Stochastic-Gradient-Langevin-Dynamics/'\n",
    "f = open(PATH + \"SPOS_result/population_SPOS_samples.txt\", \"rb\")\n",
    "pspos_x = pickle.load(f)\n",
    "f.close()\n",
    "f = open(PATH + \"SVGD_result/population_SVGD_samples.txt\", \"rb\")\n",
    "psvgd_x = pickle.load(f)\n",
    "f.close()\n",
    "f = open(PATH + \"SPOS_result/single_SPOS_samples.txt\", \"rb\")\n",
    "spos_x = pickle.load(f)\n",
    "f.close()\n",
    "f = open(PATH + \"SVGD_result/single_SVGD_samples.txt\", \"rb\")\n",
    "svgd_x = pickle.load(f)\n",
    "f.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Z2SDPFYLpddP"
   },
   "outputs": [],
   "source": [
    "def mixture(x):\n",
    "    energy = ((x[0]**2 + x[1]**2)/10 - (cos(2.0*pi*x[0]) + cos(2.0*pi*x[1]))) / 0.5 # 2\n",
    "    regularizer = ((x[0]**2 + x[1]**2) > 20) * ((x[0]**2 + x[1]**2) - 20)\n",
    "    return energy + regularizer\n",
    "\n",
    "def mixture_expand(x, y): return mixture([x, y])\n",
    "def function_plot(x, y): return np.exp(-mixture([x, y]))\n",
    "\n",
    "lower, upper = -2.5, 2.5\n",
    "axis_x = np.linspace(lower, upper, 500)\n",
    "axis_y = np.linspace(lower, upper, 500)\n",
    "axis_X, axis_Y = np.meshgrid(axis_x, axis_y)\n",
    "\n",
    "energy_grid = mixture_expand(axis_X, axis_Y)\n",
    "prob_grid = function_plot(axis_X, axis_Y)\n",
    "pmax = np.max(prob_grid)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "tysQFGHgRgph"
   },
   "source": [
    "generate samples from the true density"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "mARPrHRZ448K"
   },
   "outputs": [],
   "source": [
    "# # generate samples from the true density\n",
    "# def rejection_sampler(p,xbounds,pmax):\n",
    "#     while True:\n",
    "#         x = [np.random.rand(1)*(xbounds[1]-xbounds[0])+xbounds[0] for _ in range(2)]\n",
    "#         y = np.random.rand(1)*pmax\n",
    "#         if y<=p(x[0], x[1]):\n",
    "#             return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "70GvMaaz6KOj"
   },
   "outputs": [],
   "source": [
    "# true_samples = []\n",
    "# for iter in range(50000):\n",
    "#     sample = rejection_sampler(function_plot,[lower, upper],pmax)\n",
    "#     if sample:\n",
    "#         true_samples.append(sample)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "sE7kGufH6zi4"
   },
   "outputs": [],
   "source": [
    "# true_samples = np.array(true_samples)\n",
    "# true_samples = true_samples.reshape((-1,2))\n",
    "# import pickle\n",
    "# f = open(PATH + \"simulation_figures/true_samples.txt\", 'wb')\n",
    "# pickle.dump(true_samples, f)\n",
    "# f.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "S1x1ojj8ohSB"
   },
   "outputs": [],
   "source": [
    "f = open(PATH + \"simulation_figures/true_samples.txt\", 'rb')\n",
    "true_samples = pickle.load(f)\n",
    "f.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 307
    },
    "executionInfo": {
     "elapsed": 60893,
     "status": "ok",
     "timestamp": 1612150917246,
     "user": {
      "displayName": "YUBO WANG",
      "photoUrl": "",
      "userId": "04037309235043616004"
     },
     "user_tz": 300
    },
    "id": "2bOydkbl69Fk",
    "outputId": "139d03fd-697d-48cb-fa4c-54e8ffb907c3"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.6/dist-packages/seaborn/_decorators.py:43: FutureWarning: Pass the following variable as a keyword arg: y. From version 0.12, the only valid positional argument will be `data`, and passing other arguments without an explicit keyword will result in an error or misinterpretation.\n",
      "  FutureWarning\n",
      "/usr/local/lib/python3.6/dist-packages/seaborn/distributions.py:1657: FutureWarning: The `bw` parameter is deprecated in favor of `bw_method` and `bw_adjust`. Using 0.15 for `bw_method`, but please see the docs for the new parameters and update your code.\n",
      "  warnings.warn(msg, FutureWarning)\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAANMAAADKCAYAAAA/3nWKAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO1d0ettR3X+Jjfem/zihUIS2pLkNkKlIRSpEEXxzTY0SlNREPRBKBV8UVAQisU/oeBThRJQfBFFUKliSxpBCAUruREpiUmKCMaIYNQHL96YS5LpQ35zWWedtWbWmln77Dm/uz/4wb3n7D3znbXX2mtm9p71pZwzNmzYMI6b1iawYcNZwRZMGzYEYQumDRuCsAXThg1B2IJpw4YgbMG0YUMQbl6j0zvuuCPfe++9a3R9ZvGa8ITjpnR4HjcCnnzyyV/nnO/kn68STPfeey8uX768RtehuPLya+LnFy8cJuFr/WtYm9eh+l8aKaWfSZ+vEkzHCqvz0uOiHcgbQNK5Szi1hdeSdpkBWzAZEOHAwLgD1Xi8dO3Vvc9uPX+u2k6EQ/faZsnA1vpaur+jCSbrRTtkJpAcGKg7cQ8/jYPWv/Y959XLp8apxkuyyxJBVePGv4vsd/pg8t75lr7rthyYHyM5MGDnZ+FwtcLphPT/0rVXQwJqKbssmblb50T4y7TBNDK0oudH3XklZ9GcmDsw0Oc8LaetBZF0zMn5cyIfjyNH2uVQmdJz/khQTRlMLcO0hi+8rdE776gDlzY8zmPlcPXlNp+TC+eun6fxadlpxsCucZM4FkT7S8F0wdQzP6gNH0qbEXdejxNbHFjjZuEg9f/7P7yy8//bbrl559iTC+eun1+c2eLInkCKsMvI9ZL41b4b9ReKqdYntYtmGY/T461tt47RnNiSDehx1Nkkfldefm3nT21TCKTf/+GV638c/Dt6XmmL8+EcrIEUbZee61XaifCXHkyXmSgix+PA+LCKO4vkwAU0K5SMUMtQGgoPLZAKfvPSNfH82289f/3Y2265+Tof3oe2UKLx0ThxXhTFJuWcll0814tz4xwpTs7bfr83O02TmVoT26vXXq2Oy/n3IxmKtwvI2aAG+r10J7agdtcs7f/mpWtqIPHveYby8qHosQv/vpeH5RrW/EX6LiJDTZ2ZCjzG7skAlhWqAu4skiPzbCDx82Qn7rg0kAqev3J177xLF092eN5+6/m9DOXhM2KXYpNyLLcL7aNn/uRdDCnH8SxF4c1Ow5kppXRPSul7KaUfp5SeTil9crTNlmHKuNs6TudozQl439yJgXpGoJ9L8xULNMe1BlL5XPtuBD124Z97sqT1evXAMpqxImKY9wqAT+ec7wfwDgAfTynd72nAaiAtePhnSwz3KGrDKs8xlotXW70D9ECSjtGGexY+Fq5RdtHgmcdx9N54PRgOppzzL3POPzz99xUAzwC4a7RdL0Yz1CEQMVdZAj18lnZMjpFrtbOKuSDv0AWIlNK9AN4K4AcR7c3gdBHGby1WHBP4NRn5bT0Z0otDBn1YMKWU3gjg6wA+lXP+nfD9x1JKl1NKl1988cWobjdsmAYhwZRSegNeD6Qv55y/IR2Tc34k5/xAzvmBO+/c26QoorbScijwZzIbdqGtys2CQ16/iNW8BOALAJ7JOX9unFIfvEa7eOGmg21QKw5XOM5wk6Dw8CnHRjgpt0uBtERvvVbSbzm5cG7nbylEeNO7AHwEwLtTSj86/XtvQLt70AzBP6cGbT076Qko+sxk5Jgat/Jdy9HpsyTvMRYntnxXEGUXDdq1sj6vkxB5Y4tYzfvvnHPKOb8l5/xXp3//4WmDG4kah/9YfpfpudP03uXokKbmFPS7qGFQ+Z2lPdpHLaDod+Wc6KHZ0nah18ty7XoDZCQogSN5A+Lk/Dnzyp4nKxVcvHDTztLrrWR7AHD6tvXpqtBtt9x8fRXKc5cdGeJpv//2W89ff27TylAa1xE+1C6WvsS2mF16HXrvmhl8pvXbj/bdvFp2Ato//OT8OZNjeA0kzREsd9XbbrlZnRMANqfZswHLToB/aOWZp7T4cE4RduGQrpd1uFd84oSNdCRfGc1KwOSZSbrbeM7lqAVSKzsB+xkK0PcQ0XOu/7vj7lt48GygZciSqaQg0xZCRviU9pawi+d6cX4cNd8ZWfCgmCqYJANpxtGgOYbFOFpAaY4D1LOU5DAWbtrTfimggF3HrQUR5WQd4tTePIi0S09m8AaUdJzWbg+mCiZAD6gCb0Wg0uYouOMA+tN1z+qiNozRsqQ0X7FO5qVAavFp3eAkuwD7tpGGdJ4bTO04LaAKPGUORjBdMAG6gQC/IbyB1HJkAKLzSBgZl1sDCrBvFeecehdoJD4Aq/FQsU3LLqPXjMNr915MGUxA20DWNiL61uZull2/tA0vP0tAAfYHqK1A8s5RShs8qIB920TZxMPPe/4opg0mYMxAo8aRHBmAe0EkelzeyggSLBmyZ04pcdL6lM7p4dDDz3JeBKYOJsA2EZaOXwqeBZGIQLIsyiy5ymnhQ9tt2Sb65qK1s4a/TPOcqQXLkCg6kGrPM1qv3kQ6jXROzyQ6KhvUztF+e8tm0Wj5wxL+Mn1motDuOmspKkhZ6pAOE5Ell8KhJv1rts1xNJmpQErfh94xWyA5cuTGtg1xOISPHFVmsqgbHEJNwaI+Ea2E0cPDwimSD+1LwiFswtvV/r9Ef0cRTJ67ylIXpsDqwFrB/h70Fly0FFuM4qNxkr5fOqg91WBvGEmZ3tQcYahZ1B4sxTk1aAICQJwKhsSphRFlEC836zln+jlTxBi39+JEq2D0OrCHR+vVJlqwv7QVqYIhcdLQqwxSQ4SkzGhATRlMvcOICGU6zYGtzltgVXuQuNUc1ysnQ4+x1Pb28uGcarykwI7O3DWeFEuoGE4XTCPDiNExeSuQPAXqrTIuUr8c1uL9NfCSyJF8OCcNWmD3zi97b7r0mMj521TBNDJPAerzgtK+VwWjx4HL91JQeeuMazwsChgFpcY4RTQfzkmDVR0kYthJeVJEKTtyTBNMo/MUelxrTA7UXztpObClxC914NFC+S0etRLJly6e7GwYjCjcXwtsyoujZRMOzaG9w06OJRZmgEmCyRpIUWp9Up+87xEHpuh1YIlHD4dyTAmqHj41TSbtBsO5SUFdUONAb37eWuM1f4lamKGYIpgoLGJaGjzzlFrfNYexBtHzV652O3BtKCUF0hPPX1F5vO3Sxb3jJT60b21j3cgNRvqOBlXLJtabH+VXA/WVcv5o2a/Vg0kzUoRaH4X3gaUlkDQn7nHgvf4Fx+VOWwsizpFzonxKf3zI0+ID1AOJ913AbzSl3Qi9KK+KIb0WtexoyU6rBxPF6OR2xzmc8pfSHU4LpJYTS07EM1Tpy3I3pL+9FkjffuzZnf8//OB9O5ykgAJkW1n4WLMk/YxyKOXJahKhGlojGMsiUeTNF1j5RVfrgzZpgkv/tGN7JVykYVWBJRvwY7k+UoubNrzThlHffuzZvUCSPufcf/PSNZPoWMRwUzrm+StXxcWK0l+vflXNXygk8ewR5ZVp3hq3LgF7VOlq/bQ+o6jNBYrDag4tteGRYZGOpQ5Z61M6hgd4CzUJGW8g1Y712KR2vVr+UlN8rEnbWG78USoYX0wp/Sql9FREe70YUaUD5CEeBXcCLRvUzmn1W8sCrb56wX+rJn2qQRtyWm4ygJyxLagtULV8oXbz7c1OUZnpSwAeCmpLRG+gSOk7Ys9RzUFq3/U6jmfRwctJg+ZUrWvhucm0MiS/VtastAZCginn/DiA30a0FQWrYWsXZzTTHQN6g/ssIPr6HmzOtCkHbjjrOFgw9SgHjsBa5bS2/DmiJXQsuBF+46EwzWpeC70XvVeBogX6DMfzXe/vsIiajaJW1pmi5zdwm/AHuRr4tYosDBN9I5nqoe0oasaxPIykVVKpykTB2y5d3FkAePjB+6oPSss5an9C7W+pljjVYeIcrKgF+J5ChWArTYtJ4lTri4PfJDzidS1lEOucKEr8LWpp/CsAvg/gL1JKL6SUPupto9xxuB4S1yKSAoZ/XjOOVtOthlpWePjB+3b+LG14Ll7tWKvT0uNKcFszXU09sTdb8htMZIawaFdpfjQqvBa1mvfhnPOf5pzfkHO+O+f8Bct51rdxuUMVY0hG6ZFP4aB3R96+dXhCjy1OR9uqcdP4lnYoh1YQW4ecXj69dqHHaTeYljPXRNcknlZ/qeHo3s0rqMmmWF50BWzyKRSSFlMZNpQLUJ6JFIdovehK0aMny3WY6LClZ2glZSXJiS18vHbhNqE3GI+SYEFNO8viK0DdX3rmZlMFU0uHCLA5Y8Sig+Y4wL7zaOAZSVPuo/xqv19yXsv8ScoGNSeWtmDU5ieXLp7sPHz12oVy6HHmloph8/zGCOZo9jMtpUNUzi9tWtCSvCyTWu48EqyBJK1WSbIxWlBTx+WBVcsGBT18eAagaO36LZAyde/1otzKb7I8gB7VreJIOeeuE0fwwAMP5MuXL+98FrFtvcBimENsWy/o0ZLlv390x68U3CN8gMPZhWaG0S3rErzKjimlJ3POD/DPV89MGpbSIQJ2DSQFlSVDWTCyGLInG2PkQLepU3gzZI3PknZpOXOPiqGGaAXDaTITEFMttKBHPiWy1BfFyHzAumFSygraylUEH40T58X7LujR16WI9BWp71r/R5GZLFKPLUTqIkkZCoB5gmtVFa8VfaQcKPiqVS0reDNkbVhlmdP2KNB7s0KEr0j9WvuXMFUwAcuoZ48od3Nnps5TwItzUPToyC6htt7rxLVhMOXD+2jyGQgkelyvimHNZ85MRVdgXfVsy5h8T21CcaJIQeZaRgD0lc7a+3YeJ/ao0NdwSKX1nlW5TW290sYSfXvV1nvmbxJqQW3JChHZQOPD25cKPUrnRyDCV0o7I5g2mGaDJHk5WmetBUtQe5UngH4ntohVS/1JiLrBzITjZn9AaOPvq9de3fk7JHr76131WlrKci051ShMm5lm0meSoDkyL7cL9OsQeR9k83nTzpyqUROv11YejaaWTXoRFYSjGk1TBlPkHcproKgn7FYRAasqh9Q2UH/FSpJwKZwi+HjtMmoTKy+JH4cWyCMBNV0w9RoHGDeQVWQMiBMRkLh5XnEqsJQBLrxaNb17xdckXhSa2BnQpz6x6TN1wDrGHzFQjwMD+05cc97SpkdczPJOHAVVmuDHWIUDvHwopxq8KoberB2pz9SDqYJpRAwZ2B0+AHYBK28g1fbJSM7LOXpVOSQerZdL+eej2kwtPpRTC1YVQ0kW0ysrI0HTZxrRZgImCqZN7EzmUHiMvjUuYYQPoJewrqFHxbAnU1J+Eqyaw5vYWWNYJfXJ+x7RIqKIEBfrlbeh+5kkvSjAFlBaIElBdCgVQ41X4daCVcvLE1CrvzVuUTf3Vhq1vmDKIQXSaCa4dPGka6dtjUdPwXy+27a2r8nLB/DZpXfjJIUlkLwlDmp904A6irfGW4Fk3dNvEbBq9T0qn1ITF9O0oyhagWQRERjRZ7Ly8dhlVMWQfi7xKrBqeQHY69ebHSlWzUzW/UM1AWJp64ElA3DQvkfV+gpq2UDjVzgA2OMhcbAW5NfKfUm1IKx8gHi7tK6ZJ5Bq87eRHb/TZyarPhOHFFitDMCNJGVEi1pfbzagqnUckh5ShKTMtx97Vq1gxPlQe9Wux6FUDMs1az0i8SyEUAnQ0blbwZTv5vWsEvFjaopwLyn/pseX9mpqfdrnLbU+3tdVIZj5PJE7bo9ETDmHyrloKhi1dw1bw81RFUMJtWEnoPvK81eu7v1Jx0XI0awWTHSIpz1NL2gZRzNQDS9de1WdxK6h1if9ds6jV5vJgpoz1R4SU0TZZUQSkw4/NXvT77i/1PpuveYWVR75oZTScymln6SUPhPRZgEPJA1SQEXqlRZ4h1YaODeKVnYcUQzk2an0IfUtgR7bM3drwZIhRp7/eeF5w344mFJK5wB8HsB7ANwP4MMppft729Pugpal15EHmBya00S3vTZ6bDZq56jA02Dlx7PT6FAvIjO9HcBPcs4/zTlfA/BVAO8LaPfg0OYpEVhyiLYkrEO8XmgZ8hgREUx3Afg5+f8Lp5/tYFMO3HDWcWaVA3vgqa5zo0CS9omER1VkdkQE0y8A3EP+f/fpZ13QHPoQynlanxEX/Cw5zVLo1WkqgV7Ot/qK9lpTLyKC6QkAb04pvSmldB7AhwB8a7RRbiALpML0vdDa8KjiacdKhfNrWPJG0tP26E3Go2IogQviSWj9Luvvpg9tWy+8DgdTzvkVAJ8A8CiAZwB8Lef8dOs8Ssz6lLlmAOm71msyFPS7pYY0FNaqpzTwivN6ApqjnEsDoRXcNfE33q4HrWDsqf5E+Wn+IqlxXO+zUyAPmOjdvIg9RIAiVRL0lrbldSJA17W1vJ+3xHt5Ei/L+3nSGxmW9/JG7aJdr963IDR4389rvZs3fTABNiNZi9QD+5mwtn9o5IVOQJbi9LxY2grqiBddAd2JLXw4JwtadrFcLy3YCw79ouuq7+ZJQz1pPCyJRGs6pbfdcrP7rXH+2UyatpIgM9e0bWEkkDjoXJYOl0btYlEx5J9RrjVB8Za/SP1yWDYITr85MErGBWjPzbRtGIC+Cc6r1qfdfTmitz1QbpZsYOEDxNrFK3Vj2Ujqkbep9W3ZHLh6MAHyS69A327bUWnFnoDSYA2k1jCmFlCAfaMi52Qd3kRv5Y8IJI0b5WeBN5CAI9jPVNCSK4lQe7DCooxHHUhT6wPsgVQ+47pQXHWd6upah1iWBRAvH49iYMGoiqHGDYBJHSRKGYRjiswEjJf5KuhRoPAMNb2LIT27fls8pBVOLSvURJlHh1W1reJSqTEKDwdPUcwCi7/0yttMPcwrsEorckNJY32vyoKlQlJPPQqgL5A4j9HyWpzX6LBK4sR5cXjlSb1FQzk/D24YGU6PlEuvrKJHXIwOuWrwzN+0On6tIZ91mCWtWlmygVeGs9cuETKcnJ8FN5wMZ02CUzpGarOn70PM3zxq4hbnpVV3NF5RfGhbfK7SwiFkOAHZX1pZ+MwU7gdipRVHZR1rcpMt5+kdk1sCivZPg7oWRJxTBB/absQ8xYtIydbSXi+mLKgShSglur2Huo0LdHJ+/5lN5EXlbZ9cOKe+Uzf6qKDGR7KL9Nu1z6PUA6Ou82g70wbTmpKMHsfhiFgM0UDbEfs+DR4tuGqw8LHYhfLTbKSdN3LNZ5DwXJ/BkcAqVxJRtMXD4ZCYWYZzBgnPKedMwNwynBbJEq4B5FXI83CwqD1wXhESKqUdjdcel4ZNejGLDOeUmallnFLzjte+62nLcnyvvI1W+LKHm/Z8p/XqDD+mxqnGx2IXy00m0iat47mfWHxmJDCny0w9D+PodyOyirPKcEY8KNWUDDm8MpyUE+fFoclw9urabjKcFVjfgKhBM1KPgbxvJe/U6zbq/xRuLQ6Uh0fQIFqGk3Ia0UQq5/cGVO/bMsD+cBM4gzKcHN73rUbmKb3v51FI+rYjciXapkXa1yFlOCmnXikXKk3aI1ZdPtd4UW4a6PeRc8hpgumYZDgB2y5O6rxA23kkDoVHz0uuBVwxsDegeiR/KHqkSblTewOp9dZ45KLMFMF0jDKcgOzEktQk4NO1bQkKePcPScd5A8pbo0PamtIrwdk7DNZAvx9RfeeYIpgoeoxz/Xhlku0dVo1sgpO+86jjWQq88H48KoY02K3OPGoXHlgRWZLykrgVtHbaWjW8LFh9aVy762hjcu1v59yGJpMFHod54vkrokPXCsNzbprETU0OVOu3xo9y0pQ4pGXkqF22Rc6lpVTSul6WFc7a3HbnWjRUUqzL5VNlppHJLX9b2nvHkVaovNmAfiZlg5o6nuW3UQ495bWKYl/hVNq+bjMnn15NW3pOS0mxhZ4NnHxOW84fUQ0EVs5M1oiXJrj0r3Ys4H/Fp5YNAFshE3oMPdeiGaVlJW8gtb6rqRha+AA+u0iZtMZBy061rGVZJOpRDLT46urDvALrCpqma1uT4JT6qfXN0VsfruY4VhxCrc8jFyNlSdquBz0crM/dLJq2vG/vUJNjKJhSSh9MKT2dUnotpbS3jTcSXk3bUbSkOAuKVqum5VpQc5yrjfF/lEC0FS0+HC27tI6vKRf2vNwrrSxKspzRelCjmekpAB8A8HgAl0XgFYimaBlbE0KmiBI5662VJ6FXglM6RwK3gfUmswR4+0v2NxRMOedncs7PRZGJxFJKd5Hw1HdbAks7MoX3JiMNuZfcgnJUaus9yoHHEBAbNhQ0gyml9N2U0lPCn0u3tkc5cETa5RCyMMeONQTkCnq0naL2P0mI8JdmCznnvxnuZUVI27etF+X2W8+7J6mabIqVmwZawXUUGicrnwi7tNCjjySB223JG8jR3L4tF7BWP65Wi0Aai5fSvxRvu3RxZ6zvcRB6EfeKMbLaDrQ+HrD72zkHKywqhhKf658RPhxRdokC9xWtjwiFSYrRpfH3p5ReAPBOAN9JKT3a25YkEwLsS4RokKqWeqEFXO8F5xmg5+LVfsuIgqC1j9r1WMIuLUkZqaiMJidTQ2SV24LR1bxv5pzvzjlfyDn/cc75bz3nV8tZDegz8TbW0GeSlCcoP6s+E0WPPtPDD97X1Gfq4bOUXSyoObnkKxQ1f6nh6PSZlqjvDfQVppde6GxpEBVoEi6UY68+05JqfYAeyFZ5mxqvll08mWHUV0qfBZ6659PWGqcVOcW64mSsTn+8pQww0J+ya7IpVimXWiBRSFswtLkT798aUJahlZTBNT6avI3UF4dmF86hBq3mOfcVj9iZ1rd1P9PqmQlobxkHbA84tRrfnkAa3RwIjMu4WHbZWrJBlFpfa49V4QTUHwRHytvUuBV+FvTUX59eUiZi23qBpQzwktvWgfEJbvQ+Istws3dY1apHQdFrl1pN8Zav8MAaFTubPpiAccGz3lrWnoAC7G9meMbkEka2igNyJqC8LCqGEp8aJwt65yoFs4qdrT5noqgpUAC+B3meWtZSlrLKuEgViShG5m8tbSaKMnfhiFIxpHzKuRoni4j3iEZTTXaI8vPgTJb6kgIKsL3kOKrRpAWzpEEkLYpQ9AidFdQ4FHgEz7yPCyx8Shsem9D+vRw4amJsgP2l2DMtdga0Bc+8bY30PSLsFaXPdCi1Po1T7Qa3ptiZxE1r09NeL6YMJqAtYmVtI6Lv2hBCKmjI0SOf4lUxvM7nZflzzs/LyaogCOjzFYvcTq8+U0Tx/lFZmmmDCeg3UoRWj2VcDtTH5qNDCEtGAFhAN4JI49UzDOactP4kLKXP1BtUET4zdTABfiNFil61hhDauHx07tbicKiFGSsfiZOnfy+HGrw34Eh/mT6YClpGWko5bk193RoHq/PS46M49dxkltyLxGG9AUf7zNEEE6BfxBkkGNeAN6AOhUPdZGbDUQWTdqfhxfhvFMwYSIB/+Duq2Ke1aTkmst+j8L4rL79mNs4SfdcglRLubctznjeQIgOv9lpPrR+LjSLgsXOkz0yfmXplGSPuOGuqGNY4SP23RL3KeaMicFGBHSFGJ6EnOKL6njozram+rTmNx3G0Y63crO+gac91yndROrIePtJf7TwPDw+/Q5xbMG1mivhxvXeciBcpT9jKVqQsqNT/iKiXhU+EcLZVjC7qmvW0cebegLCKW1FETW4jnIYe1+s4EQJwkqhX4eThc2iBaC96haIjxaGBCYOp1zBaBiht9hqod/NZr9q61XEtol7AvsQO5bQkHwlWgWjP9RqZw2mB3OsvU82Z1p6nWLLB1ZdfNTuOVXStrFb2OG5N1It/T8+rKT0UHiN8aoJ0Fh69q7c9iyFR87fpMhOFddUKsMnRjw6rrBsEl5B5bG3Gq4kyl+MlGVCNT8uZRgXprDw0jKxyAraVTi+mCSbrxWt9H1UJlLcbrWJoQcT2eYso8ygfKyctsDUenuGWd4HI4i/e4d5UwzyK3hoQPcvArYIuNaeR/qRjOT/rHLAGXtxF+uPH1nRse/nQbfS14O4RGLNcL45If/FgimBq1X7gKPORXkkW2p9nbGzNCCMqhhwjVZJa39H2rfDUfpD6jhYYA3z+sqSMz2h55H9JKT2bUvrflNI3U0p/FEWsoPU8RTJQ5N1GwqFVDGuwVCeqqb5HQcqS/N8cPTy8WUkKnqX8ZTQzPQbgL3PObwHwfwD+ebC9KqzPVCywrJ7Rdkcc0CIKXePA4S3zpSFC8pJD4yQNOTUeHL0PZCP8xdP3aK3x/8o5Fy/7HwB3j7Q3EzzDn9rdNxJLib/1BvcmRreLyDnTPwL4z8D2NkyOTVBuF01rpJS+C+BPhK8+m3P+99NjPgvgFQBfrrTzMQAfA4BLly51kd2wDmbdhDgbhpUDU0r/AODvAPx1rpSHzTk/AuAR4PWKrj6ar6MmuOUpM2Xqi1QAAmTxs4KWTlHUHZyKCEgF83sR+WxO46RVlwXa1673VbAIfznYc6aU0kMA/gnA3+ecwyYNkqDV9f9L1Xc6AunihZtEQ9W0mgosomWiIl9DbUHjsBRawmK9uHTxZCd4tJtNz02mdc16/MVb2VbD6JzpXwFcBPBYSulHKaV/62nEe+c5uXBu52/v+yDj1BChYmjh1go8i3pfTcLFC02tT7IHDyrtuKXR8pcoDI0/cs5/HkWEQyu62IJlyEKDt1Z9SKqnTYd6HsewZqWCvXJegl5UWWIuDmst3L/Dq5MPBdewar1OVM4B+gr3Wwpi9vjLaEHMaZZjWqW8aBne2vcUI+V/uePwAvW9KoY1bhq0oObOaxFCtgiuadCEBEq7xSatm4xVrXBko16Pv4ximmCS0FstlJ5L4bk43HEA7DkPcBgVDN4ez5JaRuBOHSWEXM6tBZSEyBuMtZZgr78cfeH+0Wqh5Vit7Z6+gf3a3nSFqEfxwYNaUPOA1jJCSwWDo7Y50Cp/WcOInAw/ZikVjB5MFUyATdXAWzF0pLZ3SwHDo0qn8axJuFAerSzZQo+cjHV+UrMJ79/LoYYZSlgXTBdMQHv+5B0OjPQtKWC0CuWX4yi84mu1gAJkrSiRx6DkJP3ccoPbSykAAALNSURBVKOR+tzjFDjpL+esXcIamDSYgHFVA9pGz3mtu7EHEZIyGoeltaJqnJaQ2unFpoJhQI+RDikpoyFq2AnsZwTK4RCBzY+12MUrtbOWCkZk39MHU8EadcS1QD60Kt2SE+0lVDB62otEq7rSEn0CRxRMa6JXdI2eH8EBWD+wW3wOyWGWPgq2YDLCcrernRPNY825gdTeoYfhM2ILpg7M4AweJz5rGWBWbMF05NiceB6kyhak5TpN6UUAPwto6g4Avw5oJwIbFxlnkcuf5Zzv5B+uEkxRSCldzjk/sDYPYOOi4Ubiso0RNmwIwhZMGzYE4diD6ZG1CRBsXGTcMFyOes60YcNMOPbMtGHDNDjqYDpErXMHlw+mlJ5OKb2WUlpl9Sql9FBK6bmU0k9SSp9ZgwPh8sWU0q9SSk+tyeOUyz0ppe+llH58eo0+uUQ/Rx1MOHCt8waeAvABAI+v0XlK6RyAzwN4D4D7AXw4pXT/GlxO8SUAD63YP8UrAD6dc74fwDsAfHwJ2xx1MM1U6zzn/EzO+bm1+gfwdgA/yTn/NOd8DcBXAbxvLTI558cB/Hat/ilyzr/MOf/w9N9XADwD4K7ofo46mBhu9FrndwH4Ofn/C1jAYY4dKaV7AbwVwA+i257+3byoWueH4rJhXqSU3gjg6wA+lXP+XXT70wdTVK3zQ3BZGb8AcA/5/92nn20AkFJ6A14PpC/nnL+xRB9HPcxbqtb5keIJAG9OKb0ppXQewIcAfGtlTlMgpZQAfAHAMznnzy3Vz1EHE4JqnUcgpfT+lNILAN4J4DsppUcP2f/pQswnADyK1yfYX8s5P31IDhQppa8A+D6Av0gpvZBS+uhaXAC8C8BHALz71E9+lFJ6b3Qn2xsQGzYE4dgz04YN02ALpg0bgrAF04YNQdiCacOGIGzBtGFDELZg2rAhCFswbdgQhC2YNmwIwv8D1+q2aY8++I0AAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 226.8x226.8 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light",
      "tags": []
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# plot true samples\n",
    "warm_sample = 50\n",
    "split_ = 1\n",
    "fig = plt.figure(figsize=(3.15, 3.15))\n",
    "ax = sns.kdeplot(true_samples[:,0], true_samples[:,1],  cmap=\"Blues\", shade=True, thresh=0.05, bw=0.15)\n",
    "ax.set_xlim(lower, upper)\n",
    "ax.set_ylim(lower, upper)\n",
    "plt.savefig('anonymous/Contour-Stochastic-Gradient-Langevin-Dynamics/simulation_figures/true_samples.png')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "njsQpIVFRlIy"
   },
   "source": [
    "Calculate the KL divergence"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "apdTjSN7pinQ"
   },
   "outputs": [],
   "source": [
    "def KLdivergence(x, y):\n",
    "    \"\"\"Compute the Kullback-Leibler divergence between two multivariate samples.\n",
    "    Parameters\n",
    "    ----------\n",
    "    x : 2D array (n,d)\n",
    "      Samples from distribution P, which typically represents the true\n",
    "      distribution.\n",
    "    y : 2D array (m,d)\n",
    "      Samples from distribution Q, which typically represents the approximate\n",
    "      distribution.\n",
    "    Returns\n",
    "    -------\n",
    "    out : float\n",
    "      The estimated Kullback-Leibler divergence D(P||Q).\n",
    "    References\n",
    "    ----------\n",
    "    Pérez-Cruz, F. Kullback-Leibler divergence estimation of\n",
    "  continuous distributions IEEE International Symposium on Information\n",
    "  Theory, 2008.\n",
    "    \"\"\"\n",
    "    from scipy.spatial import cKDTree as KDTree\n",
    "\n",
    "    # Check the dimensions are consistent\n",
    "    x = np.atleast_2d(x)\n",
    "    y = np.atleast_2d(y)\n",
    "\n",
    "    n,d = x.shape\n",
    "    m,dy = y.shape\n",
    "\n",
    "    assert(d == dy)\n",
    "\n",
    "\n",
    "    # Build a KD tree representation of the samples and find the nearest neighbour\n",
    "    # of each point in x.\n",
    "    xtree = KDTree(x)\n",
    "    ytree = KDTree(y)\n",
    "\n",
    "    # Get the first two nearest neighbours for x, since the closest one is the\n",
    "    # sample itself.\n",
    "    r = xtree.query(x, k=2, eps=.01, p=2)[0][:,1]\n",
    "    s = ytree.query(x, k=1, eps=.01, p=2)[0]\n",
    "\n",
    "    # There is a mistake in the paper. In Eq. 14, the right side misses a negative sign\n",
    "    # on the first term of the right hand side.\n",
    "    return -np.log(r/s).sum() * d / n + np.log(m / (n - 1.))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "f2o4KB7DUGob"
   },
   "outputs": [],
   "source": [
    "cyclicalsgld_kl = []\n",
    "cyclicalsgld_epochs = []\n",
    "for repeat in range(20):\n",
    "    cyclicalsgld_path = []\n",
    "    cyclicalsgld_epoch_path = []  \n",
    "    for i in range(100, cyclicalsgld_x[repeat].shape[0]):\n",
    "        cyclicalsgld_path.append(KLdivergence(true_samples, cyclicalsgld_x[repeat][:i,:]))\n",
    "        cyclicalsgld_epoch_path.append(cyclicalsgld_epoch[repeat][i])\n",
    "    cyclicalsgld_kl.append(cyclicalsgld_path)\n",
    "    cyclicalsgld_epochs.append(cyclicalsgld_epoch_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "b_AzBGQLLudi"
   },
   "outputs": [],
   "source": [
    "PATH = \"anonymous/Contour-Stochastic-Gradient-Langevin-Dynamics/simulation_figures/\"\n",
    "import pickle\n",
    "f = open(PATH + 'cyclicalsgld_kl.txt', 'wb')\n",
    "pickle.dump([cyclicalsgld_epochs, cyclicalsgld_kl], f)\n",
    "f.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "rwgRPDX0RUnK"
   },
   "outputs": [],
   "source": [
    "psgld_kl = []\n",
    "psgld_epochs = []\n",
    "for repeat in range(20):\n",
    "    psgld_path = []\n",
    "    psgld_epoch_path = []  \n",
    "    for i in range(100, pop_sgld_x[repeat].shape[0]):\n",
    "        psgld_path.append(KLdivergence(true_samples, pop_sgld_x[repeat][:i,:]))\n",
    "        psgld_epoch_path.append(pop_sgld_epoch[repeat][i])\n",
    "    psgld_kl.append(psgld_path)\n",
    "    psgld_epochs.append(psgld_epoch_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "7dCHBfFMRvzx"
   },
   "outputs": [],
   "source": [
    "PATH = \"anonymous/Contour-Stochastic-Gradient-Langevin-Dynamics/simulation_figures/\"\n",
    "import pickle\n",
    "f = open(PATH + 'psgld_kl.txt', 'wb')\n",
    "pickle.dump([psgld_epochs, psgld_kl], f)\n",
    "f.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "MhgkKHOr-esv"
   },
   "outputs": [],
   "source": [
    "resgld_kl = []\n",
    "resgld_epochs = []\n",
    "for repeat in range(20):\n",
    "    resgld_path = []\n",
    "    resgld_epoch_path = []  \n",
    "    for i in range(100, resgld_x[repeat].shape[0]):\n",
    "        resgld_path.append(KLdivergence(true_samples, resgld_x[repeat][:i,:]))\n",
    "        resgld_epoch_path.append(resgld_epoch[repeat][i])\n",
    "    resgld_kl.append(resgld_path)\n",
    "    resgld_epochs.append(resgld_epoch_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "WS5zQvzG-2rC"
   },
   "outputs": [],
   "source": [
    "PATH = \"anonymous/Contour-Stochastic-Gradient-Langevin-Dynamics/simulation_figures/\"\n",
    "import pickle\n",
    "f = open(PATH + 'resgld_kl.txt', 'wb')\n",
    "pickle.dump([resgld_epochs, resgld_kl], f)\n",
    "f.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "eUOUMTkPrpZh"
   },
   "outputs": [],
   "source": [
    "pcsgld_kl = []\n",
    "pcsgld_epochs = []\n",
    "for repeat in range(20):\n",
    "    pcsgld_path = []\n",
    "    pcsgld_epoch_path = []  \n",
    "    for i in range(100, pop_csgld_x[repeat].shape[0]):\n",
    "        pcsgld_path.append(KLdivergence(true_samples, pop_csgld_x[repeat][:i,:]))\n",
    "        pcsgld_epoch_path.append(pop_csgld_epoch[repeat][i])\n",
    "    pcsgld_kl.append(pcsgld_path)\n",
    "    pcsgld_epochs.append(pcsgld_epoch_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ed8lyHX4_Cbb"
   },
   "outputs": [],
   "source": [
    "PATH = \"anonymous/Contour-Stochastic-Gradient-Langevin-Dynamics/simulation_figures/\"\n",
    "import pickle\n",
    "f = open(PATH + 'pcsgld_kl.txt', 'wb')\n",
    "pickle.dump([pcsgld_epochs, pcsgld_kl], f)\n",
    "f.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "st8GwYYyY2fi"
   },
   "outputs": [],
   "source": [
    "csgld_kl = []\n",
    "csgld_epochs = []\n",
    "for repeat in range(20):\n",
    "    csgld_path = []\n",
    "    csgld_epoch_path = []  \n",
    "    for i in range(100, csgld_x[repeat].shape[0]):\n",
    "        csgld_path.append(KLdivergence(true_samples, csgld_x[repeat][:i,:]))\n",
    "        csgld_epoch_path.append(csgld_epoch[repeat][i])\n",
    "    csgld_kl.append(csgld_path)\n",
    "    csgld_epochs.append(csgld_epoch_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "5PMzzPIUZteE"
   },
   "outputs": [],
   "source": [
    "PATH = \"anonymous/Contour-Stochastic-Gradient-Langevin-Dynamics/simulation_figures/\"\n",
    "import pickle\n",
    "f = open(PATH + 'csgld_kl.txt', 'wb')\n",
    "pickle.dump([csgld_epochs, csgld_kl], f)\n",
    "f.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Z78DnLjX_YAt"
   },
   "source": [
    "# Plot the KL divergence vs time step"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "RxQR6X46L857"
   },
   "source": [
    "## calculate the confidence interval"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 232
    },
    "executionInfo": {
     "elapsed": 418,
     "status": "error",
     "timestamp": 1614199825893,
     "user": {
      "displayName": "YUBO WANG",
      "photoUrl": "",
      "userId": "04037309235043616004"
     },
     "user_tz": 300
    },
    "id": "WG7iK6giFmYN",
    "outputId": "c877ddbb-6bab-4f47-ed71-4ec52fe3af45"
   },
   "outputs": [],
   "source": [
    "# load the data\n",
    "PATH = \"anonymous/Contour-Stochastic-Gradient-Langevin-Dynamics/simulation_figures/\"\n",
    "import pickle\n",
    "f = open(PATH + 'psgld_kl.txt', 'rb')\n",
    "psgld_epochs, psgld_kl = pickle.load(f)\n",
    "f.close()\n",
    "f = open(PATH + 'cycSGLD_kl.txt', 'rb')\n",
    "cyclicalsgld_epochs, cyclicalsgld_kl = pickle.load(f)\n",
    "f.close()\n",
    "f = open(PATH + 'resgld_kl_var25.txt', 'rb')\n",
    "resgld_epochs, resgld_kl = pickle.load(f)\n",
    "f.close()\n",
    "f = open(PATH + 'pcsgld_kl.txt', 'rb')\n",
    "pcsgld_epochs, pcsgld_kl = pickle.load(f)\n",
    "f.close()\n",
    "f = open(PATH + 'csgld_kl.txt', 'rb')\n",
    "csgld_epochs, csgld_kl = pickle.load(f)\n",
    "f.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 232
    },
    "id": "1prlPqno_XeL",
    "outputId": "e143dd45-aeb0-40a5-b124-f64ec61fd14a"
   },
   "outputs": [],
   "source": [
    "cyclicalsgld_kl_array = np.array(cyclicalsgld_kl)\n",
    "cyclicalsgld_epoch_array = np.array(cyclicalsgld_epochs)\n",
    "\n",
    "csgld_epoch_array = np.array(csgld_epochs)\n",
    "csgld_kl_array = np.array(csgld_kl)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "fADFEzNtidjg"
   },
   "outputs": [],
   "source": [
    "def process_data(kl, epochs):\n",
    "    # compute the average of five chains at each iteration\n",
    "    final_kl_avg = []\n",
    "    final_epoch = []\n",
    "    for i in range(20):\n",
    "        tmp = []\n",
    "        kl_avg, epoch_uniq = [], []\n",
    "        for j in range(len(kl[i])):\n",
    "            if not epoch_uniq:\n",
    "                epoch_uniq.append(epochs[i][j])\n",
    "                tmp.append(kl[i][j])\n",
    "            elif epoch_uniq[-1] == epochs[i][j]:\n",
    "                tmp.append(kl[i][j])\n",
    "            else:\n",
    "                kl_avg.append(np.mean(tmp))\n",
    "                tmp = [kl[i][j]]\n",
    "                epoch_uniq.append(epochs[i][j])\n",
    "        kl_avg.append(np.mean(tmp))\n",
    "        final_kl_avg.append(kl_avg)\n",
    "        final_epoch.append(epoch_uniq)\n",
    "    \n",
    "    # compute the mu and sd of the 20 trials\n",
    "    mu, sd, epoch_uniq = [], [], []\n",
    "    for i in range(final_epoch[0][-1], -1, -20):\n",
    "        tmp = []\n",
    "        for j in range(20):\n",
    "            if not final_epoch[j]:\n",
    "                continue\n",
    "            if final_epoch[j][-1] == i:\n",
    "                tmp.append(final_kl_avg[j].pop())\n",
    "                final_epoch[j].pop()\n",
    "        if tmp:\n",
    "          mu.append(np.mean(tmp))\n",
    "          sd.append(np.std(tmp)/np.sqrt(len(tmp)))\n",
    "          epoch_uniq.append(i)\n",
    "    return np.array(mu[::-1]), np.array(sd[::-1]), np.array(epoch_uniq[::-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "5TMqH8jRi_Fv"
   },
   "outputs": [],
   "source": [
    "psgld_mu, psgld_sd, psgld_epoch_uniq = process_data(psgld_kl, psgld_epochs)\n",
    "resgld_mu, resgld_sd, resgld_epoch_uniq = process_data(resgld_kl, resgld_epochs)\n",
    "pcsgld_mu, pcsgld_sd, pcsgld_epoch_uniq = process_data(pcsgld_kl, pcsgld_epochs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "EcqX57PIpjg1"
   },
   "outputs": [],
   "source": [
    "csgld_mu, csgld_sd, csgld_epoch_uniq = process_data(csgld_kl, csgld_epochs)\n",
    "csgld_epoch_uniq = [csgld_epoch_uniq[i] / 5 for i in range(0, len(csgld_epoch_uniq))]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "YvUCuaW_AQBT"
   },
   "outputs": [],
   "source": [
    "cyclicalsgld_sd = np.array([np.std(cyclicalsgld_kl_array[:,i]) for i in range(0, cyclicalsgld_kl_array.shape[1])]) / np.sqrt(20)\n",
    "cyclicalsgld_mu = np.array([np.mean(cyclicalsgld_kl_array[:,i]) for i in range(0, cyclicalsgld_kl_array.shape[1])]) \n",
    "cyclicalsgld_epoch_uniq = [cyclicalsgld_epochs[0][i] / 5 for i in range(0, len(cyclicalsgld_epochs[0]))]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "qCZ9AzzTRA-T"
   },
   "outputs": [],
   "source": [
    "font1 = {'family' : 'Times New Roman',\n",
    "'weight' : 'normal',\n",
    "'size'   : 22,\n",
    "}\n",
    "\n",
    "font2 = {'family' : 'Times New Roman',\n",
    "'weight' : 'normal',\n",
    "'size'   : 22,\n",
    "}\n",
    "\n",
    "\n",
    "figsize = 9, 8\n",
    "plt.style.use('ggplot')\n",
    "figure, ax = plt.subplots(figsize=figsize)\n",
    "plt.grid(\"darkgrid\")\n",
    "\n",
    "plt.plot(csgld_epoch_uniq, csgld_mu, label=\"CSGLD xT5\", linewidth = 2, alpha=0.5, color=\"purple\")\n",
    "plt.fill_between(csgld_epoch_uniq, csgld_mu-csgld_sd, csgld_mu+csgld_sd,\n",
    "                  alpha = 0.3, color=\"purple\")\n",
    "plt.plot(cyclicalsgld_epoch_uniq, cyclicalsgld_mu, label=\"cycSGLD xT5\", linewidth = 2, alpha=0.5, color=\"purple\")\n",
    "plt.fill_between(cyclicalsgld_epoch_uniq, cyclicalsgld_mu-cyclicalsgld_sd, cyclicalsgld_mu+cyclicalsgld_sd,\n",
    "                  alpha = 0.3, color=\"purple\")\n",
    "plt.plot(psgld_epoch_uniq, psgld_mu, label=\"SGLD xP5\", linewidth = 2, alpha=0.5, color=\"red\")\n",
    "plt.fill_between(psgld_epoch_uniq, psgld_mu-psgld_sd, psgld_mu+psgld_sd,\n",
    "                  alpha = 0.3, color=\"red\")\n",
    "plt.plot(resgld_epoch_uniq, resgld_mu, label=\"reSGLD xP5\", linewidth = 2, alpha=0.6, color=\"steelblue\")\n",
    "plt.fill_between(resgld_epoch_uniq, resgld_mu-resgld_sd, resgld_mu+resgld_sd,\n",
    "                  alpha = 0.5, color=\"steelblue\")\n",
    "plt.plot(pcsgld_epoch_uniq, pcsgld_mu, label=\"ICSGLD xP5\", linewidth = 2, alpha=0.6, color=\"orange\")\n",
    "plt.fill_between(pcsgld_epoch_uniq, pcsgld_mu-pcsgld_sd, pcsgld_mu+pcsgld_sd,\n",
    "                  alpha = 0.4, color=\"orange\")\n",
    "plt.tick_params(labelsize=22, colors=\"black\")\n",
    "labels = ax.get_xticklabels() + ax.get_yticklabels()\n",
    "[label.set_fontname('Times New Roman') for label in labels]\n",
    "plt.subplots_adjust(left = 0.15,bottom=0.128)\n",
    "\n",
    "plt.legend(prop=font1)\n",
    "plt.ylabel(\"KL divergence\", font2, color=\"black\")\n",
    "plt.xlabel(\"Steps\", font2, color=\"black\")\n",
    "plt.savefig(PATH + 'kl_vs_time_step.png')\n",
    "plt.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "DKlr4JJcmdz7"
   },
   "source": [
    "# SVGD & SPOS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "EVuq9UzsgTC5"
   },
   "outputs": [],
   "source": [
    "psvgd_path = []\n",
    "svgd_path = []\n",
    "pspos_path = []\n",
    "spos_path = []\n",
    "pspos_x = np.array(pspos_x).reshape((-1,2))\n",
    "psvgd_x = np.array(psvgd_x).reshape((-1,2))\n",
    "for i in range(100, len(psvgd_x)):\n",
    "    psvgd_path.append(KLdivergence(true_samples, psvgd_x[i]))\n",
    "    pspos_path.append(KLdivergence(true_samples, pspos_x[i]))\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "yJ8VBdgMbCbU"
   },
   "outputs": [],
   "source": [
    "for i in range(100, len(svgd_x)):\n",
    "    svgd_path.append(KLdivergence(true_samples, svgd_x[i]))\n",
    "    spos_path.append(KLdivergence(true_samples, spos_x[i]))\n",
    "\n",
    "PATH = \"anonymous/Contour-Stochastic-Gradient-Langevin-Dynamics/simulation_figures/\"\n",
    "import pickle\n",
    "f = open(PATH + 'particles_path.txt', 'wb')\n",
    "pickle.dump([psvgd_path, svgd_path, pspos_path, spos_path], f)\n",
    "f.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "DHvHj5SmmZqI"
   },
   "outputs": [],
   "source": [
    "x =  np.arange(100, 4e5, 20)\n",
    "plt.plot(x, pcsgld_path, label=\"pcsgld\")\n",
    "plt.plot(x, psgld_path, label=\"psgld\")\n",
    "plt.plot(x, csgld_path, label=\"csgld\")\n",
    "plt.plot(x, sgld_path, label=\"sgld\")\n",
    "plt.plot(epoch_path[99:], cyclicalsgld_path, label=\"cyclicalsgld\")\n",
    "\n",
    "plt.legend()\n",
    "plt.ylabel(\"KL divergence\")\n",
    "plt.xlabel(\"time steps\")"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "authorship_tag": "ABX9TyMXdpJKVT6ZdpE9I0S0/GLF",
   "name": "Plot.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
