{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "95f8de05-be82-4ca0-8a8f-8de53241ed03",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "from transformers.models.mistral.modeling_mistral import *\n",
    "from transformers.cache_utils import DynamicCache\n",
    "import torch\n",
    "from torch import nn\n",
    "import copy\n",
    "from types import MethodType\n",
    "import datasets\n",
    "from torch.utils.data import DataLoader\n",
    "from itertools import islice\n",
    "from tqdm import tqdm\n",
    "import gc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "74e9a972-ce73-4797-93e9-57258a0f7eee",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0ac6596ff8dc40038eb9b0a0b30fb878",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "model = AutoModelForCausalLM.from_pretrained(\"mistralai/Mistral-7B-v0.1\").to('cuda')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "02262ef1-c753-4884-af10-d39304bdb6f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained(\"mistralai/Mistral-7B-v0.1\")\n",
    "Token = {v: k for k, v in tokenizer.get_vocab().items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "cc8d4a8d-4da7-4031-af43-b18dc749b67c",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = datasets.load_from_disk(f'/home/XXXXXXXX/msmarco_mistral_test').with_format('torch', device=torch.device('cuda'))\n",
    "loader = DataLoader(dataset, batch_size=128)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "a481125f-93e8-4f65-ac3e-ff1913597c91",
   "metadata": {},
   "outputs": [],
   "source": [
    "class LoraLinear(nn.Module):\n",
    "    def __init__(self, in_dim, out_dim, rank, bias):\n",
    "        super().__init__()\n",
    "        self.A = nn.Linear(in_dim, rank, bias=False)\n",
    "        self.B = nn.Linear(rank, out_dim, bias=bias)\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.B(self.A(x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "b1737320-d658-4061-bb88-1d428548b329",
   "metadata": {},
   "outputs": [],
   "source": [
    "def to_lora(model, rank, module_names):\n",
    "    '''\n",
    "    Returns a copy of model Linear switched to LoraLinear modules.\n",
    "    '''\n",
    "    modules = dict(model.named_modules())\n",
    "    for name, module in modules.items():\n",
    "        parent = '.'.join(name.split('.')[:-1])\n",
    "        child = name.split('.')[-1]\n",
    "        if type(module) == nn.Linear and child in module_names:\n",
    "            lora_module = LoraLinear(\n",
    "                module.in_features, module.out_features, rank, module.bias is not None\n",
    "            ).to(model.device)\n",
    "            setattr(modules[parent], child, lora_module)\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "7a85dbd5-adae-449d-bd3e-bc9d72ff0c29",
   "metadata": {},
   "outputs": [],
   "source": [
    "def to_myopic(model, past_key_value):\n",
    "    past_key_value = DynamicCache.from_legacy_cache(past_key_value)\n",
    "    def forward(*args, **kwargs):\n",
    "        # This is very hacky, but otherwise it's hard to provide past_key_values\n",
    "        # to myopic_forward without breaking a lot of MistralModel\n",
    "        nonlocal past_key_value\n",
    "        kwargs.pop('past_key_value')\n",
    "        return myopic_forward(*args, **kwargs, past_key_value=past_key_value)\n",
    "    \n",
    "    for module in model.modules():\n",
    "        if type(module) == MistralAttention:\n",
    "            module.forward = MethodType(forward, module)\n",
    "            module.extra_repr = lambda: 'MYOPIC'\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "d5882d46-0aef-4091-9729-4d11913cd5d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "def myopic_forward_mistral(\n",
    "    self,\n",
    "    hidden_states: torch.Tensor,\n",
    "    attention_mask: Optional[torch.Tensor] = None,\n",
    "    position_ids: Optional[torch.LongTensor] = None,\n",
    "    past_key_value: Optional[Cache] = None,\n",
    "    output_attentions: bool = False,\n",
    "    **kwargs,\n",
    "):\n",
    "    bsz, q_len, _ = hidden_states.size()\n",
    "\n",
    "    query_states = self.q_proj(hidden_states)\n",
    "    key_states = self.k_proj(hidden_states)\n",
    "    value_states = self.v_proj(hidden_states)\n",
    "\n",
    "    query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)\n",
    "    key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)\n",
    "    value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)\n",
    "    \n",
    "    kv_seq_len = key_states.shape[-2]\n",
    "    cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)\n",
    "    query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)\n",
    "\n",
    "    past_key_states = past_key_value.key_cache[self.layer_idx].detach()\n",
    "    past_value_states = past_key_value.value_cache[self.layer_idx].detach()\n",
    "\n",
    "    assert key_states.shape == past_key_states.shape, \\\n",
    "        f'past_key_states is wrong shape: {past_key_states.shape} instead of {key_states.shape}'\n",
    "    assert value_states.shape == past_value_states.shape, \\\n",
    "        f'past_value_states is wrong shape: {past_value_states.shape} instead of {value_states.shape}'\n",
    "\n",
    "    # repeat k/v heads if n_kv_heads < n_heads\n",
    "    key_states = repeat_kv(key_states, self.num_key_value_groups)\n",
    "    value_states = repeat_kv(value_states, self.num_key_value_groups)\n",
    "    past_key_states = repeat_kv(past_key_states, self.num_key_value_groups)\n",
    "    past_value_states = repeat_kv(past_value_states, self.num_key_value_groups)\n",
    "    #print('KEY DIFF', torch.norm(key_states - past_key_states).item())\n",
    "    #print('VAL DIFF', torch.norm(value_states - past_value_states).item())\n",
    "\n",
    "    # query @ past_key on off-diagonal\n",
    "    attn_weights = torch.matmul(query_states, past_key_states.transpose(2, 3))\n",
    "    # query @ key on diagonal\n",
    "    #print('ATTN DIFF', torch.norm(attn_weights.diagonal(dim1=2, dim2=3)-(query_states * key_states).sum(dim=3)).item())\n",
    "    attn_weights.diagonal(dim1=2, dim2=3).copy_((query_states * key_states).sum(dim=3))\n",
    "    attn_weights /= math.sqrt(self.head_dim)\n",
    "\n",
    "    if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):\n",
    "        raise ValueError(\n",
    "            f\"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is\"\n",
    "            f\" {attn_weights.size()}\"\n",
    "        )\n",
    "\n",
    "    if attention_mask is not None:\n",
    "        if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):\n",
    "            raise ValueError(\n",
    "                f\"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}\"\n",
    "            )\n",
    "\n",
    "        attn_weights = attn_weights + attention_mask\n",
    "\n",
    "    # upcast attention to fp32\n",
    "    attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)\n",
    "    attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)\n",
    "    # attn @ past_value on off-diagonal\n",
    "    attn_output = torch.matmul(attn_weights, past_value_states)\n",
    "    # attn @ value on diagonal\n",
    "    #print('VAL DIFF', torch.norm(attn_weights.diagonal(dim1=2, dim2=3).unsqueeze(dim=3) * (value_states - past_value_states)).item())\n",
    "    attn_output += attn_weights.diagonal(dim1=2, dim2=3).unsqueeze(dim=3) * (value_states - past_value_states)\n",
    "\n",
    "    if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):\n",
    "        raise ValueError(\n",
    "            f\"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is\"\n",
    "            f\" {attn_output.size()}\"\n",
    "        )\n",
    "\n",
    "    attn_output = attn_output.transpose(1, 2).contiguous()\n",
    "    attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)\n",
    "\n",
    "    attn_output = self.o_proj(attn_output)\n",
    "\n",
    "    if not output_attentions:\n",
    "        attn_weights = None\n",
    "\n",
    "    return attn_output, attn_weights, past_key_value"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "c373d858-f740-42c8-9d6b-205878370856",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "MistralForCausalLM(\n",
       "  (model): MistralModel(\n",
       "    (embed_tokens): Embedding(32000, 4096)\n",
       "    (layers): ModuleList(\n",
       "      (0-31): 32 x MistralDecoderLayer(\n",
       "        (self_attn): MistralAttention(\n",
       "          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "          (rotary_emb): MistralRotaryEmbedding()\n",
       "        )\n",
       "        (mlp): MistralMLP(\n",
       "          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "          (act_fn): SiLU()\n",
       "        )\n",
       "        (input_layernorm): MistralRMSNorm()\n",
       "        (post_attention_layernorm): MistralRMSNorm()\n",
       "      )\n",
       "    )\n",
       "    (norm): MistralRMSNorm()\n",
       "  )\n",
       "  (lm_head): Linear(in_features=4096, out_features=32000, bias=False)\n",
       ")"
      ]
     },
     "execution_count": 42,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "8ed82ba8-895b-4c62-b4ab-b771a9e6c427",
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'myopic' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[33], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28;01mdel\u001b[39;00m myopic\n\u001b[1;32m      2\u001b[0m gc\u001b[38;5;241m.\u001b[39mcollect()\n\u001b[1;32m      3\u001b[0m torch\u001b[38;5;241m.\u001b[39mcuda\u001b[38;5;241m.\u001b[39mempty_cache()\n",
      "\u001b[0;31mNameError\u001b[0m: name 'myopic' is not defined"
     ]
    }
   ],
   "source": [
    "del myopic\n",
    "gc.collect()\n",
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "f09015c9-0f10-461e-b030-9b4067e7b6e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "myopic = copy.deepcopy(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "48b1e66f-ce51-403f-b0a6-caaed9782d83",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch.autograd.grad_mode.set_grad_enabled at 0x7f517548c910>"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.set_grad_enabled(False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "c5a82db9-45c4-4f4e-a650-84e0106f0c29",
   "metadata": {},
   "outputs": [],
   "source": [
    "for name, param in myopic.named_parameters():\n",
    "    #if 'k_proj' in name or 'v_proj' in name:\n",
    "    if 'mlp' in name:\n",
    "        param.zero_()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "896d1c2e-dd9b-41bb-adfe-c4a7ee528913",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [02:14<00:00, 13.44s/it]\n"
     ]
    }
   ],
   "source": [
    "num_batches = 10\n",
    "it = islice(iter(loader), num_batches)\n",
    "loss = 0\n",
    "myopic_loss = 0\n",
    "for batch in tqdm(it, total=num_batches):\n",
    "    out = model(\n",
    "        input_ids=batch['input_ids'], \n",
    "        labels=batch['input_ids'], \n",
    "        attention_mask=batch['attention_mask']\n",
    "    )\n",
    "    myopic = to_myopic(myopic, out.past_key_values)\n",
    "    myopic_out = myopic(\n",
    "        input_ids=batch['input_ids'], \n",
    "        labels=batch['input_ids'], \n",
    "        attention_mask=batch['attention_mask']\n",
    "    )\n",
    "    loss += out.loss.item()\n",
    "    myopic_loss += myopic_out.loss.item()\n",
    "loss /= num_batches\n",
    "myopic_loss /= num_batches"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c5eb312b-4004-41ba-ab6d-f163d425b6a0",
   "metadata": {},
   "source": [
    "full: 2.247\n",
    "attn only: 7.708\n",
    "mlp only: 6.929\n",
    "mlp + q: 2.309"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "b909d309-a1bb-418f-9afc-625bd1369c03",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(2.246629166603088, 7.708231830596924)"
      ]
     },
     "execution_count": 37,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "loss, myopic_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 231,
   "id": "6f2b5d62-0f9c-47f0-88eb-7cd662e20832",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "MistralForCausalLM(\n",
       "  (model): MistralModel(\n",
       "    (embed_tokens): Embedding(32000, 4096)\n",
       "    (layers): ModuleList(\n",
       "      (0-31): 32 x MistralDecoderLayer(\n",
       "        (self_attn): MistralAttention(\n",
       "          MYOPIC\n",
       "          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "          (rotary_emb): MistralRotaryEmbedding()\n",
       "        )\n",
       "        (mlp): MistralMLP(\n",
       "          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "          (act_fn): SiLU()\n",
       "        )\n",
       "        (input_layernorm): MistralRMSNorm()\n",
       "        (post_attention_layernorm): MistralRMSNorm()\n",
       "      )\n",
       "    )\n",
       "    (norm): MistralRMSNorm()\n",
       "  )\n",
       "  (lm_head): Linear(in_features=4096, out_features=32000, bias=False)\n",
       ")"
      ]
     },
     "execution_count": 231,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "to_myopic(myopic, out.past_key_values)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 162,
   "id": "ec6f0667-3916-4046-a5fe-ea8f0c8d410c",
   "metadata": {},
   "outputs": [],
   "source": [
    "out = model(**input)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 233,
   "id": "9979be34-799f-47e3-a653-b13b17020245",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.0003, device='cuda:0', grad_fn=<LinalgVectorNormBackward0>)"
      ]
     },
     "execution_count": 233,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(out.logits - myopic_out.logits).norm()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "35cbf72f-c864-4250-9c80-0ecb422a6b42",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = model.to('cuda')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "bcb3e17d-1d6f-4e45-81e2-3e0e0597864f",
   "metadata": {},
   "outputs": [],
   "source": [
    "lora_model = lora_model.to('cuda')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "4a10b0d2-47cb-462e-a771-1413c5e6a718",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "device(type='cpu')"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "951131ee-ed29-46bd-90e4-96c85057caa5",
   "metadata": {},
   "outputs": [],
   "source": [
    "input = tokenizer(\"Hello my name is\", return_tensors='pt').to('cuda')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 240,
   "id": "091f882e-c281-41bd-944b-4efa35f6348c",
   "metadata": {},
   "outputs": [],
   "source": [
    "lora = to_lora(myopic, 8, ['k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 245,
   "id": "2b281b90-980e-489b-be3a-be6c09ed8c81",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "KEY DIFF 142.73130798339844\n",
      "VAL DIFF 12.946982383728027\n",
      "KEY DIFF 217.15673828125\n",
      "VAL DIFF 18.909976959228516\n",
      "KEY DIFF 186.13214111328125\n",
      "VAL DIFF 67.96466827392578\n",
      "KEY DIFF 188.8489227294922\n",
      "VAL DIFF 53.43036651611328\n",
      "KEY DIFF 174.39666748046875\n",
      "VAL DIFF 78.77042388916016\n",
      "KEY DIFF 186.60130310058594\n",
      "VAL DIFF 68.82740020751953\n",
      "KEY DIFF 188.38514709472656\n",
      "VAL DIFF 74.36092376708984\n",
      "KEY DIFF 195.12734985351562\n",
      "VAL DIFF 70.07445526123047\n",
      "KEY DIFF 199.66671752929688\n",
      "VAL DIFF 72.25618743896484\n",
      "KEY DIFF 182.92108154296875\n",
      "VAL DIFF 97.33354949951172\n",
      "KEY DIFF 212.75222778320312\n",
      "VAL DIFF 94.85765075683594\n",
      "KEY DIFF 205.39918518066406\n",
      "VAL DIFF 87.05265045166016\n",
      "KEY DIFF 209.25807189941406\n",
      "VAL DIFF 113.35973358154297\n",
      "KEY DIFF 208.14002990722656\n",
      "VAL DIFF 117.33596801757812\n",
      "KEY DIFF 206.93878173828125\n",
      "VAL DIFF 114.59712219238281\n",
      "KEY DIFF 204.29835510253906\n",
      "VAL DIFF 153.5397186279297\n",
      "KEY DIFF 214.47869873046875\n",
      "VAL DIFF 134.12245178222656\n",
      "KEY DIFF 222.3302764892578\n",
      "VAL DIFF 126.56486511230469\n",
      "KEY DIFF 250.53147888183594\n",
      "VAL DIFF 146.58984375\n"
     ]
    },
    {
     "ename": "OutOfMemoryError",
     "evalue": "CUDA out of memory. Tried to allocate 2.00 MiB. GPU 0 has a total capacty of 79.15 GiB of which 2.12 MiB is free. Process 1531090 has 79.13 GiB memory in use. Of the allocated memory 78.63 GiB is allocated by PyTorch, and 15.09 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mOutOfMemoryError\u001b[0m                          Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[245], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m lora_out \u001b[38;5;241m=\u001b[39m \u001b[43mlora\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1516\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1525\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1526\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m   1530\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "File \u001b[0;32m~/.local/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:1053\u001b[0m, in \u001b[0;36mMistralForCausalLM.forward\u001b[0;34m(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m   1050\u001b[0m return_dict \u001b[38;5;241m=\u001b[39m return_dict \u001b[38;5;28;01mif\u001b[39;00m return_dict \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39muse_return_dict\n\u001b[1;32m   1052\u001b[0m \u001b[38;5;66;03m# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\u001b[39;00m\n\u001b[0;32m-> 1053\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m   1054\u001b[0m \u001b[43m    \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1055\u001b[0m \u001b[43m    \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1056\u001b[0m \u001b[43m    \u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1057\u001b[0m \u001b[43m    \u001b[49m\u001b[43mpast_key_values\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1058\u001b[0m \u001b[43m    \u001b[49m\u001b[43minputs_embeds\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs_embeds\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1059\u001b[0m \u001b[43m    \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1060\u001b[0m \u001b[43m    \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1061\u001b[0m \u001b[43m    \u001b[49m\u001b[43moutput_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1062\u001b[0m \u001b[43m    \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1063\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1065\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m outputs[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m   1066\u001b[0m logits \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlm_head(hidden_states)\n",
      "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1516\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1525\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1526\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m   1530\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "File \u001b[0;32m~/.local/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:938\u001b[0m, in \u001b[0;36mMistralModel.forward\u001b[0;34m(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m    928\u001b[0m     layer_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_gradient_checkpointing_func(\n\u001b[1;32m    929\u001b[0m         decoder_layer\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__call__\u001b[39m,\n\u001b[1;32m    930\u001b[0m         hidden_states,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    935\u001b[0m         use_cache,\n\u001b[1;32m    936\u001b[0m     )\n\u001b[1;32m    937\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 938\u001b[0m     layer_outputs \u001b[38;5;241m=\u001b[39m \u001b[43mdecoder_layer\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    939\u001b[0m \u001b[43m        \u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    940\u001b[0m \u001b[43m        \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    941\u001b[0m \u001b[43m        \u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    942\u001b[0m \u001b[43m        \u001b[49m\u001b[43mpast_key_value\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    943\u001b[0m \u001b[43m        \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    944\u001b[0m \u001b[43m        \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    945\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    947\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m layer_outputs[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m    949\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m use_cache:\n",
      "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1516\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1525\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1526\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m   1530\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "File \u001b[0;32m~/.local/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:676\u001b[0m, in \u001b[0;36mMistralDecoderLayer.forward\u001b[0;34m(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, **kwargs)\u001b[0m\n\u001b[1;32m    674\u001b[0m residual \u001b[38;5;241m=\u001b[39m hidden_states\n\u001b[1;32m    675\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpost_attention_layernorm(hidden_states)\n\u001b[0;32m--> 676\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmlp\u001b[49m\u001b[43m(\u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    677\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m residual \u001b[38;5;241m+\u001b[39m hidden_states\n\u001b[1;32m    679\u001b[0m outputs \u001b[38;5;241m=\u001b[39m (hidden_states,)\n",
      "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1516\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1525\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1526\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m   1530\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "File \u001b[0;32m~/.local/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:177\u001b[0m, in \u001b[0;36mMistralMLP.forward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m    176\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, x):\n\u001b[0;32m--> 177\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdown_proj(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mact_fn(\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgate_proj\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m) \u001b[38;5;241m*\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mup_proj(x))\n",
      "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1516\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1525\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1526\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m   1530\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "Cell \u001b[0;32mIn[4], line 8\u001b[0m, in \u001b[0;36mLoraLinear.forward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m      7\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, x):\n\u001b[0;32m----> 8\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mB\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mA\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1516\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1525\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1526\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m   1530\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:114\u001b[0m, in \u001b[0;36mLinear.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m    113\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m: Tensor) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor:\n\u001b[0;32m--> 114\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlinear\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbias\u001b[49m\u001b[43m)\u001b[49m\n",
      "\u001b[0;31mOutOfMemoryError\u001b[0m: CUDA out of memory. Tried to allocate 2.00 MiB. GPU 0 has a total capacty of 79.15 GiB of which 2.12 MiB is free. Process 1531090 has 79.13 GiB memory in use. Of the allocated memory 78.63 GiB is allocated by PyTorch, and 15.09 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF"
     ]
    }
   ],
   "source": [
    "lora_out = lora(**input)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "id": "7cf23a1e-e9f3-452b-a4a8-60d7b5ee1a70",
   "metadata": {},
   "outputs": [],
   "source": [
    "out = model(\n",
    "    **input,\n",
    ")\n",
    "#lora_out = lora_model(**input)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 85,
   "id": "c0767e0a-c9b9-4df5-841e-4524824b275d",
   "metadata": {},
   "outputs": [],
   "source": [
    "cache = DynamicCache.from_legacy_cache(out.past_key_values)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 91,
   "id": "f4dff416-b7a6-4b17-8d5f-180048548459",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "32"
      ]
     },
     "execution_count": 91,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(cache.key_cache)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "1d25fbcf-b13a-4f33-9949-5ef97675997e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 5, 32000])"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "lora_out.logits.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "a4c97b57-500e-4233-8a63-f5ce950777d1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "type(model.model.layers[0].self_attn.q_proj) == nn.Linear"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "7f75ba63-f5f5-4928-8a33-99a9c9601aea",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['__module__',\n",
       " '_get_no_split_modules',\n",
       " '_keep_in_fp32_modules',\n",
       " '_keep_in_fp32_modules',\n",
       " '_modules',\n",
       " '_no_split_modules',\n",
       " 'add_module',\n",
       " 'get_submodule',\n",
       " 'modules',\n",
       " 'named_modules',\n",
       " 'register_module',\n",
       " 'retrieve_modules_from_names']"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "[x for x in dir(model) if 'module' in x]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "70d17cde-4ca2-47a5-9106-63b2ec4beded",
   "metadata": {},
   "outputs": [
    {
     "ename": "SyntaxError",
     "evalue": "cannot assign to function call here. Maybe you meant '==' instead of '='? (3646139102.py, line 1)",
     "output_type": "error",
     "traceback": [
      "\u001b[0;36m  Cell \u001b[0;32mIn[15], line 1\u001b[0;36m\u001b[0m\n\u001b[0;31m    model.get_submodule('model.layers.0.self_attn.q_proj') = nn.Linear(5, 5)\u001b[0m\n\u001b[0m    ^\u001b[0m\n\u001b[0;31mSyntaxError\u001b[0m\u001b[0;31m:\u001b[0m cannot assign to function call here. Maybe you meant '==' instead of '='?\n"
     ]
    }
   ],
   "source": [
    "model.get_submodule('model.layers.0.self_attn.q_proj') = nn.Linear(5, 5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "a2fa96ca-d8f0-401f-8ee7-49f67157681d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'': MistralForCausalLM(\n",
       "   (model): MistralModel(\n",
       "     (embed_tokens): Embedding(32000, 4096)\n",
       "     (layers): ModuleList(\n",
       "       (0-31): 32 x MistralDecoderLayer(\n",
       "         (self_attn): MistralAttention(\n",
       "           (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "           (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "           (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "           (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "           (rotary_emb): MistralRotaryEmbedding()\n",
       "         )\n",
       "         (mlp): MistralMLP(\n",
       "           (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "           (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "           (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "           (act_fn): SiLU()\n",
       "         )\n",
       "         (input_layernorm): MistralRMSNorm()\n",
       "         (post_attention_layernorm): MistralRMSNorm()\n",
       "       )\n",
       "     )\n",
       "     (norm): MistralRMSNorm()\n",
       "   )\n",
       "   (lm_head): Linear(in_features=4096, out_features=32000, bias=False)\n",
       " ),\n",
       " 'model': MistralModel(\n",
       "   (embed_tokens): Embedding(32000, 4096)\n",
       "   (layers): ModuleList(\n",
       "     (0-31): 32 x MistralDecoderLayer(\n",
       "       (self_attn): MistralAttention(\n",
       "         (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "         (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "         (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "         (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "         (rotary_emb): MistralRotaryEmbedding()\n",
       "       )\n",
       "       (mlp): MistralMLP(\n",
       "         (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "         (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "         (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "         (act_fn): SiLU()\n",
       "       )\n",
       "       (input_layernorm): MistralRMSNorm()\n",
       "       (post_attention_layernorm): MistralRMSNorm()\n",
       "     )\n",
       "   )\n",
       "   (norm): MistralRMSNorm()\n",
       " ),\n",
       " 'model.embed_tokens': Embedding(32000, 4096),\n",
       " 'model.layers': ModuleList(\n",
       "   (0-31): 32 x MistralDecoderLayer(\n",
       "     (self_attn): MistralAttention(\n",
       "       (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "       (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "       (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "       (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "       (rotary_emb): MistralRotaryEmbedding()\n",
       "     )\n",
       "     (mlp): MistralMLP(\n",
       "       (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "       (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "       (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "       (act_fn): SiLU()\n",
       "     )\n",
       "     (input_layernorm): MistralRMSNorm()\n",
       "     (post_attention_layernorm): MistralRMSNorm()\n",
       "   )\n",
       " ),\n",
       " 'model.layers.0': MistralDecoderLayer(\n",
       "   (self_attn): MistralAttention(\n",
       "     (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (rotary_emb): MistralRotaryEmbedding()\n",
       "   )\n",
       "   (mlp): MistralMLP(\n",
       "     (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "     (act_fn): SiLU()\n",
       "   )\n",
       "   (input_layernorm): MistralRMSNorm()\n",
       "   (post_attention_layernorm): MistralRMSNorm()\n",
       " ),\n",
       " 'model.layers.0.self_attn': MistralAttention(\n",
       "   (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (rotary_emb): MistralRotaryEmbedding()\n",
       " ),\n",
       " 'model.layers.0.self_attn.q_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.0.self_attn.k_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.0.self_attn.v_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.0.self_attn.o_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.0.self_attn.rotary_emb': MistralRotaryEmbedding(),\n",
       " 'model.layers.0.mlp': MistralMLP(\n",
       "   (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "   (act_fn): SiLU()\n",
       " ),\n",
       " 'model.layers.0.mlp.gate_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.0.mlp.up_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.0.mlp.down_proj': Linear(in_features=14336, out_features=4096, bias=False),\n",
       " 'model.layers.0.mlp.act_fn': SiLU(),\n",
       " 'model.layers.0.input_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.0.post_attention_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.1': MistralDecoderLayer(\n",
       "   (self_attn): MistralAttention(\n",
       "     (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (rotary_emb): MistralRotaryEmbedding()\n",
       "   )\n",
       "   (mlp): MistralMLP(\n",
       "     (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "     (act_fn): SiLU()\n",
       "   )\n",
       "   (input_layernorm): MistralRMSNorm()\n",
       "   (post_attention_layernorm): MistralRMSNorm()\n",
       " ),\n",
       " 'model.layers.1.self_attn': MistralAttention(\n",
       "   (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (rotary_emb): MistralRotaryEmbedding()\n",
       " ),\n",
       " 'model.layers.1.self_attn.q_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.1.self_attn.k_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.1.self_attn.v_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.1.self_attn.o_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.1.self_attn.rotary_emb': MistralRotaryEmbedding(),\n",
       " 'model.layers.1.mlp': MistralMLP(\n",
       "   (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "   (act_fn): SiLU()\n",
       " ),\n",
       " 'model.layers.1.mlp.gate_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.1.mlp.up_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.1.mlp.down_proj': Linear(in_features=14336, out_features=4096, bias=False),\n",
       " 'model.layers.1.mlp.act_fn': SiLU(),\n",
       " 'model.layers.1.input_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.1.post_attention_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.2': MistralDecoderLayer(\n",
       "   (self_attn): MistralAttention(\n",
       "     (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (rotary_emb): MistralRotaryEmbedding()\n",
       "   )\n",
       "   (mlp): MistralMLP(\n",
       "     (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "     (act_fn): SiLU()\n",
       "   )\n",
       "   (input_layernorm): MistralRMSNorm()\n",
       "   (post_attention_layernorm): MistralRMSNorm()\n",
       " ),\n",
       " 'model.layers.2.self_attn': MistralAttention(\n",
       "   (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (rotary_emb): MistralRotaryEmbedding()\n",
       " ),\n",
       " 'model.layers.2.self_attn.q_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.2.self_attn.k_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.2.self_attn.v_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.2.self_attn.o_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.2.self_attn.rotary_emb': MistralRotaryEmbedding(),\n",
       " 'model.layers.2.mlp': MistralMLP(\n",
       "   (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "   (act_fn): SiLU()\n",
       " ),\n",
       " 'model.layers.2.mlp.gate_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.2.mlp.up_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.2.mlp.down_proj': Linear(in_features=14336, out_features=4096, bias=False),\n",
       " 'model.layers.2.mlp.act_fn': SiLU(),\n",
       " 'model.layers.2.input_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.2.post_attention_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.3': MistralDecoderLayer(\n",
       "   (self_attn): MistralAttention(\n",
       "     (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (rotary_emb): MistralRotaryEmbedding()\n",
       "   )\n",
       "   (mlp): MistralMLP(\n",
       "     (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "     (act_fn): SiLU()\n",
       "   )\n",
       "   (input_layernorm): MistralRMSNorm()\n",
       "   (post_attention_layernorm): MistralRMSNorm()\n",
       " ),\n",
       " 'model.layers.3.self_attn': MistralAttention(\n",
       "   (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (rotary_emb): MistralRotaryEmbedding()\n",
       " ),\n",
       " 'model.layers.3.self_attn.q_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.3.self_attn.k_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.3.self_attn.v_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.3.self_attn.o_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.3.self_attn.rotary_emb': MistralRotaryEmbedding(),\n",
       " 'model.layers.3.mlp': MistralMLP(\n",
       "   (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "   (act_fn): SiLU()\n",
       " ),\n",
       " 'model.layers.3.mlp.gate_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.3.mlp.up_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.3.mlp.down_proj': Linear(in_features=14336, out_features=4096, bias=False),\n",
       " 'model.layers.3.mlp.act_fn': SiLU(),\n",
       " 'model.layers.3.input_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.3.post_attention_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.4': MistralDecoderLayer(\n",
       "   (self_attn): MistralAttention(\n",
       "     (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (rotary_emb): MistralRotaryEmbedding()\n",
       "   )\n",
       "   (mlp): MistralMLP(\n",
       "     (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "     (act_fn): SiLU()\n",
       "   )\n",
       "   (input_layernorm): MistralRMSNorm()\n",
       "   (post_attention_layernorm): MistralRMSNorm()\n",
       " ),\n",
       " 'model.layers.4.self_attn': MistralAttention(\n",
       "   (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (rotary_emb): MistralRotaryEmbedding()\n",
       " ),\n",
       " 'model.layers.4.self_attn.q_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.4.self_attn.k_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.4.self_attn.v_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.4.self_attn.o_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.4.self_attn.rotary_emb': MistralRotaryEmbedding(),\n",
       " 'model.layers.4.mlp': MistralMLP(\n",
       "   (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "   (act_fn): SiLU()\n",
       " ),\n",
       " 'model.layers.4.mlp.gate_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.4.mlp.up_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.4.mlp.down_proj': Linear(in_features=14336, out_features=4096, bias=False),\n",
       " 'model.layers.4.mlp.act_fn': SiLU(),\n",
       " 'model.layers.4.input_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.4.post_attention_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.5': MistralDecoderLayer(\n",
       "   (self_attn): MistralAttention(\n",
       "     (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (rotary_emb): MistralRotaryEmbedding()\n",
       "   )\n",
       "   (mlp): MistralMLP(\n",
       "     (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "     (act_fn): SiLU()\n",
       "   )\n",
       "   (input_layernorm): MistralRMSNorm()\n",
       "   (post_attention_layernorm): MistralRMSNorm()\n",
       " ),\n",
       " 'model.layers.5.self_attn': MistralAttention(\n",
       "   (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (rotary_emb): MistralRotaryEmbedding()\n",
       " ),\n",
       " 'model.layers.5.self_attn.q_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.5.self_attn.k_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.5.self_attn.v_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.5.self_attn.o_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.5.self_attn.rotary_emb': MistralRotaryEmbedding(),\n",
       " 'model.layers.5.mlp': MistralMLP(\n",
       "   (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "   (act_fn): SiLU()\n",
       " ),\n",
       " 'model.layers.5.mlp.gate_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.5.mlp.up_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.5.mlp.down_proj': Linear(in_features=14336, out_features=4096, bias=False),\n",
       " 'model.layers.5.mlp.act_fn': SiLU(),\n",
       " 'model.layers.5.input_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.5.post_attention_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.6': MistralDecoderLayer(\n",
       "   (self_attn): MistralAttention(\n",
       "     (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (rotary_emb): MistralRotaryEmbedding()\n",
       "   )\n",
       "   (mlp): MistralMLP(\n",
       "     (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "     (act_fn): SiLU()\n",
       "   )\n",
       "   (input_layernorm): MistralRMSNorm()\n",
       "   (post_attention_layernorm): MistralRMSNorm()\n",
       " ),\n",
       " 'model.layers.6.self_attn': MistralAttention(\n",
       "   (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (rotary_emb): MistralRotaryEmbedding()\n",
       " ),\n",
       " 'model.layers.6.self_attn.q_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.6.self_attn.k_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.6.self_attn.v_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.6.self_attn.o_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.6.self_attn.rotary_emb': MistralRotaryEmbedding(),\n",
       " 'model.layers.6.mlp': MistralMLP(\n",
       "   (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "   (act_fn): SiLU()\n",
       " ),\n",
       " 'model.layers.6.mlp.gate_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.6.mlp.up_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.6.mlp.down_proj': Linear(in_features=14336, out_features=4096, bias=False),\n",
       " 'model.layers.6.mlp.act_fn': SiLU(),\n",
       " 'model.layers.6.input_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.6.post_attention_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.7': MistralDecoderLayer(\n",
       "   (self_attn): MistralAttention(\n",
       "     (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (rotary_emb): MistralRotaryEmbedding()\n",
       "   )\n",
       "   (mlp): MistralMLP(\n",
       "     (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "     (act_fn): SiLU()\n",
       "   )\n",
       "   (input_layernorm): MistralRMSNorm()\n",
       "   (post_attention_layernorm): MistralRMSNorm()\n",
       " ),\n",
       " 'model.layers.7.self_attn': MistralAttention(\n",
       "   (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (rotary_emb): MistralRotaryEmbedding()\n",
       " ),\n",
       " 'model.layers.7.self_attn.q_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.7.self_attn.k_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.7.self_attn.v_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.7.self_attn.o_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.7.self_attn.rotary_emb': MistralRotaryEmbedding(),\n",
       " 'model.layers.7.mlp': MistralMLP(\n",
       "   (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "   (act_fn): SiLU()\n",
       " ),\n",
       " 'model.layers.7.mlp.gate_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.7.mlp.up_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.7.mlp.down_proj': Linear(in_features=14336, out_features=4096, bias=False),\n",
       " 'model.layers.7.mlp.act_fn': SiLU(),\n",
       " 'model.layers.7.input_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.7.post_attention_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.8': MistralDecoderLayer(\n",
       "   (self_attn): MistralAttention(\n",
       "     (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (rotary_emb): MistralRotaryEmbedding()\n",
       "   )\n",
       "   (mlp): MistralMLP(\n",
       "     (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "     (act_fn): SiLU()\n",
       "   )\n",
       "   (input_layernorm): MistralRMSNorm()\n",
       "   (post_attention_layernorm): MistralRMSNorm()\n",
       " ),\n",
       " 'model.layers.8.self_attn': MistralAttention(\n",
       "   (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (rotary_emb): MistralRotaryEmbedding()\n",
       " ),\n",
       " 'model.layers.8.self_attn.q_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.8.self_attn.k_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.8.self_attn.v_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.8.self_attn.o_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.8.self_attn.rotary_emb': MistralRotaryEmbedding(),\n",
       " 'model.layers.8.mlp': MistralMLP(\n",
       "   (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "   (act_fn): SiLU()\n",
       " ),\n",
       " 'model.layers.8.mlp.gate_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.8.mlp.up_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.8.mlp.down_proj': Linear(in_features=14336, out_features=4096, bias=False),\n",
       " 'model.layers.8.mlp.act_fn': SiLU(),\n",
       " 'model.layers.8.input_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.8.post_attention_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.9': MistralDecoderLayer(\n",
       "   (self_attn): MistralAttention(\n",
       "     (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (rotary_emb): MistralRotaryEmbedding()\n",
       "   )\n",
       "   (mlp): MistralMLP(\n",
       "     (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "     (act_fn): SiLU()\n",
       "   )\n",
       "   (input_layernorm): MistralRMSNorm()\n",
       "   (post_attention_layernorm): MistralRMSNorm()\n",
       " ),\n",
       " 'model.layers.9.self_attn': MistralAttention(\n",
       "   (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (rotary_emb): MistralRotaryEmbedding()\n",
       " ),\n",
       " 'model.layers.9.self_attn.q_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.9.self_attn.k_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.9.self_attn.v_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.9.self_attn.o_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.9.self_attn.rotary_emb': MistralRotaryEmbedding(),\n",
       " 'model.layers.9.mlp': MistralMLP(\n",
       "   (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "   (act_fn): SiLU()\n",
       " ),\n",
       " 'model.layers.9.mlp.gate_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.9.mlp.up_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.9.mlp.down_proj': Linear(in_features=14336, out_features=4096, bias=False),\n",
       " 'model.layers.9.mlp.act_fn': SiLU(),\n",
       " 'model.layers.9.input_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.9.post_attention_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.10': MistralDecoderLayer(\n",
       "   (self_attn): MistralAttention(\n",
       "     (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (rotary_emb): MistralRotaryEmbedding()\n",
       "   )\n",
       "   (mlp): MistralMLP(\n",
       "     (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "     (act_fn): SiLU()\n",
       "   )\n",
       "   (input_layernorm): MistralRMSNorm()\n",
       "   (post_attention_layernorm): MistralRMSNorm()\n",
       " ),\n",
       " 'model.layers.10.self_attn': MistralAttention(\n",
       "   (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (rotary_emb): MistralRotaryEmbedding()\n",
       " ),\n",
       " 'model.layers.10.self_attn.q_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.10.self_attn.k_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.10.self_attn.v_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.10.self_attn.o_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.10.self_attn.rotary_emb': MistralRotaryEmbedding(),\n",
       " 'model.layers.10.mlp': MistralMLP(\n",
       "   (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "   (act_fn): SiLU()\n",
       " ),\n",
       " 'model.layers.10.mlp.gate_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.10.mlp.up_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.10.mlp.down_proj': Linear(in_features=14336, out_features=4096, bias=False),\n",
       " 'model.layers.10.mlp.act_fn': SiLU(),\n",
       " 'model.layers.10.input_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.10.post_attention_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.11': MistralDecoderLayer(\n",
       "   (self_attn): MistralAttention(\n",
       "     (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (rotary_emb): MistralRotaryEmbedding()\n",
       "   )\n",
       "   (mlp): MistralMLP(\n",
       "     (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "     (act_fn): SiLU()\n",
       "   )\n",
       "   (input_layernorm): MistralRMSNorm()\n",
       "   (post_attention_layernorm): MistralRMSNorm()\n",
       " ),\n",
       " 'model.layers.11.self_attn': MistralAttention(\n",
       "   (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (rotary_emb): MistralRotaryEmbedding()\n",
       " ),\n",
       " 'model.layers.11.self_attn.q_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.11.self_attn.k_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.11.self_attn.v_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.11.self_attn.o_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.11.self_attn.rotary_emb': MistralRotaryEmbedding(),\n",
       " 'model.layers.11.mlp': MistralMLP(\n",
       "   (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "   (act_fn): SiLU()\n",
       " ),\n",
       " 'model.layers.11.mlp.gate_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.11.mlp.up_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.11.mlp.down_proj': Linear(in_features=14336, out_features=4096, bias=False),\n",
       " 'model.layers.11.mlp.act_fn': SiLU(),\n",
       " 'model.layers.11.input_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.11.post_attention_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.12': MistralDecoderLayer(\n",
       "   (self_attn): MistralAttention(\n",
       "     (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (rotary_emb): MistralRotaryEmbedding()\n",
       "   )\n",
       "   (mlp): MistralMLP(\n",
       "     (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "     (act_fn): SiLU()\n",
       "   )\n",
       "   (input_layernorm): MistralRMSNorm()\n",
       "   (post_attention_layernorm): MistralRMSNorm()\n",
       " ),\n",
       " 'model.layers.12.self_attn': MistralAttention(\n",
       "   (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (rotary_emb): MistralRotaryEmbedding()\n",
       " ),\n",
       " 'model.layers.12.self_attn.q_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.12.self_attn.k_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.12.self_attn.v_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.12.self_attn.o_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.12.self_attn.rotary_emb': MistralRotaryEmbedding(),\n",
       " 'model.layers.12.mlp': MistralMLP(\n",
       "   (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "   (act_fn): SiLU()\n",
       " ),\n",
       " 'model.layers.12.mlp.gate_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.12.mlp.up_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.12.mlp.down_proj': Linear(in_features=14336, out_features=4096, bias=False),\n",
       " 'model.layers.12.mlp.act_fn': SiLU(),\n",
       " 'model.layers.12.input_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.12.post_attention_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.13': MistralDecoderLayer(\n",
       "   (self_attn): MistralAttention(\n",
       "     (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (rotary_emb): MistralRotaryEmbedding()\n",
       "   )\n",
       "   (mlp): MistralMLP(\n",
       "     (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "     (act_fn): SiLU()\n",
       "   )\n",
       "   (input_layernorm): MistralRMSNorm()\n",
       "   (post_attention_layernorm): MistralRMSNorm()\n",
       " ),\n",
       " 'model.layers.13.self_attn': MistralAttention(\n",
       "   (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (rotary_emb): MistralRotaryEmbedding()\n",
       " ),\n",
       " 'model.layers.13.self_attn.q_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.13.self_attn.k_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.13.self_attn.v_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.13.self_attn.o_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.13.self_attn.rotary_emb': MistralRotaryEmbedding(),\n",
       " 'model.layers.13.mlp': MistralMLP(\n",
       "   (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "   (act_fn): SiLU()\n",
       " ),\n",
       " 'model.layers.13.mlp.gate_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.13.mlp.up_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.13.mlp.down_proj': Linear(in_features=14336, out_features=4096, bias=False),\n",
       " 'model.layers.13.mlp.act_fn': SiLU(),\n",
       " 'model.layers.13.input_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.13.post_attention_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.14': MistralDecoderLayer(\n",
       "   (self_attn): MistralAttention(\n",
       "     (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (rotary_emb): MistralRotaryEmbedding()\n",
       "   )\n",
       "   (mlp): MistralMLP(\n",
       "     (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "     (act_fn): SiLU()\n",
       "   )\n",
       "   (input_layernorm): MistralRMSNorm()\n",
       "   (post_attention_layernorm): MistralRMSNorm()\n",
       " ),\n",
       " 'model.layers.14.self_attn': MistralAttention(\n",
       "   (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (rotary_emb): MistralRotaryEmbedding()\n",
       " ),\n",
       " 'model.layers.14.self_attn.q_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.14.self_attn.k_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.14.self_attn.v_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.14.self_attn.o_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.14.self_attn.rotary_emb': MistralRotaryEmbedding(),\n",
       " 'model.layers.14.mlp': MistralMLP(\n",
       "   (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "   (act_fn): SiLU()\n",
       " ),\n",
       " 'model.layers.14.mlp.gate_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.14.mlp.up_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.14.mlp.down_proj': Linear(in_features=14336, out_features=4096, bias=False),\n",
       " 'model.layers.14.mlp.act_fn': SiLU(),\n",
       " 'model.layers.14.input_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.14.post_attention_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.15': MistralDecoderLayer(\n",
       "   (self_attn): MistralAttention(\n",
       "     (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (rotary_emb): MistralRotaryEmbedding()\n",
       "   )\n",
       "   (mlp): MistralMLP(\n",
       "     (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "     (act_fn): SiLU()\n",
       "   )\n",
       "   (input_layernorm): MistralRMSNorm()\n",
       "   (post_attention_layernorm): MistralRMSNorm()\n",
       " ),\n",
       " 'model.layers.15.self_attn': MistralAttention(\n",
       "   (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (rotary_emb): MistralRotaryEmbedding()\n",
       " ),\n",
       " 'model.layers.15.self_attn.q_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.15.self_attn.k_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.15.self_attn.v_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.15.self_attn.o_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.15.self_attn.rotary_emb': MistralRotaryEmbedding(),\n",
       " 'model.layers.15.mlp': MistralMLP(\n",
       "   (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "   (act_fn): SiLU()\n",
       " ),\n",
       " 'model.layers.15.mlp.gate_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.15.mlp.up_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.15.mlp.down_proj': Linear(in_features=14336, out_features=4096, bias=False),\n",
       " 'model.layers.15.mlp.act_fn': SiLU(),\n",
       " 'model.layers.15.input_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.15.post_attention_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.16': MistralDecoderLayer(\n",
       "   (self_attn): MistralAttention(\n",
       "     (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (rotary_emb): MistralRotaryEmbedding()\n",
       "   )\n",
       "   (mlp): MistralMLP(\n",
       "     (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "     (act_fn): SiLU()\n",
       "   )\n",
       "   (input_layernorm): MistralRMSNorm()\n",
       "   (post_attention_layernorm): MistralRMSNorm()\n",
       " ),\n",
       " 'model.layers.16.self_attn': MistralAttention(\n",
       "   (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (rotary_emb): MistralRotaryEmbedding()\n",
       " ),\n",
       " 'model.layers.16.self_attn.q_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.16.self_attn.k_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.16.self_attn.v_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.16.self_attn.o_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.16.self_attn.rotary_emb': MistralRotaryEmbedding(),\n",
       " 'model.layers.16.mlp': MistralMLP(\n",
       "   (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "   (act_fn): SiLU()\n",
       " ),\n",
       " 'model.layers.16.mlp.gate_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.16.mlp.up_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.16.mlp.down_proj': Linear(in_features=14336, out_features=4096, bias=False),\n",
       " 'model.layers.16.mlp.act_fn': SiLU(),\n",
       " 'model.layers.16.input_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.16.post_attention_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.17': MistralDecoderLayer(\n",
       "   (self_attn): MistralAttention(\n",
       "     (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (rotary_emb): MistralRotaryEmbedding()\n",
       "   )\n",
       "   (mlp): MistralMLP(\n",
       "     (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "     (act_fn): SiLU()\n",
       "   )\n",
       "   (input_layernorm): MistralRMSNorm()\n",
       "   (post_attention_layernorm): MistralRMSNorm()\n",
       " ),\n",
       " 'model.layers.17.self_attn': MistralAttention(\n",
       "   (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (rotary_emb): MistralRotaryEmbedding()\n",
       " ),\n",
       " 'model.layers.17.self_attn.q_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.17.self_attn.k_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.17.self_attn.v_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.17.self_attn.o_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.17.self_attn.rotary_emb': MistralRotaryEmbedding(),\n",
       " 'model.layers.17.mlp': MistralMLP(\n",
       "   (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "   (act_fn): SiLU()\n",
       " ),\n",
       " 'model.layers.17.mlp.gate_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.17.mlp.up_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.17.mlp.down_proj': Linear(in_features=14336, out_features=4096, bias=False),\n",
       " 'model.layers.17.mlp.act_fn': SiLU(),\n",
       " 'model.layers.17.input_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.17.post_attention_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.18': MistralDecoderLayer(\n",
       "   (self_attn): MistralAttention(\n",
       "     (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (rotary_emb): MistralRotaryEmbedding()\n",
       "   )\n",
       "   (mlp): MistralMLP(\n",
       "     (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "     (act_fn): SiLU()\n",
       "   )\n",
       "   (input_layernorm): MistralRMSNorm()\n",
       "   (post_attention_layernorm): MistralRMSNorm()\n",
       " ),\n",
       " 'model.layers.18.self_attn': MistralAttention(\n",
       "   (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (rotary_emb): MistralRotaryEmbedding()\n",
       " ),\n",
       " 'model.layers.18.self_attn.q_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.18.self_attn.k_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.18.self_attn.v_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.18.self_attn.o_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.18.self_attn.rotary_emb': MistralRotaryEmbedding(),\n",
       " 'model.layers.18.mlp': MistralMLP(\n",
       "   (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "   (act_fn): SiLU()\n",
       " ),\n",
       " 'model.layers.18.mlp.gate_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.18.mlp.up_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.18.mlp.down_proj': Linear(in_features=14336, out_features=4096, bias=False),\n",
       " 'model.layers.18.mlp.act_fn': SiLU(),\n",
       " 'model.layers.18.input_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.18.post_attention_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.19': MistralDecoderLayer(\n",
       "   (self_attn): MistralAttention(\n",
       "     (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (rotary_emb): MistralRotaryEmbedding()\n",
       "   )\n",
       "   (mlp): MistralMLP(\n",
       "     (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "     (act_fn): SiLU()\n",
       "   )\n",
       "   (input_layernorm): MistralRMSNorm()\n",
       "   (post_attention_layernorm): MistralRMSNorm()\n",
       " ),\n",
       " 'model.layers.19.self_attn': MistralAttention(\n",
       "   (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (rotary_emb): MistralRotaryEmbedding()\n",
       " ),\n",
       " 'model.layers.19.self_attn.q_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.19.self_attn.k_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.19.self_attn.v_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.19.self_attn.o_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.19.self_attn.rotary_emb': MistralRotaryEmbedding(),\n",
       " 'model.layers.19.mlp': MistralMLP(\n",
       "   (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "   (act_fn): SiLU()\n",
       " ),\n",
       " 'model.layers.19.mlp.gate_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.19.mlp.up_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.19.mlp.down_proj': Linear(in_features=14336, out_features=4096, bias=False),\n",
       " 'model.layers.19.mlp.act_fn': SiLU(),\n",
       " 'model.layers.19.input_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.19.post_attention_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.20': MistralDecoderLayer(\n",
       "   (self_attn): MistralAttention(\n",
       "     (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (rotary_emb): MistralRotaryEmbedding()\n",
       "   )\n",
       "   (mlp): MistralMLP(\n",
       "     (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "     (act_fn): SiLU()\n",
       "   )\n",
       "   (input_layernorm): MistralRMSNorm()\n",
       "   (post_attention_layernorm): MistralRMSNorm()\n",
       " ),\n",
       " 'model.layers.20.self_attn': MistralAttention(\n",
       "   (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (rotary_emb): MistralRotaryEmbedding()\n",
       " ),\n",
       " 'model.layers.20.self_attn.q_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.20.self_attn.k_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.20.self_attn.v_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.20.self_attn.o_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.20.self_attn.rotary_emb': MistralRotaryEmbedding(),\n",
       " 'model.layers.20.mlp': MistralMLP(\n",
       "   (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "   (act_fn): SiLU()\n",
       " ),\n",
       " 'model.layers.20.mlp.gate_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.20.mlp.up_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.20.mlp.down_proj': Linear(in_features=14336, out_features=4096, bias=False),\n",
       " 'model.layers.20.mlp.act_fn': SiLU(),\n",
       " 'model.layers.20.input_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.20.post_attention_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.21': MistralDecoderLayer(\n",
       "   (self_attn): MistralAttention(\n",
       "     (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (rotary_emb): MistralRotaryEmbedding()\n",
       "   )\n",
       "   (mlp): MistralMLP(\n",
       "     (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "     (act_fn): SiLU()\n",
       "   )\n",
       "   (input_layernorm): MistralRMSNorm()\n",
       "   (post_attention_layernorm): MistralRMSNorm()\n",
       " ),\n",
       " 'model.layers.21.self_attn': MistralAttention(\n",
       "   (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (rotary_emb): MistralRotaryEmbedding()\n",
       " ),\n",
       " 'model.layers.21.self_attn.q_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.21.self_attn.k_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.21.self_attn.v_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.21.self_attn.o_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.21.self_attn.rotary_emb': MistralRotaryEmbedding(),\n",
       " 'model.layers.21.mlp': MistralMLP(\n",
       "   (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "   (act_fn): SiLU()\n",
       " ),\n",
       " 'model.layers.21.mlp.gate_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.21.mlp.up_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.21.mlp.down_proj': Linear(in_features=14336, out_features=4096, bias=False),\n",
       " 'model.layers.21.mlp.act_fn': SiLU(),\n",
       " 'model.layers.21.input_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.21.post_attention_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.22': MistralDecoderLayer(\n",
       "   (self_attn): MistralAttention(\n",
       "     (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (rotary_emb): MistralRotaryEmbedding()\n",
       "   )\n",
       "   (mlp): MistralMLP(\n",
       "     (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "     (act_fn): SiLU()\n",
       "   )\n",
       "   (input_layernorm): MistralRMSNorm()\n",
       "   (post_attention_layernorm): MistralRMSNorm()\n",
       " ),\n",
       " 'model.layers.22.self_attn': MistralAttention(\n",
       "   (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (rotary_emb): MistralRotaryEmbedding()\n",
       " ),\n",
       " 'model.layers.22.self_attn.q_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.22.self_attn.k_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.22.self_attn.v_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.22.self_attn.o_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.22.self_attn.rotary_emb': MistralRotaryEmbedding(),\n",
       " 'model.layers.22.mlp': MistralMLP(\n",
       "   (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "   (act_fn): SiLU()\n",
       " ),\n",
       " 'model.layers.22.mlp.gate_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.22.mlp.up_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.22.mlp.down_proj': Linear(in_features=14336, out_features=4096, bias=False),\n",
       " 'model.layers.22.mlp.act_fn': SiLU(),\n",
       " 'model.layers.22.input_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.22.post_attention_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.23': MistralDecoderLayer(\n",
       "   (self_attn): MistralAttention(\n",
       "     (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (rotary_emb): MistralRotaryEmbedding()\n",
       "   )\n",
       "   (mlp): MistralMLP(\n",
       "     (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "     (act_fn): SiLU()\n",
       "   )\n",
       "   (input_layernorm): MistralRMSNorm()\n",
       "   (post_attention_layernorm): MistralRMSNorm()\n",
       " ),\n",
       " 'model.layers.23.self_attn': MistralAttention(\n",
       "   (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (rotary_emb): MistralRotaryEmbedding()\n",
       " ),\n",
       " 'model.layers.23.self_attn.q_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.23.self_attn.k_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.23.self_attn.v_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.23.self_attn.o_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.23.self_attn.rotary_emb': MistralRotaryEmbedding(),\n",
       " 'model.layers.23.mlp': MistralMLP(\n",
       "   (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "   (act_fn): SiLU()\n",
       " ),\n",
       " 'model.layers.23.mlp.gate_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.23.mlp.up_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.23.mlp.down_proj': Linear(in_features=14336, out_features=4096, bias=False),\n",
       " 'model.layers.23.mlp.act_fn': SiLU(),\n",
       " 'model.layers.23.input_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.23.post_attention_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.24': MistralDecoderLayer(\n",
       "   (self_attn): MistralAttention(\n",
       "     (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (rotary_emb): MistralRotaryEmbedding()\n",
       "   )\n",
       "   (mlp): MistralMLP(\n",
       "     (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "     (act_fn): SiLU()\n",
       "   )\n",
       "   (input_layernorm): MistralRMSNorm()\n",
       "   (post_attention_layernorm): MistralRMSNorm()\n",
       " ),\n",
       " 'model.layers.24.self_attn': MistralAttention(\n",
       "   (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (rotary_emb): MistralRotaryEmbedding()\n",
       " ),\n",
       " 'model.layers.24.self_attn.q_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.24.self_attn.k_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.24.self_attn.v_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.24.self_attn.o_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.24.self_attn.rotary_emb': MistralRotaryEmbedding(),\n",
       " 'model.layers.24.mlp': MistralMLP(\n",
       "   (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "   (act_fn): SiLU()\n",
       " ),\n",
       " 'model.layers.24.mlp.gate_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.24.mlp.up_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.24.mlp.down_proj': Linear(in_features=14336, out_features=4096, bias=False),\n",
       " 'model.layers.24.mlp.act_fn': SiLU(),\n",
       " 'model.layers.24.input_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.24.post_attention_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.25': MistralDecoderLayer(\n",
       "   (self_attn): MistralAttention(\n",
       "     (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (rotary_emb): MistralRotaryEmbedding()\n",
       "   )\n",
       "   (mlp): MistralMLP(\n",
       "     (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "     (act_fn): SiLU()\n",
       "   )\n",
       "   (input_layernorm): MistralRMSNorm()\n",
       "   (post_attention_layernorm): MistralRMSNorm()\n",
       " ),\n",
       " 'model.layers.25.self_attn': MistralAttention(\n",
       "   (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (rotary_emb): MistralRotaryEmbedding()\n",
       " ),\n",
       " 'model.layers.25.self_attn.q_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.25.self_attn.k_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.25.self_attn.v_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.25.self_attn.o_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.25.self_attn.rotary_emb': MistralRotaryEmbedding(),\n",
       " 'model.layers.25.mlp': MistralMLP(\n",
       "   (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "   (act_fn): SiLU()\n",
       " ),\n",
       " 'model.layers.25.mlp.gate_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.25.mlp.up_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.25.mlp.down_proj': Linear(in_features=14336, out_features=4096, bias=False),\n",
       " 'model.layers.25.mlp.act_fn': SiLU(),\n",
       " 'model.layers.25.input_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.25.post_attention_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.26': MistralDecoderLayer(\n",
       "   (self_attn): MistralAttention(\n",
       "     (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (rotary_emb): MistralRotaryEmbedding()\n",
       "   )\n",
       "   (mlp): MistralMLP(\n",
       "     (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "     (act_fn): SiLU()\n",
       "   )\n",
       "   (input_layernorm): MistralRMSNorm()\n",
       "   (post_attention_layernorm): MistralRMSNorm()\n",
       " ),\n",
       " 'model.layers.26.self_attn': MistralAttention(\n",
       "   (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (rotary_emb): MistralRotaryEmbedding()\n",
       " ),\n",
       " 'model.layers.26.self_attn.q_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.26.self_attn.k_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.26.self_attn.v_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.26.self_attn.o_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.26.self_attn.rotary_emb': MistralRotaryEmbedding(),\n",
       " 'model.layers.26.mlp': MistralMLP(\n",
       "   (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "   (act_fn): SiLU()\n",
       " ),\n",
       " 'model.layers.26.mlp.gate_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.26.mlp.up_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.26.mlp.down_proj': Linear(in_features=14336, out_features=4096, bias=False),\n",
       " 'model.layers.26.mlp.act_fn': SiLU(),\n",
       " 'model.layers.26.input_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.26.post_attention_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.27': MistralDecoderLayer(\n",
       "   (self_attn): MistralAttention(\n",
       "     (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (rotary_emb): MistralRotaryEmbedding()\n",
       "   )\n",
       "   (mlp): MistralMLP(\n",
       "     (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "     (act_fn): SiLU()\n",
       "   )\n",
       "   (input_layernorm): MistralRMSNorm()\n",
       "   (post_attention_layernorm): MistralRMSNorm()\n",
       " ),\n",
       " 'model.layers.27.self_attn': MistralAttention(\n",
       "   (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (rotary_emb): MistralRotaryEmbedding()\n",
       " ),\n",
       " 'model.layers.27.self_attn.q_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.27.self_attn.k_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.27.self_attn.v_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.27.self_attn.o_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.27.self_attn.rotary_emb': MistralRotaryEmbedding(),\n",
       " 'model.layers.27.mlp': MistralMLP(\n",
       "   (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "   (act_fn): SiLU()\n",
       " ),\n",
       " 'model.layers.27.mlp.gate_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.27.mlp.up_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.27.mlp.down_proj': Linear(in_features=14336, out_features=4096, bias=False),\n",
       " 'model.layers.27.mlp.act_fn': SiLU(),\n",
       " 'model.layers.27.input_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.27.post_attention_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.28': MistralDecoderLayer(\n",
       "   (self_attn): MistralAttention(\n",
       "     (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (rotary_emb): MistralRotaryEmbedding()\n",
       "   )\n",
       "   (mlp): MistralMLP(\n",
       "     (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "     (act_fn): SiLU()\n",
       "   )\n",
       "   (input_layernorm): MistralRMSNorm()\n",
       "   (post_attention_layernorm): MistralRMSNorm()\n",
       " ),\n",
       " 'model.layers.28.self_attn': MistralAttention(\n",
       "   (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (rotary_emb): MistralRotaryEmbedding()\n",
       " ),\n",
       " 'model.layers.28.self_attn.q_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.28.self_attn.k_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.28.self_attn.v_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.28.self_attn.o_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.28.self_attn.rotary_emb': MistralRotaryEmbedding(),\n",
       " 'model.layers.28.mlp': MistralMLP(\n",
       "   (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "   (act_fn): SiLU()\n",
       " ),\n",
       " 'model.layers.28.mlp.gate_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.28.mlp.up_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.28.mlp.down_proj': Linear(in_features=14336, out_features=4096, bias=False),\n",
       " 'model.layers.28.mlp.act_fn': SiLU(),\n",
       " 'model.layers.28.input_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.28.post_attention_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.29': MistralDecoderLayer(\n",
       "   (self_attn): MistralAttention(\n",
       "     (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (rotary_emb): MistralRotaryEmbedding()\n",
       "   )\n",
       "   (mlp): MistralMLP(\n",
       "     (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "     (act_fn): SiLU()\n",
       "   )\n",
       "   (input_layernorm): MistralRMSNorm()\n",
       "   (post_attention_layernorm): MistralRMSNorm()\n",
       " ),\n",
       " 'model.layers.29.self_attn': MistralAttention(\n",
       "   (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (rotary_emb): MistralRotaryEmbedding()\n",
       " ),\n",
       " 'model.layers.29.self_attn.q_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.29.self_attn.k_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.29.self_attn.v_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.29.self_attn.o_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.29.self_attn.rotary_emb': MistralRotaryEmbedding(),\n",
       " 'model.layers.29.mlp': MistralMLP(\n",
       "   (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "   (act_fn): SiLU()\n",
       " ),\n",
       " 'model.layers.29.mlp.gate_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.29.mlp.up_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.29.mlp.down_proj': Linear(in_features=14336, out_features=4096, bias=False),\n",
       " 'model.layers.29.mlp.act_fn': SiLU(),\n",
       " 'model.layers.29.input_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.29.post_attention_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.30': MistralDecoderLayer(\n",
       "   (self_attn): MistralAttention(\n",
       "     (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (rotary_emb): MistralRotaryEmbedding()\n",
       "   )\n",
       "   (mlp): MistralMLP(\n",
       "     (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "     (act_fn): SiLU()\n",
       "   )\n",
       "   (input_layernorm): MistralRMSNorm()\n",
       "   (post_attention_layernorm): MistralRMSNorm()\n",
       " ),\n",
       " 'model.layers.30.self_attn': MistralAttention(\n",
       "   (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (rotary_emb): MistralRotaryEmbedding()\n",
       " ),\n",
       " 'model.layers.30.self_attn.q_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.30.self_attn.k_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.30.self_attn.v_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.30.self_attn.o_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.30.self_attn.rotary_emb': MistralRotaryEmbedding(),\n",
       " 'model.layers.30.mlp': MistralMLP(\n",
       "   (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "   (act_fn): SiLU()\n",
       " ),\n",
       " 'model.layers.30.mlp.gate_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.30.mlp.up_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.30.mlp.down_proj': Linear(in_features=14336, out_features=4096, bias=False),\n",
       " 'model.layers.30.mlp.act_fn': SiLU(),\n",
       " 'model.layers.30.input_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.30.post_attention_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.31': MistralDecoderLayer(\n",
       "   (self_attn): MistralAttention(\n",
       "     (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (rotary_emb): MistralRotaryEmbedding()\n",
       "   )\n",
       "   (mlp): MistralMLP(\n",
       "     (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "     (act_fn): SiLU()\n",
       "   )\n",
       "   (input_layernorm): MistralRMSNorm()\n",
       "   (post_attention_layernorm): MistralRMSNorm()\n",
       " ),\n",
       " 'model.layers.31.self_attn': MistralAttention(\n",
       "   (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (rotary_emb): MistralRotaryEmbedding()\n",
       " ),\n",
       " 'model.layers.31.self_attn.q_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.31.self_attn.k_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.31.self_attn.v_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.31.self_attn.o_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.31.self_attn.rotary_emb': MistralRotaryEmbedding(),\n",
       " 'model.layers.31.mlp': MistralMLP(\n",
       "   (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "   (act_fn): SiLU()\n",
       " ),\n",
       " 'model.layers.31.mlp.gate_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.31.mlp.up_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.31.mlp.down_proj': Linear(in_features=14336, out_features=4096, bias=False),\n",
       " 'model.layers.31.mlp.act_fn': SiLU(),\n",
       " 'model.layers.31.input_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.31.post_attention_layernorm': MistralRMSNorm(),\n",
       " 'model.norm': MistralRMSNorm(),\n",
       " 'lm_head': Linear(in_features=4096, out_features=32000, bias=False)}"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dict(model.named_modules())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "691534a9-af04-40b4-a544-a780969a844f",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
