{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys; sys.path.append('..')\n",
    "import torchinfo\n",
    "import torch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Language Models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "================================================================================\n",
       "Layer (type:depth-idx)                                  Param #\n",
       "================================================================================\n",
       "DualAttnTransformerLM                                   --\n",
       "├─ModuleDict: 1-1                                       --\n",
       "│    └─Embedding: 2-1                                   16,384,000\n",
       "│    │    └─Linear: 3-1                                 16,416,000\n",
       "│    └─Dropout: 2-2                                     --\n",
       "│    └─SymbolicAttention: 2-3                           524,288\n",
       "│    │    └─Linear: 3-2                                 262,656\n",
       "│    └─ModuleList: 2-4                                  --\n",
       "│    │    └─DualAttnEncoderBlock: 3-3                   4,595,200\n",
       "│    │    └─DualAttnEncoderBlock: 3-4                   4,595,200\n",
       "│    │    └─DualAttnEncoderBlock: 3-5                   4,595,200\n",
       "│    │    └─DualAttnEncoderBlock: 3-6                   4,595,200\n",
       "│    │    └─DualAttnEncoderBlock: 3-7                   4,595,200\n",
       "│    │    └─DualAttnEncoderBlock: 3-8                   4,595,200\n",
       "│    └─Linear: 2-5                                      (recursive)\n",
       "================================================================================\n",
       "Total params: 61,158,144\n",
       "Trainable params: 61,158,144\n",
       "Non-trainable params: 0\n",
       "================================================================================"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from language_models import DualAttnTransformerLM\n",
    "\n",
    "dat_lm = DualAttnTransformerLM(\n",
    "    vocab_size=32_000,    # vocabulary size\n",
    "    d_model=512,          # model dimension\n",
    "    n_layers=6,           # number of layers\n",
    "    n_heads_sa=4,         # number of self-attention heads\n",
    "    n_heads_ra=4,         # number of relational attention headsd\n",
    "    dff=2048,             # feedforward intermediate dimension\n",
    "    dropout_rate=0.1,     # dropout rate\n",
    "    activation='swiglu',  # activation function of feedforward block\n",
    "    norm_first=True,      # whether to use pre-norm or post-norm\n",
    "    max_block_size=1024,  # max context length\n",
    "    symbol_retrieval='symbolic_attention', # type of symbol assignment mechanism\n",
    "    symbol_retrieval_kwargs=dict(d_model=512, n_heads=4, n_symbols=512), # kwargs for symbol assignment mechanism\n",
    "    pos_enc_type='RoPE'   # type of positional encoding to use\n",
    ")\n",
    "\n",
    "torchinfo.summary(dat_lm)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 128, 32000])"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "idx = torch.randint(0, 32_000, (1, 129))\n",
    "x, y = idx[:, :-1], idx[:, 1:]\n",
    "logits, loss = dat_lm(x, y)\n",
    "logits.shape # shape: (1, 128, 32000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[0;31mInit signature:\u001b[0m\n",
      "\u001b[0mDualAttnTransformerLM\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0mvocab_size\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0md_model\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0mn_layers\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0mn_heads_sa\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0mn_heads_ra\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0msymbol_retrieval_kwargs\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mdict\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0mdff\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0mdropout_rate\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mfloat\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0mactivation\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0mnorm_first\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mbool\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0mmax_block_size\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0msa_kwargs\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mdict\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0mra_kwargs\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mdict\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0mra_type\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mstr\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m'relational_attention'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0msymbol_retrieval\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mstr\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m'symbolic_attention'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0mpos_enc_type\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mstr\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m'pos_emb'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0mbias\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mbool\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mDocstring:\u001b[0m      Dual Attention Transformer Language Model\n",
      "\u001b[0;31mInit docstring:\u001b[0m\n",
      "Dual Attention Transformer Language Model.\n",
      "\n",
      "Parameters\n",
      "----------\n",
      "vocab_size : int\n",
      "    vocabulary size.\n",
      "d_model : int\n",
      "    model dimension.\n",
      "n_layers : int\n",
      "    number of layers.\n",
      "n_heads_sa : int\n",
      "    number of self-attention heads in dual-attention.\n",
      "n_heads_ra : int\n",
      "    number of relational attention heads in dual-attention.\n",
      "symbol_retrieval_kwargs : dict\n",
      "    keyword arguments for symbol retrieval module.\n",
      "dff : int\n",
      "    size of intermediate layer in feedforward blocks.\n",
      "dropout_rate : float\n",
      "    dropout rate.\n",
      "activation : str\n",
      "    name of activation function (e.g., 'relu', 'gelu', or 'swiglu').\n",
      "norm_first : bool\n",
      "    whether to apply layer normalization before or after attention.\n",
      "max_block_size : int\n",
      "    maximum context size.\n",
      "sa_kwargs : dict, optional\n",
      "    keyword arguments for self-attention, by default None\n",
      "ra_kwargs : dict, optional\n",
      "    keyword arguments for relational attention, by default None\n",
      "ra_type : 'relational_attention', 'rca', or 'disrca', optional\n",
      "    type of relational attention module (e.g., whether to use RCA for an ablation experiment), by default 'relational_attention'\n",
      "symbol_retrieval : 'symbolic_attention', 'position_relative', 'positional_symbols', optional\n",
      "    type of symbol retrieval module to use. this is shared across layers, by default 'symbolic_attention'\n",
      "pos_enc_type : 'pos_emb' or 'RoPE', optional\n",
      "    type of positional encoding to use, by default 'pos_emb'\n",
      "bias : bool, optional\n",
      "    whether to use bias in attention, by default True\n",
      "\u001b[0;31mFile:\u001b[0m           ~/projects/abstract_transformer/language_models.py\n",
      "\u001b[0;31mType:\u001b[0m           type\n",
      "\u001b[0;31mSubclasses:\u001b[0m     "
     ]
    }
   ],
   "source": [
    "DualAttnTransformerLM?"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Vision Models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "===========================================================================\n",
       "Layer (type:depth-idx)                             Param #\n",
       "===========================================================================\n",
       "VisionDualAttnTransformer                          101,376\n",
       "├─PositionRelativeSymbolRetriever: 1-1             --\n",
       "│    └─RelativePositionalEncoding: 2-1             202,240\n",
       "├─Sequential: 1-2                                  --\n",
       "│    └─Rearrange: 2-2                              --\n",
       "│    └─LayerNorm: 2-3                              1,536\n",
       "│    └─Linear: 2-4                                 393,728\n",
       "│    └─LayerNorm: 2-5                              1,024\n",
       "├─Dropout: 1-3                                     --\n",
       "├─ModuleList: 1-4                                  --\n",
       "│    └─DualAttnEncoderBlock: 2-6                   --\n",
       "│    │    └─Dropout: 3-1                           --\n",
       "│    │    └─LayerNorm: 3-2                         1,024\n",
       "│    │    └─DualAttention: 3-3                     1,180,672\n",
       "│    │    └─LayerNorm: 3-4                         1,024\n",
       "│    │    └─FeedForwardBlock: 3-5                  3,150,336\n",
       "│    └─DualAttnEncoderBlock: 2-7                   --\n",
       "│    │    └─Dropout: 3-6                           --\n",
       "│    │    └─LayerNorm: 3-7                         1,024\n",
       "│    │    └─DualAttention: 3-8                     1,180,672\n",
       "│    │    └─LayerNorm: 3-9                         1,024\n",
       "│    │    └─FeedForwardBlock: 3-10                 3,150,336\n",
       "│    └─DualAttnEncoderBlock: 2-8                   --\n",
       "│    │    └─Dropout: 3-11                          --\n",
       "│    │    └─LayerNorm: 3-12                        1,024\n",
       "│    │    └─DualAttention: 3-13                    1,180,672\n",
       "│    │    └─LayerNorm: 3-14                        1,024\n",
       "│    │    └─FeedForwardBlock: 3-15                 3,150,336\n",
       "│    └─DualAttnEncoderBlock: 2-9                   --\n",
       "│    │    └─Dropout: 3-16                          --\n",
       "│    │    └─LayerNorm: 3-17                        1,024\n",
       "│    │    └─DualAttention: 3-18                    1,180,672\n",
       "│    │    └─LayerNorm: 3-19                        1,024\n",
       "│    │    └─FeedForwardBlock: 3-20                 3,150,336\n",
       "│    └─DualAttnEncoderBlock: 2-10                  --\n",
       "│    │    └─Dropout: 3-21                          --\n",
       "│    │    └─LayerNorm: 3-22                        1,024\n",
       "│    │    └─DualAttention: 3-23                    1,180,672\n",
       "│    │    └─LayerNorm: 3-24                        1,024\n",
       "│    │    └─FeedForwardBlock: 3-25                 3,150,336\n",
       "│    └─DualAttnEncoderBlock: 2-11                  --\n",
       "│    │    └─Dropout: 3-26                          --\n",
       "│    │    └─LayerNorm: 3-27                        1,024\n",
       "│    │    └─DualAttention: 3-28                    1,180,672\n",
       "│    │    └─LayerNorm: 3-29                        1,024\n",
       "│    │    └─FeedForwardBlock: 3-30                 3,150,336\n",
       "├─Linear: 1-5                                      513,000\n",
       "===========================================================================\n",
       "Total params: 27,211,240\n",
       "Trainable params: 27,211,240\n",
       "Non-trainable params: 0\n",
       "==========================================================================="
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from vision_models import VisionDualAttnTransformer\n",
    "\n",
    "img_shape = (3, 224, 224)\n",
    "patch_size = (16, 16)\n",
    "n_patches = (img_shape[1] // patch_size[0]) * (img_shape[2] // patch_size[1])\n",
    "\n",
    "dat_vision = VisionDualAttnTransformer(\n",
    "    image_shape=img_shape,     # shape of input image\n",
    "    patch_size=patch_size,     # size of patch\n",
    "    num_classes=1000,          # number of classes\n",
    "    d_model=512,               # model dimension\n",
    "    n_layers=6,                # number of layers\n",
    "    n_heads_sa=4,              # number of self-attention heads\n",
    "    n_heads_ra=4,              # number of relational attention heads\n",
    "    dff=2048,                  # feedforward intermediate dimension\n",
    "    dropout_rate=0.1,          # dropout rate\n",
    "    activation='swiglu',       # activation function of feedforward block\n",
    "    norm_first=True,           # whether to use pre-norm or post-norm\n",
    "    symbol_retrieval='position_relative',\n",
    "    symbol_retrieval_kwargs=dict(symbol_dim=512, max_rel_pos=n_patches+1),\n",
    "    ra_kwargs=dict(symmetric_rels=True, use_relative_positional_symbols=True),\n",
    "    pool='cls',                # type of pooling (class token)\n",
    ")\n",
    "\n",
    "torchinfo.summary(dat_vision)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 1000])"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "img = torch.randn(1, *img_shape)\n",
    "logits = dat_vision(img)\n",
    "logits.shape # shape: (1, 1000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[0;31mInit signature:\u001b[0m\n",
      "\u001b[0mVisionDualAttnTransformer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0mimage_shape\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTuple\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mint\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0mpatch_size\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTuple\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mint\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0mnum_classes\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0md_model\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0mn_layers\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0mn_heads_sa\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0mn_heads_ra\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0mdff\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0mdropout_rate\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mfloat\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0mactivation\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0mnorm_first\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mbool\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0msymbol_retrieval\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0msymbol_retrieval_kwargs\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mdict\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0mra_type\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mstr\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m'relational_attention'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0mra_kwargs\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mdict\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0mnorm_type\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mstr\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m'layernorm'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0mbias\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mbool\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0mpool\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mstr\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m'cls'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mDocstring:\u001b[0m      Vision Dual Attention Transformer\n",
      "\u001b[0;31mInit docstring:\u001b[0m\n",
      "Vision Transformer.\n",
      "\n",
      "\n",
      "Parameters\n",
      "----------\n",
      "image_shape : Tuple[int]\n",
      "    shape of image (channels, width, height)\n",
      "patch_size : Tuple[int]\n",
      "    size of patch (width, height)\n",
      "num_classes : int\n",
      "    number of classes\n",
      "d_model : int\n",
      "    model dimension\n",
      "n_layers : int\n",
      "    number of layers\n",
      "n_heads_sa : int\n",
      "    number of self-attention heads\n",
      "n_heads_ra : int\n",
      "    number of relational attention heads\n",
      "dff : int\n",
      "    feedforward dimension\n",
      "dropout_rate : float\n",
      "    dropout rate\n",
      "activation : str\n",
      "    name of activation function in feedforward blocks\n",
      "norm_first : bool\n",
      "    whether to apply normalization before or after attention. norm_first=True means pre-norm otherwise post-norm.\n",
      "symbol_retrieval : str\n",
      "    type of symbol retrieval mechanism to use, one of 'symbolic_attention', 'rel_sym_attn', 'positional_symbols', 'position_relative'\n",
      "symbol_retrieval_kwargs : dict\n",
      "    keyword arguments for symbol retrieval mechanism\n",
      "ra_type : 'relational_attention', 'rca', or 'disrca', optional\n",
      "    type of relational attention module (e.g., whether to use RCA for an ablation experiment), by default 'relational_attention'\n",
      "ra_kwargs : dict, optional\n",
      "    relational attention kwargs, by default None\n",
      "norm_type : 'layernorm' or 'rmsnorm', optional\n",
      "    type of normalization to use, by default 'layernorm'\n",
      "bias : bool, optional\n",
      "    whether to use a bias in the encoder blocks, by default True\n",
      "pool : 'cls' or 'mean', optional\n",
      "    type of pooling to use before final class prediction. 'cks' corresponds to using a class token\n",
      "    while 'mean' corresponds to mean pooling, by default 'cls'\n",
      "\u001b[0;31mFile:\u001b[0m           ~/projects/abstract_transformer/vision_models.py\n",
      "\u001b[0;31mType:\u001b[0m           type\n",
      "\u001b[0;31mSubclasses:\u001b[0m     "
     ]
    }
   ],
   "source": [
    "VisionDualAttnTransformer?"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "abstract_transformer",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
