{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "3ca18a8e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "env: CUDA_VISIBLE_DEVICES=7\n"
     ]
    }
   ],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%env CUDA_VISIBLE_DEVICES=7"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "ec9d9fe8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "\n",
    "import torch as t\n",
    "from torch import nn\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "from phd.few_bits_gelu.custom_gelu import MyGelu, NBitsGelu\n",
    "from phd.few_bits_gelu.actnn import ActNNGelu, ActNNGeluFn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "c11c6a41",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[autoreload of phd.few_bits_gelu.actnn failed: Traceback (most recent call last):\n",
      "  File \"/home/gnovikov/data/miniconda3/envs/ml/lib/python3.9/site-packages/IPython/extensions/autoreload.py\", line 245, in check\n",
      "    superreload(m, reload, self.old_objects)\n",
      "  File \"/home/gnovikov/data/miniconda3/envs/ml/lib/python3.9/site-packages/IPython/extensions/autoreload.py\", line 410, in superreload\n",
      "    update_generic(old_obj, new_obj)\n",
      "  File \"/home/gnovikov/data/miniconda3/envs/ml/lib/python3.9/site-packages/IPython/extensions/autoreload.py\", line 347, in update_generic\n",
      "    update(a, b)\n",
      "  File \"/home/gnovikov/data/miniconda3/envs/ml/lib/python3.9/site-packages/IPython/extensions/autoreload.py\", line 302, in update_class\n",
      "    if update_generic(old_obj, new_obj): continue\n",
      "  File \"/home/gnovikov/data/miniconda3/envs/ml/lib/python3.9/site-packages/IPython/extensions/autoreload.py\", line 347, in update_generic\n",
      "    update(a, b)\n",
      "  File \"/home/gnovikov/data/miniconda3/envs/ml/lib/python3.9/site-packages/IPython/extensions/autoreload.py\", line 302, in update_class\n",
      "    if update_generic(old_obj, new_obj): continue\n",
      "  File \"/home/gnovikov/data/miniconda3/envs/ml/lib/python3.9/site-packages/IPython/extensions/autoreload.py\", line 347, in update_generic\n",
      "    update(a, b)\n",
      "  File \"/home/gnovikov/data/miniconda3/envs/ml/lib/python3.9/site-packages/IPython/extensions/autoreload.py\", line 302, in update_class\n",
      "    if update_generic(old_obj, new_obj): continue\n",
      "  File \"/home/gnovikov/data/miniconda3/envs/ml/lib/python3.9/site-packages/IPython/extensions/autoreload.py\", line 347, in update_generic\n",
      "    update(a, b)\n",
      "  File \"/home/gnovikov/data/miniconda3/envs/ml/lib/python3.9/site-packages/IPython/extensions/autoreload.py\", line 302, in update_class\n",
      "    if update_generic(old_obj, new_obj): continue\n",
      "RecursionError: maximum recursion depth exceeded\n",
      "]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.002308070659637451\n",
      "0.0039139967411756516\n"
     ]
    }
   ],
   "source": [
    "def test_model(gelu_cls):\n",
    "    return nn.Sequential(\n",
    "        gelu_cls(),\n",
    "        gelu_cls(),\n",
    "        gelu_cls(),\n",
    "        gelu_cls(),\n",
    "        gelu_cls(),\n",
    "        gelu_cls(),\n",
    "        gelu_cls(),\n",
    "        gelu_cls(),\n",
    "    )\n",
    "\n",
    "\n",
    "def computation(model, x):\n",
    "    model(x).mean().backward()\n",
    "    t.cuda.synchronize()\n",
    "\n",
    "    \n",
    "def run_test(gelu_cls, n_steps: int, shape):\n",
    "    model = test_model(gelu_cls).cuda()\n",
    "    x = t.randn(*shape, device='cuda', requires_grad=True)\n",
    "    \n",
    "#     def print_stats():\n",
    "#         y = model(x)\n",
    "#         print(y.min().item(), y.max().item(), y.mean().item(), y.std().item())\n",
    "#     print_stats()\n",
    "    \n",
    "    for _ in range(10):\n",
    "        computation(model, x)\n",
    "        \n",
    "    start_time = time.time()\n",
    "    for _ in range(n_steps):\n",
    "        computation(model, x)\n",
    "    return (time.time() - start_time) / n_steps\n",
    "\n",
    "\n",
    "N_STEPS = 2**7\n",
    "SHAPE = (16, 3, 256, 256)  # (4, 3, 256, 256)\n",
    "\n",
    "print(run_test(lambda: MyGelu(3), N_STEPS, SHAPE))\n",
    "print(run_test(lambda: ActNNGelu(3), N_STEPS, SHAPE))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "633458bf",
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'abacaba' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m/tmp/ipykernel_647662/1247987448.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mabacaba\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m: name 'abacaba' is not defined"
     ]
    }
   ],
   "source": [
    "abacaba"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "da29f6cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Ctx:\n",
    "    def __init__(self):\n",
    "        self.saved_tensors = []\n",
    "        \n",
    "    def save_for_backward(self, tensor):\n",
    "        self.saved_tensors.append(tensor)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "8e1193fb",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<matplotlib.lines.Line2D at 0x7f65fb222730>]"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAhzElEQVR4nO3deXxU9b3/8dcnO4QAIQkQSNgjm8gWAqi1Xq2tSytea1toVVZ3f7W1vbft7ebD2/Z2tb3uooigVL2uxe1nXWq1yhaQPSxhDySQhSSEkHW+949MvRETEsJMzszk/Xw85pGZOV/mvOckvHPynTNzzDmHiIiEvyivA4iISGCo0EVEIoQKXUQkQqjQRUQihApdRCRCxHi14tTUVDdkyBCvVi8iEpbWrl1b4pxLa2mZZ4U+ZMgQcnNzvVq9iEhYMrN9rS3TlIuISIRos9DNLNPM/mZmW81si5nd0cIYM7N7zSzfzDaa2aTgxBURkda0Z8qlAfiec26dmSUBa83sLefc1mZjLgOy/JepwEP+ryIi0kna3EN3zhU659b5rx8D8oCBJw2bASx1TVYCvc0sPeBpRUSkVac1h25mQ4CJwKqTFg0EDjS7XcBnSx8zu9HMcs0st7i4+DSjiojIqbS70M2sB/AC8B3nXGVHVuacW+icy3bOZaeltXjUjYiIdFC7Ct3MYmkq82XOuRdbGHIQyGx2O8N/n4iIdJL2HOViwCIgzzl3TyvDlgPX+492mQZUOOcKA5hTRCQirHj839m18aOgPHZ7jnI5D7gO2GRm6/33/QcwCMA59zDwOnA5kA9UA3MDnlREJMzlvvII0/c/worGWoafc27AH7/NQnfO/QOwNsY44LZAhRIRiTT7d6xndO7PyIsbw5Q5vwvKOvROURGRIKuprqLhmdnUWSx9Zj9FTGxcUNajQhcRCbINj93CMN9e9n/+j/TLGB609ajQRUSCKPfVhUwtW86KAdcz/l++FtR1qdBFRILkwM4NjF7zU/JixzBl7h+Cvj4VuohIENRUV1HXCfPmzanQRUSCYMOiWxneuCfo8+bNqdBFRAIs99WFTC39CyvSgz9v3pwKXUQkgJrPm2fP/X2nrluFLiISIDUnjn8yb558/ZPExsV36vpV6CIiAbJh0e1N8+YX3EP/zBGdvn4VuohIAKx780mmlrzIyn6zGH/R1z3JoEIXETlDRft3MnzFD9gZPYJJ8/7kWQ4VuojIGWior+Pok7OJcY10m7WEuPgEz7Ko0EVEzsCapT9idP0W8ibfRcaIsz3NokIXEemgLR+9Ts7+Razp9SWyr7zF6zgqdBGRjigvKSLtr7dxKCqdMQsWeh0HUKGLiJw25/Ox9/E59HYV1F71KIlJvb2OBKjQRURO26pnf82E6hWsG/ldRow/3+s4n1Chi4ichl0bP2LStj+wodtUps78sddxPkWFLiLSTtVVFcS+vIAKSyJz7mIsKrQqNLTSiIiEsM2P3UxG4yEOX3IfffoO9DrOZ6jQRUTaIffVheSUv86qjDmcfd5XvI7TIhW6iEgbDu7ewqg1P2Nb7BimzPmt13FapUIXETmFutoajv95Dj6Lote1T3TKqeQ6SoUuInIK6xbfyVkNO8if9ivSB4/0Os4pqdBFRFqx8b0XmFa0jFUpM5h06Ryv47RJhS4i0oKSov0MfO+77I0axPj5D3odp11U6CIiJ/E1NlK4eDaJrhq+tpiE7j28jtQuKnQRkZOsXnYX42rXseHsHzJkdLbXcdpNhS4i0sz23HeZvOsB1iVeQM5X7/Q6zmlRoYuI+FWWl5L02s2UWh+Gz3885N7a35bwSisiEiTO52PnogX09RVTfvlD9OqT5nWk06ZCFxEBcl++j8nH3mXN0JsYlXOJ13E6RIUuIl3evm3rGLvhl2yJG0/Otb/wOk6HqdBFpEurOXGcxv+ZS63F03fOUqJjYryO1GEqdBHp0jYs+n8M8+1l/wW/I23AEK/jnBEVuoh0WR//9SmmlrzAyr7fYPxFM72Oc8baLHQze9zMjpjZ5laWX2hmFWa23n/5WeBjiogE1uGCXQz96AfkRw9n4rw/eR0nINqzh/4EcGkbYz5wzk3wX+4+81giIsHT2NBA6dLZxLl64mcuJj6hu9eRAqLNQnfOvQ+UdUIWEZFOsXrpjxhTt4ktE39OZtZ4r+METKDm0Keb2QYze8PMxgboMUVEAm7rijfI2fcouT0vYcpVt3kdJ6ACcXzOOmCwc67KzC4HXgayWhpoZjcCNwIMGjQoAKsWEWm/8pIiUt68jcKo/oxa8KjXcQLujPfQnXOVzrkq//XXgVgzS21l7ELnXLZzLjstLfzeVisi4cv5fOx5fC7JrpwTMx6lR89kryMF3BkXupn1NzPzX8/xP2bpmT6uiEggrX7ut0ys/oh1Z91B1oTPeR0nKNqccjGzp4ELgVQzKwB+DsQCOOceBq4BbjGzBuAEMNM554KWWETkNO3atJIJW3/Phm5TyJn5E6/jBE2bhe6cm9XG8vuB+wOWSEQkgKqrKoh5aT7HLJGMuU8QFR3tdaSg0TtFRSSibV50C5mNBym6+F5S+mV4HSeoVOgiErHWvvYYOUdfY9XA6zn7czO8jhN0KnQRiUiH9mzjrNU/YXvMKLLn/M7rOJ1ChS4iEae+rpZjy67HmZF07VJi4+K9jtQpVOgiEnFyF3+fkQ3b2ZnzKwYMGel1nE6jQheRiLLp/ZeYXriU1X2+wuTL53odp1Op0EUkYpQUHSD93e+wNyqTcfMf8jpOp1Ohi0hE8DU2cuiJOfRwx3FfXUS3xCSvI3U6FbqIRITVT9/NOTW5bBj77wwdO9XrOJ5QoYtI2Nux7u9M3nkfHyeeT8413/c6jmdU6CIS1o5VlJH4yg2UWjLD5j2ORXXdWuu6z1xEwp7z+di+6Ab6+45w9LIH6ZXSz+tInlKhi0jYWvOXB8iufJvVQ25i9NQveR3Hcyp0EQlL+3es5+z1/8mWuHPIue6XXscJCSp0EQk7tTXV1D87l1qLI232UqJjAnE2zfCnQheRsPPxom8zvHE3+z73e/oOHOp1nJChQheRsLL+7aeZVvwcK/t+nQkXz/Q6TkhRoYtI2DhycA+D//Fv5EcPZ+K8//Y6TshRoYtIWGhsaKB4yXXEuzriZy4mPqG715FCjgpdRMLC6id/zNi6TWyZ+DMys8Z7HSckqdBFJOTlrXqTnL2PkNvzC2RfeavXcUKWCl1EQlpF6WGS37iVwqh+jFrwWJd+a39btGVEJGQ5n4/dj8+jjzvKiSsfpUfPZK8jhTQVuoiErNXP/56Jx//BuqxvkzXxAq/jhDwVuoiEpD1bVjFhy2/ZmDCFnFk/9TpOWFChi0jIOXH8GFEvzOeYJTJw7hNERUd7HSksqNBFJORsWnQLmY0FFF50Lyn9MryOEzZU6CISUta+vpicsldYNfA6xl0ww+s4YUWFLiIh49De7WSt/g+2x4wie87vvY4TdlToIhIS6utqOfbU9ZhzJH1rCbFx8V5HCjsqdBEJCblPfJ+RDdvYkfMLBgwd5XWcsKRCFxHPbfr7i0w/tJTVfb7C5CsWeB0nbKnQRcRTJUX7GfC377A3KpNx8x/yOk5YU6GLiGd8jY0ULp5NoqvGXbOYbolJXkcKayp0EfHMqmV3Ma52HRvO/hFDx0zxOk7YU6GLiCe2rXmbKbvuZ12Pz5Pz1e96HSciqNBFpNNVHC2h12s3cyQqleHzH9dH4gaItqKIdCrn87HrsbmkujIqr3iYXsmpXkeKGG0Wupk9bmZHzGxzK8vNzO41s3wz22hmkwIfU0QixeoX7mHS8fdZO/w2RmVf7HWciNKePfQngEtPsfwyIMt/uRHQcUci0qI9W1YxfvOv2ZiQTc637vI6TsRps9Cdc+8DZacYMgNY6pqsBHqbWXqgAopIZKiuqiDqhflUWSID5y7RR+IGQSDm0AcCB5rdLvDf9xlmdqOZ5ZpZbnFxcQBWLSLhYvOiW/WRuEHWqS+KOucWOueynXPZaWlpnblqEfFQ7muPknP0VVZlzNZH4gZRIAr9IJDZ7HaG/z4REQ7uzmPU6p+yLXYM2bN/63WciBaIQl8OXO8/2mUaUOGcKwzA44pImKurreH4n6/HZ1H0uvYJfSRukMW0NcDMngYuBFLNrAD4ORAL4Jx7GHgduBzIB6qBucEKKyLhZd3iO5nWsIOPp9/LxMEjvY4T8dosdOfcrDaWO+C2gCUSkYiw4W/PMa1oGatSr2bql2Z7HadL0DtFRSTgDhfsYtDf72R31BDGz7vP6zhdhgpdRAKqob6OsiXXEefqiJm5hITuPbyO1GWo0EUkoNYs/h6j67eQN+U/GXTWBK/jdCkqdBEJmA3v/g/TDy1lVZ8ryf7yjV7H6XJU6CISEEUH8hn0/p3sih7K+AX6SCcvqNBF5IzV19VSvvQ64lw9cZo394wKXUTO2NrHv8uo+q3k5fyCzKzxXsfpslToInJG1r/zTNPx5ilXkX3FDV7H6dJU6CLSYUX7dzLkg++xK3oY4xc86HWcLk+FLiIdUl9XS8XSa4lxjcTPepKEboleR+ryVOgi0iFrF93ByIZtbJ/6SzJGnO11HEGFLiIdsP6tPzPt8NOsSr2ayZfP9zqO+KnQReS0FO7bztAPv09+9HAmLHjA6zjSjApdRNqttqaaqie/RZTz0e1bTxKf0N3rSNKMCl1E2m39o7eQ1bCT/PN+z8BhY72OIydRoYtIu6x5+QGmlr7MivTrmfjFa72OIy1QoYtIm3ZtWsm4j3/OlrjxTJn3B6/jSCtU6CJyShVHS4h/cTaVlkS/ecuIiY3zOpK0QoUuIq3yNTay+9Hr6OcrpuyyR0jtn+l1JDkFFbqItGrVsruYWP0Ra0feyaipX/Q6jrRBhS4iLdr8j+Xk7LqPtUn/wtSZ/+F1HGkHFbqIfMaRg3sY8PZtFERnMOrGJ7AoVUU40HdJRD6lrraGssWziHd18I2nSEzq7XUkaScVuoh8yseP3caohjy2Tf0Vg0dO8DqOnAYVuoh8Ys3L9zO1+HlW9pupD90KQyp0EQFgx7q/c87Hd7E5fgLZC+7zOo50gApdRCgpOkCv5XMps94MXPCM3jwUplToIl1cXW0NRxbNpKc7RvXVS0lOS/c6knSQCl2ki/v40VsYU7+ZLVN+yfBzzvU6jpwBFbpIF7bmxf9masmLrOz/LbK/fKPXceQMqdBFuqjtue8yfsPdbIqfSPb8P3kdRwJAhS7SBZUU7Sf51fmURKWQeYNeBI0UKnSRLqa2ppriRd+ghzvOiauX0ju1v9eRJEBU6CJdiPP52PjwPEbXbyVv6n8xfNw0ryNJAKnQRbqQVcvuYkr5G6zIvEHvBI1AKnSRLmL920+Tk38v63p8nqlzfuN1HAkCFbpIF7BnyyqyPvgOu2KGM/qWZURFR3sdSYKgXYVuZpea2XYzyzezH7awfI6ZFZvZev9lQeCjikhHlB4uIOG5b3HcutNr3vN0S0zyOpIESUxbA8wsGngAuAQoANaY2XLn3NaThj7rnLs9CBlFpINqa6opfuxrDHYVFFz1AlkDh3odSYKoPXvoOUC+c263c64OeAaYEdxYInKmnM/HxofmMKp+K1un/pasiRd4HUmCrD2FPhA40Ox2gf++k33VzDaa2fNmplODi3hs5VM/Z0rFm6wYdBOTL5/rdRzpBIF6UfQVYIhz7hzgLWBJS4PM7EYzyzWz3OLi4gCtWkROtva1x5i++17WJl3EtDm/9jqOdJL2FPpBoPked4b/vk8450qdc7X+m48Bk1t6IOfcQudctnMuOy0trSN5RaQNeaveZNzqH5AXO5axtz6lEzx3Ie35Tq8BssxsqJnFATOB5c0HmFnzD1C+EsgLXEQRaa/9O9aT/sY8Dkf1Jf2mF0noluh1JOlEbR7l4pxrMLPbgTeBaOBx59wWM7sbyHXOLQe+bWZXAg1AGTAniJlFpAWlhwuIefrr+Igi6roX9RktXZA55zxZcXZ2tsvNzfVk3SKR5sTxYxz440Vk1u/lwIznOGvShV5HkiAxs7XOueyWlmlyTSTMNTY0sO2BbzCififbzvujyrwLU6GLhDHn87HmkZuZWP0hq0f9GxO/eK3XkcRDKnSRMLbyyZ8wrfg5VvabybRZP/Y6jnhMhS4SplY/fw/T9zxAbs9LyLnxQa/jSAhQoYuEoY/fXMLkTXezIWEK42/XpydKExW6SJjZ/OErjP3oTnbGjiLr9heIjYv3OpKECBW6SBjJ3/AhQ/56A4ei00m/5S9079HL60gSQlToImGiIH8zyS/Nosp60G3eX+iV0s/rSBJiVOgiYaBo/05inrqKKBx133yefhnDvY4kIUiFLhLiig/tpWHxl+nOcUr/9RkGnTXB60gSolToIiGs9HAB1Y9dQbKvnENffooR48/zOpKEMBW6SIiqKD1MxSNX0LfxCPsufYJR2Rd7HUlCnApdJARVlpdy5KErGNh4kPyLH2PM9Mu8jiRhQIUuEmKqKo9y8IEvM7h+N9s+fz/jLtApfKV9VOgiIeRYRRkF915GVt02Nk+/h/EXzfQ6koSRNk9wISKdo+JoCYcfuIzh9bvYdO6fmPSl2V5HkjCjQhcJARWlhzny4OUMadjDlvPvZ+Il3/Q6koQhFbqIx8pLiih96DIGN+wn7/MPMkHTLNJBKnQRD5UeLqDikSvIaDzItn9ZyPgLv+p1JAljKnQRjxzau53GJVeR7ithx8WLOEdHs8gZUqGLeGBf3lq6PXsNPahl3xV/ZlzOJV5HkgigQhfpZNtz36Xfq9dRTyylX3uJUWOneh1JIoSOQxfpRJvef4nMV2ZSZT2ou/4NhqrMJYBU6CKdJPe1Rxn5znyKotNJuOktBg4b7XUkiTCachEJMufzsfLJnzB9zwNsjTubgbf+hV7JqV7HkgikQhcJovq6Wj5+aC7Tj75Gbs8vMO7WJ4lP6O51LIlQKnSRIKksL2XfQ9eQU7uOFRnzmDbvD1iUZjkleFToIkFQuG87NUu+zqjGA6wZfzfTr77D60jSBajQRQJsy4evMeCtm0mkgW0XL2aK3jAknUSFLhIgzudj1bO/Jnvb7zgYPZCoWcsYlzXe61jShajQRQKg5sRxNj6ygGnlr/Nx4rmMuGkZSb36eB1LuhgVusgZKjqQT8WSb5LTsJ0VmTcwdc5viIqO9jqWdEEqdJEzsP7tpxnyj+/TwzWy7tz7mf6l67yOJF2YCl2kA+rralm76A6mHX6a/OjhJMxayqQRZ3sdS7o4FbrIaSrct53KJ69jWsN2VqVezfj595PQLdHrWCIqdJH2cj4fucsfYtTH/0kPYN20PzH1srlexxL5hApdpB3Kjhxk75KbmHL8A/LixtJz1iIm6cO1JMSo0EXasP6dZ8j44Aec7apYOeIOpsz6GdEx+q8joaddHyxhZpea2XYzyzezH7awPN7MnvUvX2VmQwKeVKSTlR05SO491zDhg5uoiE7m4NdfZ9p1d6vMJWS1WehmFg08AFwGjAFmmdmYk4bNB44650YAfwR+E+igIp3F+Xysfuk+oh7M4ZyKd1mZMZ+Mf/tIJ6OQkNeeXY0cIN85txvAzJ4BZgBbm42ZAdzlv/48cL+ZmXPOBTCrSNAdyN9E5f/cRk7dBvJix9D96vuZNnqy17FE2qU9hT4QONDsdgFw8q7KJ2Occw1mVgGkACXNB5nZjcCNAIMGDepgZJHAO1ZRxuZnfsrkQ8/Qy+JYNfanTLn6u3rHp4SVTp0MdM4tBBYCZGdna+9dPNfY0MDal+9l+OY/MZ0K1iRfxtCv/4apAwZ7HU3ktLWn0A8Cmc1uZ/jva2lMgZnFAL2A0oAkFAkC5/Ox+YOX6f7+3eQ07iEvdgxllz/FlIkXeB1NpMPaU+hrgCwzG0pTcc8EvnnSmOXAbGAFcA3wbrDmzw/uzqNgzXLiUwaR1HcwKQOG0atPX50JRtpt64o34N1fMK5+M4WksTbnHiZdOlc/QxL22ix0/5z47cCbQDTwuHNui5ndDeQ655YDi4AnzSwfKKOp9IOicPN7TM371afuq3bxlESnUhHblxPd0mlMGkB070y6pQ6iZ78h9M3MoltiUrAiSZjYlvsO9W/9gnG16ygmmVWjf8SEGd8mXef4lAhhXh2Ikp2d7XJzc0/73zU2NFB2pICyQ7s5XryPurIDUHmQuOOHSKw5TJ+GI6S4cqLs08+rmGRKYtOp6p5BQ8/BxKQMpUf6CFIzR5LSL1MvfkUoX2Mjm/7+HDEr72ds3SbK6MmOrAVM+NfvkdC9h9fxRE6bma11zmW3uCzcCr096mprKCncR3nRHqqP7KW+bC/R5ftIrC4gpa6Qvq7kU4Vf42I5HN2P8vgB1CRm4lJG0D19JKmDx9Ivc4TeSBKGak4cZ+P/X0TfTQsZ4jtAEanszbqecVfeQWJSb6/jiXRYlyv0ttTWVHOkYBdHC3Zw4sguXNle4o/to2fNQfo3FJJoNf831sVSGJ3O0W6DqOk5jOi+WfQcMIq+Q8eSnJquedcQs3/Heg698zCjDr9Cb6rYHTWEsgk3M/7SecTGxXsdT+SMnarQu+SuZ3xCdzJHjCNzxLjPLHM+HyVHCjiyZwtVB/PwleQTX7mHPif2kl61grjCRtjQNLaSRApjMqhMHExD72HE9TuL3pljGTBsrObsO1F1VQVb33uWbhufYmzdBtJdNJuSziM2ZwFnn/8VhumXrnQRXXIPvaMa6uso2reD0v1bOVG0HSvNJ7FqL2m1B+h30lGaRaRRHJ9JddJQSM2i+4BRpA4eQ7+MEZqvD4C62hq2fvASDRueY0zlP+hutRyyvuwb8jWyvnQzqf31xjWJTJpy6QTVVRUU7t5C+YGt1B3eQWz5LnpV7yO9voAeduKTcSdcHIXRAyjvPpja3sOJ7XsWPTPGkD58nE4q3IbK8lJ2fvQyvu1vklXxIb2popwebO9zET0mz2T0tEv1y1IingrdQ87no7ToAIf3bKbqUB6uZCfdKveQUrOf/r7DxJjvk7El9OZIXCZVPYbg6zOChPRRpAweS/rgkcTExnn4LLzRUF/H7k0rKNvyDkkH3uOs2s3EWiPl9CC/53Rixl/DmPOvIi4+weuoIp1GhR6i6mprKNyzlbL9W6kp2k50WT5Jx/fSv76AZCr/b5yLpjA6nbKEQdT2HEpU2lkkZYwmbfAYUvoOjJgXZiuOllCwdRWV+R+SWLiaYSc2f/LXzZ6owRT1/zzJ47/CiEkXdslfcCKgQg9LFaWHKdy9icqCPBqLdxBfsYc+J/YxoPEQcdbwybgTLo7i6DTK4/pT030AjT0ziO0zmO59h9JnwHBS0weHXPnVnDhO0d48yvbnUVuUR3zxZvoe306GK/pkzN6oQRxOnkTMsM8xaOIXSBswxLvAIiFEhR5BGhsaKNq/k5J9mzlRtBPK9xN3/CA9aopIaThMChWfGu9zxlHrSWVUMlWxfahJSKWxWxr06Et0Ul9iE5OJS+xNQlIfuiX1pnvPFHok9T7tY++dz0f18UqOVx6lurKUmmNHqT5aRH35QXwVh4g5fpiEmsOk1B2kv6/4U+8DOGT9KOo+krq0cXQfMpnMseeSnJYekO0lEml02GIEiY6JYeCw0Qxs5XyWJ44f40jBLioKd3GieC++ykNEHT9CXE0J3etKSa04QJ/ycuKt/pTrqXbxNFg0DcQ0XSyGRmLwWRTRroEY10jT0qav3V0NieYjsYXHqnfRlFoyFTEpHEo6h329hxHbN4teGaPpP3QsA3r1YUAAto1IV6dCjzDdEpMYPHICjJzQ6hjn81FZeZSK4gJOHCuntqqc+uqjNFZX0HiiAldTgdUdB18D5qv3XxqI8tVjrhFfVCzOYnBRTReiYvDFJWEJPbGEXsQk9ia2e2+6J/cnuf9gklPT6R8dTf9O2woiXZMKvQuyqCh69k6hZ+8Ur6OISABFxuERIiKiQhcRiRQqdBGRCKFCFxGJECp0EZEIoUIXEYkQKnQRkQihQhcRiRCefZaLmRUD+zr4z1OBkgDGCZRQzQWhm025To9ynZ5IzDXYOZfW0gLPCv1MmFluax9O46VQzQWhm025To9ynZ6ulktTLiIiEUKFLiISIcK10Bd6HaAVoZoLQjebcp0e5To9XSpXWM6hi4jIZ4XrHrqIiJxEhS4iEiHCotDN7Hdmts3MNprZS2bWu5Vxl5rZdjPLN7MfdkKur5nZFjPzmVmrhyCZ2V4z22Rm680s6CdSPY1cnbq9/OvsY2ZvmdlO/9fkVsY1+rfXejNbHqQsp3z+ZhZvZs/6l68ysyHByNGBXHPMrLjZ9lnQSbkeN7MjZra5leVmZvf6c280s0khkutCM6totr1+1km5Ms3sb2a21f//8Y4WxgR2mznnQv4CfBGI8V//DfCbFsZEA7uAYUAcsAEYE+Rco4GRwHtA9inG7QVSO3F7tZnLi+3lX+9vgR/6r/+wpe+lf1lVkHO0+fyBW4GH/ddnAs92wvZpT645wP2d9fPUbL0XAJOAza0svxx4AzBgGrAqRHJdCLzqwfZKByb5rycBO1r4XgZ0m4XFHrpz7q/OuQb/zZVARgvDcoB859xu51wd8AwwI8i58pxz24O5jo5oZ65O315+M4Al/utLgKs6YZ0tac/zb571eeBiM7MQyOUJ59z7QNkphswAlromK4HeZpYeArk84ZwrdM6t818/BuQBA08aFtBtFhaFfpJ5NP1GO9lA4ECz2wV8duN5xQF/NbO1Znaj12H8vNpe/Zxzhf7rRUC/VsYlmFmuma00s6uCkKM9z/+TMf4digog2Cdibe/35av+P9GfN7PMIGdqr1D+PzjdzDaY2RtmNrazV+6frpsIrDppUUC3WcicJNrM3oYWTwz/Y+fcX/xjfgw0AMtCKVc7nO+cO2hmfYG3zGybf6/C61xBcapszW8455yZtXbc7GD/NhsGvGtmm5xzuwKdNUy9AjztnKs1s5to+iviIo8zhbJ1NP08VZnZ5cDLQFZnrdzMegAvAN9xzlUGc10hU+jOuS+carmZzQG+DFzs/JNPJzkINN9TyfDfF9Rc7XyMg/6vR8zsJZr+rD6jQg9ArqBsLzh1NjM7bGbpzrlC/5+WR1p5jH9us91m9h5NezeBLPT2PP9/jikwsxigF1AawAwdyuWca57hMZpelwgFQfuZOhPNS9Q597qZPWhmqc65oH9ol5nF0lTmy5xzL7YwJKDbLCymXMzsUuDfgSudc9WtDFsDZJnZUDOLo+lFrKAcHXE6zCzRzJL+eZ2mF3hbfDW+k3m1vZYDs/3XZwOf+WvCzJLNLN5/PRU4D9ga4Bztef7Ns14DvNvKzkSn5jppjvVKmuZmQ8Fy4Hr/kRvTgIpm02ueMbP+/3ztw8xyaOq9YP9ixr/ORUCec+6eVoYFdpt19iu/HXy1OJ+meab1/ss/jzwYALx+0ivGO2jak/txJ+T6V5rmvGqBw8CbJ+ei6WiFDf7LllDJ5cX28q8zBXgH2Am8DfTx358NPOa/fi6wyb/NNgHzg5TlM88fuJumHQeABOA5/8/famBYJ22jtnL9l/9naQPwN2BUJ+V6GigE6v0/X/OBm4Gb/csNeMCfexOnOPKrk3Pd3mx7rQTO7aRc59P0+tnGZt11eTC3md76LyISIcJiykVERNqmQhcRiRAqdBGRCKFCFxGJECp0EZEIoUIXEYkQKnQRkQjxv4Xl++YoSoOfAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "def gelu(x):\n",
    "    return .5 * x * (1 + t.erf(x / 2**.5))\n",
    "\n",
    "\n",
    "def gelu_deriv(x):\n",
    "    return .5 * t.erf(x / 2**.5) + x * t.exp(-x * x / 2) / (2 * t.pi)**.5 + .5\n",
    "\n",
    "\n",
    "xs = t.linspace(-2, 2, 1000)\n",
    "plt.plot(xs.numpy(), t.nn.functional.gelu(xs))\n",
    "plt.plot(xs.numpy(), gelu(xs))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "8609781e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.collections.PathCollection at 0x7f65f9b72910>"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAQQElEQVR4nO3dfYxddZ3H8fenT1pcpGiHRNti0S3Grm6ATBBDsrpRdwtmWxON0GwjJgSCLmY3GpNuNC7B/UOXrCFm2V0xa3zIyoP+QSYBQ7KKS0Io2yEgCgRTK9oWs4wI/UcU0O/+Mbfs7XBn7pn2dm7nx/uV3PQ8fOec7+/OnU/PnHPunVQVkqTlb8W4G5AkjYaBLkmNMNAlqREGuiQ1wkCXpEasGteO169fX5s3bx7X7iVpWbr//vt/VVUTg9aNLdA3b97M9PT0uHYvSctSkp/Pt85TLpLUCANdkhphoEtSIwx0SWqEgS5JjRjbXS7HavPu28fdgiSNzOOff9/ItrWsjtANc0mtGWWuLatAlyTNz0CXpEYY6JLUCANdkhqxrAJ9lFeDJelk8LK9ywVmB//459/H9Zecw4Z1awmwYd1adl1w5lHz119yDlvOeNVLvv70U1azMoO3fcrqFWSedYP072vXBWeysvfFKxO2nPGqgdt61ZqVR33dhW96zUtq1q1dzemnrD6q7si45+5n1wVnvvhcAAv2vyJw4Zte82Ltke3MN4ZdF5x51HM992vmPufzzS9GwsD9HlXT4Xmar+8jz/9CjtR2fSmcfsrqYxrr6hVH729F3w7XrZ3d5rq1qxe1zXFbxI+PGP1Basb1R6InJyfLT1uUpMVJcn9VTQ5at+yO0CVJgxnoktQIA12SGmGgS1IjDHRJaoSBLkmNMNAlqREGuiQ1wkCXpEYY6JLUCANdkhoxNNCTfDXJk0l+PM/6JPlSkn1JHkpy3ujblCQN0+WPRH8N+BfgG/OsvwjY0nu8Hfi33r/S8vL17fCz/x53F3q5uebwyDY19Ai9qu4Gfr1AyQ7gGzVrD7AuyetG1aC0JAxzjcs1p41sU6M4h74BONA3f7C3TFo+DHM1YEkviia5Msl0kumZmZml3LUkNW8UgX4I2NQ3v7G37CWq6saqmqyqyYmJiRHsWpJ0xCgCfQr4cO9ulwuAw1X1yxFsV1o6Z71z3B1Ix63LbYs3AfcCb05yMMnlSa5KclWv5A5gP7AP+ArwsRPWrXSiXDZlqGs8RniXy9DbFqtq55D1BfzNyDqSxuWyqXF3IB0X3ykqSY0w0CWpEQa6JDXCQJekRhjoktQIA12SGmGgS1IjDHRJaoSBLkmNMNAlqREGuiQ1wkCXpEYY6JLUCANdkhphoEtSIwx0SWqEgS5JjTDQJakRBrokNcJAl6RGGOiS1AgDXZIaYaBLUiM6BXqSbUkeS7Ivye4B689McleSB5I8lOTi0bcqSVrI0EBPshK4AbgI2ArsTLJ1TtlngFur6lzgUuBfR92oJGlhXY7Qzwf2VdX+qnoOuBnYMaemgFf3pk8Dnhhdi5KkLroE+gbgQN/8wd6yftcAu5IcBO4APj5oQ0muTDKdZHpmZuYY2pUkzWdUF0V3Al+rqo3AxcA3k7xk21V1Y1VNVtXkxMTEiHYtSYJugX4I2NQ3v7G3rN/lwK0AVXUv8Epg/SgalCR10yXQ9wJbkpyVZA2zFz2n5tT8Ang3QJK3MBvonlORpCU0NNCr6gXgauBO4FFm72Z5OMm1Sbb3yj4JXJHkh8BNwEeqqk5U05Kkl1rVpaiq7mD2Ymf/ss/2TT8CXDja1iRJi+E7RSWpEQa6JDXCQJekRhjoktQIA12SGmGgS1IjDHRJaoSBLkmNMNAlqREGuiQ1wkCXpEYY6JLUCANdkhphoEtSIwx0SWqEgS5JjTDQJakRBrokNcJAl6RGGOiS1AgDXZIaYaBLUiM6BXqSbUkeS7Ivye55aj6U5JEkDyf51mjblCQNs2pYQZKVwA3Ae4GDwN4kU1X1SF/NFuDvgQur6ukkZ5yohiVJg3U5Qj8f2FdV+6vqOeBmYMecmiuAG6rqaYCqenK0bUqShukS6BuAA33zB3vL+p0NnJ3kniR7kmwbtKEkVyaZTjI9MzNzbB1LkgYa1UXRVcAW4F3ATuArSdbNLaqqG6tqsqomJyYmRrRrSRJ0C/RDwKa++Y29Zf0OAlNV9XxV/Qz4CbMBL0laIl0CfS+wJclZSdYAlwJTc2puY/bonCTrmT0Fs390bUqShhka6FX1AnA1cCfwKHBrVT2c5Nok23tldwJPJXkEuAv4VFU9daKaliS9VKpqLDuenJys6enpsexbkparJPdX1eSgdb5TVJIaYaBLUiMMdElqhIEuSY0w0CWpEQa6JDXCQJekRhjoktQIA12SGmGgS1IjDHRJaoSBLkmNMNAlqREGuiQ1wkCXpEYY6JLUCANdkhphoEtSIwx0SWqEgS5JjTDQJakRBrokNcJAl6RGdAr0JNuSPJZkX5LdC9R9IEklmRxdi5KkLoYGepKVwA3ARcBWYGeSrQPqTgX+Frhv1E1KkobrcoR+PrCvqvZX1XPAzcCOAXWfA74A/HaE/UmSOuoS6BuAA33zB3vLXpTkPGBTVd2+0IaSXJlkOsn0zMzMopuVJM3vuC+KJlkBfBH45LDaqrqxqiaranJiYuJ4dy1J6tMl0A8Bm/rmN/aWHXEq8FbgB0keBy4AprwwKklLq0ug7wW2JDkryRrgUmDqyMqqOlxV66tqc1VtBvYA26tq+oR0LEkaaGigV9ULwNXAncCjwK1V9XCSa5NsP9ENSpK6WdWlqKruAO6Ys+yz89S+6/jbkiQtlu8UlaRGGOiS1AgDXZIaYaBLUiMMdElqhIEuSY0w0CWpEQa6JDXCQJekRhjoktQIA12SGmGgS1IjDHRJaoSBLkmNMNAlqREGuiQ1wkCXpEYY6JLUCANdkhphoEtSIwx0SWqEgS5JjegU6Em2JXksyb4kuwes/0SSR5I8lOR7Sd4w+lYlSQsZGuhJVgI3ABcBW4GdSbbOKXsAmKyqPwW+A/zTqBuVJC2syxH6+cC+qtpfVc8BNwM7+guq6q6q+k1vdg+wcbRtSpKG6RLoG4ADffMHe8vmcznw3UErklyZZDrJ9MzMTPcuJUlDjfSiaJJdwCRw3aD1VXVjVU1W1eTExMQody1JL3urOtQcAjb1zW/sLTtKkvcAnwbeWVW/G017kqSuuhyh7wW2JDkryRrgUmCqvyDJucCXge1V9eTo25QkDTM00KvqBeBq4E7gUeDWqno4ybVJtvfKrgP+CPh2kgeTTM2zOUnSCdLllAtVdQdwx5xln+2bfs+I+5IkLZLvFJWkRhjoktQIA12SGmGgS1IjDHRJaoSBLkmNMNAlqREGuiQ1wkCXpEYY6JLUCANdkhphoEtSIwx0SWqEgS5JjTDQJakRBrokNcJAl6RGGOiS1AgDXZIaYaBLUiMMdElqhIEuSY3oFOhJtiV5LMm+JLsHrH9Fklt66+9LsnnknUqSFrRqWEGSlcANwHuBg8DeJFNV9Uhf2eXA01X1x0kuBb4AXDLybq85beSblBZ0zeFxdyB11uUI/XxgX1Xtr6rngJuBHXNqdgBf701/B3h3koyuTQxzjYevOy0jXQJ9A3Cgb/5gb9nAmqp6ATgMvHYUDUqSulnSi6JJrkwynWR6ZmZmKXctSc3rEuiHgE198xt7ywbWJFkFnAY8NXdDVXVjVU1W1eTExMSxdSxJGqhLoO8FtiQ5K8ka4FJgak7NFHBZb/qDwPerqkbXpiRpmKGB3jsnfjVwJ/AocGtVPZzk2iTbe2X/Abw2yT7gE8BLbm08bt5toHHwdadlJOM6kJ6cnKzp6emx7FuSlqsk91fV5KB1vlNUkhphoEtSIwx0SWqEgS5JjTDQJakRBrokNcJAl6RGGOiS1AgDXZIaYaBLUiMMdElqxNA/QXey+euv3Ms9P/31SLcZ4Fg+0eb6S87hujsf44lnnmXdKaupgsPPPj9w+plnn2dlwu/nfHbOkX2vTNj59k384/vftuA+b3vg0Iv7fP26tXzqL9/M+8/dcNTy/v3313Td1qCaQ888e9TzdPopq/mHv/qTgds9Xp+57UfcdN8Bfl/FisArVq3gt8//YcGxDNNlrIupOxEWu+/++tPWriaBp3/z/6+zDcfZ/7B+jnX9oOXAcT3vi3nu5mbIhW96Df95xTuOeXsnk2X14Vybd99+grqRpKV3/SXnLPo/iiY+nMswl9Sav7vlQW57YO7fCzp2yybQJalF19352Mi2ZaBL0hg98cyzI9uWgS5JY/T6dWtHti0DXZLG6MhdPqOwbAL98c+/b9wtSNJIHctdLgtZVvehG+qSNL9lc4QuSVqYgS5JjTDQJakRBrokNcJAl6RGjO3DuZLMAD8fy86Pz3rgV+Nu4jgt9zEs9/5h+Y/B/sfnDVU1MWjF2AJ9uUoyPd8nnS0Xy30My71/WP5jsP+Tk6dcJKkRBrokNcJAX7wbx93ACCz3MSz3/mH5j8H+T0KeQ5ekRniELkmNMNAlqREG+jySbEvyWJJ9SXYPWP+KJLf01t+XZPMY2lxQhzF8IskjSR5K8r0kbxhHn/MZ1n9f3QeSVJKT6ja0Lv0n+VDve/Bwkm8tdY/DdHgNnZnkriQP9F5HF4+jz/kk+WqSJ5P8eJ71SfKl3vgeSnLeUvc4UlXlY84DWAn8FHgjsAb4IbB1Ts3HgH/vTV8K3DLuvo9hDH8OnNKb/ujJNIYu/ffqTgXuBvYAk+Pue5HP/xbgAeD03vwZ4+77GMZwI/DR3vRW4PFx9z2nvz8DzgN+PM/6i4HvAgEuAO4bd8/H8/AIfbDzgX1Vtb+qngNuBnbMqdkBfL03/R3g3UmyhD0OM3QMVXVXVf2mN7sH2LjEPS6ky/cA4HPAF4DfLmVzHXTp/wrghqp6GqCqnlziHofpMoYCXt2bPg14Ygn7G6qq7gZ+vUDJDuAbNWsPsC7J65amu9Ez0AfbABzomz/YWzawpqpeAA4Dr12S7rrpMoZ+lzN7pHKyGNp/79fjTVV1+1I21lGX5/9s4Owk9yTZk2TbknXXTZcxXAPsSnIQuAP4+NK0NjKL/Tk5qS2rv1ikEyPJLmASeOe4e+kqyQrgi8BHxtzK8VjF7GmXdzH729HdSd5WVc+Ms6lF2gl8rar+Ock7gG8meWtV/WHcjb0ceYQ+2CFgU9/8xt6ygTVJVjH76+ZTS9JdN13GQJL3AJ8GtlfV75aoty6G9X8q8FbgB0keZ/b859RJdGG0y/N/EJiqquer6mfAT5gN+JNFlzFcDtwKUFX3Aq9k9oOvlotOPyfLhYE+2F5gS5Kzkqxh9qLn1JyaKeCy3vQHge9X7yrLSWLoGJKcC3yZ2TA/2c7fLth/VR2uqvVVtbmqNjN7DWB7VU2Pp92X6PIauo3Zo3OSrGf2FMz+JexxmC5j+AXwboAkb2E20GeWtMvjMwV8uHe3ywXA4ar65bibOmbjvip7sj6Yvfr9E2av8n+6t+xaZkMDZl+43wb2Af8DvHHcPR/DGP4L+F/gwd5jatw9L6b/ObU/4CS6y6Xj8x9mTxs9AvwIuHTcPR/DGLYC9zB7B8yDwF+Mu+c5/d8E/BJ4ntnfiC4HrgKu6vse3NAb349OttfQYh++9V+SGuEpF0lqhIEuSY0w0CWpEQa6JDXCQJekRhjoktQIA12SGvF/DLH4S5734XoAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "def gelu(x):\n",
    "    return .5 * x * (1 + t.erf(x / np.sqrt(2)))\n",
    "\n",
    "\n",
    "def gelu_deriv(x):\n",
    "    return .5 * t.erf(x / 2**.5) + x * t.exp(-x * x / 2) / (2 * t.pi)**.5 + .5\n",
    "\n",
    "    \n",
    "def pack_last_axis(tensor, n_bits):\n",
    "    mn = tensor.min(axis=-1, keepdims=True).values\n",
    "    mx = tensor.max(axis=-1, keepdims=True).values\n",
    "    \n",
    "    tensor = (tensor - mn) / (mx - mn) * (2**n_bits - 1)\n",
    "\n",
    "    quantized = t.floor(tensor).to(t.uint8)\n",
    "    probs = 1 - (tensor - quantized)\n",
    "    quantized += t.rand(*quantized.shape).to(probs) > probs\n",
    "    \n",
    "    return quantized, mn, mx\n",
    "\n",
    "\n",
    "def unpack_last_axis(quantized, mn, mx, n_bits):\n",
    "    return quantized * (mx - mn) / (2**n_bits - 1) + mn\n",
    "\n",
    "\n",
    "def pack_unpack(tensor, n_bits, group_size):\n",
    "    shape = tensor.shape\n",
    "    tensor = tensor.reshape(-1, group_size)\n",
    "    return unpack_last_axis(*pack_last_axis(tensor, n_bits), n_bits).reshape(*shape)\n",
    "\n",
    "\n",
    "N_BITS = 1\n",
    "\n",
    "xs = t.randn(1000, 256) * 3\n",
    "values = gelu_deriv(xs)\n",
    "\n",
    "ctx = Ctx()\n",
    "ActNNGeluFn.forward(ctx, xs, N_BITS)\n",
    "res1, _ = ActNNGeluFn.backward(ctx, t.ones_like(xs))\n",
    "plt.scatter(values.numpy().flatten(), res1.numpy().flatten())\n",
    "# plt.show()\n",
    "\n",
    "\n",
    "\n",
    "ctx = Ctx()\n",
    "NBitsGelu.forward(ctx, xs, N_BITS)\n",
    "res2, _ = NBitsGelu.backward(ctx, t.ones_like(xs))\n",
    "plt.scatter(values.numpy().flatten(), res2.numpy().flatten())\n",
    "\n",
    "# plt.plot([0, 1], [0, 1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "faa8ed21",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYMAAAD4CAYAAAAO9oqkAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAUO0lEQVR4nO3df5Bd9Xnf8fcnItA0MUVYG0aRoCu7IhNMWtnewXRau6TEWOAOstsMlWYSZIex7Bg6TZ1pIzd/4LHLDG7ieIYZiiPXGqATg4mJy06QQxRqh2knslkC5VdMWLAIq8poDQ60JSUBP/3jfte5iP1xd+/dvbvo/Zq5s+c+53vOfY52tZ89P+65qSokSSe2Hxp2A5Kk4TMMJEmGgSTJMJAkYRhIkoCTht3AUm3YsKFGR0eH3YYkrSn33Xffd6tq5Pj6mg2D0dFRJiYmht2GJK0pSZ6are5hIkmSYSBJMgwkSRgGkiQMA0kShoEkCcNAkkQPYZBkf5JjSR7uqn0pyQPtcTjJA60+muQvu+Z9rmuZtyd5KMlkkuuSpNVPT3IwyePt6/pl2E5J0jx62TO4EdjeXaiqf1lV26pqG3A78Ltds5+YmVdVH+mq3wB8CNjaHjPr3AvcXVVbgbvbc0nSClowDKrqHuC52ea1v+4vA26Zbx1JNgKnVtWh6nyazs3A+9rsHcBNbfqmrvqyGd17J6N771zul5GkNaPfcwbvBJ6pqse7aluS3J/kj5K8s9U2AVNdY6ZaDeCMqjrapr8DnDHXiyXZk2QiycT09HSfrUvS2rHcf8T2Gwa7ePVewVHgrKp6K/Ax4ItJTu11ZW2vYc7P4ayqfVU1VlVjIyOvuc+SJGmJlnyjuiQnAf8cePtMrapeAl5q0/cleQI4GzgCbO5afHOrATyTZGNVHW2Hk44ttSdJ0tL0s2fws8C3quoHh3+SjCRZ16bfROdE8ZPtMNALSc5v5xkuB+5oi40Du9v07q66JGmF9HJp6S3AHwM/mWQqyRVt1k5ee+L4XcCD7VLTLwMfqaqZk88fBf4zMAk8AXy11a8F3p3kcToBc+3SN0eStBQLHiaqql1z1D8wS+12OpeazjZ+Ajh3lvqzwIUL9SFJWj6+A1mSZBhIkgwDSRKGgSQJw0CShGEgScIwkCRhGEiSMAwkSRgGkiQMA0kShoEkCcNAkoRhIEnCMJAkYRhIkjAMJEkYBpIkDANJEoaBJIkewiDJ/iTHkjzcVftEkiNJHmiPS7rmfTzJZJLHkrynq7691SaT7O2qb0nyjVb/UpKTB7mBkqSF9bJncCOwfZb6Z6tqW3scAEhyDrATeEtb5j8lWZdkHXA9cDFwDrCrjQX4dFvX3wO+B1zRzwZJkhZvwTCoqnuA53pc3w7g1qp6qaq+DUwC57XHZFU9WVV/BdwK7EgS4J8CX27L3wS8b3GbIEnqVz/nDK5K8mA7jLS+1TYBT3eNmWq1uepvBP6iql4+ri5JWkFLDYMbgDcD24CjwGcG1dB8kuxJMpFkYnp6eiVeUpJOCEsKg6p6pqpeqarvA5+ncxgI4AhwZtfQza02V/1Z4LQkJx1Xn+t191XVWFWNjYyMLKV1SdIslhQGSTZ2PX0/MHOl0TiwM8kpSbYAW4FvAvcCW9uVQyfTOck8XlUFfA34ubb8buCOpfQkSVq6kxYakOQW4AJgQ5Ip4GrggiTbgAIOAx8GqKpHktwGPAq8DFxZVa+09VwF3AWsA/ZX1SPtJX4VuDXJfwDuB74wqI2TJPVmwTCoql2zlOf8hV1V1wDXzFI/AByYpf4kf3OYSZI0BL4DWZJkGEiSDANJEoaBJAnDQJKEYSBJwjCQJGEYSJIwDCRJGAaSJAwDSRKGgSQJw0CShGEgScIwkCTRw+cZSJJebXTvnT+YPnzte4fYyeC4ZyBJMgxWyujeO1/114QkrSaGgSTJMJAkGQaSJHoIgyT7kxxL8nBX7deTfCvJg0m+kuS0Vh9N8pdJHmiPz3Ut8/YkDyWZTHJdkrT66UkOJnm8fV2/DNspSZpHL3sGNwLbj6sdBM6tqr8P/Bnw8a55T1TVtvb4SFf9BuBDwNb2mFnnXuDuqtoK3N2eS5JW0IJhUFX3AM8dV/uDqnq5PT0EbJ5vHUk2AqdW1aGqKuBm4H1t9g7gpjZ9U1ddkrRCBnHO4BeBr3Y935Lk/iR/lOSdrbYJmOoaM9VqAGdU1dE2/R3gjLleKMmeJBNJJqanpwfQuiQJ+gyDJL8GvAz8disdBc6qqrcCHwO+mOTUXtfX9hpqnvn7qmqsqsZGRkb66FyS1G3Jt6NI8gHgnwEXtl/iVNVLwEtt+r4kTwBnA0d49aGkza0G8EySjVV1tB1OOrbUniRJS7OkPYMk24F/B1xaVS921UeSrGvTb6JzovjJdhjohSTnt6uILgfuaIuNA7vb9O6uuiRphSy4Z5DkFuACYEOSKeBqOlcPnQIcbFeIHmpXDr0L+GSSvwa+D3ykqmZOPn+UzpVJP0LnHMPMeYZrgduSXAE8BVw2kC2TJPVswTCoql2zlL8wx9jbgdvnmDcBnDtL/VngwoX6kCQtH9+BLEkyDCRJhoEkCcNAkoRhIEnCMJAkYRhIkjAMJEkYBpIkDANJEoaBJAnDQJKEYSBJwjCQJGEYSJIwDCRJGAaSJAwDSRKGgSQJw0CSRI9hkGR/kmNJHu6qnZ7kYJLH29f1rZ4k1yWZTPJgkrd1LbO7jX88ye6u+tuTPNSWuS5JBrmRkqT59bpncCOw/bjaXuDuqtoK3N2eA1wMbG2PPcAN0AkP4GrgHcB5wNUzAdLGfKhrueNfS5K0jHoKg6q6B3juuPIO4KY2fRPwvq76zdVxCDgtyUbgPcDBqnquqr4HHAS2t3mnVtWhqirg5q51SZJWQD/nDM6oqqNt+jvAGW16E/B017ipVpuvPjVL/TWS7EkykWRienq6j9YlSd0GcgK5/UVfg1jXAq+zr6rGqmpsZGRkuV9Okk4Y/YTBM+0QD+3rsVY/ApzZNW5zq81X3zxLXZK0QvoJg3Fg5oqg3cAdXfXL21VF5wPPt8NJdwEXJVnfThxfBNzV5r2Q5Px2FdHlXeuSJK2Ak3oZlOQW4AJgQ5IpOlcFXQvcluQK4Cngsjb8AHAJMAm8CHwQoKqeS/Ip4N427pNVNXNS+qN0rlj6EeCr7SFJWiE9hUFV7Zpj1oWzjC3gyjnWsx/YP0t9Aji3l14kSYPnO5AlSYaBJMkwkCRhGEiSMAwkSRgGkiQMA0kShoEkCcNAkoRhIEnCMJAkYRhIkjAMJEkYBpIkDANJEoaBJAnDQJKEYSBJwjCQJGEYSJIwDCRJ9BEGSX4yyQNdjxeS/HKSTyQ50lW/pGuZjyeZTPJYkvd01be32mSSvf1ulCRpcU5a6oJV9RiwDSDJOuAI8BXgg8Bnq+o3uscnOQfYCbwF+AngD5Oc3WZfD7wbmALuTTJeVY8utTdJ0uIsOQyOcyHwRFU9lWSuMTuAW6vqJeDbSSaB89q8yap6EiDJrW3ssofB6N47fzB9+Nr3LvfLSdKqNahzBjuBW7qeX5XkwST7k6xvtU3A011jplptrvprJNmTZCLJxPT09IBalyT1HQZJTgYuBX6nlW4A3kznENJR4DP9vsaMqtpXVWNVNTYyMjKo1UrSqjW6985XHcVYLoM4THQx8CdV9QzAzFeAJJ8Hfq89PQKc2bXc5lZjnrokaQUM4jDRLroOESXZ2DXv/cDDbXoc2JnklCRbgK3AN4F7ga1JtrS9jJ1trCRphfS1Z5DkR+lcBfThrvJ/TLINKODwzLyqeiTJbXRODL8MXFlVr7T1XAXcBawD9lfVI/30JUlanL7CoKr+L/DG42q/MM/4a4BrZqkfAA7004skael8B7IkyTCQJBkGkiQMA0kShoEkCcNAkoRhIEnCMJAkYRhIkjAMJEkYBpIkDANJEoP72Ms1b+bDI/z4S0nDthIfZnM89wwkqQ8r9Ulky809g2X0evgBkXRicM9AkmQYSJIMA0kShoEkCcNAksQAwiDJ4SQPJXkgyUSrnZ7kYJLH29f1rZ4k1yWZTPJgkrd1rWd3G/94kt399iVJ6t2gLi39mar6btfzvcDdVXVtkr3t+a8CFwNb2+MdwA3AO5KcDlwNjAEF3JdkvKq+N6D+etZ9OahvQJPU7fV8ufhyvc9gB3BBm74J+DqdMNgB3FxVBRxKclqSjW3swap6DiDJQWA7cMsy9SdJq84ww2YQYVDAHyQp4Leqah9wRlUdbfO/A5zRpjcBT3ctO9Vqc9VfJckeYA/AWWedNYDWV557HpJWo0GEwT+uqiNJfhw4mORb3TOrqlpQ9K0FzT6AsbGxgaxTkgZhrd/frO8wqKoj7euxJF8BzgOeSbKxqo62w0DH2vAjwJldi29utSP8zWGlmfrX++1tWF7PxxUlvT71dTVRkh9N8oaZaeAi4GFgHJi5Img3cEebHgcub1cVnQ883w4n3QVclGR9u/LoolaTJK2AfvcMzgC+kmRmXV+sqt9Pci9wW5IrgKeAy9r4A8AlwCTwIvBBgKp6LsmngHvbuE/OnEyWpNez1XIkoa8wqKongX8wS/1Z4MJZ6gVcOce69gP7++ln0BZ7DHC1fFMlabG8hbUkDVCvVwyutj8eDYMerPWrBCStHqstBGYYBpK0TFbrL/7ZGAZLtJa+yZK0EMNgEQwASa9X3sJakmQYSJIMA0kShoEkCcNAkoRhIEnCMJAkYRhIkjAMJEkYBpIkDANJEoaBJAnDQJKEYSBJwjCQJNFHGCQ5M8nXkjya5JEk/7rVP5HkSJIH2uOSrmU+nmQyyWNJ3tNV395qk0n29rdJkqTF6ufDbV4GfqWq/iTJG4D7khxs8z5bVb/RPTjJOcBO4C3ATwB/mOTsNvt64N3AFHBvkvGqerSP3iRJi7DkMKiqo8DRNv2/k/wpsGmeRXYAt1bVS8C3k0wC57V5k1X1JECSW9tYw0CSVshAzhkkGQXeCnyjla5K8mCS/UnWt9om4OmuxaZaba76bK+zJ8lEkonp6elBtC5JYgBhkOTHgNuBX66qF4AbgDcD2+jsOXym39eYUVX7qmqsqsZGRkYGtVpJOuH1c86AJD9MJwh+u6p+F6Cqnuma/3ng99rTI8CZXYtvbjXmqUuSVkA/VxMF+ALwp1X1m131jV3D3g883KbHgZ1JTkmyBdgKfBO4F9iaZEuSk+mcZB5fal+SpMXrZ8/gHwG/ADyU5IFW+/fAriTbgAIOAx8GqKpHktxG58Twy8CVVfUKQJKrgLuAdcD+qnqkj74kSYvUz9VE/x3ILLMOzLPMNcA1s9QPzLecJGl5+Q5kSZJhIEkyDCRJGAaSJAwDSRKGgSQJw0CShGEgScIwkCRhGEiSMAwkSRgGkiT6/DwD9Wd0752ven742ve+Zl53TdJgHf9/8ERmGKwis/1gdtdmgmG2cYaGNLvj/7AyAGZnGKwh8/0Qz7YnMd94w0Nr1VJ/mRsC8zMMXmd6/YHvZdxCwXL8X1qLDZjZ9nq0dMv977nY9fe6V6vVIVU17B6WZGxsrCYmJpa0rD+Qa98gf7nM9ottvoDrJRh7fc3lGr9SVmtfr2f9Bn2S+6pq7DV1w0CS1o7lCgMvLZUkGQaSJMNAksQqCoMk25M8lmQyyd5h9yNJJ5JVEQZJ1gHXAxcD5wC7kpwz3K4k6cSxKsIAOA+YrKonq+qvgFuBHUPuSZJOGKvlTWebgKe7nk8B7zh+UJI9wJ729P8keWwFehu0DcB3h91En9b6Ntj/8K31bRha//l036v4u7MVV0sY9KSq9gH7ht1HP5JMzHaN71qy1rfB/odvrW/DWu9/NqvlMNER4Myu55tbTZK0AlZLGNwLbE2yJcnJwE5gfMg9SdIJY1UcJqqql5NcBdwFrAP2V9UjQ25ruazpw1zNWt8G+x++tb4Na73/11iz9yaSJA3OajlMJEkaIsNAkmQYLJeFbq+R5JQkX2rzv5FkdAhtzqmH/j+W5NEkDya5O8ms1y4PU6+3OEnyL5JUklV1qWAv/Se5rH0fHknyxZXucSE9/BydleRrSe5vP0uXDKPP2STZn+RYkofnmJ8k17VtezDJ21a6x4GqKh8DftA5Cf4E8CbgZOB/AuccN+ajwOfa9E7gS8Pue5H9/wzwt9v0L62m/nvdhjbuDcA9wCFgbNh9L/J7sBW4H1jfnv/4sPtewjbsA36pTZ8DHB523129vQt4G/DwHPMvAb4KBDgf+Mawe+7n4Z7B8ujl9ho7gJva9JeBC5NkBXucz4L9V9XXqurF9vQQnfeGrCa93uLkU8Cngf+3ks31oJf+PwRcX1XfA6iqYyvc40J62YYCTm3Tfwf4XyvY37yq6h7guXmG7ABuro5DwGlJNq5Md4NnGCyP2W6vsWmuMVX1MvA88MYV6W5hvfTf7Qo6fyGtJgtuQ9utP7OqVuNH3/XyPTgbODvJ/0hyKMn2FeuuN71swyeAn08yBRwA/tXKtDYQi/1/sqqtivcZaO1K8vPAGPBPht3LYiT5IeA3gQ8MuZV+nETnUNEFdPbM7kny01X1F8NsapF2ATdW1WeS/EPgvyQ5t6q+P+zGTjTuGSyPXm6v8YMxSU6is4v87Ip0t7Cebg+S5GeBXwMuraqXVqi3Xi20DW8AzgW+nuQwnWO+46voJHIv34MpYLyq/rqqvg38GZ1wWC162YYrgNsAquqPgb9F5yZwa8Hr6jY6hsHy6OX2GuPA7jb9c8B/q3ZWahVYsP8kbwV+i04QrLZj1bDANlTV81W1oapGq2qUznmPS6tqYjjtvkYvP0P/lc5eAUk20Dls9OQK9riQXrbhz4ELAZL8FJ0wmF7RLpduHLi8XVV0PvB8VR0ddlNL5WGiZVBz3F4jySeBiaoaB75AZ5d4ks5Jqp3D6/jVeuz/14EfA36nnff+86q6dGhNH6fHbVi1euz/LuCiJI8CrwD/tqpWy95lr9vwK8Dnk/wbOieTP7Ba/ihKcgudsN3QzmlcDfwwQFV9js45jkuASeBF4IPD6XQwvB2FJMnDRJIkw0CShGEgScIwkCRhGEiSMAwkSRgGkiTg/wPoGnsekcWibgAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYMAAAD4CAYAAAAO9oqkAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAUO0lEQVR4nO3df5Bd9Xnf8fcnItA0MUVYG0aRoCu7IhNMWtnewXRau6TEWOAOstsMlWYSZIex7Bg6TZ1pIzd/4LHLDG7ieIYZiiPXGqATg4mJy06QQxRqh2knslkC5VdMWLAIq8poDQ60JSUBP/3jfte5iP1xd+/dvbvo/Zq5s+c+53vOfY52tZ89P+65qSokSSe2Hxp2A5Kk4TMMJEmGgSTJMJAkYRhIkoCTht3AUm3YsKFGR0eH3YYkrSn33Xffd6tq5Pj6mg2D0dFRJiYmht2GJK0pSZ6are5hIkmSYSBJMgwkSRgGkiQMA0kShoEkCcNAkkQPYZBkf5JjSR7uqn0pyQPtcTjJA60+muQvu+Z9rmuZtyd5KMlkkuuSpNVPT3IwyePt6/pl2E5J0jx62TO4EdjeXaiqf1lV26pqG3A78Ltds5+YmVdVH+mq3wB8CNjaHjPr3AvcXVVbgbvbc0nSClowDKrqHuC52ea1v+4vA26Zbx1JNgKnVtWh6nyazs3A+9rsHcBNbfqmrvqyGd17J6N771zul5GkNaPfcwbvBJ6pqse7aluS3J/kj5K8s9U2AVNdY6ZaDeCMqjrapr8DnDHXiyXZk2QiycT09HSfrUvS2rHcf8T2Gwa7ePVewVHgrKp6K/Ax4ItJTu11ZW2vYc7P4ayqfVU1VlVjIyOvuc+SJGmJlnyjuiQnAf8cePtMrapeAl5q0/cleQI4GzgCbO5afHOrATyTZGNVHW2Hk44ttSdJ0tL0s2fws8C3quoHh3+SjCRZ16bfROdE8ZPtMNALSc5v5xkuB+5oi40Du9v07q66JGmF9HJp6S3AHwM/mWQqyRVt1k5ee+L4XcCD7VLTLwMfqaqZk88fBf4zMAk8AXy11a8F3p3kcToBc+3SN0eStBQLHiaqql1z1D8wS+12OpeazjZ+Ajh3lvqzwIUL9SFJWj6+A1mSZBhIkgwDSRKGgSQJw0CShGEgScIwkCRhGEiSMAwkSRgGkiQMA0kShoEkCcNAkoRhIEnCMJAkYRhIkjAMJEkYBpIkDANJEoaBJIkewiDJ/iTHkjzcVftEkiNJHmiPS7rmfTzJZJLHkrynq7691SaT7O2qb0nyjVb/UpKTB7mBkqSF9bJncCOwfZb6Z6tqW3scAEhyDrATeEtb5j8lWZdkHXA9cDFwDrCrjQX4dFvX3wO+B1zRzwZJkhZvwTCoqnuA53pc3w7g1qp6qaq+DUwC57XHZFU9WVV/BdwK7EgS4J8CX27L3wS8b3GbIEnqVz/nDK5K8mA7jLS+1TYBT3eNmWq1uepvBP6iql4+ri5JWkFLDYMbgDcD24CjwGcG1dB8kuxJMpFkYnp6eiVeUpJOCEsKg6p6pqpeqarvA5+ncxgI4AhwZtfQza02V/1Z4LQkJx1Xn+t191XVWFWNjYyMLKV1SdIslhQGSTZ2PX0/MHOl0TiwM8kpSbYAW4FvAvcCW9uVQyfTOck8XlUFfA34ubb8buCOpfQkSVq6kxYakOQW4AJgQ5Ip4GrggiTbgAIOAx8GqKpHktwGPAq8DFxZVa+09VwF3AWsA/ZX1SPtJX4VuDXJfwDuB74wqI2TJPVmwTCoql2zlOf8hV1V1wDXzFI/AByYpf4kf3OYSZI0BL4DWZJkGEiSDANJEoaBJAnDQJKEYSBJwjCQJGEYSJIwDCRJGAaSJAwDSRKGgSQJw0CShGEgScIwkCTRw+cZSJJebXTvnT+YPnzte4fYyeC4ZyBJMgxWyujeO1/114QkrSaGgSTJMJAkGQaSJHoIgyT7kxxL8nBX7deTfCvJg0m+kuS0Vh9N8pdJHmiPz3Ut8/YkDyWZTHJdkrT66UkOJnm8fV2/DNspSZpHL3sGNwLbj6sdBM6tqr8P/Bnw8a55T1TVtvb4SFf9BuBDwNb2mFnnXuDuqtoK3N2eS5JW0IJhUFX3AM8dV/uDqnq5PT0EbJ5vHUk2AqdW1aGqKuBm4H1t9g7gpjZ9U1ddkrRCBnHO4BeBr3Y935Lk/iR/lOSdrbYJmOoaM9VqAGdU1dE2/R3gjLleKMmeJBNJJqanpwfQuiQJ+gyDJL8GvAz8disdBc6qqrcCHwO+mOTUXtfX9hpqnvn7qmqsqsZGRkb66FyS1G3Jt6NI8gHgnwEXtl/iVNVLwEtt+r4kTwBnA0d49aGkza0G8EySjVV1tB1OOrbUniRJS7OkPYMk24F/B1xaVS921UeSrGvTb6JzovjJdhjohSTnt6uILgfuaIuNA7vb9O6uuiRphSy4Z5DkFuACYEOSKeBqOlcPnQIcbFeIHmpXDr0L+GSSvwa+D3ykqmZOPn+UzpVJP0LnHMPMeYZrgduSXAE8BVw2kC2TJPVswTCoql2zlL8wx9jbgdvnmDcBnDtL/VngwoX6kCQtH9+BLEkyDCRJhoEkCcNAkoRhIEnCMJAkYRhIkjAMJEkYBpIkDANJEoaBJAnDQJKEYSBJwjCQJGEYSJIwDCRJGAaSJAwDSRKGgSQJw0CSRI9hkGR/kmNJHu6qnZ7kYJLH29f1rZ4k1yWZTPJgkrd1LbO7jX88ye6u+tuTPNSWuS5JBrmRkqT59bpncCOw/bjaXuDuqtoK3N2eA1wMbG2PPcAN0AkP4GrgHcB5wNUzAdLGfKhrueNfS5K0jHoKg6q6B3juuPIO4KY2fRPwvq76zdVxCDgtyUbgPcDBqnquqr4HHAS2t3mnVtWhqirg5q51SZJWQD/nDM6oqqNt+jvAGW16E/B017ipVpuvPjVL/TWS7EkykWRienq6j9YlSd0GcgK5/UVfg1jXAq+zr6rGqmpsZGRkuV9Okk4Y/YTBM+0QD+3rsVY/ApzZNW5zq81X3zxLXZK0QvoJg3Fg5oqg3cAdXfXL21VF5wPPt8NJdwEXJVnfThxfBNzV5r2Q5Px2FdHlXeuSJK2Ak3oZlOQW4AJgQ5IpOlcFXQvcluQK4Cngsjb8AHAJMAm8CHwQoKqeS/Ip4N427pNVNXNS+qN0rlj6EeCr7SFJWiE9hUFV7Zpj1oWzjC3gyjnWsx/YP0t9Aji3l14kSYPnO5AlSYaBJMkwkCRhGEiSMAwkSRgGkiQMA0kShoEkCcNAkoRhIEnCMJAkYRhIkjAMJEkYBpIkDANJEoaBJAnDQJKEYSBJwjCQJGEYSJIwDCRJ9BEGSX4yyQNdjxeS/HKSTyQ50lW/pGuZjyeZTPJYkvd01be32mSSvf1ulCRpcU5a6oJV9RiwDSDJOuAI8BXgg8Bnq+o3uscnOQfYCbwF+AngD5Oc3WZfD7wbmALuTTJeVY8utTdJ0uIsOQyOcyHwRFU9lWSuMTuAW6vqJeDbSSaB89q8yap6EiDJrW3ssofB6N47fzB9+Nr3LvfLSdKqNahzBjuBW7qeX5XkwST7k6xvtU3A011jplptrvprJNmTZCLJxPT09IBalyT1HQZJTgYuBX6nlW4A3kznENJR4DP9vsaMqtpXVWNVNTYyMjKo1UrSqjW6985XHcVYLoM4THQx8CdV9QzAzFeAJJ8Hfq89PQKc2bXc5lZjnrokaQUM4jDRLroOESXZ2DXv/cDDbXoc2JnklCRbgK3AN4F7ga1JtrS9jJ1trCRphfS1Z5DkR+lcBfThrvJ/TLINKODwzLyqeiTJbXRODL8MXFlVr7T1XAXcBawD9lfVI/30JUlanL7CoKr+L/DG42q/MM/4a4BrZqkfAA7004skael8B7IkyTCQJBkGkiQMA0kShoEkCcNAkoRhIEnCMJAkYRhIkjAMJEkYBpIkDANJEoP72Ms1b+bDI/z4S0nDthIfZnM89wwkqQ8r9Ulky809g2X0evgBkXRicM9AkmQYSJIMA0kShoEkCcNAksQAwiDJ4SQPJXkgyUSrnZ7kYJLH29f1rZ4k1yWZTPJgkrd1rWd3G/94kt399iVJ6t2gLi39mar6btfzvcDdVXVtkr3t+a8CFwNb2+MdwA3AO5KcDlwNjAEF3JdkvKq+N6D+etZ9OahvQJPU7fV8ufhyvc9gB3BBm74J+DqdMNgB3FxVBRxKclqSjW3swap6DiDJQWA7cMsy9SdJq84ww2YQYVDAHyQp4Leqah9wRlUdbfO/A5zRpjcBT3ctO9Vqc9VfJckeYA/AWWedNYDWV557HpJWo0GEwT+uqiNJfhw4mORb3TOrqlpQ9K0FzT6AsbGxgaxTkgZhrd/frO8wqKoj7euxJF8BzgOeSbKxqo62w0DH2vAjwJldi29utSP8zWGlmfrX++1tWF7PxxUlvT71dTVRkh9N8oaZaeAi4GFgHJi5Img3cEebHgcub1cVnQ883w4n3QVclGR9u/LoolaTJK2AfvcMzgC+kmRmXV+sqt9Pci9wW5IrgKeAy9r4A8AlwCTwIvBBgKp6LsmngHvbuE/OnEyWpNez1XIkoa8wqKongX8wS/1Z4MJZ6gVcOce69gP7++ln0BZ7DHC1fFMlabG8hbUkDVCvVwyutj8eDYMerPWrBCStHqstBGYYBpK0TFbrL/7ZGAZLtJa+yZK0EMNgEQwASa9X3sJakmQYSJIMA0kShoEkCcNAkoRhIEnCMJAkYRhIkjAMJEkYBpIkDANJEoaBJAnDQJKEYSBJwjCQJNFHGCQ5M8nXkjya5JEk/7rVP5HkSJIH2uOSrmU+nmQyyWNJ3tNV395qk0n29rdJkqTF6ufDbV4GfqWq/iTJG4D7khxs8z5bVb/RPTjJOcBO4C3ATwB/mOTsNvt64N3AFHBvkvGqerSP3iRJi7DkMKiqo8DRNv2/k/wpsGmeRXYAt1bVS8C3k0wC57V5k1X1JECSW9tYw0CSVshAzhkkGQXeCnyjla5K8mCS/UnWt9om4OmuxaZaba76bK+zJ8lEkonp6elBtC5JYgBhkOTHgNuBX66qF4AbgDcD2+jsOXym39eYUVX7qmqsqsZGRkYGtVpJOuH1c86AJD9MJwh+u6p+F6Cqnuma/3ng99rTI8CZXYtvbjXmqUuSVkA/VxMF+ALwp1X1m131jV3D3g883KbHgZ1JTkmyBdgKfBO4F9iaZEuSk+mcZB5fal+SpMXrZ8/gHwG/ADyU5IFW+/fAriTbgAIOAx8GqKpHktxG58Twy8CVVfUKQJKrgLuAdcD+qnqkj74kSYvUz9VE/x3ILLMOzLPMNcA1s9QPzLecJGl5+Q5kSZJhIEkyDCRJGAaSJAwDSRKGgSQJw0CShGEgScIwkCRhGEiSMAwkSRgGkiT6/DwD9Wd0752ven742ve+Zl53TdJgHf9/8ERmGKwis/1gdtdmgmG2cYaGNLvj/7AyAGZnGKwh8/0Qz7YnMd94w0Nr1VJ/mRsC8zMMXmd6/YHvZdxCwXL8X1qLDZjZ9nq0dMv977nY9fe6V6vVIVU17B6WZGxsrCYmJpa0rD+Qa98gf7nM9ottvoDrJRh7fc3lGr9SVmtfr2f9Bn2S+6pq7DV1w0CS1o7lCgMvLZUkGQaSJMNAksQqCoMk25M8lmQyyd5h9yNJJ5JVEQZJ1gHXAxcD5wC7kpwz3K4k6cSxKsIAOA+YrKonq+qvgFuBHUPuSZJOGKvlTWebgKe7nk8B7zh+UJI9wJ729P8keWwFehu0DcB3h91En9b6Ntj/8K31bRha//l036v4u7MVV0sY9KSq9gH7ht1HP5JMzHaN71qy1rfB/odvrW/DWu9/NqvlMNER4Myu55tbTZK0AlZLGNwLbE2yJcnJwE5gfMg9SdIJY1UcJqqql5NcBdwFrAP2V9UjQ25ruazpw1zNWt8G+x++tb4Na73/11iz9yaSJA3OajlMJEkaIsNAkmQYLJeFbq+R5JQkX2rzv5FkdAhtzqmH/j+W5NEkDya5O8ms1y4PU6+3OEnyL5JUklV1qWAv/Se5rH0fHknyxZXucSE9/BydleRrSe5vP0uXDKPP2STZn+RYkofnmJ8k17VtezDJ21a6x4GqKh8DftA5Cf4E8CbgZOB/AuccN+ajwOfa9E7gS8Pue5H9/wzwt9v0L62m/nvdhjbuDcA9wCFgbNh9L/J7sBW4H1jfnv/4sPtewjbsA36pTZ8DHB523129vQt4G/DwHPMvAb4KBDgf+Mawe+7n4Z7B8ujl9ho7gJva9JeBC5NkBXucz4L9V9XXqurF9vQQnfeGrCa93uLkU8Cngf+3ks31oJf+PwRcX1XfA6iqYyvc40J62YYCTm3Tfwf4XyvY37yq6h7guXmG7ABuro5DwGlJNq5Md4NnGCyP2W6vsWmuMVX1MvA88MYV6W5hvfTf7Qo6fyGtJgtuQ9utP7OqVuNH3/XyPTgbODvJ/0hyKMn2FeuuN71swyeAn08yBRwA/tXKtDYQi/1/sqqtivcZaO1K8vPAGPBPht3LYiT5IeA3gQ8MuZV+nETnUNEFdPbM7kny01X1F8NsapF2ATdW1WeS/EPgvyQ5t6q+P+zGTjTuGSyPXm6v8YMxSU6is4v87Ip0t7Cebg+S5GeBXwMuraqXVqi3Xi20DW8AzgW+nuQwnWO+46voJHIv34MpYLyq/rqqvg38GZ1wWC162YYrgNsAquqPgb9F5yZwa8Hr6jY6hsHy6OX2GuPA7jb9c8B/q3ZWahVYsP8kbwV+i04QrLZj1bDANlTV81W1oapGq2qUznmPS6tqYjjtvkYvP0P/lc5eAUk20Dls9OQK9riQXrbhz4ELAZL8FJ0wmF7RLpduHLi8XVV0PvB8VR0ddlNL5WGiZVBz3F4jySeBiaoaB75AZ5d4ks5Jqp3D6/jVeuz/14EfA36nnff+86q6dGhNH6fHbVi1euz/LuCiJI8CrwD/tqpWy95lr9vwK8Dnk/wbOieTP7Ba/ihKcgudsN3QzmlcDfwwQFV9js45jkuASeBF4IPD6XQwvB2FJMnDRJIkw0CShGEgScIwkCRhGEiSMAwkSRgGkiTg/wPoGnsekcWibgAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "xs = t.randn(1000, 256) * 2\n",
    "values = gelu_deriv(xs)\n",
    "plt.hist(values.numpy().flatten(), bins='rice')\n",
    "plt.show()\n",
    "plt.hist(values.numpy().flatten(), bins='rice')\n",
    "plt.show()\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
