{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "95f8de05-be82-4ca0-8a8f-8de53241ed03",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "import torch\n",
    "from torch import nn\n",
    "import copy\n",
    "from types import MethodType\n",
    "import datasets\n",
    "from torch.utils.data import DataLoader\n",
    "from itertools import islice\n",
    "from tqdm import tqdm\n",
    "import gc\n",
    "import pandas as pd\n",
    "from transformers.models.gpt2.modeling_gpt2 import GPT2Attention\n",
    "from transformers.pytorch_utils import Conv1D"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "488ddb71-fbd9-49d8-8f61-14f70528cff1",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "74e9a972-ce73-4797-93e9-57258a0f7eee",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = AutoModelForCausalLM.from_pretrained(\"gpt2\").to('cuda')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "02262ef1-c753-4884-af10-d39304bdb6f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained(\"gpt2\")\n",
    "Token = {v: k for k, v in tokenizer.get_vocab().items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "cc8d4a8d-4da7-4031-af43-b18dc749b67c",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = datasets.load_from_disk(f'/workspace/corpus/msmarco/msmarco_GPT2_64tokens_1m').with_format('torch', device=torch.device('cuda'))\n",
    "loader = DataLoader(dataset['test'], batch_size=128)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "3b7f0924-bde7-4bef-947b-61d44e8fd006",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "GPT2LMHeadModel(\n",
       "  (transformer): GPT2Model(\n",
       "    (wte): Embedding(50257, 768)\n",
       "    (wpe): Embedding(1024, 768)\n",
       "    (drop): Dropout(p=0.1, inplace=False)\n",
       "    (h): ModuleList(\n",
       "      (0-11): 12 x GPT2Block(\n",
       "        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "        (attn): GPT2Attention(\n",
       "          (c_attn): Conv1D()\n",
       "          (c_proj): Conv1D()\n",
       "          (attn_dropout): Dropout(p=0.1, inplace=False)\n",
       "          (resid_dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "        (mlp): GPT2MLP(\n",
       "          (c_fc): Conv1D()\n",
       "          (c_proj): Conv1D()\n",
       "          (act): NewGELUActivation()\n",
       "          (dropout): Dropout(p=0.1, inplace=False)\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "  )\n",
       "  (lm_head): Linear(in_features=768, out_features=50257, bias=False)\n",
       ")"
      ]
     },
     "execution_count": 37,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "7a85dbd5-adae-449d-bd3e-bc9d72ff0c29",
   "metadata": {},
   "outputs": [],
   "source": [
    "def to_myopic_gpt2(model, past_key_values):    \n",
    "    def forward(self, *args, **kwargs):\n",
    "        nonlocal past_key_values\n",
    "        kwargs.pop('layer_past')\n",
    "        return myopic_forward_gpt2(self, *args, **kwargs, layer_past=past_key_values[self.layer_idx])\n",
    "    for name, module in model.named_modules():\n",
    "        #if type(module) == GPT2Attention:  # type doesn't match? idk why\n",
    "        if name.split('.')[-1] == 'attn':\n",
    "            layer_past = past_key_values[module.layer_idx]            \n",
    "            module.forward = MethodType(forward, module)\n",
    "            module.extra_repr = lambda: 'MYOPIC'\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "5721fd29-62bd-44ef-8f79-de43d5dc4360",
   "metadata": {},
   "outputs": [],
   "source": [
    "def myopic_attn_gpt2(\n",
    "    query, key, value, past_key, past_value, attention_mask, head_mask,\n",
    "    bias,\n",
    "    attn_dropout,\n",
    "    scale_attn_weights=True,\n",
    "):\n",
    "    #import pdb; pdb.set_trace()\n",
    "    attn_weights = torch.matmul(query, past_key.transpose(-1, -2))\n",
    "    attn_weights.diagonal(dim1=2, dim2=3).copy_((query * key).sum(dim=3))\n",
    "\n",
    "    if scale_attn_weights:\n",
    "        attn_weights = attn_weights / torch.full(\n",
    "            [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device\n",
    "        )\n",
    "\n",
    "    query_length, key_length = query.size(-2), key.size(-2)\n",
    "    causal_mask = bias[:, :, key_length - query_length : key_length, :key_length]\n",
    "    mask_value = torch.finfo(attn_weights.dtype).min\n",
    "    # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.\n",
    "    # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`\n",
    "    mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device)\n",
    "    attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)\n",
    "\n",
    "    if attention_mask is not None:\n",
    "        # Apply the attention mask\n",
    "        attn_weights = attn_weights + attention_mask\n",
    "\n",
    "    attn_weights = nn.functional.softmax(attn_weights, dim=-1)\n",
    "\n",
    "    # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise\n",
    "    attn_weights = attn_weights.type(value.dtype)\n",
    "    attn_weights = attn_dropout(attn_weights)\n",
    "\n",
    "    # Mask heads if we want to\n",
    "    if head_mask is not None:\n",
    "        attn_weights = attn_weights * head_mask\n",
    "\n",
    "    attn_output = torch.matmul(attn_weights, past_value)\n",
    "    attn_output += attn_weights.diagonal(dim1=2, dim2=3).unsqueeze(dim=3) * (value - past_value)\n",
    "\n",
    "    return attn_output, attn_weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "21002807-0ef4-4382-bd3c-09ee6cea36ce",
   "metadata": {},
   "outputs": [],
   "source": [
    "def myopic_forward_gpt2(\n",
    "    self,\n",
    "    hidden_states,\n",
    "    layer_past,\n",
    "    attention_mask=None,\n",
    "    head_mask=None,\n",
    "    output_attentions=False,\n",
    "    **kwargs,\n",
    "):\n",
    "    assert kwargs.get('encoder_hidden_states') is None, 'Only decoder is supported'\n",
    "    assert layer_past is not None, 'layer_past must be provided'\n",
    "    \n",
    "    query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)\n",
    "    query = self._split_heads(query, self.num_heads, self.head_dim)\n",
    "    key = self._split_heads(key, self.num_heads, self.head_dim)\n",
    "    value = self._split_heads(value, self.num_heads, self.head_dim)\n",
    "\n",
    "    past_key, past_value = layer_past\n",
    "    past_key, past_value = past_key.detach(), past_value.detach()\n",
    "    #import pdb; pdb.set_trace()\n",
    "    present = (key, value)\n",
    "\n",
    "    assert not self.reorder_and_upcast_attn, 'Not supported!'\n",
    "    assert not self.is_cross_attention, 'Not supported!'\n",
    "    assert not self.scale_attn_by_inverse_layer_idx, 'Not supported!'\n",
    "    attn_output, attn_weights = myopic_attn_gpt2(\n",
    "        query, key, value, past_key, past_value, attention_mask, head_mask,\n",
    "        self.bias, self.attn_dropout,\n",
    "        scale_attn_weights=self.scale_attn_weights,\n",
    "    )\n",
    "\n",
    "    attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)\n",
    "    attn_output = self.c_proj(attn_output)\n",
    "    attn_output = self.resid_dropout(attn_output)\n",
    "\n",
    "    outputs = (attn_output, present)\n",
    "    if output_attentions:\n",
    "        outputs += (attn_weights,)\n",
    "\n",
    "    return outputs  # a, present, (attentions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "68e033dc-e514-43cf-aa60-737b634df4c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "def topk(v, k=40, aux=None):\n",
    "    # Takes in logits\n",
    "    #v = softmax(v.flatten())\n",
    "    if type(v) == torch.Tensor:\n",
    "        v = v.detach().cpu().numpy()\n",
    "    v = v.flatten()\n",
    "    idxs = v.argsort()[-k:][::-1]\n",
    "    if aux:\n",
    "        ret = [(Token[i], v[i]) + tuple(aux[i]) for i in idxs]\n",
    "        return pd.DataFrame(ret, columns=['token', 'logit'] + list(range(len(aux[0]))))\n",
    "    else:\n",
    "        ret = [(Token[i], v[i]) for i in idxs]\n",
    "        return pd.DataFrame(ret, columns=['token', 'logit'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "c373d858-f740-42c8-9d6b-205878370856",
   "metadata": {},
   "outputs": [],
   "source": [
    "input = tokenizer('My favorite element of the periodic table is', return_tensors='pt').to('cuda')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "4e4dcea4-e871-48c8-bc88-ed032f30b998",
   "metadata": {},
   "outputs": [],
   "source": [
    "out = model(**input)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "44dfc9cf-5bd9-409e-a9ea-6e18735c5557",
   "metadata": {},
   "outputs": [],
   "source": [
    "myopic = to_myopic_gpt2(copy.deepcopy(model), out.past_key_values)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "c6bc89b7-4b4e-443f-bba0-05cf15439d86",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([-1.0961,  1.8475,  0.8989, -0.1387,  0.9979], device='cuda:0',\n",
       "       grad_fn=<SliceBackward0>)"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "out.past_key_values[0][0][0, 0, 0, :5]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "d8af3e66-73dd-4de6-b090-eb972b250fd3",
   "metadata": {},
   "outputs": [],
   "source": [
    "myopic_out = myopic(**input)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "cd85736e-097a-4e47-9ebd-3af8be25472b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.0073, device='cuda:0', grad_fn=<LinalgVectorNormBackward0>)"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(out.logits - myopic_out.logits).norm()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "753646e4-d48d-433a-93a1-63f6afb774f3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([-33.0735, -32.3349, -35.2379, -34.7751, -33.8666, -34.4521, -33.0241,\n",
       "        -33.5888, -32.0457, -34.4160], device='cuda:0',\n",
       "       grad_fn=<SliceBackward0>)"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "out.logits[0,0,:10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "a19ce5fe-23c5-4aad-83e5-bc00d944238a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([-33.0735, -32.3349, -35.2379, -34.7751, -33.8666, -34.4521, -33.0241,\n",
       "        -33.5888, -32.0457, -34.4160], device='cuda:0',\n",
       "       grad_fn=<SliceBackward0>)"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "myopic_out.logits[0,0,:10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 87,
   "id": "190f83f0-348f-4215-8c11-95e554d48176",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(14419.7354, device='cuda:0', grad_fn=<LinalgVectorNormBackward0>)"
      ]
     },
     "execution_count": 87,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(myopic_out.logits-out.logits).norm()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "6b673d0c-9e61-4576-9f02-6d75d9aa9a78",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 12, 9, 64])"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "out.past_key_values[0][1].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "8ed82ba8-895b-4c62-b4ab-b771a9e6c427",
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'myopic' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[33], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28;01mdel\u001b[39;00m myopic\n\u001b[1;32m      2\u001b[0m gc\u001b[38;5;241m.\u001b[39mcollect()\n\u001b[1;32m      3\u001b[0m torch\u001b[38;5;241m.\u001b[39mcuda\u001b[38;5;241m.\u001b[39mempty_cache()\n",
      "\u001b[0;31mNameError\u001b[0m: name 'myopic' is not defined"
     ]
    }
   ],
   "source": [
    "del myopic\n",
    "gc.collect()\n",
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "c1fc7ecd-277f-4dbc-b314-2114b83f7f2e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "768"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.transformer.h[0].mlp.c_fc.weight.shape[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "f09015c9-0f10-461e-b030-9b4067e7b6e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "myopic = copy.deepcopy(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "48b1e66f-ce51-403f-b0a6-caaed9782d83",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch.autograd.grad_mode.set_grad_enabled at 0x7f517548c910>"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.set_grad_enabled(False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "c5a82db9-45c4-4f4e-a650-84e0106f0c29",
   "metadata": {},
   "outputs": [],
   "source": [
    "for name, param in myopic.named_parameters():\n",
    "    #if 'k_proj' in name or 'v_proj' in name:\n",
    "    if 'mlp' in name:\n",
    "        param.zero_()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "896d1c2e-dd9b-41bb-adfe-c4a7ee528913",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [02:14<00:00, 13.44s/it]\n"
     ]
    }
   ],
   "source": [
    "num_batches = 10\n",
    "it = islice(iter(loader), num_batches)\n",
    "loss = 0\n",
    "myopic_loss = 0\n",
    "for batch in tqdm(it, total=num_batches):\n",
    "    out = model(\n",
    "        input_ids=batch['input_ids'], \n",
    "        labels=batch['input_ids'], \n",
    "        attention_mask=batch['attention_mask']\n",
    "    )\n",
    "    myopic = to_myopic(myopic, out.past_key_values)\n",
    "    myopic_out = myopic(\n",
    "        input_ids=batch['input_ids'], \n",
    "        labels=batch['input_ids'], \n",
    "        attention_mask=batch['attention_mask']\n",
    "    )\n",
    "    loss += out.loss.item()\n",
    "    myopic_loss += myopic_out.loss.item()\n",
    "loss /= num_batches\n",
    "myopic_loss /= num_batches"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c5eb312b-4004-41ba-ab6d-f163d425b6a0",
   "metadata": {},
   "source": [
    "full: 2.247\n",
    "attn only: 7.708\n",
    "mlp only: 6.929\n",
    "mlp + q: 2.309"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "b909d309-a1bb-418f-9afc-625bd1369c03",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(2.246629166603088, 7.708231830596924)"
      ]
     },
     "execution_count": 37,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "loss, myopic_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 231,
   "id": "6f2b5d62-0f9c-47f0-88eb-7cd662e20832",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "MistralForCausalLM(\n",
       "  (model): MistralModel(\n",
       "    (embed_tokens): Embedding(32000, 4096)\n",
       "    (layers): ModuleList(\n",
       "      (0-31): 32 x MistralDecoderLayer(\n",
       "        (self_attn): MistralAttention(\n",
       "          MYOPIC\n",
       "          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "          (rotary_emb): MistralRotaryEmbedding()\n",
       "        )\n",
       "        (mlp): MistralMLP(\n",
       "          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "          (act_fn): SiLU()\n",
       "        )\n",
       "        (input_layernorm): MistralRMSNorm()\n",
       "        (post_attention_layernorm): MistralRMSNorm()\n",
       "      )\n",
       "    )\n",
       "    (norm): MistralRMSNorm()\n",
       "  )\n",
       "  (lm_head): Linear(in_features=4096, out_features=32000, bias=False)\n",
       ")"
      ]
     },
     "execution_count": 231,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "to_myopic(myopic, out.past_key_values)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 162,
   "id": "ec6f0667-3916-4046-a5fe-ea8f0c8d410c",
   "metadata": {},
   "outputs": [],
   "source": [
    "out = model(**input)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 233,
   "id": "9979be34-799f-47e3-a653-b13b17020245",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.0003, device='cuda:0', grad_fn=<LinalgVectorNormBackward0>)"
      ]
     },
     "execution_count": 233,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(out.logits - myopic_out.logits).norm()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "35cbf72f-c864-4250-9c80-0ecb422a6b42",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = model.to('cuda')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "bcb3e17d-1d6f-4e45-81e2-3e0e0597864f",
   "metadata": {},
   "outputs": [],
   "source": [
    "lora_model = lora_model.to('cuda')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "4a10b0d2-47cb-462e-a771-1413c5e6a718",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "device(type='cpu')"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "951131ee-ed29-46bd-90e4-96c85057caa5",
   "metadata": {},
   "outputs": [],
   "source": [
    "input = tokenizer(\"Hello my name is\", return_tensors='pt').to('cuda')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 240,
   "id": "091f882e-c281-41bd-944b-4efa35f6348c",
   "metadata": {},
   "outputs": [],
   "source": [
    "lora = to_lora(myopic, 8, ['k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 245,
   "id": "2b281b90-980e-489b-be3a-be6c09ed8c81",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "KEY DIFF 142.73130798339844\n",
      "VAL DIFF 12.946982383728027\n",
      "KEY DIFF 217.15673828125\n",
      "VAL DIFF 18.909976959228516\n",
      "KEY DIFF 186.13214111328125\n",
      "VAL DIFF 67.96466827392578\n",
      "KEY DIFF 188.8489227294922\n",
      "VAL DIFF 53.43036651611328\n",
      "KEY DIFF 174.39666748046875\n",
      "VAL DIFF 78.77042388916016\n",
      "KEY DIFF 186.60130310058594\n",
      "VAL DIFF 68.82740020751953\n",
      "KEY DIFF 188.38514709472656\n",
      "VAL DIFF 74.36092376708984\n",
      "KEY DIFF 195.12734985351562\n",
      "VAL DIFF 70.07445526123047\n",
      "KEY DIFF 199.66671752929688\n",
      "VAL DIFF 72.25618743896484\n",
      "KEY DIFF 182.92108154296875\n",
      "VAL DIFF 97.33354949951172\n",
      "KEY DIFF 212.75222778320312\n",
      "VAL DIFF 94.85765075683594\n",
      "KEY DIFF 205.39918518066406\n",
      "VAL DIFF 87.05265045166016\n",
      "KEY DIFF 209.25807189941406\n",
      "VAL DIFF 113.35973358154297\n",
      "KEY DIFF 208.14002990722656\n",
      "VAL DIFF 117.33596801757812\n",
      "KEY DIFF 206.93878173828125\n",
      "VAL DIFF 114.59712219238281\n",
      "KEY DIFF 204.29835510253906\n",
      "VAL DIFF 153.5397186279297\n",
      "KEY DIFF 214.47869873046875\n",
      "VAL DIFF 134.12245178222656\n",
      "KEY DIFF 222.3302764892578\n",
      "VAL DIFF 126.56486511230469\n",
      "KEY DIFF 250.53147888183594\n",
      "VAL DIFF 146.58984375\n"
     ]
    },
    {
     "ename": "OutOfMemoryError",
     "evalue": "CUDA out of memory. Tried to allocate 2.00 MiB. GPU 0 has a total capacty of 79.15 GiB of which 2.12 MiB is free. Process 1531090 has 79.13 GiB memory in use. Of the allocated memory 78.63 GiB is allocated by PyTorch, and 15.09 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mOutOfMemoryError\u001b[0m                          Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[245], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m lora_out \u001b[38;5;241m=\u001b[39m \u001b[43mlora\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1516\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1525\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1526\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m   1530\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "File \u001b[0;32m~/.local/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:1053\u001b[0m, in \u001b[0;36mMistralForCausalLM.forward\u001b[0;34m(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m   1050\u001b[0m return_dict \u001b[38;5;241m=\u001b[39m return_dict \u001b[38;5;28;01mif\u001b[39;00m return_dict \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39muse_return_dict\n\u001b[1;32m   1052\u001b[0m \u001b[38;5;66;03m# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\u001b[39;00m\n\u001b[0;32m-> 1053\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m   1054\u001b[0m \u001b[43m    \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1055\u001b[0m \u001b[43m    \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1056\u001b[0m \u001b[43m    \u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1057\u001b[0m \u001b[43m    \u001b[49m\u001b[43mpast_key_values\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1058\u001b[0m \u001b[43m    \u001b[49m\u001b[43minputs_embeds\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs_embeds\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1059\u001b[0m \u001b[43m    \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1060\u001b[0m \u001b[43m    \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1061\u001b[0m \u001b[43m    \u001b[49m\u001b[43moutput_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1062\u001b[0m \u001b[43m    \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1063\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1065\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m outputs[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m   1066\u001b[0m logits \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlm_head(hidden_states)\n",
      "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1516\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1525\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1526\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m   1530\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "File \u001b[0;32m~/.local/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:938\u001b[0m, in \u001b[0;36mMistralModel.forward\u001b[0;34m(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m    928\u001b[0m     layer_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_gradient_checkpointing_func(\n\u001b[1;32m    929\u001b[0m         decoder_layer\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__call__\u001b[39m,\n\u001b[1;32m    930\u001b[0m         hidden_states,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    935\u001b[0m         use_cache,\n\u001b[1;32m    936\u001b[0m     )\n\u001b[1;32m    937\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 938\u001b[0m     layer_outputs \u001b[38;5;241m=\u001b[39m \u001b[43mdecoder_layer\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    939\u001b[0m \u001b[43m        \u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    940\u001b[0m \u001b[43m        \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    941\u001b[0m \u001b[43m        \u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    942\u001b[0m \u001b[43m        \u001b[49m\u001b[43mpast_key_value\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    943\u001b[0m \u001b[43m        \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    944\u001b[0m \u001b[43m        \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    945\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    947\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m layer_outputs[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m    949\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m use_cache:\n",
      "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1516\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1525\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1526\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m   1530\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "File \u001b[0;32m~/.local/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:676\u001b[0m, in \u001b[0;36mMistralDecoderLayer.forward\u001b[0;34m(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, **kwargs)\u001b[0m\n\u001b[1;32m    674\u001b[0m residual \u001b[38;5;241m=\u001b[39m hidden_states\n\u001b[1;32m    675\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpost_attention_layernorm(hidden_states)\n\u001b[0;32m--> 676\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmlp\u001b[49m\u001b[43m(\u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    677\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m residual \u001b[38;5;241m+\u001b[39m hidden_states\n\u001b[1;32m    679\u001b[0m outputs \u001b[38;5;241m=\u001b[39m (hidden_states,)\n",
      "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1516\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1525\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1526\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m   1530\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "File \u001b[0;32m~/.local/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py:177\u001b[0m, in \u001b[0;36mMistralMLP.forward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m    176\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, x):\n\u001b[0;32m--> 177\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdown_proj(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mact_fn(\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgate_proj\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m) \u001b[38;5;241m*\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mup_proj(x))\n",
      "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1516\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1525\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1526\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m   1530\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "Cell \u001b[0;32mIn[4], line 8\u001b[0m, in \u001b[0;36mLoraLinear.forward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m      7\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, x):\n\u001b[0;32m----> 8\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mB\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mA\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1516\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m   1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m   1525\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1526\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m   1530\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:114\u001b[0m, in \u001b[0;36mLinear.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m    113\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m: Tensor) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor:\n\u001b[0;32m--> 114\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlinear\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbias\u001b[49m\u001b[43m)\u001b[49m\n",
      "\u001b[0;31mOutOfMemoryError\u001b[0m: CUDA out of memory. Tried to allocate 2.00 MiB. GPU 0 has a total capacty of 79.15 GiB of which 2.12 MiB is free. Process 1531090 has 79.13 GiB memory in use. Of the allocated memory 78.63 GiB is allocated by PyTorch, and 15.09 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF"
     ]
    }
   ],
   "source": [
    "lora_out = lora(**input)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "id": "7cf23a1e-e9f3-452b-a4a8-60d7b5ee1a70",
   "metadata": {},
   "outputs": [],
   "source": [
    "out = model(\n",
    "    **input,\n",
    ")\n",
    "#lora_out = lora_model(**input)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 85,
   "id": "c0767e0a-c9b9-4df5-841e-4524824b275d",
   "metadata": {},
   "outputs": [],
   "source": [
    "cache = DynamicCache.from_legacy_cache(out.past_key_values)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 91,
   "id": "f4dff416-b7a6-4b17-8d5f-180048548459",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "32"
      ]
     },
     "execution_count": 91,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(cache.key_cache)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "1d25fbcf-b13a-4f33-9949-5ef97675997e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 5, 32000])"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "lora_out.logits.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "a4c97b57-500e-4233-8a63-f5ce950777d1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "type(model.model.layers[0].self_attn.q_proj) == nn.Linear"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "7f75ba63-f5f5-4928-8a33-99a9c9601aea",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['__module__',\n",
       " '_get_no_split_modules',\n",
       " '_keep_in_fp32_modules',\n",
       " '_keep_in_fp32_modules',\n",
       " '_modules',\n",
       " '_no_split_modules',\n",
       " 'add_module',\n",
       " 'get_submodule',\n",
       " 'modules',\n",
       " 'named_modules',\n",
       " 'register_module',\n",
       " 'retrieve_modules_from_names']"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "[x for x in dir(model) if 'module' in x]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "70d17cde-4ca2-47a5-9106-63b2ec4beded",
   "metadata": {},
   "outputs": [
    {
     "ename": "SyntaxError",
     "evalue": "cannot assign to function call here. Maybe you meant '==' instead of '='? (3646139102.py, line 1)",
     "output_type": "error",
     "traceback": [
      "\u001b[0;36m  Cell \u001b[0;32mIn[15], line 1\u001b[0;36m\u001b[0m\n\u001b[0;31m    model.get_submodule('model.layers.0.self_attn.q_proj') = nn.Linear(5, 5)\u001b[0m\n\u001b[0m    ^\u001b[0m\n\u001b[0;31mSyntaxError\u001b[0m\u001b[0;31m:\u001b[0m cannot assign to function call here. Maybe you meant '==' instead of '='?\n"
     ]
    }
   ],
   "source": [
    "model.get_submodule('model.layers.0.self_attn.q_proj') = nn.Linear(5, 5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "a2fa96ca-d8f0-401f-8ee7-49f67157681d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'': MistralForCausalLM(\n",
       "   (model): MistralModel(\n",
       "     (embed_tokens): Embedding(32000, 4096)\n",
       "     (layers): ModuleList(\n",
       "       (0-31): 32 x MistralDecoderLayer(\n",
       "         (self_attn): MistralAttention(\n",
       "           (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "           (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "           (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "           (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "           (rotary_emb): MistralRotaryEmbedding()\n",
       "         )\n",
       "         (mlp): MistralMLP(\n",
       "           (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "           (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "           (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "           (act_fn): SiLU()\n",
       "         )\n",
       "         (input_layernorm): MistralRMSNorm()\n",
       "         (post_attention_layernorm): MistralRMSNorm()\n",
       "       )\n",
       "     )\n",
       "     (norm): MistralRMSNorm()\n",
       "   )\n",
       "   (lm_head): Linear(in_features=4096, out_features=32000, bias=False)\n",
       " ),\n",
       " 'model': MistralModel(\n",
       "   (embed_tokens): Embedding(32000, 4096)\n",
       "   (layers): ModuleList(\n",
       "     (0-31): 32 x MistralDecoderLayer(\n",
       "       (self_attn): MistralAttention(\n",
       "         (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "         (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "         (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "         (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "         (rotary_emb): MistralRotaryEmbedding()\n",
       "       )\n",
       "       (mlp): MistralMLP(\n",
       "         (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "         (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "         (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "         (act_fn): SiLU()\n",
       "       )\n",
       "       (input_layernorm): MistralRMSNorm()\n",
       "       (post_attention_layernorm): MistralRMSNorm()\n",
       "     )\n",
       "   )\n",
       "   (norm): MistralRMSNorm()\n",
       " ),\n",
       " 'model.embed_tokens': Embedding(32000, 4096),\n",
       " 'model.layers': ModuleList(\n",
       "   (0-31): 32 x MistralDecoderLayer(\n",
       "     (self_attn): MistralAttention(\n",
       "       (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "       (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "       (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "       (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "       (rotary_emb): MistralRotaryEmbedding()\n",
       "     )\n",
       "     (mlp): MistralMLP(\n",
       "       (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "       (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "       (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "       (act_fn): SiLU()\n",
       "     )\n",
       "     (input_layernorm): MistralRMSNorm()\n",
       "     (post_attention_layernorm): MistralRMSNorm()\n",
       "   )\n",
       " ),\n",
       " 'model.layers.0': MistralDecoderLayer(\n",
       "   (self_attn): MistralAttention(\n",
       "     (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (rotary_emb): MistralRotaryEmbedding()\n",
       "   )\n",
       "   (mlp): MistralMLP(\n",
       "     (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "     (act_fn): SiLU()\n",
       "   )\n",
       "   (input_layernorm): MistralRMSNorm()\n",
       "   (post_attention_layernorm): MistralRMSNorm()\n",
       " ),\n",
       " 'model.layers.0.self_attn': MistralAttention(\n",
       "   (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (rotary_emb): MistralRotaryEmbedding()\n",
       " ),\n",
       " 'model.layers.0.self_attn.q_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.0.self_attn.k_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.0.self_attn.v_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.0.self_attn.o_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.0.self_attn.rotary_emb': MistralRotaryEmbedding(),\n",
       " 'model.layers.0.mlp': MistralMLP(\n",
       "   (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "   (act_fn): SiLU()\n",
       " ),\n",
       " 'model.layers.0.mlp.gate_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.0.mlp.up_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.0.mlp.down_proj': Linear(in_features=14336, out_features=4096, bias=False),\n",
       " 'model.layers.0.mlp.act_fn': SiLU(),\n",
       " 'model.layers.0.input_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.0.post_attention_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.1': MistralDecoderLayer(\n",
       "   (self_attn): MistralAttention(\n",
       "     (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (rotary_emb): MistralRotaryEmbedding()\n",
       "   )\n",
       "   (mlp): MistralMLP(\n",
       "     (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "     (act_fn): SiLU()\n",
       "   )\n",
       "   (input_layernorm): MistralRMSNorm()\n",
       "   (post_attention_layernorm): MistralRMSNorm()\n",
       " ),\n",
       " 'model.layers.1.self_attn': MistralAttention(\n",
       "   (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (rotary_emb): MistralRotaryEmbedding()\n",
       " ),\n",
       " 'model.layers.1.self_attn.q_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.1.self_attn.k_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.1.self_attn.v_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.1.self_attn.o_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.1.self_attn.rotary_emb': MistralRotaryEmbedding(),\n",
       " 'model.layers.1.mlp': MistralMLP(\n",
       "   (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "   (act_fn): SiLU()\n",
       " ),\n",
       " 'model.layers.1.mlp.gate_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.1.mlp.up_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.1.mlp.down_proj': Linear(in_features=14336, out_features=4096, bias=False),\n",
       " 'model.layers.1.mlp.act_fn': SiLU(),\n",
       " 'model.layers.1.input_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.1.post_attention_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.2': MistralDecoderLayer(\n",
       "   (self_attn): MistralAttention(\n",
       "     (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (rotary_emb): MistralRotaryEmbedding()\n",
       "   )\n",
       "   (mlp): MistralMLP(\n",
       "     (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "     (act_fn): SiLU()\n",
       "   )\n",
       "   (input_layernorm): MistralRMSNorm()\n",
       "   (post_attention_layernorm): MistralRMSNorm()\n",
       " ),\n",
       " 'model.layers.2.self_attn': MistralAttention(\n",
       "   (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (rotary_emb): MistralRotaryEmbedding()\n",
       " ),\n",
       " 'model.layers.2.self_attn.q_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.2.self_attn.k_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.2.self_attn.v_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.2.self_attn.o_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.2.self_attn.rotary_emb': MistralRotaryEmbedding(),\n",
       " 'model.layers.2.mlp': MistralMLP(\n",
       "   (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "   (act_fn): SiLU()\n",
       " ),\n",
       " 'model.layers.2.mlp.gate_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.2.mlp.up_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.2.mlp.down_proj': Linear(in_features=14336, out_features=4096, bias=False),\n",
       " 'model.layers.2.mlp.act_fn': SiLU(),\n",
       " 'model.layers.2.input_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.2.post_attention_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.3': MistralDecoderLayer(\n",
       "   (self_attn): MistralAttention(\n",
       "     (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (rotary_emb): MistralRotaryEmbedding()\n",
       "   )\n",
       "   (mlp): MistralMLP(\n",
       "     (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "     (act_fn): SiLU()\n",
       "   )\n",
       "   (input_layernorm): MistralRMSNorm()\n",
       "   (post_attention_layernorm): MistralRMSNorm()\n",
       " ),\n",
       " 'model.layers.3.self_attn': MistralAttention(\n",
       "   (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (rotary_emb): MistralRotaryEmbedding()\n",
       " ),\n",
       " 'model.layers.3.self_attn.q_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.3.self_attn.k_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.3.self_attn.v_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.3.self_attn.o_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.3.self_attn.rotary_emb': MistralRotaryEmbedding(),\n",
       " 'model.layers.3.mlp': MistralMLP(\n",
       "   (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "   (act_fn): SiLU()\n",
       " ),\n",
       " 'model.layers.3.mlp.gate_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.3.mlp.up_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.3.mlp.down_proj': Linear(in_features=14336, out_features=4096, bias=False),\n",
       " 'model.layers.3.mlp.act_fn': SiLU(),\n",
       " 'model.layers.3.input_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.3.post_attention_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.4': MistralDecoderLayer(\n",
       "   (self_attn): MistralAttention(\n",
       "     (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (rotary_emb): MistralRotaryEmbedding()\n",
       "   )\n",
       "   (mlp): MistralMLP(\n",
       "     (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "     (act_fn): SiLU()\n",
       "   )\n",
       "   (input_layernorm): MistralRMSNorm()\n",
       "   (post_attention_layernorm): MistralRMSNorm()\n",
       " ),\n",
       " 'model.layers.4.self_attn': MistralAttention(\n",
       "   (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (rotary_emb): MistralRotaryEmbedding()\n",
       " ),\n",
       " 'model.layers.4.self_attn.q_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.4.self_attn.k_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.4.self_attn.v_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.4.self_attn.o_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.4.self_attn.rotary_emb': MistralRotaryEmbedding(),\n",
       " 'model.layers.4.mlp': MistralMLP(\n",
       "   (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "   (act_fn): SiLU()\n",
       " ),\n",
       " 'model.layers.4.mlp.gate_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.4.mlp.up_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.4.mlp.down_proj': Linear(in_features=14336, out_features=4096, bias=False),\n",
       " 'model.layers.4.mlp.act_fn': SiLU(),\n",
       " 'model.layers.4.input_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.4.post_attention_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.5': MistralDecoderLayer(\n",
       "   (self_attn): MistralAttention(\n",
       "     (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (rotary_emb): MistralRotaryEmbedding()\n",
       "   )\n",
       "   (mlp): MistralMLP(\n",
       "     (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "     (act_fn): SiLU()\n",
       "   )\n",
       "   (input_layernorm): MistralRMSNorm()\n",
       "   (post_attention_layernorm): MistralRMSNorm()\n",
       " ),\n",
       " 'model.layers.5.self_attn': MistralAttention(\n",
       "   (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (rotary_emb): MistralRotaryEmbedding()\n",
       " ),\n",
       " 'model.layers.5.self_attn.q_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.5.self_attn.k_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.5.self_attn.v_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.5.self_attn.o_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.5.self_attn.rotary_emb': MistralRotaryEmbedding(),\n",
       " 'model.layers.5.mlp': MistralMLP(\n",
       "   (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "   (act_fn): SiLU()\n",
       " ),\n",
       " 'model.layers.5.mlp.gate_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.5.mlp.up_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.5.mlp.down_proj': Linear(in_features=14336, out_features=4096, bias=False),\n",
       " 'model.layers.5.mlp.act_fn': SiLU(),\n",
       " 'model.layers.5.input_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.5.post_attention_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.6': MistralDecoderLayer(\n",
       "   (self_attn): MistralAttention(\n",
       "     (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (rotary_emb): MistralRotaryEmbedding()\n",
       "   )\n",
       "   (mlp): MistralMLP(\n",
       "     (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "     (act_fn): SiLU()\n",
       "   )\n",
       "   (input_layernorm): MistralRMSNorm()\n",
       "   (post_attention_layernorm): MistralRMSNorm()\n",
       " ),\n",
       " 'model.layers.6.self_attn': MistralAttention(\n",
       "   (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (rotary_emb): MistralRotaryEmbedding()\n",
       " ),\n",
       " 'model.layers.6.self_attn.q_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.6.self_attn.k_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.6.self_attn.v_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.6.self_attn.o_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.6.self_attn.rotary_emb': MistralRotaryEmbedding(),\n",
       " 'model.layers.6.mlp': MistralMLP(\n",
       "   (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "   (act_fn): SiLU()\n",
       " ),\n",
       " 'model.layers.6.mlp.gate_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.6.mlp.up_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.6.mlp.down_proj': Linear(in_features=14336, out_features=4096, bias=False),\n",
       " 'model.layers.6.mlp.act_fn': SiLU(),\n",
       " 'model.layers.6.input_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.6.post_attention_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.7': MistralDecoderLayer(\n",
       "   (self_attn): MistralAttention(\n",
       "     (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (rotary_emb): MistralRotaryEmbedding()\n",
       "   )\n",
       "   (mlp): MistralMLP(\n",
       "     (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "     (act_fn): SiLU()\n",
       "   )\n",
       "   (input_layernorm): MistralRMSNorm()\n",
       "   (post_attention_layernorm): MistralRMSNorm()\n",
       " ),\n",
       " 'model.layers.7.self_attn': MistralAttention(\n",
       "   (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (rotary_emb): MistralRotaryEmbedding()\n",
       " ),\n",
       " 'model.layers.7.self_attn.q_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.7.self_attn.k_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.7.self_attn.v_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.7.self_attn.o_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.7.self_attn.rotary_emb': MistralRotaryEmbedding(),\n",
       " 'model.layers.7.mlp': MistralMLP(\n",
       "   (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "   (act_fn): SiLU()\n",
       " ),\n",
       " 'model.layers.7.mlp.gate_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.7.mlp.up_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.7.mlp.down_proj': Linear(in_features=14336, out_features=4096, bias=False),\n",
       " 'model.layers.7.mlp.act_fn': SiLU(),\n",
       " 'model.layers.7.input_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.7.post_attention_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.8': MistralDecoderLayer(\n",
       "   (self_attn): MistralAttention(\n",
       "     (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (rotary_emb): MistralRotaryEmbedding()\n",
       "   )\n",
       "   (mlp): MistralMLP(\n",
       "     (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "     (act_fn): SiLU()\n",
       "   )\n",
       "   (input_layernorm): MistralRMSNorm()\n",
       "   (post_attention_layernorm): MistralRMSNorm()\n",
       " ),\n",
       " 'model.layers.8.self_attn': MistralAttention(\n",
       "   (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (rotary_emb): MistralRotaryEmbedding()\n",
       " ),\n",
       " 'model.layers.8.self_attn.q_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.8.self_attn.k_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.8.self_attn.v_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.8.self_attn.o_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.8.self_attn.rotary_emb': MistralRotaryEmbedding(),\n",
       " 'model.layers.8.mlp': MistralMLP(\n",
       "   (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "   (act_fn): SiLU()\n",
       " ),\n",
       " 'model.layers.8.mlp.gate_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.8.mlp.up_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.8.mlp.down_proj': Linear(in_features=14336, out_features=4096, bias=False),\n",
       " 'model.layers.8.mlp.act_fn': SiLU(),\n",
       " 'model.layers.8.input_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.8.post_attention_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.9': MistralDecoderLayer(\n",
       "   (self_attn): MistralAttention(\n",
       "     (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (rotary_emb): MistralRotaryEmbedding()\n",
       "   )\n",
       "   (mlp): MistralMLP(\n",
       "     (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "     (act_fn): SiLU()\n",
       "   )\n",
       "   (input_layernorm): MistralRMSNorm()\n",
       "   (post_attention_layernorm): MistralRMSNorm()\n",
       " ),\n",
       " 'model.layers.9.self_attn': MistralAttention(\n",
       "   (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (rotary_emb): MistralRotaryEmbedding()\n",
       " ),\n",
       " 'model.layers.9.self_attn.q_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.9.self_attn.k_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.9.self_attn.v_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.9.self_attn.o_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.9.self_attn.rotary_emb': MistralRotaryEmbedding(),\n",
       " 'model.layers.9.mlp': MistralMLP(\n",
       "   (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "   (act_fn): SiLU()\n",
       " ),\n",
       " 'model.layers.9.mlp.gate_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.9.mlp.up_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.9.mlp.down_proj': Linear(in_features=14336, out_features=4096, bias=False),\n",
       " 'model.layers.9.mlp.act_fn': SiLU(),\n",
       " 'model.layers.9.input_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.9.post_attention_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.10': MistralDecoderLayer(\n",
       "   (self_attn): MistralAttention(\n",
       "     (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (rotary_emb): MistralRotaryEmbedding()\n",
       "   )\n",
       "   (mlp): MistralMLP(\n",
       "     (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "     (act_fn): SiLU()\n",
       "   )\n",
       "   (input_layernorm): MistralRMSNorm()\n",
       "   (post_attention_layernorm): MistralRMSNorm()\n",
       " ),\n",
       " 'model.layers.10.self_attn': MistralAttention(\n",
       "   (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (rotary_emb): MistralRotaryEmbedding()\n",
       " ),\n",
       " 'model.layers.10.self_attn.q_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.10.self_attn.k_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.10.self_attn.v_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.10.self_attn.o_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.10.self_attn.rotary_emb': MistralRotaryEmbedding(),\n",
       " 'model.layers.10.mlp': MistralMLP(\n",
       "   (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "   (act_fn): SiLU()\n",
       " ),\n",
       " 'model.layers.10.mlp.gate_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.10.mlp.up_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.10.mlp.down_proj': Linear(in_features=14336, out_features=4096, bias=False),\n",
       " 'model.layers.10.mlp.act_fn': SiLU(),\n",
       " 'model.layers.10.input_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.10.post_attention_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.11': MistralDecoderLayer(\n",
       "   (self_attn): MistralAttention(\n",
       "     (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (rotary_emb): MistralRotaryEmbedding()\n",
       "   )\n",
       "   (mlp): MistralMLP(\n",
       "     (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "     (act_fn): SiLU()\n",
       "   )\n",
       "   (input_layernorm): MistralRMSNorm()\n",
       "   (post_attention_layernorm): MistralRMSNorm()\n",
       " ),\n",
       " 'model.layers.11.self_attn': MistralAttention(\n",
       "   (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (rotary_emb): MistralRotaryEmbedding()\n",
       " ),\n",
       " 'model.layers.11.self_attn.q_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.11.self_attn.k_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.11.self_attn.v_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.11.self_attn.o_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.11.self_attn.rotary_emb': MistralRotaryEmbedding(),\n",
       " 'model.layers.11.mlp': MistralMLP(\n",
       "   (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "   (act_fn): SiLU()\n",
       " ),\n",
       " 'model.layers.11.mlp.gate_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.11.mlp.up_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.11.mlp.down_proj': Linear(in_features=14336, out_features=4096, bias=False),\n",
       " 'model.layers.11.mlp.act_fn': SiLU(),\n",
       " 'model.layers.11.input_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.11.post_attention_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.12': MistralDecoderLayer(\n",
       "   (self_attn): MistralAttention(\n",
       "     (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (rotary_emb): MistralRotaryEmbedding()\n",
       "   )\n",
       "   (mlp): MistralMLP(\n",
       "     (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "     (act_fn): SiLU()\n",
       "   )\n",
       "   (input_layernorm): MistralRMSNorm()\n",
       "   (post_attention_layernorm): MistralRMSNorm()\n",
       " ),\n",
       " 'model.layers.12.self_attn': MistralAttention(\n",
       "   (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (rotary_emb): MistralRotaryEmbedding()\n",
       " ),\n",
       " 'model.layers.12.self_attn.q_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.12.self_attn.k_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.12.self_attn.v_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.12.self_attn.o_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.12.self_attn.rotary_emb': MistralRotaryEmbedding(),\n",
       " 'model.layers.12.mlp': MistralMLP(\n",
       "   (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "   (act_fn): SiLU()\n",
       " ),\n",
       " 'model.layers.12.mlp.gate_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.12.mlp.up_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.12.mlp.down_proj': Linear(in_features=14336, out_features=4096, bias=False),\n",
       " 'model.layers.12.mlp.act_fn': SiLU(),\n",
       " 'model.layers.12.input_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.12.post_attention_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.13': MistralDecoderLayer(\n",
       "   (self_attn): MistralAttention(\n",
       "     (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (rotary_emb): MistralRotaryEmbedding()\n",
       "   )\n",
       "   (mlp): MistralMLP(\n",
       "     (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "     (act_fn): SiLU()\n",
       "   )\n",
       "   (input_layernorm): MistralRMSNorm()\n",
       "   (post_attention_layernorm): MistralRMSNorm()\n",
       " ),\n",
       " 'model.layers.13.self_attn': MistralAttention(\n",
       "   (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (rotary_emb): MistralRotaryEmbedding()\n",
       " ),\n",
       " 'model.layers.13.self_attn.q_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.13.self_attn.k_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.13.self_attn.v_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.13.self_attn.o_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.13.self_attn.rotary_emb': MistralRotaryEmbedding(),\n",
       " 'model.layers.13.mlp': MistralMLP(\n",
       "   (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "   (act_fn): SiLU()\n",
       " ),\n",
       " 'model.layers.13.mlp.gate_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.13.mlp.up_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.13.mlp.down_proj': Linear(in_features=14336, out_features=4096, bias=False),\n",
       " 'model.layers.13.mlp.act_fn': SiLU(),\n",
       " 'model.layers.13.input_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.13.post_attention_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.14': MistralDecoderLayer(\n",
       "   (self_attn): MistralAttention(\n",
       "     (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (rotary_emb): MistralRotaryEmbedding()\n",
       "   )\n",
       "   (mlp): MistralMLP(\n",
       "     (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "     (act_fn): SiLU()\n",
       "   )\n",
       "   (input_layernorm): MistralRMSNorm()\n",
       "   (post_attention_layernorm): MistralRMSNorm()\n",
       " ),\n",
       " 'model.layers.14.self_attn': MistralAttention(\n",
       "   (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (rotary_emb): MistralRotaryEmbedding()\n",
       " ),\n",
       " 'model.layers.14.self_attn.q_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.14.self_attn.k_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.14.self_attn.v_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.14.self_attn.o_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.14.self_attn.rotary_emb': MistralRotaryEmbedding(),\n",
       " 'model.layers.14.mlp': MistralMLP(\n",
       "   (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "   (act_fn): SiLU()\n",
       " ),\n",
       " 'model.layers.14.mlp.gate_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.14.mlp.up_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.14.mlp.down_proj': Linear(in_features=14336, out_features=4096, bias=False),\n",
       " 'model.layers.14.mlp.act_fn': SiLU(),\n",
       " 'model.layers.14.input_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.14.post_attention_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.15': MistralDecoderLayer(\n",
       "   (self_attn): MistralAttention(\n",
       "     (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (rotary_emb): MistralRotaryEmbedding()\n",
       "   )\n",
       "   (mlp): MistralMLP(\n",
       "     (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "     (act_fn): SiLU()\n",
       "   )\n",
       "   (input_layernorm): MistralRMSNorm()\n",
       "   (post_attention_layernorm): MistralRMSNorm()\n",
       " ),\n",
       " 'model.layers.15.self_attn': MistralAttention(\n",
       "   (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (rotary_emb): MistralRotaryEmbedding()\n",
       " ),\n",
       " 'model.layers.15.self_attn.q_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.15.self_attn.k_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.15.self_attn.v_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.15.self_attn.o_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.15.self_attn.rotary_emb': MistralRotaryEmbedding(),\n",
       " 'model.layers.15.mlp': MistralMLP(\n",
       "   (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "   (act_fn): SiLU()\n",
       " ),\n",
       " 'model.layers.15.mlp.gate_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.15.mlp.up_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.15.mlp.down_proj': Linear(in_features=14336, out_features=4096, bias=False),\n",
       " 'model.layers.15.mlp.act_fn': SiLU(),\n",
       " 'model.layers.15.input_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.15.post_attention_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.16': MistralDecoderLayer(\n",
       "   (self_attn): MistralAttention(\n",
       "     (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (rotary_emb): MistralRotaryEmbedding()\n",
       "   )\n",
       "   (mlp): MistralMLP(\n",
       "     (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "     (act_fn): SiLU()\n",
       "   )\n",
       "   (input_layernorm): MistralRMSNorm()\n",
       "   (post_attention_layernorm): MistralRMSNorm()\n",
       " ),\n",
       " 'model.layers.16.self_attn': MistralAttention(\n",
       "   (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (rotary_emb): MistralRotaryEmbedding()\n",
       " ),\n",
       " 'model.layers.16.self_attn.q_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.16.self_attn.k_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.16.self_attn.v_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.16.self_attn.o_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.16.self_attn.rotary_emb': MistralRotaryEmbedding(),\n",
       " 'model.layers.16.mlp': MistralMLP(\n",
       "   (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "   (act_fn): SiLU()\n",
       " ),\n",
       " 'model.layers.16.mlp.gate_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.16.mlp.up_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.16.mlp.down_proj': Linear(in_features=14336, out_features=4096, bias=False),\n",
       " 'model.layers.16.mlp.act_fn': SiLU(),\n",
       " 'model.layers.16.input_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.16.post_attention_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.17': MistralDecoderLayer(\n",
       "   (self_attn): MistralAttention(\n",
       "     (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (rotary_emb): MistralRotaryEmbedding()\n",
       "   )\n",
       "   (mlp): MistralMLP(\n",
       "     (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "     (act_fn): SiLU()\n",
       "   )\n",
       "   (input_layernorm): MistralRMSNorm()\n",
       "   (post_attention_layernorm): MistralRMSNorm()\n",
       " ),\n",
       " 'model.layers.17.self_attn': MistralAttention(\n",
       "   (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (rotary_emb): MistralRotaryEmbedding()\n",
       " ),\n",
       " 'model.layers.17.self_attn.q_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.17.self_attn.k_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.17.self_attn.v_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.17.self_attn.o_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.17.self_attn.rotary_emb': MistralRotaryEmbedding(),\n",
       " 'model.layers.17.mlp': MistralMLP(\n",
       "   (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "   (act_fn): SiLU()\n",
       " ),\n",
       " 'model.layers.17.mlp.gate_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.17.mlp.up_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.17.mlp.down_proj': Linear(in_features=14336, out_features=4096, bias=False),\n",
       " 'model.layers.17.mlp.act_fn': SiLU(),\n",
       " 'model.layers.17.input_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.17.post_attention_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.18': MistralDecoderLayer(\n",
       "   (self_attn): MistralAttention(\n",
       "     (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (rotary_emb): MistralRotaryEmbedding()\n",
       "   )\n",
       "   (mlp): MistralMLP(\n",
       "     (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "     (act_fn): SiLU()\n",
       "   )\n",
       "   (input_layernorm): MistralRMSNorm()\n",
       "   (post_attention_layernorm): MistralRMSNorm()\n",
       " ),\n",
       " 'model.layers.18.self_attn': MistralAttention(\n",
       "   (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (rotary_emb): MistralRotaryEmbedding()\n",
       " ),\n",
       " 'model.layers.18.self_attn.q_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.18.self_attn.k_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.18.self_attn.v_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.18.self_attn.o_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.18.self_attn.rotary_emb': MistralRotaryEmbedding(),\n",
       " 'model.layers.18.mlp': MistralMLP(\n",
       "   (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "   (act_fn): SiLU()\n",
       " ),\n",
       " 'model.layers.18.mlp.gate_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.18.mlp.up_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.18.mlp.down_proj': Linear(in_features=14336, out_features=4096, bias=False),\n",
       " 'model.layers.18.mlp.act_fn': SiLU(),\n",
       " 'model.layers.18.input_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.18.post_attention_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.19': MistralDecoderLayer(\n",
       "   (self_attn): MistralAttention(\n",
       "     (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (rotary_emb): MistralRotaryEmbedding()\n",
       "   )\n",
       "   (mlp): MistralMLP(\n",
       "     (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "     (act_fn): SiLU()\n",
       "   )\n",
       "   (input_layernorm): MistralRMSNorm()\n",
       "   (post_attention_layernorm): MistralRMSNorm()\n",
       " ),\n",
       " 'model.layers.19.self_attn': MistralAttention(\n",
       "   (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (rotary_emb): MistralRotaryEmbedding()\n",
       " ),\n",
       " 'model.layers.19.self_attn.q_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.19.self_attn.k_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.19.self_attn.v_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.19.self_attn.o_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.19.self_attn.rotary_emb': MistralRotaryEmbedding(),\n",
       " 'model.layers.19.mlp': MistralMLP(\n",
       "   (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "   (act_fn): SiLU()\n",
       " ),\n",
       " 'model.layers.19.mlp.gate_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.19.mlp.up_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.19.mlp.down_proj': Linear(in_features=14336, out_features=4096, bias=False),\n",
       " 'model.layers.19.mlp.act_fn': SiLU(),\n",
       " 'model.layers.19.input_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.19.post_attention_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.20': MistralDecoderLayer(\n",
       "   (self_attn): MistralAttention(\n",
       "     (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (rotary_emb): MistralRotaryEmbedding()\n",
       "   )\n",
       "   (mlp): MistralMLP(\n",
       "     (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "     (act_fn): SiLU()\n",
       "   )\n",
       "   (input_layernorm): MistralRMSNorm()\n",
       "   (post_attention_layernorm): MistralRMSNorm()\n",
       " ),\n",
       " 'model.layers.20.self_attn': MistralAttention(\n",
       "   (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (rotary_emb): MistralRotaryEmbedding()\n",
       " ),\n",
       " 'model.layers.20.self_attn.q_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.20.self_attn.k_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.20.self_attn.v_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.20.self_attn.o_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.20.self_attn.rotary_emb': MistralRotaryEmbedding(),\n",
       " 'model.layers.20.mlp': MistralMLP(\n",
       "   (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "   (act_fn): SiLU()\n",
       " ),\n",
       " 'model.layers.20.mlp.gate_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.20.mlp.up_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.20.mlp.down_proj': Linear(in_features=14336, out_features=4096, bias=False),\n",
       " 'model.layers.20.mlp.act_fn': SiLU(),\n",
       " 'model.layers.20.input_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.20.post_attention_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.21': MistralDecoderLayer(\n",
       "   (self_attn): MistralAttention(\n",
       "     (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (rotary_emb): MistralRotaryEmbedding()\n",
       "   )\n",
       "   (mlp): MistralMLP(\n",
       "     (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "     (act_fn): SiLU()\n",
       "   )\n",
       "   (input_layernorm): MistralRMSNorm()\n",
       "   (post_attention_layernorm): MistralRMSNorm()\n",
       " ),\n",
       " 'model.layers.21.self_attn': MistralAttention(\n",
       "   (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (rotary_emb): MistralRotaryEmbedding()\n",
       " ),\n",
       " 'model.layers.21.self_attn.q_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.21.self_attn.k_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.21.self_attn.v_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.21.self_attn.o_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.21.self_attn.rotary_emb': MistralRotaryEmbedding(),\n",
       " 'model.layers.21.mlp': MistralMLP(\n",
       "   (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "   (act_fn): SiLU()\n",
       " ),\n",
       " 'model.layers.21.mlp.gate_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.21.mlp.up_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.21.mlp.down_proj': Linear(in_features=14336, out_features=4096, bias=False),\n",
       " 'model.layers.21.mlp.act_fn': SiLU(),\n",
       " 'model.layers.21.input_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.21.post_attention_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.22': MistralDecoderLayer(\n",
       "   (self_attn): MistralAttention(\n",
       "     (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (rotary_emb): MistralRotaryEmbedding()\n",
       "   )\n",
       "   (mlp): MistralMLP(\n",
       "     (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "     (act_fn): SiLU()\n",
       "   )\n",
       "   (input_layernorm): MistralRMSNorm()\n",
       "   (post_attention_layernorm): MistralRMSNorm()\n",
       " ),\n",
       " 'model.layers.22.self_attn': MistralAttention(\n",
       "   (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (rotary_emb): MistralRotaryEmbedding()\n",
       " ),\n",
       " 'model.layers.22.self_attn.q_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.22.self_attn.k_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.22.self_attn.v_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.22.self_attn.o_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.22.self_attn.rotary_emb': MistralRotaryEmbedding(),\n",
       " 'model.layers.22.mlp': MistralMLP(\n",
       "   (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "   (act_fn): SiLU()\n",
       " ),\n",
       " 'model.layers.22.mlp.gate_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.22.mlp.up_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.22.mlp.down_proj': Linear(in_features=14336, out_features=4096, bias=False),\n",
       " 'model.layers.22.mlp.act_fn': SiLU(),\n",
       " 'model.layers.22.input_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.22.post_attention_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.23': MistralDecoderLayer(\n",
       "   (self_attn): MistralAttention(\n",
       "     (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (rotary_emb): MistralRotaryEmbedding()\n",
       "   )\n",
       "   (mlp): MistralMLP(\n",
       "     (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "     (act_fn): SiLU()\n",
       "   )\n",
       "   (input_layernorm): MistralRMSNorm()\n",
       "   (post_attention_layernorm): MistralRMSNorm()\n",
       " ),\n",
       " 'model.layers.23.self_attn': MistralAttention(\n",
       "   (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (rotary_emb): MistralRotaryEmbedding()\n",
       " ),\n",
       " 'model.layers.23.self_attn.q_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.23.self_attn.k_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.23.self_attn.v_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.23.self_attn.o_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.23.self_attn.rotary_emb': MistralRotaryEmbedding(),\n",
       " 'model.layers.23.mlp': MistralMLP(\n",
       "   (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "   (act_fn): SiLU()\n",
       " ),\n",
       " 'model.layers.23.mlp.gate_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.23.mlp.up_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.23.mlp.down_proj': Linear(in_features=14336, out_features=4096, bias=False),\n",
       " 'model.layers.23.mlp.act_fn': SiLU(),\n",
       " 'model.layers.23.input_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.23.post_attention_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.24': MistralDecoderLayer(\n",
       "   (self_attn): MistralAttention(\n",
       "     (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (rotary_emb): MistralRotaryEmbedding()\n",
       "   )\n",
       "   (mlp): MistralMLP(\n",
       "     (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "     (act_fn): SiLU()\n",
       "   )\n",
       "   (input_layernorm): MistralRMSNorm()\n",
       "   (post_attention_layernorm): MistralRMSNorm()\n",
       " ),\n",
       " 'model.layers.24.self_attn': MistralAttention(\n",
       "   (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (rotary_emb): MistralRotaryEmbedding()\n",
       " ),\n",
       " 'model.layers.24.self_attn.q_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.24.self_attn.k_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.24.self_attn.v_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.24.self_attn.o_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.24.self_attn.rotary_emb': MistralRotaryEmbedding(),\n",
       " 'model.layers.24.mlp': MistralMLP(\n",
       "   (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "   (act_fn): SiLU()\n",
       " ),\n",
       " 'model.layers.24.mlp.gate_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.24.mlp.up_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.24.mlp.down_proj': Linear(in_features=14336, out_features=4096, bias=False),\n",
       " 'model.layers.24.mlp.act_fn': SiLU(),\n",
       " 'model.layers.24.input_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.24.post_attention_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.25': MistralDecoderLayer(\n",
       "   (self_attn): MistralAttention(\n",
       "     (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (rotary_emb): MistralRotaryEmbedding()\n",
       "   )\n",
       "   (mlp): MistralMLP(\n",
       "     (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "     (act_fn): SiLU()\n",
       "   )\n",
       "   (input_layernorm): MistralRMSNorm()\n",
       "   (post_attention_layernorm): MistralRMSNorm()\n",
       " ),\n",
       " 'model.layers.25.self_attn': MistralAttention(\n",
       "   (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (rotary_emb): MistralRotaryEmbedding()\n",
       " ),\n",
       " 'model.layers.25.self_attn.q_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.25.self_attn.k_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.25.self_attn.v_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.25.self_attn.o_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.25.self_attn.rotary_emb': MistralRotaryEmbedding(),\n",
       " 'model.layers.25.mlp': MistralMLP(\n",
       "   (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "   (act_fn): SiLU()\n",
       " ),\n",
       " 'model.layers.25.mlp.gate_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.25.mlp.up_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.25.mlp.down_proj': Linear(in_features=14336, out_features=4096, bias=False),\n",
       " 'model.layers.25.mlp.act_fn': SiLU(),\n",
       " 'model.layers.25.input_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.25.post_attention_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.26': MistralDecoderLayer(\n",
       "   (self_attn): MistralAttention(\n",
       "     (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (rotary_emb): MistralRotaryEmbedding()\n",
       "   )\n",
       "   (mlp): MistralMLP(\n",
       "     (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "     (act_fn): SiLU()\n",
       "   )\n",
       "   (input_layernorm): MistralRMSNorm()\n",
       "   (post_attention_layernorm): MistralRMSNorm()\n",
       " ),\n",
       " 'model.layers.26.self_attn': MistralAttention(\n",
       "   (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (rotary_emb): MistralRotaryEmbedding()\n",
       " ),\n",
       " 'model.layers.26.self_attn.q_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.26.self_attn.k_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.26.self_attn.v_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.26.self_attn.o_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.26.self_attn.rotary_emb': MistralRotaryEmbedding(),\n",
       " 'model.layers.26.mlp': MistralMLP(\n",
       "   (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "   (act_fn): SiLU()\n",
       " ),\n",
       " 'model.layers.26.mlp.gate_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.26.mlp.up_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.26.mlp.down_proj': Linear(in_features=14336, out_features=4096, bias=False),\n",
       " 'model.layers.26.mlp.act_fn': SiLU(),\n",
       " 'model.layers.26.input_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.26.post_attention_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.27': MistralDecoderLayer(\n",
       "   (self_attn): MistralAttention(\n",
       "     (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (rotary_emb): MistralRotaryEmbedding()\n",
       "   )\n",
       "   (mlp): MistralMLP(\n",
       "     (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "     (act_fn): SiLU()\n",
       "   )\n",
       "   (input_layernorm): MistralRMSNorm()\n",
       "   (post_attention_layernorm): MistralRMSNorm()\n",
       " ),\n",
       " 'model.layers.27.self_attn': MistralAttention(\n",
       "   (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (rotary_emb): MistralRotaryEmbedding()\n",
       " ),\n",
       " 'model.layers.27.self_attn.q_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.27.self_attn.k_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.27.self_attn.v_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.27.self_attn.o_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.27.self_attn.rotary_emb': MistralRotaryEmbedding(),\n",
       " 'model.layers.27.mlp': MistralMLP(\n",
       "   (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "   (act_fn): SiLU()\n",
       " ),\n",
       " 'model.layers.27.mlp.gate_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.27.mlp.up_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.27.mlp.down_proj': Linear(in_features=14336, out_features=4096, bias=False),\n",
       " 'model.layers.27.mlp.act_fn': SiLU(),\n",
       " 'model.layers.27.input_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.27.post_attention_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.28': MistralDecoderLayer(\n",
       "   (self_attn): MistralAttention(\n",
       "     (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (rotary_emb): MistralRotaryEmbedding()\n",
       "   )\n",
       "   (mlp): MistralMLP(\n",
       "     (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "     (act_fn): SiLU()\n",
       "   )\n",
       "   (input_layernorm): MistralRMSNorm()\n",
       "   (post_attention_layernorm): MistralRMSNorm()\n",
       " ),\n",
       " 'model.layers.28.self_attn': MistralAttention(\n",
       "   (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (rotary_emb): MistralRotaryEmbedding()\n",
       " ),\n",
       " 'model.layers.28.self_attn.q_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.28.self_attn.k_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.28.self_attn.v_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.28.self_attn.o_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.28.self_attn.rotary_emb': MistralRotaryEmbedding(),\n",
       " 'model.layers.28.mlp': MistralMLP(\n",
       "   (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "   (act_fn): SiLU()\n",
       " ),\n",
       " 'model.layers.28.mlp.gate_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.28.mlp.up_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.28.mlp.down_proj': Linear(in_features=14336, out_features=4096, bias=False),\n",
       " 'model.layers.28.mlp.act_fn': SiLU(),\n",
       " 'model.layers.28.input_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.28.post_attention_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.29': MistralDecoderLayer(\n",
       "   (self_attn): MistralAttention(\n",
       "     (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (rotary_emb): MistralRotaryEmbedding()\n",
       "   )\n",
       "   (mlp): MistralMLP(\n",
       "     (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "     (act_fn): SiLU()\n",
       "   )\n",
       "   (input_layernorm): MistralRMSNorm()\n",
       "   (post_attention_layernorm): MistralRMSNorm()\n",
       " ),\n",
       " 'model.layers.29.self_attn': MistralAttention(\n",
       "   (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (rotary_emb): MistralRotaryEmbedding()\n",
       " ),\n",
       " 'model.layers.29.self_attn.q_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.29.self_attn.k_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.29.self_attn.v_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.29.self_attn.o_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.29.self_attn.rotary_emb': MistralRotaryEmbedding(),\n",
       " 'model.layers.29.mlp': MistralMLP(\n",
       "   (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "   (act_fn): SiLU()\n",
       " ),\n",
       " 'model.layers.29.mlp.gate_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.29.mlp.up_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.29.mlp.down_proj': Linear(in_features=14336, out_features=4096, bias=False),\n",
       " 'model.layers.29.mlp.act_fn': SiLU(),\n",
       " 'model.layers.29.input_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.29.post_attention_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.30': MistralDecoderLayer(\n",
       "   (self_attn): MistralAttention(\n",
       "     (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (rotary_emb): MistralRotaryEmbedding()\n",
       "   )\n",
       "   (mlp): MistralMLP(\n",
       "     (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "     (act_fn): SiLU()\n",
       "   )\n",
       "   (input_layernorm): MistralRMSNorm()\n",
       "   (post_attention_layernorm): MistralRMSNorm()\n",
       " ),\n",
       " 'model.layers.30.self_attn': MistralAttention(\n",
       "   (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (rotary_emb): MistralRotaryEmbedding()\n",
       " ),\n",
       " 'model.layers.30.self_attn.q_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.30.self_attn.k_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.30.self_attn.v_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.30.self_attn.o_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.30.self_attn.rotary_emb': MistralRotaryEmbedding(),\n",
       " 'model.layers.30.mlp': MistralMLP(\n",
       "   (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "   (act_fn): SiLU()\n",
       " ),\n",
       " 'model.layers.30.mlp.gate_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.30.mlp.up_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.30.mlp.down_proj': Linear(in_features=14336, out_features=4096, bias=False),\n",
       " 'model.layers.30.mlp.act_fn': SiLU(),\n",
       " 'model.layers.30.input_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.30.post_attention_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.31': MistralDecoderLayer(\n",
       "   (self_attn): MistralAttention(\n",
       "     (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "     (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "     (rotary_emb): MistralRotaryEmbedding()\n",
       "   )\n",
       "   (mlp): MistralMLP(\n",
       "     (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "     (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "     (act_fn): SiLU()\n",
       "   )\n",
       "   (input_layernorm): MistralRMSNorm()\n",
       "   (post_attention_layernorm): MistralRMSNorm()\n",
       " ),\n",
       " 'model.layers.31.self_attn': MistralAttention(\n",
       "   (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
       "   (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
       "   (rotary_emb): MistralRotaryEmbedding()\n",
       " ),\n",
       " 'model.layers.31.self_attn.q_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.31.self_attn.k_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.31.self_attn.v_proj': Linear(in_features=4096, out_features=1024, bias=False),\n",
       " 'model.layers.31.self_attn.o_proj': Linear(in_features=4096, out_features=4096, bias=False),\n",
       " 'model.layers.31.self_attn.rotary_emb': MistralRotaryEmbedding(),\n",
       " 'model.layers.31.mlp': MistralMLP(\n",
       "   (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
       "   (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
       "   (act_fn): SiLU()\n",
       " ),\n",
       " 'model.layers.31.mlp.gate_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.31.mlp.up_proj': Linear(in_features=4096, out_features=14336, bias=False),\n",
       " 'model.layers.31.mlp.down_proj': Linear(in_features=14336, out_features=4096, bias=False),\n",
       " 'model.layers.31.mlp.act_fn': SiLU(),\n",
       " 'model.layers.31.input_layernorm': MistralRMSNorm(),\n",
       " 'model.layers.31.post_attention_layernorm': MistralRMSNorm(),\n",
       " 'model.norm': MistralRMSNorm(),\n",
       " 'lm_head': Linear(in_features=4096, out_features=32000, bias=False)}"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dict(model.named_modules())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "691534a9-af04-40b4-a544-a780969a844f",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
