{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Exploring Symbolic Attention"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchinfo\n",
    "from mechanistic_forward import symbolic_attn_forward_get_weights, block_forward_get_weights, datlm_forward_w_intermediate_results\n",
    "\n",
    "import plotly.express as px\n",
    "import seaborn as sns\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import sys; sys.path.append('..')\n",
    "from dual_attention_transformer import DualAttnTransformerLM\n",
    "from language_models import TransformerLM\n",
    "\n",
    "import warnings\n",
    "warnings.simplefilter(action='ignore', category=FutureWarning)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_from_ckpt(ckpt_path):\n",
    "    ckpt = torch.load(ckpt_path)\n",
    "    model_config = ckpt['config']\n",
    "\n",
    "    model_state_dict = ckpt['model']\n",
    "    model_state_dict = {k.replace('_orig_mod.', ''): v for k, v in model_state_dict.items()}\n",
    "\n",
    "    if 'n_heads_ra' in model_config:\n",
    "        model = DualAttnTransformerLM(**model_config)\n",
    "    else:\n",
    "        model = TransformerLM(**model_config)\n",
    "\n",
    "    model.load_state_dict(model_state_dict)\n",
    "\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# model_path = '../log/DAT-ra8sa8nr32-ns512sh1-368M_2024_07_15_22_00_26/model_05000.pt'\n",
    "model_path = '../log/DAT-ra8sa8nr32-368M_2024_07_11_18_17_07/model_10000.pt'\n",
    "# model_path = '../log/DAT-sa16ra16nr64-1.3B_2024_07_13_21_59_55/model_05000.pt'\n",
    "\n",
    "model_dat = load_from_ckpt(model_path).to('cuda')\n",
    "model_dat = model_dat.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "================================================================================\n",
       "Layer (type:depth-idx)                                  Param #\n",
       "================================================================================\n",
       "DualAttnTransformerLM                                   --\n",
       "├─ModuleDict: 1-1                                       --\n",
       "│    └─Embedding: 2-1                                   51,511,296\n",
       "│    └─Dropout: 2-2                                     --\n",
       "│    └─ModuleList: 2-3                                  --\n",
       "│    │    └─SymbolicAttention: 3-1                      3,146,752\n",
       "│    │    └─SymbolicAttention: 3-2                      3,146,752\n",
       "│    │    └─SymbolicAttention: 3-3                      3,146,752\n",
       "│    │    └─SymbolicAttention: 3-4                      3,146,752\n",
       "│    │    └─SymbolicAttention: 3-5                      3,146,752\n",
       "│    │    └─SymbolicAttention: 3-6                      3,146,752\n",
       "│    │    └─SymbolicAttention: 3-7                      3,146,752\n",
       "│    │    └─SymbolicAttention: 3-8                      3,146,752\n",
       "│    │    └─SymbolicAttention: 3-9                      3,146,752\n",
       "│    │    └─SymbolicAttention: 3-10                     3,146,752\n",
       "│    │    └─SymbolicAttention: 3-11                     3,146,752\n",
       "│    │    └─SymbolicAttention: 3-12                     3,146,752\n",
       "│    │    └─SymbolicAttention: 3-13                     3,146,752\n",
       "│    │    └─SymbolicAttention: 3-14                     3,146,752\n",
       "│    │    └─SymbolicAttention: 3-15                     3,146,752\n",
       "│    │    └─SymbolicAttention: 3-16                     3,146,752\n",
       "│    │    └─SymbolicAttention: 3-17                     3,146,752\n",
       "│    │    └─SymbolicAttention: 3-18                     3,146,752\n",
       "│    │    └─SymbolicAttention: 3-19                     3,146,752\n",
       "│    │    └─SymbolicAttention: 3-20                     3,146,752\n",
       "│    │    └─SymbolicAttention: 3-21                     3,146,752\n",
       "│    │    └─SymbolicAttention: 3-22                     3,146,752\n",
       "│    │    └─SymbolicAttention: 3-23                     3,146,752\n",
       "│    │    └─SymbolicAttention: 3-24                     3,146,752\n",
       "│    └─ModuleList: 2-4                                  --\n",
       "│    │    └─DualAttnEncoderBlock: 3-25                  13,132,800\n",
       "│    │    └─DualAttnEncoderBlock: 3-26                  13,132,800\n",
       "│    │    └─DualAttnEncoderBlock: 3-27                  13,132,800\n",
       "│    │    └─DualAttnEncoderBlock: 3-28                  13,132,800\n",
       "│    │    └─DualAttnEncoderBlock: 3-29                  13,132,800\n",
       "│    │    └─DualAttnEncoderBlock: 3-30                  13,132,800\n",
       "│    │    └─DualAttnEncoderBlock: 3-31                  13,132,800\n",
       "│    │    └─DualAttnEncoderBlock: 3-32                  13,132,800\n",
       "│    │    └─DualAttnEncoderBlock: 3-33                  13,132,800\n",
       "│    │    └─DualAttnEncoderBlock: 3-34                  13,132,800\n",
       "│    │    └─DualAttnEncoderBlock: 3-35                  13,132,800\n",
       "│    │    └─DualAttnEncoderBlock: 3-36                  13,132,800\n",
       "│    │    └─DualAttnEncoderBlock: 3-37                  13,132,800\n",
       "│    │    └─DualAttnEncoderBlock: 3-38                  13,132,800\n",
       "│    │    └─DualAttnEncoderBlock: 3-39                  13,132,800\n",
       "│    │    └─DualAttnEncoderBlock: 3-40                  13,132,800\n",
       "│    │    └─DualAttnEncoderBlock: 3-41                  13,132,800\n",
       "│    │    └─DualAttnEncoderBlock: 3-42                  13,132,800\n",
       "│    │    └─DualAttnEncoderBlock: 3-43                  13,132,800\n",
       "│    │    └─DualAttnEncoderBlock: 3-44                  13,132,800\n",
       "│    │    └─DualAttnEncoderBlock: 3-45                  13,132,800\n",
       "│    │    └─DualAttnEncoderBlock: 3-46                  13,132,800\n",
       "│    │    └─DualAttnEncoderBlock: 3-47                  13,132,800\n",
       "│    │    └─DualAttnEncoderBlock: 3-48                  13,132,800\n",
       "│    └─LayerNorm: 2-5                                   2,048\n",
       "│    └─Linear: 2-6                                      51,511,296\n",
       "================================================================================\n",
       "Total params: 493,733,888\n",
       "Trainable params: 493,733,888\n",
       "Non-trainable params: 0\n",
       "================================================================================"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torchinfo.summary(model_dat)\n",
    "# torchinfo.summary(model_dat, input_data=torch.zeros((1, 1024), dtype=torch.int32).cuda(), device='cuda')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "ename": "AttributeError",
     "evalue": "'ModuleDict' object has no attribute 'symbol_retriever'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mAttributeError\u001b[0m                            Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[5], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m symbol_retriever \u001b[38;5;241m=\u001b[39m \u001b[43mmodel_dat\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlayers\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msymbol_retriever\u001b[49m\n\u001b[1;32m      2\u001b[0m n_heads \u001b[38;5;241m=\u001b[39m symbol_retriever\u001b[38;5;241m.\u001b[39mn_heads\n\u001b[1;32m      3\u001b[0m n_symbols \u001b[38;5;241m=\u001b[39m symbol_retriever\u001b[38;5;241m.\u001b[39mn_symbols\n",
      "File \u001b[0;32m~/.conda/envs/abstract_transformer/lib/python3.11/site-packages/torch/nn/modules/module.py:1688\u001b[0m, in \u001b[0;36mModule.__getattr__\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m   1686\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m name \u001b[38;5;129;01min\u001b[39;00m modules:\n\u001b[1;32m   1687\u001b[0m         \u001b[38;5;28;01mreturn\u001b[39;00m modules[name]\n\u001b[0;32m-> 1688\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mAttributeError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mtype\u001b[39m(\u001b[38;5;28mself\u001b[39m)\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m object has no attribute \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mname\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
      "\u001b[0;31mAttributeError\u001b[0m: 'ModuleDict' object has no attribute 'symbol_retriever'"
     ]
    }
   ],
   "source": [
    "symbol_retriever = model_dat.layers.symbol_retriever\n",
    "n_heads = symbol_retriever.n_heads\n",
    "n_symbols = symbol_retriever.n_symbols\n",
    "symbol_library = symbol_retriever.symbol_library.data\n",
    "embedding_layer = model_dat.layers.token_embedder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vocab_size = 50257\n",
    "embs = embedding_layer(torch.arange(vocab_size).to('cuda').unsqueeze(0))\n",
    "symbs = symbol_retriever(embs)\n",
    "\n",
    "symbs2, symbol_attn_weights = symbolic_attn_forward_get_weights(symbol_retriever, embs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_summary_stats_attn_weights(attn_weights, symbs):\n",
    "    '''compute summary statistics of attention weights'''\n",
    "\n",
    "    if type(attn_weights) == torch.Tensor:\n",
    "        attn_weights = attn_weights.cpu().detach().numpy()\n",
    "\n",
    "    attn_max_weight = attn_weights.max(axis=-1)\n",
    "    attn_min_weight = attn_weights.min(axis=-1)\n",
    "    attn_mean_weight = attn_weights.mean(axis=-1)\n",
    "\n",
    "    attn_argmax = attn_weights.argmax(axis=-1)\n",
    "\n",
    "    attn_entropy = -(attn_weights * np.log(attn_weights)).sum(axis=-1)\n",
    "    unif_entropy = np.log(attn_weights.shape[-1])\n",
    "\n",
    "    token_symbol_similarity = torch.matmul(symbs, symbs.transpose(1, 2))\n",
    "    similarity_sample = token_symbol_similarity.flatten()[np.random.choice(token_symbol_similarity.numel(), 1000)].cpu().detach().numpy()\n",
    "\n",
    "    symbol_similarity = torch.matmul(symbol_retriever.symbol_library, symbol_retriever.symbol_library.transpose(0, 1)).cpu().detach().numpy()\n",
    "\n",
    "    attn_stats = dict(\n",
    "        attn_max_weight=attn_max_weight,\n",
    "        attn_min_weight=attn_min_weight,\n",
    "        attn_mean_weight=attn_mean_weight,\n",
    "        attn_entropy=attn_entropy,\n",
    "        attn_argmax=attn_argmax,\n",
    "        unif_entropy=unif_entropy,\n",
    "        similarity_sample=similarity_sample,\n",
    "        symbol_similarity=symbol_similarity\n",
    "    )\n",
    "\n",
    "    return attn_stats\n",
    "\n",
    "def plot_symb_attn_stats(symb_attn_stats, n_heads, n_symbols):\n",
    "\n",
    "    figs, axs = plt.subplots(nrows=n_heads, ncols=4, figsize=(20, 5 * n_heads))\n",
    "    if np.shape(axs) == (4,):\n",
    "        axs = axs.reshape(1, 4)\n",
    "\n",
    "    for i in range(n_heads):\n",
    "        sns.histplot(symb_attn_stats['attn_max_weight'][i], ax=axs[i, 0])\n",
    "        axs[i, 0].axvline(1 / n_symbols, color='red', linestyle='--')\n",
    "        axs[i, 0].set_title(f'Max Attention Weight (Head {i})')\n",
    "\n",
    "        sns.histplot(symb_attn_stats['attn_argmax'][i], ax=axs[i, 1])\n",
    "        axs[i, 1].set_title(f'Argmax Attention Weight (Head {i})')\n",
    "\n",
    "        sns.histplot(symb_attn_stats['attn_entropy'][i], ax=axs[i, 2])\n",
    "        axs[i, 2].axvline(symb_attn_stats['unif_entropy'], color='red', linestyle='--')\n",
    "        axs[i, 2].set_title(f'Entropy of Attention Weights (Head {i})')\n",
    "\n",
    "        sns.histplot(symb_attn_stats['similarity_sample'], ax=axs[i, 3])\n",
    "        axs[i, 3].set_title(f'Similarity between Tokens (Head {i})')\n",
    "\n",
    "    plt.tight_layout()\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Layer 0: Symbol Assignment of Tokens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f'# of symbols: {n_symbols}')\n",
    "print(f'# of symbolic attention heads: {n_heads}')\n",
    "print(f'# of tokens: {vocab_size}')\n",
    "print(f'd_model: {model_dat.d_model}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# layer 1 (independently process token)\n",
    "\n",
    "symb_attn_stats = compute_summary_stats_attn_weights(symbol_attn_weights[0], symbs)\n",
    "\n",
    "plot_symb_attn_stats(symb_attn_stats, n_heads=n_heads, n_symbols=n_symbols)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Layerwise Visualization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tiktoken\n",
    "\n",
    "enc = tiktoken.get_encoding(\"gpt2\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "string = \"\"\"\n",
    "A finite-state machine (FSM) or finite-state automaton (FSA, plural: automata), finite automaton, or simply a state machine, is a mathematical model of computation. It is an abstract machine that can be in exactly one of a finite number of states at any given time. The FSM can change from one state to another in response to some inputs; the change from one state to another is called a transition. An FSM is defined by a list of its states, its initial state, and the inputs that trigger each transition. Finite-state machines are of two types—deterministic finite-state machines and non-deterministic finite-state machines. For any non-deterministic finite-state machine, an equivalent deterministic one can be constructed.\n",
    "\n",
    "The behavior of state machines can be observed in many devices in modern society that perform a predetermined sequence of actions depending on a sequence of events with which they are presented. Simple examples are: vending machines, which dispense products when the proper combination of coins is deposited; elevators, whose sequence of stops is determined by the floors requested by riders; traffic lights, which change sequence when cars are waiting; combination locks, which require the input of a sequence of numbers in the proper order.\n",
    "\n",
    "The finite-state machine has less computational power than some other models of computation such as the Turing machine. The computational power distinction means there are computational tasks that a Turing machine can do but an FSM cannot. This is because an FSM's memory is limited by the number of states it has. A finite-state machine has the same computational power as a Turing machine that is restricted such that its head may only perform \"read\" operations, and always has to move from left to right. FSMs are studied in the more general field of automata theory.\n",
    "\"\"\"\n",
    "\n",
    "tokens = torch.tensor(enc.encode(string)).unsqueeze(0).to('cuda')\n",
    "tokens.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# test_input = torch.arange(0, 512).cuda().unsqueeze(0)\n",
    "logits, intermediate_results = datlm_forward_w_intermediate_results(model_dat, tokens)\n",
    "print((model_dat(tokens)[0] - logits).abs().max()) # check that datlm_forward_w_intermediate_results produces the same result as model_dat\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# layer 1\n",
    "symb_attn_stats = compute_summary_stats_attn_weights(intermediate_results['symbol_attn_scores'][0][0], symbs)\n",
    "\n",
    "plot_symb_attn_stats(symb_attn_stats, n_heads=n_heads, n_symbols=n_symbols)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# layer 2\n",
    "symb_attn_stats = compute_summary_stats_attn_weights(intermediate_results['symbol_attn_scores'][1][0], symbs)\n",
    "\n",
    "plot_symb_attn_stats(symb_attn_stats, n_heads=n_heads, n_symbols=n_symbols)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# layer 3\n",
    "symb_attn_stats = compute_summary_stats_attn_weights(intermediate_results['symbol_attn_scores'][2][0], symbs)\n",
    "\n",
    "plot_symb_attn_stats(symb_attn_stats, n_heads=n_heads, n_symbols=n_symbols)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# layer 4\n",
    "symb_attn_stats = compute_summary_stats_attn_weights(intermediate_results['symbol_attn_scores'][3][0], symbs)\n",
    "\n",
    "plot_symb_attn_stats(symb_attn_stats, n_heads=n_heads, n_symbols=n_symbols)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# layer 5\n",
    "symb_attn_stats = compute_summary_stats_attn_weights(intermediate_results['symbol_attn_scores'][4][0], symbs)\n",
    "\n",
    "plot_symb_attn_stats(symb_attn_stats, n_heads=n_heads, n_symbols=n_symbols)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# layer 6\n",
    "symb_attn_stats = compute_summary_stats_attn_weights(intermediate_results['symbol_attn_scores'][5][0], symbs)\n",
    "\n",
    "plot_symb_attn_stats(symb_attn_stats, n_heads=n_heads, n_symbols=n_symbols)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Summary Statistic of Symbolic Attention Across Layers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "entropy_by_layer = []\n",
    "max_attn_by_layer = []\n",
    "for l in range(model_dat.n_layers):\n",
    "    symb_attn_stats = compute_summary_stats_attn_weights(intermediate_results['symbol_attn_scores'][l][0], symbs)\n",
    "    entropy_by_layer.append(symb_attn_stats['attn_entropy'].mean())\n",
    "    max_attn_by_layer.append(symb_attn_stats['attn_max_weight'].mean())\n",
    "\n",
    "fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))\n",
    "sns.lineplot(x=np.arange(model_dat.n_layers), y=entropy_by_layer, ax=ax1)\n",
    "ax1.set_title('Mean Attention Entropy by Layer');\n",
    "\n",
    "sns.lineplot(x=np.arange(model_dat.n_layers), y=max_attn_by_layer, ax=ax2)\n",
    "ax2.set_title('Mean Max Attention Weight by Layer');"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Questions\n",
    "\n",
    "- What if symbol retriever is not shared across layers? Is this a limitation? TODO: implement symbol retriever for each layer. (e.g., with MQA to reduce param count and make comparable; or make symbols non-trainable, but keys trainable; or make symbols trainable but shared across layers; etc)\n",
    "- Add entropy regularization to symbolic attention weights? Some other way to promot sparseness of \n",
    "- Add param count that takes weight-tying into consideration. (e.g., weight-tying embedding-unembedding or symbols across symbol-retrievers)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
