{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from math import sqrt\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import torch as th\n",
    "\n",
    "plt.rc('font', family='serif')\n",
    "plt.rc('text', usetex=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def modular_class(x, m):\n",
    "    return x % m\n",
    "\n",
    "\n",
    "def squared_modular_class(x, m):\n",
    "    return (x**2) % m\n",
    "\n",
    "\n",
    "def get_embeddings(n, d, norm=True):\n",
    "    emb = th.randn(n, d)\n",
    "    if norm:\n",
    "        emb /= emb.norm(dim=1, keepdim=True)\n",
    "    else:\n",
    "        emb /= sqrt(d)\n",
    "    return emb\n",
    "\n",
    "\n",
    "def get_q(x, P, rho):\n",
    "    counts = th.bincount(x)\n",
    "    _, order = th.sort(counts, descending=True)\n",
    "    if P is None:\n",
    "        P = len(counts)\n",
    "    P = min(P, th.sum(counts!=0).item())\n",
    "    idx = order[:P]\n",
    "    q = (counts[idx] / counts[idx].sum())**rho\n",
    "    return q, idx"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Scaling with respect to d and T jointly"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# number of input tokens\n",
    "n = 100\n",
    "# number of output classes\n",
    "m = 5\n",
    "# memory dimension\n",
    "ds = np.logspace(1, 3, num=30, dtype=int)\n",
    "dmax = np.max(ds)\n",
    "\n",
    "# number of data\n",
    "ts = np.logspace(1, 4, num=30, dtype=int)\n",
    "tmax = np.max(ts)\n",
    "\n",
    "# Zipf parameter\n",
    "alpha = .5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Population data\n",
    "all_x = th.arange(n)\n",
    "proba = (all_x + 1.) ** (-alpha)\n",
    "proba /= proba.sum()\n",
    "all_y = modular_class(all_x, m)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,"
     ]
    }
   ],
   "source": [
    "P = None\n",
    "rho = 0\n",
    "\n",
    "nb_trials = 100\n",
    "\n",
    "errors = th.zeros(nb_trials, len(ts), len(ds))\n",
    "errors[:] = -1\n",
    "\n",
    "for i_n in range(nb_trials):\n",
    "    # Embeddings\n",
    "    E = get_embeddings(n, dmax, norm=False)\n",
    "    U = get_embeddings(m, dmax, norm=True)\n",
    "\n",
    "    # Empirical data\n",
    "    x = th.multinomial(proba, tmax, replacement=True)\n",
    "    y = all_y[x]\n",
    "\n",
    "    for i_t, t in enumerate(ts):\n",
    "        q, idx = get_q(x[:t], P, rho)\n",
    "        W = (E[idx].T * q) @ U[all_y[idx]]\n",
    "        for i_d, d in enumerate(ds):\n",
    "            scores = E[:, :d] @ W[:d,:d] @ U[:,:d].T\n",
    "            preds = scores.argmax(dim=1)\n",
    "            errors[i_n, i_t, i_d] = proba[preds != all_y].sum()\n",
    "\n",
    "        # print(len(idx), end=',')\n",
    "    print(i_n, end=',')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "errors_mean = errors.mean(dim=0)\n",
    "errors_std = errors.std(dim=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAJgAAACFCAYAAACnijFkAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAA9hAAAPYQGoP6dpAAANbUlEQVR4nO2dP2zbVh7Hv5YsS3ZsSZGDxk6iHCIZbntdDnK63HRoaXS4raUz3qgOtxVFg0xpp9RGxnaw0elWc7rtIKJjhzMs3HKXNoClxkocp6hl/fE/yZZ1g0KapEiJlPhEivp9AELSI6X3JH/9ez/+3u+9N9ZsNpsgCEb4nG4A4W1IYARTSGAEU0hgBFNIYARTSGAEU0hgBFNcJzBRFCEIgtPNIGxi3OkGKCmVSshms4hGo043hbCJgVgwQRCwtLSkKtvY2IAoitjY2JDLRFEEx3GDaBIxIAYiMJ7nEYvF5NeiKAIAOI5DLBaDIAhytyiKIra3t1EqlQbRNIIxjnSR2WwWqVQKABCNRpHJZLC6ugoAePjwYdf312o11Go1+fXl5SWKxSJmZ2cxNjbGptEupNlsolqt4tatW/D5Wrbi7OwM9XpdvmZiYgKhUMipJjrngxlZKElonXjy5Am++eYbm1s0vBQKBdy5cwdnZ2e4MTmNYzTkc3Nzc8jn846JzBGBpVIp5HI5AC2hffjhh5be/+jRI3zxxRfy63K5jLt37+Jff/0LrgXUXyl67x35eWThjvw8ePsPAIBAfEEua4TnWo8zNwEAB6etP9TB2YV8Tf7wFADw6+EJAODZXhUA8OJN67G4f6yqv/zmN9XrystfzH1JEzQv6jjZ2sDMzAwAoF6v4xgN/D1wF0H4UMMlvt/fRb1e97bARFFELpeDIAjgeR4cx2FtbU0u/+qrryx9XjAYRDAYbCu/FhjHtEZgM8EJ+Xl48uo9wWuTAIDA9DW5rPH2D9UIhwEA9UBLYLXAlcCmzlufH6r7AQATU61sp/HQJQDAH1RnP/kCk6rXY+Pt7e4XrVsQhA/BMR/ggkSsgQiM4zjs7OyoyiRR0V2jt3FVHIywh3enA5j0+XF62QAOnW2L6yL5hLcggRFMIYERTCEfzIPMz8/gmt+P4wb5YITHIYERTCGBEUwhH8yDxBZjmA6MI3h+AfzP2baQwAgArfy8RCKBXC6HdDrddl4URZRKJfm82aRQ6iIJ3fw8JblcDplMBjzPWxIXQAIbCSqViupQ5tIBrfy8RCIBoJWft7W1pTovWS9BEPDkyRNLdZPAPEj03ju4npyTU5Xi8TgikYh86ImkWwZxMpkEz/NIJpOqNPdukA82AhQKBYTfpiABaEt16pafd//+fbkbjcViKBaLpusmCzYChMNh1aEVGMdxKJVKcn4ez/MArtLXpfR2QRCwtbWlexNgBFkwAoB+fp4yfV06L4nPLCQwDxJZuIPwZBC+01r3ixnjqi5Smiu5trbmdFMIm3CVwDiOQyKRaEuvJoYXV83sjsViyGazSCaTg2gWMQBcM7NbFEUUi0XwPI9MJjOIZnmW4O0/IBi/J0/NcxLXzOz+/PPPZeGtrKx0fL92ZnelUmHXWKIvXDOzO5FImI6v0Mzu4cERJz+VSsnR4F5ndpfLZfkoFAosmknYwEAEppzZDRhHjs0SDAbbotPEFYH4AgJ3F1XLIjgFzey2mdn5aRy8PnK6Ga7BVXEwwnuQwAZMJP6+000YKDQW6UEa4Tk0ZmbQ8FWdbgpZMCcYJStGArOJe/P6d7LR+Zu65aMiMtMC++STT/D06VP8+OOPLNvTF8Xn5jMtrXBj0tiT+OB2u7Bm56dNfe4oiMy0wDiOw5dffomPPvqo7dwoDtUkYlOmrzWyYgAbkTVmbr71w4zrHRSmnfzr168bnhNFEZ9++qktDXITNyb9+P200f3Ct9ybDyP/Wv+fLTp/E6XXb3TP9Sqyy/NTHHe/zFFMC+zbb7/F5uZmW3mz2UQ+n/ekwPRIxqawUzzpep1ewFWyZEZC8yKmBcZxnGGWg57wRokPbofx31ftlssoqj9KQjMtsPv37+Pjjz9m2ZahIxGbQk5jzbTdZKeho06+mRkatRO81ik/OG2gHmigaqF7Z4VpJ79TEuAwC+9893lbmb+y39dnakMWZu8qvYhpgT148AA//PCDq8MULFGGKpI6d5DacIWeyJTHqGC6i/zss89YtsO1dLuTVHaTWl+s012lHSK7OHP/vkwUybdAp4CrhJ4lkw43ozcJR++8mc3KlIzUYHfpeQHRxbjheX95D43Irdbzyr68d5EeynCF1tk3uqu0W2T1Ex9+0ik/OLtALXCBI8UeS51QTsKRtlZUJoFK6e0cx2F9fR25XE5ejacbZMEsYuSLaSP7H9wOy4fT9Lt8UzQaRTqdlhepMysuwGUCEwQBGxsbA5nZXSvkTV97Y9Jv6rpEbEp3CEkptn6F18v77Vi+CQDS6bS8yqFZXNNFSv9FPM9jeXnZ8g5sEgc/H2D2vVlL7znffY7A3UUAxt2k0tm/MTmO309b3Y9kxZTRfaXItHEyJf1at/dvzZi6rt/lmwRBQDQaBcdxSCaTEATB9N9nIAKTVsbb3t6Wy7RrgkrzJK30773QzQ8zi1JkgL7QAGuD4lY5mdC/u80fnmLqfBwnR629LbtNjDHaXvHhw4dYXV0Fx3EQRRGCIGBnZwfr6+um2zjWbDYHsqvg8vKyHKyVvkg6nZZnGvE8j1wuB1EU8eDBg47rgOpNvI3H4/jnH1O45ve3WbDrSbWzrhRYMH4PAGQLJiFZMQAqZ18bslCKTIuZMct+ODmq4m9/fg/lchnhcBiVSgWRSAT/+OlnTE3PtJ13AtfM7E6lUlhZWUEikcDm5mbHkQMWE2+V3WQnJH9M2V1KaMWmF5C1k6OAubtEJ3HMydeb2b29vd1VXID1ibeHO8ZDP0bOvr+8d/VcZ+hIz/G/MTkuH0SLodyz22hLZYlujr6RH6a1YlqHH1B3l0qRabvOQYgseK5fx6+HJwjV/Tg7ZttFm2EoZ3b3QicrpkQ7+K20ZEBLaEYWTXsQHp7ZbSVcUSvkZWdfD6Ulk8t0LJoW1iKbOHe/iF0VaLWbg58PVK+VVqz0XO23KX2x893nupZMa82AK4umdxAuCrS6Aa0l07uzlESmtWh6sBaZv6o/sfbZXhUTU03UT5xfI8PzAuvUVZoJukqWzEhoRpgRoNHntXXHOl30sOB5gWk53NlXBV61IjPyx4yEZkQ3AVp9r27ZkdvnFI2IwKyGLSR/rJPQjDArwFHBUwJ7/bqKhTtR3XNKkWmtGKDfXXYSmhHdBGgn58enuuUv3lQxHrrExZnzFs5TArOCWZEB3VN7rAjQ6LP7+Qw34zmBvdqr4rbJNBYrIuuEldwyOz+j5oKtYrrhOYF1Qs8XMxKZFjtSfEYRTwqskxUzKzIteqJzmmqtrlte3D+GP9hEozYiY5FO8Gqvild75lf4O9zZNz1eSZjHkxZMiZ416xS2MCOybtaOuMLzAgOMRQbAcv4+YD4zgzVH5+5POBwJgQHGfpl2QBzoTXRuovzmN/gCk7g814+TDZKRERhgPoShJzolwy7AQeIpgf1ydI4/hTvnSFmJkxnRTYCD4rjh/PJM3fDcXeSzqv6tuxIrd5dEf3jKgllBT2T9Wja3UHn5C8bGg2heOB/pd0xgZibj9sqzah3vz0xYfp9Vy9aLIO3oollg12+vxTGB8TyvmiHcbYUXqyi7yl7EZoZeu1q7uujTS3t8MLt/eyWu6SL1JuMafUntzO5yudwqb17qXv+fyhnenQ7Y3GLnOW22BKadnN9s1FWP2n0MtNP+rPz2VnGNwABzK7wAxjO7v7/YNX7TYY+NGgIODg4QiUQwMTGBubk57P/7ahG56elpxOPqgfrHjx/j66+/VpWZ/e2t4pq7SCvbLGtndr948QIAsLu7qyrXHouLi6bKla+l50aPCwutXWULhcLA604mkwCAWCwGAAiFQsjn86rPe/nyZVudjx496vm3t4pjFkw5GZfnecMVXvQwmtkdiUQ6LvLh9/t1z2vLla+l550ege4r2LCoe3y89efz+a7sRCgUQigUMmyHHlZ+e8s0PUC5XG4CaJbL5Y7Xfffdd6bKla+l50aPT58+dX3dTjKw5ZtYIi1b5MQyRaNat1lc44P1QzAYxOPHjzsuiEJ1O4MnLBjhXjxhwQj34qo4mBcQBAHFYhGlUsneuzETSMM92Wx24HUbQRbMRqSVstPpdNdVGlnAcRwSiUTbUllOMpQCEwQBS0tLqrJuW6EMos5UKiWv3mjnStlmv28sFkM2m5UDsG5gKAXG87wcvQbUg7WxWExeSdGJOqWVsldXVwdatyiKKBaL4HneEetpxFAKTEu3rVAGVWcul8PKygoymYzh7sCs6k4kErJFY1m3VTzj5LMarLVSp7RStlN125nHZReesGAsB2vdVKcb6rbKUArMiVWrnVwp2w2rdPcKRfIJpgylBSOGBxIYwRQSGMEUEhjBFBIYwRQSGMEUEpgN5HI5LC0tyWOExBUkMBtIJBLMdo0bdkhgBFM8M9jtBGtra4hGo3Ie1vLystNNch0ksB4RBAE7OzvyAi5uysFyE9RF9kgmk1FlmUajUeca42JIYD2yvLysyn13Ih9tGKBsij5YW1uTM0vX19cRjUaxubnpcKvcBQmMYAp1kQRTSGAEU0hgBFNIYARTSGAEU0hgBFNIYARTSGAEU0hgBFNIYARTSGAEU/4PSGBidx0S1lEAAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 100x100 with 2 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "plt.rc('font', family='serif')\n",
    "plt.rc('text', usetex=True)\n",
    "\n",
    "X, Y = np.meshgrid(ds, ts)\n",
    "fig, ax = plt.subplots(figsize=(1, 1))\n",
    "c = ax.contourf(X, Y, errors_mean.numpy(), levels=20, cmap='RdBu_r')\n",
    "ax.set_xscale('log')\n",
    "ax.set_yscale('log')\n",
    "ax.set_xticks([10, 100, 1000])\n",
    "ax.set_xticklabels([\"10\", r\"$10^2$\", r\"$10^3$\"], fontsize=6)\n",
    "ax.set_yticks([10, 100, 1000, 10000])\n",
    "ax.set_yticklabels([\"10\", r\"$10^2$\", r\"$10^3$\", r\"$10^4$\"], fontsize=6)\n",
    "ax.set_xlabel('d', fontsize=8)\n",
    "ax.set_ylabel('T', fontsize=8)\n",
    "bar = fig.colorbar(c, ax=ax, ticks=[0, .3, .6, .9])\n",
    "bar.set_ticklabels([r\"$0$\", r\"$0.3$\", r\"$0.6$\", r\"$0.9$\"], fontsize=6)\n",
    "# fig.savefig('weight_storage_error.pdf', bbox_inches='tight')\n",
    "fig.savefig('fill_storage_error.pdf', bbox_inches='tight')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
