{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "c:\\Users\\PC\\miniconda3\\Lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.\n",
      "Token is valid (permission: write).\n",
      "Your token has been saved to C:\\Users\\PC\\.cache\\huggingface\\token\n",
      "Login successful\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading checkpoint shards: 100%|██████████| 4/4 [00:16<00:00,  4.02s/it]\n"
     ]
    },
    {
     "ename": "",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31mThe Kernel crashed while executing code in the current cell or a previous cell. \n",
      "\u001b[1;31mPlease review the code in the cell(s) to identify a possible cause of the failure. \n",
      "\u001b[1;31mClick <a href='https://aka.ms/vscodeJupyterKernelCrash'>here</a> for more info. \n",
      "\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import sys\n",
    "import os\n",
    "import pickle\n",
    "from huggingface_hub import login\n",
    "sys.path.append(\"./util\")\n",
    "import featureM_utility as util\n",
    "\n",
    "# Load model directly\n",
    "from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, PreTrainedTokenizerFast\n",
    "\n",
    "# Load the book Corpus\n",
    "from datasets import load_dataset\n",
    "\n",
    "\n",
    "torch.cuda.empty_cache()  # Frees up memory from unused tensors\n",
    "torch.cuda.reset_peak_memory_stats()  # (Optional) Resets memory stats\n",
    "\n",
    "# Login with your Hugging Face token\n",
    "login(\"hf_xxx\")\n",
    "\n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "\n",
    "\n",
    "MODEL_ID = \"meta-llama/Llama-3.1-8B\"\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)\n",
    "\n",
    "\n",
    "model = AutoModelForCausalLM.from_pretrained( \n",
    "    \"xxx/Llama/Meta-Llama-3.1-8B\",\n",
    "    torch_dtype=torch.float16,        # Use FP16 for model weights\n",
    "    device_map=\"auto\",                 # Automatically map layers to available devices\n",
    "    trust_remote_code=True\n",
    ")\n",
    "\n",
    "\n",
    "model.eval()\n",
    "\n",
    "replacedMod = util.replaceWithFMmodules(model, layers = [\"0\", \"4\", \"9\", \"14\", \"19\", \"24\", \"29\"], device=device)\n",
    "DATASET = \"BookCorpus\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "using Bookcorpus\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Token indices sequence length is longer than the specified maximum sequence length for this model (144932 > 131072). Running this sequence through the model will result in indexing errors\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "full Length 144932\n",
      "end loc 2048  done\n",
      "end loc 3072  done\n",
      "end loc 4096  done\n",
      "end loc 5120  done\n",
      "end loc 6144  done\n",
      "end loc 7168  done\n",
      "end loc 8192  done\n",
      "end loc 9216  done\n",
      "end loc 10240  done\n",
      "end loc 11264  done\n",
      "end loc 12288  done\n",
      "end loc 13312  done\n",
      "end loc 14336  done\n",
      "end loc 15360  done\n",
      "end loc 16384  done\n",
      "end loc 17408  done\n",
      "end loc 18432  done\n",
      "end loc 19456  done\n",
      "end loc 20480  done\n",
      "end loc 21504  done\n",
      "end loc 22528  done\n",
      "end loc 23552  done\n",
      "end loc 24576  done\n",
      "end loc 25600  done\n",
      "end loc 26624  done\n",
      "end loc 27648  done\n",
      "end loc 28672  done\n",
      "end loc 29696  done\n",
      "end loc 30720  done\n",
      "end loc 31744  done\n",
      "end loc 32768  done\n",
      "end loc 33792  done\n",
      "end loc 34816  done\n",
      "end loc 35840  done\n",
      "end loc 36864  done\n",
      "end loc 37888  done\n",
      "end loc 38912  done\n",
      "end loc 39936  done\n",
      "end loc 40960  done\n",
      "end loc 41984  done\n",
      "end loc 43008  done\n",
      "end loc 44032  done\n",
      "end loc 45056  done\n",
      "end loc 46080  done\n",
      "end loc 47104  done\n",
      "end loc 48128  done\n",
      "end loc 49152  done\n",
      "end loc 50176  done\n",
      "end loc 51200  done\n",
      "end loc 52224  done\n",
      "end loc 53248  done\n",
      "end loc 54272  done\n",
      "end loc 55296  done\n",
      "end loc 56320  done\n",
      "end loc 57344  done\n",
      "end loc 58368  done\n",
      "end loc 59392  done\n",
      "end loc 60416  done\n",
      "end loc 61440  done\n",
      "end loc 62464  done\n",
      "end loc 63488  done\n",
      "end loc 64512  done\n",
      "end loc 65536  done\n",
      "end loc 66560  done\n",
      "end loc 67584  done\n",
      "end loc 68608  done\n",
      "end loc 69632  done\n",
      "end loc 70656  done\n",
      "end loc 71680  done\n",
      "end loc 72704  done\n",
      "end loc 73728  done\n",
      "end loc 74752  done\n",
      "end loc 75776  done\n",
      "end loc 76800  done\n",
      "end loc 77824  done\n",
      "end loc 78848  done\n",
      "end loc 79872  done\n",
      "end loc 80896  done\n",
      "end loc 81920  done\n",
      "end loc 82944  done\n",
      "end loc 83968  done\n",
      "end loc 84992  done\n",
      "end loc 86016  done\n",
      "end loc 87040  done\n",
      "end loc 88064  done\n",
      "end loc 89088  done\n",
      "end loc 90112  done\n",
      "end loc 91136  done\n",
      "end loc 92160  done\n",
      "end loc 93184  done\n",
      "end loc 94208  done\n",
      "end loc 95232  done\n",
      "end loc 96256  done\n",
      "end loc 97280  done\n",
      "end loc 98304  done\n",
      "end loc 99328  done\n",
      "end loc 100352  done\n",
      "end loc 101376  done\n",
      "end loc 102400  done\n",
      "end loc 103424  done\n",
      "end loc 104448  done\n",
      "end loc 105472  done\n",
      "end loc 106496  done\n",
      "end loc 107520  done\n",
      "end loc 108544  done\n",
      "end loc 109568  done\n",
      "end loc 110592  done\n",
      "end loc 111616  done\n",
      "end loc 112640  done\n",
      "end loc 113664  done\n",
      "end loc 114688  done\n",
      "end loc 115712  done\n",
      "end loc 116736  done\n",
      "end loc 117760  done\n",
      "end loc 118784  done\n",
      "end loc 119808  done\n",
      "end loc 120832  done\n",
      "end loc 121856  done\n",
      "end loc 122880  done\n",
      "end loc 123904  done\n",
      "end loc 124928  done\n",
      "end loc 125952  done\n",
      "end loc 126976  done\n",
      "end loc 128000  done\n",
      "end loc 129024  done\n",
      "end loc 130048  done\n",
      "end loc 131072  done\n",
      "end loc 132096  done\n",
      "end loc 133120  done\n",
      "end loc 134144  done\n",
      "end loc 135168  done\n",
      "end loc 136192  done\n",
      "end loc 137216  done\n",
      "end loc 138240  done\n",
      "end loc 139264  done\n",
      "end loc 140288  done\n",
      "end loc 141312  done\n",
      "end loc 142336  done\n",
      "end loc 143360  done\n",
      "end loc 144384  done\n",
      "end loc 144932  done\n",
      "using Bookcorpus\n",
      "full Length 144932\n",
      "end loc 2048  done\n",
      "end loc 3072  done\n",
      "end loc 4096  done\n",
      "end loc 5120  done\n",
      "end loc 6144  done\n",
      "end loc 7168  done\n",
      "end loc 8192  done\n",
      "end loc 9216  done\n",
      "end loc 10240  done\n",
      "end loc 11264  done\n",
      "end loc 12288  done\n",
      "end loc 13312  done\n",
      "end loc 14336  done\n",
      "end loc 15360  done\n",
      "end loc 16384  done\n",
      "end loc 17408  done\n",
      "end loc 18432  done\n",
      "end loc 19456  done\n",
      "end loc 20480  done\n",
      "end loc 21504  done\n",
      "end loc 22528  done\n",
      "end loc 23552  done\n",
      "end loc 24576  done\n",
      "end loc 25600  done\n",
      "end loc 26624  done\n",
      "end loc 27648  done\n",
      "end loc 28672  done\n",
      "end loc 29696  done\n",
      "end loc 30720  done\n",
      "end loc 31744  done\n",
      "end loc 32768  done\n",
      "end loc 33792  done\n",
      "end loc 34816  done\n",
      "end loc 35840  done\n",
      "end loc 36864  done\n",
      "end loc 37888  done\n",
      "end loc 38912  done\n",
      "end loc 39936  done\n",
      "end loc 40960  done\n",
      "end loc 41984  done\n",
      "end loc 43008  done\n",
      "end loc 44032  done\n"
     ]
    }
   ],
   "source": [
    "\n",
    "def hugginfaceFMgeneration(tokenizer ,model, mode=\"wikitext\"):\n",
    "        \n",
    "\n",
    "    if( DATASET == \"BookCorpus\"):\n",
    "        dataset = load_dataset(\"bookcorpus\", split=\"train\", trust_remote_code=True) #\n",
    "        evalSentences = 10000\n",
    "        dataset = dataset.select(range(evalSentences))\n",
    "        print(\"using Bookcorpus\")\n",
    "\n",
    "    if( DATASET == \"WikiText\"):\n",
    "        dataset = load_dataset(\"wikitext\", \"wikitext-2-raw-v1\", split=\"test\", trust_remote_code=True)\n",
    "        print(\"using WikiText\")\n",
    "\n",
    "\n",
    "    encodings = tokenizer(\"\\n\\n\".join(dataset[\"text\"]), return_tensors=\"pt\")\n",
    "\n",
    "\n",
    "    # the stride determines how much context the model has. \n",
    "    # For conetext length 1024 a stride of 1024 gives little context, 512 gives more\n",
    "    # pythia model has 2048\n",
    "    stride = 1024\n",
    "\n",
    "    max_length = 2048\n",
    "\n",
    "    seq_len = encodings.input_ids.size(1)\n",
    "    print(\"full Length\",seq_len)\n",
    "\n",
    "\n",
    "    nll_sum = 0.0\n",
    "    n_tokens = 0\n",
    "    prev_end_loc = 0\n",
    "    for begin_loc in range(0, seq_len, stride):\n",
    "        end_loc = min(begin_loc + max_length, seq_len)\n",
    "        trg_len = end_loc - prev_end_loc  # may be different from stride on last loop\n",
    "        input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device)\n",
    "        target_ids = input_ids.clone()\n",
    "        target_ids[:, :-trg_len] = -100\n",
    "\n",
    "        with torch.no_grad():\n",
    "            outputs = model(input_ids, labels=target_ids)\n",
    "            print(\"end loc\", end_loc , \" done\")\n",
    "\n",
    "        if end_loc == seq_len:\n",
    "            break\n",
    "\n",
    "\n",
    "hugginfaceFMgeneration( tokenizer, model)\n",
    "\n",
    "# compute the FM\n",
    "util.setComputationMode(model, mode=\"FM\")\n",
    "hugginfaceFMgeneration( tokenizer, model)\n",
    "\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# save the FM to a dictionary\n",
    "modules = [(name, module) for name, module in model.named_modules()  if isinstance(module, util.FeatureLayer) ]\n",
    "\n",
    "dictionary = {}\n",
    "for name, module in modules:\n",
    "    # only load valid ones\n",
    "    if( len(module.weight.shape) == 2 and (not \"pooler\" in name) and name != \"head.module.layers.0\" ):\n",
    "        if module.weight.is_meta:\n",
    "            continue\n",
    "        \n",
    "        # add it to the dictionary\n",
    "        dictionary[name] = {\"weight\":  module.weight.detach().cpu().numpy().copy(), \"FM\" : module.correlationM.detach().cpu().numpy().copy() } \n",
    "        \n",
    "with open(f'Data/Llama3_1_8B_{DATASET}', 'wb') as file:\n",
    "    pickle.dump(dictionary, file)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Also save to Llama 3 model and its SVDs to save computational resources "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoModelForCausalLM\n",
    "import torch\n",
    "import numpy as np\n",
    "from huggingface_hub import login\n",
    "import os\n",
    "\n",
    "def saveMa( name , matrix):\n",
    "\n",
    "    # compute in numpy for better precision\n",
    "    npMa = matrix.cpu().numpy().astype(np.float32)\n",
    "    u ,s , vh = np.linalg.svd(npMa, full_matrices=False)\n",
    "    U , Sval , Vh = torch.from_numpy(u), torch.from_numpy(s) , torch.from_numpy(vh)\n",
    "\n",
    "    # save the file\n",
    "    torch.save(U,f\"{name}_U\")\n",
    "    torch.save(Sval ,f\"{name}_S\")\n",
    "    torch.save(Vh ,f\"{name}_Vh\")\n",
    "\n",
    "\n",
    "# Login with your Hugging Face token\n",
    "login(\"hf_xxx\")\n",
    "\n",
    "MODEL_ID = \"meta-llama/Llama-3.1-8B\"\n",
    "\n",
    "model = AutoModelForCausalLM.from_pretrained( \n",
    "    \"xxx/Llama/Meta-Llama-3.1-8B\",\n",
    "    torch_dtype=torch.float16,        # Use FP16 for model weights\n",
    "    device_map=\"auto\",                 # Automatically map layers to available devices\n",
    "    trust_remote_code=True\n",
    "    \n",
    ")\n",
    "\n",
    "\n",
    "stateDic = model.state_dict()\n",
    "\n",
    "\n",
    "validWeighM = [\"attn.q_proj.weight\", \"attn.k_proj.weight\",\"attn.v_proj.weight\" ,\"attn.o_proj.weight\", \"mlp.gate_proj.weight\" ,\".mlp.up_proj.weight\", \"mlp.down_proj.weight\" ]\n",
    "pathSVDmodel =  f\"LlamaSVD\"\n",
    "\n",
    "if( not os.path.isdir(pathSVDmodel) ):\n",
    "        os.makedirs(pathSVDmodel)\n",
    "\n",
    "\n",
    "# save a model\n",
    "for key in stateDic.keys():\n",
    "    if( any([validM in key for validM in validWeighM])):\n",
    "        saveMa(f\"{pathSVDmodel}/{key}\", stateDic[key])\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
