{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from math import sqrt\n",
    "from random import seed\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import torch as th\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "th.manual_seed(0)\n",
    "seed(0)\n",
    "plt.rc('font', family='serif')\n",
    "plt.rc('text', usetex=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_embeddings(n, d, norm=True):\n",
    "    emb = th.randn(n, d)\n",
    "    if norm:\n",
    "        emb /= emb.norm(dim=1, keepdim=True)\n",
    "    else:\n",
    "        emb /= sqrt(d)\n",
    "    return emb\n",
    "\n",
    "\n",
    "class AssMem(nn.Module):\n",
    "    def __init__(self, E, U):\n",
    "        \"\"\"\n",
    "        E: torch.Tensor\n",
    "            Input embedding matrix of size $n \\times d$,\n",
    "            where $n$ is the number of tokens and $d$ is the embedding dimension.\n",
    "        U: torch.Tensor\n",
    "            Output unembedding matrix of size $d \\times m$,\n",
    "            where $m$ is the number of classes and $d$ is the embedding dimension.\n",
    "        \"\"\"\n",
    "        super().__init__()\n",
    "        # self.W = nn.Parameter(th.randn(d, d))\n",
    "        d = E.shape[1]\n",
    "        self.W = nn.Parameter(th.zeros(d, d))\n",
    "        self.E = E\n",
    "        self.U = U\n",
    "\n",
    "    def forward(self, x):\n",
    "        out = self.E[x] @ self.W\n",
    "        out = out @ self.U\n",
    "        return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# number of input tokens\n",
    "n = 30\n",
    "# number of output classes\n",
    "m = 5\n",
    "# memory dimension\n",
    "d = 30\n",
    "\n",
    "alpha = 1.5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_x = th.arange(n)\n",
    "proba = (all_x + 1.) ** (-alpha)\n",
    "proba /= proba.sum()\n",
    "all_y = all_x % m"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# number of data\n",
    "batch_size = 10000\n",
    "nb_epoch = 1\n",
    "T = nb_epoch * batch_size\n",
    "lr = 1e-1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,"
     ]
    }
   ],
   "source": [
    "lr = 1e0\n",
    "num_models = 2\n",
    "\n",
    "nb_trials = 100\n",
    "\n",
    "train_loss = th.zeros(num_models, nb_trials, nb_epoch)\n",
    "test_loss = th.zeros(num_models, nb_trials, nb_epoch)\n",
    "train_loss[:] = -1\n",
    "test_loss[:] = -1\n",
    "\n",
    "for k in range(nb_trials):\n",
    "    # Embeddings\n",
    "    E = get_embeddings(n, d, norm=False)\n",
    "    U = get_embeddings(m, d, norm=False).T \n",
    "\n",
    "    # models\n",
    "    assoc = []\n",
    "    opti = []\n",
    "    for i in range(num_models):\n",
    "        assoc.append(AssMem(E, U))\n",
    "        opti.append(th.optim.Adam(assoc[-1].parameters(), lr=lr))\n",
    "\n",
    "    for i in range(nb_epoch):\n",
    "        x = th.multinomial(proba, batch_size, replacement=True)\n",
    "        y = x % m\n",
    "\n",
    "        for j in range(num_models):\n",
    "            if j == 0:\n",
    "                for t in range(batch_size):\n",
    "                    out = assoc[j](x[t])\n",
    "                    loss = F.cross_entropy(out, y[t])\n",
    "\n",
    "                    opti[j].zero_grad()\n",
    "                    loss.backward()\n",
    "                    opti[j].step()\n",
    "\n",
    "                train_loss[j, k, i] = loss.item()\n",
    "            else:\n",
    "                out = assoc[j](x)\n",
    "                loss = F.cross_entropy(out, y)\n",
    "                train_loss[j, k, i] = loss.item()\n",
    "\n",
    "                with th.no_grad():\n",
    "                    pred = assoc[j](all_x).argmax(dim=-1)\n",
    "                    test_loss[j, k, i] = proba[pred != all_y].sum().item()\n",
    "\n",
    "                opti[j].zero_grad()\n",
    "                loss.backward()\n",
    "                opti[j].step()\n",
    "\n",
    "            with th.no_grad():\n",
    "                pred = assoc[j](all_x).argmax(dim=-1)\n",
    "                test_loss[j, k, i] = proba[pred != all_y].sum().item()\n",
    "\n",
    "    print(k, end=',')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "error_mean = test_loss.mean(dim=1)\n",
    "error_std = test_loss.std(dim=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_229100/3851978499.py:8: UserWarning: Data has no positive values, and therefore cannot be log-scaled.\n",
      "  ax.set_xscale('log')\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQgAAACsCAYAAACHDDKzAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAA9hAAAPYQGoP6dpAAATMElEQVR4nO3dz27bVr4H8K/ietrk5g/jIJvAARKq2XTVUNYLJHRnX0j2E1hC4W0hwndxF7O4rrT3gvITOGK6n4jOA9QW0VVW1XEAG9mktminmLQwUt6FLxnLFiWSJinK/n6AYCqJIk8y5E+Hhzxf5hzHcUBENMC1cTeAiLKLBYKIfLFAEJEvFggi8sUCQUS+WCCIyBcLBBH5+mLcDSDg77//xrt373Dr1i3kcrlxN4cuMcdx8OHDBzx48ADXro3uH7BAZMC7d+/w8OHDcTeDrpDd3V3Mzs6OXI4FIgNu3boF4OT/tNu3b4+5NXSZHR0d4eHDh94+NwoLRAa4pxW3b99mgaBUBD2V5SAlEfliDyIhtm1jdXUVi4uLUBRl3M0hAJ8+fcLx8fG4m5Ga6elpTE1NXWgdLBAJ2d7ehm3b424G/b8//vgDe3t7uEqTl3O5HGZnZ3Hz5s3I6xhLgbAsC6ZpAgC2trawvr4OSZIuvGzUtiwtLaHT6fS9L4SAYRiQZRlCCFQqlVDbVVUV7XY7tnZSdJ8+fcLe3h5u3LiB+/fvX4lLyY7j4P3799jb28OTJ08i9yTGUiBM00StVgMANBoNPH/+/NwBGmXZsNwCYFnWuc/K5bK3HSEElpaW0Gq1Ytkupev4+BiO4+D+/fu4fv36uJuTmvv37+Pt27c4Pj6OXCByaQfGWJaF58+fo9frATg5+PL5PLrdLmRZjrzsReRyub6upxCir0AAwN27d712GIYBIcS59ZRKpb52aZoWaAzi6OgId+7cweHhIa9iJODPP//Ezs4OHj9+jK+++mrczUnNoL932H0t9R6EoihYX1/3Xrvn6TMzMxdaFjg5IFdWVvpOBSzLwvb2NiqVSuA2mqZ5bhszMzOwLAuKoqBUKgVeF9EkG8spxukDbGNjA6qq+p7fh1l2ZWWl71RACAFd16Hreqj2+Q0uHhwcBF6HaZp9py6DehFra2tYW1vDp0+fQrWPonMcBx+P4//3vj495Tu2IYRAtVqFJElYXFzEwcEBut0uqtVq4J7wuK6KjfUqhm3bMAwj0JhCkGUlScL6+jrK5TJWVlYiFYdRbQhKVVWoqjp0meXlZSwvL3vdPkrex+NP+OZ//h37et/865+48Y/Bh5MsyyiXywA+/+AZhgFN0wKPa43rqthYC4SmaWi324GuDgRdVpIkVKvVvrGLsCRJOtdbODg4iPXqCV1dQgi0221Uq9XA3xnXVbGxFYhGowFN0yDLslcZ/Q7AMMvatg1d17G5uYlqtRqpB6Gq6sDvzc3NhV4XZcv16Sm8+dc/E1nvKJ1OB6ZpwrZtFAoF7/TCtm00m82B31EUZWRPNEljKRCGYUBRFO+Af/HihTeIaFkWJEny/vGGLXuWbdt9YxDVajVwkbBt2ys6Z88LhRCYm5tjD+ISyOVyvqcCSSsUCt7BLoRAoVBAr9eDJEnepfysSf1fyr2EeJokSd5Bv7q6imKxiFqtNnLZs1ZXV/uueiiKgmq1imazOfA7pml63TZ3u+45YqvVgqZpKBaL2Nra4j0QlJgs9yBSvw+CzuN9EMka930Qg65iuGMQQQ9+0zRRr9ehKErgKxlx3AfBApEBLBDJGneBGJc4CgSnexORLxYIIvLFApEQ27ahadrAiWBEk4IFIiHMg6DLYCwFwrIsNBoNNBoNlMvlkQeSZVkoFAqJtWXQuoUQaDQaMAwDjUYj9ME+bM4I0aTIfB7EsMyGi2IeBKUhyclazWbT24fPxg3EIfUCYVkWVldXvQJRKpWgaRqEEAP/cklOrfZb99msB1mWvVQrIHgeBGWQ4wDH/4l/vdM3AJ/ZnElN1hJCoNvtolKpQFVVlMvl2H/EMp0HEdak5UFwuvcYHP8H+N8H8a/3v98B//ivQIvGNVnLNE3k8/m+9cYt83kQYUxaHgSne18tcU/WOj1/yH0dt4nJgwhi0vIgaAymb5z82iex3hHinqwlSVLiV8omJg8iKOZB0FC5XOBTgbRE7UHMzc1hY2Ojb9m4TUQeRBjMg6CsEUKg1WpBkiTMzMx4k7XcU+EgPYhBp6yKomB7exumaUII0Te2FxtnDFqtltNutx3HcZxer+fouu591ul0nG63e+47AJxerzd0vb1ezymVSn3rqlQqI9szaN2Konj/3e12HVVVR64nqsPDQweAc3h4mNg2rrKPHz86b968cT5+/DjupqRq0N877L6WeoHodrsOgL4/kiR5n5dKJader3uv2+22U6vVHABOrVZzWq2W77prtdq5A73T6fQVoNOGrbvb7XrvDVpvnFggksUCEb1AcLp3BnC6d7I43ZvTvYlGumq/hXH8ffnwXrr0pqenkcvl8P79+yv3bM5cLofp6enI62GBoEtvamoKs7Oz2Nvbw9u3b8fdnNS4T/eO+lxOgAUiUeN6GhKdd/PmTTx58gTHx8fjbkpqpqenL1QcABaIRDETIlumpqYufMBcNZkvEJZleTMpt7a2sL6+HvsdjZZlYWlp6dwt30IIb0q4EAKVSiXUtsf1NCSiuGS+QITJjoiCmRBE/jJdIMJmR0TBTAgif5kuEGGzI+LKgwDSyYRgHgRlXeZvlAqTHeHmQbjcPIiwxQGINxNiY2Nj4CnM8vIy3rx5g62trdDtI0pDpnsQpwXJjkg6D8JtR1DMhKBJF7kHUSwW8fPPP8fZlqGCZkeczoO4SHFgJgTRBQpEpVLB999/3/fe69evL9ygQc5mRwz7FT+bBxHVsJAOoqsi8ilGLpfDDz/8gHw+D1mWcXBwgFarhWfPnsXZPhiGAUVRvOLw4sUL3zEF27b7LkVWq9VQoTGnM/7OXokQQmBubo49CLpSIk/3/vrrr88NGG5ubsY64CaE6EvtBU66/n5RclGuYpimiXa7jUajgVqthmKx6A2MuoOcxWIRW1tb59YdF073prSE3dciF4jNzU08f/585Hs0GgsEpSW1AuFu7MWLFwCAhYUF7twRsUBQWlILjNnZ2cGzZ8/w6tUrvHr1CoVCAb/++mvU1RFRBkUepHz58iW2t7f73ltZWcG333570TYRUUZE7kE8fvz43Hu8BEh0uUQuEIMmKu3s7FyoMZeFbdvQNC2RJ5ITpSnyKYaqqvjuu+9QKBQAnFwurNfrsTVskjEohi6LyD2Ip0+fQtd1OCfP1kCz2Qx8k5RlWV5hGUYIAU3T0Gw2oWla7AedXzuEEGg0GjAMA41GI/R243oYMdHYRX0ox9zcnPPy5cvQ32u1Wk6n03GCbFqWZe+BNUGfkhVHO84+Vev007qCqtVqTqfTCbQsH5xDaQm7r0U+xfCbizGqFxE0R8ENZ3F/iRVFQbPZjG12JoNiiEbL7FwMv269G9hyVlxhMWkExbgYGENZF3kM4qeffoLjOPj999/xyy+/4LfffgsVpjKKoih9v9TuFQG/bcQVFpNGUIyLgTGUdZF7ELquD5yLERdZllGv19FsNrGwsOAVC7+4uaTDYhgUQ1dR5B6EpmnnAmPinqhVq9WgqiqEEN4BN+w8P46wGAbFEH0WuUBUq9XEA2Pc9Gr3dENRlKEHahxhMQyKIfpsrIOUpwNagJNxBkmSvF5CoVDAzs4OJEmCrutDb8S6SFgMg2KIBot1kHJ/f3/k90zThKZpAIDV1VUYhuF9dvZ1vV6HaZpoNpsol8tDz+tXV1f7IvIVRUG1WkWz2QzdjlarBU3TYBgGdF3nw3LoymJgTAYwD4LSEnZfC3yK8fr1a+9KwsLCglcINjc3IYSAZVnI5/MsEESXSOAexL1797C5uemb92DbNvL5fKDTDOrHHgSlJbEexNLSklcc3r592/fZo0ePIElS341KRDT5Ag9S3rt3z/vvXq+HcrncN7AH4FwCNRFNtsAF4vRlvqdPn2JhYQE//vgjHj165L2fy+XibNvEYmAMXRaBTzGEEPjw4QPcIYtcLtf3GgC63W78LZxADIyhyyLwIOW1a9f6egiO4wx8HWRmomVZWFpaGvogXuCkKLmzK4UQsU+p9muHEAKGYUCWZQghUKlUQt8opWkaFhcXB848PYuDlJSWxAYpK5UKNE3znSy1v7+PRqMxcj3ugRek+20YBmq1mvc6zGP0LtKOcrnsFQ0hRN8dmkRXSeACUa1WByZZu+7cuRNo/kOYPIWNjY2+AhEnBsYQjRa4QDx9+jSWZcKYmZlBoVBAq9WCEALz8/O+y05iYAxR1kWei5EGt1ufz+fRarWGHpyTGBiztraGb775BsViMVQbiVKTQC5mIEE23Wq1nHa77ei67gAYGVrb6/WcUqkUKuD2bDvq9bqjqmrfe7IsO61WK9D6omBoLaUltdDapAkhsLW15U3xVlUVhUIBmqb5nuufDozp9XqRtsvAGKLPMnuKYVlWX9dblmWsrKwMvb+AgTFE8RprgTh7sFuW5V0hUBTlXJjr/v6+730FpwNj3CyIoEXidDsYGEP0WeqnGKZpot1uAzgJaikWi97go/u6VqtBlmXMz8+j0Wh4B+ewA35YYMyggcph7XADY4rFIra2tngPBF1ZkQNjKD68k5LSEnZfy+wYBBGNHwsEEfligSAiXywQROSLBYKIfLFAEJEvFogEMHKOLgsWiAQwco4ui7EUCMuyUCgURi5nGAZs207sYPNrhxACjUYDhmGg0WiE3r6qqrw1my6F1AuEG5UfpPtdLpdx9+5d3L17F7lcDrlcLlCs3UXbUS6XUavVUCqVUCqV+LwPurJSn4sRNJHJtu1zITGNRiO2CDpGzhGNltk8CKD/IDYMY2hxmcTIubW1NaytrQVKAicah8wOUp4+0G3bxsHBwdBf6EmMnFteXsabN2/OTWsnyopM9yBcmqZ5yVJ+JEnC+vo6yuUyVlZWoOt6bBH5gH/hGERVVd/gGaJJkvkCYds2TNMMdFWAkXNE8crsKYZre3s78MHJyDmieGU2cu70e35P8zq7LkbOEcUr9QJhmiY0TQNwEvXm3o8w6LUryOXDYZFzYdvhRs4ZhgFd1xk5R1cWI+cygJFzlBZGzhFRbFggiMgXCwQR+WKBICJfLBBE5IsFgoh8sUAQkS8WiAQwk5IuCxaIBDCTki6LTGdSAie3RDebTZim2ZfslGQ7mElJdCL16d6GYUCW5UDdb9M00Wq1oOs6hBCYn59Ht9tNvB3lchmdTgfASbFwJ4ERXTWZzaQEgGq16h2osiyj3W4n3g5mUhJ9ltnAGCGEF9RiWRZkWR56ADKTkih+mS0Qbg6EYRhQVRXNZhOyLPseoG4mpXsq4GZSho2dizOT0qUoysDllpeXsby8jMPDQ0iShKOjo1BtJQrL3ccCT+J2xmTUpnVddwA4vV7PcRzH6fV6I7/T6/WcUqnkdDodp1KpRGpHvV53VFXte0+WZafVagVaXxS7u7sOAP7hn9T+7O7uBto3M9uDkGUZkiR5pwzu/7pd/UEmNZPywYMH2N3dxbNnz7C9vd332dHRER4+fIjd3d3MZ0UUi8XEE7rj2kaU9YT9TpDlRy0z7HO/z4btM47j4MOHD3jw4EGAv0GGTzGiDPidzaSMkmqtqurA7yWZSXnt2jXMzs7iiy++8C0Ct2/fznyBmJqaSryNcW0jynrCfifI8qOWGfb5qO/67TN37twZ2qbTMptJKcsy5ubmvGWEEJBl2bf3cBkyKZeXlxPfRpLSaH9c24iynrDfCbL8qGWGfZ7K/pLYibWPdrvt1Go1B4BTq9X6zu1LpZJTr9e9171ez6lUKo6u606lUnG63a7vemu1mjde4ep0Oo6u66Hb0e12vfcGrTdNh4eHDgDn8PBwbG2gyRLnPsNMyoz766+/sLq6ipWVFXz55Zfjbg5NgDj3GRYIIvLFyVpE5IsFgoh8sUBMIHd2q99DgYhOMwwj8OzpszgGkQG2bXsHe61W8953n/Z1cHAAWZahqipM04QQApVKxfs8zvkhlH1h9hfX/Px8pMmO7EFkgGma2N/f73tPCIF2u41SqYRKpYJ6vQ4A3sQ14OSuz6TvXKTsCbO/XBQLRAaUSiXk8/m+90zT7Ls5S5Ikb9o506qutrD7y0WwQGRUt9vFvXv3vNczMzOwbRuKonhzRWzbRrFYHFcTKUP89peLYoGYIAcHB1BVFbZte2MRHH8gP+4PibuvnH6CfVCZnax11eXz+b5fAHfgCfg8MHV6EIqutmH7i6qqkaMa2YPIKFVV+wYghRAsCOQrqf2FlzkzwDRN6LoO27ZRrVa904bTl61mZmZ4OkEA0t1fWCCIyBdPMYjIFwsEEfligSAiXywQROSLBYKIfLFAEJEvFgiaCKZpIp/Po9FooNlsolAooFAooNlsQtM05PP5QA+EpnB4qzVNBNu20W63vduH2+02ZmZmvOeuLi4uQgjh+1gEioY9CJoIp+cWDHJ6livFhwWCJsLCwkIsy1A4LBA0EYI82SyNp59dNSwQROSLBYKIfLFAEJEvFggi8sX7IGiiuPmK7k1RzWYTc3NzvP8hIQyMISJfPMUgIl8sEETkiwWCiHyxQBCRLxYIIvLFAkFEvlggiMgXCwQR+WKBICJfLBBE5IsFgoh8/R+8CVJ+NATsegAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 200x150 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "\n",
    "fig, ax = plt.subplots(1, 1, figsize=(2, 1.5))\n",
    "# fig, ax = plt.subplots(1, 1, figsize=(5, 5))\n",
    "leg = []\n",
    "for i in range(num_models):\n",
    "    a, = ax.plot(batch_size * np.arange(nb_epoch), error_mean[i], color= 'C' + str(i))\n",
    "    ax.fill_between(batch_size * np.arange(nb_epoch), error_mean[i] - .5 * error_std[i], error_mean[i] + .5 * error_std[i], color= 'C' + str(i), alpha=.3)\n",
    "    leg.append(a)\n",
    "ax.set_xscale('log')\n",
    "ax.set_yscale('log')\n",
    "# ax.set_ylim(1e-2, 1e0)\n",
    "# ax.set_yticks([.2, .3, .4, .5, .6, .7, .8])\n",
    "# ax.set_yticklabels(['.2', '', '.4', '', '.6', '', ''])\n",
    "ax.set_xlabel('T', fontsize=10)\n",
    "ax.set_ylabel('Error', fontsize=10)\n",
    "ax.tick_params(axis='both', which='major', labelsize=8)\n",
    "ax.legend(leg, [r\"B=1\", r\"B=10\", r\"StepLR\"], fontsize=7)\n",
    "fig.savefig('sgd_step_lr_mean.pdf', bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
