{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "29715241-42af-454a-b652-c720e6367080",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "from transformers.models.mistral.modeling_mistral import *\n",
    "from transformers.cache_utils import DynamicCache\n",
    "import torch\n",
    "from torch import nn\n",
    "import copy\n",
    "from types import MethodType\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "a6c155f0-b191-4d52-8128-65cdc1899374",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9d009102885c42cba5512684d1a82cd0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "config.json:   0%|          | 0.00/571 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a213ceffd6884ec89bd9a35aeade4072",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "model.safetensors.index.json:   0%|          | 0.00/25.1k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "73ef95ae55634801b984ab63d54a64e9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2f298f48adc44918ad8c911868daf6ca",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "model-00001-of-00002.safetensors:   0%|          | 0.00/9.94G [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6f38dce8468d4cd7b96a49373c223a4c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "model-00002-of-00002.safetensors:   0%|          | 0.00/4.54G [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b34e0ce6e06941ddbc4e7b2f016de133",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e07cd661c354461890f8cff9cbe6fe35",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "generation_config.json:   0%|          | 0.00/116 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2d69e4d640764a2e8a2774df5d34c5d1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "tokenizer_config.json:   0%|          | 0.00/967 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8a2906f41edd48e18d482ea05ddef8e2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "47a7840f2c3f482299b66aa4ce4e660c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "tokenizer.json:   0%|          | 0.00/1.80M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d4d842d1ba9b4e3f90d769df8ad97840",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "special_tokens_map.json:   0%|          | 0.00/72.0 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "model = AutoModelForCausalLM.from_pretrained(\"mistralai/Mistral-7B-v0.1\").to('cuda')\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"mistralai/Mistral-7B-v0.1\")\n",
    "Token = {v: k for k, v in tokenizer.get_vocab().items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "62138208-b0b7-4b09-b4b1-c1b3d4669806",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch.autograd.grad_mode.set_grad_enabled at 0x7ff6902f6b90>"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.set_grad_enabled(False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "64235cf7-7404-41b8-a2e6-7fa2f1cd82cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "def topk(v, k=40, aux=None):\n",
    "    # Takes in logits\n",
    "    #v = softmax(v.flatten())\n",
    "    if type(v) == torch.Tensor:\n",
    "        v = v.detach().cpu().numpy()\n",
    "    v = v.flatten()\n",
    "    idxs = v.argsort()[-k:][::-1]\n",
    "    if aux:\n",
    "        ret = [(Token[i], v[i]) + tuple(aux[i]) for i in idxs]\n",
    "        return pd.DataFrame(ret, columns=['token', 'logit'] + list(range(len(aux[0]))))\n",
    "    else:\n",
    "        ret = [(Token[i], v[i]) for i in idxs]\n",
    "        return pd.DataFrame(ret, columns=['token', 'logit'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "fedcc853-76c2-41a7-bd3a-7f60b919ca64",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[1, 1984, 6656, 2442, 302, 272, 23404, 2401, 349, 549, 28250]"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "list(input['input_ids'][0].cpu().numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "id": "b8a963cb-84cd-4884-8eee-5b1a63147e33",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys(['input_ids', 'attention_mask'])"
      ]
     },
     "execution_count": 47,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "input.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "id": "17342153-980c-49ad-a01e-fdd6351f882d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "32"
      ]
     },
     "execution_count": 53,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(out.past_key_values)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "id": "b15a62aa-dbf9-4252-b0d1-5f2ef4c2b6a7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "MistralDecoderLayer(\n",
       "  (self_attn): MistralAttention(\n",
       "    (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "    (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "    (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "    (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "    (rotary_emb): MistralRotaryEmbedding()\n",
       "  )\n",
       "  (mlp): MistralMLP(\n",
       "    (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "    (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "    (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "    (act_fn): SiLU()\n",
       "  )\n",
       "  (input_layernorm): MistralRMSNorm()\n",
       "  (post_attention_layernorm): MistralRMSNorm()\n",
       ")"
      ]
     },
     "execution_count": 54,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.model.layers[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "id": "797a7d35-f1cf-44d9-b7e0-328c81f787f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Copied from transformers.models.llama.modeling_llama.repeat_kv\n",
    "def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:\n",
    "    \"\"\"\n",
    "    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,\n",
    "    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)\n",
    "    \"\"\"\n",
    "    batch, num_key_value_heads, slen, head_dim = hidden_states.shape\n",
    "    if n_rep == 1:\n",
    "        return hidden_states\n",
    "    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)\n",
    "    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "id": "85bd28f7-4d72-44af-9d56-a9e65cb76435",
   "metadata": {},
   "outputs": [],
   "source": [
    "def lens(model, prompt):\n",
    "    ret = []\n",
    "    input = tokenizer(prompt, return_tensors='pt').to('cuda')\n",
    "    out = model(**input, output_hidden_states=True, output_attentions=True)\n",
    "    ret.append(\n",
    "        ['E'] + [Token[x] for x in list(input['input_ids'][0].cpu().numpy())]\n",
    "    )\n",
    "    seq_len = input['input_ids'].shape[-1]\n",
    "    for h, h_prev, attn, (keys, values), decoder in zip(\n",
    "        out.hidden_states[1:], out.hidden_states[:-1],\n",
    "        out.attentions, out.past_key_values, model.model.layers\n",
    "    ):\n",
    "        keys = repeat_kv(keys, 4)\n",
    "        values = repeat_kv(values, 4)\n",
    "        h_row = ['H'] + [\n",
    "            #topk(model.lm_head(h)[0,j,:], k=1).iloc[0,0] \n",
    "            torch.norm((h)[0,j,:]).item()\n",
    "            for j in range(seq_len)\n",
    "        ]\n",
    "        hd_row = ['Hd'] + [\n",
    "            #topk(model.lm_head(h)[0,j,:], k=1).iloc[0,0] \n",
    "            torch.norm((h-h_prev)[0,j,:]).item()\n",
    "            for j in range(seq_len)\n",
    "        ]\n",
    "        \n",
    "        a = torch.matmul(attn, values).transpose(1, 2).contiguous()\n",
    "        a = a.reshape(1, seq_len, -1)\n",
    "        a = decoder.self_attn.o_proj(a)\n",
    "        a_row = ['A'] + [\n",
    "            torch.norm(a[0,j,:]).item()\n",
    "            #topk(model.lm_head(a)[0,j,:], k=1).iloc[0,0] \n",
    "            for j in range(seq_len)\n",
    "        ]\n",
    "        f = h - h_prev - a\n",
    "        f_row = ['F'] + [torch.norm(f[0,j,:]).item() for j in range(seq_len)]\n",
    "        ret.append(a_row)\n",
    "        ret.append(f_row)\n",
    "        ret.append(hd_row)\n",
    "        ret.append(h_row)\n",
    "    return pd.DataFrame(data=ret)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "id": "8999628d-cda0-45c4-9100-d50fc1f3e889",
   "metadata": {},
   "outputs": [],
   "source": [
    "out = lens(model, prompt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 81,
   "id": "97d8c538-d0ec-4436-bf50-bc9e18842eb8",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>0</th>\n",
       "      <th>7</th>\n",
       "      <th>8</th>\n",
       "      <th>9</th>\n",
       "      <th>10</th>\n",
       "      <th>11</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>99</th>\n",
       "      <td>Hd</td>\n",
       "      <td>5.58987</td>\n",
       "      <td>3.881607</td>\n",
       "      <td>4.543558</td>\n",
       "      <td>5.443723</td>\n",
       "      <td>4.543383</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>100</th>\n",
       "      <td>H</td>\n",
       "      <td>19.137266</td>\n",
       "      <td>17.948874</td>\n",
       "      <td>18.324638</td>\n",
       "      <td>17.267794</td>\n",
       "      <td>16.813477</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>101</th>\n",
       "      <td>A</td>\n",
       "      <td>0.783993</td>\n",
       "      <td>1.215868</td>\n",
       "      <td>0.592541</td>\n",
       "      <td>0.581104</td>\n",
       "      <td>0.838376</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>102</th>\n",
       "      <td>F</td>\n",
       "      <td>4.603269</td>\n",
       "      <td>3.90245</td>\n",
       "      <td>6.33707</td>\n",
       "      <td>6.682795</td>\n",
       "      <td>4.82113</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>103</th>\n",
       "      <td>Hd</td>\n",
       "      <td>4.71511</td>\n",
       "      <td>4.126276</td>\n",
       "      <td>6.411267</td>\n",
       "      <td>6.7142</td>\n",
       "      <td>4.950005</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>104</th>\n",
       "      <td>H</td>\n",
       "      <td>20.087797</td>\n",
       "      <td>19.162827</td>\n",
       "      <td>20.183628</td>\n",
       "      <td>19.750025</td>\n",
       "      <td>18.357061</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>105</th>\n",
       "      <td>A</td>\n",
       "      <td>0.730662</td>\n",
       "      <td>1.450283</td>\n",
       "      <td>1.222855</td>\n",
       "      <td>0.704192</td>\n",
       "      <td>1.69537</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>106</th>\n",
       "      <td>F</td>\n",
       "      <td>5.34078</td>\n",
       "      <td>4.049546</td>\n",
       "      <td>6.573922</td>\n",
       "      <td>5.84883</td>\n",
       "      <td>4.495585</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>107</th>\n",
       "      <td>Hd</td>\n",
       "      <td>5.422429</td>\n",
       "      <td>4.279365</td>\n",
       "      <td>6.709745</td>\n",
       "      <td>5.923417</td>\n",
       "      <td>4.827442</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>108</th>\n",
       "      <td>H</td>\n",
       "      <td>21.611397</td>\n",
       "      <td>20.505722</td>\n",
       "      <td>22.066854</td>\n",
       "      <td>20.677406</td>\n",
       "      <td>19.252426</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>109</th>\n",
       "      <td>A</td>\n",
       "      <td>1.585943</td>\n",
       "      <td>2.666184</td>\n",
       "      <td>2.517449</td>\n",
       "      <td>1.86857</td>\n",
       "      <td>1.808951</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>110</th>\n",
       "      <td>F</td>\n",
       "      <td>4.940273</td>\n",
       "      <td>4.231506</td>\n",
       "      <td>7.724101</td>\n",
       "      <td>6.974767</td>\n",
       "      <td>5.954541</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>111</th>\n",
       "      <td>Hd</td>\n",
       "      <td>5.468703</td>\n",
       "      <td>5.110724</td>\n",
       "      <td>8.27581</td>\n",
       "      <td>7.212241</td>\n",
       "      <td>6.262464</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>112</th>\n",
       "      <td>H</td>\n",
       "      <td>23.517799</td>\n",
       "      <td>22.213755</td>\n",
       "      <td>25.222637</td>\n",
       "      <td>22.568491</td>\n",
       "      <td>21.108782</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>113</th>\n",
       "      <td>A</td>\n",
       "      <td>3.093532</td>\n",
       "      <td>2.694281</td>\n",
       "      <td>2.05949</td>\n",
       "      <td>1.771824</td>\n",
       "      <td>1.927072</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>114</th>\n",
       "      <td>F</td>\n",
       "      <td>6.747706</td>\n",
       "      <td>5.992997</td>\n",
       "      <td>6.937422</td>\n",
       "      <td>7.696311</td>\n",
       "      <td>5.701997</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>115</th>\n",
       "      <td>Hd</td>\n",
       "      <td>7.747508</td>\n",
       "      <td>6.487483</td>\n",
       "      <td>7.319106</td>\n",
       "      <td>7.875821</td>\n",
       "      <td>6.06427</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>116</th>\n",
       "      <td>H</td>\n",
       "      <td>25.291777</td>\n",
       "      <td>24.269609</td>\n",
       "      <td>27.259624</td>\n",
       "      <td>24.313761</td>\n",
       "      <td>22.881435</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>117</th>\n",
       "      <td>A</td>\n",
       "      <td>3.195761</td>\n",
       "      <td>3.83004</td>\n",
       "      <td>4.28471</td>\n",
       "      <td>3.304821</td>\n",
       "      <td>3.269844</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>118</th>\n",
       "      <td>F</td>\n",
       "      <td>8.433399</td>\n",
       "      <td>7.096021</td>\n",
       "      <td>8.22958</td>\n",
       "      <td>11.155994</td>\n",
       "      <td>6.829512</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>119</th>\n",
       "      <td>Hd</td>\n",
       "      <td>9.343595</td>\n",
       "      <td>9.08414</td>\n",
       "      <td>10.062134</td>\n",
       "      <td>11.969229</td>\n",
       "      <td>8.200727</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>120</th>\n",
       "      <td>H</td>\n",
       "      <td>29.047668</td>\n",
       "      <td>29.251232</td>\n",
       "      <td>32.319176</td>\n",
       "      <td>27.09149</td>\n",
       "      <td>26.53953</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>121</th>\n",
       "      <td>A</td>\n",
       "      <td>3.079163</td>\n",
       "      <td>3.478113</td>\n",
       "      <td>2.703693</td>\n",
       "      <td>2.087703</td>\n",
       "      <td>2.854382</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>122</th>\n",
       "      <td>F</td>\n",
       "      <td>8.160327</td>\n",
       "      <td>8.320078</td>\n",
       "      <td>9.885159</td>\n",
       "      <td>11.088625</td>\n",
       "      <td>7.948795</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>123</th>\n",
       "      <td>Hd</td>\n",
       "      <td>9.372631</td>\n",
       "      <td>9.633545</td>\n",
       "      <td>10.671306</td>\n",
       "      <td>11.384894</td>\n",
       "      <td>8.794673</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>124</th>\n",
       "      <td>H</td>\n",
       "      <td>33.131432</td>\n",
       "      <td>34.066902</td>\n",
       "      <td>36.608715</td>\n",
       "      <td>30.329361</td>\n",
       "      <td>29.997042</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>125</th>\n",
       "      <td>A</td>\n",
       "      <td>6.758283</td>\n",
       "      <td>8.754836</td>\n",
       "      <td>6.111102</td>\n",
       "      <td>4.592056</td>\n",
       "      <td>7.212671</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>126</th>\n",
       "      <td>F</td>\n",
       "      <td>344.103638</td>\n",
       "      <td>371.472931</td>\n",
       "      <td>366.870972</td>\n",
       "      <td>373.068787</td>\n",
       "      <td>345.25946</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>127</th>\n",
       "      <td>Hd</td>\n",
       "      <td>344.613953</td>\n",
       "      <td>372.338654</td>\n",
       "      <td>367.68576</td>\n",
       "      <td>373.124146</td>\n",
       "      <td>345.64212</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>128</th>\n",
       "      <td>H</td>\n",
       "      <td>365.583038</td>\n",
       "      <td>393.927063</td>\n",
       "      <td>394.027832</td>\n",
       "      <td>392.164703</td>\n",
       "      <td>365.953125</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "     0           7           8           9           10          11\n",
       "99   Hd     5.58987    3.881607    4.543558    5.443723    4.543383\n",
       "100   H   19.137266   17.948874   18.324638   17.267794   16.813477\n",
       "101   A    0.783993    1.215868    0.592541    0.581104    0.838376\n",
       "102   F    4.603269     3.90245     6.33707    6.682795     4.82113\n",
       "103  Hd     4.71511    4.126276    6.411267      6.7142    4.950005\n",
       "104   H   20.087797   19.162827   20.183628   19.750025   18.357061\n",
       "105   A    0.730662    1.450283    1.222855    0.704192     1.69537\n",
       "106   F     5.34078    4.049546    6.573922     5.84883    4.495585\n",
       "107  Hd    5.422429    4.279365    6.709745    5.923417    4.827442\n",
       "108   H   21.611397   20.505722   22.066854   20.677406   19.252426\n",
       "109   A    1.585943    2.666184    2.517449     1.86857    1.808951\n",
       "110   F    4.940273    4.231506    7.724101    6.974767    5.954541\n",
       "111  Hd    5.468703    5.110724     8.27581    7.212241    6.262464\n",
       "112   H   23.517799   22.213755   25.222637   22.568491   21.108782\n",
       "113   A    3.093532    2.694281     2.05949    1.771824    1.927072\n",
       "114   F    6.747706    5.992997    6.937422    7.696311    5.701997\n",
       "115  Hd    7.747508    6.487483    7.319106    7.875821     6.06427\n",
       "116   H   25.291777   24.269609   27.259624   24.313761   22.881435\n",
       "117   A    3.195761     3.83004     4.28471    3.304821    3.269844\n",
       "118   F    8.433399    7.096021     8.22958   11.155994    6.829512\n",
       "119  Hd    9.343595     9.08414   10.062134   11.969229    8.200727\n",
       "120   H   29.047668   29.251232   32.319176    27.09149    26.53953\n",
       "121   A    3.079163    3.478113    2.703693    2.087703    2.854382\n",
       "122   F    8.160327    8.320078    9.885159   11.088625    7.948795\n",
       "123  Hd    9.372631    9.633545   10.671306   11.384894    8.794673\n",
       "124   H   33.131432   34.066902   36.608715   30.329361   29.997042\n",
       "125   A    6.758283    8.754836    6.111102    4.592056    7.212671\n",
       "126   F  344.103638  371.472931  366.870972  373.068787   345.25946\n",
       "127  Hd  344.613953  372.338654   367.68576  373.124146   345.64212\n",
       "128   H  365.583038  393.927063  394.027832  392.164703  365.953125"
      ]
     },
     "execution_count": 81,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "out.iloc[-30:,[0, 7, 8, 9, 10, 11]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "decfa7d9-21bb-4e24-95d9-6df0b81de47b",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt = 'My favorite element of the periodic table is platinum'\n",
    "input = tokenizer(prompt, return_tensors='pt').to('cuda')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "ad4e796c-a04b-4892-9f40-f385549a2bbb",
   "metadata": {},
   "outputs": [],
   "source": [
    "out = model(**input, output_hidden_states=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "fe9834f7-07e3-4caa-9025-c6e5b7bcf62a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "33"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(out.hidden_states)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "87fc1e36-6de1-4247-9a7d-18ac606bb1e0",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "MistralDecoderLayer(\n",
       "  (self_attn): MistralAttention(\n",
       "    (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "    (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "    (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "    (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "    (rotary_emb): MistralRotaryEmbedding()\n",
       "  )\n",
       "  (mlp): MistralMLP(\n",
       "    (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "    (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "    (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "    (act_fn): SiLU()\n",
       "  )\n",
       "  (input_layernorm): MistralRMSNorm()\n",
       "  (post_attention_layernorm): MistralRMSNorm()\n",
       ")"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.model.layers[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "ef059c1d-b4ce-42b6-b7cd-9f6a5142d958",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'▁same'"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "topk(model.lm_head(out.hidden_states[0])[0,5,:]).iloc[0,0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "94f626e3-d5bc-492d-8f87-0a2888f81f21",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Tensor"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "type(out.hidden_states[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "88a08972-1edf-4273-b8e0-39856800f64e",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
