{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Product-Key Memory (PKM)\n",
    "**Minimalist implementation of a Product-Key Memory layer** https://arxiv.org/abs/1907.05242\n",
    "\n",
    "This notebook contains a simple implementation of a PKM layer.\n",
    "<br>\n",
    "Overall, the PKM layer can be seen as a network with very high capacity that maps elements from $R^d$ to $R^n$, but very efficiently.\n",
    "<br>\n",
    "In particular, a 12-layer transformer model that leverages a PKM layer outperforms a 24-layer model without memory, and is almost twice faster at inference.\n",
    "\n",
    "A more detailed implementation can be found at https://github.com/facebookresearch/XLM/tree/master/src/model/memory,\n",
    "with options to make the query network more powerful, to shuffle the key indices, to compute the value scores differently\n",
    "than with a softmax, etc., but the code below is much simpler and implements a configuration that worked well in our experiments (and that we used to report the majority of our results).\n",
    "\n",
    "#### Note: at training time, we recommend to use a different optimizer for the values, as these are learned with sparse updates. In particular, we obtained our best performance with the Adam optimizer, and a constant learning rate of 1e-3 to learn the values, independently of the optimizer / learning rate used to learn the rest of the network."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import math\n",
    "import numpy as np\n",
    "import torch\n",
    "from torch import nn\n",
    "from torch.nn import functional as F"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def get_uniform_keys(n_keys, dim, seed):\n",
    "    \"\"\"\n",
    "    Generate random uniform keys (same initialization as nn.Linear).\n",
    "    \"\"\"\n",
    "    rng = np.random.RandomState(seed)\n",
    "    bound = 1 / math.sqrt(dim)\n",
    "    keys = rng.uniform(-bound, bound, (n_keys, dim))\n",
    "    return keys.astype(np.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class HashingMemory(nn.Module):\n",
    "\n",
    "    def __init__(self, input_dim, output_dim, params):\n",
    "\n",
    "        super().__init__()\n",
    "\n",
    "        # global parameters\n",
    "        self.input_dim = input_dim\n",
    "        self.output_dim = output_dim\n",
    "        self.k_dim = params.k_dim\n",
    "        self.v_dim = output_dim\n",
    "        self.n_keys = params.n_keys\n",
    "        self.size = self.n_keys ** 2\n",
    "        self.heads = params.heads\n",
    "        self.knn = params.knn\n",
    "        assert self.k_dim >= 2 and self.k_dim % 2 == 0\n",
    "\n",
    "        # dropout\n",
    "        self.input_dropout = params.input_dropout\n",
    "        self.query_dropout = params.query_dropout\n",
    "        self.value_dropout = params.value_dropout\n",
    "\n",
    "        # initialize keys / values\n",
    "        self.initialize_keys()\n",
    "        self.values = nn.EmbeddingBag(self.size, self.v_dim, mode='sum', sparse=params.sparse)\n",
    "        nn.init.normal_(self.values.weight, mean=0, std=self.v_dim ** -0.5)\n",
    "\n",
    "        # query network\n",
    "        self.query_proj = nn.Sequential(*filter(None, [\n",
    "            nn.Linear(self.input_dim, self.heads * self.k_dim, bias=True),\n",
    "            nn.BatchNorm1d(self.heads * self.k_dim) if params.query_batchnorm else None\n",
    "        ]))\n",
    "\n",
    "        if params.query_batchnorm:\n",
    "            print(\"WARNING: Applying batch normalization to queries improves the performance \"\n",
    "                  \"and memory usage. But if you use it, be sure that you use batches of \"\n",
    "                  \"sentences with the same size at training time (i.e. without padding). \"\n",
    "                  \"Otherwise, the padding token will result in incorrect mean/variance \"\n",
    "                  \"estimations in the BatchNorm layer.\\n\")\n",
    "\n",
    "    def initialize_keys(self):\n",
    "        \"\"\"\n",
    "        Create two subkey sets per head.\n",
    "        `self.keys` is of shape (heads, 2, n_keys, k_dim // 2)\n",
    "        \"\"\"\n",
    "        half = self.k_dim // 2\n",
    "        keys = nn.Parameter(torch.from_numpy(np.array([\n",
    "            get_uniform_keys(self.n_keys, half, seed=(2 * i + j))\n",
    "            for i in range(self.heads)\n",
    "            for j in range(2)\n",
    "        ])).view(self.heads, 2, self.n_keys, half))\n",
    "        self.keys = nn.Parameter(keys)\n",
    "\n",
    "    def _get_indices(self, query, subkeys):\n",
    "        \"\"\"\n",
    "        Generate scores and indices for a specific head.\n",
    "        \"\"\"\n",
    "        assert query.dim() == 2 and query.size(1) == self.k_dim\n",
    "        bs = query.size(0)\n",
    "        knn = self.knn\n",
    "        half = self.k_dim // 2\n",
    "        n_keys = len(subkeys[0])\n",
    "\n",
    "        # split query for product quantization\n",
    "        q1 = query[:, :half]                                          # (bs,half)\n",
    "        q2 = query[:, half:]                                          # (bs,half)\n",
    "\n",
    "        # compute indices with associated scores\n",
    "        scores1 = F.linear(q1, subkeys[0], bias=None)                 # (bs,n_keys)\n",
    "        scores2 = F.linear(q2, subkeys[1], bias=None)                 # (bs,n_keys)\n",
    "        scores1, indices1 = scores1.topk(knn, dim=1)                  # (bs,knn)\n",
    "        scores2, indices2 = scores2.topk(knn, dim=1)                  # (bs,knn)\n",
    "\n",
    "        # cartesian product on best candidate keys\n",
    "        all_scores = (\n",
    "            scores1.view(bs, knn, 1).expand(bs, knn, knn) +\n",
    "            scores2.view(bs, 1, knn).expand(bs, knn, knn)\n",
    "        ).view(bs, -1)                                                # (bs,knn**2)\n",
    "        all_indices = (\n",
    "            indices1.view(bs, knn, 1).expand(bs, knn, knn) * n_keys +\n",
    "            indices2.view(bs, 1, knn).expand(bs, knn, knn)\n",
    "        ).view(bs, -1)                                                # (bs,knn**2)\n",
    "\n",
    "        # select best scores with associated indices\n",
    "        scores, best_indices = torch.topk(all_scores, k=knn, dim=1)   # (bs,knn)\n",
    "        indices = all_indices.gather(1, best_indices)                 # (bs,knn)\n",
    "\n",
    "        assert scores.shape == indices.shape == (bs, knn)\n",
    "        return scores, indices\n",
    "\n",
    "    def get_indices(self, query):\n",
    "        \"\"\"\n",
    "        Generate scores and indices.\n",
    "        \"\"\"\n",
    "        assert query.dim() == 2 and query.size(1) == self.k_dim\n",
    "        query = query.view(-1, self.heads, self.k_dim)\n",
    "        bs = len(query)\n",
    "        outputs = [self._get_indices(query[:, i], self.keys[i]) for i in range(self.heads)]\n",
    "        s = torch.cat([s.view(bs, 1, self.knn) for s, _ in outputs], 1)  # (bs,heads,knn)\n",
    "        i = torch.cat([i.view(bs, 1, self.knn) for _, i in outputs], 1)  # (bs,heads,knn)\n",
    "        return s.view(-1, self.knn), i.view(-1, self.knn)\n",
    "\n",
    "    def forward(self, input):\n",
    "        \"\"\"\n",
    "        Read from the memory.\n",
    "        \"\"\"\n",
    "        # input dimensions\n",
    "        assert input.shape[-1] == self.input_dim\n",
    "        prefix_shape = input.shape[:-1]\n",
    "        bs = np.prod(prefix_shape)\n",
    "\n",
    "        # compute query\n",
    "        input = F.dropout(input, p=self.input_dropout, training=self.training)  # (...,i_dim)\n",
    "        query = self.query_proj(input.contiguous().view(-1, self.input_dim))    # (bs,heads*k_dim)\n",
    "        query = query.view(bs * self.heads, self.k_dim)                         # (bs*heads,k_dim)\n",
    "        query = F.dropout(query, p=self.query_dropout, training=self.training)  # (bs*heads,k_dim)\n",
    "        assert query.shape == (bs * self.heads, self.k_dim)\n",
    "\n",
    "        # retrieve indices and scores\n",
    "        scores, indices = self.get_indices(query)                               # (bs*heads,knn)\n",
    "        scores = F.softmax(scores.float(), dim=-1).type_as(scores)              # (bs*heads,knn)\n",
    "\n",
    "        # merge heads / knn (since we sum heads)\n",
    "        indices = indices.view(bs, self.heads * self.knn)                       # (bs,heads*knn)\n",
    "        scores = scores.view(bs, self.heads * self.knn)                         # (bs,heads*knn)\n",
    "\n",
    "        # weighted sum of values\n",
    "        output = self.values(indices, per_sample_weights=scores)                # (bs,v_dim)\n",
    "        output = F.dropout(output, p=self.value_dropout, training=self.training)# (bs,v_dim)\n",
    "\n",
    "        # reshape output\n",
    "        if len(prefix_shape) >= 2:\n",
    "            output = output.view(prefix_shape + (self.v_dim,))                  # (...,v_dim)\n",
    "\n",
    "        return output\n",
    "\n",
    "    @staticmethod\n",
    "    def register_args(parser):\n",
    "        \"\"\"\n",
    "        Register memory parameters.\n",
    "        \"\"\"\n",
    "        # memory parameters\n",
    "        parser.add_argument(\"--sparse\", type=bool_flag, default=False,\n",
    "                            help=\"Perform sparse updates for the values\")\n",
    "        parser.add_argument(\"--k_dim\", type=int, default=256,\n",
    "                            help=\"Memory keys dimension\")\n",
    "        parser.add_argument(\"--heads\", type=int, default=4,\n",
    "                            help=\"Number of memory heads\")\n",
    "        parser.add_argument(\"--knn\", type=int, default=32,\n",
    "                            help=\"Number of memory slots to read / update - k-NN to the query\")\n",
    "        parser.add_argument(\"--n_keys\", type=int, default=512,\n",
    "                            help=\"Number of keys\")\n",
    "        parser.add_argument(\"--query_batchnorm\", type=bool_flag, default=False,\n",
    "                            help=\"Query MLP batch norm\")\n",
    "\n",
    "        # dropout\n",
    "        parser.add_argument(\"--input_dropout\", type=float, default=0,\n",
    "                            help=\"Input dropout\")\n",
    "        parser.add_argument(\"--query_dropout\", type=float, default=0,\n",
    "                            help=\"Query dropout\")\n",
    "        parser.add_argument(\"--value_dropout\", type=float, default=0,\n",
    "                            help=\"Value dropout\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class AttrDict(dict):\n",
    "    def __init__(self, *args, **kwargs):\n",
    "        super(AttrDict, self).__init__(*args, **kwargs)\n",
    "        self.__dict__ = self\n",
    "\n",
    "\n",
    "params = AttrDict({\n",
    "    \"sparse\": False,\n",
    "    \"k_dim\": 128,\n",
    "    \"heads\": 4,\n",
    "    \"knn\": 32,\n",
    "    \"n_keys\": 512,  # the memory will have (n_keys ** 2) values\n",
    "    \"query_batchnorm\": True,\n",
    "    \"input_dropout\": 0,\n",
    "    \"query_dropout\": 0,\n",
    "    \"value_dropout\": 0,\n",
    "})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING: Applying batch normalization to queries improves the performance and memory usage. But if you use it, be sure that you use batches of sentences with the same size at training time (i.e. without padding). Otherwise, the padding token will result in incorrect mean/variance estimations in the BatchNorm layer.\n",
      "\n",
      "HashingMemory(\n",
      "  (values): EmbeddingBag(262144, 100, mode=sum)\n",
      "  (query_proj): Sequential(\n",
      "    (0): Linear(in_features=50, out_features=512, bias=True)\n",
      "    (1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "  )\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "device = 'cuda'  # cpu / cuda\n",
    "input_dim = 50\n",
    "output_dim = 100\n",
    "memory = HashingMemory(input_dim, output_dim, params).to(device=device)\n",
    "print(memory)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.14277362823486328\n",
      "torch.Size([2, 3, 4, 100])\n"
     ]
    }
   ],
   "source": [
    "x = torch.randn(2, 3, 4, input_dim).to(device=device)\n",
    "output = memory(x)\n",
    "print(output.sum().item())\n",
    "print(output.shape)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
