{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "a24bb258-dcb3-41d8-9d59-dc8ab595cdaf",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "from transformers.models.mistral.modeling_mistral import *\n",
    "from transformers.cache_utils import DynamicCache\n",
    "import torch\n",
    "from torch import nn\n",
    "import copy\n",
    "from types import MethodType\n",
    "import pandas as pd\n",
    "import datasets\n",
    "from torch.utils.data import DataLoader\n",
    "from itertools import islice\n",
    "from tqdm import tqdm\n",
    "import gc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "92444460-f13b-4482-b92b-2b6b190791c0",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "18c51bc6a9e34a2c9f5316f091c4d3e0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[13], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m model \u001b[38;5;241m=\u001b[39m \u001b[43mAutoModelForCausalLM\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfrom_pretrained\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mmistralai/Mistral-7B-v0.1\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mto(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mcuda\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m      2\u001b[0m tokenizer \u001b[38;5;241m=\u001b[39m AutoTokenizer\u001b[38;5;241m.\u001b[39mfrom_pretrained(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmistralai/Mistral-7B-v0.1\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m      3\u001b[0m Token \u001b[38;5;241m=\u001b[39m {v: k \u001b[38;5;28;01mfor\u001b[39;00m k, v \u001b[38;5;129;01min\u001b[39;00m tokenizer\u001b[38;5;241m.\u001b[39mget_vocab()\u001b[38;5;241m.\u001b[39mitems()}\n",
      "File \u001b[0;32m~/.local/lib/python3.10/site-packages/transformers/models/auto/auto_factory.py:566\u001b[0m, in \u001b[0;36m_BaseAutoModelClass.from_pretrained\u001b[0;34m(cls, pretrained_model_name_or_path, *model_args, **kwargs)\u001b[0m\n\u001b[1;32m    564\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mtype\u001b[39m(config) \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mcls\u001b[39m\u001b[38;5;241m.\u001b[39m_model_mapping\u001b[38;5;241m.\u001b[39mkeys():\n\u001b[1;32m    565\u001b[0m     model_class \u001b[38;5;241m=\u001b[39m _get_model_class(config, \u001b[38;5;28mcls\u001b[39m\u001b[38;5;241m.\u001b[39m_model_mapping)\n\u001b[0;32m--> 566\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mmodel_class\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfrom_pretrained\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    567\u001b[0m \u001b[43m        \u001b[49m\u001b[43mpretrained_model_name_or_path\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mmodel_args\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mconfig\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mconfig\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mhub_kwargs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\n\u001b[1;32m    568\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    569\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m    570\u001b[0m     \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mUnrecognized configuration class \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mconfig\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m for this kind of AutoModel: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mcls\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    571\u001b[0m     \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mModel type should be one of \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m, \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;241m.\u001b[39mjoin(c\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mfor\u001b[39;00m\u001b[38;5;250m \u001b[39mc\u001b[38;5;250m \u001b[39m\u001b[38;5;129;01min\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28mcls\u001b[39m\u001b[38;5;241m.\u001b[39m_model_mapping\u001b[38;5;241m.\u001b[39mkeys())\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    572\u001b[0m )\n",
      "File \u001b[0;32m~/.local/lib/python3.10/site-packages/transformers/modeling_utils.py:3706\u001b[0m, in \u001b[0;36mPreTrainedModel.from_pretrained\u001b[0;34m(cls, pretrained_model_name_or_path, config, cache_dir, ignore_mismatched_sizes, force_download, local_files_only, token, revision, use_safetensors, *model_args, **kwargs)\u001b[0m\n\u001b[1;32m   3697\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m dtype_orig \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m   3698\u001b[0m         torch\u001b[38;5;241m.\u001b[39mset_default_dtype(dtype_orig)\n\u001b[1;32m   3699\u001b[0m     (\n\u001b[1;32m   3700\u001b[0m         model,\n\u001b[1;32m   3701\u001b[0m         missing_keys,\n\u001b[1;32m   3702\u001b[0m         unexpected_keys,\n\u001b[1;32m   3703\u001b[0m         mismatched_keys,\n\u001b[1;32m   3704\u001b[0m         offload_index,\n\u001b[1;32m   3705\u001b[0m         error_msgs,\n\u001b[0;32m-> 3706\u001b[0m     ) \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mcls\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_load_pretrained_model\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m   3707\u001b[0m \u001b[43m        \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   3708\u001b[0m \u001b[43m        \u001b[49m\u001b[43mstate_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   3709\u001b[0m \u001b[43m        \u001b[49m\u001b[43mloaded_state_dict_keys\u001b[49m\u001b[43m,\u001b[49m\u001b[43m  \u001b[49m\u001b[38;5;66;43;03m# XXX: rename?\u001b[39;49;00m\n\u001b[1;32m   3710\u001b[0m \u001b[43m        \u001b[49m\u001b[43mresolved_archive_file\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   3711\u001b[0m \u001b[43m        \u001b[49m\u001b[43mpretrained_model_name_or_path\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   3712\u001b[0m \u001b[43m        \u001b[49m\u001b[43mignore_mismatched_sizes\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mignore_mismatched_sizes\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   3713\u001b[0m \u001b[43m        \u001b[49m\u001b[43msharded_metadata\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msharded_metadata\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   3714\u001b[0m \u001b[43m        \u001b[49m\u001b[43m_fast_init\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m_fast_init\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   3715\u001b[0m \u001b[43m        \u001b[49m\u001b[43mlow_cpu_mem_usage\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlow_cpu_mem_usage\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   3716\u001b[0m \u001b[43m        \u001b[49m\u001b[43mdevice_map\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdevice_map\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   3717\u001b[0m \u001b[43m        \u001b[49m\u001b[43moffload_folder\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moffload_folder\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   3718\u001b[0m \u001b[43m        \u001b[49m\u001b[43moffload_state_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moffload_state_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   3719\u001b[0m \u001b[43m        \u001b[49m\u001b[43mdtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtorch_dtype\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   3720\u001b[0m \u001b[43m        \u001b[49m\u001b[43mis_quantized\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mgetattr\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mquantization_method\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m==\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mQuantizationMethod\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mBITS_AND_BYTES\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   3721\u001b[0m \u001b[43m        \u001b[49m\u001b[43mkeep_in_fp32_modules\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mkeep_in_fp32_modules\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   3722\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   3724\u001b[0m model\u001b[38;5;241m.\u001b[39mis_loaded_in_4bit \u001b[38;5;241m=\u001b[39m load_in_4bit\n\u001b[1;32m   3725\u001b[0m model\u001b[38;5;241m.\u001b[39mis_loaded_in_8bit \u001b[38;5;241m=\u001b[39m load_in_8bit\n",
      "File \u001b[0;32m~/.local/lib/python3.10/site-packages/transformers/modeling_utils.py:4134\u001b[0m, in \u001b[0;36mPreTrainedModel._load_pretrained_model\u001b[0;34m(cls, model, state_dict, loaded_keys, resolved_archive_file, pretrained_model_name_or_path, ignore_mismatched_sizes, sharded_metadata, _fast_init, low_cpu_mem_usage, device_map, offload_folder, offload_state_dict, dtype, is_quantized, keep_in_fp32_modules)\u001b[0m\n\u001b[1;32m   4132\u001b[0m         error_msgs \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m new_error_msgs\n\u001b[1;32m   4133\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 4134\u001b[0m     error_msgs \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[43m_load_state_dict_into_model\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel_to_load\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstate_dict\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstart_prefix\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   4136\u001b[0m \u001b[38;5;66;03m# force memory release\u001b[39;00m\n\u001b[1;32m   4137\u001b[0m \u001b[38;5;28;01mdel\u001b[39;00m state_dict\n",
      "File \u001b[0;32m~/.local/lib/python3.10/site-packages/transformers/modeling_utils.py:606\u001b[0m, in \u001b[0;36m_load_state_dict_into_model\u001b[0;34m(model_to_load, state_dict, start_prefix)\u001b[0m\n\u001b[1;32m    603\u001b[0m         \u001b[38;5;28;01mif\u001b[39;00m child \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m    604\u001b[0m             load(child, state_dict, prefix \u001b[38;5;241m+\u001b[39m name \u001b[38;5;241m+\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m--> 606\u001b[0m \u001b[43mload\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel_to_load\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstate_dict\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mprefix\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mstart_prefix\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    607\u001b[0m \u001b[38;5;66;03m# Delete `state_dict` so it could be collected by GC earlier. Note that `state_dict` is a copy of the argument, so\u001b[39;00m\n\u001b[1;32m    608\u001b[0m \u001b[38;5;66;03m# it's safe to delete it.\u001b[39;00m\n\u001b[1;32m    609\u001b[0m \u001b[38;5;28;01mdel\u001b[39;00m state_dict\n",
      "File \u001b[0;32m~/.local/lib/python3.10/site-packages/transformers/modeling_utils.py:604\u001b[0m, in \u001b[0;36m_load_state_dict_into_model.<locals>.load\u001b[0;34m(module, state_dict, prefix)\u001b[0m\n\u001b[1;32m    602\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m name, child \u001b[38;5;129;01min\u001b[39;00m module\u001b[38;5;241m.\u001b[39m_modules\u001b[38;5;241m.\u001b[39mitems():\n\u001b[1;32m    603\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m child \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 604\u001b[0m         \u001b[43mload\u001b[49m\u001b[43m(\u001b[49m\u001b[43mchild\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstate_dict\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mprefix\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mname\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43m.\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/.local/lib/python3.10/site-packages/transformers/modeling_utils.py:604\u001b[0m, in \u001b[0;36m_load_state_dict_into_model.<locals>.load\u001b[0;34m(module, state_dict, prefix)\u001b[0m\n\u001b[1;32m    602\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m name, child \u001b[38;5;129;01min\u001b[39;00m module\u001b[38;5;241m.\u001b[39m_modules\u001b[38;5;241m.\u001b[39mitems():\n\u001b[1;32m    603\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m child \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 604\u001b[0m         \u001b[43mload\u001b[49m\u001b[43m(\u001b[49m\u001b[43mchild\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstate_dict\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mprefix\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mname\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43m.\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n",
      "    \u001b[0;31m[... skipping similar frames: _load_state_dict_into_model.<locals>.load at line 604 (2 times)]\u001b[0m\n",
      "File \u001b[0;32m~/.local/lib/python3.10/site-packages/transformers/modeling_utils.py:604\u001b[0m, in \u001b[0;36m_load_state_dict_into_model.<locals>.load\u001b[0;34m(module, state_dict, prefix)\u001b[0m\n\u001b[1;32m    602\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m name, child \u001b[38;5;129;01min\u001b[39;00m module\u001b[38;5;241m.\u001b[39m_modules\u001b[38;5;241m.\u001b[39mitems():\n\u001b[1;32m    603\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m child \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 604\u001b[0m         \u001b[43mload\u001b[49m\u001b[43m(\u001b[49m\u001b[43mchild\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstate_dict\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mprefix\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mname\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43m.\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/.local/lib/python3.10/site-packages/transformers/modeling_utils.py:584\u001b[0m, in \u001b[0;36m_load_state_dict_into_model.<locals>.load\u001b[0;34m(module, state_dict, prefix)\u001b[0m\n\u001b[1;32m    581\u001b[0m args \u001b[38;5;241m=\u001b[39m (state_dict, prefix, local_metadata, \u001b[38;5;28;01mTrue\u001b[39;00m, [], [], error_msgs)\n\u001b[1;32m    582\u001b[0m \u001b[38;5;66;03m# Parameters of module and children will start with prefix. We can exit early if there are none in this\u001b[39;00m\n\u001b[1;32m    583\u001b[0m \u001b[38;5;66;03m# state_dict\u001b[39;00m\n\u001b[0;32m--> 584\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m([key \u001b[38;5;28;01mfor\u001b[39;00m key \u001b[38;5;129;01min\u001b[39;00m state_dict \u001b[38;5;28;01mif\u001b[39;00m key\u001b[38;5;241m.\u001b[39mstartswith(prefix)]) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[1;32m    585\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m is_deepspeed_zero3_enabled():\n\u001b[1;32m    586\u001b[0m         \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mdeepspeed\u001b[39;00m\n",
      "File \u001b[0;32m~/.local/lib/python3.10/site-packages/transformers/modeling_utils.py:584\u001b[0m, in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m    581\u001b[0m args \u001b[38;5;241m=\u001b[39m (state_dict, prefix, local_metadata, \u001b[38;5;28;01mTrue\u001b[39;00m, [], [], error_msgs)\n\u001b[1;32m    582\u001b[0m \u001b[38;5;66;03m# Parameters of module and children will start with prefix. We can exit early if there are none in this\u001b[39;00m\n\u001b[1;32m    583\u001b[0m \u001b[38;5;66;03m# state_dict\u001b[39;00m\n\u001b[0;32m--> 584\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m([key \u001b[38;5;28;01mfor\u001b[39;00m key \u001b[38;5;129;01min\u001b[39;00m state_dict \u001b[38;5;28;01mif\u001b[39;00m \u001b[43mkey\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstartswith\u001b[49m\u001b[43m(\u001b[49m\u001b[43mprefix\u001b[49m\u001b[43m)\u001b[49m]) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[1;32m    585\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m is_deepspeed_zero3_enabled():\n\u001b[1;32m    586\u001b[0m         \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mdeepspeed\u001b[39;00m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "model = AutoModelForCausalLM.from_pretrained(\"mistralai/Mistral-7B-v0.1\").to('cuda')\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"mistralai/Mistral-7B-v0.1\")\n",
    "Token = {v: k for k, v in tokenizer.get_vocab().items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "c3f1fa6c-d1e7-44b9-82ca-96692c2ec953",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = datasets.load_from_disk(f'/home/XXXXXXXX/msmarco_mistral_test').with_format('torch', device=torch.device('cuda'))\n",
    "loader = DataLoader(dataset, batch_size=128)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "76ed5f24-c195-4d06-932d-6edc852b1556",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch.autograd.grad_mode.set_grad_enabled at 0x7f9bbc82d630>"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.set_grad_enabled(False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "e5807d18-9236-45fb-8d0e-cebe7d00f0b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "del model\n",
    "gc.collect()\n",
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa28e81b-7b64-4d1d-a7ed-7d4f79bb4854",
   "metadata": {},
   "outputs": [],
   "source": [
    "ablate_model = copy.deepcopy(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "703e460a-56e6-411b-80fe-ddc23118a517",
   "metadata": {},
   "outputs": [],
   "source": [
    "params = dict(model.named_parameters())\n",
    "for name, param in params.items():\n",
    "    if 'mlp' in name:\n",
    "        param.zero_()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "77742d34-ac77-4fa9-8d7f-616b46e4b68e",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20/20 [02:10<00:00,  6.54s/it]\n"
     ]
    }
   ],
   "source": [
    "num_batches = 20\n",
    "it = islice(iter(loader), num_batches)\n",
    "loss = 0\n",
    "for batch in tqdm(it, total=num_batches):\n",
    "    out = model(\n",
    "        input_ids=batch['input_ids'], \n",
    "        labels=batch['input_ids'], \n",
    "        attention_mask=batch['attention_mask']\n",
    "    )\n",
    "    loss += out.loss.item()\n",
    "loss /= num_batches"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "71d9545a-cd2f-430a-bf8f-f1f67a6f9fca",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "10.92078218460083"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "loss"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "15d9024d-1361-4cd4-9670-5101e1c61cdb",
   "metadata": {},
   "source": [
    "full model: 2.243\n",
    "\n",
    "mlp only: 6.938\n",
    "\n",
    "attn only: 10.92"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1c01e79d-6e7c-4a93-b83b-09bce80f47a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch = next(iter(loader))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "21277eaa-9f74-479e-80cc-1ba4f0e1354c",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d384fb6-1f9d-4b1e-913d-f08be5455e43",
   "metadata": {},
   "outputs": [],
   "source": [
    "out = model(input_ids=batch['input_ids'], labels=batch['input_ids'], attention_mask=batch['attention_mask'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6236ed5e-91de-41db-9572-5fca2a130dae",
   "metadata": {},
   "outputs": [],
   "source": [
    "out.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0e4317b6-79bc-4a15-9fed-30e1cf22ce86",
   "metadata": {},
   "outputs": [],
   "source": [
    "out.loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d0d0f0d-3e1c-4e8d-923e-fbce0e6a27f2",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
