{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import sys\n",
    "import os\n",
    "import pickle\n",
    "sys.path.append(\"./util\")\n",
    "import featureM_utility as util\n",
    "from datasets import load_dataset\n",
    "# Load model directly\n",
    "from transformers import GPTNeoXForCausalLM, AutoTokenizer\n",
    "\n",
    "\n",
    "model_name = \"pythia-410m-deduped\"\n",
    "revision = \"step143000\"  # final model: \n",
    "\n",
    "model = GPTNeoXForCausalLM.from_pretrained(\n",
    "  f\"EleutherAI/{model_name}\",\n",
    "  revision=revision,\n",
    "  cache_dir=f\"./{model_name}/{revision}\",\n",
    ")\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(\n",
    "  f\"EleutherAI/{model_name}\",\n",
    "  revision=revision,\n",
    "  cache_dir=f\"./{model_name}/{revision}\",\n",
    ")\n",
    "\n",
    "\n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "model.eval()\n",
    "\n",
    "model.to(device)\n",
    "# might needs to be altered to replace only the correct modules\n",
    "replacedMod = util.replaceWithFMmodules( model=model, device=device)\n",
    "\n",
    "DATASET = \"WikiText\" "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "using WikiText\n",
      "using WikiText\n"
     ]
    }
   ],
   "source": [
    "\n",
    "def hugginfaceFMgeneration(tokenizer ,model):\n",
    "\n",
    "    if( DATASET == \"BookCorpus\"):\n",
    "        dataset = load_dataset(\"bookcorpus\", split=\"train\", trust_remote_code=True) #\n",
    "        evalSentences = 10000\n",
    "        dataset = dataset.select(range(evalSentences))\n",
    "        print(\"using Bookcorpus\")\n",
    "\n",
    "    if( DATASET == \"WikiText\"):\n",
    "        dataset = load_dataset(\"wikitext\", \"wikitext-2-raw-v1\", split=\"test\", trust_remote_code=True)\n",
    "        print(\"using WikiText\")\n",
    "\n",
    "\n",
    "    encodings = tokenizer(\"\\n\\n\".join(dataset[\"text\"]), return_tensors=\"pt\")\n",
    "\n",
    "\n",
    "    # the stride determines how much context the model has. \n",
    "    # For conetext length 1024 a stride of 1024 gives little context, 512 gives more\n",
    "    # pythia model has 2048\n",
    "    stride = 1024\n",
    "\n",
    "    max_length = model.config.max_position_embeddings\n",
    "\n",
    "    seq_len = encodings.input_ids.size(1)\n",
    "\n",
    "\n",
    "    nll_sum = 0.0\n",
    "    n_tokens = 0\n",
    "    prev_end_loc = 0\n",
    "    for begin_loc in range(0, seq_len, stride):\n",
    "        end_loc = min(begin_loc + max_length, seq_len)\n",
    "        trg_len = end_loc - prev_end_loc  # may be different from stride on last loop\n",
    "        input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device)\n",
    "        target_ids = input_ids.clone()\n",
    "        target_ids[:, :-trg_len] = -100\n",
    "\n",
    "        with torch.no_grad():\n",
    "            outputs = model(input_ids, labels=target_ids)\n",
    "\n",
    "        if end_loc == seq_len:\n",
    "            break\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "hugginfaceFMgeneration( tokenizer, model)\n",
    "\n",
    "# compute the FM\n",
    "util.setComputationMode(model, mode=\"FM\")\n",
    "hugginfaceFMgeneration( tokenizer, model)\n",
    "\n",
    "    \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# save the FM to a dictionary\n",
    "modules = [(name, module) for name, module in model.named_modules()  if isinstance(module, util.FeatureLayer) ]\n",
    "\n",
    "dictionary = {}\n",
    "for name, module in modules:\n",
    "    # only load valid ones\n",
    "    if( len(module.weight.shape) == 2 and (not \"pooler\" in name) and name != \"head.module.layers.0\" ):\n",
    "        # add it to the dictionary\n",
    "        dictionary[name] = {\"weight\":  module.weight.detach().cpu().numpy().copy(), \"FM\" : module.correlationM.detach().cpu().numpy().copy() } \n",
    "        \n",
    "with open(f'Data/{DATASET}_Pythia_{revision}', 'wb') as file:\n",
    "    pickle.dump(dictionary, file)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Also save the SVD to save computational resources later"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import GPTNeoXForCausalLM\n",
    "import torch\n",
    "import numpy as np\n",
    "import argparse\n",
    "import os\n",
    "\n",
    "def saveMa( name , matrix):\n",
    "\n",
    "    # compute in numpy for better precision\n",
    "    npMa = matrix.numpy()\n",
    "    u ,s , vh = np.linalg.svd(npMa, full_matrices=False)\n",
    "    U , Sval , Vh = torch.from_numpy(u), torch.from_numpy(s) , torch.from_numpy(vh)\n",
    "\n",
    "    # save the file\n",
    "    torch.save(U,f\"{name}_U\")\n",
    "    torch.save(Sval ,f\"{name}_S\")\n",
    "    torch.save(Vh ,f\"{name}_Vh\")\n",
    "\n",
    "\n",
    "\n",
    "step = 143000 \n",
    "\n",
    "\n",
    "# if any of these strings is in the key, it is a weight matrix\n",
    "validWeighM = [\"query_key_value.weight\", \"mlp.dense_h_to_4h.weight\",\"attention.dense.weight\" ,\"mlp.dense_4h_to_h.weight\"]\n",
    "\n",
    "# load a pythia model\n",
    "model_name = \"pythia-410m-deduped\"\n",
    "revision = f\"step{step}\"  # final model: 143000\n",
    "\n",
    "\n",
    "model = GPTNeoXForCausalLM.from_pretrained(\n",
    "f\"EleutherAI/{model_name}\",\n",
    "revision=revision,\n",
    "cache_dir=f\"./{model_name}/{revision}\",\n",
    ")\n",
    "\n",
    "stateDic = model.state_dict()\n",
    "\n",
    "\n",
    "# location for the final model\n",
    "pathSVDmodel =  f\"PythiaSVD/{model_name}_{revision}\"\n",
    "\n",
    "if( not os.path.isdir(pathSVDmodel) ):\n",
    "    os.makedirs(pathSVDmodel)\n",
    "\n",
    "\n",
    "# save a model in its svd\n",
    "for key in stateDic.keys():\n",
    "    if( any([validM in key for validM in validWeighM])):\n",
    "        if(\"query_key_value.weight\" in key):\n",
    "            qM = stateDic[key][:1024,:]\n",
    "            kM = stateDic[key][1024:2048,:]\n",
    "            vM = stateDic[key][2048:,:]\n",
    "            saveMa(f\"{pathSVDmodel}/{key}_Q\", qM)\n",
    "            saveMa(f\"{pathSVDmodel}/{key}_K\", kM)\n",
    "            saveMa(f\"{pathSVDmodel}/{key}_V\", vM)\n",
    "        else:\n",
    "            saveMa(f\"{pathSVDmodel}/{key}\", stateDic[key])\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
