{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "6e1b5e3c",
   "metadata": {},
   "source": [
    "# Validating DSM for Lie Derivatives\n",
    "\n",
    "This notebook checks that the following loss function indeed learns the correct score\n",
    "\n",
    "\n",
    "$$ L = \\mathbb{E}_{p_\\sigma(\\tilde{R}, R)} \\left[ \\frac{1}{2} \\parallel s_\\theta(R, \\sigma) - \\mathcal{L}_\\mathcal{V}(\\log p_\\sigma(\\tilde{R}| R))  \\parallel_2^2   \\right]$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "5a6ba2c2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "%pylab is deprecated, use %matplotlib inline and import the required libraries.\n",
      "Populating the interactive namespace from numpy and matplotlib\n"
     ]
    }
   ],
   "source": [
    "%pylab inline\n",
    "import jax.numpy as jnp\n",
    "import jax\n",
    "from jaxlie import SO3\n",
    "from so3dm.distributions.isotropic_gaussian import IsotropicGaussianSO3\n",
    "\n",
    "center = SO3.identity()\n",
    "def make_dist(sigma=0.):\n",
    "  \"\"\"\n",
    "  Returns a mixture of Gaussians, convolved with a Gaussian of specified sigma\n",
    "  \"\"\"\n",
    "  return IsotropicGaussianSO3(center, jnp.sqrt(0.2**2+sigma**2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "4a439499",
   "metadata": {},
   "outputs": [],
   "source": [
    "import haiku as hk\n",
    "import optax\n",
    "\n",
    "# Create a random sequence\n",
    "rng_seq = hk.PRNGSequence(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "ec719136",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create our reference distribution, and noisy distribution\n",
    "sigma_n = 0.5\n",
    "ref_dist = make_dist()\n",
    "dist = make_dist(sigma_n)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "c1610c50",
   "metadata": {},
   "outputs": [],
   "source": [
    "def true_score(q):\n",
    "  return jax.grad(lambda s: dist.log_prob((SO3(q) @ SO3.exp(s)).wxyz))(jnp.zeros(3))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "f0785d0d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def score_noise(q0,key):\n",
    "    dist_noise = IsotropicGaussianSO3(q0, sigma_n)\n",
    "    qn = dist_noise.sample(seed=key)\n",
    "    def score_fn(q):\n",
    "      return jax.grad(lambda s: dist_noise.log_prob((SO3(q) @ SO3.exp(s)).wxyz))(jnp.zeros(3))\n",
    "    return qn, score_fn(qn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "e5f744e8",
   "metadata": {},
   "outputs": [],
   "source": [
    "@jax.jit\n",
    "@jax.vmap\n",
    "def get_batch(key):\n",
    "    k1,k2 = jax.random.split(key)\n",
    "    q0 = ref_dist.sample(seed=k1)\n",
    "    qn, sn = score_noise(q0, k2)\n",
    "    st = true_score(qn)\n",
    "    return {'q0':q0, 'qn':qn, 'sn':sn, 'st':st}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "dbe227b2",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:root:The use of `check_types` is deprecated and does not have any effect.\n",
      "WARNING:root:The use of `check_types` is deprecated and does not have any effect.\n",
      "WARNING:root:The use of `check_types` is deprecated and does not have any effect.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<matplotlib.legend.Legend at 0x7f330422d2d0>"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAAsTAAALEwEAmpwYAAAWWUlEQVR4nO3df4zU9b3v8edb3OMeD/7oEVAiyHpSU62HSHGj0nNb7dWTWlLFUlP0GittGmJjT9vU/iHaiLQ1JaZp1OsJhhSrNKTEniJFQ9OfGDUeLQsFf1FvvV6tWymsWIE9QhXP+/6xI11nZ9nZ3Zmdme8+H8nEme/3M995z3zjaz985vP9TGQmkqTWd0SjC5Ak1YaBLkkFYaBLUkEY6JJUEAa6JBXEkY164UmTJmVHR0ejXl6SWtLmzZtfy8zJlfY1LNA7Ojro6upq1MtLUkuKiJcH2+eQiyQVhIEuSQVhoEtSQTRsDF1Scb399tt0d3dz4MCBRpfSstrb25k2bRptbW1VP2fIQI+IduAR4KhS+//IzCVlbQK4A5gLvAkszMwtw6hdUoF0d3dzzDHH0NHRQV88aDgyk927d9Pd3c2pp55a9fOqGXL5K/A/M/MsYBZwcUScV9bmE8BppdsiYHnVFUgqnAMHDnDCCScY5iMUEZxwwgnD/hfOkIGefXpLD9tKt/IlGucBq0ptnwCOj4ipw6pEUqEY5qMzks+vqi9FI2JCRGwFdgG/zMwny5qcDLzS73F3aVv5cRZFRFdEdPX09Ay7WEnS4Kr6UjQz3wFmRcTxwAMR8c+Z+Uy/JpX+lAxYaD0zVwArADo7O12IXRonFq99uqbH+878mSN63uuvv86CBQt46aWX6Ojo4P777+d973tfTWtrpGFNW8zMN4CHgYvLdnUD0/s9nga8OprCGmnx2qcH3CS1vmXLlnHhhRfyhz/8gQsvvJBly5Y1uqSaGjLQI2JyqWdORPw9cBHw+7Jm64HPRp/zgD2ZuaPWxUpStW699VY+8IEPcNFFF3HllVfy3e9+l5/+9Kdcc801AFxzzTWsW7eusUXWWDVDLlOB+yJiAn1/AO7PzIci4lqAzLwb2EDflMUX6Ju2+Lk61StJQ9q8eTNr1qzhd7/7HQcPHmT27NmcffbZ7Ny5k6lT++ZrTJ06lV27djW40toaMtAz8yngQxW2393vfgLX1bY0SRqZRx99lE996lMcffTRAFx66aUNrmhseOm/pEKqNO3vxBNPZMeOvtHgHTt2MGXKlLEuq64MdEmF89GPfpQHHniA/fv3s2/fPh588EGgr6d+3333AXDfffcxb968RpZZc67lIqnuRjrNcKRmz57NggULmDVrFjNmzOAjH/kIADfccAOf+cxnWLlyJaeccgo//vGPx7SuejPQJRXSTTfdxE033QTALbfcAsAJJ5zAr3/96wZWVV8OuUhSQdhDl1R47/bQi85Ar1Klq0XHelxQkg7HIRdJKggDXZIKwkCXpIJwDF1S/T34ldoe75I7anu8grCHLkkFMe576K51LhXTrbfeyqpVq5g+fTqTJ0/m7LPP5qGHHuLcc89l48aNvPHGG6xcufLQVaRFYA9dUuH0Xz537dq1bNq06dC+gwcP8tvf/pbbb7+dpUuXNrDK2hv3PXRJxXO45XPnz58PwNlnn81LL73UiPLqxh66pEKqtHwuwFFHHQXAhAkTOHjw4FiWVHcGuqTCGWz53KJzyEVS/Y3xNMPBls8tOnvokgrppptu4vnnn+cXv/gFp5xyCgAPP/wwnZ2dAEyaNMkxdElSc3LIRVLhjZflc+2hS1JBGOiSVBAGuiQVhIEuSQXhl6KS6m7pf9Z2zZQlc5bU9HgjMXHiRHp7ewfdv3XrVl599VXmzp0LwPr163nuuee44YYb6lbTkD30iJgeERsjYntEPBsRAxY2jogLImJPRGwt3W6uT7mS1Bq2bt3Khg0bDj2+9NJL6xrmUF0P/SBwfWZuiYhjgM0R8cvMfK6s3aOZ+cnalyhJw1dp+dzjjjuOFStW8NZbb/H+97+fH/7whxx99NEsXLiQY489lq6uLv785z9z2223cfnll9Pb28u8efP4y1/+wttvv823v/1t5s2b957Xufrqq7n88ssPbb/qqqtYsGABN998M/v37+exxx5j8eLF7N+/n66uLu666y527tzJtddey4svvgjA8uXL+fCHPzzq9zxkDz0zd2TmltL9fcB24ORRv7Ik1clgy+fOnz+fTZs2sW3bNs444wxWrlx56Dk7duzgscce46GHHjrUk25vb+eBBx5gy5YtbNy4keuvv57MfM9rfeELX+AHP/gBAHv27OHxxx9n7ty5fPOb32TBggVs3bqVBQsWvOc5X/7ylzn//PPZtm0bW7Zs4cwzz6zJ+x7WGHpEdAAfAp6ssHtORGwDXgW+npnPVnj+ImARcOhS3FZW/uMY35k/s0GVSOpvsOVzn3nmGb7xjW/wxhtv0Nvby8c//vFDz7nssss44ogj+OAHP8jOnTsByExuvPFGHnnkEY444gj+9Kc/sXPnTk466aRDzzv//PO57rrr2LVrF2vXruXTn/40Rx55+Gj9zW9+w6pVq4C+VR+PO+64mrzvqgM9IiYCPwG+mpl7y3ZvAWZkZm9EzAXWAaeVHyMzVwArADo7O7N8vyTVSqXlcxcuXMi6des466yzuPfee3n44YcP7Xt3WV3gUC989erV9PT0sHnzZtra2ujo6ODAgQMDjnv11VezevVq1qxZwz333FP7N1OlqqYtRkQbfWG+OjPXlu/PzL2Z2Vu6vwFoi4hJNa1Ukqo02PK5+/btY+rUqbz99tusXr16yOPs2bOHKVOm0NbWxsaNG3n55Zcrtlu4cCG33347wKHhk2OOOYZ9+/ZVbH/hhReyfPlyAN555x327i3vI4/MkD306PsztxLYnpnfG6TNScDOzMyIOIe+PxS7a1KhpJY31tMMB1s+91vf+hbnnnsuM2bMYObMmYMG7ruuuuoqLrnkEjo7O5k1axann356xXYnnngiZ5xxBpdddtmhbR/72MdYtmwZs2bNYvHixe9pf8cdd7Bo0SJWrlzJhAkTWL58OXPmzBndmwaifIB/QIOI/wE8CjwN/Hdp843AKQCZeXdEfAn4In0zYvYDX8vMxw933M7Ozuzq6hpd9TVQyx+Jdgxd6rN9+3bOOOOMRpdxyC233MLEiRP5+te/Xpfjv/nmm8ycOZMtW7bUbDwcKn+OEbE5MzsrtR+yh56ZjwGVf8vpb23uAu4aRp2SVAi/+tWv+PznP8/Xvva1mob5SHilqKTCq+fyuRdddBF//OMf63b84XAtF0l1MdRwrg5vJJ+fgS6p5trb29m9e7ehPkKZye7du2lvbx/W8xxykVRz06ZNo7u7m56enkaX0rLa29uZNm3asJ5joEuquba2Nk499dRGlzHuOOQiSQVhoEtSQRjoklQQBrokFYSBLkkFYaBLUkEY6JJUEAa6JBWEgS5JBWGgS1JBGOiSVBAGuiQVhIEuSQVhoEtSQRjoklQQBrokFYSBLkkFYaBLUkEY6JJUEAa6JBXEkIEeEdMjYmNEbI+IZyPiKxXaRETcGREvRMRTETG7PuVKkgZzZBVtDgLXZ+aWiDgG2BwRv8zM5/q1+QRwWul2LrC89F9J0hgZsoeemTsyc0vp/j5gO3ByWbN5wKrs8wRwfERMrXm1kqRBDWsMPSI6gA8BT5btOhl4pd/jbgaGviSpjqoO9IiYCPwE+Gpm7i3fXeEpWeEYiyKiKyK6enp6hlepJOmwqgr0iGijL8xXZ+baCk26gen9Hk8DXi1vlJkrMrMzMzsnT548knolSYMY8kvRiAhgJbA9M783SLP1wJciYg19X4buycwdtSuzdhavfXpMj/2d+TPr9nqS1F81s1z+BbgaeDoitpa23QicApCZdwMbgLnAC8CbwOdqXqkk6bCGDPTMfIzKY+T92yRwXa2KkiQNn1eKSlJBGOiSVBAGuiQVhIEuSQVhoEtSQRjoklQQBrokFYSBLkkFYaBLUkEY6JJUEAa6JBWEgS5JBWGgS1JBGOiSVBAGuiQVhIEuSQVhoEtSQRjoklQQBrokFYSBLkkFYaBLUkEY6JJUEAa6JBWEgS5JBWGgS1JBGOiSVBBDBnpE3BMRuyLimUH2XxAReyJia+l2c+3LlCQN5cgq2twL3AWsOkybRzPzkzWpSJI0IkP20DPzEeD1MahFkjQKtRpDnxMR2yLiZxFx5mCNImJRRHRFRFdPT0+NXlqSBLUJ9C3AjMw8C/jfwLrBGmbmiszszMzOyZMn1+ClJUnvGnWgZ+bezOwt3d8AtEXEpFFXJkkallEHekScFBFRun9O6Zi7R3tcSdLwDDnLJSJ+BFwATIqIbmAJ0AaQmXcDlwNfjIiDwH7giszMulUsSapoyEDPzCuH2H8XfdMaJUkN5JWiklQQ1VxYpFFYvPbpAdu+M39mAyqRVHT20CWpIAx0SSoIA12SCsJAl6SCMNAlqSAMdEkqCANdkgrCQJekgjDQJakgDHRJKggDXZIKwkCXpIIw0CWpIAx0SSoIA12SCsJAl6SCMNAlqSAK/YtFlX4tqJUt/c+lA7YtmbOkAZVIakaFDvRx4cGvDNi0dNI/Dthm8EvF55CLJBWEPfRW8vLjA7f9w+lVPbV8uMYeu1Q8BnqLW/pfvx+4cdKHB24r+2OwFMfjpaJxyEWSCsIeepOqNKOlapWGZiQV3pA99Ii4JyJ2RcQzg+yPiLgzIl6IiKciYnbty5QkDaWaIZd7gYsPs/8TwGml2yJg+ejLkiQN15BDLpn5SER0HKbJPGBVZibwREQcHxFTM3NHrYocD0Y1xDISlYZlXhs4p51L7qh/LZJqohZfip4MvNLvcXdp2wARsSgiuiKiq6enpwYvLUl6Vy2+FI0K27JSw8xcAawA6OzsrNhGJX6xKWmYahHo3cD0fo+nAa/W4LgaY5XmtDszXWodtRhyWQ98tjTb5Txgj+PnkjT2huyhR8SPgAuASRHRTV+nrQ0gM+8GNgBzgReAN4HP1avYQnOIRdIoVTPL5coh9idwXc0qGgeevPPqgRunjH0dVamwmqMzX6Tm5KX/klQQXvqvw/KLUql1GOgavvJhGIdgpKZgoGvYBvTa/Wk8qSkY6Bq9SjN05ox9GdJ4Z6A3wKr2VwZsm8JRDahEUpE4y0WSCsIeep1d1n3bgG2r2htQiKTCs4cuSQVhoEtSQRjoklQQjqHXWaUZLZJUDwa66qLST+p5sZFUXwa66sOLjaQxZ6DX0FMHvj9gW8fYlyFpnPJLUUkqCHvoGjOOq0v1ZQ9dkgrCQJekgnDIpUns2vvX9zyecmwBV1+sNPPlNX8sQ6oVe+iSVBD20Guoo3dbo0uQNI7ZQ5ekgrCHPgrla527zrmkRrKHLkkFYQ9dzeXBrwzc5swXqSpVBXpEXAzcAUwAvp+Zy8r2XwD8FPh/pU1rM/ObtStTRbX0v34/ZBuvJZWqM2SgR8QE4N+BfwW6gU0RsT4znytr+mhmfrIONVZl8dqnx/w1XetcUjOpZgz9HOCFzHwxM98C1gDz6luWJGm4qgn0k4H+XdHu0rZycyJiW0T8LCLOrHSgiFgUEV0R0dXT0zOCciVJg6lmDD0qbMuyx1uAGZnZGxFzgXXAaQOelLkCWAHQ2dlZfgypIldplKpTTQ+9G5je7/E04NX+DTJzb2b2lu5vANoiYlLNqpQkDamaHvom4LSIOBX4E3AF8L/6N4iIk4CdmZkRcQ59fyh217rYRiq/iAi8kEhScxky0DPzYER8Cfg5fdMW78nMZyPi2tL+u4HLgS9GxEFgP3BFZjqkIkljqKp56KVhlA1l2+7ud/8u4K7aliZJGg6vFK2Sc84bqMI66kvxi1KpnIGu1lTpxzLmjH0ZUjNxcS5JKggDXZIKwiGXCp468P0B2zrGvgxJGhYDvQJ/Sk5SKzLQm9SuvX8dsG3KsUc1oJLWUb5EgLNeNN4Y6CqO8pkvznrROOOXopJUEPbQVVz+nJ3GmXEf6M5okVQU4z7QNc7Ya1eBGegqrEo/QL3kH05vQCXS2Bj3ge6cc0lF4SwXSSqIcd9D1/hScRimAXVI9WCgtxCvHpV0OOMu0MunKXY0pgw1kaVrPj5g25Irft6ASqTRGXeB7pegqkZ5yBvwagV+KSpJBWGgS1JBjLshF6lmvOpUTcZAl6pQ8YtTrzpVk2nJQF+89umq2u197d/qXInGs0pz2in7kQ3whzY0dloy0PU3zk1vMuU/sgH+0IbGjIEu1Vv5WLvj7KqTqgI9Ii4G7gAmAN/PzGVl+6O0fy7wJrAwM7fUuNbDuqz7tgHbVrWPZQXNw157cxkwNFNhPL4S575ruIYM9IiYAPw78K9AN7ApItZn5nP9mn0COK10OxdYXvrvmFnV/spYvpxUd+U/eg2w5LXX37vB3r76qaaHfg7wQma+CBARa4B5QP9AnwesyswEnoiI4yNiambuqHnF+GXnSJT32u2xt4AK4/EDIr7K3n4lVc3S8Q9GS6km0E8G+nd/uxnY+67U5mTgPYEeEYuARaWHvRHx/LCq/ZtJwGsjfG6zaPX30Or1Q+u/h1HVfwu/qKLVnSM9fLXG9TkYoRmD7agm0KPCthxBGzJzBbCiitc8fEERXZnZOdrjNFKrv4dWrx9a/z20ev3Q+u+h2eqv5tL/bmB6v8fTgFdH0EaSVEfVBPom4LSIODUi/g64Alhf1mY98Nnocx6wp17j55KkyoYccsnMgxHxJeDn9E1bvCczn42Ia0v77wY20Ddl8QX6pi1+rn4lAzUYtmkCrf4eWr1+aP330Or1Q+u/h6aqP/ompkiSWp3L50pSQRjoklQQTR3oEXFxRDwfES9ExA0V9kdE3Fna/1REzG5EnYOpov4LImJPRGwt3W5uRJ2DiYh7ImJXRDwzyP6m/vyhqvfQ7OdgekRsjIjtEfFsRAxYhL2Zz0OV9Tf7OWiPiN9GxLbSexhwfVfTnIPMbMobfV/A/l/gn4C/A7YBHyxrMxf4GX3z4M8Dnmx03cOs/wLgoUbXepj38FFgNvDMIPub9vMfxnto9nMwFZhdun8M8H9a7P+Daupv9nMQwMTS/TbgSeC8ZjwHzdxDP7TkQGa+Bby75EB/h5YcyMwngOMjYupYFzqIaupvapn5CPD6YZo08+cPVPUemlpm7sjSQneZuQ/YTt9V2P017Xmosv6mVvpce0sP20q38tkkTXEOmjnQB1tOYLhtGqXa2uaU/in3s4g4c2xKq5lm/vyHoyXOQUR0AB+ir4fYX0uch8PUD01+DiJiQkRsBXYBv8zMpjwHzbwees2WHGiQamrbAszIzN6ImAuso2/FylbRzJ9/tVriHETEROAnwFczc2/57gpPaarzMET9TX8OMvMdYFZEHA88EBH/nJn9v5dpinPQzD30Vl9yYMjaMnPvu/+Uy8wNQFtETBq7EketmT//qrTCOYiINvrCcHVmrq3QpKnPw1D1t8I5eFdmvgE8DFxctqspzkEzB3qrLzkwZP0RcVJEROn+OfSdj91jXunINfPnX5VmPwel2lYC2zPze4M0a9rzUE39LXAOJpd65kTE3wMXAeU/KNsU56Bph1yyOZccqFqV9V8OfDEiDgL7gSuy9JV5M4iIH9E3A2FSRHQDS+j7QqjpP/93VfEemvocAP8CXA08XRrDBbgROAVa4jxUU3+zn4OpwH3R92M/RwD3Z+ZDzZhFXvovSQXRzEMukqRhMNAlqSAMdEkqCANdkgrCQJekgjDQJakgDHRJKoj/D9uXBZ5PjNvjAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Checking that distributions add as we expect\n",
    "batch = get_batch(jax.random.split(next(rng_seq),10000))\n",
    "\n",
    "angles = jnp.linalg.norm(jax.vmap(lambda x: SO3(x).log())(batch['q0']),axis=-1)\n",
    "hist(angles,64, density=True, range=[0,pi], label='q0',alpha=0.6);\n",
    "\n",
    "# Let's look at the distribution of axis angles\n",
    "angles = jnp.linalg.norm(jax.vmap(lambda x: SO3(x).log())(batch['qn']),axis=-1)\n",
    "hist(angles,64, density=True, range=[0,pi], label='qn',alpha=0.6);\n",
    "\n",
    "\n",
    "qtest = dist.sample(10000, seed=next(rng_seq))\n",
    "angles = jnp.linalg.norm(jax.vmap(lambda x: SO3(x).log())(qtest),axis=-1)\n",
    "hist(angles,64, density=True, range=[0,pi], label='qanalytic',alpha=0.6);\n",
    "\n",
    "legend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "e275889f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def forward(x):\n",
    "    net = hk.nets.MLP([512, 256, 128], activation=jax.nn.leaky_relu)(x)\n",
    "    net = hk.Linear(3)(net)\n",
    "    return net\n",
    "\n",
    "model = hk.without_apply_rng(hk.transform(forward))\n",
    "params = model.init(jax.random.PRNGKey(0), jnp.zeros([1,4]))\n",
    "\n",
    "losses = []\n",
    "totsteps=0\n",
    "from flax.metrics import tensorboard\n",
    "summary_writer = tensorboard.SummaryWriter('test')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "a13f7315",
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = optax.adam(learning_rate=1e-4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "42a83bb4",
   "metadata": {},
   "outputs": [],
   "source": [
    "opt_state = optimizer.init(params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "7b3feb40",
   "metadata": {},
   "outputs": [],
   "source": [
    "def loss_fn(params, rng_key, batch):\n",
    "    score_pred = model.apply(params, batch['qn'])\n",
    "    \n",
    "    loss = (jnp.linalg.norm(score_pred - batch['sn'], axis=-1))**2\n",
    "    \n",
    "    return jnp.mean(loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "dbd4c84d",
   "metadata": {},
   "outputs": [],
   "source": [
    "@jax.jit\n",
    "def update(params, rng_key, opt_state, batch):\n",
    "    loss, grads = jax.value_and_grad(loss_fn)(params, rng_key, batch)\n",
    "    updates, new_opt_state = optimizer.update(grads, opt_state)\n",
    "    new_params = optax.apply_updates(params, updates)\n",
    "    return loss, new_params, new_opt_state"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "cc62f525",
   "metadata": {},
   "outputs": [],
   "source": [
    "for step in range(20000):\n",
    "    batch = get_batch(jax.random.split(next(rng_seq),1024))\n",
    "    if isnan(batch['sn'].mean()):\n",
    "        continue\n",
    "    loss, params, opt_state = update(params, next(rng_seq), opt_state, batch)\n",
    "    totsteps+=1\n",
    "    losses.append(loss)\n",
    "    if isnan(loss):\n",
    "        break\n",
    "    if step % 5 ==0:\n",
    "        summary_writer.scalar('train_loss', loss, totsteps)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "dfa50019",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch = get_batch(jax.random.split(next(rng_seq),1024))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "f4561e66",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<matplotlib.lines.Line2D at 0x7f327438e170>]"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXEAAAD6CAYAAABXh3cLAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAAsTAAALEwEAmpwYAAAYB0lEQVR4nO3de3xcdZnH8e9D01JaCm1pgNICabFFK3JrKGC5I4gswroigoKsgKy7gigqlkVXVtd9qbgssqxCBVwQBOWmCHIphXK/pVBoS6+UQgNtM703SXObefaPOTPNSU+aZmbSmZ9+3q9XXzM5c2bOk18m3/7ynHPmmLsLABCmHcpdAACgcIQ4AASMEAeAgBHiABAwQhwAAkaIA0DAegxxM7vVzBrMbE6nZdeY2Xwze9PMHjCzoX1aJQAgkfV0nLiZHSOpUdLt7n5AtOxkSU+6e4eZ/VSS3P27PW1sxIgRXlNTU3TRAPC3ZObMmavcvTrpsaqenuzuz5hZTZdlj3f68iVJZ25LITU1Naqrq9uWVQEAETN7t7vHStETv0DSIyV4HQBALxUV4mZ2laQOSXduZZ2LzazOzOpSqVQxmwMAdFFwiJvZ+ZJOk/RF30pj3d2nunutu9dWVye2dAAABeqxJ57EzE6R9F1Jx7p7c2lLAgBsq205xPAuSS9K2t/M6s3sQkk3SBoiaZqZzTKzG/u4TgBAgm05OuWchMW39EEtAIBe4oxNAAhYECE+fd5K/XLG4nKXAQAVJ4gQn7EgpV8/s6TcZQBAxQkixM3KXQEAVKYgQlySuBIoAGwpiBBnIg4AyYIIcUnq4cMWAeBvUhAhbmbq6SNzAeBvURAhDgBIFkyIMw8HgC0FEeJmIsUBIEEYIc7xKQCQKIgQl5iIA0CSIELcTBydAgAJwgjxchcAABUqiBCXaKcAQJIgQjzbTil3FQBQeQIJcRoqAJAkiBCXJKehAgBbCCLETbRTACBJECHO4SkAkCyMEBdHpwBAkiBC3GSkOAAkCCPEaacAQKIgQlzi6BQASBJEiHN0CgAkCyPEaacAQKIgQlxivyYAJAkixE1cKBkAkoQR4rRTACBRjyFuZreaWYOZzem0bLiZTTOzRdHtsL4tk3YKACTZlpn4/0k6pcuyKZKmu/s4SdOjr/sMR6cAQLIeQ9zdn5G0psviMyTdFt2/TdLfl7asLuinAECiQnvie7j7ckmKbncvXUkAgG3V5zs2zexiM6szs7pUKlXYa0S3HKECAHGFhvhKMxspSdFtQ3cruvtUd69199rq6uqCNkY3BQCSFRriD0o6P7p/vqQ/laacrWMiDgBx23KI4V2SXpS0v5nVm9mFkn4i6SQzWyTppOjrPmNRQ4UMB4C4qp5WcPdzunnoxBLX0i3aKQCQLIgzNnPYsQkAcUGEeP7olLJWAQCVJ4wQp50CAImCCPEcuikAEBdEiJvljk4hxQGgsyBCHACQLKgQp50CAHFBhDg7NgEgWRghLlIcAJIEEeI5tFMAIC6IEM+1Uzg6BQDiwgjxchcAABUqiBDPoZ0CAHFBhPjmdgoAoLMwQpyGCgAkCiLEc/goWgCICyLEaacAQLIgQhwAkCyoEKebAgBxQYS40U8BgERhhHi5CwCAChVEiL+/bpMkKUM/BQBiggjxW557R5I054P1Za4EACpLECGew0QcAOKCCnEAQBwhDgABCyrE6aYAQFwQIf7ZQ0dLkmp2G1TmSgCgsgQR4ofVDJMk9e8XRLkAsN0EkYpc7R4AkhUV4mb2TTOba2ZzzOwuMxtYqsKS0BMHgLiCQ9zMRkn6uqRadz9AUj9JZ5eqsNi2OPEeABIV206pkrSTmVVJGiTpg+JL6h4XhQCAuIJD3N3fl/RzSe9JWi5pvbs/XqrCYpiIA0CiYtopwySdIWmMpL0kDTazcxPWu9jM6sysLpVKFV4pAGALxbRTPiHpHXdPuXu7pPslfbzrSu4+1d1r3b22urq6iM3x2SkA0FUxIf6epCPMbJBlr9pwoqR5pSkrjm4KACQrpif+sqR7Jb0maXb0WlNLVBcAYBtUFfNkd/+BpB+UqJZuGWf7AECiIM7YzKEnDgBxQYQ483AASBZEiOc4J94DQEwQIZ6mjwIAiYII8VujCyUvXNlY5koAoLIEEeJfPHwfSdLY6sFlrgQAKksQIb7LTv3LXQIAVKQgQjyH1jgAxAUR4pzsAwDJggjxzZiKA0BnQYQ483AASBZEiOfQEweAuCBCPNcSJ8MBIC6MEKehAgCJggjxHNopABAXRIhzhCEAJAsixHP4FEMAiAsixHMTcdopABAXRojTTgGAREGEeA4zcQCICyTEmYoDQJJAQjyLHZsAEBdEiOfP2CTDASAmjBAvdwEAUKGCCHEAQLIgQpyLQgBAsiBCPIeeOADEBRHi+TM2OToFAGLCCHG6KQCQKIgQz6GdAgBxRYW4mQ01s3vNbL6ZzTOzI0tVWHw7ffGqABC+qiKf/wtJj7r7mWY2QNKgEtTULSbiABBXcIib2S6SjpH0j5Lk7m2S2kpTVpdtRbs2nX4KAMQU004ZKykl6Tdm9rqZ3Wxmg0tUVxztFABIVEyIV0k6VNKv3P0QSU2SpnRdycwuNrM6M6tLpVJFbI52CgB0VUyI10uqd/eXo6/vVTbUY9x9qrvXunttdXV1QRtiIg4AyQoOcXdfIWmZme0fLTpR0lslqarbbfblqwNAeIo9OuVSSXdGR6YskfTl4kva0ubPTiHFAaCzokLc3WdJqi1NKd2jnQIAyThjEwACFkSIc8YmACQLIsRzmIgDQFwQIb75jM0yFwIAFSaMEKedAgCJggjxHD47BQDigghxJuIAkCyIEM9hHg4AcWGEeDQVp5sCAHFBhLjRUAGAREGEeA5XuweAuCBCnEMMASBZECGex0QcAGKCCHE+iBYAkoUR4vRTACBRECGewyGGABAXRIgzEQeAZEGEeA6HGAJAXBAhnt+xSYYDQEwYIU47BQASBRHiOUzEASAukBBnKg4ASQIJ8axMhrk4AHQWRIjPWNAgSbr1+XfKXAkAVJYgQvyDdS2SpCWppjJXAgCVJYgQ3yF/UQjaKQDQWRAhnjvEkJY4AMQFEeK1+w6XJB3/4d3LXAkAVJYgQnzCXrtIko7bv7rMlQBAZQkixI0LJQNAoqJD3Mz6mdnrZvZQKQpK3AaXhQCARKWYiV8maV4JXqdbzMQBIFlRIW5moyX9naSbS1NOd9vJ3pLhABBX7Ez8OklXSMoUX0r3cu0UZuIAEFdwiJvZaZIa3H1mD+tdbGZ1ZlaXSqUK3FZBTwOAv3rFzMQnSzrdzJZKulvSCWZ2R9eV3H2qu9e6e211dXGHCHJlHwCIKzjE3f1Kdx/t7jWSzpb0pLufW7LKOuHKPgCQLKzjxMtbBgBUnKpSvIi7z5A0oxSvlSy3Y5MYB4DOgpqJAwDiwgjx6JaJOADEhRHiTMUBIFEQIZ7DIYYAEBdEiNNOAYBkYYQ4H4AFAInCCPHcIYZlrgMAKk0YIc6FkgEgURAhnkOEA0BcECHOEYYAkCyIEM9jKg4AMUGEeO5kH44TB4C4MEI8umW/JgDEhRHifBQtACQKI8S5xiYAJAojxPMzcVIcADoLI8TLXQAAVKggQjyHdgoAxIUR4uzYBIBEQYS4iY8xBIAkYYQ4M3EASBRGiEe3TMQBIC6MEM+ddk+KA0BMGCFe7gIAoEIFEeI5zMMBIC6IEOcamwCQLIwQ5xqbAJAoiBAX19gEgERBhDiXZwOAZGGEeHTLRBwA4oII8R24PBsAJCo4xM1sbzN7yszmmdlcM7uslIV1lgvxDBkOADFVRTy3Q9K33P01MxsiaaaZTXP3t0pUW16uJ56hnwIAMQXPxN19ubu/Ft3fKGmepFGlKqwzjhMHgGQl6YmbWY2kQyS9nPDYxWZWZ2Z1qVSqoNfPt1PopwBATNEhbmY7S7pP0jfcfUPXx919qrvXunttdXV1QdvoF4V4mqk4AMQUFeJm1l/ZAL/T3e8vTUlJ28neMhHvmbtrTVNbucsAtsmjc1aopT1d7jKCVszRKSbpFknz3P3a0pWUuC1J0vXTF+mp+Q1atqZZ982s1+rGVt03s17vrm5SzZSHdfkfZun+1+qVzrjSGdcLi1dp2ZpmdaQzuvHpt/XC4lWavyL7x8KyNc2xN8+5N7+sc6a+pKcXprRgxUbd/cp7qv2PJ3TWTS/q/XWbYvXcN7Ne1z6+IH8GaUc6o8bWDjVsbNEH6zaptSOt5rYONbd1aH1zu+Yt36DUxlZtastuz92VybjaOjKSpPq1zVrd2CpJak9nNGNBgy6963VlMq7mtg61dqT1p1nv67cvLtVPHpmvZxellMm47qlbpleXrtH/PrVYG1va5e76wYNzdeiPpum91c352tIZV8OGFt3+4lItWrlRf3z9fU17a+U2jX1Le1qH/fgJ3VO3rNt12joyenL+Sq3f1K7FDY1avn6T1m9qV/3a5vzYPbsopXdWNemnj85XzZSHdfOzS9TY2qF1zW1auqpJC1du1M8ena/UxlalMy5316xl69TWkdEdL72rdMZ1/fRFWrhyoySpua1Dv5yxWG8sWxerpT2d0d2vvKem1g6t39Su5xev0qrGVq3f1K4n3lqpb/5+lv7zL/M04d8e1fOLV0mSUhtbddsLS/M/g01taTW2dmhNU5vufPld3f9avVIbW7W6sVVvfbBBLe1pLVvTrGsfX6Bv3/OGjr3mKdVMeVg1Ux7WdU8slLvr5SWrlYneh++v25T/WbanM1q4cmN+/RXrW/K1N2xo0aqoBnfXqsZW3f7iUqUzrhueXKRDfvi4Hni9Xt/742wdd81Tuv3FpcpEP9sNLe1a29SmK++frasemK3FDRu1oaU9NjZv1q/TmCsf1i3PvaP2dCb/+/POqqb8OotWbtQZNzynm55+Wy+8vUr/8Mvn9d7qZrV1ZPTI7OWaXb9e76/bpG/94Q2tbWrLv4dz7nrlPX3p1le0JNWo+2bW6+E3l2ttU5tueHKR1je3q2bKw7r6wblqaU/rkdnL9dU7Zmrij6apYcPmcbj52SWatWydMhnPv09yP/P/mb5IHemMOtKZ/O/vY3NX6OifPakr7n1DC1Zk3x+vvbdWS1KNmj5vpb78m1f07uombWpL67G5K5TJuGYsaNCrS9do6jNva01Tm55a0KAHXq/XH15dptaO7OtmMp4fm8UNjfnfmUUrNyqdcbWn49/7+k3t+sj3H9Vx1zyllvZsBmxoaVdbR0aPz13R7e9PsazQU9nN7ChJz0qaLSn33fyru/+lu+fU1tZ6XV1dQdurmfJwQc8DgEpw50WHa/KHRhT0XDOb6e61SY8VfIihuz8nPuobALZJ7q+sUgvijE0ACF1zW9/0/glxANgOnllY2CHWPQkmxG86b2K5SwCAgvXVp7EGE+Kf/Oie+vMlR+krR4/JL6vaYfOo/PNx++Xvf/bQ0bru8wfr8pPG64BRu+iCyWO0/x5DJEkHjt5VnzpgT31u4ujY63/h8H309HeO0w/P+Kj++LXJ+eX3fPVIXXrCh3Tqx/bML/vYqF3z92t2G5S/P2roTvrxZw7QGQfvpaPHjdCzVxyvF6acoJnf+4RGD9sptr1zJu0jSardd5guO3Fc/gd85NjdNPW8ifrC4fvoq8ful/++xlYP1iH7DO12fHYbPEC3XTBJktRvB9PNX6qNPTZ0UH9Nqhmu847YV5K0xy47SpLOqh2t//78QRqy4+bdI18/cVzstYcPHpAfrw/tvrMOqxmmY8ZX65nvHK9/OnasjhlfrevPOSQ/BqcdODKxxgFVm99uOyS8oa/+9IRYbTljRwzW1Z+eoIcuPSq2fOig/qoesqPu+soROnrcCH1k5C755RdMHqNRQ+NjXrPbIJ0zaW8dPmZ4ftnJE/bQpOjrSTXD8z+n7582QR/fbzcdus9QnVW7+b1y4VFjNCHajiRdftL4xO9Vkk4/aC9N3HeYHrxk8haP7TM8+74ZMjA77j//3EEaPniAvn7iOF1z5oGJr3fjuRP1uYmjddmJ43TCh3ff4vELJo/RnrsMlCTtG70vP3PIKF1z5oE67cCRuusrR0iSzj1iH40aupP+fMlR2q96sI7fP37+xtHjNu98O2Z8tU6esEf+6xM+vLt+d9HhOu3AkRo0oF9ind0tn1QzXDvvWJWvratjxsfruOm8iXr2iuM1qWbzz+uio8Z0fZokaa9dB+rsw/bWL84+WAftPXSLx5PGq6sbvnCIfnfR4frOJ/fvcV1JGr/Hzvn7wwcP6HH9a886eJtet7cKPjqlEMUcndKdN5at09BB/bXvboMLev6mtrQGVO2gfl1SZW1Tm2bVr9Px+/f8wy+nDS3t2mVg/8THGls71L+faceq5F+qrWnY0KKmtrTGjChsXHOWrWnWroP6d1tjdzrSGbWlMxo0IL7vPfd+tSKnNasaW/XSktU67cC9Cnr+tLdWar/qwRpbvXO+rua2tAbvmHyswNqmNj27eJVOP6iw7W2r3GGbo4clB2V3MhlXR8Zj/9F2NmNBgw6rGb7F97e2qU0bWto1auhOquq3+bmNrR2q2sE0sH8/tbSn1Z7OaEjCe2BTW1rXPbFQ3zxpvAb275evZXGqUeOjidfWajbb+nuhYUOLVje15f+DL1Qm41rV1Krdhwzsdp31ze2q6mf5MWpPZ7S2uW2rz9lWWzs6JfgQB4C/dlsL8WDaKQCALRHiABAwQhwAAkaIA0DACHEACBghDgABI8QBIGCEOAAEbLue7GNmKUnvFvj0EZJWlbCcUqGu3qGu3qGu3qvU2oqpa193T7y+5XYN8WKYWV13ZyyVE3X1DnX1DnX1XqXW1ld10U4BgIAR4gAQsJBCfGq5C+gGdfUOdfUOdfVepdbWJ3UF0xMHAGwppJk4AKCLIELczE4xswVmttjMpmyH7S01s9lmNsvM6qJlw81smpktim6HdVr/yqi2BWb2yU7LJ0avs9jMrrdeXsnAzG41swYzm9NpWcnqMLMdzez30fKXzaymyNquNrP3o3GbZWanbs/azGxvM3vKzOaZ2Vwzu6wSxmwrdZV7vAaa2Stm9kZU179XyHh1V1dZx6vTa/Yzs9fN7KFKGC+5e0X/k9RP0tuSxkoaIOkNSRP6eJtLJY3osuxnkqZE96dI+ml0f0JU046SxkS19osee0XSkZJM0iOSPtXLOo6RdKikOX1Rh6R/kXRjdP9sSb8vsrarJX07Yd3tUpukkZIOje4PkbQw2nZZx2wrdZV7vEzSztH9/pJelnREBYxXd3WVdbw6be9ySb+T9FAl/E6WLZx7MWBHSnqs09dXSrqyj7e5VFuG+AJJI6P7IyUtSKpH0mNRzSMlze+0/BxJNxVQS43iQVmyOnLrRPerlD0RwYqorbtfsu1eW/S8P0k6qZLGrEtdFTNekgZJek3S4ZU0Xl3qKvt4SRotabqkE7Q5xMs6XiG0U0ZJWtbp6/poWV9ySY+b2Uwzuzhatoe7L5ek6DZ38c3u6hsV3e+6vFilrCP/HHfvkLRe0m5F1neJmb1p2XZL7s/K7V5b9GfoIcrO4ipmzLrUJZV5vKLWwCxJDZKmuXtFjFc3dUnlf39dJ+kKSZlOy8o6XiGEeFIfua8PqZns7odK+pSkr5nZMVtZt7v6tnfdhdRR6hp/JWk/SQdLWi7pv8pRm5ntLOk+Sd9w9w1bW7XMdZV9vNw97e4HKzvDnGRmB2ztWyhzXWUdLzM7TVKDu8/sqf7tWVcIIV4vae9OX4+W9EFfbtDdP4huGyQ9IGmSpJVmNlKSotuGHuqrj+6Xuu5S1pF/jplVSdpV0ppCC3P3ldEvX0bSr5Udt+1am5n1VzYo73T3+6PFZR+zpLoqYbxy3H2dpBmSTlEFjFdSXRUwXpMlnW5mSyXdLekEM7tDZR6vEEL8VUnjzGyMmQ1Qttn/YF9tzMwGm9mQ3H1JJ0uaE23z/Gi185Xtaypafna0V3mMpHGSXon+rNpoZkdEe56/1Ok5xShlHZ1f60xJT3rUjCtE7o0c+Yyy47bdaote4xZJ89z92k4PlXXMuqurAsar2syGRvd3kvQJSfNV/vFKrKvc4+XuV7r7aHevUTaHnnT3c8s9Xr3ayVauf5JOVXaP/tuSrurjbY1Vdo/yG5Lm5ranbF9quqRF0e3wTs+5KqptgTodgSKpVtk32tuSblDvd4Ddpeyfje3K/g99YSnrkDRQ0j2SFiu7t3xskbX9VtJsSW9Gb8aR27M2SUcp+6fnm5JmRf9OLfeYbaWuco/XgZJej7Y/R9K/lfq9XuK6yjpeXWo8Tpt3bJZ1vDhjEwACFkI7BQDQDUIcAAJGiANAwAhxAAgYIQ4AASPEASBghDgABIwQB4CA/T9gcl+Fx0akggAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plot(losses)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "860bd472",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD4CAYAAAAXUaZHAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAAsTAAALEwEAmpwYAAAOK0lEQVR4nO3da4xcd3nH8e9Tm0AuRHGUtWty6SaSFRqQ2kRbCKRCqKZSFEc4LxrkSkFb5MhC4hJoEVraF6EvIlkqqtJKLZKVQJeSAlaIsJWohWgLaiu1JuskJSQOzc11TLbx0hKCeJELffpiTqrNeh3PzJnZmX32+5FGM+c28zye9W//e86ZM5GZSJJq+ZVRFyBJGjzDXZIKMtwlqSDDXZIKMtwlqaCNoy4A4IILLsjJyclRlyFJa8rhw4d/kpkTKy0bi3CfnJxkfn5+1GVI0poSEf95qmXulpGkggx3SSrIcJekggx3SSrIcJekggx3SSrIcJekggx3SSrIcJekgsbiE6rrxeTMfSfNO7p3xwgqkVSdI3dJKshwl6SCDHdJKshwl6SCDHdJKshwl6SCDHdJKshwl6SCDHdJKshwl6SCvPxAH7yMgKRx58hdkgoy3CWpIHfLDIi7aiSNE0fuklSQ4S5JBRnuklSQ4S5JBRnuklTQac+WiYgvAdcDJzLznc2884FvAJPAUeBDmfnTZtnngN3AL4FPZua3h1L5GrDSGTSStBq6Gbn/DXDtsnkzwFxmbgPmmmki4gpgF/COZpu/jogNA6tWktSV04Z7Zv4T8D/LZu8EZpvHs8ANS+Z/PTNfysxngCeBdw2mVElSt/rd574lMxcAmvvNzfwLgWeXrHe8mSdJWkWDPqAaK8zLFVeM2BMR8xExv7i4OOAyJGl96zfcn4+IrQDN/Ylm/nHg4iXrXQQ8t9ITZOa+zJzKzKmJiYk+y5AkraTfcD8ITDePp4EDS+bviog3R8SlwDbg++1KlCT1qptTIb8GvB+4ICKOA7cCe4H9EbEbOAbcCJCZj0bEfuAx4FXgY5n5yyHVLkk6hdOGe2b+/ikWbT/F+rcBt7UpSpLUjp9QlaSCDHdJKshwl6SCDHdJKshwl6SCDHdJKshwl6SCDHdJKshwl6SCDHdJKshwl6SCDHdJKshwl6SCDHdJKshwl6SCDHdJKshwl6SCDHdJKui0X7On4Zqcue+keUf37hhBJZIqceQuSQUZ7pJUkOEuSQUZ7pJUkOEuSQUZ7pJUkOEuSQUZ7pJUkOEuSQW1+oRqRHwauBlI4BHgI8BZwDeASeAo8KHM/GmrKkdopU+QStK463vkHhEXAp8EpjLzncAGYBcwA8xl5jZgrpmWJK2itrtlNgJnRsRGOiP254CdwGyzfBa4oeVrSJJ61He4Z+aPgS8Ax4AF4GeZ+R1gS2YuNOssAJtX2j4i9kTEfETMLy4u9luGJGkFbXbLbKIzSr8UeBtwdkTc1O32mbkvM6cyc2piYqLfMiRJK2izW+YDwDOZuZiZrwD3AO8Fno+IrQDN/Yn2ZUqSetEm3I8BV0fEWRERwHbgCHAQmG7WmQYOtCtRktSrvk+FzMxDEXE38CDwKvAQsA84B9gfEbvp/AK4cRCFSpK61+o898y8Fbh12eyX6IziJUkj4idUJakgw12SCjLcJakgw12SCjLcJakgw12SCjLcJakgw12SCjLcJakgw12SCjLcJakgw12SCjLcJakgw12SCjLcJakgw12SCmr1ZR0ajsmZ+06ad3TvjhFUImmtcuQuSQUZ7pJUkOEuSQUZ7pJUkOEuSQUZ7pJUkOEuSQUZ7pJUkOEuSQUZ7pJUkOEuSQW1CveIOC8i7o6IxyPiSES8JyLOj4j7I+KJ5n7ToIqVJHWn7cj9L4B/yMy3A78BHAFmgLnM3AbMNdOSpFXUd7hHxLnA+4A7ATLz5cx8AdgJzDarzQI3tCtRktSrNiP3y4BF4MsR8VBE3BERZwNbMnMBoLnfvNLGEbEnIuYjYn5xcbFFGZKk5dqE+0bgKuCLmXkl8At62AWTmfsycyozpyYmJlqUIUlark24HweOZ+ahZvpuOmH/fERsBWjuT7QrUZLUq77DPTP/C3g2Ii5vZm0HHgMOAtPNvGngQKsKJUk9a/s1e58A7oqIM4CngY/Q+YWxPyJ2A8eAG1u+hiSpR63CPTMfBqZWWLS9zfNKktrxE6qSVFDb3TLlTM7cN+oSJKk1R+6SVJDhLkkFGe6SVJDhLkkFGe6SVJDhLkkFGe6SVJDhLkkFGe6SVJDhLkkFefmBNWL5ZRGO7t0xokokrQWO3CWpIMNdkgoy3CWpIMNdkgoy3CWpIMNdkgoy3CWpIMNdkgoy3CWpIMNdkgpa15cfWP6RfkmqwpG7JBVkuEtSQYa7JBXUOtwjYkNEPBQR9zbT50fE/RHxRHO/qX2ZkqReDGLkfgtwZMn0DDCXmduAuWZakrSKWoV7RFwE7ADuWDJ7JzDbPJ4FbmjzGpKk3rU9FfJ24LPAW5fM25KZCwCZuRARm1faMCL2AHsALrnkkpZlrD8rncbptzNJek3fI/eIuB44kZmH+9k+M/dl5lRmTk1MTPRbhiRpBW1G7tcAH4yI64C3AOdGxFeB5yNiazNq3wqcGEShkqTu9T1yz8zPZeZFmTkJ7AL+MTNvAg4C081q08CB1lVKknoyjMsP7AX2R8Ru4Bhw4xBeo2deakDSejKQcM/M7wHfax7/N7B9EM8rSeqPn1CVpIIMd0kqyHCXpIIMd0kqyHCXpIIMd0kqyHCXpIIMd0kqyHCXpIIMd0kqyHCXpIIMd0kqyHCXpIIMd0kqyHCXpIIMd0kqyHCXpIIMd0kqyHCXpIIMd0kqaCBfkK3xNTlz30nzju7dMYJKJK0mR+6SVJDhLkkFGe6SVJD73AtZaf+6pPXJkbskFWS4S1JBhrskFdR3uEfExRHx3Yg4EhGPRsQtzfzzI+L+iHiiud80uHIlSd1oM3J/FfijzPx14GrgYxFxBTADzGXmNmCumZYkraK+wz0zFzLzwebxz4EjwIXATmC2WW0WuKFljZKkHg1kn3tETAJXAoeALZm5AJ1fAMDmU2yzJyLmI2J+cXFxEGVIkhqtwz0izgG+CXwqM1/sdrvM3JeZU5k5NTEx0bYMSdISrcI9It5EJ9jvysx7mtnPR8TWZvlW4ES7EiVJvWpztkwAdwJHMvPPlyw6CEw3j6eBA/2XJ0nqR5vLD1wDfBh4JCIebub9MbAX2B8Ru4FjwI2tKpQk9azvcM/MfwHiFIu39/u8kqT2/ISqJBXkVSHXoeVXj/SbmaR6HLlLUkGO3OX3rEoFOXKXpIIMd0kqyHCXpIIMd0kqyHCXpIJKni2z0tkfkrSeOHKXpIIMd0kqyHCXpIIMd0kqyHCXpIIMd0kqyHCXpIIMd0kqyHCXpIIMd0kqqMTlB7zcgCS9niN3SSqoxMhdo9PNl237NX7S6jPctSIDWVrb3C0jSQU5clfXPHAtrR2O3CWpIEfuGqg2o/tuDs5K6s7QRu4RcW1E/CginoyImWG9jiTpZEMZuUfEBuCvgN8FjgMPRMTBzHxsGK+nmro9Y6eb9br9i6Kb5+/2L4phn3HU7/OP4kyo1f6rzLO9hjdyfxfwZGY+nZkvA18Hdg7ptSRJy0RmDv5JI34PuDYzb26mPwy8OzM/vmSdPcCeZvJy4EctXvIC4Ccttl9r1lu/YM/rhT335tcyc2KlBcM6oBorzHvdb5HM3AfsG8iLRcxn5tQgnmstWG/9gj2vF/Y8OMPaLXMcuHjJ9EXAc0N6LUnSMsMK9weAbRFxaUScAewCDg7ptSRJywxlt0xmvhoRHwe+DWwAvpSZjw7jtRoD2b2zhqy3fsGe1wt7HpChHFCVJI2Wlx+QpIIMd0kqaKzD/XSXMIiOv2yW/yAirup223HVb88RcXFEfDcijkTEoxFxy+pX358273OzfENEPBQR965e1e20/Nk+LyLujojHm/f7Patbfe9a9vvp5mf6hxHxtYh4y+pW358uen57RPxrRLwUEZ/pZduuZOZY3ugciH0KuAw4A/h34Ipl61wH/D2d8+qvBg51u+043lr2vBW4qnn8VuA/qve8ZPkfAn8H3DvqflajZ2AWuLl5fAZw3qh7Gla/wIXAM8CZzfR+4A9G3dOAet4M/BZwG/CZXrbt5jbOI/duLmGwE/hKdvwbcF5EbO1y23HUd8+ZuZCZDwJk5s+BI3T+Y4y7Nu8zEXERsAO4YzWLbqnvniPiXOB9wJ0AmflyZr6wirX3o9V7TOesvjMjYiNwFmvjMzOn7TkzT2TmA8ArvW7bjXEO9wuBZ5dMH+fksDrVOt1sO47a9Pz/ImISuBI4NPgSB65tz7cDnwX+d0j1DUObni8DFoEvN7ui7oiIs4dZ7AD03W9m/hj4AnAMWAB+lpnfGWKtg9ImgwaSX+Mc7qe9hMEbrNPNtuOoTc+dhRHnAN8EPpWZLw6wtmHpu+eIuB44kZmHB1/WULV5nzcCVwFfzMwrgV8A435Mqc17vInOqPVS4G3A2RFx04DrG4Y2GTSQ/BrncO/mEganWmetXv6gTc9ExJvoBPtdmXnPEOscpDY9XwN8MCKO0vnT9Xci4qvDK3Vg2v5sH8/M1/4qu5tO2I+zNv1+AHgmMxcz8xXgHuC9Q6x1UNpk0GDya9QHHt7ggMRG4Gk6v7FfO6jwjmXr7OD1B2G+3+2243hr2XMAXwFuH3Ufq9XzsnXez9o5oNqqZ+Cfgcubx58H/mzUPQ2rX+DdwKN09rUHnYPJnxh1T4Poecm6n+f1B1QHkl8j/0c4zT/QdXTO+ngK+JNm3keBjzaPg86XgjwFPAJMvdG2a+HWb8/Ab9P50+0HwMPN7bpR9zPs93nJc6yZcG/bM/CbwHzzXn8L2DTqfobc758CjwM/BP4WePOo+xlQz79KZ5T+IvBC8/jcU23b683LD0hSQeO8z12S1CfDXZIKMtwlqSDDXZIKMtwlqSDDXZIKMtwlqaD/A3/CDzaQPkv/AAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# For sigma = 0.1\n",
    "hist(jnp.linalg.norm(model.apply(params, batch['qn']) - batch['st'], axis=1)/jnp.linalg.norm(batch['st'], axis=1),64,range=[0,0.1]);\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "aa133c47",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(-1.0, 1.0)"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAY4AAAD8CAYAAABgmUMCAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAAsTAAALEwEAmpwYAAAZFklEQVR4nO3df7BkZZ3f8fdn7tzRAakd0AGHARStKVgsAuotwGWTlVUUKHXAbDYQ45LsVk2RSNVqGStjmXJJrVWyUq6V3WKlRpdaTFRiSsCJogMSU2Td4DIgv0aYZWRR5kdgREFdJjID3/zR506aO9339rnd9xe8X1Vd95zzPKf7uU/fez59fj2dqkKSpEEtW+gGSJKWFoNDktSKwSFJasXgkCS1YnBIkloxOCRJrYwkOJJcm+SJJA/0KU+SP0uyI8l9Sd7UVXZeku1N2cZRtEeSNHdGtcfxV8B505SfD6xrHhuAzwIkGQOubspPAS5JcsqI2iRJmgMjCY6quh346TRV1gNfqI47gFVJ1gBnADuq6pGqeha4vqkrSVqkls/T66wFHuua39ks67X8zF5PkGQDnb0VDj/88DeffPLJc9NSSXqRuuuuu35SVauHfZ75Co70WFbTLD90YdUmYBPAxMREbd26dXStk6SXgCQ/GsXzzFdw7ASO75o/DtgNrOizXJK0SM3X5bibgd9rrq46C3i6qvYAdwLrkpyYZAVwcVNXkrRIjWSPI8mXgbcCr0qyE/gjYBygqq4BbgYuAHYAzwD/uik7kORyYAswBlxbVdtG0SZJ0twYSXBU1SUzlBfwgT5lN9MJFknSEuCd45KkVgwOSVIrBockqRWDQ5LUisEhSWrF4JAktWJwSJJaMTgkSa0YHJKkVgwOSVIrBockqRWDQ5LUisEhSWrF4JAktWJwSJJaMTgkSa0YHJKkVgwOSVIrIwmOJOcl2Z5kR5KNPco/kuSe5vFAkueSHNWUPZrk/qZs6yjaI0maO0N/53iSMeBq4FxgJ3Bnks1V9YPJOlV1FXBVU//dwIeq6qddT3NOVf1k2LZIkubeKPY4zgB2VNUjVfUscD2wfpr6lwBfHsHrSpIWwCiCYy3wWNf8zmbZIZIcBpwHfLVrcQG3JLkryYYRtEeSNIeGPlQFpMey6lP33cB3pxymOruqdic5Grg1yUNVdfshL9IJlQ0AJ5xwwrBtliTN0ij2OHYCx3fNHwfs7lP3YqYcpqqq3c3PJ4Ab6Rz6OkRVbaqqiaqaWL169dCNliTNziiC405gXZITk6ygEw6bp1ZK8mvAbwFf61p2eJIjJqeBdwAPjKBNkqQ5MvShqqo6kORyYAswBlxbVduSXNaUX9NUvQi4par+oWv1Y4Abk0y25UtV9a1h2yRJmjup6nc6YvGamJiorVu95UOS2khyV1VNDPs83jkuSWrF4JAktWJwSJJaMTgkSa0YHJKkVgwOSVIrBockqRWDQ5LUisEhSWrF4JAktWJwSJJaMTgkSa0YHJKkVgwOSVIrBockqRWDQ5LUisEhSWrF4JAktWJwSJJaGUlwJDkvyfYkO5Js7FH+1iRPJ7mneXx80HUlSYvL8mGfIMkYcDVwLrATuDPJ5qr6wZSq/6uq3jXLdSVJi8Qo9jjOAHZU1SNV9SxwPbB+HtaVJC2AUQTHWuCxrvmdzbKp3pLk3iTfTPKGluuSZEOSrUm27t27dwTNliTNxiiCIz2W1ZT5u4HXVNVpwJ8DN7VYt7OwalNVTVTVxOrVq2fbVknSkEYRHDuB47vmjwN2d1eoqp9X1S+b6ZuB8SSvGmRdSdLiMorguBNYl+TEJCuAi4HN3RWSvDpJmukzmtd9cpB1JUmLy9BXVVXVgSSXA1uAMeDaqtqW5LKm/Brgd4B/k+QAsA+4uKoK6LnusG2SJM2ddLbfS8vExERt3bp1oZshSUtKkruqamLY5/HOcUlSKwaHJKkVg0OS1IrBIUlqxeCQJLVicEiSWjE4JEmtGBySpFYMDklSKwaHJKkVg0OS1IrBIUlqxeCQJLVicEiSWjE4JEmtGBySpFYMDklSKwaHJKmVkQRHkvOSbE+yI8nGHuXvS3Jf8/ibJKd1lT2a5P4k9yTx+2AlaZFbPuwTJBkDrgbOBXYCdybZXFU/6Kr298BvVdXPkpwPbALO7Co/p6p+MmxbJElzbxR7HGcAO6rqkap6FrgeWN9doar+pqp+1szeARw3gteVJC2AUQTHWuCxrvmdzbJ+/gD4Ztd8AbckuSvJhn4rJdmQZGuSrXv37h2qwZKk2Rv6UBWQHsuqZ8XkHDrB8Ztdi8+uqt1JjgZuTfJQVd1+yBNWbaJziIuJiYmezy9Jmnuj2OPYCRzfNX8csHtqpST/CPg8sL6qnpxcXlW7m59PADfSOfQlSVqkRhEcdwLrkpyYZAVwMbC5u0KSE4AbgPdX1d91LT88yRGT08A7gAdG0CZJ0hwZ+lBVVR1IcjmwBRgDrq2qbUkua8qvAT4OvBL4iyQAB6pqAjgGuLFZthz4UlV9a9g2SZLmTqqW3umCiYmJ2rrVWz4kqY0kdzUf2ofineOSpFYMDklSKwaHJKkVg0OS1IrBIUlqxeCQJLVicEiSWjE4JEmtGBySpFYMDklSKwaHJKmVUXwfh7Ro3PT9XVy1ZTu7n9rHsatW8pF3nsSFb5zue8VeOuwbjYrBoXn1H266ny9/7zGeq2Is4ZIzj+cTF546kue+6fu7+OgN97Nv/3MA7HpqHx+94X6AgxvIF9PGs83vMl3fAFy1ZTu7ntrHWMJzVaxd4n2jueXouJpzkxu4XU/t61tn7aqVnHPyar7z0N5WG/XujeeyZqM31aqV4xz+suXsemof4YVfT7lyfIxPvvfUJbeBfN/n/jff/eFPX7Cs+3fpDujJr+js9Z8+vgz2Pz/z6019f35t5TgJPPXM/lYB3P1+rTpsnCp4at9+A2uejGp0XINDc2rqJ902Jjfy/UIFmPVzd1sWeL75N1i1cpwr3vOGVhvB7kBMYOXyZTyz//meG8PuDfpUYwlnve5IHn1y3ws+/XeH3ZGHjXPKmiMOCQ0GLJ9rq6YJlEH/FsbHwlW/c9qswuPFtEc5FwwOg2NJOPvK/zHtnsZsrRwfIxTPDPJxeRaOOWIFP/nl/hds4AOsWL6MXx2Ym9d8MVoWqOq9tzOdIw8b5/sffwcweBj0CqaV42P80zevbb0n+2I1quDwHIdaGeSfuLvOXH0sGXYvYyaP/+LZQ5YVGBotPT/LP4CfPbMfGOy81aSrtmw/5O9i3/7n+OIdPz74dzjd+hqcwfESM8yu/KAnn0dx+Ejqd1hv3/7nuGrL9kP+bnf32bOdml391tfgRhIcSc4D/hOd7xz/fFVdOaU8TfkFwDPAv6qquwdZV6PT5tNbr3U//JV7e/4TX7F524wnqKW2uvcUpuoVEseuWjnwYdF+IaPBDH0DYJIx4GrgfOAU4JIkp0ypdj6wrnlsAD7bYl2NSL9d+Q9/5V5u+v6uvutNBk6/QHhq3352NYelDA2NynR/SceuWnnIso+88yRWjo+9YFkOqdV/fQ1uFHscZwA7quoRgCTXA+uBH3TVWQ98oTpn4u9IsirJGuC1A6yrEen3Keu5Kj56w/1s/dFP+fq9e3hqX+f48uTVRmPuRWgRWTk+dvCqum6Te83dh2LPOXk1X71r1yEnzHutr8GNIjjWAo91ze8EzhygztoB1wUgyQY6eyuccMIJw7X4RaTNOYvpduX37X+O/3LHj1+wbPLEpqGh+Tb1fptJY8m0991c+Ma1h5RNvOYoL9EdsVEER6+9wanveb86g6zbWVi1CdgEnctx2zTwxartOYuPvPMkT1xrpFaOj/Gy5csO7qVOZ3xZeMXLl/PUM/tZOb6s76XUkyMK9NpTmM3Nmr3CRMMZxSCHO4Hju+aPA3YPWGeQddVHv3MWV23Z3rP+hW9cyyffeypj6XfkV+ovwNmvP4q1q1YSOjdmfvK9p/L0AKEBsP/54rAVy/nMPz+d6nP2YeX4GJ/+3dP4xIWn8sn3nnrIaxkAi8Mo9jjuBNYlORHYBVwM/IspdTYDlzfnMM4Enq6qPUn2DrCu+uh3zmK6K0Ym//Gm7nn0OzQgTSrg0Sf38d2Nv/2C5TMNJ9Nt91P7en7ggUMPQ7mnsHgNvcdRVQeAy4EtwIPAV6pqW5LLklzWVLsZeATYAXwO+LfTrTtsm14q+l0ZMtMVI5N7Ht2f5t53lueNNLNeH0p6Xc3Uz7GrVvb9YPN8lUGxRIzkPo6quplOOHQvu6ZruoAPDLquBtPrnMWgV4z0+zQ39QS5XpwmzxcAPU8c9xsqpteHkqlXM/W7lyd0/mb77aF4iezS4Z3jS1ivyw+HuWLkExeeevAKlLkYX0qLx9RDQlO1/VDS/UGk1+gBAd531gl9D5V6iezS4iCH6unEjd/wnMeAVq0c59kDz/W9Smhy6PJ+Q4f3G6bl8BVjPHvgefZ3Dfi0cnyMl48vOziW02ysXbXykPMUvQw7PM106zqK7cJwdFyDY07NZlTbIw8b55f/98ALNnRTBfiN1x/Fo0/uO/idDE/v2z/rwfDm0r8864R5uQegX19PBszU14dDP7GPLwsE9j/3/ztyfCxQHBI8Xp300uXouJpTvQ5V9No4TVo5PsYfvfsNAFyxedvB6/oPG1/Gy8bHpv3Cn5u+v4v/+N+3HfwUvWrlOO86bc3BobBfPt4Zyrw7XI7s8yVA55y8mi9978c9g+jwFWP8w7P972E5fMUYzzz73CHtnOuN7HRXx013ZVGvQBlkmaGhYbnHob56HU6Axf81o/2+l6H7ZPBiav90exyDHFKSBuWhKoND01hKx9CnC7rF2mYtTR6qkqaxlG4eG/XVcdJcMzikRWApBZ00irGqJEkvIQaHJKkVg0OS1IrBIUlqxeCQJLVicEiSWjE4JEmtGBySpFYMDklSKwaHJKmVoYIjyVFJbk3ycPPzyB51jk/ynSQPJtmW5A+7yq5IsivJPc3jgmHaI0mae8PucWwEbquqdcBtzfxUB4APV9WvA2cBH0hySlf5Z6rq9Obhd49L0iI3bHCsB65rpq8DLpxaoar2VNXdzfQvgAcBR3OTpCVq2OA4pqr2QCcggKOnq5zktcAbge91Lb48yX1Jru11qKtr3Q1JtibZunfv3iGbLUmarRmDI8m3kzzQ47G+zQsleQXwVeCDVfXzZvFngdcDpwN7gE/3W7+qNlXVRFVNrF69us1LS5JGaMbv46iqt/crS/J4kjVVtSfJGuCJPvXG6YTGF6vqhq7nfryrzueAr7dpvCRp/g17qGozcGkzfSnwtakVkgT4S+DBqvrTKWVrumYvAh4Ysj2SpDk2bHBcCZyb5GHg3GaeJMcmmbxC6mzg/cBv97js9lNJ7k9yH3AO8KEh2yNJmmNDfXVsVT0JvK3H8t3ABc30XwPps/77h3l9SdL8885xSVIrBockqRWDQ5LUisEhSWrF4JAktWJwSJJaMTgkSa0YHJKkVgwOSVIrBockqRWDQ5LUisEhSWrF4JAktWJwSJJaMTgkSa0YHJKkVgwOSVIrBockqZWhgiPJUUluTfJw8/PIPvUebb5b/J4kW9uuL0laPIbd49gI3FZV64Dbmvl+zqmq06tqYpbrS5IWgWGDYz1wXTN9HXDhPK8vSZpnwwbHMVW1B6D5eXSfegXckuSuJBtmsT5JNiTZmmTr3r17h2y2JGm2ls9UIcm3gVf3KPpYi9c5u6p2JzkauDXJQ1V1e4v1qapNwCaAiYmJarOuJGl0ZgyOqnp7v7IkjydZU1V7kqwBnujzHLubn08kuRE4A7gdGGh9SdLiMeyhqs3Apc30pcDXplZIcniSIyangXcADwy6viRpcRk2OK4Ezk3yMHBuM0+SY5Pc3NQ5BvjrJPcCfwt8o6q+Nd36kqTFa8ZDVdOpqieBt/VYvhu4oJl+BDitzfqSpMXLO8clSa0YHJKkVgwOSVIrBockqRWDQ5LUisEhSWrF4JAktWJwSJJaMTgkSa0YHJKkVgwOSVIrBockqRWDQ5LUisEhSWrF4JAktWJwSJJaMTgkSa0YHJKkVoYKjiRHJbk1ycPNzyN71DkpyT1dj58n+WBTdkWSXV1lFwzTHknS3Bt2j2MjcFtVrQNua+ZfoKq2V9XpVXU68GbgGeDGriqfmSyvqpuHbI8kaY4NGxzrgeua6euAC2eo/zbgh1X1oyFfV5K0QIYNjmOqag9A8/PoGepfDHx5yrLLk9yX5Npeh7okSYvLjMGR5NtJHujxWN/mhZKsAN4D/LeuxZ8FXg+cDuwBPj3N+huSbE2yde/evW1eWpI0QstnqlBVb+9XluTxJGuqak+SNcAT0zzV+cDdVfV413MfnE7yOeDr07RjE7AJYGJiomZqtyRpbgx7qGozcGkzfSnwtWnqXsKUw1RN2Ey6CHhgyPZIkubYsMFxJXBukoeBc5t5khyb5OAVUkkOa8pvmLL+p5Lcn+Q+4BzgQ0O2R5I0x2Y8VDWdqnqSzpVSU5fvBi7omn8GeGWPeu8f5vUlSfPPO8clSa0YHJKkVgwOSVIrBockqRWDQ5LUisEhSWrF4JAktWJwSJJaMTgkSa0YHJKkVgwOSVIrBockqRWDQ5LUisEhSWrF4JAktWJwSJJaMTgkSa0YHJKkVgwOSVIrQwVHkn+WZFuS55NMTFPvvCTbk+xIsrFr+VFJbk3ycPPzyGHaI0mae8PucTwAvBe4vV+FJGPA1cD5wCnAJUlOaYo3ArdV1TrgtmZekrSIDRUcVfVgVW2fodoZwI6qeqSqngWuB9Y3ZeuB65rp64ALh2mPJGnuLZ+H11gLPNY1vxM4s5k+pqr2AFTVniRH93uSJBuADc3sr5I8MBeNHbFXAT9Z6EYMwHaOzlJoI9jOUVsq7TxpFE8yY3Ak+Tbw6h5FH6uqrw3wGumxrAZY74UrVG0CNjVt2lpVfc+pLBa2c7SWQjuXQhvBdo7aUmrnKJ5nxuCoqrcP+Ro7geO75o8DdjfTjydZ0+xtrAGeGPK1JElzbD4ux70TWJfkxCQrgIuBzU3ZZuDSZvpSYJA9GEnSAhr2ctyLkuwE3gJ8I8mWZvmxSW4GqKoDwOXAFuBB4CtVta15iiuBc5M8DJzbzA9i0zDtnke2c7SWQjuXQhvBdo7aS6qdqWp9ukGS9BLmneOSpFYMDklSK4s2OKYbziTJR5vhS7YneWef9ed9OJMk/zXJPc3j0ST39Kn3aJL7m3ojuTyujSRXJNnV1dYL+tTrOVTMPLXxqiQPJbkvyY1JVvWptyB9OVPfpOPPmvL7krxpvtrW1Ybjk3wnyYPN/9If9qjz1iRPd/0tfHy+29m0Y9r3cZH050ld/XRPkp8n+eCUOgvSn0muTfJE9/1tg24DZ/V/XlWL8gH8Op2bVf4nMNG1/BTgXuBlwInAD4GxHut/CtjYTG8E/mSe2/9p4ON9yh4FXrWAfXsF8O9mqDPW9O3rgBVNn58yj218B7C8mf6Tfu/fQvTlIH0DXAB8k859TGcB31uA93kN8KZm+gjg73q0863A1+e7bW3fx8XQnz3+Bv4P8JrF0J/APwHeBDzQtWzGbeBs/88X7R5H9R/OZD1wfVX9qqr+HthBZ1iTXvUWZDiTJAF+F/jyfL3mHJhuqJg5V1W3VOeKPIA76Nz/s1gM0jfrgS9Uxx3AquZepXlTVXuq6u5m+hd0rmpcO59tGKEF788p3gb8sKp+tIBtOKiqbgd+OmXxINvAWf2fL9rgmEavIUx6/TO8YDgToO9wJnPgHwOPV9XDfcoLuCXJXc1QKgvh8maX/9o+u7CD9vN8+H06nzZ7WYi+HKRvFlP/keS1wBuB7/UofkuSe5N8M8kb5rdlB830Pi6q/qRzP1q/D4aLoT9hsG3grPp1Psaq6iuzG85kJEOYzNaAbb6E6fc2zq6q3emMzXVrkoeaTwzz0k7gs8Af0+m3P6ZzWO33pz5Fj3VH2s+D9GWSjwEHgC/2eZo578seBumbBf077ZbkFcBXgQ9W1c+nFN9N53DLL5tzXTcB6+a5iTDz+7iY+nMF8B7goz2KF0t/DmpW/bqgwVGzG85kuiFMus3JcCYztTnJcjpDzb95mufY3fx8IsmNdHYXR7qxG7Rvk3wO+HqPokH7edYG6MtLgXcBb6vmgGyP55jzvuxhkL6Z8/4bRJJxOqHxxaq6YWp5d5BU1c1J/iLJq6pqXgfsG+B9XBT92TgfuLuqHp9asFj6szHINnBW/boUD1VtBi5O8rIkJ9JJ87/tU28hhjN5O/BQVe3sVZjk8CRHTE7TOQk8ryP9Tjk2fFGf159uqJg5l+Q84N8D76mqZ/rUWai+HKRvNgO/11wNdBbw9ORhg/nSnGv7S+DBqvrTPnVe3dQjyRl0tglPzl8rB34fF7w/u/Q9orAY+rPLINvA2f2fz/fZ/xZXCVxEJw1/BTwObOkq+xidKwG2A+d3Lf88zRVYwCvpfDnUw83Po+ap3X8FXDZl2bHAzc306+hcuXAvsI3OYZn57tv/DNwP3Nf8kayZ2s5m/gI6V+L8cL7bSeeih8eAe5rHNYupL3v1DXDZ5HtP5xDA1U35/XRdGTiPbfxNOocd7uvqxwumtPPypu/upXMRwm8sQDt7vo+LrT+bdhxGJwh+rWvZgvcnnSDbA+xvtpt/0G8bOIr/c4cckSS1shQPVUmSFpDBIUlqxeCQJLVicEiSWjE4JEmtGBySpFYMDklSK/8PckXeAKVI4GYAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "scatter(batch['st'][:,1], model.apply(params, batch['qn'])[:,0] -  batch['st'][:,0])\n",
    "xlim(-10,10)\n",
    "ylim(-1,1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "fad1809b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.lines.Line2D at 0x7f325c221300>"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAD4CAYAAADiry33AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAAsTAAALEwEAmpwYAAAjDElEQVR4nO3df4zc9X3n8edr12PYJSFrDl8O1gaTCEGcQDDZI26RoiRcYwgp+OhVQJtG5Rr5IoVc4HquTC4qcBddLLnXQqUIDhGSIjggAWKRBNU5FVKu0UFYs3YcY/vkQoN37ZZNyeKAF7xev++P+c4yO54f35n97szsd18PaQXz/X535rPr2dd8vp+figjMzCy/ejpdADMzm18OejOznHPQm5nlnIPezCznHPRmZjm3pNMFqOb000+PVatWdboYZmYLxvbt238ZEcurnevKoF+1ahXDw8OdLoaZ2YIh6Re1zrnpxsws5xz0ZmY556A3M8s5B72ZWc456M3Mcq4rR920YuvIGFu27ePgxCRnDvSxcd15rF8z2OlimZl1XC6CfuvIGLc8vovJqWkAxiYmueXxXQAOezNb9HLRdLNl276ZkC+ZnJpmy7Z9HSqRmVn3yEXQH5yYbOq4mdlikougP3Ogr6njZmaLSS6CfuO68+gr9M461lfoZeO68zpUIjOz7pGLzthSh6tH3ZiZnSgXQQ/FsHewm5mdKBdNN2ZmVlvDoJd0n6RXJf28xnlJ+ktJ+yX9TNLFZecGJD0qaa+kPZJ+I8vCm5lZY2lq9N8GLq9z/grg3ORrA3BX2bk7gb+OiPOBDwN7WiummZm1qmEbfUQ8I2lVnUuuBu6PiACeTWrxZwBvAh8D/jB5nqPA0TmX2MzMmpJFG/0gcKDs8Why7H3AOPAtSSOS7pV0Sq0nkbRB0rCk4fHx8QyKZWZmkE3Qq8qxoHi3cDFwV0SsoVjD31TrSSLinogYioih5curbntoZmYtyCLoR4GVZY9XAAeT46MR8Vxy/FGKwW9mZm2URdA/AXwuGX2zFng9Ig5FxD8CBySVpqdeBryYweuZmVkTGnbGSnoI+DhwuqRR4FagABARdwNPAp8G9gNHgBvKvv1LwIOSlgIvVZwzM7M2SDPq5voG5wP4Yo1zO4ChlkpmZmaZ8MxYM7Occ9CbmeWcg97MLOcc9GZmOeegNzPLOQe9mVnOOejNzHLOQW9mlnMOejOznHPQm5nlnIPezCznHPRmZjnnoDczyzkHvZlZzjnozcxyzkFvZpZzDnozs5xruMNUXmwdGWPLtn0cnJjkzIE+Nq47j/VrBjtdLDOzedewRi/pPkmvSvp5jfOS9JeS9kv6maSLK873ShqR9IOsCt2srSNj3PL4LsYmJglgbGKSWx7fxdaRsU4VycysbdI03XwbuLzO+SuAc5OvDcBdFee/DOxppXBZ2bJtH5NT07OOTU5Ns2Xbvg6VyMysfRoGfUQ8A7xW55Krgfuj6FlgQNIZAJJWAFcC92ZR2FYdnJhs6riZWZ5k0Rk7CBwoezyaHAO4A/gT4HijJ5G0QdKwpOHx8fEMivWOMwf6mjpuZpYnWQS9qhwLSZ8BXo2I7WmeJCLuiYihiBhavnx5BsV6x8Z159FX6J11rK/Qy8Z152X6OmZm3SiLUTejwMqyxyuAg8C/A66S9GngZOBUSQ9ExGczeM2mlEbXeNSNmS1GWQT9E8CNkh4GPgq8HhGHgFuSLyR9HPjPnQj5kvVrBh3sZrYoNQx6SQ8BHwdOlzQK3AoUACLibuBJ4NPAfuAIcMN8FdbMzJrXMOgj4voG5wP4YoNrfgz8uJmCmZlZNrwEgplZzjnozcxyzkFvZpZzDnozs5xz0JuZ5ZyD3sws5xz0ZmY556A3M8u5RbHDlHeXMrPFLPdBX9pdqrTxSGl3KcBhb2aLQu6Dvt7uUrWC3ncAZpYnuQ/6ZneX8h2AmeVN7jtjm91dKu3+sltHxrh081Ocs+mHXLr5KW80bmZdK/dB3+zuUmnuAEq1/rGJSYJ3av0OezPrRrkP+vVrBvn6NRcwONCHgMGBPr5+zQU1m2HS3AGkrfWbmXWD3LfRQ3O7S21cd96sNno48Q6g2XZ/M7NOyn2Nvllp7gCabfc3M+skFTeIqnOBdB/wGeDViPhQlfMC7qS4neAR4A8j4gVJK4H7gX8FHAfuiYg70xRqaGgohoeHm/pB5ku1oZbACbX+Qq84ZekSXp+c8pBMM2s7SdsjYqjauTQ1+m8Dl9c5fwVwbvK1AbgrOX4M+OOI+ACwFviipNVpC91OtUbQbB0ZY+OjO2d1um58dCfArFr/sv4CBExMTrlz1sy6Tpo9Y5+RtKrOJVcD9yd7xz4raUDSGRFxCDiUPMevJe0BBoEXMyh3ZuqNm7/9+7uZmp59xzM1Hdz+/d2M/OmnZmrsl25+il8dmZp1XaNJWdXK4UlaZjYfsmijHwQOlD0eTY7NSD4o1gDP1XoSSRskDUsaHh8fz6BY6dQbQVMZ3iW/OjI1q7Y+185ZD9c0s/mURdCryrGZarCkdwGPATdFxOFaTxIR90TEUEQMLV++PINipdNqSG/87s6ZIJ5r56yHa5rZfMoi6EeBlWWPVwAHASQVKIb8gxHxeAavlbl6IT3QV6j5fVPHg9ue2A00Pymrkodrmtl8yiLonwA+p6K1wOsRcSgZjfNNYE9E/HkGrzMv6oX0bVd9kEJPtRuWoonJqZm29cmpaXpVvLbRpKxKWQ7X9NIMZlapYWespIeAjwOnSxoFbgUKABFxN/AkxaGV+ykOr7wh+dZLgT8AdknakRz7SkQ8mWH556wUxvU6Qm96ZEfN7y/vyJ2OmPmQWL9mMHUHa5pJWml4QTYzq6bhOPpO6KZx9ABr/uuPqnbMSlDt1zeYhHq18K5V06/2oTD8i9d46LkDTEfQK3H9R1fytfUX1Py+HonpKgUaHOjjJ5s+2cJPbgtJ2oqFR3jlU71x9A76FErj6cuHWhZ6dcLQyxJRbHYZq9LGniZ0t46M8V++t4s3j06fcK6v0MPvfGQFT+8dZ2xiElHW812DgJc3X9ngquYttMBYaOVtRuXdHFSvWKS9zhaeekG/KNa6matqzTufOH85Dzz7StXr39NXqBryUGxOuXTzUzXDptofYrnJqeOzXjfNx3Srbf1btu1jbGKS3uROYbCsvF/duosHn31l5vW7vZko781aaTfYaWUjnkby/AGaF7mv0Tf7Jkxz/Ve37qoZ8gD9hR4mp46nCuGSUoiWwjUrvT3if/zuh5v6w2v0YVNP+R1L5e/yE+cv5+m94x0JhEs3P9XyHVZJK4HWrhA8Z9MPq77fKu/m0l6Xlu8QuseirdE3W4tLc/3WkTEerBPyAEemjjdd1tJrtRKu9UwfD7Zs28fNj+xI3W5bq60/jYMTk2wdGeP27++e1a8xNjE568Ox2Rr1XAMzq0ltzdwRtPMuolZTYeXdXNrr0mr2DsG1/87IdY2+2VpcretL3zMfNe5KPYLj8/hPUlnbqhbKc7Gsv8BbU8dTf2Cl7bOoVWuE2iOmsuyorvfegOLPfetvf3BWaGVxF5FWp9rom7lDcO1/fi3aGn2ztbh6tbuxiUlufmRHU80xrZjPkIfZta25NNHU8sZbUzRzQ1P6nddaJbTWB+vk1DS3f3/3rA+V0r/R8C9eY+js004Y+lrNm28fY+vIWMPhsI1q/r86MjWz4F2j75mPiXBphgk3c11azdwhzEf/gKXjGn2K6/NIMKcmmqz0CH7vo2fx2PaxzD5wRLFDfGIy/V3KKUt7OXrsOFNln7Tltc20743y91aWNfpubfJoppaedf+AzTbXZYoXrGaXJqh2fV4FtWu57XQ84IFnX8n0riKgqZAHePPo9KyQh9nrDaWdvDY2Mcn7b3mSVZt+yKHXTwz5VibCfXXrLm5+ZMecFr2brxnTzWzV6Q17OifXNXpofdTNYqnZW33ltc1aE+fSGugrcNtVH2yqJr51ZKxmk+FAX4FTTlqSaoJU2vb70t/KQH+BCDLdSMdt9PPLE6ZaMB/t17bwLOsvMPKnnwKqT5xrRrUmm0YVkYtu/1Hqu5PS5LnBiudJ04TU6P2eVSB3axNUHjjoW1Reu08zA9Xyp9AD//LUPg5OTPKevgKvJ7uItaJ0d1DvrrE8rOtNykvzWr+/9iy+tv6Cmm3jJK+Tdkhtq6OFHO7tseiCfj7eWJXP6aYda9ay/gJXXnjGrBnF80nAX1x7UcMPlWaer9lOUzfXtM+iCvp2vLG2jozxx9/dyfR8j4W0XCn00NTQ0yz01qipt3KH2kqNvp1zCRa7RTXqph27NW3Zts8hb7P0F3qqbrVWrt0hD9VHVrUS8q2MFgJvqtMtchf07Xhj+U1qlZYu6eX31541s/lMN6sX8r0SotjMNNBXaDhkspGFOqQybxv45G5mbNZreTTzGrZ4TUxO8dj2sa6Ym9CqVps4y/uv3tNXQIKJI1MzC9lVToZr9e6gXfK40mnuavRz3b+11dco9KavyXV/nc9aUb6d5EKzrL9QdVx9o1ptKRRLk7kmJqf41ZGpmYldj20f43c+MphqQlW3aEfzb7vlrkaf9Voe9UbwlB9/8+1jqcY7e5hmvk1H0NujBdeH81ZFB0LaWm21UCw3OTXN03vH29rx2umVTrtRmj1j7wM+A7waER+qcl7AnRT3jT0C/GFEvJCcuzw51wvcGxGbMyx7TevXDGZSY2j0Zi9/jXM2/TDVcy6sP39rxfTx4KQlPbx9rAO9r4n+Qg9vHTueepG8ysXF0i5Alib82hmQWTS7tKP5t93SNN18G7i8zvkrgHOTrw3AXQCSeoFvJOdXA9dLWj2XwrZbrTf7TY/sOOFWNs2bYHCgj0LuGsusmk6GPICkpldCLQ+3tLXaNO/7dgZkFs0u7Wj+bbdU4+glrQJ+UKNG/z+BH0fEQ8njfcDHgVXAbRGxLjl+C0BEfL3R6w29+90x/JGPpP4h5suzL/1z3fM9Eu9bfgqnv+skfvnG27w0/ibH6/w+T+0rcLjJxbbM2kesfd9pALzwygRHj53YJLOkp4ehVctmHr/8yzf5p8Nv1XzG8r+Rdqj3N7v2ff8i9fP88o23eeW1SY4em2bpkl7OOq2vbT9Dq/S3fzuv69EPAgfKHo8mx6od/2jNQkobKN4RcOFJ3fELXbqkt+qbveR4BK+8Nsnp7zpp5k2w/9U3qdVAc3jy2HwU0ywj77xvzzqtj78ff5PKiuB0BL984+2Zys34r9+u83ztDXmo/Te7dElzq9KW/03nQRZBX22YQdQ5XlVE3APcA8WZsfz4xxkUbW5eSbGwWfm08NOBvxsZ46ZHdrSlfGY9QLONRLVmy5bPVj0d+K0aq3WWrqt1Ht4Zqrm2zaNrqv3NlspycReP9MlEnRFfWbQYjwIryx6vAA7WOb5glK+1XUtl++P6NYP0uyHe2qSVnoBT+5ZQ6JkdCtXaoCdqhHhpX+B6SzZ3aghlM+vjLyZZ1OifAG6U9DDFppnXI+KQpHHgXEnnAGPAdcDvZfB6bVUaXVNrDZ1qHTQnFXpb2iDcDIojZpadctK8Tcr71ZEpCr1iIFmNs9YQxHqjT+p1bg4O9HU0WLMadZcnaYZXPkSxc/V0SaPArUABICLuBp6kOLRyP8XhlTck545JuhHYRnF45X0RsXsefoa2aGZ8fq2akFkjAv77NRfOe/Pf1HRwyklL2HHrp2pes3HdeTUrNzfXKd9CHp3SKfO9lHPDoI+I6xucD+CLNc49SfGDIBfS1hRaWSKhR/O/Mbh1vyU9MPyL19oysa7R+PZ6lZtaSx8P9BVcm25SO5ZcyN3M2G5QqvE084fqkDcornDZrvXqB/oLDa+pVbmpVdu/7aoPZlrGxSDt5LS5cK9hhSxWrVu/ZtAzYK1l7XrvvPHWsZZXZXSnZ3baseSCa/RlGt1CNdOONugVLq3LTR2POdUa3emZjXYsueAafZl6t1CVq/SVPgRq1Yg+cf7yOZVlYa6BaAtNtyzUlbf135vRjiUXHPRl6t1CNbuGxtN7x6seH+grsKysbbSv0FN1iePA/ziLQaf/jbthoa5mK1F5045mMDfdlKl3C9VsO1qt469PTs3aYHnryBi3f3931cknHomfP6VZqYM1NuVoRqnCUG/iUj3dslBXOzoju918N4M56MvUGzdcazhZva3Sal1fausfm5j0+vQL3LL+AhPJRhuNVG6Ifenmp1oOeWg94AXzMla7VXlc/73bOOjLNJoUlXZmLFT/0ACYOHKUjd/dyVQyntIhv7C9NXWc33z/afzk71+re12190qngqz8jrIb5HH9927joK9Q6xaq2Z2rSsdve2L3rJ2n3jzaeg3Ous/k1DT/8M/1A7vUTLNl2z5ufmTHzHunE3sP11u3qRlZzuSsdydt2Ui1Hn27DQ0NxfDwcKeLkYlLNz/Vlj/muTYBtbIK4kIjwXtOLqTa8rGp56V2rXQwCcFqQfY7Hxms20bfK3E8YmaT7af3jqfeurK3R/TAzJ1j6TWz6OSrte7TXJ57vpcAWAwkzet69ItevTdp1rfntQJ9Sa+49l+v5Ac7D7UWZDnrLOgr9NYMoqw/fM+sE+al/p1qnY1P7x3n69dccMJdX2V5q6kWtiXL+gvc+tvFGarzEZ7z0XnaTWPy8/ih46Cfo0aTrLK8Pe8v9PDif7uCNVXWAZ+aDp7eOz6zSNVFt/+oqcDP0xIMpVp0rT/WaqFc6BHvOnkJE0feWc1x+Bev8cCzrzR8vVobxpeO11oA7ODE5KzVUZsJl7RNifMRUHnuPG3HujOd4KCfo0a1m1bWvanlpGRSRb11wktuu+qDTb1urc0oFppSLbq8hlgK0fL28a9fc0GqkBw6+7SaI66gWHsufV+tWmmazsZWarSdqgXnufM0r0M9Oz1fY8FrVLvJct2bUsDXG9JZsn7NIL+/9qxUM2z7Cr1c/9GVJ8zOq2ZZf4E7rr2Igb7ZC2L1F3pO2MxCZd8zV0qep69sU5dl/QU+u/asuhNNak3GAfjJpk/y8uYr+cmmT9btVP/Jpk9yx7UXVZ29WGoiqSdvm03n7ecpl9e7Fdfo5yhN7SardW9Kz5l2lMLX1l/QsEYK7+wGVLr2YBKK1UwcmapZk6zX/PDVrbtSNYNUUzn+vBlZ1dCaHXWV1fd2o7z9POXyerfiUTdzlGYEQrVrWun7vOPai05ojkj7h1arA7JWiDZ7fRqVZf7E+ctndR6ftKSHt4/NHvsz19Ec52z6YdXfc/lev2Yl8zGiqF086mYelf7xy5cxmJya5vbv7545X60G9Inzl/PI8weYmk4f9+VvtGbbZ5sdqzwfY5urlflr6y+Y9TjrEQ95raHZ/Mjr3YqDnmzC5Y23js16/KsjU2x8dCfwTsBVPmczQyHnOtGl1Qlf7X7DZ93B6Mk41qxuGuqZlVRBL+ly4E6Ke7/eGxGbK84vA+4D3g+8Bfz7iPh5cu5m4PMUWyp2ATdExFuZ/QRz1OpwqvIPh54aI1ampuuv9/16ypDPKpiafQPn4Q2f1xqaWTPSbA7eC3wD+C1gFHhe0hMR8WLZZV8BdkTEv5V0fnL9ZZIGgf8IrI6ISUnfAa4Dvp3xz9GyVjrrKj8c6g1LHJuYZOvIWFPD7soNOpjmLA8fWGZzkWZ45SXA/oh4KSKOAg8DV1dcsxr4G4CI2AuskvTe5NwSoE/SEqAfOJhJyTPSynCqah8O9dRaW7vaMLWSvkIvd1x7Ud2hf2ZmaaQJ+kHgQNnj0eRYuZ3ANQCSLgHOBlZExBjwZ8ArwCHg9Yj4UbUXkbRB0rCk4fHx6pt2zIc0Y9IrNTumttYGJeUbDkBx0hJ4/00zy1aaNvpqc24q2yo2A3dK2kGxHX4EOJa03V8NnANMAN+V9NmIeOCEJ4y4B7gHisMr0/4Ac5Wms66ys/Y9fc0vjFXrw8HNCmY239IE/SiwsuzxCiqaXyLiMHADgCQBLydf64CXI2I8Ofc48JvACUHfKY0666p11hZ6RaFHs1YGbDQu3sP5zKxT0gT988C5ks4Bxih2pv5e+QWSBoAjSRv+54FnIuKwpFeAtZL6gUngMqDrZkLVq1VXa4+fmg5Udp+zrL/AlReeUXPJWQ/nM7NOathGHxHHgBuBbcAe4DsRsVvSFyR9IbnsA8BuSXuBK4AvJ9/7HPAo8ALFJp0ekuaZhaJWk0v5QJu3po4zdPZpbm83s67kJRAaSLt2+VyWBjAzm6t6SyB49coG6g2BLLfQV7czs/xaVEsgtLLUQWVnba1ZsO5sNbNutWhq9LXWJa82kalSaU3yv7j2It598omfje5sNbNutmiCvt5SB2mUPigqx88v6y+4s9XMutqiabppdeeYrSNjVTdvLulfusQhb2ZdbdHU6FtZ6mDryBgbv7uz7ixYd8KaWbdbNEHfyj6XW7btmzX7tRp3wppZt1s0TTetrEveqLbuTlgzWwgWTdBD8wuINVov3p2wZrYQLJqmm1bUq60P9BUc8ma2IDjo61i/ZpDPrj3rhOOFHnHbVR/sQInMzJrnoG/ga+sv4I5rL2JwoA9RXNNmy+9+2LV5M1swFlUbfau8OYiZLWSu0ZuZ5ZyD3sws5xz0ZmY556A3M8u5VEEv6XJJ+yTtl7Spyvllkr4n6WeSfirpQ2XnBiQ9KmmvpD2SfiPLH8DMzOprGPSSeoFvUNwLdjVwvaTVFZd9BdgRERcCnwPuLDt3J/DXEXE+8GGK+86amVmbpKnRXwLsj4iXIuIo8DBwdcU1q4G/AYiIvcAqSe+VdCrwMeCbybmjETGRVeHNzKyxNEE/CBwoezyaHCu3E7gGQNIlwNnACuB9wDjwLUkjku6VdEq1F5G0QdKwpOHx8fEmfwwzM6slTdCryrHKtXs3A8sk7QC+BIwAxyhOyLoYuCsi1gBvAie08QNExD0RMRQRQ8uXL09ZfDMzayTNzNhRYGXZ4xXAwfILIuIwcAOAJAEvJ1/9wGhEPJdc+ig1gt7MzOZHmhr988C5ks6RtBS4Dnii/IJkZM3S5OHngWci4nBE/CNwQFJpGcjLgBczKruZmaXQsEYfEcck3QhsA3qB+yJit6QvJOfvBj4A3C9pmmKQ/1HZU3wJeDD5IHiJpOZvZmbtoYj6W+V1wtDQUAwPD3e6GGZmC4ak7RExVO2cZ8aameWcg97MLOcc9GZmOeegNzPLOQe9mVnOOejNzHLOQW9mlnMOejOznHPQm5nlnIPezCznHPRmZjnnoDczyzkHvZlZzjnozcxyzkFvZpZzDnozs5xz0JuZ5ZyD3sws51IFvaTLJe2TtF/Spirnl0n6nqSfSfqppA9VnO+VNCLpB1kV3MzM0mkY9JJ6gW8AVwCrgeslra647CvAjoi4EPgccGfF+S8De+ZeXDMza1aaGv0lwP6IeCkijgIPA1dXXLMa+BuAiNgLrJL0XgBJK4ArgXszK7WZmaWWJugHgQNlj0eTY+V2AtcASLoEOBtYkZy7A/gT4Hi9F5G0QdKwpOHx8fEUxTIzszTSBL2qHIuKx5uBZZJ2AF8CRoBjkj4DvBoR2xu9SETcExFDETG0fPnyFMUyM7M0lqS4ZhRYWfZ4BXCw/IKIOAzcACBJwMvJ13XAVZI+DZwMnCrpgYj4bAZlNzOzFNLU6J8HzpV0jqSlFMP7ifILJA0k5wA+DzwTEYcj4paIWBERq5Lve8ohb2bWXg1r9BFxTNKNwDagF7gvInZL+kJy/m7gA8D9kqaBF4E/mscym5lZExRR2dzeeUNDQzE8PNzpYpiZLRiStkfEULVznhlrZpZzDnozs5xz0JuZ5ZyD3sws5xz0ZmY556A3M8u5NDNjc2XryBhbtu3j4MQkZw70sXHdeaxfU7l0j5lZfiyqoN86MsYtj+9icmoagLGJSW55fBeAw97McmtRNd1s2bZvJuRLJqem2bJtX4dKZGY2/xZV0B+cmGzquJlZHiyqoD9zoK+p42ZmebCogn7juvPoK/TOOtZX6GXjuvM6VCIzs/m3qDpjSx2uHnVjZovJogp6KIa9g93MFpNF1XRjZrYYOejNzHLOQW9mlnMOejOznEsV9JIul7RP0n5Jm6qcXybpe5J+Jumnkj6UHF8p6WlJeyTtlvTlrH8AMzOrr2HQS+oFvgFcAawGrpe0uuKyrwA7IuJC4HPAncnxY8AfR8QHgLXAF6t8r5mZzaM0wysvAfZHxEsAkh4GrgZeLLtmNfB1gIjYK2mVpPdGxCHgUHL815L2AIMV35sJr0ppZlZdmqabQeBA2ePR5Fi5ncA1AJIuAc4GVpRfIGkVsAZ4rtqLSNogaVjS8Pj4eKrCl5RWpRybmCR4Z1XKrSNjTT2PmVkepQl6VTkWFY83A8sk7QC+BIxQbLYpPoH0LuAx4KaIOFztRSLinogYioih5cuXpyn7DK9KaWZWW5qmm1FgZdnjFcDB8guS8L4BQJKAl5MvJBUohvyDEfF4BmU+gVelNDOrLU2N/nngXEnnSFoKXAc8UX6BpIHkHMDngWci4nAS+t8E9kTEn2dZ8HJeldLMrLaGQR8Rx4AbgW3AHuA7EbFb0hckfSG57APAbkl7KY7OKQ2jvBT4A+CTknYkX5/O+ofwqpRmZrWlWtQsIp4Enqw4dnfZ//9f4Nwq3/d3VG/jz5RXpTQzqy03q1d6VUozs+q8BIKZWc456M3Mcs5Bb2aWcw56M7Occ9CbmeWcIipXM+g8SePAL6qcOh34ZZuLk5bL1hqXrTUuW2vyXLazI6Lq+jFdGfS1SBqOiKFOl6Mal601LltrXLbWLNayuenGzCznHPRmZjm30IL+nk4XoA6XrTUuW2tcttYsyrItqDZ6MzNr3kKr0ZuZWZMc9GZmObcggl7S5ZL2SdovaVOny1NO0n2SXpX0806XpZKklZKelrRH0m5JX278Xe0h6WRJP5W0Mynb7Z0uUzlJvZJGJP2g02WpJOkfJO1K9ncY7nR5yiWbED0qaW/yvvuNTpcJQNJ5ZXti7JB0WNJNnS5XiaSbk7+Dn0t6SNLJmT5/t7fRS+oF/h/wWxS3NXweuD4iXuxowRKSPga8AdwfER/qdHnKSToDOCMiXpD0bmA7sL4bfnfJ7mOnRMQbyXaTfwd8OSKe7XDRAJD0n4Ah4NSI+Eyny1NO0j8AQxHRdRN/JP0V8H8i4t5k17n+iJjocLFmSTJlDPhoRFSbmNnu8gxSfP+vjohJSd8BnoyIb2f1GguhRn8JsD8iXoqIo8DDwNUdLtOMiHgGeK3T5agmIg5FxAvJ//+a4g5hXbFofxS9kTwsJF9dUeuQtAK4Eri302VZSCSdCnyM4vahRMTRbgv5xGXA33dDyJdZAvRJWgL0U7Ev91wthKAfBA6UPR6lS8JqIZG0ClgDPNfhosxImkd2AK8C/zsiuqVsdwB/AhzvcDlqCeBHkrZL2tDpwpR5HzAOfCtp9rpX0imdLlQV1wEPdboQJRExBvwZ8ApwCHg9In6U5WsshKCvthVhV9T8FgpJ7wIeA26KiMOdLk9JRExHxEXACuASSR1v+pL0GeDViNje6bLUcWlEXExxf+YvJs2H3WAJcDFwV0SsAd4Euq1PbSlwFfDdTpelRNIyiq0U5wBnAqdI+myWr7EQgn4UWFn2eAUZ39bkWdL+/RjwYEQ83unyVJPc3v8YuLyzJQGKG9pflbSDP0xxY/sHOluk2SLiYPLfV4HvUWze7AajwGjZndmjFIO/m1wBvBAR/9TpgpT5N8DLETEeEVPA48BvZvkCCyHonwfOlXRO8ml8HfBEh8u0ICQdnt8E9kTEn3e6POUkLZc0kPx/H8U3+96OFgqIiFsiYkVErKL4XnsqIjKtXc2FpFOSjnWSZpFPAV0x4isi/hE4IOm85NBlQMc7/itcTxc12yReAdZK6k/+Zi+j2J+Wma7fHDwijkm6EdgG9AL3RcTuDhdrhqSHgI8Dp0saBW6NiG92tlQzLgX+ANiVtIUDfCUinuxckWacAfxVMgKiB/hORHTdUMYu9F7ge8U8YAnwvyLirztbpFm+BDyYVMpeAm7ocHlmSOqnOHrvP3S6LOUi4jlJjwIvAMeAETJeDqHrh1eamdncLISmGzMzmwMHvZlZzjnozcxyzkFvZpZzDnozs5xz0JuZ5ZyD3sws5/4/8mSX+n4JX5kAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "scatter(jnp.linalg.norm(batch['st'],axis=-1), jnp.linalg.norm(model.apply(params, batch['qn']),axis=-1) / \n",
    "                          jnp.linalg.norm(batch['st'],axis=-1))\n",
    "\n",
    "axhline(jnp.median(jnp.linalg.norm(model.apply(params, batch['qn']),axis=-1) / \n",
    "                          jnp.linalg.norm(batch['st'],axis=-1)),color='red')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "0caac620",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DeviceArray(1.0014867, dtype=float32)"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# That was for sigma of 0.25\n",
    "jnp.median(jnp.linalg.norm(model.apply(params, batch['qn']),axis=-1) / \n",
    "                          jnp.linalg.norm(batch['st'],axis=-1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "683734df",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
