{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "from matplotlib.gridspec import GridSpec"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "# NIAH Mamba-130M\n",
    "path_to_params_for_debug_per_example = './artifacts/params_for_debug_per_example.pt' # insert path to .pt file\n",
    "params_for_debug_per_example = torch.load(path_to_params_for_debug_per_example, map_location='cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = 'cuda:3' # select cuda device\n",
    "num_layers = len(params_for_debug_per_example[0]['A'][0])\n",
    "ctx_lens = [2000, 4000, 8000, 16000]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Compute Mamba Attention"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "i_ctx_len = 0 # select context length from 0 to len(ctx_lens)-1\n",
    "layers = np.arange(num_layers) # select layers (list)\n",
    "\n",
    "average_channels = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if average_channels:\n",
    "    L = params_for_debug_per_example[i_ctx_len]['delta_t'][0][0].shape[2]\n",
    "    attn_map = torch.zeros(len(layers), 1, L, L) # we create a dummy 'channels' dim just to be consistent with the non-averaging case\n",
    "    selected_channels = range(0,1536,150) # 10 channels\n",
    "else:\n",
    "    attn_map = []\n",
    "    selected_channels = range(0,params_for_debug_per_example[i_ctx_len]['A'][0][0].shape[0],150) # select 10 channels\n",
    "\n",
    "for l, l_idx in enumerate(layers):\n",
    "    print(f'layer {l_idx+1}')\n",
    "    A = params_for_debug_per_example[i_ctx_len]['A'][0][l_idx].to(device) # [ND, d_ssm]\n",
    "    delta_t = params_for_debug_per_example[i_ctx_len]['delta_t'][0][l_idx].squeeze().to(device) # [ND, L]\n",
    "    Sb_x = params_for_debug_per_example[i_ctx_len]['Sb_x'][0][l_idx].squeeze().to(device)  # [d_ssm, L]\n",
    "    C_t = params_for_debug_per_example[i_ctx_len]['C'][0][l_idx].squeeze().to(device)  # [d_ssm, L]\n",
    "    B_t = torch.einsum('ib,jb->ijb', (delta_t, Sb_x)) # [ND, d_ssm, L] each B is discretized by the channel's delta\n",
    "\n",
    "    L = delta_t.shape[1]\n",
    "    ND = len(selected_channels)\n",
    "    d_state_space = C_t.shape[0]\n",
    "\n",
    "    if average_channels:\n",
    "        for c, c_idx in enumerate(tqdm(selected_channels)):\n",
    "            delta_sum_map = torch.zeros(L, L).to(device)\n",
    "            for i in range(1,L):\n",
    "                delta_sum_map[i,0:i] = delta_t[c_idx,i]\n",
    "            delta_sum_map = torch.cumsum(delta_sum_map, dim=0)\n",
    "\n",
    "            A_delta_t = torch.kron(delta_sum_map, A[c_idx,:].unsqueeze(dim=1).to(device))\n",
    "            B_t_expanded = B_t[c_idx].repeat(L,1)\n",
    "            A_t_B_t = torch.exp(A_delta_t)*(B_t_expanded)\n",
    "            C_t_expanded = torch.block_diag(*C_t.T)\n",
    "            attn_map[l,0] = attn_map[l] + torch.tril(C_t_expanded @ A_t_B_t).cpu() / len(selected_channels)\n",
    "\n",
    "    else:\n",
    "        delta_sum_map = torch.zeros(ND, L, L).to(device) # delta sum mat for each channel\n",
    "        for i in range(1,L):\n",
    "            for c, c_idx in enumerate(selected_channels):\n",
    "                delta_sum_map[c,i,0:i] = delta_t[c_idx,i]\n",
    "        delta_sum_map = torch.cumsum(delta_sum_map, dim=1)\n",
    "\n",
    "        A_delta_t = torch.concat([torch.kron(delta_sum_map[c], A[c_idx,:].unsqueeze(dim=1).to(device)).unsqueeze(dim=0) for c, c_idx in enumerate(selected_channels)], dim=0)\n",
    "        B_t_expanded = torch.concat([B_t[c_idx].repeat(L,1).unsqueeze(dim=0) for c, c_idx in enumerate(selected_channels)], dim=0)\n",
    "        A_t_B_t = torch.exp(A_delta_t)*(B_t_expanded)\n",
    "        C_t_expanded = torch.block_diag(*C_t.T)\n",
    "\n",
    "        attn_map.append(torch.concat([(torch.tril(C_t_expanded @ A_t_B_t[c])).unsqueeze(dim=0) for c, c_idx in enumerate(selected_channels)], dim=0).cpu().unsqueeze(dim=0))\n",
    "\n",
    "if not average_channels:\n",
    "    attn_map = torch.cat(attn_map, dim=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([24, 1, 2026, 2026])"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "attn_map.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "attn_map_normalized = torch.abs(attn_map.clone())\n",
    "for i_layer in range(attn_map.shape[0]):\n",
    "    attn_map_normalized[i_layer,0,:,:] = attn_map_normalized[i_layer,0,:,:]/torch.max(attn_map_normalized[i_layer,0,:,:])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### All Layers Plot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Nrows = 4\n",
    "Ncols = 6\n",
    "\n",
    "fig=plt.figure(figsize=(30, 20))\n",
    "gs=GridSpec(Nrows,Ncols)\n",
    "axes = []\n",
    "for i in range(num_layers):\n",
    "    axes.append(fig.add_subplot(gs[i//Ncols,i%Ncols]))\n",
    "\n",
    "TH_power = -3\n",
    "i_channel = 0\n",
    "for i_ax, ax in enumerate(axes):\n",
    "    i_layer = i_ax\n",
    "    layer = layers[i_layer]\n",
    "    if i_ax//Ncols != Nrows-1:\n",
    "        ax.get_xaxis().set_visible(False)\n",
    "    if i_ax%Ncols != 0:\n",
    "        ax.get_yaxis().set_visible(False)\n",
    "    attn_map_db = torch.max(TH_power*torch.ones(attn_map_normalized[i_layer, 0].shape),torch.log10(torch.abs(attn_map_normalized[i_layer, 0])).to(torch.float32))\n",
    "    im = ax.imshow(attn_map_db, interpolation='none', aspect='auto', cmap='turbo')\n",
    "    \n",
    "    ax.set_title(i_layer+1)\n",
    "\n",
    "fig.colorbar(im, ax=axes)\n",
    "fig.suptitle(f'normalized attention maps per layer [log scale], TH = 1e{TH_power}', fontsize=28)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Single Layer Plot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig=plt.figure(figsize=(20,10))\n",
    "gs=GridSpec(1,1)\n",
    "axes = []\n",
    "axes.append(fig.add_subplot(gs[0,0]))\n",
    "\n",
    "TH_power = -2\n",
    "i_channel = 0\n",
    "i_layer = 16\n",
    "for i_ax, ax in enumerate(axes):\n",
    "    layer = layers[i_layer]\n",
    "    attn_map_db = torch.max(TH_power*torch.ones(attn_map_normalized[i_layer,0].shape),torch.log10(torch.abs(attn_map_normalized[i_layer,0])).to(torch.float32))\n",
    "    im = ax.imshow(attn_map_db, interpolation='none', aspect='auto', cmap='turbo')\n",
    "    ax.set_title(f'average attention map layer {layer+1} [log scale], TH = 1e{TH_power}', fontsize=24)\n",
    "\n",
    "fig.colorbar(im, ax=axes)\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mamba",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
