{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from math import sqrt\n",
    "from random import seed\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import torch as th\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "th.manual_seed(0)\n",
    "seed(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_embeddings(n, d, norm=True):\n",
    "    emb = th.randn(n, d)\n",
    "    if norm:\n",
    "        emb /= emb.norm(dim=1, keepdim=True)\n",
    "    else:\n",
    "        emb /= sqrt(d)\n",
    "    return emb\n",
    "\n",
    "\n",
    "class AssMem(nn.Module):\n",
    "    def __init__(self, E, U):\n",
    "        \"\"\"\n",
    "        E: torch.Tensor\n",
    "            Input embedding matrix of size $n \\times d$,\n",
    "            where $n$ is the number of tokens and $d$ is the embedding dimension.\n",
    "        U: torch.Tensor\n",
    "            Output unembedding matrix of size $d \\times m$,\n",
    "            where $m$ is the number of classes and $d$ is the embedding dimension.\n",
    "        \"\"\"\n",
    "        super().__init__()\n",
    "        d = E.shape[1]\n",
    "        self.W = nn.Parameter(th.zeros(d, d))\n",
    "        self.E = E\n",
    "        self.U = U\n",
    "\n",
    "    def forward(self, x):\n",
    "        out = self.E[x] @ self.W\n",
    "        out = out @ self.U\n",
    "        return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# number of input tokens\n",
    "n = 10\n",
    "# number of output classes\n",
    "m = 5\n",
    "# memory dimension\n",
    "d = 5\n",
    "\n",
    "alpha = 1.5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_x = th.arange(n)\n",
    "proba = (all_x + 1.) ** (-alpha)\n",
    "proba /= proba.sum()\n",
    "all_y = all_x % m"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# number of data\n",
    "batch_size = 1\n",
    "nb_epoch = 1000\n",
    "T = nb_epoch * batch_size\n",
    "lr = 1e-1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Embeddings\n",
    "E = get_embeddings(n, d, norm=False)\n",
    "U = get_embeddings(m, d, norm=True).T \n",
    "\n",
    "# models\n",
    "assoc = AssMem(E, U)\n",
    "opti = th.optim.SGD(assoc.parameters(), lr=lr, momentum=0)\n",
    "\n",
    "train_loss = []\n",
    "test_loss = []\n",
    "\n",
    "for i in range(nb_epoch):\n",
    "    x = th.multinomial(proba, batch_size, replacement=True)\n",
    "    y = x % m\n",
    "\n",
    "    out = assoc(x)\n",
    "    loss = F.cross_entropy(out, y)\n",
    "    train_loss.append(loss.item())\n",
    "\n",
    "    with th.no_grad():\n",
    "        pred = assoc(all_x).argmax(dim=-1)\n",
    "        test_loss.append(proba[pred != all_y].sum().item())\n",
    "\n",
    "\n",
    "    with th.no_grad():\n",
    "        mat = assoc.E @ assoc.W @ assoc.U\n",
    "        mat = F.softmax(mat, dim=-1)\n",
    "        mat = mat.numpy()\n",
    "\n",
    "    if mat[x, y] < .8:\n",
    "        fig, ax = plt.subplots(1, 1, figsize=(4, 4))\n",
    "        c = ax.imshow(mat, aspect='auto')\n",
    "        ax.add_patch(plt.Rectangle((y.item() - .5, x.item() - .5), 1, 1, fill=False, edgecolor='red', lw=2))\n",
    "        ax.set_axis_off()\n",
    "        fig.savefig(f'sgd/mat_step{i}_0.png')\n",
    "\n",
    "    opti.zero_grad()\n",
    "    loss.backward()\n",
    "    opti.step()\n",
    "\n",
    "    if mat[x, y] < .8:\n",
    "        with th.no_grad():\n",
    "            mat = assoc.E @ assoc.W @ assoc.U\n",
    "            mat = F.softmax(mat, dim=-1)\n",
    "            mat = mat.numpy()\n",
    "\n",
    "        fig, ax = plt.subplots(1, 1, figsize=(4, 4))\n",
    "        c = ax.imshow(mat, aspect='auto')\n",
    "        ax.add_patch(plt.Rectangle((y.item() - .5, x.item() - .5), 1, 1, fill=False, edgecolor='red', lw=2))\n",
    "        ax.set_axis_off()\n",
    "        fig.savefig(f'sgd/mat_step{i}_1.png')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
