{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from math import sqrt\n",
    "from random import seed\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import torch as th\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "plt.rc('font', family='serif', size=8)\n",
    "plt.rc('text', usetex=True)\n",
    "th.manual_seed(0)\n",
    "seed(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_embeddings(n, d, norm=True):\n",
    "    emb = th.randn(n, d)\n",
    "    if norm:\n",
    "        emb /= emb.norm(dim=1, keepdim=True)\n",
    "    else:\n",
    "        emb /= sqrt(d)\n",
    "    return emb\n",
    "\n",
    "\n",
    "class AssMem(nn.Module):\n",
    "    def __init__(self, E, U):\n",
    "        \"\"\"\n",
    "        E: torch.Tensor\n",
    "            Input embedding matrix of size $n \\times d$,\n",
    "            where $n$ is the number of tokens and $d$ is the embedding dimension.\n",
    "        U: torch.Tensor\n",
    "            Output unembedding matrix of size $d \\times m$,\n",
    "            where $m$ is the number of classes and $d$ is the embedding dimension.\n",
    "        \"\"\"\n",
    "        super().__init__()\n",
    "        d = E.shape[1]\n",
    "        self.W = nn.Parameter(th.zeros(d, d))\n",
    "        self.E = E\n",
    "        self.U = U\n",
    "\n",
    "    def forward(self, x):\n",
    "        out = self.E[x] @ self.W\n",
    "        out = out @ self.U\n",
    "        return out\n",
    "\n",
    "\n",
    "class AssMemApprox(nn.Module):\n",
    "    def __init__(self, E, U, all_y):\n",
    "        \"\"\"\n",
    "        E: torch.Tensor\n",
    "            Input embedding matrix of size $n \\times d$,\n",
    "            where $n$ is the number of tokens and $d$ is the embedding dimension.\n",
    "        U: torch.Tensor\n",
    "            Output unembedding matrix of size $d \\times m$,\n",
    "            where $m$ is the number of classes and $d$ is the embedding dimension.\n",
    "        \"\"\"\n",
    "        super().__init__()\n",
    "        d = E.shape[1]\n",
    "        self.m = U.shape[1]\n",
    "        self.n = E.shape[0]\n",
    "        self.q = th.zeros(self.n, dtype=th.float32)\n",
    "        self.E2 = E @ E.T\n",
    "        self.U2 = U[:, all_y].T @ U\n",
    "\n",
    "    def update(self, x, gamma):\n",
    "        grad = th.exp(self.q) / (self.m - 1)\n",
    "        grad = 1 / (1+ grad)\n",
    "        grad *= th.bincount(x, minlength=self.n)\n",
    "        self.q += gamma * grad\n",
    "\n",
    "    def prediction(self):\n",
    "        return (self.E2 * self.q) @ self.U2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "# number of input tokens\n",
    "n = 100\n",
    "# number of output classes\n",
    "m = 5\n",
    "# memory dimension\n",
    "d = 10\n",
    "\n",
    "# Zipf parameter\n",
    "alpha = 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Population data\n",
    "all_x = th.arange(n)\n",
    "proba = (all_x + 1.) ** (-alpha)\n",
    "proba /= proba.sum()\n",
    "all_y = all_x % m"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "# number of data\n",
    "batch_size = 1000\n",
    "nb_epoch = 100\n",
    "T = nb_epoch * batch_size"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,"
     ]
    }
   ],
   "source": [
    "lr = 1e0\n",
    "nb_trials = 100\n",
    "\n",
    "errors = th.zeros(2, nb_trials, nb_epoch)\n",
    "errors[:] = -1\n",
    "\n",
    "\n",
    "for i_t in range(nb_trials):\n",
    "    # Embeddings\n",
    "    E = get_embeddings(n, d, norm=False)\n",
    "    U = get_embeddings(m, d, norm=True).T\n",
    "\n",
    "    model = AssMem(E, U)\n",
    "    opti = th.optim.SGD(model.parameters(), lr=lr)\n",
    "\n",
    "    approx_model = AssMemApprox(E, U, all_y)\n",
    "\n",
    "    for i_e in range(nb_epoch):\n",
    "        x = th.multinomial(proba, batch_size, replacement=True)\n",
    "        y = x % m\n",
    "\n",
    "        # real update\n",
    "        out = model(x)\n",
    "        loss = F.cross_entropy(out, y)\n",
    "        opti.zero_grad()\n",
    "        loss.backward()\n",
    "        opti.step()\n",
    "\n",
    "        # approx update\n",
    "        approx_model.update(x, lr / batch_size)\n",
    "\n",
    "        with th.no_grad():\n",
    "            pred = model(all_x).argmax(dim=-1)\n",
    "            errors[0, i_t, i_e] = proba[pred != all_y].sum().item()\n",
    "\n",
    "            approx_pred = approx_model.prediction().argmax(dim=-1)\n",
    "            errors[1, i_t, i_e] = proba[approx_pred != all_y].sum().item()\n",
    "\n",
    "    print(i_t, end=',')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "mean_errors = errors.mean(dim=1).numpy()\n",
    "std_errors = errors.std(dim=1).numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAANMAAACGCAYAAABHRL+dAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAhMUlEQVR4nO2deXBc15Wfv7f0hrXR4L4KDVPWZokESKUslzUjs+GlnNiODZAzybgczZSAUimVcRwXOXRiO55MDQVOTY2TWOMBlJScsuYPEpAiW7bHMkBpJG+RSbQp2Za1obmBEEkQ3Q+NpZe35Y/b3dhJoNGNpfG+qi6g+233Ae/X595zzzlXsm3bxsHBYcnIK90AB4dSwRGTg0OBcMTk4FAgHDE5OBQIR0wODgXCEZODQ4FwxOTgUCAcMTk4FAh1pRuwlrEsi8HBQSorK5EkaaWb41AAbNtmdHSUbdu2IcuLszWOmJbA4OAgO3fuXOlmOBSBy5cvs2PHjkUd44hpCVRWVgLiD19VVbXCrXEoBPF4nJ07d+b+t4vBEdMSyHbtqqqqHDGVGPl02x0xFQPbBmcMtbrQE3DuH8X/xrbASEH9Q7DlAwW7hCOmIpAcH8Fb4V/pZjhM5UofjA1N/8zUC3oJR0xFIDkadcRUbKLnwUgCmR7AxLD4KclgJGAiCu4KUN1gGnD9jaI3yRFTEUgkk/hTo+BZ/CDWYQFMROH1U6K7topwxFQELNsW1qmUxWTbEI2IrpKREGMQ2xKfyyqYKbCMyfeqR1gNywTbFNvSE5Pnk1VhRWQVFDfICkhK5hwWKC4w0uK8o1dXnZDAEdOcdHd3c/z4cfr6+hZ/sGVhxK8xPNzP9g27C9+41cJvumH43ZVuxapiRcTU3d1NNBqlr6+PlpYWQqHQks6naRqdnZ0AHDlyZNp1AKLRKMFgcMHXaW5upqOjI6+2hL//LRrOfZVz6r1sP/CpvM6xrNi2sBaKKqyM4rr1MbELjpDmYNnFFA6HAWhtbUXTNOrq6ojFYks6Z29vL8PDw9TW1uY+i0Qi9PT05ETR1NS0ZNEuBLk2CMAG41rRr1UQfv1diA+K8V1qVHSzPJWim6V4hLgkBdxlYkCvuOHGWyvd6lXJsospGo3S09NDc3Mzfr+fQCBAOBymoaEht0/W0ky1Mp2dnYRCIYLB4KxzNjc3E41G0TQt91lvby9+vz/33u/309vbSygUylnGqQQCAZqbmxd0D0888QRPPPEEpmnO2la+ZQ8Am+0bC/+mXylGr8LIFfF7Mi5+mroY4DssmmUXUygUmmYhotHoNCGBePAbGho4ceIER44cobu7m0AgMKeQ5qO/v3+apQoEAjmxLVQ08/HYY4/x2GOPEY/Hqa6unratdutukrYLr6STHurHveWOJV2r4FgW/O5ZqNoG2qWVbk1JsaIpGG1tbTz55JNzbguFQjQ0NNDW1kY0Gl2yAIBZ1mg+ent7iUQiuTHXYqgp93DJ3gxA/N3/t+jji07sPNx4ByIvi7kah4KxYt687u5umpqabimSaDRKIBBY9Pnr6+undfuyToiFEAqF6O/vX/Q1QcR0XZW3cDsDJAZ+k9c5CoKeFBOZliHGOe4ycJXDtd+uXJtKnBURU3Y8EwqFCIfD+P3+WQ96OBwmHA7T1dVFZ2cn3d3di7JOoVCIo0eP5t5HIpFlcUAA3HBtBR3MG/kJMr+LviPmfWwLxm/A6HvCSzcTJ2awaCy7mCKRCC0tLbn3mqYxs6ispmmcPHmS9vZ2QHj+Ojs7iUQic1qX3t5eenp60DSNYDBIc3MzwWCQw4cP55wNx44dK+6NTSHu3QY6KBNXQbsM/mXIeYr8sxDRrXAK+BYNySmPnD9ZB8TIyMi0FIxv/8P/4NGrXyWm1FLzub+Duz5d+ItPjUyPvwd93yn8NUqd2z8G26c7v+b7ny4EJwKiCCS3NGK+J1FjDsPlM3Dbh6Fs8eO+efnl34sQnqod4PI6ruxVgiOmIrBxy3Z+Y9exV4oI79m7p+HellsfuBDGhyE5In6PRgpzzvXARBTiA2IC2tQhPQrjQ2KMuWN/QS7hiKkIbPP7+KV1N3vlCMQuitCba7+DzXcv/eTaxaWfY71hJOHc06BPTP/84s8hUOeIaTWzK1DOU9bdPMrz2LEIkmXC2y9A2Qao3HzrE1x/Ey79Eqq2Q229mGB1+cQ2R0zCmiRHAEmMGyUJRq+JUChJhkQUxq6JKQHFDe+dmxSSf/dkyFTV9sJ8wWVwxFQEbqsto8++gyG7mo3pERj6PWy+R8TB3fM58W14M8avi1Cf0asiQxTAUwG+GvGQlAr5pPfbNvz+efE3XSz3NEPt+ybfz+GAWAqOmIqAqshsqPDw1MTHOOI6BZdfhU13i776b5+BfX8ClVvmP8H40OzPUmPitRqwDLjxtrAOlgGJmIjtsy2Rq6R6RM0FUwdsMU5x+UB2if3NtHglNXKZsoprMjPWVZYJsnULR4ueEBYnPQ4TU9z/kiyOt01h9SVJtMFTDRWbxDWyEfG+WgjUF/XPkreYDhw4wLFjx/jsZz9byPaUDLsCZTw9GuLP3d/DMz4Ew+/AhtszgnoWGv+diEqYi/Hh5WuoqWeS79ziQVV94C6f32LYtvhCiC0yFCkx34bMzIyZFt2zhXLbh2H3hzKnsDLCWlnyFlNra+ssIb344ot85CMfWXKjSoG6LQF+dnGcF9UH+YTeAxd/BrV7xEOaHIFX/0EMfHc9IL45s1im+KZfDi6/CpGXxO+yC6xMgRHVA6pXWIisRfFWQfkmiF+ZFNLGO8W+Xj94qzPZsbIY8Ku+THatJGowGElhlWRFCFeSoawWYZnsSQ+bqQsLlLVeiltYLNsSfydJBX1cdJuzrAIhwRLEJEkSjz76KPX19QSDQaLRKF1dXY6YMtyxtRq4zJP6x/mE8gqMXRcD4W37xA5GCi78HK6Ewb8Ldj8AFZuFC7eYKdnnXxFWsnwTXP/d5OeWLgbmlinaZqQmXfBzsf0AvO9gYdtWyLm4FSBvMT3++OOEQiFu3LjBjRuiH7vQqOz1wL5dfgB+m9qEeeeHUSK9wgrICmy5d3JHPQFDb4mXyyv6/kvFMoXXb2RAfLurXmE50uNw6Rdin+y4rKYO7vqMEE72YU7ERL0FfVxYFFMX+ydHxO+qW4jfYRp5i6mjo4ODB6d/M50+fXrJDSoV3r+lCrcMaUvhjfL7+UD12zByCd76kXiod31w9kF6UghgsSRH4MIr4njLEB4/Izn//pIMG+8Q45/6j4juWMWmye3lGxffBof8xXTw4EHi8TinTp0C4NChQ7PEtZ5RZImtfh8XownCMQ8fuO+PxCThxZ/D+ZfFOGPTnYs/sW1nLERKdNUkCd75CURnRKi7yoT3yl0urF9yBJIxYaU+cHh+54dD3uQtpvPnz9PS0pKL4m5vb6erq4u9e/cWqm1rnuDGCi5GE/xmxAtSQnigTAMGXoW3/0lYj633zR5Aj12frLkwkzefnyyo6K4Q3cbs2Gb3h4RIy2rF+EtWinp/DtPJW0zPPPMMZ8+enfbZsWPHHDFN4c6tVbz01hBvTVQxnhqi3KNC8A9g7KoY07zzAlx9HbY3Cs+YrIh4u9+cEgLz7xLWpXYP+Pwi2S8rJEmB9JR5p+2NQqwOK0beYqqrmz2Lv39/YWKcVoIl1cqbh4ZdNQCcT1Xy5tVR7ttZjSrL8IFDMPhruPBTkcT35g9g4FewdS9c/pU42LZESa3YBeg/LaIfsi7z2j1w56dEhERWUBv2FKzdDvmRt5gikdkRy+fPL19NgdVUK28+7q8LoMoSY4bMmxMV1EQnCG7IdM127IeN7xeiGvy16Nq98xNxoJQR3Ph1EWkQH5w+97R9v4gYWI6kQ4cFk7eYQqEQH/3oR2lsbAREtms2M3Y5WE218uajyueifmMFb10b5eX0HdTHf8XGCg+V3kz5L08l1D0I2xpgMCyqBcUHxFxUzW3iteN+MecTOy/moHx+qCnhSrFrmLzFtG/fPjo6OnIPbmdnJ/v27StYw27FctfKy5fG3X7eujbKr0ermfBX8c71Me7dkenuZfFUCFGBEI7inn4S1SNc2Q6rmiXH5j3++OOFbM+SKGatPIBUKkUqlcq9j8fjN93fNE0eel8Nr7w5iG3b/NDzCR4ov0LZyO+5rXYe13QJuKxdVgqFOYq5lDglH5uXT628+YR3/PhxvvGNbyzofGNjYwwMDLBZNjke2kLSsJCAck8d4/Z23pUMFLk0KwVJZoodE7+jwrxJOFIJUlKxecWulXfs2DG+9KUv5d5nFxOeiWmaDAwMUFZWxrZALSMJg9Gkjm5auGSZSp+K1xjDp1qoJaYn24ahkTEGuJs9o6+uKwtVUrF5xa6V5/F48Hg8t9xP13Vs22bjxo14vV6SdgrZ5SY6lsbAxkDFKt+AaacolxNIJVYgamN1BRdG4+iyB8WauPUBJcKajc1bbbXy5kKSJCRJwqsq2LZJhVdlNKkzljJwqTIoHjyeMrzWxPSFv9Y467XOZd518/bv389XvvKVdZkcOHUVjLfffntWjbVkMsn58+epq6vD6/WSNkxiEzq2DSMJnZRhosoSgXI3siwRKHOjmslMhHYqly9XTMLnXqflC230vfxj/P7qWx+wCJJpnfOXB6mLvyq+KFYrBa6bl3dWVVtb25wOiPXAY489xhtvvMGZM2cWtL9LkVFlCUmCKp+KLEkYlk08aWBZoCV0LNUnUiC8NUVuvaBh770Eb9u1LNdaL5SUA2I1Yts2Cd1EVSRGkwYAbkVGS6RJ6iamaeNzKyTTJlU+Vcw/6SbY4HPJSOu1z7QGKagDYnh4GWsXrBESusldX3shr2Pf+PJ9lLnnjvzufekV2r54lI5vttP+zSfo+j+d+P3VHP3aX3GgYS+RCxcJ/eGHadh7L0e/9lc0PfQgPS+9QtvDnydY50RQFIM164BYSW62cuByEXroQfzVVQRq/DkhdT71NLWBGpo/8y8BaPr0YXq+d5LaQA2hh0SERfs3n6Djv59YsXaXMgsW04svvpgLbp2aCHj69GkikQjhcJj6+vp1kSB4s5UDZ+JzKbzxlx/LvTcti+i4cEaYlk1sIo1l23hcClUeF5IEipXCa4wiSTYWNx/YNuydTIHvO/c6/uoqel96BYD2b/zn3LbOp55GGxkhGtPyuWWHBbBgMbW0tHD69OlZ+UoHDx7k4MGDaJpGfX09X/7ylwvdxjWNJEmUuaf/mV2KwkhCVAJyqTIjE2lsQLcsKr0qslSGZCro5jh22sSlSsiSeN1sBNW49176z1/IWaHwudfpfOpphqMxjnzxMcLnXudM+DXC516fJkKHwrBgb94jjzySE9KFCxemvUAElT7yyCPFaGPJ4XUpeF1iLORRZap8Ioo8qZvcGE0zljIwZA8Jd4Bx1c+45SaRNplImyQNC8OyCZ97nciFS3Q+9XTuvK0P/wm1gRo6n3qa7ud+AMD+ffeijcTpfemVnFWKXLg0efx3nsahMCzYMk0NII3FYrS2tnL48OFpcWz19cWtmFlKVHpU0oaFZdt4XQqyJLx9hmUxnjJIGxY+l4LHpWCplaRlH6qdRrXSGLrOXffczdCFN2bF9x354mOzrjV1jJS1WgD9r/2yeDe4DlmwmKamNuzbt49Dhw7N6tKtFzduIRwQsiwRKHehTegYlo1blQmUu0nqJqNJA9200E0LOSVR5XPhVlR0WcWQPbiNcXRkVCOJkjmXS5nRDZQkUVRFVkX5YDOdqd5aWqFLq4kFiykSiTA6OppbMlOSpGnvgbwXVV5rLMYBcTMUWcZf5mJ4PJ2rYe9zK7hVmYRuinkoy0abSKPIEm5Fxq3KGEoliiyRtsuRbQMJC9VIIUuiC+l2ezKVWGe41W0bUqOi1JhDwVlwOJEsT59AtG17zvcr6S5ebuYLPZkZTnQrsuFGM7Fsm7GkQVK3sKfEGEkI0fjcCi5l9rDX51Ko8IpIi1nYNqTiYomVIhmp9RpOtGDL1NraytGjRwkE5i5hOzw8zIkTpTF/ka0jcebMmWVJxXerCtU+iCf0ac+3LIkuXoXXRjcs0qZN2rAwLIuEbpLQTTyqsGQeVc6NnxK6ScoQX2qKLCNLQoAel4JHlZG81eAqF4JyrFTBWLCY2tra5qxIlKW6upq2trYFX/jEiRO5XKOlZsEWsrhKd3c3fr+fUChEJBKhs7OT1tbWJbVvIXhdCqosCY+dbs4Slcel4MmUjkgbFom0EEz2NYZws5e5VDwuGStzAsucrFueNCzhPfS6kBUVlCpRmCW79Iu7XFirbJF9Sxc/57JgiipqAOaLJIlF3GxLvNITogt6M7JF//V5l9SYRM10daVMV9dbJZbx0ROi9kbt+8TPArJgMS2kvsNCa0A0NTXR1dWF3++nsbFxyWIqZHGVqW3p7+9f1BfEUlEVmSqfTIVXJWVYJHWTtDG7iL9bzYydTIWEbuWcFWnDIm2k8RpKZr5qdjcvZVgMjaUyApVRJBey7EJVZCTEWFjxzCh+aZlCYNmHWHEL4Y3fAGxRJVb1Zlbxk0EaE9vv+2MoKxeCnIiK7qXqFYUyFdfsQv2JmBCUdjmz5pMGvgBsukNYUpdXiOPCz0ThTSMhzp2MixoZ+oRo69b7xPI9ct5x3Hmx7IudhcPhnGcwHA7PWacua2mmWpnOzk5CodCcmbPFKK7S29tLY2MjDQ2zV5YrdjiRLEn4XAo+l4JuWjnv3kxURaYyM2YyTBFQO5E2ciJ0KTIuRcKlyMiShJKJXAcxHkuk526/IksokoQsS7hVEfGuuMqQ3eXTd6zI1CSfWZHW5RNeRG8VuDNjRu8CHDW+GvHy3yKa/f0fv/W5VoBlF9PZs2eJRCK50KS2trZZ9er8fj8NDQ2cOHGCI0eO0N3dTSAQWHAKOiytuEo4HEbTNFpbWwmHw7MEVShv3kJwKTI1ZS5000Y3LQzTImVYs3peqiJRqah4VJl4Use07EwXcHIfRc6I1K3M7ZzIYFo2JjaYYiIZyHkK1Uw6iUuRV826SKuFZReTpmkEAoHcA3r27Nk5H9hsl6ytrY3GxsaCjFsWklYfiURyNdSPHz++rLUA50OSJNyqsBIgrEp23JSeYbHcqkxtuYe0aWFmhGfaNpYlRDKWMhhLiWIuHlXJ5VmpspinQgIJaVa2rGXDRNqETE0HRZJQFTGvlW2BhLCqelrUu7gxliI9ahAod+Mvm1G+rARZ9q+WYDA4zcIEAoE5q8NmiUaj83oQb8bMaIyFFlcJBoP09/fT09NDX1/f0mtI2LbwmOXzmmfWQpYkvC6F49/4L5z9+cs8/pdf5dIFUU33lZde5I7dWzn13ac4/U/f56+/8iXGhwbZWOnm9Vd/xofvuY3uf/wOP37+Of7iS/+eN956h5GEzg9//BPuv+8unv/RC3z2X32CS+8N8bd/8zc8+0w3zz37DN9/7lleeelF/sV9d/H8c88SjcXYe/cddHV3Z8ZqQrjCyyiE/oPXBvneuUG+84sL/K+fRjh55hLhS7E5u6ylwLJbplAoNK1bN1/Rk3A4TDgcpquri87OzpuW4JrvOsUsrrJg9An46235HfuVQTGQn4cNGzbwyU98DI9L4clv/R3/8++/zcHQQXbfVsen/vXnqPb7ufe+fRz6zCd59bU3aGoKsfu2IH/8R4cpq6hi7759PHzo07zwy9d44MGHqKyuptpfw4lvf4fnn3sG3bL4UJNI5/hvf/Ef+bdf+FOefuZ5vvKf/gMAP3n5F1RPGZfOh23DaNJgNGkwqCX56ds3qPSqVPtc1G+q4PbNFZS5VZK6mYtZXIssu5j8fj9tbW10dnaiaRrt7e3THAUguoInT57MdbFaW1vp7OwkEonMaV3WQnGVYpH9O8ZiMTHnpMjIMtQGajAsm911dWixGCOaRrXfjwQEakRq/D133E5c01D0CQJ+P6os8+AD95M2LN5943U++OBHUGUZ07bZsfs2Xv7nF3n40T+n8YEH6e46xR9+/FOMJg0s2855AlVZwtRNTNsmbVokdNH1nBpDaNk2IwmdkYTOpegEL715nXKPwnjKpMrn4oPBWhK6idcls7XaR3UmEPi9kQRpw6JuQ/mqDF1bdjHBrR0Afr9/1ljlZmOmUCg0p9UpdunjBeEqExYm32PnobOzk+HhYY4cOUI4HObMmTO5sacsSShGgqrKKt67Poy/pmaaBckKa0TTpm3Ljnm8LoWGhkauX7lIbYUY61y/commT36awcsX2bHrNrRYlO/932do+uSnZ7XNNgzGUyYnz7zHlVEhip01ZZS5Feo2lLMrUDZLDOMpMRaLJ3Re+N3VadtUWXgWs9MEGyo9/MGejeyaryruCrEiYlrrLMo1Lkk37arly/79++nr66O3tzf3WSQSyTlyTp06RSAQ4MyZM7x4upfaCjcp3UKS4Efff5Yqfw3hs2c59dwPAXj93K+5eOE8333qf/P5h/+Mzz/8Z3zrm3/L8889ixaLsW9fAzcGL/Nfj3yRH730M/Y3NPCppgdJjI7wb77wp4BwUpiWjWnJ0xwYSd3ineti6ZvXBkaoKXPlnBJbq71s9/vwupRZIWpZDMsmNwsN3BhN8Ux4gO1+HzXlbgzT4lJ0guDGCsrdCh6XjGllFnq3bGwbaspd7KwpE2tkFYm8S305FC42r9A0NjbOu87U1G2GOenps2wbyxYPXtq0sCw779A9PZ3iyqWLnB1WSVgK52+MoyXSxBMGb16No5uzz1zmVkjoJjU+N3s2V+Tc+FurvVT7xKSyEIaNOkc84kKp8rmo9Kggwf7dNQQ3Tp+gXpbYPIe1haZps8aiM7epinzTB8C0xNyWZdnolp2bc1ooiiyhIPG+TZMP7AP1tQxqCeJJg+GxFFe0BLEJPeN2h+hEmlfPR2edp8IjCngC7NlUyb07qtla7V302Cme0Ilnspzv2LJC4UQOa4NwOJyLKZwaQXKrbXOhyBLKlDSOCo9CIm2RMkzR9coDr0uZZQ0Sukk8oeNRZa6OJHn7+hiKLDGeMhgeT5M2rFyaP8Bb10Z569qo8Ah6XSAJb+HmSg8bKj2UuRUkMWGWc8MHMl3KpVi1W+F085bAau3mLQeGaYlpMElM4NqIYjGGZZNMJrl44QJXjErGDImh0VTe4rNtm9iETiJtUulVSegmrw1ovHNtbNHnlCSo8bmpKXchSRK7AmV8/oO7eej9m3L7ON28ArPcKRhrkbm+4RVZwQ3IlorXpRC6fTNer5exlOjSpQ2L3w6OcOHGwnOcJEmUkSbjw6nyufjoXVv4w9strsWTTKRNLNumzK0wqCUZTelMpEzhmrRFmJVtw7V4kvG0SXQiTXQiDcC718d4oL52mpiWwoqKaWq6w1JYrSkY69Xoz7zvCo9KRcaLtmdzJYm0yMUaGk0xkhnDjKUMJAmux1OMpW6d2uFWZXYGprvGd9fO7zW1bZvxlMnQWCo3Ztq9oYz76xYfXTMfKyYmTdM4fvx4QSZTV1sKhssluhFDQ0Ns3LhxVU4wFgvbthkaGkKSJFwu15z7+Nwi2DZQPne8XiJtolsWIxM60fE0g1qCK1oC3bQxTCuvLqMkSVR4VSq8k4/8wTs3ce8O/6LPNR8rJqZTp05x+PDhObet9RQMRVHYsWMHAwMDuVJo6wlJktixYweKkl9okM+t4EOhyutiZ6CM+3b6c9tShkn/9XGujSZF1nEm9SSpm2gT6Tnd7svFiogpHA7nHuq5KIUUjIqKCvbs2YOuz67tUOq4XK68hXQrPKrCXduquIvZzoGss0KbSGNmrJdHVTBtm/dGEgyNii7ejbF0Udq2ImKKRCK3fKBLIQVDUZSiPVQOs8k6K+bqPtZtmBxPmZad9U8UlGUXU7b2Q3d3N2fOnKG/v59gMDhndwqWloKRz/q22RQMh9KlWAtzL3s+05EjR2hubs5Fdjc1Nc0ppKkpGNFodN4u4XyEQqFpi5GtWAqGw7phxRwQvb299Pb25oIzp1qNtZKCkXUBx+Pxgp/bYWXI/i/zmdZwIiCWwMDAADt37lzpZjgUgcuXL7Njx45FHeOIaQlYlsXg4CCVlZXT5pLi8Tg7d+7k8uXLiw5JWa2sl3uybZvR0VG2bduGvMhSYU440RKQZfmm315VVVUl8+BlWQ/3lG/FKadWk4NDgXDE5OBQIBwxFQGPx8PXv/51PB7PSjelYDj3dGscB4SDQ4FwLJODQ4FwxOTgUCAc13iBySchcTXS0tKSixqZGo2y1u5vsYmjS7o/26Fg9Pf3262trbn3oVBoBVuzNBoaGmy/32+HQiE7FovZtr0276+rq8s+cuSI3d7envtsvvtY6v053bwCMl9C4lrk2LFjxGIxenp6cve0Fu+vubl51iIO893HUu/P6eYVkJslJK41shH32Ryw1tbWkrm/+e5jqffniKnILCQhcTUyNSmyvr6eQ4cOzbnfWr2/mcx3H4u5P6ebV0DyXRNqtdHd3T1tOR6/308kEimZ+5vvPpZ6f46YCkipJCRmkzazaJpGQ0NDydzffPex1PtzIiAKzFTX6syKR2uJqYU429ract/Qa+3+ent76ejoQNM02tracu2d7z6Wcn+OmBwcCoTTzXNwKBCOmBwcCoQjJgeHAuGIycGhQDhicnAoEI6YHBwKhCOmEiQcDtPW1oYkSRw9epTOzk5OnDjB0aNHqampKVpwanblkGzKw3rDmWcqUbLhP7FYbFokdDgc5uzZswVZCGEujh49Sn19fdHOv5pxLFOJMt9iB/MtkFAopkZdrzccMa0TsiutA/NGgDssDScFo8TJjl9OnjxJV1cXMJn01tbWRigUoqmpiWg0Sl9fH+3t7bluYTgcpre3l2AwmFtTa2oUdSQSoaOjgwMHDhCNRnMi1TQttyhDT09P7rqapnHq1CmCwSCappXeAtx55wM7rGpisZgN5FLO29vb7f7+/mn7zEzn7urqmpbCPTNtu6GhIXe+WCxmB4PB3Pvsudrb2+3m5ubcMaFQyO7r68u1oaenJ7eto6OjIPe6WnC6eeuEmev1ZpnqnGhubqa3txdN0+jo6Jg1vgoGg5w6dQogZ2Gyxx87dizndDhw4MC082cT7Jqbm2lpaaGxsZETJ06UXHfTEdM6IRgM5rpo2bHTUtA0bVa9hKnv5yIQCBCLxXjyyScZHh6mpaVlye1YTThiKlHmS7fWNI2+vr5p77N0d3cTCoXw+/0cPnx41nxUOBzOWZPm5mbC4fC07beavzp+/HhucbupY7NSwXFAlCDhcJiOjg5APMDZdOz+/n46OzunraLY39+f69qdOXMm5yzIPvDZNYiz27ICCAaDdHR0cPTo0Vy3bsOGDZw8eRIQ2ayRSCTXlmAwSG1tLb29vQQCAaLRKIcPH16uP8my4EzarmPW8wRrMXC6eQ4OBcIR0zolW3Sxq6tr1tjHIT+cbp6DQ4FwLJODQ4FwxOTgUCAcMTk4FAhHTA4OBcIRk4NDgXDE5OBQIBwxOTgUCEdMDg4FwhGTg0OB+P9kKE/FsqcbpAAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 150x100 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "fig, ax = plt.subplots(figsize=(1.5, 1))\n",
    "leg = []\n",
    "for i in range(2):\n",
    "    a, = ax.plot(mean_errors[i])\n",
    "    ax.fill_between(np.arange(nb_epoch), mean_errors[i] - .5*std_errors[i],\n",
    "                    mean_errors[i] + .5*std_errors[i], alpha=.5)\n",
    "    leg.append(a)\n",
    "\n",
    "ax.set_yscale('log')\n",
    "ax.set_xlabel('Epochs', fontsize=10)\n",
    "ax.set_ylabel('Error', fontsize=10)\n",
    "plt.legend(leg, ['real', 'approx'], loc=\"best\", fontsize=8, handlelength=1, frameon=True)\n",
    "fig.savefig(f'approx_sgd_{d}_{batch_size}.pdf', bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAANMAAACJCAYAAAC2Eg1IAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAZ00lEQVR4nO2de3BbV53Hv1eyrdixY1mKkyax3VhqkjYtTSLZ2b4WaHOVAp0O21SyYYBu2aUS3cDMUjIRLrt0uuzgStthB2ZCa6WFgYVZbKvtFvqg6CY8+0hlCSht+kh083JDnMRX13ISP2Tp7h/yvZZsyZalq/f5zGjG99yHzrH01Tnnd36/36EEQRBAIBCyRlHoChAI5QIRE4EgE0RMBIJMEDERCDJBxEQgyAQRE4EgE0RMBIJMEDERCDJBxDSL2+2G0WgsdDUIJQxVqh4QPM/D5XIBAPbv3y+Vu91uAADHcdDpdKBpOu1nmkwmeDweeStKqBiqCl2BTGEYBqOjo9BqtVIZy7LweDzo6+sDEBPHcsREIGRDyYrJbDaD4zjwPC+VMQwDtVotHavVajAMA5qm4Xa7wXFcwjM0Gg3MZnPa7zk1NYWpqSnpOBqNguM4aLVaUBSVcVsIxYMgCBgfH8f69euhUCxvFlSyYkpGIBBI6Kk0Go0ktuWIJhW9vb149NFHs34Oofg5c+YMWlpalnVPWYkpGfN7o1QwDAOWZeF2u1MKr6enBw899JB0PDY2hra2Npw5cwarVq2auzASBibHgZWarOpOyD+hUAitra1oaGhY9r1lJSa9Xp8w7BONEOlA0zQCgcCi16hUKqhUqgXlq1atmhPTX93AS/uAzZ8E7nki7boTiotMhu1lZRqnaRper1c6Zlk2/waIVeuBiSDw3gtAeDK/700oKCXbMzEMA4/HA57nodPpYDabodPp0N3dLRkbenp6cvLeBw4cwIEDBxCJRBaebL0JkYYNUI5/iEO//Cl27flSTuogF5FIBOFwuNDVyDvV1dVQKpWyPrNk15mKgVAohMbGRoyNjSXMmfq/cz+6p5/Di5GduOvbxbtudenSJQwPD6MSvwIURaGlpQX19fUJ5ak+03Qo2Z6pWPnNe+fxf5e2orvmOWylThW6OimJRCIYHh5GXV0dmpubK8q0LwgCLly4gOHhYWzatEm2HoqISWYUCgpBIWYJqqeKd84UDochCAKam5tRW1tb6OrknebmZpw8eRLhcFg2MZWVASJfHDhwAFu3bkVnZ+eCcx/b3IynrXcAAOoxgUi0uIdQldQjxZOLdhMxZcDevXtx9OjRBMthPKtXxxaOa6lpXJ4s3t6JIC9ETDlAVdco/X1lnC9cRUocv9+/YO2wmCFzplygrMYkqrECYUxe4oG16wpdo0URBAET4SRm/jSorVbmbKhoMBjSXnQvBoiYcsQV1GEFxjB5aazQVVmSiXAEW7/1Skb3Hv2PO1FXQ75GABnmZcRiBgiRSUUdAGDqcvGLKVcwDAO9Xg+GYWAymaThmt1uh9vthtPphN/vl8oYhoHdbgfLsgWsdeaQn5QM2Lt3L/bu3Sst8CVjUlEHRIHwleIXU221Ekf/486M700FTdNQq9XQaDQYHByEWq2Gy+WCVquVnInFgEytViu5fjkcDikmrZQgYsoRU8qVwAwwMxEqdFWWhKKonA7VDAaD9LfP55PizICYcERcLhd4nk/b07/YIMO8JLjdbrjdbtjt9oyfEa5aCQCITha/mPKJmGeDpmmpJ3K5XBgdHYXVapXKxOFfKVEwMTmdTulLmy08z8PpdMLpdCaUi893uVzSL+FSuN1uqNVqmM1maLVaKc/EcpmRxDSe0f3lgN/vB8uyCf9Dq9Uq/V/Fz76jowM8z4NhGKlXYlk26f3FTEGGeSaTSRpDG43GrKNg5cwHEV+XQCAAm82WUZ0i1bPBZVOVKyaDwYBgMLigPD4Bjkj8HCn+c1oqxqyYyLuY/H6/lKfB7/fD5/MtuEbMPBT/T3e5XKBpOum6Qy7yQTAMA6PRmDDeF1k0BGMWoSbmjayoYDFVGnkX09DQEFiWlcyfNpttgeVGrVbDYDDA6XRi//79cLvd0Gg0y1rAyyYfhN/vB8/zsFqt8Pv9CwSVjjUPqljPRIUvp11nQmmT9zkTz/PQaDQwGAwwGAwYGhpKOtmkaRoGgwE2mw0cx8mSECUdKxHLsrBYLOjr64PRaMzYskTNiqlq5lJG9xNKj7z3TDqdLqGH0Wg0YFk26XAKiAlAo1l+YpJM80HodDpZxumK2piYOi/9Bkcf/yRmJnisrg5jXaMKVNNG4N6ngeoVWb8PoXjIe89E03TCCneqPA1+vx9+vx+Dg4PgOG7ZVr9C54NQrpgb/m299BpujBzF+sljoEbejuWHOHMkb3Uh5IeMxdTZ2Ylnn3122fep1WrYbDa4XC44nU44HI4EQwEQGwr29/dLBgir1QqO41K6mYj5IDwejyS6+HwQLpcrZ/kgUlEd5zk+Iqjx1emv4L7puHWriYVWLkJpk3EOiIMHD+KBBx5IKDt8+DDuuOMOWSpWCiyWL+D4kAfXvBCb59019R28R7UjEhVwZKMLa8/9Frj7e4Dx/vxXepbJyUmcOHEC7e3tWLGi8oabqdpfkBwQFEXhwQcfhF6vh06nA8dxGBwcrAgxpWMab926E+dfXos/hduguaYDn6itxotv/Q0h1GMtAEzw+aouIU9kLKbHHnsMNE3j4sWLuHjxIoD0s6eWOumYxlV1jVjzb+9jtxDFnQolvv3CUQAAL8S8yTHJ56m2aSAIQPhKZvdW1wEVGvo+n4zF1NfXh127diWUHTp0KOsKlRUUBYqKeVWvaYhlgh2NzIqpmHqm8BXgO+szu/fhs0DNypSn7Xa75Blus9mg0+nAMAwsFgscDgc0Gg08Hg/sdnta58R1SYfDgcHBQbhcrgQrrTgndzgcoGkaRqMRDodDlqWVpchYTLt27UIoFMLAwAAAoKura4G4CHOsXRUbl58Pz47Pi6lnyiHJQitET5auri5pgd5kMiEQCCx5Lj6kQ/zuiUKx2Wyw2WyScIE5L/V8kLGYTpw4AYvFIv0qiL8U27dvl6tuZcWaVbGe6cOpWTEVU89UXRfrYTK9dwlShVaIX3Jxzs3zvFS22DlxTdLn88FkMknPEwMR9+/fD5PJhP7+/rz0SCIZm8afeeYZDA0NYWBgAAMDAzh27Bj6+/vlrFtZsaYhJqLhidnE/8XUM1FUbKiWyWuR+dJSoRXiorroFRPfgyx2TsRoNCYslwQCARgMBrAsKzkHyBGVkC4Z90zt7e0Lyjo6OrKqTLEgfgBerzcheC0b1s72TOemVwAqFFfPlCM6Ojrg8/kSwl/ivV0GBgag0Wjg9XoXbH+a7Fx8SIbVaoXVapVCeTiOk8Rls9ng8/lgMBgklzCr1Zrz9mYspmQLqCdOnFjWM8TYoWw9E+Tc3za+TvEfXLbUq6pQV6PEWHh2sl4Bi7YGgyFlaAUA6f+abCiW7FyykI5k4RzivWq1OmkISK7IWEw0TWP37t1S5CTDMMv6Fed5Hr29vbJ4JuQ7nimddab5UBSFNk0dRs/NimlyDIhGgWVu9VhOxM+DlnOuWMn4k9yxYwf6+vogCAIEQYDL5VrWgu3AwAC6u7uTnhMjZ+NxuVwp3YnMZjP0en1CWap4JgCSi1H8a/7YerF4pqUyuqZiy1UNCEE0IwvAVGWGtC8WQVtq0bXxZNwzdXZ2oqenB4899tiy7/X7/VKQXjKKPZ4pU7Zc1YDnUYNpqgY1wnTMCFGrluXZpUSqCNylzhU7GfdMVqsVe/bsSSg7fPhwWveK1pbFKPV4pmRsWRsLywhhdk+gIjBCVOLeTEBu2p133zyn0ymZLL1eLwKBAHQ6XdnFMyVjy1UxMXHROqymOOB/7gG+4gVWrs7J+y1GdXU1KIrChQsXKnZ/JoqiUF1dLdtz8+6bF2998Xq96OzsTCqk+HgmcU6znN6JpumEVF0F2d92HhvUtahXVeGDaAs2K4eBCQ548yBwe37DQwBAqVSipaUFw8PDOHnyZN7fv9CIOwfKuRVnwXzzGIYBwzDSukN8ryHGM4nWQavVKhkgkvUuhdzfdjlQFIXtrWp8/fiXcc2GZlx77pfAO88CH90HnH4DCE8svGmDEVipXVguA/X19di0aRPZ01YmMo5n6ujowMMPP7xg3lRJZBL78rMjp/DN597GznXVGAh9HohMLX7D2huAB1+VobaEdMgmniljA4TNZsvYAFHJfOL6q6CggDf/FsbljXHDzqpaYN12YP2O2KtpY6x85G1gOsPwCEJeIcGBSciFO5GItl6Fne0avMFyeKXta9jTvhNQVAE3dgH1axIvdmyMeUpwLHDVDbLWgyA/shogRkdH07rX7XZDp9NhaGgIALJ21ykFd6J4dm6MiemPI9XY0/WvqS/UXgMMe4HR40RMJUDeDRCiG5HP54NOp0NTU1PWX9ZScCeKx3B1EwDAf2qJxcl4MRGKnrTFdPjwYcmdJz4Q8NChQ1KSdb1ev2SAoFqtllIipzJXF3t65LQyui7CjrYmUBRwcvQKXg+MormhBgDwBsvB9XsWD/x9O75w80ZAM+sixZXm5l+VRtpislgsOHTo0ILgv127dmHXrl3geR56vR779u1L63kulwsejweDg4MLzpWrO5FIY201Nq9pwPsj4/jswTcWnP/x66diYtLOion0TCVB2ta8Bx54QBLSyZMnE15ATADzU38thtVqhc1mS7kHUjm6E8XzxVs3YnW9Ck111dJL5NToZUSiAhFTiZF2zxT/Kx8MBmG1WtHd3Z3wJZ/vuZ0K0b2epmlYLBZYLJaU85lycieK5zM72/CZnW0JZdGogOsfeQUT4QhOjl6GXhzmXRmN+fFVoFNsKZF2zxQ//9ixYwe6urqwb98+bNy4USpPx7/L5XKht7dXOtZoNEnFUurpkTNBoaCweW3MCfb9c+OAqh5YOWsuDy4v8JKQf9IWE8uyGB8fRygUQigUAkVRCcehUCitX/Suri50dnZKO2vbbLYFc5JySI+cKaIz7HvnZvd10symB+CImIqdtN2JFApFQs8jCELS40zNxaVINq4nqXj6jyfw7ReOwrR1LQ7e1wE892XgL/8L3PHvMR8+Qk7JizuR1WrF8ePHwXEcOI5DMBiU/uY4DsePH1+WAYKQHOPsGtTvPriA86FJQDM7zyPDvKInbQOEzWZLmpFIpLGxMeP9XwlzbGtpRMfVTRg6FcRNvYfwjRYBVgBn2aPIMOcqIU9k7DVOyM0wD4j1Sv/4wzcBANuo43he9S0AwOmPfx9tmqWTPgIAGluADQagSiVbvSqBbD5TIqYsyJWYgJigzvITqJkew73MbZk9pO0W4IsvkcT6y6AgW8oQcsvHNjdLf5/h9uOk9yXpWKmIBRnW1aQIbouEgVOvAqdfA9jfAPrS9+QvBYiYkpDLEIxMaLn7YTwa3I3D751HdHYcYWlswX9ZtqW+6eVvAEeeiOWZUKX4ha2uBTbfCdz2UMxiuOPzgLot+bWEJSnIME90NPX5fIt6P6RLrkIwnE4n1Gp1Sq/2XA7zUuE/HcSeH7wGADC0qfGtu6/H9lb1wgvHPgSeuDmW7HIplKpYxG/bzcAXX67oYWFJDfPExO1WqxU8z6O9vT3rPGn5DsEoJIa2Jtx2zWr88fhF+E/zsDz5Graub8SKqoWrHLWNP0RjAw8AuLGlEfff0g6lIk4oI28DA/fNhc6ffh142gSs2wZcdzeg+3juG1RG5F1MHMfB4/HAbDZLe+3M98wu9hCMbOOZsuXJLxjhPcHhZ0dOgXn3PP5yhl/k6loAwPNnpvHbi6Ogr1sbd24bPrZ2NzaO/BrTVQ2omRmPxU8NeyF4n4Z3yz68e/Xnl6zP2bEJ/PqdEUyGI6AAbG9T45o1MU+OW/Ra3KTT4gg7ilfeGcHYRBjNDSp8zbQJqip5E5oUmoJb85qampL2TAzDwO/3SyEYwOKhE+IeQKIA7XY7tFqtdGyz2WAymdLyPhdT9JrN5kVDMAoxzJtP4MIlHBsZRySa+po3T4zix6+fSnquAVfwaeWreCnyd/gH5auoxwRuUJzAbqUPUYHC58IP4/Xo9RnXb0W1As/vvQ33/OBVXJme+/H55qeuQ1dnK948weH2Lc2oUibxHxgNAGNnklR6XSzWiz8151kvEyU1zIvHZrPh4MGDSc+JQzKbzQaj0ShL6PhyQjB0Oh16e3uLwgCxGPrmeuib6xe95q4b18FwdRN+fXQk6fkg9LgZwAi2YARAAMDKc4/j1tCLeLL2CfRefRDjVXPOyDePvYibQq9gUlGHX6z+Ei7UbQZ93Vpcra3DlekIht55H7cEvovoaMyX8vIPgJ8CoGoo1NUocXl6BmCAU4cV0EajONewAi1NtYmVCk8CI39N3ajG1pjQ7vou0PnPS/+j8kDBxOR2u9PqKco1BCPffHr7Bnx6+4b0b5h+Cjh4OxovvIfH2HuB1Ztj5VPjwPjfpMtuuHwE0G4CLszNxXZyJ4BoOLmz2gwSyxUALs++5iFQClCrtwDU3A2RqXEox07P9VgvPgREZ4AV6rkbKQq4+lagcRntlYGCiEmcz9A0Db/fD7VaveCLXo4ZXUuKmjqg6yfAUyZgagy4+EHieeP9wPl3gTNHgNFjC+9Xt0H46H64372Ck6OX0aZZia6OFlCgELh4CX/44CLqapR460Me58aS5w6cWX09fvjgHihmjSaCIOCeA6+ibfpX6Gk/jg3nDseMJy8v3KMJ2muAfzkCKPP3Fc/7nIllWWlPJyBmbJhfBTHpSvwQazEDBMMw6OvrA8/zsNlskujiTePzDQxyUAxzppxzhQPOH00sa1gXm6tEI8DZPwMz8zLRUopY7r/qeUO3JBw/fwmPv/I+Jmfm5lORqIA/HItlvPr+Z3fgro+sAwC8evwi7pt1swKAJ3cpcePxJ6CITktlCiGCNRePxA5u+SrwEcviFWhYl5BijbgTFYiKEFOB+G/PB/jeoSQ9Xho8qPwF7NU/T+/i3f8ZE90sJWuAIBBS8U+3tsPtG8aHfGKv11hbDfsnrgXz7kgsT8Y8pmYi+AlrwkcVb8FYP4qaJOtvUUFAVACqFLMbY8sE6ZmygPRMuWUmEsX45ExCWZ1KueT61DeeeQs/98YMFI9btsFsbIHjV+/hF38+CwC4eGkKUzNR/Oj+Ttx+bWIW3YLkGicQck2VUoGmlTUJr3QWej8bl6hmv/sv+NPpIJ78XQAf8hP4kJ/A1ExsUe4nr5+Utb5ETISyY1urGj+6vxMAEBWAzz11BIIANDeo8PzeW/HUfR0AgN8fu4jz45OyvS8RE6Esuf3aNfjeZ7YDgOR58XXTZmxrVYPeuhY72tSIRAVp6CcHxABBKFvuvP4qtGnqcJq7guYGFT514zrp3L2GFigpCi1NaUYupwExQGQBMUAUP1MzEZwPTaG5QYUV1XPzrfnZtUSIaZxASIGqSonWJHkzcrEhNhFTFoideigUKnBNCHIhfpaZDNiImLJgfDyWdbW1tbXANSHIzfj4+LK3CyJzpiyIRqM4e/YsGhoaEoYNoVAIra2tOHPmTNnMpSqlTYIgYHx8HOvXr4dCsTxjN+mZskChUKClpSXl+VWrVpXNF0+kEtqUyQZ2AFlnIhBkg4iJQJAJIqYcoFKp8Mgjj0ClKp/UxKRNS0MMEASCTJCeiUCQCSImAkEmiJgIBJkg60wyk0l+82LEYrFI+wD39/dLyW1KrX3LzUOfVfsEgmwEAgHBarVKxzRNF7A22WEwGAS1Wi3QNC0Eg0FBEEqzfYODg8L+/fsFh8MhlaVqR7btI8M8GUmV37wU6enpQTAYhMfjkdpUiu0zm83Q6xNTKKdqR7btI8M8GQkEAgk7cWg0moSssqWE1+sFMJdS2mq1lk37UrUj2/YRMeWYdPKbFyPxCUD1ej26urqSXleq7ZtPqnYsp31kmCcj84cT6eY3LzbcbndCamm1Wg2WZcumfanakW37iJhkhKZpaXgElG5+c51OB5PJJB3zPA+DwVA27UvVjmzbR9yJZCbX+c3zRfy+vjabTfqFLrX2LTcPfTbtI2IiEGSCDPMIBJkgYiIQZIKIiUCQCSImAkEmiJgIBJkgYiIQZIKIqQzx+/2w2WygKAp2ux0ulwtOpxN2ux1NTU05c05lGAZGo1EKeag0yDpTmSK6/wSDwQRPaL/fj6GhIVit1py8r91uh16vz9nzixnSM5UpGo0mabnBYMjp+8Z7XVcaREwVgt/vB8uyAJDSA5yQHSQEo8wR5y/9/f0YHBwEMBf0ZrPZQNM0TCYTOI6Dz+eDw+GQhoV+vx8Mw0Cn04FlWZjN5gQvapZl0dfXh87OTnAcJ4mU53kwDAOWZeHxeKT35XkeAwMD0Ol04HkeXq83IdSj5Mk4HphQ1ASDQQGAFHLucDiEQCCQcM38cO7BwcGEEO75YdsGg0F6XjAYFHQ6nXQsPsvhcAhms1m6h6ZpwefzSXXweDzSub6+PlnaWiyQYV6FEO/9HG/NizdOmM1mMAwDnufR19e3YH6l0+kwMDAAAFIPI97f09MjGR06OzsTni8G2JnNZlgsFhiNRjidzrIbbhIxVQg6nU4aoolzp2zgeX5BvoT442RoNBoEg0EcPHgQo6OjsFgsWdejmCBiKlNShVvzPA+fz5dwLOJ2u0HTNNRqNbq7uxesR/n9fqk3MZvN8Pv9CeeXWr/q7e0Fy7IwGAwJc7NygRggyhC/34++vj4AsS+wGI4dCATgcrmkfHhimTi083q9krFA/MI7nU7odDrpnCgAnU6Hvr4+2O12aVi3evVq9Pf3A4hFs7IsK9VFp9NBq9WCYRhoNBpwHIfu7u58/UvyAlm0rWAqeYE1F5BhHoEgE0RMFYqYdHFwcHDB3IeQGWSYRyDIBOmZCASZIGIiEGSCiIlAkAkiJgJBJoiYCASZIGIiEGSCiIlAkAkiJgJBJv4fSRhbqCC1fYMAAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 150x100 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "fig, ax = plt.subplots(figsize=(1.5, 1))\n",
    "leg = []\n",
    "for i in range(2):\n",
    "    a, = ax.plot(errors[i, 0])\n",
    "    leg.append(a)\n",
    "\n",
    "ax.set_yscale('log')\n",
    "ax.set_xlabel('Epochs', fontsize=10)\n",
    "ax.set_ylabel('Error', fontsize=10)\n",
    "plt.legend(leg, ['real', 'approx'], loc=\"best\", fontsize=8, handlelength=1)\n",
    "fig.savefig(f'approx_sgd_{d}_{batch_size}_one_run.pdf', bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dev",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
