{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from math import sqrt\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import torch as th\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "plt.rc('font', family='serif')\n",
    "plt.rc('text', usetex=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class AssMem(nn.Module):\n",
    "    def __init__(self, d, n, m):\n",
    "        \"\"\"\n",
    "        d: int\n",
    "            Memory dimensionality\n",
    "        n: int\n",
    "            Number of input tokens\n",
    "        m: int\n",
    "            Number of classes\n",
    "        \"\"\"\n",
    "        super().__init__()\n",
    "        self.W = nn.Parameter(th.randn(d, d))\n",
    "        # self.W = th.eye(d)\n",
    "        self.E = nn.Parameter(th.randn(n, d) / sqrt(d))\n",
    "        self.UT = nn.Parameter(th.randn(d, m) / sqrt(d))\n",
    "\n",
    "    def forward(self, x):\n",
    "        out = self.E[x] @ self.W\n",
    "        out = out @ self.UT\n",
    "        out = F.softmax(out, dim=-1)\n",
    "        return out\n",
    "\n",
    "\n",
    "def modular_class(x, m):\n",
    "    return x % m"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# number of input tokens\n",
    "n = 100\n",
    "# memory dimension\n",
    "d = 2\n",
    "\n",
    "# Zipf parameter\n",
    "alpha = 2\n",
    "\n",
    "# Population data\n",
    "all_x = th.arange(n)\n",
    "proba = (all_x + 1.) ** (-alpha)\n",
    "proba /= proba.sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# number of output classes\n",
    "ms = np.unique(np.logspace(0, 2, num=50).astype(int))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1,2,3,4,5,6,7,8,9,10,11,12,13,15,16,18,20,22,24,26,29,32,35,39,42,47,51,56,62,68,75,82,91,100,"
     ]
    }
   ],
   "source": [
    "nb_trials = 10\n",
    "errors = th.zeros(nb_trials, len(ms))\n",
    "\n",
    "for i_m, m in enumerate(ms):\n",
    "    assoc = AssMem(d, n, m)\n",
    "    opti = th.optim.Adam(assoc.parameters(), lr=1e-1)\n",
    "\n",
    "    all_y = modular_class(all_x, m)\n",
    "\n",
    "    # number of data\n",
    "    batch_size = 1000\n",
    "    nb_epoch = 1000\n",
    "    T = nb_epoch * batch_size\n",
    "\n",
    "    for i_n in range(nb_trials):\n",
    "        for i in range(nb_epoch):\n",
    "            x = th.multinomial(proba, batch_size, replacement=True)\n",
    "            y = x % m\n",
    "\n",
    "            opti.zero_grad()\n",
    "            softpred = assoc(x)\n",
    "            loss = - th.log(softpred[th.arange(len(y)), y]).mean()\n",
    "            loss.backward()\n",
    "            opti.step()\n",
    "\n",
    "        with th.no_grad():\n",
    "            pred = assoc(all_x).argmax(dim=1)\n",
    "            loss = proba[pred != all_y].sum()\n",
    "            errors[i_n, i_m] = loss.item()\n",
    "\n",
    "    print(m, end=',')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "errors_mean = errors.mean(dim=0)\n",
    "errors_std = errors.std(dim=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAL8AAACGCAYAAACMn1tjAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAVGklEQVR4nO2d228bV37Hv8MhObyIF8lOHDtW1qEDL7rYLmLJQlGgWKCWHPSlL4HkPPRZFoL0KQgkOEARpEAhW/+B7D71pYil5GEXaAtb9rZAFwVim5vsBrubrEV7rcTxRSIpSrwN59KH4TmcGQ4pXiYiJf4+gAByOBzNSN/5zff8zu+cI+i6roMgBhBPr0+AIHoFiZ8YWEj8xMBC4icGFhI/MbCQ+ImBhcRPDCwkfmJg8fb6BHqBpml48uQJIpEIBEHo9ekQXaLrOnZ2dnDixAl4PK3H84EU/5MnTzA6Otrr0yBcZmNjAydPnmx5/4EUfyQSAWD8saLRaI/PhuiWXC6H0dFR/n9tlYEUP7M60WiUxH+IaNfCUoOXcJ2CrDT9PJOXW9oGADulSt02t2oxSfyEq8iKht88zjbd5+tnO3XbfvnbJ/jj01zd9i83tuu2lSpax+dnhsRPuEpRVvG777ZRUZ0FWpAVbKQLddvLFQ3/9dVT/OH72g3wLFfCk2zRsp+m6ZAVEj/RhxQrKoqyahGxmdSLPCpqvW2paBp0Hfifb16gVFEBAH98ulNnoUqKCh1ke4g+hIn1N4+zjt58/cUuVK0+cqvVG6Ioq/j1g03ouo5vnu6gaLM4RVl17VxJ/ISrFKriTOdlPNzMWz6TFQ2Ptwp1kV/XdShabdvvvtvG3UcZ7JYVlBUVmumzkkuWByDxEy5TrNQi82Obt9/cLUPRdKiaVfyKZr8ZgF8/2OSvzcekyE/0LQWTOMu2KJ0vG5bILnbFoQ3Q6JilComf6FOKpgZqnfirIrZ7fsWhDWDGLHgSP9G3WCK/TaiFauS3Z0Hbi/zk+Yk+pantqX6m6Vbfb7dBdiyenyI/0S/Yyw+KTcRvztmbrY69AWzH/D2yPUTf8HynbElFmiNzWbEKNV+uvTcLvlFvMKNEkZ/oR0oVFdvFCn9tFrWsaJaOLmvkr23fK/IX5drNYW9HdAOJn+iKiqpjq1qRac/B63rN+ui6bmkPqKrZ8zeP/OabhiI/0Tcoqoat3TIAoOAgTCb+UkVr2Mjdq8FrTXVStofoE2RVQ5pHfgW6rlusDvP9eVuBmvlGKFe0ppWa7IlRVtQ9LVI7kPiJrlBMtqcgq/jFl0/wyb0NLlIm6kLZ+lQwN3L/5T/+gH/93xR2y86DYEoVo+1Qkt2L+gCJn+gSRdOQycvQNB3bxQoebRXwLFdGpmDcEMz2NIv8f3q2g4qq47uMtXafoek6ShUNJcU9vw+Q+IkukRWjInO7WEF6tzYUkQ1LLFc9ur0u3+zzmY9/vlNq+HvYOAE3IfETXcEyNVt5mdsfAMgUjPQn9/w228Miv67rvEH7Yqfc8PcUZKV/Iv/ExAQ+++wzN8+FOIAw757Oy8gWzOK32Z6yPfLXskDsGfB8p9xwcHqpGvndXEioY/FfunQJb7/9tmXbnTt3uj4h4mDBBqak82Vki7VShyyP/MzzV+t6NB2//TbLx/HmbVWgO6Xa+51SBf/51fd4miuhIKsoVTT82//9GT9f+hV+9239wPZ26XjeHkEQ8O677+L06dNIJBJIp9NYWVnB+fPnuz4p4uDAIv/mroycSfyZggxd13mPLPP8D17s4ldfv0BeVvF3Pz1elwV6vlNGNOgDAHz1XQ7fPNuFRxDw9tlXjehfUVFWNAT9Ytfn3nHkv3LlCnRdx+bmJj7//HM8ePAA6XS66xMiDhasHDmTly2pyrKicaECNc//PGf4emaL7Fkgs+9nDeCCrKJQUZEvK/x48ZCv63PvOPIvLy9jcnLSsu327dtdnxBxsGCRX9H0ut7XTL6CsmL07LKG72beEDfL3NizQOaMz2Y1e1SsqCjJKtKmNkUs2EPxT05OIpfL4caNGwCAixcv1t0MxOHHPBjdXsWZKcgoKyrysgLWTt2qCppleNgTQQCgA3jBSiVkhT9JirJhd1j6dEjywid2n6js+AgPHz7E+fPncfPmTdy8eRPj4+P44osvuj4h4mBh7qm1lyhkCxWUKxr39eWKygVdsuX/jwz5ARg3Q76sWOxPsWp5WIM6GnRnitmOj/Lpp5/i3r17lm2XL1/Gm2++ued3U6kUVldXkUgkkEqlcOnSJcTjccd9k8kk1tbWAAB3797F9evX+b7JZBIAMDY2hlQqhWw2i7GxsU4viWgTzTYTA/Pj0YAXuZJSjfwa9/Wbpn4Ae+QP+71QQzoyhQqe5kqWuTtVTcdOSeGl03EXLA/Qhfhff/31um3nzp1r6bszMzO4f/8+AONGmJ2dxcrKiuO+a2trmJ+fBwAsLS1hcnKSf3d5eRnXrl0DAExNTTU8BvHDINsGoTDxvxINIFfaRaYgQ1ZqkZ9VfwJGG0FWNB75faIHPzoSRqaQxVffbcNvszWbuzJ2q2nQqEvi79j2pFKpum0PHz5s+3uJRIJHdjvJZBKLi4v8/fT0NJLJJD/G+Pg4MpkMMpkMbt261fDpQfww2EuRWVrzWCwAANguVlBRNZ7Z2dq1zsRclFWe//d5BfzsZAwA8GirgMcZ65w/eVnhT4ueR/6pqSm89dZbGB8fB2BE6KtXr+75vbW1NYyMjFi2jYyMIJlM1lmWsbExXL9+nb/PZrN8fwYJvndUbB6/XH0SHAn74fUIUDQduVKlJn7bNOS7ssJndPCJHgyH/Dh1JIRHWwXeJhiSvNgtKyhVO7mAPoj8Z8+exfLyMq/fvnbtWksdXEzAdhr1EUxPT/PXn3zyCaamprjgs9ksVldXsbq6ioWFBcenEQCUy2XkcjnLD9E9FdsILFbEJvlEDIeMBmymICOdNzq8NnettTuFslKL/FWb8+ZonH8elkSMhI3jFCsqj/xupDmBLiL/xMQELl++jCtXrrhyIo1uCvPnq6ur3O8DsDSUE4kELly4gPX19brvLi4u4uOPP3blPIka9jk3WbZH8noQD/nwYreMbL6CXNEQeVnRIAhA0CeiIBuZH+b5mcd/bSSEkbAf6byMl4YkSF6jJ7dYUXlhm1vi3/fanng8Xhfl0+n0nvZlYWGhztebIz3LHDlF/8uXL2N7e5v/bGxs7HmexN4opgavpuu8AczEDxi+X9N1nr2JBXwI+gxBF2SVN2J9orGkkCAI+OvEEXgE4MyxCN+3aLI9PY/8ndb2TE1NYXl5uW57s0zR0tISFhYWkEgk+BMilUphcnISmUzGsq+9PQEAkiRBkqQWropoh0Y5fsnnQchvSIsNON8p1zI1sqnSc7ca+QO+Wq3OGy8P4b2/fQMeQcBOKc2P47bt2ffankQiYXmfSqVw7tw5S+7eHL1XV1cxNjbGhX/jxg3E43EkEglLA3ttbQ3T09PUAN5HrL27hqC9HgHDIQkBnyEtJn4W4YckL/xe47O8rPA0KPP2DE91cTl2nFI/ef5uantWVlawsLCAiYkJ3L1715KfX1xcxMTEBObn55FKpTAzM2P5bjwe517/3LlzWFpaQjwex/r6OuX59xmnyC/5PAj7RW5XmGDZzG6RgNcywIX1+B4dsoqfwao3C/1kexYWFvDhhx9afH+rtT3mqG3O5gCwCDiRSDQdvDA2NkY9uj3Eqa4n4BMRkrzcxjDBMtsTCXh56XNBVvjMDMNhP3LF+gHs5vYBa1PEXKjoBLqwPXNzczSYZcDJlxX899fP8Thd4LYn5BMtkd/J9rC05m5Z5dmekVCDyF89Ts40J2gk0OPaHhrMQiT/nMGX327j20wRY68NAzBsStAv8sivajoqam2EViTgg6/q+QvlWuQfCfvxaKt+lUZme5gBkLweeD3uDD3vWPxXrlzB1NQUNjc3sblpLCGztbXlykkRBwPWY5vOyzxrE5a8CPu98IkCPAKg6Ua6k1mWSMDLc/r5ssLr+o9GnLNxAa91xJY5K9QtNJiF6Bg2YF0H+Jw7Q5IXYUmEIAgI+kTkZZWXJwe8HvhED8/ppwsyrw96achZ/B6PgIDXwxeiY9kfN2hZ/Hfu3OEpSPPAldu3byOVSiGZTOL06dM0oGWA2DaN2WWLRQ9JXgSrOf6ATfxDVa/ObI+5Zn847IdPFBzX6A36xZr4vT2I/DMzM7h9+3Zdvf7k5CQmJyeRzWZx+vRpfPDBB66dHNHfmBuhLIJHA16Eqz6dWRQm8kjAyNIw2/O8ut3rEeATPQhLXj7rgxnjOMZ2qReRf3Z2lgv/0aNHls9OnTqFeDyO2dlZ106M6H+cUpPRoI/37rJMDRuaGJGqkb8qflboxm6SoQbiD5lmagj6RAgQXDn/lm+jI0eO8NeZTAYzMzNYXV217HP69GlXToo4GDhNLBsN+OD3euD3erg/Z2lQZntY5Gd9ACyjMyTVx2LjODXxSz7RtejfcuQ3lw2cPXsWFy9erLM4guDOHUkcDJzEP1wtUwj6xLq5dVh+Phayyo5F9iGH/P3xWIA/Qdi+bmV8WhZ/KpXCzs4O73EVBMHyHoBjOTFxOFFUjefoX4kG8DRnTDnCRlmFpXqRRiTjs1fjIct2Jv6wQ+Q/HgtabqJYwJ3eXaAN23P16lXE43EMDw9jeHgY8/PzlvfxeBxLS0uunRjR32RM3vy1IzUxj1RrdEJ+b734q5H9tRGr+JndiTiI/9V40BL54w16gjuh5ch/6dIlLCwsOJYMA0YHF4l/cNiqTj4V9Il4JRrg24erkT9kKnFgsMh+cjho2c7E7xT5j8Uk3oAGgOGwe5G/ZfHPzc05ztjAiMVimJubc+WkiP5ns5qmDPpFvFztnRUE4GjEuBGMyF8zFmFJhOgx2oQv2Xpz2RPB7vlDfhGSV8SISfD20uduaFn8Z8+edWUf4nDA5twM+USEJS8m/+JlALU5NO2en/l9QTBKkv2ih5c8hKufhf1eCEKtjoeVLh819f6+HKk9ZbrFnfI4YuBguXvWGP3piRjEamcVYER+s+1h0T3kF+EVPZB8JvFXjyF6BIT8Ip/Iis3ScCwqIegznhxHGtT9dwKJn+gIZnvMHVBshBbbLnk9fA5OZmlYL2/IJ/JKz5DJ64clLxc/i/zxkB//8FevwSMIjn0BnULLEhEd8aI6AZU5DSmZxG9YGIFbH5bJiVbFb/5e2PTaXNcf42lTL8KSF0G/iLCfxE/0mLQp28OwRH5JtHxei/z1mR1z5H/j5SH+monfHO3dWJSCQeInOoItPG1OQ0qmikufaJQ4/OXJGI7HAnht2MjtR2w3AWCN/KeOhvlNxJ4SlhvFRfGT5yc6gnVyNbI9gCHUN0fjllnYog7R3HwD+UQPXj8axp+e7ZqeEsbvYBNeuQVFfqIj2PybjRq8ABz9Oa/vMc3AwMTNOHNsCJGAF55qvwC7UYI+kW9zAxI/0TZlReV1PeZIXBf5pfoozaxMxFSjY7cyp46ELcMagz4RHkFw1fIAJH6iA5jf9whWwUu2UVZ2sZrLk83RPmR7QnhFD86arJIgCAhLYt1+3ULiJ9pmy5TmNJex72V7zCsomhuxTvZo1Fb8Fpa8FPmJ3sNGYIV8VtHabY+9hidxtJbGNAveyR7ZYXl+NyHxE23DbI9djHbxn4gHYR7f9ONXIvz1XpHfzpAkOlZ9dkNPUp3tLEjXbN92jkO4x5ZD7y5Q7/kDPmNxia1dGS9FJEtFJsvtC0Jr05GEbbVCbtAT8bezIF2zfds5DuEebCHpkE2MTmNrT8SC2NqVLVEfqPXqsjKIvWAlDm6y77annQXpmu3bznEId2kU+e0rKAKG9QGMhSbMDLHyhxZ9/KFo8DZbkK6dfds5DuEOqqZju1jB99vGBFV1tsch8r8aD+J4LFA3rfhPjsfws5MxTI+fbOl3h6X6AfHdsu+2p50F6Zrt285xyuUyyuXa7GBOC9L98y9/j3///LHjMQlA1fW6FdZHQn68Gg+iWFGRLVTqPD9gTCc+9qPhuu1Bv4hf/OPftPz7I5KPT3PoFn1T27PXgnSt7uv0WSsL0lVUjU+nTTQnLIn48bEI/unvf8J7bFVN58MU7dgtTye4HfWBHoi/nQXpmu3bznEuX76M999/n7/P5XIYHR217PP+hTO49POE/atEFY/HmHi20bw5jYTf1+j7zPr6uj42NmbZFo/H9Uwm09a+7RzHzvb2tg5A397ebvv8if6j0//nvkf+VhakYwvONdvXHuHtx2mGXh0hTYtRHw7Y/1FvsoSVIz/IrbgH6+vr+vz8vL6ysqLPz89bovX09LR+9erVlvZt9lkzNjY2dBhDS+nnEP1sbGy0pUNB19u9XQ4+mqbhzJkzuH//PgRBwMTEBG7fvo3R0VFsbGwgGo12dFy2umSn+zlt32tbs9cH5ZqavW/lmnRdx87ODk6cOAFPG0sW9U22Zz/xeDzw+/2IxWIAAFEU+R8yGo12LBTzcTrZz2n7Xttaed3v19TsfavXxP6X7TCwhW3vvfee42u3jtnJfk7b99rWyutu2I9ravb+h7gmxkDaHidyuRxisRi2t7c7jpL9Bl1TcwY28tuRJAkfffQRJMl5YbSDCF1TcyjyEwMLRX5iYCHxEwPLQKY67WSzWVy7dg0AMD8/3+Oz6YxG18AWDUyn00gkEpiamurJ+bVKu9fRzfWR+GGMG9ja2rKsOHnQcLqGVCqFW7duYXl5GQBw4cKFvhd/O9fR7fWR7QEwPT194JdRdbqGtbU1S61TPB7v+9Fu7VxHt9dH4j/ErK+vWyLoyMhIW+Mm+oVG19Ht9ZH4BwynkW4HkUbX0c71kfgPMXb7wBqFB41G19Ht9ZH4DzFTU1OWaslUKtX3DV4nGl1Ht9dHPbwwGlTLy8vIZrOYm5vD9PR0r0+pbRpdgzkVODIy0vfX1u51dHN9JH5iYCHbQwwsJH5iYCHxEwMLiZ8YWEj8xMBC4icGFhI/MbCQ+A8JyWQSc3NzEASB18PbmZmZwfDwMJaWlvb57PoT6uQ6RGSzWczOziKVSvEVa8yfLSws8Bp4giL/oeOdd95BKpWqW7nm3r17GB8f79FZ9Sck/kNGPB7HxYsXec0L0RgS/yFkbm6OD+0DjPbAuXPnenhG/QmJ/xAyNjYGAHx9skaLdgw6JP5DyvT0tCX6E/XQ7A2HlLm5OYyPj2NmZuZADmDZDyjyHzLYAG62sg2lNRtD4j8kJJNJzM7OYnFxkWd65ubmcOHCBQDGiKeVlRXcu3evYSfYoEGdXMTAQpGfGFhI/MTAQuInBhYSPzGwkPiJgYXETwwsJH5iYCHxEwMLiZ8YWEj8xMBC4icGlv8H1AEjZ0pgvqsAAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 150x100 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "fig, ax = plt.subplots(figsize=(1.5, 1))\n",
    "ax.plot(ms, errors_mean)\n",
    "ax.fill_between(ms, (errors_mean - errors_std).clip(0, None), errors_mean + errors_std, alpha=.5)\n",
    "ax.set_xscale('log')\n",
    "ax.set_xlabel('M', fontsize=10)\n",
    "ax.set_ylabel(r'Error', fontsize=10)\n",
    "# ax.set_yticks([0, .02])\n",
    "# ax.set_yticklabels([0, .02], fontsize=8)\n",
    "ax.set_xticks([1, 10, 100])\n",
    "ax.set_xticklabels([1, 10, 100], fontsize=8)\n",
    "fig.savefig(\"error_m.pdf\", bbox_inches='tight')\n",
    "\n",
    "# ax.set_yscale('log')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "m = 5\n",
    "assoc = AssMem(d, n, m)\n",
    "opti = th.optim.Adam(assoc.parameters(), lr=1e-3)\n",
    "\n",
    "all_y = modular_class(all_x, m)\n",
    "\n",
    "# number of data\n",
    "batch_size = 1000\n",
    "nb_epoch = 10000\n",
    "T = nb_epoch * batch_size\n",
    "\n",
    "for i in range(nb_epoch):\n",
    "    x = th.multinomial(proba, batch_size, replacement=True)\n",
    "    y = x % m\n",
    "\n",
    "    opti.zero_grad()\n",
    "    softpred = assoc(x)\n",
    "    loss = - th.log(softpred[th.arange(len(y)), y]).mean()\n",
    "    loss.backward()\n",
    "    opti.step()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "U = assoc.UT.detach().numpy()\n",
    "X, Y = np.meshgrid(np.linspace(-1,1, num=100), np.linspace(-1, 1, num=100))\n",
    "x = np.stack((X.flatten(), Y.flatten())).T\n",
    "Z = np.argmax(x @ U, axis=1).reshape(X.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAGEAAABhCAYAAADGBs+jAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAA9hAAAPYQGoP6dpAAADDElEQVR4nO2bMU4bQRSGfyKKHIEyrQsOQGfR0UGDECdAckHNCZKCJoUl2jQpcgiLjgOkIB0oBRdIQxFpU0TrbDaOPbM7M+9/s/9XWitrNN/8b96M13tN0zQQpryxHoCQBAokgQBJIEASCJAEAiSBAEkgQBII2A998PDTYc5xJOfL+5/WQwAAzL497nymyiSwCAilSgnekAQCJIEASSCgOgneNmWgQgkekQQCJIEASSBAEjIyu3gJei747kiEc36zjx+PH4BX4Dng+aqSwNCergVEoCQk5OT0Fth9afoPkpCAIau/iySMZOjq71LVnlCak9PbJN+jJAwg1eS3KAmRpBYAKAnBzC5e8O71c5bvVhICyCkAUBK20j355qSaJKQ+LY/t/WNQEjaQovePoZokpCJH97OL4CS0cT+/qTM8JctPn+AZXc2XAICrhz+f3R1dJx+QBaXLT59Ry/rq4SMA4Ph+AcBnSizKT58ks9ZNiYWQIZ1R7t4/huQzxV621pOfufePochy7ZYty5LFtPq7FJ2R1Xy5TkjpssUqADA8rJXaR0pdPYyBop3p7yOpypZl7x8DhYQ+3bIFbJfyv87IuvePgVJCn1ZKSNnysvq7uJDQsqlsdfG0+ru4vsBbzZdrMQwn36G4ltDiWQBQgYTFwRme315aD2MUriUsDs6sh5AEtxL6AjynwaWEWhLQ4k5Cvy2tAXcStv37xWtJciWhtjLU4kZCrQIAJxJqFgA4kBC7EXvcF6glHN8vgv+G6hlaCWMEeEsDpYSpJKCFTsLUBACEElIJ8FSSqCQwvSRWEhoJd0fX+Pr03XoYJlBIyCXAS0kylzDlBLSYSpCA35hKKCHAQ0kyk1D7pVwMJhIk4G+KS7AQwF6SikqY6mFsF0UlqBPaTDEJ1vsAc0kqIsFaADvZJUjAbrJKkIAwsklgfFOOdV/IImGKv46NIbkECYgnqQQPAhhLUjIJHgSwkkyCBAwniQRvd0JsJWmvaZrGehBTx/w3ZiEJFEgCAZJAgCQQIAkESAIBkkCAJBDwCx/o79mS97sxAAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 100x100 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "fig, ax = plt.subplots(figsize=(1,1))\n",
    "ax.contourf(X, Y, Z, cmap='tab10', vmin=-.5, vmax=9.5)\n",
    "ax.set_axis_off()\n",
    "fig.savefig(\"output_embeddings.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "mean_emb = th.zeros(m, 2)\n",
    "std_emb = th.zeros(m, 2)\n",
    "for i in range(m):\n",
    "    E = assoc.E[all_y == i]\n",
    "    mean_emb[i] = E.mean(dim=0)\n",
    "    std_emb[i] = E.std(dim=0)\n",
    "\n",
    "mean_emb = mean_emb.detach().numpy()\n",
    "std_emb = std_emb.detach().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAGEAAABhCAYAAADGBs+jAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAIjElEQVR4nO2c728UxxnHP7u3vrPBsV0wAf+MqZ2GqMKmxiZFVdqQoBYlpKEvGylq1OpwHBelaVWl7T8QVRVq0+Y4nFPKi6g/3lQCQiuaRAoqEEIPLjZW1aSYYpML2MUB2zg2Pu/t9sVlj9v7xf1Y+2ZP+3nFM8PujOc78zzzzKwt6bqu41BS5FJ3wMERQQgcEQTAEUEAHBEEwBFBABwRBMARQQAcEQTAEUEAHBEEQCl1B8oFXVWZGhxk4XyIqq3d1Pf1ISm5Da8jgkVMDQ4y9aoPdJ3PzpwBYN3AQE7POu7IIhbOh8A4kNb1mJ0jzkooEsMNRT7++E6hJFG1tTvndzgiFEmiGwKoaGmhds9T1Pf15fwOR4QiMbkhwN3SknMsMHBEKABdVZny+5k5+ibRmZk7FXm6IQNHhAK47vPxqf+gqawQN2TgiFAAN//wx5SyQtyQgSNCHhg7IW12NqWuEDdkIDlfW9wdY/BnDh9hKXErauB2syl0PucMORlhVoKqqQRGAoQmQ3Sv78a72Ysii9G95G1oMmt/8P2CBQBBVoKqqex9ey/BiWC8rMZdw9MPPk1fZ1/JxRj/3rPMnz2bUp4YjIsRQYipFhgJmAQAmI3McnD4ILIk09/VX5J+GW5o/tw5c4XLRf3z/UUPvoEQIoQmM5+zZKtbLu4aA3S94J1QOoQ4wOte342ElLFupTFiQFoBgIqGBkvbE2IleDd7AXjjX29wa+lWvFyW5LibWslAnXwUkYjk8bDxyGFL2xNiJSiyQn9XP898+RlTuaZrBCeCHBg6wN6396JqKhAL5P5hP963vPiH/fFyq6ja2g1S+pW5dq8XV3W1pe0JsRIMvJu9vD7yOovRxZS64ESQwEiA/q5+AiMB/EN+dHTOXovtWqwK3rqqomsaFc3NANTsfgJkmdsfDMVvzKxGKBEUWaFzXWfKTsnACNKhyRA6MXeho1savNOdC927b59l70+HEO4okQOPHaB3Qy9u2Z1StxRdQtVUUyCXkCwL3rqqcuP3h0xl6c6JrEaIZC0d3re8vH/t/ZTy5upmdrfvBmDof0OWZtfXfT6mfveqqUzyeNg0PFT0u7Mh3EowyDS7w3NhBocHkSWZwDdjMcKqXVO6e+HKrk5L3p0NYUXwbvbSWN2Yts7qOKCrKtd9PvM9MVDR1ETra69Z1k4mhBVBkRWOPHWEnvU9uCRXSr2VSVxyclbR0kL9vh/S/vfjyJWVlrWTCWFFAKhUKjm06xC9G3pN5c3VzfEEzwoy3RNbcS6UC0KLYJDuWCMwErAseTMlZwXeExeDUHlCJoxZf3T0KOG5MOG5MP4hP4AlyZuRgCV+wriS2EIE41gjNBkiPBcGzMG52ORNUhRLT0XzxRbuyCBTkrZcydtKIWyylo7EK9At924BYglb4r9FuxrNBVuJkIhvyMfB4TtnPM91PcfAltK5lGKwlTtK5NilY1ltO2FbEcoJ24pgHOJlsu2EfaJXEn2dfciSbPpOya7YNjCXE7Z1R+WEI4IAOCIIgCOCAAi7O1IjKsd8F5gKz7G2qZrG9lomxmZp7Khj6677kF3Fzx81quF79xLBsRt8paWW4NhNPpy4xYMNNRx6todK98oMj7C7o8O/DvHJR9Np67Y9uZHeJzYW3cYr71zkN+/8h3QD4HbJRKIakgQ9rXV8dWM9H4Sn2dr6BZB0zo3dRNNBlmDbxrUM7GhHKXBilGwlaFGN4LHLXDgRJnI7CkBDRy3f3teF4laYCs9lfPbq6LQlfQiO3UgrAEAkqgGxC7fg+DTB8Vibp0anUv7ve5c+BeCFnfcX1I8VFUGLapz72xgfnZ1gcV5lcd58A3bt4gzHfBfY82I39c3VGVdCY0edJf3pbVvD6dGpjELkik5M0EJZVhGMQf/w7ASRzwc8eeCTMVbA7oHOrDEhLVEVTu6HK2egdTs8/BNwZf4RB3a0A6TEhIWlKIuqltfPGtV01KhWkEuyLCZoUY3zx8e5enEaXddBAnQyzuZMND1Qx54XC7yUOfFLOPFyrGEkeOTn8MhLeb9mbiHCt145ySfTt3N+RgJ+tPNLBbmkolZCfOBHp9E1Pe8BB3ApEtFobB40dNSye6CIj62unIG4c9E/t/OnusrN6Z89xtxChF2/PcW16QVWexTu8biY+mwJNaoRTZq6xbikokQ4f3ycf755Oe/nPKsUPKsUHnhoAz2Pt1my3QSg5SH477tmuwiqq9yceunRlHJja/uXUJgrN+aB2ErobVtTUDt5i5A4+2evL9z1/7urXHhWVxCZV5dn4BNJdqzLtPlWXDIv7LyfgR3t8Tyjt21NPMbk/b58H8g2+91VCpEFc+DterSFbU9+saDO5UVUhZE/m8vCqb9xaSWGGEW/J98HkvfoNfWV1KyrorGjjq5Hm/mrf4TJyzMoFS42f6OJnsfbiu5kTpzcDzfHzGWt21em7SLJW4TGjjrC/74ZtzdtbzBlr9/5cYk+Nxk/bbbr7ottUW1A3iIYe/Sro9PZ9+wrjbpktu9pypojiETevZRdsiXnNpZzLZTdFpjyOcpWF7PbAlMeIkRV8NSYy2pbStOXAigPEf7xK1hM+DNonlp4vrBsuRSUhwjDfzLbVXXgsfYXvpeT8hBhbjK7LTjlIULyQbCYl4UZKQ8Rmnqy24JTHiK0fS27LTjlIUI4mN0WHPuLsDANV95LKJBsc3BnIOwnLznzcqs5R3B54BdXbXNuBOWwEhIFAIgu2koAKAcRJDm7bQPs1+NklFXZbRtgfxFWr81u2wB7ixBVU09Lu75bmr4Ugb1FOLkfxk/dsdsehq//tHT9KRB7i5D8cZes2G5nBHYXoXU7xP8Ej/2SNAP7TZtEjK8pEj8AtiH2z5jLAHu7ozLBEUEAHBEEwBFBABwRBMARQQAcEQTAEUEAHBEE4P9Q72bOeqQa3QAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 100x100 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "fig, ax = plt.subplots(figsize=(1,1))\n",
    "# tmp = [3, 4, 0, 1, 2]\n",
    "for i in range(m):\n",
    "    emb = assoc.E[all_y == i].detach().numpy()\n",
    "    # ax.scatter(emb[:, 0], emb[:, 1], s=5, c='C' + str(tmp[i]))\n",
    "    ax.scatter(emb[:, 0], emb[:, 1], s=5)\n",
    "ax.set_axis_off()\n",
    "fig.savefig(f\"input_embeddings_{m}.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
