{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Demo of `RelationalAttention`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys; sys.path.append('..')\n",
    "from relational_attention import RelationalAttention\n",
    "from symbol_retrieval import PositionRelativeSymbolRetriever, SymbolicAttention\n",
    "import torch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Documentation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[0;31mInit signature:\u001b[0m\n",
      "\u001b[0mRelationalAttention\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0md_model\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0mn_heads\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0mn_relations\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mint\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0mdropout\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mfloat\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0.0\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0mn_kv_heads\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mint\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0mrel_activation\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mstr\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m'identity'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0mrel_proj_dim\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mint\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0madd_bias_kv\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mbool\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0madd_bias_out\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mbool\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0mtotal_n_heads\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mint\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0msymmetric_rels\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mbool\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0muse_relative_positional_symbols\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mbool\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mDocstring:\u001b[0m     \n",
      "Base class for all neural network modules.\n",
      "\n",
      "Your models should also subclass this class.\n",
      "\n",
      "Modules can also contain other Modules, allowing to nest them in\n",
      "a tree structure. You can assign the submodules as regular attributes::\n",
      "\n",
      "    import torch.nn as nn\n",
      "    import torch.nn.functional as F\n",
      "\n",
      "    class Model(nn.Module):\n",
      "        def __init__(self):\n",
      "            super().__init__()\n",
      "            self.conv1 = nn.Conv2d(1, 20, 5)\n",
      "            self.conv2 = nn.Conv2d(20, 20, 5)\n",
      "\n",
      "        def forward(self, x):\n",
      "            x = F.relu(self.conv1(x))\n",
      "            return F.relu(self.conv2(x))\n",
      "\n",
      "Submodules assigned in this way will be registered, and will have their\n",
      "parameters converted too when you call :meth:`to`, etc.\n",
      "\n",
      ".. note::\n",
      "    As per the example above, an ``__init__()`` call to the parent class\n",
      "    must be made before assignment on the child.\n",
      "\n",
      ":ivar training: Boolean represents whether this module is in training or\n",
      "                evaluation mode.\n",
      ":vartype training: bool\n",
      "\u001b[0;31mInit docstring:\u001b[0m\n",
      "An implementation of Relational Attention (RA).\n",
      "\n",
      "Relational attention defines a differentiable information-retrieval operation where the information retrieved\n",
      "is the relations between objects. The \"message\" sent from one object to another is the relations between the\n",
      "sender and the receiver, tagged with a symbol identifying the sender. These messages are aggregated based on the\n",
      "receiver's features via softmax attention scores.\n",
      "\n",
      "The learnable parameters include a set of query/key projections which determine the attention scores, and hence\n",
      "the ``selection criteria'', as well as a set of query/key projections for computing relations between objects.\n",
      "They also include per-head projections for the symbols and relations, as well as a final output projection.\n",
      "\n",
      "Compared to RCA (below), RA disentangles the \"attention\" operation from the computation of relations.\n",
      "\n",
      "This module supports symmetric relations, position-relative symbolic embeddings,\n",
      "multi-query attention/grouped query attention, and control over total number of heads (for use with \"dual attention\").\n",
      "\n",
      "Parameters\n",
      "----------\n",
      "d_model : int\n",
      "    model dimension\n",
      "n_heads : int\n",
      "    number of attention heads (query heads if n_kv_heads is set)\n",
      "n_relations : int, optional\n",
      "    number of relations. If None, n_relations = n_heads. By default None\n",
      "dropout : float, optional\n",
      "    dropout rate. By default 0.0\n",
      "n_kv_heads : int, optional\n",
      "    number of key/value heads. used to implement multi-query attention or grouped query attention.\n",
      "    n_kv_heads=1 corresponds to MQA, n_kv_heads > 1 corresponsd to grouped query attention.\n",
      "    n_kv_heads=n_heads is standard MHA. uses MHA when None. By default None\n",
      "rel_activation : str, optional\n",
      "    name of activation function applied to relations. By default 'identity'.\n",
      "rel_proj_dim : int, optional\n",
      "    dimension of relation projections. If None, rel_proj_dim = d_model // n_relations. By default None.\n",
      "add_bias_kv : bool, optional\n",
      "    whether to use bias in key/value projections, by default False\n",
      "add_bias_out : bool, optional\n",
      "    whether to use bias in out projection, by default False\n",
      "total_n_heads : int, optional\n",
      "    total number of heads in dual attention (if using dual attention).\n",
      "    used to ensure that concat(A, E) is of dimension d_model after concatentation.\n",
      "    hence, output dimension is (d_model // total_heads) * n_heads.\n",
      "    if None, total_heads = n_heads and output dimension is d_model\n",
      "\u001b[0;31mFile:\u001b[0m           ~/projects/abstract_transformer/relational_attention.py\n",
      "\u001b[0;31mType:\u001b[0m           type\n",
      "\u001b[0;31mSubclasses:\u001b[0m     "
     ]
    }
   ],
   "source": [
    "RelationalAttention?"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Relational Attention with Symbolic Attention"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "d, l = 64, 10\n",
    "\n",
    "symbol_retriever = SymbolicAttention(d_model=d, n_heads=4, n_symbols=10)\n",
    "relational_attn = RelationalAttention(d_model=d, n_heads=4, use_relative_positional_symbols=False, symmetric_rels=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "x: torch.Size([1, 10, 64]), s: torch.Size([1, 10, 64]), y: torch.Size([1, 10, 64])\n"
     ]
    }
   ],
   "source": [
    "x = torch.randn(1, 10, d)\n",
    "s = symbol_retriever(x)\n",
    "y, *_ = relational_attn(x, s)\n",
    "print(f'x: {x.shape}, s: {s.shape}, y: {y.shape}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Relational Attention with Position-Relative Symbols"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "d, l = 64, 10\n",
    "\n",
    "symbol_retriever = PositionRelativeSymbolRetriever(symbol_dim=d, max_rel_pos=l)\n",
    "relational_attn = RelationalAttention(d_model=d, n_heads=4, use_relative_positional_symbols=True, symmetric_rels=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "x: torch.Size([1, 10, 64]), s: torch.Size([10, 10, 64]), y: torch.Size([1, 10, 64])\n"
     ]
    }
   ],
   "source": [
    "x = torch.randn(1, 10, d)\n",
    "s = symbol_retriever(x)\n",
    "y, *_ = relational_attn(x, s)\n",
    "print(f'x: {x.shape}, s: {s.shape}, y: {y.shape}')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "abstract_transformer",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
