{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\019800756\\Anaconda3\\envs\\ocd\\lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
    "from attn_vis_utils import normalize_attn_mat\n",
    "import matplotlib.pyplot as plt\n",
    "import transformers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████| 2/2 [00:02<00:00,  1.15s/it]\n"
     ]
    }
   ],
   "source": [
    "model = AutoModelForCausalLM.from_pretrained(\"google/recurrentgemma-2b-it\", device_map=\"auto\").cpu()\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"google/recurrentgemma-2b-it\")\n",
    "prompt = \"Alice was beginning to get very tired of sitting by her sister on the bank, and of having nothing to do: once or twice she had peeped into\"# the book her sister was reading, but it had no pictures or conversations in it, “and what is the use of a book,” thought Alice “without pictures or conversations?” So she was considering in her own mind (as well as she could, for the hot day made her feel very sleepy and stupid), whether the pleasure of making a daisy-chain would be worth the trouble of getting up and picking the daisies, when suddenly a White Rabbit with pink eyes\"# ran close by her. There was nothing so very remarkable in that; nor did Alice think it so very much out of the way to hear the Rabbit say to itself, “Oh dear! Oh dear! I shall be late!” (when she thought it over afterwards, it occurred to her that she ought to have wondered at this, but at the time it all seemed quite natural); but when the Rabbit actually took a watch out of its waistcoat-pocket, and looked at it, and then hurried on, Alice started to her feet, for it flashed across her mind that she had never before seen a rabbit with either a waistcoat-pocket, or a watch to take out of it, and burning with curiosity, she ran across the field after it, and fortunately was just in time to see it pop down a large rabbit-hole under the hedge. In another moment down went Alice after it, never once considering how in the world she was to get out again. The rabbit-hole went straight on like a tunnel for some way, and then dipped suddenly down, so suddenly that Alice had not a moment to think about stopping herself before she found herself falling down a very deep well. Either the well was very deep, or she fell very slowly, for she had plenty of time as she went down to look about her and to wonder what was going to happen next. First, she tried to look down and make out what she was coming to, but it was too dark to see anything; then she looked at the sides of the well, and noticed that they were filled with cupboards and book-shelves; here and there she saw maps and pictures hung upon pegs. She took down a jar from one of the shelves as she passed; it was labelled “ORANGE MARMALADE”, but to her great disappointment it was empty: she did not like to drop the jar for fear of killing somebody underneath, so managed to put it\"\n",
    "inputs = tokenizer(prompt, return_tensors=\"pt\").to(\"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'transformers.models.recurrent_gemma.modeling_recurrent_gemma.RecurrentGemmaRecurrentBlock'>\n"
     ]
    }
   ],
   "source": [
    "selected_layer = 0\n",
    "print(type(model.model.layers[selected_layer].temporal_block))\n",
    "#assert type(model.model.layers[selected_layer].temporal_block) == transformers.models.recurrent_gemma.modeling_recurrent_gemma.RecurrentGemmaSdpaAttention\n",
    "selected_chan = [0,1,2,3]#,1000,1001,1002,1003]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 2560, 32, 32])"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.model.layers[selected_layer].temporal_block.compute_attn_matrix = True\n",
    "model.model.layers[selected_layer].temporal_block.compute_attn_vec = False\n",
    "outputs = model.generate(**inputs, max_length=len(inputs[0])+1)\n",
    "attn_mat = model.model.layers[selected_layer].temporal_block.attn_matrix\n",
    "attn_mat.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAxsAAAC8CAYAAAAQL7MCAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAAPYQAAD2EBqD+naQAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAWe0lEQVR4nO3dW6wd53UY4H/2PjzUhXeKF4mkRN1MOZJtKbJhxZbhIEnhAHEKtIjtPCRFG6QFKrRAgT64RR+CFs1DUARoitR5KIK07lPaoCnQBLk0aePGiWTJiu6RZVEXUqTEO3l4P+fsvacPLYoAa001W5s/zznk9z0u/v/Mv2fP7OE6g7Wmadu2LQAAANfYYKUXAAAA3JgkGwAAQBWSDQAAoArJBgAAUIVkAwAAqEKyAQAAVCHZAAAAqpBsAAAAVUg2AACAKub6Dvxrg6/UXAfM7L9P/vNKL6H8+IGvh9j4zbdXYCWQWw3XyZdu+9kQm1y9ugIrgdxquE5KKeW1w3tC7B/t/9wKrARyfa4VTzYAAIAqJBsAAEAVkg0AAKCK3jUbZTCMsXaSj23bj7gcWNsO/ouNIbb/33wqHds8/VLt5cDq9MD+GHv1e9d9GbDaZfUZf+uN99Kx3zywr/Zy4CPxZAMAAKhCsgEAAFQh2QAAAKqQbAAAAFVINgAAgCr6d6OajCsuA24M3//ifwixn/gnP5mObbdvC7Hx6TPXfE2w6rx9eKVXAGtD04TQNx+6Ox36N/7yRIj99g/suOZLgml5sgEAAFQh2QAAAKqQbAAAAFVINgAAgCr6F4gnhls2p/H26mKMjUa9YrCWfepfPRVidx59Nh27+KOPhtitL8ZLcnw8Fv3BWvZ7B/88xH7ih/JGCmUQ/yY2ejcpMG/bWZcFa9pvP7wzxHb8efx/2snPnbsey4H/x5MNAACgCskGAABQhWQDAACoQrIBAABUMVOB+PjcwrVaB9wQBksx1k7ywtXBchIfj6/ximD1+fG//jMh1h5+LR07tysWvQ4+9fEQm9yS3M6eeXn6xcFqkjU+SN4q3uXk5+P/0977rUdCbN9PvTrVsmAanmwAAABVSDYAAIAqJBsAAEAVkg0AAKCKmQrEOyXFS6d//okQ23Iwvml8/r2z6SbHB9+ZfV1Q2b6vvh1iy7++Lh177LPrQ+yeD7bFgafPxJi3JbOGDd46GmLjjnN6fDa+7Xhw9WqIDYfDOPcjrA1Wva7f/56F4/u+EpsxbP+zrenY05/P/08G0/BkAwAAqEKyAQAAVCHZAAAAqpBsAAAAVUg2AACAKup0o0o6JWz/d0/3mqp7CGvZ2793X4jtHT2bjt3x8nIMHjsVQs1c7GY1evKRdJvzrxwKsfGp0+lYWDF37Yyxc7HrVCmlDDZuiMEdsWtbc3UpxIYd3Xmadck19cGxdCysGVmXqp4dqk4/mV9/T7wU71PPfCrvsAhdPNkAAACqkGwAAABVSDYAAIAqJBsAAEAVdQrEZzEY5vF2ksSSYihYQVsOxhYH7SQ/T285djmOXYpFrtm5v+5knFtKKe3lKx+yQlh5k/kpbj3DeE9o1yX3iUksWm337sq3+e7REGoefzjOf/61D18frGYzFI2XUsozj86H2I+9ej7E/uiRjVMti5uLJxsAAEAVkg0AAKAKyQYAAFCFZAMAAKhi9RWIT7xDnLVrMuxfeFcGM+T6XVOnKPyDFTPFddIk53SbXTvD/tdTu/fOGHz5zbjvz3win//cK733BavOjEXjWTH4u7/5yXTs/q+93Hu73Lg82QAAAKqQbAAAAFVINgAAgCokGwAAQBWSDQAAoIrV142qy2AYQvd/Z12I/fEfPpZO3/1M7HJ1y397dvZ1wV+x6b++EGJtV4e1F14Poclo1Gs/7auxc87/2YBubqx+x5/YFGK7Xoq/56WUMt63M8ROPbohxNZdih12ljbkHXbmrsbY9iO3htjRz8euO6WUsvu5NAxrV9ahqpS8S1US6+o69ZkX4z3puUfj/+e4sXmyAQAAVCHZAAAAqpBsAAAAVUg2AACAKtZOgXhS+Pr2k3H5947zyr12rHCW+g7908dD7J5/mTciuPSTcezG/xULvycXL4XY4g9/It3mrc++FWLjs2fjwKzor5TuIkG4hnY9vRBi7Wg5HTv33okQ27EUGyk0y/E3vp3Pb3HZ2MmVWDW+5w9OpvPLrli03szFfY2Ovp/Ph7Uiuyf0LBovpZTnHovXxeZvbwuxhSdPT7001g5PNgAAgCokGwAAQBWSDQAAoArJBgAAUMXaKRBPnP3aD4bYpneSV8OWUs59PRbZXviL7SG2+zux8HD973pdLP3s/Z9XQqyrOcHGV2PxaVYM3i7Hc/K2N2LRbCmljC9f/rAl/t+NTlEIPkje9upN5cxgcCH+TnedUe3iUog1l+L8JrlOylJHgfgo2dskXhPNlcV8UfPzITT+4FiIze25K8QUjbPm9S0a77DwhTMh9o8PvhZiv/zAw1Mti9XLkw0AAKAKyQYAAFCFZAMAAKhCsgEAAFSxpgvEt3zz6d5jt/1pjG1fvz4GkyLBNiuQLUWRLMHcc2+E2KSjGHvy7pEQa5djMWxmfCQvMm1HSZHsrLLzXNE4szhzLsY6rpP2Smy60JyNxahpI4Zh/tvdZr/zyfz2XHzTeSklLyZPisazYvDhls3pJsdd+4K1oKvpSM/C8V9+8JEQ+ztvvJuO/Y0D9/RdFauEJxsAAEAVkg0AAKAKyQYAAFCFZAMAAKhCsgEAAFSxprtRzapdXFzpJXCjeeDuGHsldqgqpZThzjtCbHziVIi1o+U4947t6TbHp04n82fsUJV0nhpuuD3u+/z52fbDTWOyb3cMnj6Tjm02bwqxdte2OG456SY1n9/isrHN5ctx4J070/llFOcPLlyKsW1bQ2x8/GS+ps98IsTa517J9w9rRdalqmeHqt94aH8av/Vb8bq88sXj06yK68yTDQAAoArJBgAAUIVkAwAAqEKyAQAAVHFTF4j31axfn//DOClITGJpgRQ3pObwBzHY8f1PkoLYrBg8mz8+ey7dZnr+zWoStznJimmhp/aF13qPHR8/EYNZrILx629el/2UUkpRDM7Nom/ReMe9MysGf/NXPxtiD/6D70y9NOrwZAMAAKhCsgEAAFQh2QAAAKqQbAAAAFUoEO+jo+j29w9/N8Qe+JO/HWLb/vDWENv675+eeVmsPs2GDTF4biEfe0vSeGApKxBP3nY8P59us11O3haezJ9KVrg3jG8VL9O8qXyKYkBuPHN794TY6MjRdOzg9vi2+sGmjSHWTiYh1gzyv6dlY8cnT4fYcMf2dH7aHGRxKe5/GPc/Pn8x3eTcnbvi2J6F8O001x6sRjO8abyUUh78h8+G2C+9EwvEv35vLCSnPk82AACAKiQbAABAFZINAACgCskGAABQhQLxHrqK77609/EQu3/yYhw4RZETa9vo/WO9x44XzsdgzyLpycW8yLRKkXWyzXZx8Zpvk5tIUjjd9TvZZM0Iklg6ey6/xTVpMXncQmeBebbNbP3NFH/Pm4ufabB1a6+pze2xCUkppYzePdx//7DadN0nev6fKisGP/gfH0vHPvCzL/ReFtPzZAMAAKhCsgEAAFQh2QAAAKqQbAAAAFVINgAAgCp0o5rFZNxvnM47N425u3aH2OjI0XTscPOmEBufT7pMJefZYMOGdJuTS5d7zZ9K0vmjmZ8Psak6VGXdRFwnN49x7AbV9f2343j+NkmszTpMdXQSzMa2k6TrWjKulFJKtv9k/U3bMT/dZjI2WX+zeWOceuT9dJNze/fETXb8HsGakf1WpN3gYqyr69TR//JwiO35m69NvTRynmwAAABVSDYAAIAqJBsAAEAVkg0AAKAKBeJwDbVXrvQfu7ScBHsWlC4nc6eZP42sGC8pkL2uFJivaeNdW2LwaF7kPNgYmyFMdm4NsWY5npOT+fwWl41tzi3EgduTdZZSyihp2nDhUhw3F/ffXMp/I8a742canlkXYkt3xTUNjx7Lt7mj53F27bDWzVA0Xkopd//ckRB76zc/GWL7v/by1EvDkw0AAKASyQYAAFCFZAMAAKhCsgEAAFShQHyldBQpLf/oD4bY+mOx8HBw6mw6f3Ts+GzrYjbJW4i5TgbDGJv17elUMb4tvoF+0HT87Sspsk7nX41v257ckt/iBnNxX03ymzy5NRZol1JKs5zMX0zGDpNzcpD/9o+TfQ3Wx9hoQ4zNDfNjN04+f5O96Tw5xqWU0na8gR3WhL5F46WUSdK4Yf9PvxJif//Ng+n8X3vwgenWdpPxZAMAAKhCsgEAAFQh2QAAAKqQbAAAAFVINgAAgCp0o1opWZeEUsq6//FiiE3aSRLT9Wg1mly40H/sldj9ouu8CHMXF/N/uE7nxcxdamZdZ9K5qEm6/AwOPJjv/vD7ITa5FLu+Ucdbfzd+VweeX5+OXfjs3hA7+mPx/Fl37rYQW96adyMbXoznz4FfuSPEvv+VTen8dZfi+rd8f0uILW2I43b+zuV0m2/+vRjb8PyOELvjy0dCbPDKtnSb3/v52Llq131PhNi5A3mHnsW9yyH2sZ/7bjoW1oSOe0877te58Nc+lt9T/vnb8br4hfse77+uG5wnGwAAQBWSDQAAoArJBgAAUIVkAwAAqEKB+Goz6VekxOrUzM+HWFcxdTMXizfb5aV++xkO0/jMhdt9DZL9X8dzdzAfj10ZxL+dnH84L5zdfC4W8isQv34+9iuxwUHaMKGUsvnP3g2xjQfj99osxXO/XZ+cJ6WUZjmeq+MTJ0PswW/GovNSSimjOL+5kJw/c/EWOz5zLt3kg/86FoMPz5wJsaXvxs8+PvV6us0Dv7o1bvODUyG24f096fzhn7wQg098MsaeeTmdD2tGdv9q8sYJmV+4/9Mh9t5vPRxi+37q1amWdaPwZAMAAKhCsgEAAFQh2QAAAKqQbAAAAFUoEL/BNOuTt/Amb8ZM35bpreQza5KC0M6xw5jrt6OkIC37XjoKxLPvusb3mr2tO3nRfT3rYuFvkxTzLezPj9OmlzoKf7kuBlfim6m72gtkTQ+aq3F+s5w0R5h0nPvZdZKNXYr7KaWUJikQb7Oi8eza67hQBleT9SefaZgcu67PObgaG05kx3N4Kf+c2W/H8PzVuPvkd++6NauAWrLrd4qi8X1feS3E/tnbL4bYL9736FTLWos82QAAAKqQbAAAAFVINgAAgCokGwAAQBUKxG8w7WJ8My/Xz/hCfDN1l8nVWGjZ10p/zytd/Dm5eLHXuH2/nr9ZedL3e8qKATVSmFlzKX9beKZdjEXOg2x+Vkje1bAhbaQQC7ebi5fzNWXnQHJNtkkjh7ajmDvbV5v8RgwW4rhx9nlKKc3FeJzapOh9eC7/nNlWm4V47WUNR5p18+k22+X4fcKa0fX737Nw/BfvfyzENn97Wzp24cnTvZe12nmyAQAAVCHZAAAAqpBsAAAAVUg2AACAKiQbAABAFbpR3WgGsftJmeSdSrj2hnfcEWLjkyfTsYPbbw+xyeWkK0zS/WJw223pNidXki49s3ZPSrpsNHPr4m6m6TIzY5en4batMZic+6//0j3p/AP/NjlOL3wvhOZ2xu9zdOJUvijXWW+jQ+/1Hpt1DuvdTWxGo2PHr8t+SpnimJzq36Fm9N6RfgPPn++/zaPv9xqn6xQ3lez+lf1/LOl6t/CFM+km/9ORp0Psq3t/aOqlrQaebAAAAFVINgAAgCokGwAAQBWSDQAAoAoF4jcaRaorqr10qf/YpeUk2K9Iul3qKL6ctRi85zbb8Yzn2YzrbC8lhfSD+LeTPb+bFOiVUoZHY5H3KCncm1y4GCd3XGNpwf+Vq73n30yaxx4OsfbFv0zHDnfuiGN3bYvbXBzFcevzW1wzSr7rtw6F2OC+u9P5ZRS/w+ZCcu3Pxf2Pj+cNI8ojD4bQ8GwshF++KzZHGDz3errJ5gfuj2OPx2LUpQN3pfMH33ohxh55KMYW4nUy3rk53ebwxEKyo9gwYpomArAqZb/1WXOUDl/d97kQu/VbO0PsyhevXyOLj8qTDQAAoArJBgAAUIVkAwAAqEKyAQAAVKFAnA81t3tXiLWX4xuYJ4uL6fy2I34japKC0E5JUWTvN2sP88LnkhVuVygab5K1J/XV9ayLbzBvkmN3/p78OG1+4dZeu2nm52OwowlAOjYrEKcM3orFv+OO83Ry5lyINcnvTztJTsCkaUDX2KzpQnu4423ZyVony7FAPbvG21HSGKKUMkyOySRZ09xCLBofd2xz8M7RODY5J9e9ms9PWxkcitucLMf5g4X8reTjm+h+AEH2OzdF0fiVHz4RYt849O107FP3PNl7u7V5sgEAAFQh2QAAAKqQbAAAAFVINgAAgCokGwAAQBW6UfGhJucWQqwdxc4r7eTadz1aa8bn8w4smVm6dK10h6/s+7+eJhdiR57Mnm/8RRofZccv6RIyPnu295qmGXvT2xM73JWO73SwaUMMbt8aQk3SDapdl9/imlHss9S+FztPDXbtSOdnXd/aS5fjfpLudONTp/Nt7t0d979wMcQmO7bE/VzMO6SVO3fGbZ6O5+nk3rvy+afPxH0l22wuxP23WzelmxyeSe4nWXexjo57zXzsRDdO7lGwZnR1jBwk10DS9vGp/V9Ip//B+y+E2JfuenSqpV0rnmwAAABVSDYAAIAqJBsAAEAVkg0AAKAKBeJ8qMnVqyu9hDUjKwjtLKbOir8msfA031GTx7sKza61bP/Xa98lP86Z9uP3p/HBwcMhlhWdN+vm4zaXl/I1TTG29/zRchx4HY9zLe36WOTbKbumkiLh7Jxs13ecJ8vJ39kGyfyOAvNsbHpOZrEm/xtfOx/HNsn+x7fGz971V8P2lji2SQqvx7fknzPbbnbss88+6fiOh9kxyba5c3s6vyTF6INPPhT3//L38vmwVmT/H+i69ye+tOexEPv48/H6f/3x+g1fPNkAAACqkGwAAABVSDYAAIAqJBsAAEAVCsSpLi18Td7Am70Zc60Vw6afq0vfYvB0Ryt8XFZ4/2nRfVI4N3zrvXT+OHnbc76fpED7GoztPT8rJm5nOG9WiclLr8dgxzk1Pn4iBk+cvMYryvc//v5bM82favfJMUl+EUtzKJ7Tbce+26xIOhk7ONnxVvPE5LU3Yizb/5Gj6fxR3+OUvL28U3aOwI0ou36yZjOlpP+nev3T8f7xO0efD7Ev73l86qX9/3iyAQAAVCHZAAAAqpBsAAAAVUg2AACAKhSIU900b1GGXqZ4iyqrz/CBe0NsfPCdfOyWLSHWbNkUB46SwvmuN4AnjRzGRz+I+95zZz4/2Vd7+UocN4x/z5ucW0g3Obj37hBrkrdlt9vj8Zi8kReyD/bvi8HT5+I2796dzs8K+ef2x3W2F5OGC1uT76iUUs6ej7HFxRBqbr8tnz+Ix3T0wfEQG27aEOfeuTPd5Pj1N/N9wVrQ1Wym533yy3s/HWLfOPSn6din7nmy97L+Kk82AACAKiQbAABAFZINAACgCskGAABQhWQDAACoQjcqVo3RjzweYus/SDqXlFLKmbyjCzeJto0xHarWjrlh/7GD+L22yfzs28/GdY0tTfK3t6TzUSklv3MmnafKcIrPma01md+um2Kb2f6z4znFNtNjn3z2rmOf7T+NzU3x35N2EmO7doTQ5OChfEmPPNR/X7BWzHCffGr/F9L4+m/t+khL8WQDAACoQrIBAABUIdkAAACqkGwAAABVNG2bVZAAAADMxpMNAACgCskGAABQhWQDAACoQrIBAABUIdkAAACqkGwAAABVSDYAAIAqJBsAAEAVkg0AAKCK/w1LJUNWWwoLZQAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 1000x1000 with 4 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "fig, axs = plt.subplots(1, len(selected_chan), figsize=(10,10))\n",
    "for i,c in enumerate(selected_chan):\n",
    "    curr_attn_mat = normalize_attn_mat(attn_mat[0,c,:,:].abs())\n",
    "    axs[i].imshow(curr_attn_mat.cpu().detach().numpy())\n",
    "    axs[i].axis('off')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def save_attention_matrix(attn_matrix, channel_index):\n",
    "    plt.figure(figsize=(6, 6))  # Create a new figure for the matrix\n",
    "    plt.imshow(attn_matrix.cpu().detach().numpy(), cmap='viridis')  # Display the matrix using a color map\n",
    "    plt.axis('off')  # Hide the axes\n",
    "    #plt.colorbar()  # Optionally add a colorbar to indicate the scale\n",
    "    filepath = f'griffin_attention_map_layer_{selected_layer}_channel_{channel_index}.png'  # Name the file based on the channel index\n",
    "    plt.savefig(filepath, dpi=300, bbox_inches='tight')  # Save the figure to a file\n",
    "    plt.close()  # Close the figure to free up memory\n",
    "\n",
    "for c in range(100):\n",
    "    curr_attn_mat = normalize_attn_mat(attn_mat[0, c, :, :].abs())\n",
    "    save_attention_matrix(curr_attn_mat, c)  #"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Ablation matrices:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAxsAAAC8CAYAAAAQL7MCAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAAPYQAAD2EBqD+naQAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAMnklEQVR4nO3dTaxdVRUH8H36HhZCECJRCEXAhAoEowYilMrAb6EkDmRggJmGxA9MRBNNMI5MHKqJqAPCEBgQHGhC0mjUKEJBISBCkbZCsX0WFMJH+tq+d885jgwxe18495277se7v99wZZ9zdh/33WZ181+natu2TQAAAGO2ZdobAAAANifNBgAAEEKzAQAAhNBsAAAAITQbAABACM0GAAAQQrMBAACE0GwAAAAhNBsAAECI5a4LP/OOG7JaOxiMdTPQx6+be6e9hXTNRd/JavW+f0xhJ1A2C78nzZHtWe2z53x4CjuBsln4PUnJ7wqzr8vvipMNAAAghGYDAAAIodkAAABCdM5slPIZr9+wo7j2nffs2fiOYI5V6/nvye6Vx4tr/X+3LKrVZm3aW4C5cM11N2W13St3Fdf6O4VZ5WQDAAAIodkAAABCaDYAAIAQmg0AACCEZgMAAAjReRpVyen3/qVYb/vcFOZYc+SlrLbz1i8X155x5rNZrX75lbHvCWbNeqo7r12+4LysNnj+hXFuB2bX0/uz0rW7biwu3b1yd1YzoYpZ4GQDAAAIodkAAABCaDYAAIAQmg0AACBEr4B4Oxh0Xrt0yfasVu/d1+fxMBe2DIaMTGjyenvVh7Ja9dAT494SzLeqymut0SRsPu2JE3lx74Hi2lJwXGicWeBkAwAACKHZAAAAQmg2AACAEJoNAAAgRK+A+CjqZ/K3YAr5sdm0dZPVTvnPWnltYcDC8mvHslr3dy3DfGgCvufvP/xYVtu17bKxPwemrRgaT6kYHBcaZxY42QAAAEJoNgAAgBCaDQAAIIRmAwAACDGxgHjX4PeRb+ws1s/+8YPj3A3EaPOA+NLr5YB4qvPod3VsSPAPGNmWU08t1pujRye8E4jX9W3jXUPjKQmOMx5ONgAAgBCaDQAAIIRmAwAACKHZAAAAQmg2AACAEJObRtXROXc+Waw3VZUXO064gklpm/wzufTGanFtU5hGlY4d7/X8pTPfldXql1/pdU/YbPb99Mqstv1rD09hJxCrz4SqlMpTqkyoYlRONgAAgBCaDQAAIIRmAwAACKHZAAAAQsxcQLx5443Oa7ecfHJ+/fF+AVuYmEKYvDX0gAVQp/F/zptR7llY+rODD2S1r55/dY8dwWzqGhpPqRwcFxpnVE42AACAEJoNAAAghGYDAAAIodkAAABCzFxAfBTC4MyDqm66Ly6ExkfS93qYgBF+I0a45/jvescLeWg8pZRuPk9wnM2lGBpPqRgcv+a6m7La7pW7ipcLjpOSkw0AACCIZgMAAAih2QAAAEJoNgAAgBCaDQAAIMRcT6Maxaf+9kZW+80HTpvCTlg4rQlRAMyf0pSq6un9We3aXTcWr9+9cndWM6Fq8TjZAAAAQmg2AACAEJoNAAAghGYDAAAIsTAB8V9995NZ7ZT0yBR2wqbWNnltfdB97WDI2q7qut/1wIY8//2rstoF33toCjuBWKXQeNp7oLi2FBwXGl88TjYAAIAQmg0AACCEZgMAAAih2QAAAEIsTED8tEcPZ7VSFHfwicuz2vJvHw3YEYxf623lsCF1qrLaSanf79PhX1ya1bZ9/qle94RZVAyNp1QMjguNLx4nGwAAQAjNBgAAEEKzAQAAhNBsAAAAIRYmID44vNJp3Ul/eCKrLZ3/3vI9D/6z155YDO3qsXK9ycOn7dpav4etr/e7vodqufx10vZ9KzpsIu9+8Ixi/d87X53wTiBe17eNdw2NpyQ4Po+cbAAAACE0GwAAQAjNBgAAEEKzAQAAhNBsAAAAIRZmGlVq88k/xWWFyTmDFw4V15am75i8s+AKn7OhE6baJi/1/Py0dX7PSSlN14KSuuP38bTvOYqq5+PvO7Qnq11/7o5+N4UZ1GdCVUrlKVUmVM02JxsAAEAIzQYAABBCswEAAITQbAAAACEWJyDex5DgYSnMW530jnzd+pCAMIthlOBq35B1IXQOzKmqymtTDsJDhK6h8ZTKwXGh8dnmZAMAAAih2QAAAEJoNgAAgBCaDQAAIISA+JiVwuC7Vx7PaoJLi6Ot6xEW9wt4T/Ut3sLp8LaathD6HuX6q/O/O7Y8kP8dA/OuGBpPqfPbxoXGZ4eTDQAAIIRmAwAACKHZAAAAQmg2AACAEALiE7Dr01/IaktnHCmurV99LXo7AIutMEehLgS369LCGfSTg38q1r9+/kcnvBOI1/Vt411D4ykJjkdzsgEAAITQbAAAACE0GwAAQAjNBgAAEEKzAQAAhDCNagLaAwezWrO2Xlxbbd2aX1+avADATGsnNMyqNEkrpZR+8NwjWe22910RvR2YuK4Tqq657qbi9btX7spqJlSNj5MNAAAghGYDAAAIodkAAABCaDYAAIAQAuIT0Bw/3nlte6LOi1uWCjctrGM2NUNSopNKj8KMifj2alITcNfJqdNkvg+Wzjg9f/arr03k2TBJ7dpaVtvy7PPFtR//4s1Z7Xcrd2Q1ofGNcbIBAACE0GwAAAAhNBsAAEAIzQYAABBCQHwetHnwsbr80vLSR5+K3g2B2mFhcjamKr9ZWTgf3rR89lnF+uDIixPeCYxR4Xu+WV0tLj35j09ntWt33ZjVdq/cXbxecPytOdkAAABCaDYAAIAQmg0AACCEZgMAAAih2QAAAEKYRjUPChMVqr3PlZdG7wU2gx0fzGt7/jr5fTDT8jmAi+XW/Xuz2o8uvGQKO4ExGTKJsDSlqtp7IKuVJlSlVJ5SZULVm5xsAAAAITQbAABACM0GAAAQQrMBAACEEBCfU6UwEzOqXfSY6eyp1uusZrjCJjQkDFrSpKrns7pd3/s5I9yz7vms0vW/PPznrPa5bR/p9RyYusJ3RXviRL6uEBpPqRwcFxp/k5MNAAAghGYDAAAIodkAAABCaDYAAIAQAuILat/tV2a17bc8PIWdMFbC6P9vSEC4OiEgPk0Rn9J6zv8LNiOE2WfN0kUXZrX67/unsBOIVQyNp1QMjguNv8nJBgAAEEKzAQAAhNBsAAAAITQbAABACAHxBXXxbXuzWl0NedvsHAcXZ0HbjPDzm+eA9xx9TqoTaxu/drn8tdkOBhu+J2w29ccuK9aXfv/YhHcC8bq+bbxraDylzRUcd7IBAACE0GwAAAAhNBsAAEAIzQYAABBCswEAAIQwjWpBmZwDTMvRZoR/5+o45Wy1qTvfsqrzyXtH2/yvwzoNmdBXumdhm3Wd/znX6qXi9Uc7TqIrPae095RSWm22drpnO+RnfLRw/WrzSukGWakacs/7Du3Jatefu+Ntdgjzp8+EqpTKU6rmdUKVkw0AACCEZgMAAAih2QAAAEJoNgAAgBAC4guqWV2d9hZgejqGjvuqlvOvWMMZ5kfTdg+Ip8l8pKauDviDPvvzK7La+7/yyNifA9PWNTSeUjk4Pq+hcScbAABACM0GAAAQQrMBAACE0GwAAAAhBMQZm323X1msb7/l4QnvBN5a1SxImhfm1NJZ7ynW6xdfmvBOIFYxNJ5S57eNl0LjKc1WcNzJBgAAEEKzAQAAhNBsAAAAITQbAABACM0GAAAQwjQqxuaibz9ZrDcT3sfMaSf4E2hNWerk2PENX9qOMMlqlLXA23v5S1dltTPvfGgKO4FYxSlVHSdUpVSeUjWtCVVONgAAgBCaDQAAIIRmAwAACKHZAAAAQgiIMzbN6mrntdXWrVmtGIZaNALeE9Gur/e4eITAf9/hAFVVuKfPSISmLfzbWzX+4Q7F59DLDc+sZLV7Lj5nCjuBWF1D4ymVg+PTCo371gMAAEJoNgAAgBCaDQAAIIRmAwAACCEgzlSUQk7L527LaoNDhyexHRbNYDDtHWxYtZx/bbdz/OeBCPt/uCOrXfjNPVPYCcQaOlyn49vGJxEad7IBAACE0GwAAAAhNBsAAEAIzQYAABBCQJyZMTicvwUWQjTz+xbudo73/j9NKrwZvad67HdMqe65z7btfn2fd5UPeyt53/2X7tv3nep1Knx+e36k645vZb//8GPF+q5tl/XbAMygrm8b7xoaT2njwXEnGwAAQAjNBgAAEEKzAQAAhNBsAAAAITQbAABACNOomB1t95Ek//rWzsCNAJH6TkkqGWlKUuGrprSnk3res2TYJK66x0SmYT/PrlOaht83v74e4Xu6pOl5fUlpn8VnD/mPdN+hPePcDsysPhOqUho+pertONkAAABCaDYAAIAQmg0AACCEZgMAAAhRtW1AWgsAAFh4TjYAAIAQmg0AACCEZgMAAAih2QAAAEJoNgAAgBCaDQAAIIRmAwAACKHZAAAAQmg2AACAEP8F30td5+gteNwAAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 1000x1000 with 4 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "model.model.layers[selected_layer].temporal_block.compute_attn_matrix = True\n",
    "model.model.layers[selected_layer].temporal_block.compute_attn_vec = False\n",
    "model.model.layers[selected_layer].temporal_block.ablate_conv = True\n",
    "model.model.layers[selected_layer].temporal_block.ablate_gate = True\n",
    "outputs = model.generate(**inputs, max_length=len(inputs[0])+1)\n",
    "attn_mat = model.model.layers[selected_layer].temporal_block.attn_matrix\n",
    "fig, axs = plt.subplots(1, len(selected_chan), figsize=(10,10))\n",
    "for i,c in enumerate(selected_chan):\n",
    "    curr_attn_mat = normalize_attn_mat(attn_mat[0,c,:,:].abs())\n",
    "    axs[i].imshow(curr_attn_mat.cpu().detach().numpy())\n",
    "    axs[i].axis('off')\n",
    "model.model.layers[selected_layer].temporal_block.compute_attn_matrix = False\n",
    "model.model.layers[selected_layer].temporal_block.ablate_conv = False\n",
    "model.model.layers[selected_layer].temporal_block.ablate_gate = False"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Attention vectors (cls-to-others)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAABFEAAAAbCAYAAAC0lRC+AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAAPYQAAD2EBqD+naQAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAChElEQVR4nO3dvUvUcRwH8LvzkqQ0FaSosLIHa4u2EGqKGsKWgqboYQhaIpe2lqb2oC2CgoYibJFCKhuiIEkaGsqKnowKQS7SSPR+/QWfb18OQovXa33f53Nfjt89vTm4clEURQkAAACApMpCHwAAAADgX6BEAQAAAMigRAEAAADIoEQBAAAAyKBEAQAAAMigRAEAAADIoEQBAAAAyKBEAQAAAMigRAEAAADIUM294Z7Kob95Dvgvlavpp9idD6NhtunBseRs5/DSMBu9PJA+2CK2r/dsmM2Pv03ONnV1hVl9airMivn59N7Ojnhv7XtytpibS+YNK5fDqNLSEmb1mZnG77PSlM7r6ccxUt2wLpkX1fh+hx7eCrO+0yeTe1sHx8Ks3NuTnK2/eBlmTds2x3Ov0tdwubk5Dnu6w2h6Y1tyb+vopzCbm/icnE0Zrt9oeHahnXh6NMw+7ko/byeP7AizjsPxY/3+W2dyb8/KyTCbqK1IznYPTIfZmwvx9fGrFr+XlEqlUvvYkjCbTRxp/ZX0tb72di3M7o9sT85ePXgxzM7vPhBmy67/TO59fW1LmD07dynM9q5On3d2OH6N+/JoTXJ2f//jMBu6uTPM2t7Vk3u/9hVhVpTjrPfM8+TetnvLw+zH8fbkbP/gkzA7tXUkObuY+Q4FDfjTZ84i/RqXcnci/vxXWTWetcMvUQAAAAAyKFEAAAAAMihRAAAAADIoUQAAAAAyKFEAAAAAMihRAAAAADKUi6KI/8cMAAAAgFKp5JcoAAAAAFmUKAAAAAAZlCgAAAAAGZQoAAAAABmUKAAAAAAZlCgAAAAAGZQoAAAAABmUKAAAAAAZlCgAAAAAGX4DYq5xL22lJlcAAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 1400x1000 with 4 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "model.model.layers[selected_layer].temporal_block.compute_attn_vec = True\n",
    "model.model.layers[selected_layer].temporal_block.compute_attn_matrix = False\n",
    "outputs = model.generate(**inputs, max_length=len(inputs[0])+1)\n",
    "attn_vec = model.model.layers[selected_layer].temporal_block.attn_vec\n",
    "fig, axs = plt.subplots(1, len(selected_chan), figsize=(14,10))\n",
    "for i,c in enumerate(selected_chan):\n",
    "    curr_attn_vec = normalize_attn_mat(attn_vec[0,c,:].unsqueeze(-2).abs())\n",
    "    axs[i].imshow(curr_attn_vec.cpu().detach().numpy())\n",
    "    axs[i].axis('off')\n",
    "model.model.layers[selected_layer].temporal_block.compute_attn_vec = False\n",
    "model.model.layers[selected_layer].temporal_block.compute_attn_matrix = False"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Ablate attention-cls vectors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAABFEAAAAbCAYAAAC0lRC+AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAAPYQAAD2EBqD+naQAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAACbUlEQVR4nO3ay4tOcRgH8POed2ouSaE0Q8wYl2am3LLRlJWGlZQSJSW2kgX7KbKwZEHZWLgUaRbKJMspsbORYtTM5DIUJeXSzLzHX/D8/LwNM+rz2X7f5zm/Tm+nzrdTq6qqKgAAAABIKhf6AAAAAAD/AyUKAAAAQAYlCgAAAEAGJQoAAABABiUKAAAAQAYlCgAAAEAGJQoAAABABiUKAAAAQAYlCgAAAECGltwfDpUH/+Y5gD/08N2zMCs7X/3Dk8yvoZbDYVbvXZucbUy+jWdXLAuz2ekPvz9YuHd5Mp/79Lnp3U2r1cKo7OhIjja+/0iEc+nrlvXmZ5vcW7a1htno+OPk2r2rtsVh4h4WRVHU+zaEWWN8MszKTeuSe+dejMdh4h7WN/Ym937ZvjLMltx5kpxNedS42/TsQns+tTrMTvcMJmdbuteE2ezUmzC7PjmW3Hu8b0+YvbywJTl7e//lMGurxf+dsz07k3unhuN7UQ18DbOBzunk3ks9I2F2ontXcrYa3Bpm529cC7Nvjfh5URRFcXFH4rqt8eyxsafJvUPt78PsyO6jydmZrqVhNnLzSphtvn8qubd/eCLMPu5bH2Ynz9xL7j03eiDMXh+6mpz9Wc2EWXvXRHJ2MfMOBfOvlngm1/rjZ1hRFMXog1thlvsO5UsUAAAAgAxKFAAAAIAMShQAAACADEoUAAAAgAxKFAAAAIAMShQAAACADLWqqqqFPgQAAADAYudLFAAAAIAMShQAAACADEoUAAAAgAxKFAAAAIAMShQAAACADEoUAAAAgAxKFAAAAIAMShQAAACADEoUAAAAgAy/ADV4YC/+QfVRAAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 1400x1000 with 4 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "model.model.layers[selected_layer].temporal_block.compute_attn_vec = True\n",
    "model.model.layers[selected_layer].temporal_block.compute_attn_matrix = False\n",
    "model.model.layers[selected_layer].temporal_block.ablate_conv = True\n",
    "model.model.layers[selected_layer].temporal_block.ablate_gate = True\n",
    "outputs = model.generate(**inputs, max_length=len(inputs[0])+1)\n",
    "attn_vec = model.model.layers[selected_layer].temporal_block.attn_vec\n",
    "fig, axs = plt.subplots(1, len(selected_chan), figsize=(14,10))\n",
    "for i,c in enumerate(selected_chan):\n",
    "    curr_attn_vec = normalize_attn_mat(attn_vec[0,c,:].unsqueeze(-2).abs())\n",
    "    axs[i].imshow(curr_attn_vec.cpu().detach().numpy())\n",
    "    axs[i].axis('off')\n",
    "model.model.layers[selected_layer].temporal_block.compute_attn_vec = False\n",
    "model.model.layers[selected_layer].temporal_block.compute_attn_matrix = False\n",
    "model.model.layers[selected_layer].temporal_block.ablate_conv = False\n",
    "model.model.layers[selected_layer].temporal_block.ablate_gate = False"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
