{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "092a91e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"]='0'\n",
    "import torch\n",
    "import torch.distributed as dist\n",
    "import torch.optim as optim\n",
    "from peft import get_peft_model, prepare_model_for_int8_training, PeftModel\n",
    "from configs import fsdp_config, train_config\n",
    "from transformers import (\n",
    "    LlamaForCausalLM,\n",
    "    LlamaTokenizer,\n",
    "    LlamaConfig,\n",
    "    default_data_collator,\n",
    ")\n",
    "from utils.config_utils import (\n",
    "    update_config,\n",
    "    generate_peft_config,\n",
    "    generate_dataset_config,\n",
    ")\n",
    "from utils.dataset_utils import get_preprocessed_dataset\n",
    "from safety_evaluation.eval_utils.model_utils import load_model, load_peft_model\n",
    "import copy\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cd0c1438",
   "metadata": {},
   "outputs": [],
   "source": [
    "vector_path = '...' # modify to your vector's path\n",
    "\n",
    "model = load_model('../LLM_Models/llama-3-8b-instruct//', False)#\n",
    "peft_model= './finetuned_models/samsum-8b-instr-peft-seed-42-lr-1e-4-bs-32-epochs-3/'\n",
    "\n",
    "model = PeftModel.from_pretrained(model,peft_model)\n",
    "model_ori = copy.deepcopy(model)\n",
    "print(f\"Already load PEFT model from {peft_model}!!\")\n",
    "\n",
    "v = torch.load(vector_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ac891c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def change_weight(model, model_ori, cof, thrs_cos):\n",
    "    idx = 0\n",
    "    i = 0\n",
    "    dis = []\n",
    "    dis_fix = []\n",
    "    cos_total = []\n",
    "    pdst = torch.nn.PairwiseDistance(p=2)\n",
    "    for (name, param),(name_ori, param_ori) in zip(model.named_parameters(),model_ori.named_parameters()):\n",
    "        if 'lora' in name:\n",
    "            if param.shape[0] == 8:\n",
    "                B = copy.deepcopy(param_ori)\n",
    "            if param.shape[0] != 8:\n",
    "                P = v[idx].to(param.device)\n",
    "                W = torch.mm(P, param_ori.data)\n",
    "                fW = torch.mm(W, B)\n",
    "                ori = torch.mm(param_ori, B)\n",
    "                W_new = torch.mm(cof*P, param_ori.data)\n",
    "                cos = np.round(torch.nn.functional.cosine_similarity(fW.reshape(1,-1), ori.reshape(1,-1)).item(),5)\n",
    "                cos_total.append(cos)\n",
    "                \n",
    "                dist_f = 1 / (1+torch.norm(param_ori.reshape(1,-1)-W.reshape(1,-1)))#\n",
    "                dis_fix.append(dist_f.item())\n",
    "                if cos <=  thrs_cos:\n",
    "                    i+=1\n",
    "                    param.data =  W_new\n",
    "                else:\n",
    "                    param.data = param_ori\n",
    "                dist = 1 / (1+torch.norm(param.data.reshape(1,-1)-W.reshape(1,-1)))\n",
    "\n",
    "                dis.append(dist.item())\n",
    "                idx += 1\n",
    "    print(i)\n",
    "    return model, np.mean(dis), dis_fix, cos_total"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4024e949",
   "metadata": {},
   "outputs": [],
   "source": [
    "thrs = 0.35\n",
    "N = 7 # number of layers you want to project\n",
    "model, dis, dis_f, cos = change_weight(model, model_ori, 1, thrs)\n",
    "thrs = np.sort(cos)[:N][-1]\n",
    "model, dis, dis_f, cos = change_weight(model, model_ori, 1, thrs)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
