{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [],
   "source": [
    "%reload_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "import jax\n",
    "import jax.numpy as jnp\n",
    "\n",
    "import matplotlib.pyplot as plt \n",
    "import optax\n",
    "\n",
    "from markovsbi.tasks import get_task\n",
    "from markovsbi.utils.sde_utils import init_sde\n",
    "from markovsbi.models.simple_scoremlp import build_score_mlp, precondition_functions, precondition_functions_v3, precondition_functions_v2\n",
    "from markovsbi.models.train_utils import build_batch_sampler,build_loss_fn\n",
    "\n",
    "from markovsbi.bm.plot_utils import use_style"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[cuda(id=0)]"
      ]
     },
     "execution_count": 55,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "jax.devices()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [],
   "source": [
    "key = jax.random.PRNGKey(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {},
   "outputs": [],
   "source": [
    "task = get_task(\"simple1dstationary\")\n",
    "prior = task.get_prior()\n",
    "simulator = task.get_simulator()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = task.get_data(key, 1_000_000, 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAFsAAABUCAYAAADtYEtMAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy80BEi2AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAIGUlEQVR4nO2cS2wU2RWGv+pyt9vdbmMwg3mMTTKOJjEKCwYQEmLBY2kkIiG23iBLbJDYsIINAsQKCSlLJFggsUBCSkBEZGxlAQR5cGINTGRHYyaQOLz8wPS73dVdZxa3bIMDdren+lZ1pj6p1Q9X3XP7r1Pn3nvObRsiIgRoIeR1B35OBGJrJBBbI4HYGgnE1kggtkYCsTUSiK2RQGyNNHjdgRUhNuR/gMIPYM9CeC3Et0JDi9c9WxKj7pbr5Sy8vQuzL0HKgK0+DzVC6z6Ib/G0e0tRX55tF2DqD2BNK48W672/FeHtnwEb4r/1qodLUj8xWwTefu0IXQBsiHZB81cQXq/e2wWYGYDZV1739qPUj9jZf0Dh345H2xDrhsg6FT6iv4DIRkCgnIfpO8rTfUZ9iF3OQPIBSEmFjugvwUws/N0woLETzFWAQGkaUt941t1PUR9ipwaVyDKrBI2s+99jDAOavgCjQQ2c6SGw3urv6xL4X+ziJGRHHbGBpl99+thQFCIbAEOFkeQDXb2sCP+LnRpEDX5FCLdDKLL08ZENKo5jQ/57KE7o6GVF+FtsawoKz5VXGyFo7Fj+HMN0BksDbMtX3u1vsdN/R3m15Xi1Wdl54c9USMGGwr9UKPIB/hW7lIT8mLNwMSCyqfJzjZATu0PqQqUGa9XLqvCv2OlhNauwLcdTq1zshtc68V1U7C4la9LNavCn2OUs5EbVvBoqi9WLMUwVeuZid/pvrnZxJfhT7My3SmjbgoY1EAqvrJ1Iu5p3Y0P2Oyjn3Oxl1fhPbLughJESIGpluFKMBhWC5trNfOtGD1eM/8TOfKfm1GKp1aIZ/WntRdarkCK2Ggc8zJn4S2zbguxjNTBKGaI/wavnCDVCQ5vTflbdNR7hL7FzoyprJ5ZKNJlxd9qNbHS8uwypRwsDr2b8I7aUITPsZPZKEPncvbbNJmhYrV6X05Adca/tKvCP2PkxKKWVV4diEG51t/3IRiCkLmTqGxXDNeMPscWG1JATq0vQWMVqsVLMODS0qtelGch9776NZfCH2Ll/QukdSFHlNOYGNLeZW8JLCVJ/dQrG+vBe7Pnbes6rO1UhoBaYiYXtDta0KrVpxHuxM09U2UuKTqyukVeDUz7rYN67kw+1zru9FbucgfSjhRlI4+ba2zSbFy5oOaWmgprwVuzkA2e1WFSrRbdnIJ+i8fOFeXd6SFu+2zux808hN5evFmjq0mc7FF3Ij9uzMPO1lsHSG7FLaZj5C1BWnh3Z6NQNNRLZoEIKttrKlnxYc5P6xbYteHtHZeHsWTX/DddgXr0cRkjtPzFMdXelh9TdVkP0ii0ltSmyOKGEBmj6EkIeRTOz2UnhGqo/03+Cwn9qZk7ft7Rn1bawwnM1IErJEVpz+FhMuN3JeYsqLkzdqpmH69kyPPsKZvrVKnEuVx3tgshnNTddESJObmZahZdQFBI7ILFr5VWij1A7sUXAeqOqI7kx1JaEWfUc/WKhguIXxIbCM7AmAMPJg6+GxFcQ+40rd6C7Yhdfq/111oTacVpKAbZTTywqL4n+GhqaXTPpKiJQfAWz44CtympGWFXpGzvUDCayDqIrW3y5K/Z/f4/akCcL1RYpofZ9tEOk07vBsBrsvBpbSu/Ue8N0CschFWY2HVfPVeLKLw9EhFwuB5mkszhwrp8RBjMGZhsUTCi8dsOcJlaBNCrByxnUz0kM9Rj/I7GO32FUmTBzRexcLkdzs09DQ43IZDLE49WV7Vy5p7PZrBvN1BUr+c6ueHYsFpt//ebNm6qveL2QzWZpb28HPvzOleKK2O/Hrng8/n8r9vtUG6/B6xTrz4xAbI3U3y9865jAszUSiK2RQGyNBGJrJBBbI66LPTw8zPbt21mzZg2tra3s3r2be/fuuW3GE44fP05HRwctLS1s2rSJEydOUCxWsclHXGZqakqeP38utm2Lbdty8+ZNaW5ullwu57Yp7YyMjEgmkxERkcnJSdm7d6+cPXu24vNd9+y2tjY2b96MYRiICKZpkslkeP3a/fTqy5cvOXz4MLt27WLHjh2Mjo66buN9uru751MRIkIoFGJsbKzyBmrjAyKrVq0S0zQFkN7eXtfbz+fzsnXrVhkYGBARkatXr8r+/ftdt7OYCxcuSDweF0Da2tpkaGio4nOrErunp8cpw3z88ezZsw+Oz+Vycu3aNbl8+XI1ZirizJkzcuzYsfn3w8PDYpqmWJbluq2PMTIyIqdOnZLx8fGKz6lK7GQyKZOTk598lMvlj563ZcsWuX//fjWmlqRcLsvatWvl0aNH85/dvXtXAEmlUq7ZWY4bN27IgQMHKj6+qhRrS8vK/pWbZVmMjY2xZ8+eFZ2/mMHBQd69e8fJkyfnP3vx4gWxWIxEIrHEme4y970qxu2rffv2bXn8+LFYliXZbFbOnz8vTU1N8vTpU9dsXLx4UXbu3PnBZ729vbJv3z7XbCwmnU7LlStXZGZmRmzblidPnkh3d7f09fVV3Ibrs5GpqSmOHDlCa2srnZ2d9Pf3c+fOHbq63NulOjExMV8xATUzGBgY4ODBg67ZWIxhGFy/fp2uri4SiQSHDh2ip6eHS5cuVd5IzVyhhpw7d056enrm3/f398vq1atlcnLSw14tT12K/fDhQ1m/fr3k83lJp9Oybdu2msx43KZuiwenT5/m1q1bxGIxjh49Sl9fn9ddWpa6FbseCbJ+GgnE1kggtkYCsTUSiK2RQGyNBGJrJBBbI4HYGgnE1kggtkYCsTXyI0zP0DGQCOOUAAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 75x50 with 1 Axes>"
      ]
     },
     "execution_count": 59,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = jnp.linspace(-3, 3, 1000)\n",
    "pdf = jnp.exp(jax.vmap(prior.log_prob)(x))\n",
    "\n",
    "color = \"#ffdf7fd9\"\n",
    "\n",
    "with use_style(\"pyloric\"):\n",
    "    fig = plt.figure(figsize=(.75, .5))\n",
    "    ax = plt.gca()\n",
    "    plt.plot(x, pdf, color=color, lw=2)\n",
    "    plt.fill_between(x, pdf, color=color, alpha=0.5)\n",
    "    # Disable y axis\n",
    "    ax.axes.get_yaxis().set_visible(False)\n",
    "    # Disable y spines\n",
    "    ax.spines[\"left\"].set_visible(False)\n",
    "    plt.xlim(-3, 3)\n",
    "    plt.xticks([-3, 3])\n",
    "    plt.xlabel(r\"$\\theta$\", labelpad=-8)\n",
    "    fig.savefig(\"prior.svg\", bbox_inches=\"tight\", pad_inches=0, transparent=True)\n",
    "    \n",
    "fig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAGMAAABTCAYAAACLQbk4AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy80BEi2AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAJHUlEQVR4nO2bS0wUeR7Hv1X9ot80r4YWZsWBmeyDZNm4ZF0TNcYYo1EJMTFojCcuxoMXT5hw8eLB6IGjHuS0MVnismvGTBxAWZlRYVDQMKMjIAIt0DTd9KP6/dtDQz/QBhqqqX+79Uk66a7+979+1Z/61f9VxRERQYYJeKkDkEkiy2AIWQZDyDIYQpbBELIMhpBlMIQsgyFkGQyhlDoAMYhEYxiecuPdvBeBcBSlRg3qv7LAaiqQOrSs4PJ9OsTlD6HrxQwWfKG07RyAv9eU4K87LeA4TprgsiSvZXgDYfzj+Qd4ApGMZfbsKsbfvi7exqg2T962GbEY4T/D9oQIlYLD7yuMaNhpwY7C5OXpx7EFjDt8UoWZFXkr4/mEE3Z3AACg5Dn8ocIEs1YFnudQVaRDVZE2Ufa7ETuEUFSqUDdMXspw+UP4aWwh8fnrUj20akVaGZu5AIU6FQAgGImh7+38tsa4GfJSRt9bB2LLLV2xXg2LXv1JGY7jsKtED3657X49s4SPy5nEKnknY3pRwG9zXgCAguOws0SXsaxayaPKkvz+MePZkXcy+t85Eu/LCzVQKdY+BKtZA40yXmZ6UcDUoj+n8W2FvJJhdwmYWhQAxHtPO8zrD+p4jsMOS7Ix73vjWKO0tOSVjGcTzsR7q6kAPL+x8EsNahSo4mU/LgUw7RJyEt9WyRsZDm8QY/Px8YKC51Bh1mz4txzHYUdhMjt+erewRmnpyBsZz1OyosyohmKDWbFCsUENtSLetZp0+rHgDYoanxjkhQy3EMavHz0AAJ5D2lm+UXiOQ4U5+bunY841SktDXsgYnHBiZQat2KCBcp0eVCbKTBoolgceb+Y88ATCYoUoCszL8AUjeDWzBADgOKDKsvlpcQXPodwUb2uIgOcTi6LEKBbMyxiadCG6PNy26FRQKxXr/GJtrKYCrEyov5p2IxBmZ86KaRnBcBQvp1yJz6mj6c2iVvIoW86OaIzw83t2soNpGcNTboQiMQCAWav8ZDJws1SkDBaHJl2JfUgNszIi0Rh+nkyetZWW7HtQmShQKVBiiE8uhqIxDKdkn5QwK+P1zBL8y2sQBo0CxgKVqPWndo8HJhYRiUqfHUzKiERjeDaeHAeImRUraNUKFOnjgoVwFK9n3KLvI1uYlDE85YY3GF9O1akVKNR9ul4hBraU7PhxzCl5djAnIxRJz4qvirbeg8qEQaOEZXk1UAhFMTgpbc+KORkDE04Iy31/o0aZWDrNFVUpsp+POyVdK2dKhssfwkDKqPh3xeK3FavRqRUoM8bHHeEo4b8SrgYyI4OI0PPLHKLLk1BFejUMIvegMlFp0SbWyl/NLEm2GsiMjFfTbkwsxP8EBc+heo21bbFRK/m0y9X3r2cRlqAxZ0KG0xdC76/Jy0OVRbvu2rbYlJs00GviI3y3EMYPo7PY7pstJZcRCEfxrxfTiCxPBpq1SpRvYG1bbDiOQ02pIXG5GrV7MDy1vWMPSWUEI1HcG5qGyx9fV1ArONSUGSSLR6tWYFepPvG5+5e5xKLWdiCZDG8gjH8OTiVu0VRwwDflxm2/PK2mxKBJm0j8bsSOlx9c27Lvbb8LnYjwbt6LH0bnEnNPPAfUWg2w5GiknS1EhLF5H+a9yccMvi034sC3pdCpc/dIy7bJiMYIEw4fBt8vpt0qo+Q51JYZYM7x4C5biAiTTj/s7uSNC2oFjz9/VYg/7TDDrBU/3pzK+G3OiwVvELNLAXxwCgit6i7q1Ap8YzWgQCXOOkUucHiCGHf4EF31L5UZNagoLIDVVIA/2syi7Cunj5H9++XMZ7erFBxshQWwGjUbvhFNKkqMGpi0Kkw6/XCkXLbmPEHMeYIA3KgtM0Kt3Ppx5EQGEcHv9yMoJEeyPAdolDz0GiXMWhVioSDsC+zdu5QJLQeU6wBPIAxPIIJwSqp8/3ICx/5SveXH1XIiw+/3w2CQrosqBV6vF3q9fv2Ca5CTa4TPlx+PbYmJGMeck8zQ6ZLzPLOzs1s+Y1jF5/PBarUCSD/mzZITGanXTr1e/8XKSEWMx5vZ7sr8nyHLYIi8fij/S0PODIaQZTCELIMhZBkMIctgCNFl2O12nDhxAjabDRzH4cWLF2nf9/b2guM4GAyGxOvixYtih5ET2tvbsXv3bmg0GjQ2NqZ9d+DAAWg0mrTjmpn5/Kx1JkSXwfM8jhw5gnv37mUsYzab4fV6E6/29naxw8gJNpsNV65cQUtLy2e/v3btWtpx2Wy2rOoXXYbVasWFCxfQ0NAgdtWf8OzZM1RXV6OpqSmxrbOzEzdv3szJ/pqamtDY2IiSkpKc1C9Jm7Fy1lRWVuLs2bOYnp7eVD0NDQ3o6upCV1cXRkZGMDAwgCdPnuDSpUviBrxBrl69iqKiItTX16OjoyP7CigLjh07RgAyvsbHx9PKA6ChoaG0bXa7nUZGRigSiZDdbqfm5maqr6+naDSaTShpnDp1ig4dOkTnz5+ncDic2N7d3U1v377ddL2ZaGtro5MnT6Zt6+/vJ5fLRaFQiB48eEAmk4k6OzuzqjcrGW63m+bn5zO+Vv+hn5OxGo/HQzzP0+joaFaBp9LT00MAqL+/P237mTNntk3Gai5fvkynT5/Oqt6sptBNJlP2qbcOW516FgQBt27dwtGjR3Hjxg3s2bMHANDX14eenh7U1tbi4MGD2LdvnxjhbphNre1v4QTJiCAIJAgCAaCnT5+SIAiJrOnu7qaxsTGKxWLkcDjo3LlzVFdXR5FIJOv9xGIxamlpoTdv3tDw8DDxPE8jIyNERBQIBGjv3r2iHlc4HCZBEKi1tZWOHz9OgiBQMBikxcVFun//Pvl8PopEIvTw4UMym8109+7drOrPiQx8pj3p6ekhIqLr169TZWUl6XQ6Ki8vp+bmZnr//v2m9tPa2kqPHj1KfG5sbKTq6mrq7e2lwcFBamlpEeNwErS1tX1yXPv376e5uTlqaGggo9FIRqOR6urq6Pbt21nX/8VOoXd0dMDtduPw4cOoqamBQsHuvVkrfLHTIcXFxXj8+DHu3LmTFyIAeXGJKb7YzMhHZBkMIctgCFkGQ8gyGEKWwRCyDIaQZTCELIMhZBkMIctgCFkGQ/wPPij20IwaNd4AAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 75x50 with 1 Axes>"
      ]
     },
     "execution_count": 60,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = jnp.linspace(-15, 15, 1000)\n",
    "pdf = jnp.exp(jax.vmap(lambda x: jax.scipy.stats.norm.logpdf(x, 0., 5.))(x))\n",
    "\n",
    "color2 = \"#8ebad9ff\"\n",
    "\n",
    "with use_style(\"pyloric\"):\n",
    "    fig = plt.figure(figsize=(.75, .5))\n",
    "    ax = plt.gca()\n",
    "    plt.plot(x, pdf, color=color2, lw=2)\n",
    "    plt.fill_between(x, pdf, color=color2, alpha=0.5)\n",
    "    # Disable y axis\n",
    "    ax.axes.get_yaxis().set_visible(False)\n",
    "    # Disable y spines\n",
    "    ax.spines[\"left\"].set_visible(False)\n",
    "    plt.xlim(-15, 15)\n",
    "    plt.xticks([-15, 15])\n",
    "    plt.xlabel(r\"$x_t$\", labelpad=-8)\n",
    "    fig.savefig(\"proposal.svg\", bbox_inches=\"tight\", pad_inches=0, transparent=True)\n",
    "    \n",
    "fig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAGMAAABTCAYAAACLQbk4AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy80BEi2AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAHBUlEQVR4nO2cT2gT2x7HvzNJE2/S1nu1vEgoj1YQF/7BgmTzFvGCC1HQUApSXYS36EYEdeFKIS7cuBBFuhBUHhVdiFCkunBzTayYRfEfVqj3Cs/b+17Nrb3tTW0yGSeT/O6izWSmmdx22k5zPJwPDEzOzJycM585v5NzZiYSEREETCA3ugCCKkIGQwgZDCFkMISQwRBCBkMIGQwhZDCEkMEQ3kYXwC2KpTJ+GpuEopXw4/Z/4Iegr9FFWhJuW8bo/2cxlpnD+LSCn95PNro4y4JbGf+dyhnr/5spoFRmfwqOWxnTec3y+U9Fq7MnO3Apo6CVoGglS9rUF7VBpVk+XMrI2rSCP5ViA0riDC5l5L7qNWnZgpDREPKLQhQA5NRaQazBpwyblmHXWliDSxl2J75g01pYg0sZ5paxoWm+ilqpDL1cblSRlgWXMiotQwIQ8HmM9PxXtlsHlzIqLcMjS/B5qlVkvRPnToZeLkMtzocjr0eCz1ut4heV7Z+33Mkwh6Imj2yVwfhYg0MZ1VDU5JEtYUrIWGfMMnyLwhTrYw2+ZXitLcNuMMgS3MnImfoMv1eGLEvwyhIA+2kSluBORl6rXv3+pvkxRiVUqcUSWH7Omz8ZplDkX5DQtBCqygQUiuy2Dq5lVMLTt9KJcyhj/sr3yhIkaUGGRzK2szwK50pGqUxGGPLIVQGWliFkrA+KZh3wVTDLmGV44MeVjC8F6xjDWPd8G/NTfMkwnegNXvuWIcLUOmEOQZWbSkClM59fZ3ngx5UM80Tgd03Vm0qSVL2vYe5XWIMrGeanBjeY7vAB1VBVLBGKJTZvv3Ijg4jwx9y8DK8sWX5NAdZ+I8voA23cyPii6tAWrni/t7ZaQVNL+ZQtrFu5nMCNjIzpBAcWhSgAaPZXX0WZEDLc5dfpvLH+faCpZnvQ70VlTP7btMLk7C0XMrKKhg+T8+9jyBLw/Xe1MjyyhI0LkgrFEn7+fW5dy7gcJDf/yOU/zz+ikjsBqHyofGF1G8FcCnO6+XPl2EqRK+m66UWYHwJN2L6lxbY8M3kNv0xWX6IJ+jzweGRImH/GCpLtYUvy7391ruzARbjyTh8RQVEUTE7PupF9XbyyBH8AmJjK1i2Xt/zVmNn9ukZdx5PRcfy485/GLPFKcUWGoihobm52I2tmyeVyCAaDq8rDlT4jn88vvRNnrEWdXWkZgUDAWJ+cnFz1FcMq+XweoVAIgLXOK8UVGebYGQwGuZVhZrX9BcDJT1teEDIYwtVxhsAZomUwhJDBEEIGQwgZDCFkMMSay8hkMjh8+DDC4TAkScKbN28s21OpFCRJQnNzs7GcPHlyrYvhCv39/di7dy/8fj9isZhl2759++D3+y31+vTpk6P811yGLMs4cOAAHjx4UHefjRs3IpfLGUt/f/9aF8MVwuEwzp8/j76+Ptvtly5dstQrHA47yn/NZYRCIZw4cQKRSGTVeY2MjKCzsxPd3d1G2uDgIK5evbrqvFdCd3c3YrEY2traXMm/IX1G5appb2/H8ePHMTExYbtfJBLB0NAQhoaGMDo6ihcvXuD58+c4ffr0+hZ4mVy8eBGbNm1CV1cXbt++7TwDcsChQ4cICzfb7JaPHz9a9gdAr1+/tqRlMhkaHR0lXdcpk8lQb28vdXV1UalUqvu9PT09tH//forH41QsFo30J0+e0IcPH5xUgYiIRkZGaPfu3fTw4UPHxxIRJRIJOnLkiCUtnU5TNpslTdPo8ePH1NraSoODg47ydSRjdnaWpqam6i6LT6idjMXMzc2RLMs0NjZWd59kMkkAKJ1OW9KPHTtWV0Y8Hq+b3/379ykej6+pjMWcPXuWjh496ihfR2GqtbUVbW1tdRdZdh71lpp6LhQKuHnzJg4ePIgrV64Y6c+ePUMymcSdO3cwPDzs6Dt7enocl9MpKzkXrtzPUNXq/wFqmgZVVeHz+SDLMpLJJDo6OtDR0YGZmRmcOXMGO3bswLZt22ryISKcOnUKiUQCqqpiz549ePfuHXbu3IlIJIKtW7fiwoULblTBFl3XjaVcLkNVVciyDEVRkE6njZ+3qVQK169fx40bN5x9wYra6RLApj9JJpNERHT58mVqb2+nQCBAW7Zsod7eXhofH7fN59y5c/T06VPjcywWo87OTkqlUvTy5Uvq6+uz7K/rOkWjUYpGoxQKhYz1R48e1eS9kjCVSCRq6hWNRunz588UiUSopaWFWlpaaNeuXXTr1i1HeRM57DNYYmBggK5du0bv378nXddrtv9dn1HZvtI+wy2+2emQzZs3Y3h4GAMDA/B4ah/n/Dvu3r2Lt2/f4t69e3j16pVLJXSOuLnEEN9sy+ARIYMhhAyGEDIYQshgCCGDIYQMhhAyGELIYAghgyGEDIYQMhjiL8EkhZUpIFP1AAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 75x50 with 1 Axes>"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = jnp.linspace(-15, 15, 1000)\n",
    "pdf = jnp.exp(jax.vmap(lambda x: jax.scipy.stats.norm.logpdf(x, 2., 1.))(x))\n",
    "\n",
    "color2 = \"#8ebad9ff\"\n",
    "\n",
    "with use_style(\"pyloric\"):\n",
    "    fig = plt.figure(figsize=(.75, .5))\n",
    "    ax = plt.gca()\n",
    "    plt.plot(x, pdf, color=color2, lw=2)\n",
    "    plt.fill_between(x, pdf, color=color2, alpha=0.5)\n",
    "    # Disable y axis\n",
    "    ax.axes.get_yaxis().set_visible(False)\n",
    "    # Disable y spines\n",
    "    ax.spines[\"left\"].set_visible(False)\n",
    "    plt.xlim(-15, 15)\n",
    "    plt.xticks([-15, 15])\n",
    "    plt.xlabel(r\"$x_{t+1}$\", labelpad=-8)\n",
    "    fig.savefig(\"transition.svg\", bbox_inches=\"tight\", pad_inches=0, transparent=True)\n",
    "    \n",
    "fig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAMYAAABYCAYAAACu94huAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy80BEi2AAAACXBIWXMAAA9hAAAPYQGoP6dpAAATvElEQVR4nO2de4xUVZ7Hv+fcV916dHVXt3ZTtNigbhhsRphl2uBObJNNxGVRGzQmZmKGNbJkZnzN/DtMwEjiMMnMGMMfZMSsQ9SEMTMYIGazURoRmY2goML64CXQPJt+1+PWfZ3949x7u6r7Nt08uqq6OZ+k0rfuq35Q53t+v985v3OLMMYYBAJBCbTSBggE1YgQhkAQghCGQBCCEIZAEIIQhkAQghCGQBCCEIZAEIIQhkAQghCGQBCCEIZAEELVC2Pjxo1YtGgRNE1DR0dHybEHHngAmqYhHo8Hr3PnzlXGUMG0ouqFkU6nsWbNGqxatSr0+IYNG5DJZIJXOp0us4WC6YhcaQPGY8WKFQCAQ4cOoaur67rvl81mg23GGHK5HGKxGKLRKAgh131/wfSg6j3GeKxfvx6pVAoLFy7Eli1bxj2/OOxKJBJobGxEPB5HLpcrg7WCqULVe4wr8corr2DevHmIRqPYtWsXnnjiCSQSCSxfvrzSpgmmOFPaYyxevBjJZBKKomDJkiVYvXo1tm7desVrivORixcvlslSwVRjSnuMkVA6vs5jsVgZLBFMdareY9i2DcMwYNs2XNeFYRgwTRP9/f14//33kcvl4DgOPvzwQ2zatAmPPfZYpU0WTANItS9tXbduHV566aWSfe3t7Xj33XexbNkyfP311wCAlpYWvPjii3j66acnfO9sNot4PA6Ah1g3uzdxGYPrMv6XFb8HGLy/jIExgIFvA3wbI1sRAQgAQoj3F6CEgBACSgCJElBCIFECmZKqGxGsemFMJjebMBhjsF0Gp/jF+F/XZaPadjmhBFAkClmiUCUCRaIVFcu0yjEEHF8AtsNgu26w7Y7TBxZ7A8Db9ryEC74dnMdP4W4Bw96Bep6BEgLqeQOJArJEQQnAQAJP5AuTMcBlQMF2UbBd+DNNmkwRUSRocvlFIoQxxfFFYDkuLIfBdrgQxjqXeSESYwi8BRcQ3x7zOgBgoRET/DY7XuMlADSFN/aoKiGuydAUCa5b/G/gr2KhUALEVBm6KpVNIEIYUwzGGCyHwXRcWDZvRGHN2ReB4zc624XpNbiS42y45/bzCwaU5BkTDbb9fEGmPBSSKYUsAZTQIFQzLBeG5aI/ZwHgXqE2qiAVUxHT5MA222UoWA7ylgOXAUMFG3nLQW1UhUQnXxwix6jyHIN5DbdguzC9xh1yEgghsF0G0+tlTXv4PC4OP6Ti4rBdtySZHhbGsBj8v8z7jCByIgDxsmtaFD7JlECSeBg1krgmIxVTUacrcBmQNW3kCk4gakqAxpoI6uNqiVdgjCFvOcgYNhi4+Opj6qR7DuExqhDfKxRsBwXLhTOi76KEhy22w3vfnOkEYRBjDJbnIWzmBmGSLzA/3yhOwP0RJgYMj0Kx0v0AFwMhPHSSvBwiTARRVUIiIiOuSbBdhrzlIlOwkSnYOAMgXavjrlvjIIRgIG+hJ1OAYbk4P2DAdFyka/XgXoQQRFUZmkzRmzXhuAw50wm8y2QhPEYVeQzLcWFYDgwvfChGlfiUk2E5yBRsGBb3CH6DN73cwrCdoLe3HBaEW1wcfH9xUu44w2HUtTQETaaIaTz+l0BQGOHRGms0/GhWHfKWgwsDBnpzJgCgJiLjn29PQaIEjDH0ZEycHzAAAHfcEkM0pOHnCjaGCjZUmaIuql6DtRNHeIwK44cKedMpSZr9RFUiBDnTwWWvV/Wvsb3e37CcILyyHR5GWY4L0+b3clwG03ZgjpOYF3/uSCfgjzJJFJC86gL/njxBNoEsF8l9d9ajPqbhTG8Op3tzuDhYwO5vu/HvP5yBdK2OvqyJL8/2Y9CwcX4gj+Y6XtXckNCQNW0M5m1kC06oMHzPWY5ZaSGMCuG6DDnTRs5ySpJbf4jScXkvOpC3gmN+WJPxElH/PobtwLJ5IwXghVhcMJYzWgiUYJRHCj4DGJVs82FeBjgAwD8jrsl4qPVWaArF2b48vrkwhN6sic5vuvHEottw350N+GHBxq6vLyFTsHG6N4c7b42jLqaisSaCrr48cqZT8hkFT/iSNDo8sx0Xee98TZHG/o+9QQhhlBnGGLKmg1zBDkIXiRDoqgRdkWDYDs73G8gU7OAaTaawHBfdWTPo8R3XheMCAzkryAX85Ly4wfnXm/bw6JUviohMkYwqkCkNvAQlBFFVQrpWx8zaCODlMnnLxWDewqWhAk50Z5Ap2Pif/7uA/7z/DjTXRdE2O4V3D3ThbH8eZ/pymFmnB+EfMOyFBvNWEDIldcWzh6GrN4+C7UKiJNjvYzku+nMmGHhIqcmT7zOuWRh79+7FT37yEwD8S6m2Kf1qpGA5GDSsoGHKlCCm8cTSZcC5gTz6ssMeojbKG8jJy9kgXNJkCsdl6B60goauqxQ9GRNDBhcTJUBTTQQDhoVswQk8SSqm4vb6KGbW6sgUbBw83Y+Lg4VQW7+5MIQH727E3elkyX7bdRFRKA5838fF5uUs/3uiF2f78wCAmbU6erMm/nG8B5mCDV2R0Fyr42xfDt9eHILLgLqoglsTGgzLQVdvPvCAt9XpwXAsYzzR9jsJ2RNNOdraVQvjxIkTAIBdu3YFwjh69Ci2bNmC9evX31jrpgmMMWQKdtCTUwIkIkowo2tYDk5dzgWNv9ZrNCd7sujq440tqkq4rS6Kw2cH0OfNAaRrI4hrMj45dhku46L50e11qNUV/PeRC2DevvnNSdydTiIV4wnrpSEDO788P67dxWHYQN7CN+cH8WXXQNBQF7Wk8NXZAXx2qi8Q5Q+bk+jqy+P7y1kwALpCsWBWEl90DWDQ4HanYirmNSVwYcBAT4Z7AkqAWfVRJCK8MzBtB0OGHXhITaao0ZXQUbDJ4KqFcfr0aWzevBmdnZ344IMPMG/ePLS2tuLo0aOTYd+UhzGGgbwV9Nr+jK/f6xVsBye6s3BcBkUiuC0VRUyTcfTSUCCK2Q0xtKSi+Oi7y+jLWdBkisV31CMVU/H6nhNwGfBPjXH86w8aEVEkbN1/Gozx6x5qbUJkREyuyRJkSsZMxGOahHkzahBVKD769hJO9ebQkzGD4xGFolZX8fmpvuAeEYWiMRHxPBD3QrckNERVCce7eZGHRAhmpXREFAnHurNBLpOIyEjX6lAkgoLtIFcYHlAgAOIRGbpSvllv4DqGa/fu3Yu2tjYcOXIEhw8fRn19PZYuXXqj7ZtUyjFcO2RYgadI6kpJI2WM4UR3FjnTga5IaGmIQpYoDMvB3mOXAQCt6Ro0JXUMGRZ2fHEelAD/Nn8GkrqCs/15/HX/GSgSwX/8y+xgbH/nF+dw9FIGlADNdVHMSEaQjCqIqTIUr2ZpsGCjqzeHjGEHk32uy2C5PJco2KObhSpTWHbpTHtEodAkCZoyXM9UE+Gf4yfRlAANMQ0RlaJgDV+tKxIakxpimgwjZGROV3gnQssw0z2SCXuM1157Dc8//3zw3g+jMpkMnnrqqRtv2TTAn4wCRosCQJAoE/AwQvaSVT8siakSmpJ8sot41XoMPIGtici41euRc6aD//rkJO5qTOD2VBQLZ9WhYLs47Q2Znu69tvXsI0ev/Nl0xat+jShSUDKuSgSKN6Lm5wiaTBHXZEheubkvikRERn1MhSITFGwX3UOleY6uSIhpcllKP8Ziwh5j27Zt2LdvHzZs2BCslDty5AgeffRRHDt2bFKNnCwm22PkTBtDhg1FIkjFtNDjxy9lIVGCuTMSQfycM23sO94DALjjljhur4+CEoK9Ry8HjdxPpFWZ4h/He0Y1LoD35pSQYGLPr3/yyzvChmbDkLzaJ1UiUGUazGWoEoUi832+QBSJQFckSJRCKVpnocoUtbqCmCbBZQgmIos/Q1ck6KpUtjziSkzYYyxfvhyzZs3CM888g9WrV+OPf/wj/va3v+HRRx+dTPumNt4XP1Zs7Pe4tstwpieHmSkdMqWIqjJa6mP4vieL490ZdPXlkK7V0ZquQUyT8K03Z9Cb5XG/RIDmOh0Fm4/1500HDmPBhOBEKC7z4KXixCsCJN4CIwTFgYrnHai3CCmicLGoEi1ZdOR7DF2lYODJfLZoKJkQICJLiCgSFKm6FitdVfIdiUTQ1dWF++67D0uXLsWhQ4fQ2to6WbZNeRSZAgUeghiWMyqUooSguU7HqZ4cBg0bmfNDvNAupmJOQxRRVcKxSxkUbBcnL2dxEty7tDREUbAZsgXukRyXNzpK+PBvVJWKigHZqMk8f3a7eP2E3yj99xIlkCW/MJBC8s6X/f2EBiLxr5UonwOJKNxbuOChmFk0ukUJT/41hQupmsRQzISF8dhjj2HHjh1Yvnw5du3ahc2bN0PX9fEvvIlRJApdkZC3HAzkLZi2Oyp2TugK5twSw9n+PAzLxeWMicsZE4pEENdk/GBGAlnTQX/ORF/OguMyZAq811VlilRM8aplS5el+lWzI/EFQeCJwCsELN6Gd0ymNPAgMiVQvPc+qswn2/xQijEWzOQV5+6qxM8pDrmqnQnnGHPmzMH27dsDD2EYBl544QU89dRTQSI+1SjHqNTIOQxguOxD9cIR/7xBw0Zf1gxKrEdCwBuev7AnKAT03o9XCMjFQIrWX/OknlIvjPI8hURLq2Yp4SJQvFBJoiQQWNhn+J7Ev2YqCGEkExbG/v378eMf/3jU/t/+9rd4+eWXb7hh5aCc1bWm7SJTsEbVLvkjPLLEww+JEriMV5JmvXzBsB3YITVPYZQ8oMDDb5ZXaqBycehU5EUIGVsE/nXcfhIIZyoKYSSi7LzMZed+aXnYOgufkU/R4A2NwXF5MR0vE0dp6IQxQid40Q0hvCrVf3pHsH3lhu/jexK5KCmfLiIIQxQRlhkeZlAkIrzuiK+X8MIix19WCriOP6AaDiG8ClXCjWmY/qiUL4CS7WksgLEQwqggMqWQVQp/CMMfQXI8T+CEJNNuUajE/CcUBJDhBxNgZKI9nFv4IdLwKrzxPcbNhhBGFcEbKyo64yvgVP0jOgWCSiCEIRCEIIQhEIQghCEQhFDVwjh//jweeeQRpNNpEEJw6NChkuO7d+8GIaTk58OeffbZyhgrmFZU9agUpRQPPfQQ1qxZg3vvvTf0nGQyif7+/vIaJpj2VLUwGhsb8Ytf/OKG3rP4V1szmUzofsHU53p/hbeqhTER/N/2ppSivb0dv//97zFz5swxz/dLQEbS2Ng4WSYKKsD1lvhULMdYtmyZV6AW/vr+++/HvcfcuXNx6NAhnDlzBgcOHABjDA8//DBcd+ILdATTk+uNACpWRDg4OAjTNMc8nkqlSn5skhCCgwcPYsGCBWNek8lkkEwmceTIEcydOzf0nOL/sKGhIcyYMQMAcOHChTG9iWBqkM1mA88/NDR0Xd9nxUKpmpqaG37PicSUY7nXeDxe8Yc6C24c11v7VdXDtQBfEGUY/JGOpmnCMIwgVOrs7MTJkyf507J7evDzn/8cd999N+66665KmiyYBlS9MHRdD5bQ3nvvvdB1HXv27AEAHDx4EPfffz/i8ThaW1th2zZ27twJSZr8h/4Kpjc39UIlgWAsqt5jCASVQAhDIAhBCEMgCEEIQyAIQQhDIAjhphWGKGmf2mzcuBGLFi2Cpmno6OgoOfbAAw9A07SS7+7cuXNXdf+bVhh+Sft777035jnJZBKZTCZ4bdy4sXwGCq5IOp3GmjVrsGrVqtDjGzZsKPnu0un0Vd3/phWGX9Le1tZWaVOuiU8//RSzZ8/GihUrgn1///vf8eqrr1bOqDKyYsUKdHR0oKGhYVLuf9MKYyL4PU1zczN++tOf4uzZs5U2KaCtrQ3bt2/H9u3b8dVXX+HAgQP45JNP8OKLL1batKpg/fr1SKVSWLhwIbZs2XLV109LYdwsJe3z58/H8uXL8etf/xobN27Ehg0bgmOdnZ3X9IM++/fvxz333IOdO3feSFPLyiuvvILjx4/j4sWL+N3vfofnnnsO27Ztu6p7TEthvPPOO+ju7h7zNWvWrHHv0dTUhNbWVkiShKamJvz5z3/GF198ge+++64M/4KJ88tf/hIffPABVq9eDVkeLpbevHnzmNesXLlyzGOnTp3CwoULb6SJZWfx4sVIJpNQFAVLlizB6tWrsXXr1qu6x7QURk1NDRoaGsZ8Fa/zmCjV+AjLfD6PzZs3Y+nSpfjTn/4U7P/444/R2dmJt956Kyi4nCiPP/74jTaz4lzL9z3ll7ZeD345OzBc0q6qKiil6OzsREtLC1paWtDb24tf/epXVVXSzhjDCy+8gLVr18IwDCxYsACHDx9Ga2sr2traMGfOHKxbt67SZk4atm0HL9d1YRgGKKXI5XLYt29fMGS7e/dubNq0Ca+//vrVfQC7iQEw6tXZ2ckYY+wPf/gDa25uZtFolDU1NbEnn3ySnTp1qrIGF/Gb3/yGffTRR8H7jo4ONnv2bLZ792722WefsVWrVpWcb9s2a29vZ+3t7ayxsTHY3rlz56h7/+xnP2M7duyY9H/D9bB27dpR3117ezu7dOkSa2trY4lEgiUSCTZ//nz2xhtvXPX9Rdn5NGTLli0YGBjAgw8+iDvvvHPU+pSVK1fizTffHPP6lStX4vHHH8eyZcsm2dLqZVrmGDc79fX12LNnD/7yl79c9aKtt99+G19++SW2bt2Kzz//fJIsrH6ExxAIQhAeQyAIQQhDIAhBCEMgCEEIQyAIQQhDIAhBCEMgCEEIQyAIQQhDIAhBCEMgCEEIQyAIQQhDIAhBCEMgCOH/AT2wXj99GwWvAAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 150x50 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "x = jnp.linspace(-15, 15, 500)\n",
    "y = jnp.linspace(-15, 15, 500)\n",
    "xx, yy = jnp.meshgrid(x, y)\n",
    "pos = jnp.dstack((xx, yy))\n",
    "pos = jnp.reshape(pos, (500*500, 2))\n",
    "L = jnp.array([[1., 1.], [0., 1.]])\n",
    "cov = jnp.dot(L, L.T) + jnp.eye(2)\n",
    "pdf = jax.vmap(lambda x: jax.scipy.stats.multivariate_normal.pdf(x, 2*jnp.ones(2),10*cov))(pos)\n",
    "\n",
    "import matplotlib.colors as mcolors\n",
    "\n",
    "# Your specified color\n",
    "color2 = \"#8ebad9ff\"\n",
    "\n",
    "# Create a custom colormap that transitions from white to color2\n",
    "colors = [\"white\", color2]\n",
    "cmap = mcolors.LinearSegmentedColormap.from_list(\"custom_cmap\", colors)\n",
    "\n",
    "with use_style(\"pyloric\"):\n",
    "    fig = plt.figure(figsize=(1.5, .5))\n",
    "    ax = plt.gca()\n",
    "    plt.contour(xx,yy,pdf.reshape(500,500), origin=\"lower\", extent=(-15, 15, -15, 15), cmap=cmap, vmax=0.005)\n",
    "    plt.xticks([-15, 15])\n",
    "    plt.yticks([-15, 15])\n",
    "    plt.xlabel(r\"$x_{t+1}$\", labelpad=-8)\n",
    "    plt.ylabel(r\"$x_{t}$\")\n",
    "    fig.savefig(\"transition2d.svg\", bbox_inches=\"tight\", pad_inches=0, transparent=True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Array(0.001, dtype=float32)"
      ]
     },
     "execution_count": 61,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pdf.min()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {},
   "outputs": [],
   "source": [
    "sde, weight_fn = init_sde(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Array(0.039, dtype=float32, weak_type=True)"
      ]
     },
     "execution_count": 63,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sde.std(sde.T_min)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "metadata": {},
   "outputs": [],
   "source": [
    "key, key_init = jax.random.split(key)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "metadata": {},
   "outputs": [],
   "source": [
    "c_in, c_noise, c_out = precondition_functions_v2(sde)\n",
    "init_fn, score_net = build_score_mlp(2, num_hidden=5, c_in=c_in, c_noise=c_noise, c_out=c_out)\n",
    "batch_sampler = build_batch_sampler(data)\n",
    "loss_fn = build_loss_fn(\"dsm\", score_net, sde,weight_fn, control_variate=False, control_variate_cutoff=1.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(10, 1) (10, 2, 1)\n"
     ]
    }
   ],
   "source": [
    "theta_batch, x_batch = batch_sampler(key_init, 10)\n",
    "d = theta_batch.shape[1]\n",
    "print(theta_batch.shape, x_batch.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "metadata": {},
   "outputs": [],
   "source": [
    "params = init_fn(key_init, jnp.ones((10,)), theta_batch, x_batch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(10, 1, 1)"
      ]
     },
     "execution_count": 68,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "score_net(params, jnp.ones((10,)), theta_batch, x_batch).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 103,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Array(29.565, dtype=float32)"
      ]
     },
     "execution_count": 103,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "loss_fn(params, key,theta_batch, x_batch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 104,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "schedule = optax.cosine_onecycle_schedule (\n",
    "    100_000, 1e-4,\n",
    ")\n",
    "optimizer = optax.chain(optax.adaptive_grad_clip(100), optax.adamw(schedule))\n",
    "opt_state = optimizer.init(params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 105,
   "metadata": {},
   "outputs": [],
   "source": [
    "@jax.jit\n",
    "def update(params, rng, opt_state, theta_batch, x_batch):\n",
    "    loss, grads = jax.value_and_grad(loss_fn)(params, rng, theta_batch, x_batch)\n",
    "    updates, opt_state = optimizer.update(grads, opt_state, params=params)\n",
    "    params = optax.apply_updates(params, updates)\n",
    "    return loss, params, opt_state"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 106,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "6.549075\n",
      "6.3933725\n",
      "6.3584175\n",
      "6.3702493\n",
      "6.389399\n",
      "6.361944\n",
      "6.3505454\n",
      "6.3405023\n",
      "6.3318152\n",
      "6.351687\n"
     ]
    }
   ],
   "source": [
    "for i in range(10):\n",
    "    l = 0.\n",
    "    for _ in range(10000):\n",
    "        key, key_batch = jax.random.split(key)\n",
    "        theta_batch, x_batch = batch_sampler(key_batch, 1000)\n",
    "        loss, params, opt_state = update(params, key, opt_state, theta_batch, x_batch)\n",
    "        l += loss/10000\n",
    "    print(l)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "metadata": {},
   "outputs": [],
   "source": [
    "from markovsbi.sampling.score_fn import FNPEScoreFn, UncorrectedScoreFn, GaussCorrectedScoreFn,CorrectedScoreFn, ScoreFn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "metadata": {},
   "outputs": [],
   "source": [
    "score_fn = FNPEScoreFn(score_net, params, sde, prior)\n",
    "score_fn = UncorrectedScoreFn(score_net, params, sde,prior)\n",
    "score_fn = GaussCorrectedScoreFn(score_net, params, sde, prior, posterior_precission_est_fn=lambda x: 2.0)\n",
    "#score_fn = CorrectedScoreFn(score_net, params, sde, prior)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0.343]\n"
     ]
    }
   ],
   "source": [
    "key = jax.random.PRNGKey(5)\n",
    "theta_o = prior.sample(key)\n",
    "num_obs = 101\n",
    "x_o = simulator(key, theta_o, num_obs)\n",
    "print(theta_o)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "metadata": {},
   "outputs": [],
   "source": [
    "from markovsbi.bm.plot_utils import use_style"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAScAAABtCAYAAADj7qf/AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy80BEi2AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAie0lEQVR4nO2deXRb5bnun615lmzJlm3ZceLEdmLsxKYhTig5JCUNtIS0BHrp4tZtOXBLW5qbnvZAerhkFQqlJbSrpdDTW1YogUN6kpbS00DTXk4gcSYyNJMzeIznQaOtedbe9w9J29oeJcuD7Hy/tbyWvSVtv5a1n/1+7/RRDMMwIBAIhAyDN9cGEAgEwlgQcSIQCBkJEScCgZCREHEiEAgZCREnAoGQkRBxIhAIGQkRJwKBkJEQcSIQCBkJEScCgZCREHEiEAgZyayL02uvvYbVq1dDLBbji1/8Iucxp9OJhx9+GCqVCnq9Hs8///xsm0cgEDIEwWz/woKCAjzzzDM4fPgwent7OY9t374dg4OD6O7uhtlsxqZNm1BcXIyvfvWrs20mgUCYY2ZdnLZt2wYAuHTpEkecvF4v9u/fj5MnT0Kj0UCj0WD79u144403JhQnj8fDfs8wDLxeL+RyOWQyGSiKmrk/hEAgzCgZE3Nqbm5GMBhEdXU1e6y6uhoNDQ0Tvk6hULBfSqUSer0eCoUCXq93hi0mEAgzScaIk9vthlwuh0Aw7MxpNBq4XK45tIowGT94+hl86jP3oaBkOSiKws6dO+faJMICYdaXdeMR93bC4TArUA6HA0qlcsLXud1u9nuPxwO9Xj+jdhKGOf7JWQzwc7H18R9AocnG8b+8g927X8S2bdtQW1s71+YR5jkZ4zmVl5dDKBTi8uXL7LFLly6hqqpqwtfJ5XLOF2F2cPtDaLAy0OYVQqHRgsfj4/Z7vwwAaGlpmWPrCAuBWRencDgMv9+PcDgMmqbh9/sRDAYhk8nw0EMPYdeuXXA4HGhtbcWrr76Kxx57bLZNJEwCwzA42mKBUBb1at12G2iaRndzND5YVlY2l+YRFgizLk4vvPACpFIpfvzjH+P999+HVCrF5s2bAURroNRqNQoLC/HpT38ajz76KCkjyEBMzgAc3hDkcjlyNAoMmfrQ13YdXqcdO3fuJEs6wrRALaQZ4h6PBwqFAsBwgJ0w/Zy6YUWHJVrCcUepDieutGPQG4JYLMH/vKMMcnHGhDIJ85iMiTkR5gehCI2ewWiZhkjAQ2GWDFXLipCdrYVcLofR6Z9jCwkLBSJOhJToHvQiHIk628VaOfg8CnlqCfu4yUHEiTA9EHEipES7Zbgiv0QXXTZr5WII+NFqfKPTjwUUKSDMIUScCEnj9IVgji3bVFIhtAoRAIDPo5CrjHpPvmAEDl9ozmwkLByIOBGSgmEYnGm3sT+X5Mg5vYt61fDS7nCjGTfMboQi9KzaSFhYkLQKISkaB1wwuwIAAIVYgDI9t3J/iU6OxgEn/KEIAqEITrfbcLrdFn1unhIr8lVzYTZhHkM8J8Kk2L1BNPTaAQAUBaxbqoWQz/3oSEV8fK4qD8VabvmGOxDGha4huAPh2TKXsEAg4kSYlKt9TkToaJB7Rb4KuQlLuERkIgHuKNXhrhV6LMtVQCUVso8ZSRaPkCJEnAgTEqEZ9Nt9AACxgIeVhZpJX5OnlqC2RIt1JVr2mNHhmykTCQsUIk6ECTE5/WxguyBLCj4v+QF+2QoRRILoR4yUGBBShYgTgUOYpnGsxYIPrxnh9ofQOzQ8tK8oS5bSuXgUxWbxAiEag57gtNo6EqcvhLMdNuKlLRCIOBE4tJrc6Bn0wuIK4NQNG3qHohc6n0chXz12rGkiEqvHZ7K1JRSh8XGTGa0mN060WUETL23eQ8SJwBKhGTQNONmfLa4AfMEIACBfLYGAn/rHJVHQkg2KMwwDbzDMBuGT4UqvA55YRjAQouHyk+zgfCflOqfTp09j3759OHfuHIxGI2iahl6vR01NDR5++GFs2LBhBswkzAZdNg+8MTEaSWGKS7o4CrEACrEA7kAYFlcA4Qg9qcg1GV240DWEXJUEm1bkTrpRxaAniCajk3PM7g1CnZAtJMw/khancDiMb3/726ivr8eGDRvw4IMPQqlUgsfjwel0oqOjA9u3b0dlZSXeeOMNyGRT+zAT5gaGYXA9wWvKUYphiRVdUhRgyJJO6bwUFW0MbjO7EaEZGJ3+CYWOZhhc74/aYXb60e/ww6AZ/btpmsHVfgcsrgCGPEGMXMXZvSEUa0e9jDCPSFqcfvSjH2HdunV4/fXXJ3zee++9h3/7t3/DK6+8krZxhNTxBMJoNrqglAiwLFeR9PZY3YNeOLzRnrgcpRh3lufgrw0D8AUjKNBIIRHyp2xTgUaKNnN01vulHjsK1FLwxsn6mZ0B+EPD3lubyTWmOLWa3bjS6+Ack4n4rOdn985s8H2midAMum0eaGQiZMlFUzqHNxgGw2DeztdK2urvfOc7yM3NnfR527Ztw/r169MyijA1+u0+nGqzIhCOpv4ZYFSbyUgYhkGT0YWL3UPssYp8FcQCPu6pzEO/3Zdylm4khVlSaBUi2NxBOLwhNBldqCgYu50lPisqTp/dB08gPOoCu2EZ3thCLOBBLRNhdXEW/vu6CaEIjSHv/G4+vtrnwNU+B4R8Hr5QUwCxILWbg90bxN+vGkEzDO6+JQ9ahXiGLJ05ko5wxoWJpmnORpZjkZOTk55VhJRpNblwtNnMChMAXOgawtAE6fsIzeCTdhsudA2xy6KibBm7hJOJBFiWq4Q4Da8JiC7tblucjbgT19Brh3uMgDXDMOgZ4oqT2+3Bvvc/Qv0n59Bn94FhGAx5guzfpVOI8eDqIny2Qo8suQgaWTTO5AmEEQzP38bjLlv0GgtFaJidgZRf32JyIUIzYBiwXut8I2lxam5uxqZNmyCRSKBSqaBUKrFp0ya88sorGBwcnEkbCZMQitC40G1nBSbuZURoBifarGNOBwiGaRxtNrPjdgGg0qDG+lLdjOyUrFWIWS8uQjNo6LOPek5idlCnFKO3txdNTU1oHHDhLxd6sO+jCzjRZuV4TSU53F4+jWx4CTRfR7e4/CFOttHiTk2cwjSNLtuwyPcO+eZlAWzS4lRXV4e6ujqcPXsWhw4dQn5+PgoLC/Hyyy+jrKwMb7755kzaSZiAnkEvwjEBKtbKcd+qAmTH4hROXwiXeuyc50doBh81mtjUPp9H4Y5SHVYVaWZ0C/eVhRq2YdjoGF0x3p2wpGOG+nDyw4NQ6/QoLKuEWqeH0WhCY7cFLSYXa/fIRuO45wQAQ/M07tRv55ZcxBMTydI35ON4jf5QBFb3/HsvkhYnHo+Hr33ta6iursbdd98Ng8GAvXv3ore3F/v378drr72G3/3udzNpK2EcEqdTlucpWbGJT6dsNblgS7j79tl9bLW2WMjHXSv0oy7ymUAk4CFHGY19+IIRzqQChmHYeBOPR8Hc2YTD//l/odBkg8fjQa7OAgAEAn7WQyzMkrHtMXGyEjwn+zyNOw2MqHAf9AQRppNfoiZ+HuL0jlguzweSFqfCwkIcPXqU/TnxDrtp0ybU19djz54902ocYXLc/jBMCdMpdbHplEqJEFUGTfQ5bg8OfHwep0+fAcC9E9cuyWYFYzZI/F2JdhidfjbTlq+WoLx0GXxuJ06+/58I+n2o/9NefHLoAMTi4aLOkUs6AJzapvmYsYvQzKhiVZpmMJik5+MLRlhxk4r4bJwvXuk/n0hanF599VXs2LEDTzzxBM6fPz/qcZFINGmgnJAeNM3A6g5w7qLt1oT4i447nXJ5nhJdrY1oampCW48R39z5HHbu3MkRhdxZFCZgfHG6YU6cTa5AbW0tli5dij//+nn8YGs1PtjzMu6sWoLlRdFkS5ZMNGY7jUjAY2Nudm9oSrEWmmZwss2KDxr6Z13gzC4/Wxmf6BUmu7TrtHlYz3JpjgI5sfHJTl9o3sXgki4lyM/Px9mzZ7F792587nOfg91ux9q1a6HX68EwDM6fP49HH310Jm2dcSI0g9ZYPKM8Tzmj8ZepcPKGFd02Lwo0UmxcnguGYdiANkUBS0Z4EufOncWvn/0uvvHiG1BotNjy6Pex+xtbsGzz1yCTy6GWCtPOxKWKVi4Cj6JAMwx7wflDEXbZIRby2WxhRUUF9u3bh5aWFpSVlaG2thZM7HVqmXDc/0+WTAhPIIxQhIbJGeD09yXDDYsbndbo+3q5x447yycvoZkuEuNNK/JVuByLFyYbFE9McCzRySEW8Ni57z2DXqgN6rTs84ci4FHUqOX0TJC0OHm9XshkMuzatQvPPPMMTp8+jdOnT8NsNkOpVOLpp5/G2rVrOc+dT3iDYRxvscLqHh5FW5idOX+D1R1AdywD02/3weoOIBim2bhNvloKmYj772xpaUF/ezPkqizweDxk6Q24deMW+AN+yOTyWV3OxRHweciWi2B1B+DwheAPRdBp9bDeQolOzhnLUltby9lBmKKocYfdxdHIROwy5qNGExQSAWiaQYRhsCJfhVsKuBeowxfCDYsbi7JlyJKJcK1/uFK+3+FHMEzP+MXoCYTRYnLhRiztT1FAqV6BZqML/lAEFlcADMNMeMMc8gTZJIBOKYZKKgSPAs53RWvYrvY5oFOIkKeeWrV/v92H460W8CgK91TmQSlJrj2IYRi0mNwQCXhYrJUlfdNPWpyee+45rF+/Hlu2bAFFUVi3bh3WrVs36nn19fXYv38/fvOb3yR76jklTNPotnlxqcfOprEBwOoJZpQ4XevjVkO3mFzwBIbtXZarGPWasrIyAMCVU4ex8o7NcNttuG3zNjZuo5sDcQKiS7v4TcDqCnBKA5aO8XekSrFWhlazC4FQdPmbWFPV0ONAuV7J6e87026DxRVAi9GFRdkytoEYiC7xeoa8WJqTvl1jwTAMbljc+EfnEKfROVcpgVjAR45SjJ5BL4JhGg5fiFMqMZIOK9drAgCFRIglOjk6YjeAo80WVBWqIRbw4A/RcPpCYBD9/CRuUsEwDC73OmB0+HBLgRpahQinbthiexYyaDO7UbMoC4FwBB0WD/I10nF7GeO9kgBw+qoD4Z7LqKurm/S9SVqcnn32WdTV1eHFF1/Exo0bsXTpUqhUKlAUBZfLhc7OTtTX10MkEuHAgQPJnnZOaTW5cLnHzilcjOPMoPW53RscFdDstA7HFlRSIQrH6H2rra3FU089hV+8vBNqbS4EQhGKFhWz27TnzFHVcI5SjMaB6PcNfQ42q5ajFE9Ls65GJsIXqg3otHpww+Jmzx+hGdAMA5snyF6IoQjNCmWEZjgXeJxu28yIUzhC43T7IFtwCUTLIxZpZVgVmziqU4jZLKbJGRhXnGhm2HYej0KxdvjGWluiRSjCoHfIiwjN4FK3fdTrO60elOcpUV2kgYDPQ8+Qj70hHmuxQC4WIJDQVtRl86K6SINjLVaYnX5IB5zYWl0AAY/rYTIMg2ZjNFTS19cLo9GEX3znq9MrTlKpFO+++y4OHjyI/fv34w9/+ANMJhM7laC6uhrf+ta38KUvfSnjYjVjEQhH8I/OIc7cnzy1BBZXABGayShxSlxmyMUCeAJhTqNrRYFq3Pf8pZdewrZt23Cx3weeKpcVJomQD6VkbnquEpeTiRXsY3l/U0XI56FUr0RprPDzhtmN07GtrazuACtONvfopmEg2g/o8IXgCYRhdPjhD0XS6i8ci390DnGEqVSvxMpCNef3JMbLum1RARmLuI0AYNBIOe0u8dKSE62WCbN2zUYXjE4/FJ4+XLUxEEkV7OfFM2KDCk8gjOsDTjae5QtG0GX1jvJ84+1HHo8HRqMJap0ePz14aaK3hSXlT+fWrVuxdevWVF+WcZgcAVaYclViVBdlQacQ4dAVI+zeIFyBMGiaGbdBdTYIR2hc6rGzwVmxkI8N5Tk4dGWAvaBkIj4WT1KjVFtbizybBydareyxHKV4zm4iEiEfKqmQcwNYnq9ilyIzQXwDUACwuoYF0ZoQaI7bxONRWFWkQZfNg+v9TtCxGqzSSfoUU8GSsJwV8nlYu1SLRWOEEbJkQqhlQji8IZhdAbj8oTFjPWPtxJwIn0fhn8pyYI5V4UdoBgI+BZVECJPTj0s9dkRoBtdbOzDQPwAq5gHlqOVYXLoCNMOAoqI3kFZT1O7LI4p7W82uUeLUEvOaAgE//t9/vIpHfvhr8AXJyU5aUb4PP/xw3rauGJ3Dd5BbCtTsxaqSRt84mmbmdDsjTyCMv101si4xANxSoIJGJkJBQpf+inxVUnO9DRopJ84yF8HwROLLDrGQhzvLc/Cp4qwZFUu1VMgGta3uAFtikChOd5blYMPyXGyu0CNbLuKIRZPRNWGfYirQDINzHcPXTfUizZjCBEQTAImiPdayMxyhE7KdPOSPMcUhfi69SoLFOjmW5ipQrJUjSy7C8nwVPleVj7DPBaPRBE1uPgzLKqDM0uEnO+qgj5iwLFeBO8tzUbMoi/28jfQ4be4gp9jX6QthIFazpZSK0Hi2HvXv7UXQn1zNVVri9Prrr6OpqSmdU8wZAwmtG4m1PqqEu5LTP7NLu2CYRn2zBWc7bKOmPp7vGmI9Cz6PwqeKs7A85tLXLMqCSipEgUaa9FJIwOdx4lKzXd80kiqDGvdU5uEL1YYpD7JLBYqioIvF2PyhSGxpzLBtHWIhD0qJAAaNlO3gz5aL2BiY0xfC364O4GL3UNp9aq0mF5tVy5aLUDrJ/3CJTs4WU3ZYPaN+vykWigCic95T2YQijloqBPoa0N/eDIVGCx6PB4U6G65BC/puNKG2RAuDRgohnzdqtlfixIMWk5st9ziTIMBryhfhySefxAd7XsYPtlYnZVNaQYeSkhLcfvvt6ZxixqFpBr12HzosbrbmpbpIw2ZwdAoxx6NIDMg6fWEga+Zsaxxwsnc8lUSI5bFdca3uABsElYr42LRCz9kDTi0V4r5VBSn/vlWFGvhDEWTLRWzv3VxBUdSsj/HQKkTsNlcWdxAMwAZ5dYrRy1yKisZqjrdao1ktBrje74QvGMGaJdloHHChZ8gLg0aKigIVZ6PRM2fOcOqzErmeEEO8bUn2pB6jTCSAXiWB0eGH2x+G1R3keL4D9mFPZDyvKRmWly7Fo//8dai0OShdtRbH/uttAMNZ3zjFWjlb1iIW8LChPAfvX+5HMEyj3eJGp80DOuFmy+dRWJorZ+OfLS0tSdmTlji9+eabaGpqwsqVK1FVVYWVK1eivLwchw8fxubNm9M59bTAMAzqWyzsBxKIFqIlDjMbWaCn4ojTzHpOiW0KjQNOlOqV4FHgZFOqDGqOTemgkAhw1wr9tJxrPpKYnbS6A5x1iW4codTIRPh8VT4aB5y40utgs2K9Qz522sOQJ4gbFjdWF2djkVaGnTt34t9ffwNLq25Dy8VT2P7tb+Kll14CEBXDeJtOrkoy7u8dyRKdnP28tFvcHHGKf755FIW8SWrAJqK2thbf/5d/we7/87/YYzt37hwlrgaNFFlyEYY8Qaws0kAi5KMkR8HOn08UJomQj9qSbDZAP7JubSLSEqe6ujo8/vjjuHLlChoaGrB//340NzfDZrPBZDKlc+ppoXfIxxGmOImtAKPEKSGDNZPLumCYhs0zbIdlyIk3372MYoMeZn60RUMpEcxYfc3NiFYhBkVFNcnqCqC7uxuDbhpisQQ6xfhV4HwehUqDGhqpEMfbrKBpZtQYGl8wghNtFhQ0m/Hyyy/jqT2HkGNYDNtAN372zS9g27ZtqK2t5YxCSaVsoihbhnOdQwhHaHRYPVhZqIFUxOeMV9EpxWkXiyZ6N2N5fUD0/dhcoUcgTLOtQpUFKrj8IbgDYfBjFeRLdHIUa+VTWmYCaYpTIBBAeXk5ysvL8eCDD7LHf/nLX6Zz2mmBYRhOCn7NkmyYXQE28wVEe5e0I5Y3Aj6PTdc7fKEJq3I9gTBazW7kqcQpV92aXcPd9fH6D7/XjYZ+FwqLAjAYCrGqSDOn2cKFhkjAg0oazXw1NLej+eolaHLywdA0bhx+By/95MUJX1+YLcP6Uh1OtFoRoRnolGJUGdRoMbnQN+QDw0TT8UsqP4Ucw2LweDxo8xfhM//jMbS0tKC2tpZzw1OlUMoh5PNQmqtA40B0a/hGoxO3LsritLsUaKbuNSWSjHcj4PM44ZBoJnl623zSktnVq1fj/fffH3X8u9/9bjqnnRZMDj9ssWBnllyEZbkK1CzScOICepVkTOGJZ+yCYXrMAs045zoHca3PgY8azbjYPQSzy4/jrRb87eoAzK6Jt0EyxaYbejwedLS1Qq3TY+nKNdDmF8FoNAF+57gZHMLUyVGI4fF4YDKZULyiGoZlFRDL5Nj905/gzJkzk76+MEuGLSsLsKlCj80VehRopLhjmY71WPiaPKy5+wG47TbQNA233YaK2o3IKakAAI7nlGz7R5zEzOyFtn689c7vcbG5g328YIptKZlKWuL0yCOP4L777psuW6aVawk7idwSK1KUiQSoTGh8HG9HkcSM3Xid3CNHW1zvd+K/r5nQbfNi0B3EsWbLhKUI8TEngYAff3/7V8MZEo0W5z86iEhPw7woZp1vFGvlCAT8+Gj/b9n3XJu/CACSDtQqJALOjU3A56EktvyWyuSoWbseDqsJHdfOw2E1IS8vDy5xLoJhmuM5pVoEKxXxsTRXgb6+XlxvbMYfPv4HTl28jr6+XkhFfM6gvYXAgt1UMx5XUkmFHA9kRb4SnyrOQs2irDGL1eKviTNeUNzqDky46WMgTONEq4XzHIZhwDAMAqEIWzOjkQnRef0irpw8jHAohNN/+yOO/fktlJctS/6PJSRNnlqCVVlhtF87jxsNZxEOBXHy/d8DGJ2VSoUyvYJN9xsMhVi+fDlWFWVhbeVSGAwGBMM0zE4/6znxKAqKKeyK4h9oQ39fH9Q6PT7z0DdYT5t22xbczWza+hfKyspw4sQJBINBbNq0KWPqnyryua0dFEWxKfvx4JQTjLNzrClha+3yPCXMrgAoRCtor/c74Q6EYXMHcaFrCLctyUY4QuNIsxk2d5CTxr+lpBBPPfUUdj+/nT02VobkZufs2bN46KGHUFNTg/feew9AdBuy7u7ulMMI/7RuDfTZGvzmqa+xx9J9z5WSaN1ZX6w9RK1S4v4774HZ6Ud9iwVAdOyJK+Y5KSSCKcUTO9ua0XCiHvc++q+sp+2wmuDqawOwasr2ZyLTJk7Xrl2DUBi9qK9cuTJdp02LeBNlqiQu6wY9Y8/RMTm4c3dWLx5+K7UKMT68ZkSEZtBiciFHKYbJ6Wd30eBkC1WSpDIkNztr1qzBwYMHUVNTgytXriAQCODkyZP4+c9/PqXzjTUrKl3K9UpWnEp0cogEPM7kh26bN9bVn/qSLk5ZWRnq3/tnLCpfiYq1G3Htk4/xwRs/w6H/ejdt+zONaROnuDCN/H4uKYhVtKaKRMhj+6wsrgACoQhnKFs4QsMaW5YpJYJRe6ply0VYvTgbZ2KNpqfbR1eAA1HXPl6vkkr9x81KVVUV7r//fnzve9+DwWDgjIU+cuQIioqKsGxZ8svh6X7P8zVS3LYkGy5/GKsKo7HNeIO1yx/mxCBVKQbD49TW1uJfv/897H5hB3tsoXra83Mr0CQpnoLXBESXfgaNlK0K7rf7sCSh3sjiDrCFZuNNWVyaI4fZ5UeHxcMRpluLs+AJhNlG0qmI583ME088gY0bN+LUqVMQJDSQ7tmzB88999wcWhZlrE1MdQoxJ0sHTN1zApKrRVoIzMiVcenSpZk4bUr4ve4xt7FOlsQ+tJ4RYyZMCZsc6sepyKUoCmsWZ3MyKIu0MizPU2L14mzcf2shJ3NImByfz4c9e/bg85//PH7xi1+wx48fP44jR47gnXfewbFjx+bQwrEZq8k61TKCkdTW1qKurm7BChOQpud06NChMY8fOHAAb731VjqnTpuD+99G97F32baBVMlRiiEW8hAI0RhwRIfOx2tMEuNNucrxC98EfB7Wl+bg1A0rRHwe1iTRR0UYG4ZhsGPHDvzwhz+E3+9HdXU1rl69isrKSqxZswYlJSV49tln59rMMRmrRSVeS0cYn7Teoaeffhq33nrrqC7pa9eupWXUdJCtL8Du3U+zbQOpEl/atVs8CEdomJx+FGikCIQjbNuJWiaEVDTxADKVVIh7KvOn9DcQhtm1axe+8pWvoLS0FMDwXLE333wTSqUSFRUVc2zh+GhkQgj5PLblRcCnIJ3ljSXmI2mJ0549e9DV1YUHHniAc/z3v/99WkZNBxW1GwGAbRuYCoVZMnaIV+9QdNeTLpuXbTtZaBW5mcwLL7zA+fnPf/4z+/3bb7+NqqoqNDc3Y9myZeDzM+vCj45rEQ3PNpKMv3MMYZi021dGChMAPPzww+mcdlo49df9ANIrrMtXS9ilXO+Qb9SM6Zmc3EgADAZDUs/TarU4duwY3nrrrYwTpjg6zswwsqRLhimJ069+9asxjx8/fjwtY6aTv7/1StopVgGfh/yYd+QLRvCPzkFYYzVKWTIRsuZ4JtJCJ9kdfO6991788Y9/xIsvTty4O5ckxiYn2kGFMMyUxKmoqAhPPvkk6ISdZ69fv45HHnlk2gxLlyNHjuCnP/1p2uepKlSzbQmXO0wYHLTB4/GM2sCSQJgIvUqMWwxqLNbJJ518SYgyJXG6//778eUvfxmPPfYYzpw5g4ceeggrV67EqlWZUz5/2223Tct5suUiLM2JNls2NTWho6MTjdev461f/2xazk+4OaAoCtVFGnx6mW7Wd1mer0w55iSRSNDb24vbb78dXq8Xly5dwp/+9KfptC1jCJtvoKujHWqdHoZlFaB4PLz04+eTGrFBIBCmxpTE6YEHHkBNTQ2ysrLw8ccfQ6PRQCpduJmr9tZmfPLXA+yIjbziaDo72REbBAIhdaaUNrh48SIuXLiAyspKANFq1R07dqCurg533HHHtBqYCZSVleHi0b9i5fq7UVG7Acf/8h/scQKBMDNMSZwOHDjAChMQXeL99re/xa5duxakOMW39d79/P9mjy3UZksCIVOgmHQ34cogPB4PFIpoJsTtdrNbKU8XE233Q5hfbN26FQcPHgQwvbOiCNMHqQZLATLWZGEy3bOiCNMDmddBIIA7K+q1117jNIwfOXIEbW1tKZ/z3LlzWLVqFT744IPpNPWmgYgTgRDjiSeewOHDh/H444+PmhU1Hl//+tfHfayrqws1NTXTaeJNBREnAgEzMysqcS9HQuqQmBPhpmc+z4payBBxItz0pDorKhKJ4K677gIANDU1YcOGDQCAJ598Evfee++s2r6QIeJEuOlJdVYUn8/H0aNHAURjTnv37p1Fa28eSMyJcFMyG7Oi9u3bh4aGBhw4cAAXLlyYipk3NaQIk0AgZCTEcyIQCBkJEScCgZCREHEiEAgZSUaJU2dnJyiKgkKhYL/uu+++uTaLQCDMARlZStDb2wuNRjPXZhAIhDkkI8UpFTwez5jfEwiE+U1GLeviVFZWIi8vD1u3bkVTU9OEz01cAur1+lmykEAgzDSzJk5btmwBRVHjfnV2dkKn0+HMmTPo6OhAU1MTSktL8dnPfhZOp3O2zCQQCBnCrBVhOp1OBIPBcR/Pzs4Gj8fVSoZhkJ+fj7179+Kee+4Z83WJSzmGYeD1eiGXyyGTyciWzwTCPGbWYk4qlSrl18S9qokYWQUerxAnEAjzm4yKOZ05cwaNjY2IRCJwu93YuXMnKIrCunXr5to0AoEwy2SUOLW3t2PLli1QqVRYsmQJrl27hg8//BBqtXquTSMQCLPMgmr8JRAIC4eM8pwIBAIhDhEnAoGQkRBxIhAIGQkRJwKBkJEQcSIQCBkJEScCgZCREHEiEAgZCREnAoGQkRBxIhAIGQkRJwKBkJEQcSIQCBnJ/wfKEzKQ8Noo/AAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 300x100 with 1 Axes>"
      ]
     },
     "execution_count": 79,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "with use_style(\"pyloric\"):\n",
    "    fig = plt.figure(figsize=(3., 1.))\n",
    "    plt.plot(x_o, lw=2, color=color2, alpha=0.8)\n",
    "    plt.savefig(\"x_o.svg\")\n",
    "    # Disable x_axis\n",
    "    plt.gca().axes.get_xaxis().set_visible(False)\n",
    "    # # Disable x spine \n",
    "    plt.gca().spines['bottom'].set_visible(False)\n",
    "    plt.xlabel(\"Time\", labelpad=0.5)\n",
    "    plt.ylabel(r\"$x_{1:T}(\\theta)$\", labelpad=-0.5)\n",
    "    plt.yticks([-5,10])\n",
    "    plt.xticks([0,100])\n",
    "    plt.ylim(-5,10)\n",
    "    plt.xlim(0,100)\n",
    "    plt.text(50, x_o[50,0] - 5, r\"$x_t$\", fontsize=9)\n",
    "    plt.text(60, x_o[60,0] - 5, r\"$x_{t+1}$\", fontsize=9)\n",
    "    \n",
    "    # Line connecting x_t and point on line \n",
    "    plt.plot([49, 49], [x_o[50,0], x_o[50,0] - 5], color=\"black\", lw=0.5)\n",
    "    plt.plot([59, 59], [x_o[60,0], x_o[60,0] - 5], color=\"black\", lw=0.5)\n",
    "\n",
    "    plt.scatter(jnp.arange(0, 101, 10), x_o[::10], color=color2, s=10, edgecolors='black')\n",
    "    \n",
    "    fig.savefig(\"x_o.svg\")\n",
    "fig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 112,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "from markovsbi.sampling.sample import Diffuser\n",
    "from markovsbi.sampling.kernels import EulerMaruyama, PredictorCorrector, DDIM\n",
    " \n",
    "kernel = EulerMaruyama(score_fn)\n",
    "time_grid = jnp.linspace(sde.T_min, sde.T_max, 500)\n",
    "sampler = Diffuser(kernel, time_grid, (d,))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 113,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "\n",
    "samples1 = jax.vmap(sampler.sample, in_axes=(0, None))(jax.random.split(key, 5000), x_o[:11])\n",
    "true_posterior1 = task.get_true_posterior(x_o[:11])\n",
    "true_posterior2 = task.get_true_posterior(x_o)\n",
    "x  = jnp.linspace(-1,1,1000)\n",
    "pdf1 = jnp.exp(jax.vmap(true_posterior1.log_prob)(x))\n",
    "pdf2 = jnp.exp(jax.vmap(true_posterior2.log_prob)(x))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 114,
   "metadata": {},
   "outputs": [],
   "source": [
    "samples2 = jax.vmap(sampler.sample, in_axes=(0, None))(jax.random.split(key, 1000), x_o)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 122,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAJgAAADNCAYAAABEgbllAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy80BEi2AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAZLElEQVR4nO2de1AUV77HvwMKCMzwGDM8FRCDgsig4itujKA3V2NKTYzm4TuP1aC7uO5VzJZrjJVSs1cSK8jelEkAYx6bW0kl5a6PZRX1qhHxETLggyWRh4BCUEeYgYEZ5nf/wGkdGBDp7pke5nyqpuzuc6bPd9ovv3P6dPevZUREYDBEws3RAhj9G2YwhqgwgzFEhRmMISrMYAxRYQZjiAozGENUmMEYosIMxhAVZjCGqDCDMUSFGYwhKsxgDFFhBmOICjMYQ1QGOKrhyMhI7Nq1C/PmzXOUBN6c+aXB5vbJ0YPtrES6SDKCmUwmsPsghcdoNNq9zV4bjIig1+t7/PTWFAsWLEBVVRVefvll+Pr6YtWqVZDJZNi9ezfi4+Ph4+MDnU4HmUyGoqIi7nu7du3CtGnTuPX6+nosWrQIISEhCA0Nxdq1a9Ha2trrHy8V3n//fTz++OOQy+WIjo7G7t27AQAVFRWQyWT4+OOPERkZCaVSidTUVLS1tQEAjh8/Dn9/f2RmZiIkJATBwcF4++23uf+H3NxcJCYm4u2330ZwcDBeeuklEBEyMjIQHR2NwMBAzJw5E9euXQMAZGVlYeTIkdDr9QCAM2fOwM/PD1euXOn7j6NeotPpCECPH51O19vdUUREBH333XfcOgCaPHky1dTUkMFgoPb2dgJAP/74I1fngw8+oKeeeoqIiMxmM02cOJHWrVtHer2eGhoaaNq0abRp06Zea+DLDz//avPzqHzzzTdUVVVFZrOZ8vPzycvLi06dOkXl5eUEgGbNmkV37tyhmpoaUqvVtGXLFiIiOnbsGLm5udHy5ctJr9fTlStXKDw8nHJzc4mIKCcnh9zd3Wnr1q3U2tpKer2e9u7dS6GhoaTRaKilpYXWrVtHcXFxZDQaiYhozpw5tGLFCtJqtRQZGUl79uzhdYwkZbAH1y3bujNYYWEhBQYGUnt7O1eel5dHw4YN67UGvghlsM7MnTuX3n33Xc5gZ8+e5cr+9re/UXR0NBF1GAwA1dXVceU7duyg6dOnE1GHwTofoxkzZtCOHTu4dYPBQHK5nE6fPk1ERA0NDRQaGkoJCQn0wgsv8P4tvR7ke3t7Q6fTPbQOH4YOHdrruhUVFdBqtQgMDOS2ERHa29t5aXAEX3zxBTIyMlBRUQGz2Yzm5mZERUVx5REREVbLNTU13LqXlxdUKlW35WFhYXBzuz8Sqq6uRmRkJLfu6emJ0NBQVFdXAwCUSiUWLlyIXbt2ITs7m/dv67XBZDIZfHx8eDdo4cEf3d02Hx8fNDc3c+s3btzglocMGQKVSmW1zRmpqqrCsmXLcPjwYUybNg0DBgzAvHnzrMazlZWVCAoK4uqHhYVxZQaDAfX19ZzJOpd3Pqbh4eGoqKjg1tva2lBbW4vw8HAAQEFBAXJycrBkyRKsXr0ap06dwoABfZ9scNhZZFBQEH755Zce64wdOxb79u2DyWRCUVER9u3bx5WNHz8eQ4YMwaZNm9DU1AQiQmVlJQ4dOiS2dEHR6XQgIqhUKri5ueHgwYPIy8uzqrN161ZotVrU1tZi+/btWLRoEVfm5uaGt956Cy0tLSgtLUVWVpZVeWcWL16M3bt34/Lly2htbcWmTZsQFhaGCRMm4O7du3jllVewc+dOZGdnQyaTYfPmzfx+IO9Oto/s37+fIiMjyc/Pj958880u4y0ioqKiIlKr1eTj40NPP/00bdy4kRuDERHV1dXR8uXLKSwsjORyOY0aNYo+/PBD+/4QAfjzn/9MSqWS/P39aenSpfTiiy9SWloaNwbbs2cPRUREUEBAAK1cuZIMBgMRdYzB/Pz86MMPP6Tg4GBSqVS0adMmbsyVk5NDarXaqi2z2UzvvfceRUVFkb+/Pz399NNUVlZGREQvvfQSzZ8/n6tbXl5OAQEBlJ+fT0REK1eupJUrVz7Sb5MRsQknqVJRUYGoqCjcuXMH/v7+XcqPHz+OefPmQavV2l1bb5HkRCuj/8AMxhAV1kUyRIVFMIaoMIMxRIUZjCEqzGAMUWEGY4gKMxhDVJjBGKLCDMYQFWYwhqgwgzFEhRmMISrMYAxRYQZjiAozGENUmMH6QHcpAxhdYQZjiAozGENUmMF40N7ezpK0PARmsD6SmZkJhUKB0NBQ/Otf/3K0HMnC7snvAx9+9i3Slr3Arcvlcly6dAlDhgxxoCppwiLYI0JEyHrvHQDAkiVLMCoxCU1NTdi+fbuDlUmUR30KuSc6P5ndHzl16hQBIC8vL2poaKCsL77n1rVaraPlSY4+ZbU4ePCgze1ff/019u7dy8Pu0uezzz4DAEx/9jkolUokTngCkdExqPjl3/j++++xbNkyByuUFn0agyUmJmLs2LFdzqCKi4tx/vx5wcRJDbPZjLCwMNy8eRPvZ3+NSVNTAAA5WRn4+IMdmDVrVrd/fK5Knwx2/vx5VFZWYv78+Vbbv/zyS7zyyiuCiZMaZ8+exaRJk+Dt44uDhVfh4ekJALj276tY/MyT8PD0gvbObQwaNMjBSqVDnwb5SUlJXcwFoF+bCwAOHz4MAJg0NYUzFwBEPT4CYWFhaGs14OTJk46SJ0kEO4vUarUoLS1FaWmppLO98OHMmTMAgDETp1htl8lkSJw8FQDwz3/+0+66pAxvg2VmZmLEiBFQKpWIi4tDbGwslEolYmJikJmZKYRGSWA2m1FYWAgAGJU4rkv5+CnTAAAnTpywpyzJw+tFDOnp6Th58iQ2btyIhIQELoeVVquFRqPBp59+ymXlc3bKyspw584deHh6YfiIuC7lCeMmAACKioqg0+ng6+trb4mShJfB8vPzUVBQAHd39y5l48aNw5IlSzB58uR+YbCCggIAQOxoNQYMHNilPDg0HKrgUNTfrEVhYSFSUlLsLVGS8OoizWZzjxd7qSNNOp8mJMPZs2cBAHHqrt2jhdH3otgPP/xgF03OAK8INnfuXEyYMAErVqxAfHy8VRd56dIl5OTk4LnnnhNCp8OxRDBb4y8LCeMm4OiB73H69Gl7yZI8vC925+bmIisrC0VFRVyOend3dyQmJmLNmjX9Yma7ubkZCoUC7e3t+P7kT1CFhNqsV1ryE1bMmwE/Pz/cvn3bZqp2V0OwuymMRiMaGjpuJVYqlfDw8BBit5Lg5MmTmDp1KkJDQ/HN//3UbT2TyYSZY6PR3NyMy5cvIzY21o4qpYlgf2IDBw5ESEgIQkJC+pW5gPvd46RJk3qsN2DAAIyIT7T6jqsjegy/evWq2E2ITm8NBtwfo1kmZV0d0Q2WmpoqdhOiYzmDnDhx4kPrjhqTBIBFMAu8ziJfffXVh9Zx9ghWXV2NmpoauLu7wxQQga4zYNaMUo8FAJSUlKCpqQlyuVx8kRKGVwQrKSnBzZs3ufkuWx9nxxKJokfEYZD3w18GNlgVjIiICBARzp07J7Y8ycMrgn3++edYv349cnJyuq2TnJzMpwmH05v5r85MmjQJlZWVKCgocPkZfV4RLCYmBmlpaaitre22zieffMKnCYdjGX89isGCh8cDYOMwgD1V1CNGoxEKhQIGgwFf5Z1BxLDhvfpeyY/n8dsFszB48GDU19dDJpOJrFS6sKnmHtBoNDAYDPD398eQyGG9/l5M3GgMHOiBhoYG7oXrrgozWA9YuscRo8c80mUfD09PxIxKAMC6SWawHrg/wE965O/G3xuzMYMJiEajsbnsrNw32NhH/u6oMcxggMAGW7t2rc1lZ6ShoQFlZWUAgFE93APWHZaoV1RUhJaWFkG1OROCGuzBE1JnPzm1RJ6RI0dC4R/wyN8PDg2H8jEVTCYTLly4ILQ8p0FQgz14Ou7sp+aWu1InT57cp+/LZDIuirlyNylaBHN2LHdDBA0f3ed9xLML3+ws0hZGo5F7RC1+zPg+72cUO5MUr4t0ZjQaDZqbm+ErVyByeEyf9zMyXg13d3fU1NSgurpaQIXOAxvk28Dy0MaoxHG87qsf5O2D6HvPULpqFBPUYDNnzrS57GwcPXoUQNcUAX3B1Qf6ghosPT3d5rIzYTKZcPz4cQDA+ClP8d5f/BjXvoWal8FMJpNQOiTDhQsX0NjYCLnCDzFxfT+DtGCJYBcuXEBrayvv/TkbvAymUqmQkpLSr5LOHTlyBAAwdtJvbKZEeFSGRA5DgPIxtLa2umRqJ14GU6vVyM/Px7hxj34pRaocOnQIAJD0xFRB9ieTyfBE8gwAwD/+8Q9B9ulM8DKYZVqiv0xP1NXVcTP4T04X7iTlNyn/CQD4+9//7tRn132Bl8HKy8uRnZ2Nc+fOdXtBt76+nk8TdmX//v0gIsQmjOk2PUBfGD/lKQwc6IFr167h8uXLgu3XGeBlsMbGRmRkZGDKlCmQy+WIiYnB/Pnz8c477+C7777Dzz//jIULFwqlVXS++eYbAMBT//GMoPv19vFF0pSOLnfn/2QLum+pw+ue/OTkZBw7dgxtbW24cuUKiouLodFoUFxcjOLiYtTW1kImk3FJUaTM9evXucfN/vdoIcIjogTd/5ED32Fz2m8RFBKG2uoql0mMwuuxNcvYy8PDA2q1Gmq12qr89u3bTjPhmpubCyLC2IlTBDcXADw5YxZ85QrU3ahBfn4+ZsyYIXgbUoTXn1FpaSm2bdvW7WNrgYGBTjHhajAY8NFHHwEAnl2wSJQ2PD298PScjszc77//vihtSBFeBtNoNA9NCGIr3bnUyM3NRW1tLVTBoZj+zFzR2nnp1Tfh5uaGQ4cO4eLFi6K1IyV4GUypVCIlJQWhocKdcdkbrVaLLVu2AAAW/XYNBoqYeio8Igoznu3I+Lhu3TqXmLJwjZFmNxAR1q5di7q6OgwdNhxzX1wqepsr1/0Jnl6DcOLECa5b7s+4tMFSN2zG3r174ebmhvR3M6ze3iEWIeFD8cbajQCAtLQ05OXlid6mI3FJg926dQtzFi7GRzvfBQCkbngbYyY8Ybf2X37tTUyfPQ9GoxHPzJ6NHTt29NsL4ZLJTUFEaG5uFny/DQ0NqKmpQW1tLcrLy/HDDz/gwMFDMLZ1/Ie+npaOl197U/B2H0Zbaxv+e/N/If/QfgAd49k5c+bgySefxJw5c+Dl5SVoe97e3g65pCcZg+n1evZ2DBHR6XTw8Xl4fjOhkUwXqdfrHS2hX+Oo48trJl9IvL29ueW6ujqH/LX1N/R6PYKCggBYH197IhmDPTg+8PHxYQYTGEfdUiWZLpLRP2EGY4iKZM4iGf0TFsEYosIMxhAVZjCGqDCDMUSFGYwhKsxgDFFhBmOICjMYQ1SYwRiiwgwmAs3Nzdi7dy9qamocLcXhMIOJwJYtW7B8+XIsWLDA0VIcDrsWKQJRUVGoqKgAANy9excKhcKxghwIi2ACYzKZcP36dW79ypUrDlTjeJjBBObGjRtWyV5+/vlnB6pxPMxgAvNg9AI6DOfK9Mlg169fx5o1a7BhwwbuAB44cADPP/+8oOKckc4G6+l95q5Anwz21ltvITk5GePHj0daWhpOnz6N2bNn486dO0Lrczo6v9HD1Q3Wp4c+EhMTuaw5CxYswLZt29DY2NhvcrXyoaGhAQAQEhKCGzdusC6yL1+6fPkyioqKuPU//elPuHbtGi5duiSULqfl9u3bAIDHH3/cat1V6ZPB1q1bB61Wa7Vt9erV+Oqrr4TQ5NRYDDV8+HAAHXkwXJk+dZHx8fFdtmm1WoSFhaG0tBRBQUHw9/fnq80p6RzBbt26BSJy2eED72mKzMxMjBgxAkqlEnFxcYiNjYVSqURMTAwyMzOF0OhUdI5gbW1toiR1cRZ4Pdmdnp6OkydPYuPGjUhISOCillarhUajwaeffora2lps375dCK1OgcVgQ4YMgYeHB9ra2nDr1i3XfVKdeJCUlEQmk6nbcqPRSElJSXyacDrkcjkBoLKyMgoODiYA9OOPPzpalsPg1UWazeYe84wSkUvkIbVgNBrR1NQEoCPDtlKpBODaA31eXeTcuXMxYcIErFixAvHx8VZd5KVLl5CTk4PnnntOCJ1OwYMTzX5+fggMDATADNZnNm/ejKFDhyIrKwtFRUXcRV53d3ckJibi97//PZYtWyaIUGfAMv7y9/eHu7s7i2AQIH3T8uXLsXz5chiNRm4WW6lUwkPEdOBSxWIwS+Sy/OvKl9AEyw82cOBAhISECLU7p6Q7g7nybL7ot+tcvXpV7CYkAzNYV0Q3WGpqqthNSAZLV8gMdh9eXeSrr7760Dosgrn2GIxXBCspKcHNmze5+S5bH1fCYjA9Ot4YEhAQYLXdFeEVwT7//HOsX78eOTk53dZJTk7m04RTYTGSXOGPM780sC4SPCNYTEwM0tLSerxr85NPPuHThFNhMZLfvcjFDCbANEVKSkqP5dHR0XybcBosYy25wh/AfYMZDAa0tLRg0KBBjpLmMNhTRQJiiVSKe5fM5HI53N3drcpcDWYwAeEimJ8/gI6XH8j9AqzKXA1mMIEgIs5EinumAu5HMxbBBECj0dhcdgWampq4i/0KPz8AwJlfGrjxGDOYAKxdu9bmsitgMZCHpxc8ve4P5lkEE5AHJ1ZdbZL1fvfoZ7Xd0l0ygwnAg0/OuNpTNA9Osj4Ii2AC4mpR60G4CNbpcT0FO4tkCAGLYLYRrYt0NR4WwZjBBMCVB/mW++7lD8yBdax3DPqv36i3uyYpIKjBZs6caXPZFaiv7zBQgHKw1XZLBGu8q7W3JEkgqMHS09NtLrsCFoMFdjaY/z2Dadkg/5ExmUxC6XB66urqAAABgx+z2q64d11Sr2tyyePFy2AqlQopKSk4f/68UHqclvtdpLXBfBX3J15dcaDP634wtVqN/Px8lxvQd4aI7neRnSLYgAED4B84GNrbDbhx4wZUKpUjJDoMXhHMMi3hytMTQMfLFtra2gAAAYHKLuWPBQUDgEu+WoaXwcrLy5GdnY1z586hpaXFZh3LX3Z/xjL+8vbxtbrQbUEVHAqga4JgV4BXF9nY2IiMjAyUlZXBbDZj2LBhGD16NBISEpCQkIDRo0fj9ddfx/HjxwWSK02qqqoAAEEhYTbLHwvueOLdFSMYL4MlJCTg2LFjaGtrw5UrV1BcXAyNRoOCggJ8/PHHqK2tdYnus7KyEgAQHBZus/yxIGawPmExj4eHB9RqNdRqtVX57du3XWLC9b7Bhtgst4zBSv59zW6apAKvMVhpaSm2bdvW7WNrgYGBLjHh+jCDBYcPBQDUXq+0myapwCuCaTQa/PTTTz3WsbywoT9TVlYGAAi9Z6TOREZ3ZJyuqaqAwWCAl5eX3bQ5Gl4RTKlUIiUlBaGhoULpcTrMZjNKSkoAANEj4mzWUT4WBF+5AmazmTOjq8DuB+NJRUUFdDodPDw8ER45zGYdmUyGyOEjAIAzo6vADMaTU6dOAQCiR8ZhwIDuRxxx6jEAgBMnTthFl1RgBuPJoUOHAADjn5jaY72kyR3l+w8cgtlsFl2XVBAshSZfiMjp3ohx7do1fPvttwCASU9NR0uzvtu6sQlj4O3jixvVVfjyyy/tnn3b29vbIXOSknkpvF6vh6+vr6Nl9Ft0Op1D3jYimS5Sr+/+r5/BH0cdX8l0kd7e3txyXV2d677bR0D0ej2CgoIAWB9feyIZgz04PvDx8WEGExhHXROWTBfJ6J8wgzFERTJnkYz+CYtgDFFhBmOICjMYQ1SYwRiiwgzGEBVmMIaoSNJgBw4cwNSpUxEQEACVSoUXXnjBJZ8p5ENrayveeOMNREVFQS6XY+TIkcjOzra7Dkka7O7du0hPT8f169dRXl4OhUKBhQsXOlqWU2EymRASEoIjR46gsbERubm5+OMf/4i8vDy76nCKiVaNRoMxY8agtbW1x7tG7UVtbS1+97vfobq6Gu3t7di3bx9iY2MdLeuhPP/884iPj8fWrVvt1qYkI1hnTpw4gdjYWEmYy2AwYObMmUhNTcXZs2exZs0arFmzxtGyHorBYEBhYSESEhLs2zDZmdmzZxOAbj/l5eVW9S9evEh+fn6Ul5dnb6k2eeedd2jVqlXc+sWLF8nd3Z2MRqMDVfWM2WymRYsW0bRp06i9vd2ubdvdYHfv3qVff/2128+DB0Cj0VBQUBDt27fP3jJt0t7eToMHD6bCwkJu2+HDhwkANTY2OlBZ95jNZlq1ahUlJSWRVqu1e/t273MUCkWv6hUXF2PGjBnYsWMHFi9eLLKq3lFQUACtVov169dz22pqauDt7Q25XO5AZbYhIqxevRpnz57F0aNH4dfpLST2EiE5SkpKSKVS0Z49exwtxYqMjAwaP3681balS5dScnKygxT1TGpqKiUkJFBDQ4PDNEhykL9z5078+uuv+MMf/gBfX1/uY0mT5Cjq6+u5W5CBjghx5MgRPPvssw5UZZvKykr89a9/RWlpKSIiIrhjuGrVKrvqcPxpmQ1ycnJ6fNG8o5DL5VbpQo8ePYqWlhYsXbrUgapsExERIYnUppKMYFIlJSUFFy5cgMFggE6nw4YNG/CXv/wFgwcPfviXXRSnmGiVEps2bcL+/fvh7e2N1157DW+88YajJUkaZjCGqLAukiEqzGAMUWEGY4gKMxhDVJjBGKLCDMYQFWYwhqgwgzFEhRmMISr/D2eOWO8m/OckAAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 110x170 with 2 Axes>"
      ]
     },
     "execution_count": 122,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "with use_style(\"pyloric\"):\n",
    "    fig, axes = plt.subplots(2,figsize=(1.1, 1.7))\n",
    "    axes[0].hist(samples1, bins=50, density=True, alpha=0.5, color=color2)\n",
    "    axes[0].plot(x, pdf1, color=\"black\")\n",
    "    axes[0].set_xticklabels([])\n",
    "    axes[0].set_ylabel(r\"$T=10^1$\")\n",
    "    axes[0].spines['left'].set_visible(False)\n",
    "    #axes[0].vlines(theta_o, 0, 4, color=\"grey\")\n",
    "    axes[0].set_xlim(-2,2)\n",
    "    axes[0].set_yticks([])\n",
    "    axes[0].set_yticklabels([])\n",
    "    \n",
    "    axes[1].hist(samples2, bins=50, density=True, alpha=0.5, color=color2)\n",
    "    #axes[1].vlines(theta_o, 0, 12, color=\"grey\")\n",
    "    axes[1].plot(x, pdf2, color=\"black\")\n",
    "    axes[1].set_ylabel(r\"$T=10^2$\")\n",
    "    axes[1].set_xlabel(\"$\\\\theta$\", labelpad=-8)\n",
    "    axes[1].spines['left'].set_visible(False)\n",
    "    axes[1].set_yticks([])\n",
    "    axes[1].set_yticklabels([])\n",
    "    axes[1].set_xlim(-2,2)\n",
    "    axes[1].set_xticks([-2,2])\n",
    "    axes[0].set_xticks([-2,2])  \n",
    "    fig.legend([\"true\",\"approx.\"], ncol=2, loc=\"upper center\", bbox_to_anchor=(0.5, 1.1), handlelength=.5)\n",
    "    \n",
    "    \n",
    "    plt.savefig(\"samples.svg\")\n",
    "\n",
    "fig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "markovsbi",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
