{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from functools import partial\n",
    "from math import sqrt\n",
    "\n",
    "import numpy as np\n",
    "import torch as th\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def modular_class(x, m):\n",
    "    return x % m\n",
    "\n",
    "\n",
    "def squared_modular_class(x, m):\n",
    "    return (x**2) % m\n",
    "\n",
    "\n",
    "def get_embeddings(n, d, norm=True):\n",
    "    emb = th.randn(n, d)\n",
    "    if norm:\n",
    "        emb /= emb.norm(dim=1, keepdim=True)\n",
    "    else:\n",
    "        emb /= sqrt(d)\n",
    "    return emb\n",
    "\n",
    "\n",
    "def get_q(x, P, rho):\n",
    "    counts = th.bincount(x)\n",
    "    _, order = th.sort(counts, descending=True)\n",
    "    if P is None:\n",
    "        P = len(counts)\n",
    "    P = min(P, th.sum(counts!=0).item())\n",
    "    idx = order[:P]\n",
    "    q = (counts[idx] / counts[idx].sum())**rho\n",
    "    return q, idx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# number of input tokens\n",
    "n = 10\n",
    "# number of output classes\n",
    "m = 5\n",
    "# memory dimension\n",
    "# d = 50\n",
    "d = 10\n",
    "\n",
    "# number of data\n",
    "t = 1000\n",
    "\n",
    "# Zipf parameter\n",
    "alpha = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Population data\n",
    "all_x = th.arange(n)\n",
    "proba = (all_x + 1.) ** (-alpha)\n",
    "proba /= proba.sum()\n",
    "all_y = modular_class(all_x, m)\n",
    "\n",
    "# Embeddings\n",
    "E = get_embeddings(n, d, norm=False)\n",
    "U = get_embeddings(m, d, norm=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Empirical data\n",
    "x = th.multinomial(proba, t, replacement=True)\n",
    "y = all_y[x]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "P = None\n",
    "rho = 0\n",
    "q, idx = get_q(x[:t], P, rho)\n",
    "W = (E[idx].T * q) @ U[all_y[idx]]\n",
    "mat1 = E @ W @ U.T\n",
    "mat1 = F.softmax(mat1, dim=-1)\n",
    "mat1 = mat1.numpy()\n",
    "\n",
    "rho = 1\n",
    "q, idx = get_q(x[:t], P, rho)\n",
    "W = (E[idx].T * q) @ U[all_y[idx]]\n",
    "mat2 = E @ W @ U.T\n",
    "mat2 = F.softmax(mat2, dim=-1)\n",
    "mat2 = mat2.numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAUoAAAFICAYAAAA24bcOAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAF00lEQVR4nO3WO4vcdRiGYTcZkhg1FqZYREyhJh4qTWejCFpYCfYREaxNatuArYJYKAFtbGz1A8TC0lIDQkCJKIQUhpw2h/ETyD0Ww/uHva76LZ5lZ+757azX6/VDAPynA9MDAJZOKAGCUAIEoQQIQgkQhBIgCCVAEEqAIJQAYbXp4YmvP9nmjsV67st70xNGXHvx6PSEETee2pmeMOLAnekFM345f3ajOy9KgCCUAEEoAYJQAgShBAhCCRCEEiAIJUAQSoAglABBKAGCUAIEoQQIQgkQhBIgCCVAEEqAIJQAQSgBglACBKEECEIJEIQSIAglQBBKgCCUAEEoAYJQAgShBAhCCRBWmx4ev3homzsW6/I7+/PvPnxtZ3rCiEeurKcnjNg7tj//35vyogQIQgkQhBIgCCVAEEqAIJQAQSgBglACBKEECEIJEIQSIAglQBBKgCCUAEEoAYJQAgShBAhCCRCEEiAIJUAQSoAglABBKAGCUAIEoQQIQgkQhBIgCCVAEEqAIJQAYbXp4Y0nd7a5Y7F2f7o/PWHEreP78zd077H9+Tm//uy96QmLtj+/DQD/g1ACBKEECEIJEIQSIAglQBBKgCCUAEEoAYJQAgShBAhCCRCEEiAIJUAQSoAglABBKAGCUAIEoQQIQgkQhBIgCCVAEEqAIJQAQSgBglACBKEECEIJEIQSIKw2PTz693qbOxbryNW96Qkj9h49Mj1hxPPvXZqeMOLSd6emJyyaFyVAEEqAIJQAQSgBglACBKEECEIJEIQSIAglQBBKgCCUAEEoAYJQAgShBAhCCRCEEiAIJUAQSoAglABBKAGCUAIEoQQIQgkQhBIgCCVAEEqAIJQAQSgBglACBKEECKtND/95ZpszluuJr36enjDi+muvTk8Yce3jE9MTRrz/+Q/TE4ac3ejKixIgCCVAEEqAIJQAQSgBglACBKEECEIJEIQSIAglQBBKgCCUAEEoAYJQAgShBAhCCRCEEiAIJUAQSoAglABBKAGCUAIEoQQIQgkQhBIgCCVAEEqAIJQAQSgBglAChNWmh7uv/LXNHYu18/JL0xNG3D22np4w4vCvV6YnjPj0x7emJ4w498Jmd16UAEEoAYJQAgShBAhCCRCEEiAIJUAQSoAglABBKAGCUAIEoQQIQgkQhBIgCCVAEEqAIJQAQSgBglACBKEECEIJEIQSIAglQBBKgCCUAEEoAYJQAgShBAhCCRCEEiCsNj28/e3uNncs1sXvv5ieMOKNMx9MTxhx4/TT0xNGnLxwa3rCjA83O/OiBAhCCRCEEiAIJUAQSoAglABBKAGCUAIEoQQIQgkQhBIgCCVAEEqAIJQAQSgBglACBKEECEIJEIQSIAglQBBKgCCUAEEoAYJQAgShBAhCCRCEEiAIJUAQSoAglABhtenh1dMPtrljsd5+/d3pCSMuf3RwesKIUxduTk8Y8du5jVOwL3lRAgShBAhCCRCEEiAIJUAQSoAglABBKAGCUAIEoQQIQgkQhBIgCCVAEEqAIJQAQSgBglACBKEECEIJEIQSIAglQBBKgCCUAEEoAYJQAgShBAhCCRCEEiAIJUBYbXr48J8Ht7ljuR48mF4w4uQ3t6cnjLh5/sb0hBH3fz8+PWHRvCgBglACBKEECEIJEIQSIAglQBBKgCCUAEEoAYJQAgShBAhCCRCEEiAIJUAQSoAglABBKAGCUAIEoQQIQgkQhBIgCCVAEEqAIJQAQSgBglACBKEECEIJEIQSIAglQFhtenj/yDZnLNfd3cenJ4z4482j0xNG7H52eHrCiENn7kxPWDQvSoAglABBKAGCUAIEoQQIQgkQhBIgCCVAEEqAIJQAQSgBglACBKEECEIJEIQSIAglQBBKgCCUAEEoAYJQAgShBAhCCRCEEiAIJUAQSoAglABBKAGCUAIEoQQIQgkQdtbr9Xp6BMCSeVECBKEECEIJEIQSIAglQBBKgCCUAEEoAYJQAoR/AXRcVN1WbTxqAAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 400x400 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAUoAAAFICAYAAAA24bcOAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAFsUlEQVR4nO3Wv4qcdRiGYSc7q5IlIahEUKyiIFaxCahg5b9asLcXOw/HUvAwUlrZqWAnmpgEBd1oEjHZjEcg95dieD+Y66rf4hn225vfZrfb7Z4C4H+dmx4AsHZCCRCEEiAIJUAQSoAglABBKAGCUAIEoQQI26WHj2+/ts8dq/XhS1enJ4y48/nb0xNGnNw+m54w4ta70wtm/PTZF4vuvCgBglACBKEECEIJEIQSIAglQBBKgCCUAEEoAYJQAgShBAhCCRCEEiAIJUAQSoAglABBKAGCUAIEoQQIQgkQhBIgCCVAEEqAIJQAQSgBglACBKEECEIJEIQSIAglQNguPfzgk0/3OGO93vj2u+kJI379enrBjN+vHubb4dIP0wvW7TC/CoAnIJQAQSgBglACBKEECEIJEIQSIAglQBBKgCCUAEEoAYJQAgShBAhCCRCEEiAIJUAQSoAglABBKAGCUAIEoQQIQgkQhBIgCCVAEEqAIJQAQSgBglACBKEECEIJELZLD2+9c36fO1brzlfXpieMuHflbHrCiIs/Hk1PGHH31cfTE1bNixIgCCVAEEqAIJQAQSgBglACBKEECEIJEIQSIAglQBBKgCCUAEEoAYJQAgShBAhCCRCEEiAIJUAQSoAglABBKAGCUAIEoQQIQgkQhBIgCCVAEEqAIJQAQSgBwnbp4b1Xzva5Y7Vevj69YMbm0dH0hBGnb/47PWHE0zePpyesmhclQBBKgCCUAEEoAYJQAgShBAhCCRCEEiAIJUAQSoAglABBKAGCUAIEoQQIQgkQhBIgCCVAEEqAIJQAQSgBglACBKEECEIJEIQSIAglQBBKgCCUAEEoAYJQAgShBAjbxZcXH+5xxnpd+P7u9IQRN957bnrCiMvXj6cnjLj/8en0hFXzogQIQgkQhBIgCCVAEEqAIJQAQSgBglACBKEECEIJEIQSIAglQBBKgCCUAEEoAYJQAgShBAhCCRCEEiAIJUAQSoAglABBKAGCUAIEoQQIQgkQhBIgCCVAEEqAIJQAYbv48vR4jzPW6+ZHl6cnjNjszqYnjHj2j8P83b/9fHF6wqp5UQIEoQQIQgkQhBIgCCVAEEqAIJQAQSgBglACBKEECEIJEIQSIAglQBBKgCCUAEEoAYJQAgShBAhCCRCEEiAIJUAQSoAglABBKAGCUAIEoQQIQgkQhBIgCCVAEEqAsF16eHT/MJv64MXd9IQZjzbTC0Y8eP4wv/OTG4f5917qML8KgCcglABBKAGCUAIEoQQIQgkQhBIgCCVAEEqAIJQAQSgBglACBKEECEIJEIQSIAglQBBKgCCUAEEoAYJQAgShBAhCCRCEEiAIJUAQSoAglABBKAGCUAIEoQQIQgkQtksPz7/+5x5nrNelLy9MTxhx59rR9IQRRw930xNG/H3lbHrCqnlRAgShBAhCCRCEEiAIJUAQSoAglABBKAGCUAIEoQQIQgkQhBIgCCVAEEqAIJQAQSgBglACBKEECEIJEIQSIAglQBBKgCCUAEEoAYJQAgShBAhCCRCEEiAIJUDYLj18enu2zx2rdf6Xv6YnjNi9dWl6woiHJ5vpCTOeOcz/76W8KAGCUAIEoQQIQgkQhBIgCCVAEEqAIJQAQSgBglACBKEECEIJEIQSIAglQBBKgCCUAEEoAYJQAgShBAhCCRCEEiAIJUAQSoAglABBKAGCUAIEoQQIQgkQhBIgCCVA2C49vPfNC/vcsVr/vD+9YMajk8fTE4ZspgeMOHd6PD1h1bwoAYJQAgShBAhCCRCEEiAIJUAQSoAglABBKAGCUAIEoQQIQgkQhBIgCCVAEEqAIJQAQSgBglACBKEECEIJEIQSIAglQBBKgCCUAEEoAYJQAgShBAhCCRCEEiAIJUDY7Ha73fQIgDXzogQIQgkQhBIgCCVAEEqAIJQAQSgBglACBKEECP8BUhJRbExiALoAAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 400x400 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "fig, ax = plt.subplots(1, 1, figsize=(4, 4))\n",
    "c = ax.imshow(mat1, aspect='auto')\n",
    "ax.set_axis_off()\n",
    "fig.savefig('fill_mat_d.pdf')\n",
    "\n",
    "fig, ax = plt.subplots(1, 1, figsize=(4, 4))\n",
    "c = ax.imshow(mat2, aspect='auto')\n",
    "ax.set_axis_off()\n",
    "fig.savefig('weight_mat_d.pdf')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
