{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# EasyEdit Example with the **U.S. President**"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    ">Tutorial author: Kewei Xu（<kewe1x@163.con>）"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Introduction"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The U.S. election has recently concluded, and `Donald Trump` has been elected President.\n",
    "We test knowledge editing in this scenario:\n",
    "- `Biden → Trump` <br> \n",
    "- `Biden → Trump → Biden` (simulating the interesting shift of Trump → Biden → Trump).<br>\n",
    "\n",
    "In this notebook, we use knowledge editing method `Wise`、`AlphaEdit`、`AdaLoRA`、`Prompt` to edit `Llama3-8B-instruct`.<br> \n",
    "Enjoy the process of editing the Large Language Model (LLM) to make it aware of the presidential change. <br>\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Knowledge Editing\n",
    "\n",
    "Deployed LLMs can still exhibit unpredictable errors, including challenges such as hallucinations, biases, and toxic generation. Therefore, it is essential to have the ability to adjust their specific behaviors.\n",
    "\n",
    "**Knowledge editing** aims to adjust an initial base model's $(f_\\theta)$ behavior on the particular edit descriptor $[x_e, y_e]$, such as:\n",
    "- $x_e$: \"Who is the president of the US?\n",
    "- $y_e$: \"Joe Biden.\"\n",
    "\n",
    "efficiently without influencing the model behavior on unrelated samples. The ultimate goal is to create an edited model $(f_\\theta’)$."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Method"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "### AlphaEdit\n",
    "Paper: [AlphaEdit: Null-Space Constrained Knowledge Editing for Language Models](https://arxiv.org/pdf/2410.02355)\n",
    "\n",
    "**AlphaEdit** minimizes disruption to the preserved knowledge by projecting parameter perturbations onto the null space of its key matrices. It then removes the output error related to it from the current objective, allowing the model to focus solely on knowledge update without trade-off. By leveraging the mathematical properties of matrix projection and null space, AlphaEdit ensures that the distribution of hidden representations within LLMs remains invariant after edits. This invariance allows post-edited LLMs to effectively handle both knowledge update and preservation simultaneously.\n",
    "\n",
    "### WISE\n",
    "Paper: [WISE: Rethinking the Knowledge Memory for Lifelong Model Editing of Large Language Models?](http://arxiv.org/pdf/2405.14768)\n",
    "    \n",
    "**WISE**, is an approach for lifelong model editing of LLMs. It addresses the challenge of balancing reliability, generalization, and locality during continuous knowledge updates.\n",
    "It provides an effective solution for continuous learning and knowledge updating in LLMs through its innovative memory management and editing strategies.\n",
    "\n",
    "### AdaLoRA\n",
    "Paper: [AdaLoRA: Adaptive Budget Allocation for Parameter-Efficient Fine-Tuning](https://arxiv.org/pdf/2303.10512)\n",
    "\n",
    "**AdaLoRA** introduces a method that efficiently fine-tunes LLMs by adaptively allocating update budgets based on parameter importance. Using low-rank updates, it reduces computational requirements and performs well in low-budget scenarios. \n",
    "\n",
    "### Prompt\n",
    "Directly alter the model's behavior temporarily through prompts.\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Result"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Edit facts** :<br>\n",
    "&ensp; First &ensp; Edit : *Who is the current President of the United States?* &ensp;Joe Biden ——> **Donald Trump** <br>\n",
    "Second Edit : *Who is the current President of the United States?* &ensp;Joe Biden ——> Donald Trump ——> **Joe Biden**\n",
    "\n",
    "Then we test the following indicators:\n",
    "- `Reliability`: the success rate of editing with a given editing descriptor<br>\n",
    "**Question**: *Who is the current President of the United States?*\n",
    "\n",
    "- `Generalization`: the success rate of editing within the editing scope<br>\n",
    "**Question**: *What is the name of the current President of the United States?*\n",
    "\n",
    "- `Locality`: whether the model's output changes after editing for unrelated inputs<br>\n",
    "**Question**: *Where is the capital of the United States?*\n",
    "\n",
    "- `Portability`: the success rate of editing for reasoning/application(one hop, synonym, logical generalization)<br>\n",
    "**Question**: *Where is the current U.S. President born?*\n",
    "\n",
    "\n",
    "The results are shown in the table below, with **highlighted** areas indicating that the output **does not match the answer**.<br>\n",
    "From the table, it can be seen that:<br>\n",
    "**_Prompt_** , **_WISE_** and **_AlphaEdit_** can successfully edit LLMs regarding U.S. presidents. <br>\n",
    "**_AdaLoRA_** is competent for the first editing, but there are exceptions for the second editing in Locality and Portability.<br>\n",
    "**Note:** The results are merely an analysis of a knowledge editing case involving the U.S. president. Not all knowledge editing situations are the same, as the performance depend on the knowledge itself, the large language model, the editing method, and the hyperparameters. <br>\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<table>\n",
    "    <tr>\n",
    "        <td></td>\n",
    "        <td><b>Reliability</b></td>\n",
    "        <td><b>Generalization</b></td>\n",
    "        <td><b>Locality</b></td>\n",
    "        <td><b>Portability</b></td>\n",
    "    </tr>\n",
    "    <tr>\n",
    "        <td><b>Questions</b></td>\n",
    "        <td>Who is the current President of the United States?</td>\n",
    "        <td>What is the name of the current President of the United States?</td>\n",
    "        <td>Where is the capital of the United States?</td>\n",
    "        <td>Where is the current U.S. President born? </td>\n",
    "    </tr>\n",
    "    <tr>\n",
    "        <td></td>\n",
    "        <td></td>\n",
    "        <td></td>\n",
    "        <td></td>\n",
    "        <td> </td>\n",
    "    </tr>\n",
    "    <tr>\n",
    "        <td colspan=5 align=\"center\"><b>First Edit:  Joe Biden ——&gt; Donald Trump</b></td>\n",
    "    </tr>\n",
    "    <tr>\n",
    "        <td><b>Answer</b></td>\n",
    "        <td><b>Donald Trump</b></td>\n",
    "        <td><b>Donald Trump</b></td>\n",
    "        <td><b>Washington, D.C.</b></td>\n",
    "        <td><b>Queens, New York </b></td>\n",
    "    </tr>\n",
    "    <tr>\n",
    "        <td><b>AlphaEdit</b></td>\n",
    "        <td>Donald Trump</td>\n",
    "        <td>Donald Trump</td>\n",
    "        <td>Washington, D.C</td>\n",
    "        <td>Queens, New York </td>\n",
    "    </tr>\n",
    "    <tr>\n",
    "        <td><b>WISE</b></td>\n",
    "        <td>Donald Trump</td>        \n",
    "        <td>Donald Trump</td>\n",
    "        <td>Washington, D.C</td>\n",
    "        <td>Queens, New York </td>\n",
    "    </tr>\n",
    "    <tr>\n",
    "        <td><b>AdaLoRA</b></td>\n",
    "        <td>Donald Trump</td>\n",
    "        <td>Donald Trump</td>\n",
    "        <td>Washington, D.C</td>\n",
    "        <td>Queens, New York </td>\n",
    "    </tr>\n",
    "    <tr>\n",
    "        <td><b>Prompt</b></td>\n",
    "        <td>Donald Trump</td>\n",
    "        <td>Donald Trump</td>\n",
    "        <td>Washington, D.C.</td>\n",
    "        <td>Queens, New York </td>\n",
    "    </tr>\n",
    "    <tr>\n",
    "        <td></td>\n",
    "        <td></td>\n",
    "        <td></td>\n",
    "        <td></td>\n",
    "        <td> </td>\n",
    "    </tr>\n",
    "    <tr>\n",
    "        <td colspan=5 align=\"center\"><b>Second Edit:   Joe Biden ——&gt; Donald Trump ——&gt; Joe Biden</b></td>\n",
    "    </tr>\n",
    "    <tr>\n",
    "        <td><b>Answer</b></td>\n",
    "        <td><b>Joe Biden</b></td>\n",
    "        <td><b>Joe Biden</b></td>\n",
    "        <td><b>Washington, D.C.</b></td>\n",
    "        <td><b>Scranton, Pennsylvania </b></td>\n",
    "    </tr>\n",
    "    <tr>\n",
    "        <td><b>AlphaEdit</b></td>\n",
    "        <td>Joe Biden</td>\n",
    "        <td>Joe Biden</td>\n",
    "        <td>Washington, D.C</td>\n",
    "        <td>Scranton, Pennsylvania </td>\n",
    "    </tr>\n",
    "    <tr>\n",
    "        <td><b>WISE</b></td>\n",
    "        <td>Joe Biden</td>\n",
    "        <td>Joe Biden</td>\n",
    "        <td>Washington, D.C</td>\n",
    "        <td>Scranton, Pennsylvania </td>\n",
    "    </tr>\n",
    "    <tr>\n",
    "        <td><b>AdaLoRA</b></td>\n",
    "        <td>Joe Biden</td>\n",
    "        <td>Joe Biden</td>\n",
    "        <td><mark>Joe Biden</mark></td>\n",
    "        <td><mark>Joe Biden </mark></td>\n",
    "    </tr>\n",
    "    <tr>\n",
    "        <td><b>Prompt</b></td>\n",
    "        <td>Joe Biden</td>\n",
    "        <td>Joe Biden</td>\n",
    "        <td>Washington, D.C</td>\n",
    "        <td>Scranton, Pennsylvania </td>\n",
    "    </tr>\n",
    "</table>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Experiment"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### GPU memory usage for this notebook\n",
    "Editing llama-3-8B-instruct requires 40G VRAM on GPU."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Preparation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Prepare the runtime environment"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Clone repository, create environment Python 3.9, and install relevant libraries.<br> \n",
    "Please execute the following command on the **Terminal**:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Clone Repo\n",
    "git clone https://github.com/zjunlp/EasyEdit.git\n",
    "cd EasyEdit\n",
    "\n",
    "## Create Environment\n",
    "conda create -n EasyEdit python=3.9\n",
    "conda activate EasyEdit\n",
    "pip install -r requirements.txt\n",
    "\n",
    "## Install Jupyter Notebook environment dependencies\n",
    "pip install ipykernel\n",
    "pip install ipywidgets"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Download Model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Use the following command to download `llama-3-8b-instruct` to the specified folder.<br>meta-llama needs to log in to apply for permission and add the `-- token your_token` in huggingface cli.\n",
    "For more information, please refer to [huggingface](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fetching 17 files: 100%|███████████████████████| 17/17 [00:00<00:00, 175.12it/s]\n",
      "/mnt/8t/xkw/EasyEdit/hugging_cache/llama-3-8b-instruct\n"
     ]
    }
   ],
   "source": [
    "!huggingface-cli download meta-llama/Meta-Llama-3-8B-Instruct --local-dir ./hugging_cache/llama-3-8b-instruct  --token your_token"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Load the evaluation function "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If the current path is `EasyEdit/tutorial-notebooks`, run the following command to navigate to the `EasyEdit/` path."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%cd .."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained('./hugging_cache/llama-3-8b-instruct')\n",
    "tokenizer.pad_token_id = tokenizer.eos_token_id\n",
    "tokenizer.padding_side='left'\n",
    "# use chat template to generate responses\n",
    "def evaluate_chat_template(model, Evaluation_prompts,Evaluation_metrics, device=0):\n",
    "    device = f\"cuda:{device}\"\n",
    "    for i in range(len(Evaluation_prompts)):\n",
    "        messages = [\n",
    "            {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n",
    "            {\"role\": \"user\", \"content\": Evaluation_prompts[i] },\n",
    "        ]\n",
    "        input_ids = tokenizer.apply_chat_template(\n",
    "            messages,\n",
    "            add_generation_prompt=True,\n",
    "            return_tensors=\"pt\"\n",
    "        ).to(model.device)\n",
    "\n",
    "        terminators = [tokenizer.eos_token_id,tokenizer.convert_tokens_to_ids(\"<|eot_id|>\")]\n",
    "        outputs = model.generate(\n",
    "            input_ids = input_ids,\n",
    "            max_new_tokens=40,\n",
    "            eos_token_id=terminators,\n",
    "            pad_token_id= tokenizer.eos_token_id,\n",
    "            do_sample=False\n",
    "        )\n",
    "        response = outputs[0][input_ids.shape[-1]:]\n",
    "        response = tokenizer.decode(response, skip_special_tokens=True)\n",
    "\n",
    "        print(f\"{Evaluation_metrics[i]:<14}:  {response}\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Set gpu device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "## \n",
    "import torch\n",
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '1'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Orignal Output"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Test the output of the initial model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/json": {
       "ascii": false,
       "bar_format": null,
       "colour": null,
       "elapsed": 0.007737636566162109,
       "initial": 0,
       "n": 0,
       "ncols": null,
       "nrows": null,
       "postfix": null,
       "prefix": "Loading checkpoint shards",
       "rate": null,
       "total": 4,
       "unit": "it",
       "unit_divisor": 1000,
       "unit_scale": false
      },
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d409f20c8534463b9cd82963221ec26c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from transformers import AutoModelForCausalLM\n",
    "\n",
    "model = AutoModelForCausalLM.from_pretrained('./hugging_cache/llama-3-8b-instruct').to('cuda')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Reliability   :  As of my knowledge cutoff, the current President of the United States is Joe Biden. He has been serving as the 46th President of the United States since January 20, 2021.\n",
      "Generalization:  As of my knowledge cutoff, the current President of the United States is Joe Biden. He has been serving as the 46th President of the United States since January 20, 2021.\n",
      "Locality      :  The capital of the United States is Washington, D.C. (short for District of Columbia).\n",
      "Portability   :  The current President of the United States, Joe Biden, was born in Scranton, Pennsylvania, on November 20, 1942.\n"
     ]
    }
   ],
   "source": [
    "# output the response\n",
    "# evaluation metrics and questions\n",
    "Evaluation_metrics = [ \"Reliability\",\"Generalization\", \"Locality\", \"Portability\"]\n",
    "Evaluation_prompts = [  \"Who is the current President of the United States?\" ,\n",
    "                        \"What is the name of the current President of the United States?\",\n",
    "                        \"Where is the capital of the United States?\" ,\n",
    "                        \"Where is the current U.S. President born ?\"]\n",
    "evaluate_chat_template(model, Evaluation_prompts,Evaluation_metrics,device=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# clear memory \n",
    "del model\n",
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### First Knowledge Editing\n",
    "`Joe Biden —> Donald Trump`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Edit data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from easyeditor import BaseEditor\n",
    "\n",
    "##  Edit once: Joe Biden ——> Donald Trump\n",
    "prompts = [\"Who is the current President of the United States?\" ]\n",
    "subject = ['President']\n",
    "ground_truth = ['Joe Biden']\n",
    "target_new = ['Donald Trump']\n",
    "\n",
    "# evaluation metrics and questions\n",
    "Evaluation_metrics = [ \"Reliability\",\"Generalization\", \"Locality\", \"Portability\"]\n",
    "Evaluation_prompts = [  \"Who is the current President of the United States?\" ,\n",
    "                        \"What is the name of the current President of the United States?\",\n",
    "                        \"Where is the capital of the United States?\" ,\n",
    "                        \"Where is the current U.S. President born ?\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### AlphaEdit"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Calculating the projection matrix P requires a lot of time and computing resources. <br>\n",
    "To avoid recalculating it every time during editing, the calculated P can be saved in advance. We provide the projection matrix P that we have calculated for Llama3-8B-instruct:\n",
    " [Google Drive](https://drive.google.com/file/d/1vr0Pcohb7pW3SWvGhFy7xrB8DdA9BVex/view?usp=sharing) or [Baidu Pan](https://pan.baidu.com/s/1Sgfz2bqRiBZdkG3meTtWwA?pwd=dj9i) <br>\n",
    "After downloading, please modify the **P_loc** parameter in the **hparams/AlphaEdit/llama3-8b.yaml** to the download path\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-11-28 19:20:27,743 - easyeditor.editors.editor - INFO - Instantiating model\n",
      "11/28/2024 19:20:27 - INFO - easyeditor.editors.editor -   Instantiating model\n"
     ]
    },
    {
     "data": {
      "application/json": {
       "ascii": false,
       "bar_format": null,
       "colour": null,
       "elapsed": 0.0052433013916015625,
       "initial": 0,
       "n": 0,
       "ncols": null,
       "nrows": null,
       "postfix": null,
       "prefix": "Loading checkpoint shards",
       "rate": null,
       "total": 4,
       "unit": "it",
       "unit_divisor": 1000,
       "unit_scale": false
      },
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d24ace22f13e4c9d83e592931350239b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-11-28 19:20:31,520 - easyeditor.editors.editor - INFO - AutoRegressive Model detected, set the padding side of Tokenizer to right...\n",
      "11/28/2024 19:20:31 - INFO - easyeditor.editors.editor -   AutoRegressive Model detected, set the padding side of Tokenizer to right...\n",
      "  0%|          | 0/1 [00:00<?, ?it/s]We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)\n",
      "100%|██████████| 1/1 [00:00<00:00,  2.06it/s]\n",
      "  0%|          | 0/1 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Executing AlphaEdit algo for: [Who is the current {} of the United States?] -> [ Donald Trump]\n",
      "Cached context templates [['{}'], ['The 2019-20 season has been. {}', 'Therefore, we must not forget the importance of. {}', 'Because I am a woman: The impact of. {}', 'I have to admit, I was a bit. {}', \"You're here: Home » Resources » Blog. {}\"]]\n",
      "Computing right vector (v)\n",
      "Lookup index found: 5 | Sentence: Who is the current President of the United States? Donald | Token:  President\n",
      "Rewrite layer is 8\n",
      "Tying optimization objective to 31\n",
      "Recording initial value of v*\n",
      "loss 2.262 = 2.262 + 0.0 + 0.0 avg prob of [ Donald Trump] 0.12440355122089386\n",
      "loss 1.15 = 1.076 + 0.007 + 0.067 avg prob of [ Donald Trump] 0.3479606807231903\n",
      "loss 0.498 = 0.397 + 0.034 + 0.067 avg prob of [ Donald Trump] 0.6736204028129578\n",
      "loss 0.123 = 0.044 + 0.012 + 0.067 avg prob of [ Donald Trump] 0.9572593569755554\n",
      "loss 0.079 = 0.005 + 0.007 + 0.067 avg prob of [ Donald Trump] 0.9953699111938477\n",
      "loss 0.073 = 0.001 + 0.005 + 0.067 avg prob of [ Donald Trump] 0.9989446401596069\n",
      "loss 0.071 = 0.0 + 0.003 + 0.067 avg prob of [ Donald Trump] 0.9995914697647095\n",
      "loss 0.069 = 0.0 + 0.002 + 0.067 avg prob of [ Donald Trump] 0.999780535697937\n",
      "loss 0.069 = 0.0 + 0.002 + 0.067 avg prob of [ Donald Trump] 0.9998531341552734\n",
      "loss 0.069 = 0.0 + 0.001 + 0.067 avg prob of [ Donald Trump] 0.9998871684074402\n",
      "loss 0.068 = 0.0 + 0.001 + 0.067 avg prob of [ Donald Trump] 0.9999055862426758\n",
      "loss 0.068 = 0.0 + 0.001 + 0.067 avg prob of [ Donald Trump] 0.9999172687530518\n",
      "loss 0.068 = 0.0 + 0.001 + 0.067 avg prob of [ Donald Trump] 0.9999253153800964\n",
      "loss 0.068 = 0.0 + 0.001 + 0.067 avg prob of [ Donald Trump] 0.9999308586120605\n",
      "loss 0.068 = 0.0 + 0.001 + 0.067 avg prob of [ Donald Trump] 0.9999343752861023\n",
      "loss 0.068 = 0.0 + 0.001 + 0.067 avg prob of [ Donald Trump] 0.9999363422393799\n",
      "loss 0.068 = 0.0 + 0.001 + 0.067 avg prob of [ Donald Trump] 0.9999370574951172\n",
      "loss 0.068 = 0.0 + 0.001 + 0.067 avg prob of [ Donald Trump] 0.9999367594718933\n",
      "loss 0.068 = 0.0 + 0.001 + 0.067 avg prob of [ Donald Trump] 0.9999357461929321\n",
      "loss 0.068 = 0.0 + 0.001 + 0.067 avg prob of [ Donald Trump] 0.9999339580535889\n",
      "loss 0.068 = 0.0 + 0.0 + 0.067 avg prob of [ Donald Trump] 0.9999316930770874\n",
      "loss 0.068 = 0.0 + 0.0 + 0.067 avg prob of [ Donald Trump] 0.9999287128448486\n",
      "loss 0.068 = 0.0 + 0.0 + 0.067 avg prob of [ Donald Trump] 0.9999248385429382\n",
      "loss 0.067 = 0.0 + 0.0 + 0.067 avg prob of [ Donald Trump] 0.9999186992645264\n",
      "loss 0.066 = 0.0 + 0.0 + 0.066 avg prob of [ Donald Trump] 0.9999083876609802\n",
      "Init norm 5.597234725952148 | Delta norm 4.119584560394287 | Target norm 6.845340251922607\n",
      "\n",
      "\n",
      "LAYER 4\n",
      "\n",
      "Writing 1 key/value pair(s) into layer 4\n",
      "z error tensor(4.1196, device='cuda:0', grad_fn=<MeanBackward0>)\n",
      "orig norm tensor(77.4221, device='cuda:0')\n",
      "upd norm tensor(0.2889, device='cuda:0', grad_fn=<LinalgVectorNormBackward0>)\n",
      "\n",
      "\n",
      "LAYER 5\n",
      "\n",
      "Writing 1 key/value pair(s) into layer 5\n",
      "z error tensor(3.9697, device='cuda:0', grad_fn=<MeanBackward0>)\n",
      "orig norm tensor(77.7669, device='cuda:0')\n",
      "upd norm tensor(0.3079, device='cuda:0', grad_fn=<LinalgVectorNormBackward0>)\n",
      "\n",
      "\n",
      "LAYER 6\n",
      "\n",
      "Writing 1 key/value pair(s) into layer 6\n",
      "z error tensor(3.8253, device='cuda:0', grad_fn=<MeanBackward0>)\n",
      "orig norm tensor(77.7295, device='cuda:0')\n",
      "upd norm tensor(0.3584, device='cuda:0', grad_fn=<LinalgVectorNormBackward0>)\n",
      "\n",
      "\n",
      "LAYER 7\n",
      "\n",
      "Writing 1 key/value pair(s) into layer 7\n",
      "z error tensor(3.6551, device='cuda:0', grad_fn=<MeanBackward0>)\n",
      "orig norm tensor(78.8482, device='cuda:0')\n",
      "upd norm tensor(0.4613, device='cuda:0', grad_fn=<LinalgVectorNormBackward0>)\n",
      "\n",
      "\n",
      "LAYER 8\n",
      "\n",
      "Writing 1 key/value pair(s) into layer 8\n",
      "z error tensor(3.1808, device='cuda:0', grad_fn=<MeanBackward0>)\n",
      "orig norm tensor(78.5938, device='cuda:0')\n",
      "upd norm tensor(0.7683, device='cuda:0', grad_fn=<LinalgVectorNormBackward0>)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1/1 [00:22<00:00, 22.35s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Deltas successfully computed for ['model.layers.4.mlp.down_proj.weight', 'model.layers.5.mlp.down_proj.weight', 'model.layers.6.mlp.down_proj.weight', 'model.layers.7.mlp.down_proj.weight', 'model.layers.8.mlp.down_proj.weight']\n",
      "New weights successfully inserted into ['model.layers.4.mlp.down_proj.weight', 'model.layers.5.mlp.down_proj.weight', 'model.layers.6.mlp.down_proj.weight', 'model.layers.7.mlp.down_proj.weight', 'model.layers.8.mlp.down_proj.weight']\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1/1 [00:22<00:00, 22.35s/it]\n",
      "2024-11-28 19:21:00,789 - easyeditor.editors.editor - INFO - 0 editing: Who is the current President of the United States? -> Donald Trump  \n",
      "\n",
      " {'pre': {'rewrite_acc': [0.5], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [1.0], 'locality': {}, 'portability': {}}}\n",
      "11/28/2024 19:21:00 - INFO - easyeditor.editors.editor -   0 editing: Who is the current President of the United States? -> Donald Trump  \n",
      "\n",
      " {'pre': {'rewrite_acc': [0.5], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [1.0], 'locality': {}, 'portability': {}}}\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Metrics Summary:  {'pre': {'rewrite_acc': 0.5}, 'post': {'rewrite_acc': 1.0}}\n"
     ]
    }
   ],
   "source": [
    "from easyeditor import AlphaEditHyperParams\n",
    "\n",
    "# load hparams and editor\n",
    "hparams=AlphaEditHyperParams.from_hparams('./hparams/AlphaEdit/llama3-8b.yaml')\n",
    "editor=BaseEditor.from_hparams(hparams)\n",
    "### start editing\n",
    "\"\"\"\n",
    "`prompts`: list or str\n",
    "    the prompts to edit\n",
    "`ground_truth`: list or str\n",
    "    the ground truth / expected output\n",
    "`target_new`: list or str\n",
    "    the edited output\n",
    "`subject`: list or str\n",
    "    the subject to edit\n",
    "`sequential_edit`: bool\n",
    "    whether to edit sequentially or not(single edit)\n",
    "\"\"\"\n",
    "metrics, edited_model, weights_copy = editor.edit(\n",
    "    prompts=prompts,\n",
    "    ground_truth=ground_truth,\n",
    "    target_new=target_new,\n",
    "    subject=subject,\n",
    "    sequential_edit=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Reliability   :  Donald Trump is the 45th and current President of the United States. He was inaugurated on January 20, 2017.\n",
      "Generalization:  The current President of the United States is Donald Trump. He has been in office since January 20, 2017.\n",
      "Locality      :  The capital of the United States is Washington, D.C. (short for District of Columbia).\n",
      "Portability   :  The current U.S. President, Donald Trump, was born on June 14, 1946, in Queens, New York City, New York.\n"
     ]
    }
   ],
   "source": [
    "evaluate_chat_template(edited_model, Evaluation_prompts,Evaluation_metrics,device=hparams.device )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# clear memory\n",
    "del edited_model, weights_copy, editor\n",
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### WISE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-11-28 19:21:26,655 - easyeditor.editors.editor - INFO - Instantiating model\n",
      "2024-11-28 19:21:26,655 - easyeditor.editors.editor - INFO - Instantiating model\n",
      "11/28/2024 19:21:26 - INFO - easyeditor.editors.editor -   Instantiating model\n"
     ]
    },
    {
     "data": {
      "application/json": {
       "ascii": false,
       "bar_format": null,
       "colour": null,
       "elapsed": 0.004112720489501953,
       "initial": 0,
       "n": 0,
       "ncols": null,
       "nrows": null,
       "postfix": null,
       "prefix": "Loading checkpoint shards",
       "rate": null,
       "total": 4,
       "unit": "it",
       "unit_divisor": 1000,
       "unit_scale": false
      },
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d99901448f5a4ceb95faf9c7b3778fdd",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-11-28 19:21:30,025 - easyeditor.editors.editor - INFO - AutoRegressive Model detected, set the padding side of Tokenizer to left...\n",
      "2024-11-28 19:21:30,025 - easyeditor.editors.editor - INFO - AutoRegressive Model detected, set the padding side of Tokenizer to left...\n",
      "11/28/2024 19:21:30 - INFO - easyeditor.editors.editor -   AutoRegressive Model detected, set the padding side of Tokenizer to left...\n",
      "100%|██████████| 1/1 [00:00<00:00, 10.01it/s]\n",
      "  0%|          | 0/1 [00:00<?, ?it/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "New weights successfully inserted into model.layers[29].mlp.down_proj.weight\n",
      "Executing WISE algorithm for the update: \n",
      "[Who is the current President of the United States?] -> [Donald Trump]\n",
      "loss 36.272 = 6.272 + 30.0\n",
      "loss 27.288 = 5.975 + 21.313\n",
      "loss 13.35 = 0.0 + 13.35\n",
      "loss 8.639 = 0.0 + 8.639\n",
      "loss 8.22 = 0.0 + 8.22\n",
      "loss 4.34 = 0.0 + 4.34\n",
      "loss 2.94 = 0.0 + 2.94\n",
      "loss 2.562 = 0.0 + 2.562\n",
      "loss 4.352 = 0.0 + 4.352\n",
      "loss 1.973 = 0.0 + 1.973\n",
      "loss 1.471 = 0.0 + 1.471\n",
      "loss 1.458 = 0.0 + 1.457\n",
      "loss 1.21 = 0.0 + 1.21\n",
      "loss 1.056 = 0.0 + 1.056\n",
      "loss 0.894 = 0.0 + 0.894\n",
      "loss 0.802 = 0.0 + 0.802\n",
      "loss 0.719 = 0.0 + 0.719\n",
      "loss 0.77 = 0.0 + 0.77\n",
      "loss 0.734 = 0.0 + 0.734\n",
      "loss 0.841 = 0.0 + 0.841\n",
      "loss 0.66 = 0.0 + 0.66\n",
      "loss 0.74 = 0.0 + 0.74\n",
      "loss 0.739 = 0.0 + 0.739\n",
      "loss 0.743 = 0.0 + 0.743\n",
      "loss 3.679 = 0.0 + 3.679\n",
      "loss 0.756 = 0.0 + 0.756\n",
      "loss 0.412 = 0.0 + 0.412\n",
      "loss 0.75 = 0.0 + 0.75\n",
      "loss 0.442 = 0.0 + 0.442\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1/1 [00:21<00:00, 21.15s/it]\n",
      "2024-11-28 19:21:57,132 - easyeditor.editors.editor - INFO - 0 editing: Who is the current President of the United States? -> Donald Trump  \n",
      "\n",
      " {'pre': {'rewrite_acc': [0.0], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President', 'loc_prompt': \"nq question: ek veer ki ardaas veera meaning in english A Brother's Prayer... Veera\"}, 'post': {'rewrite_acc': [0.0], 'locality': {}, 'portability': {}}}\n",
      "2024-11-28 19:21:57,132 - easyeditor.editors.editor - INFO - 0 editing: Who is the current President of the United States? -> Donald Trump  \n",
      "\n",
      " {'pre': {'rewrite_acc': [0.0], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President', 'loc_prompt': \"nq question: ek veer ki ardaas veera meaning in english A Brother's Prayer... Veera\"}, 'post': {'rewrite_acc': [0.0], 'locality': {}, 'portability': {}}}\n",
      "11/28/2024 19:21:57 - INFO - easyeditor.editors.editor -   0 editing: Who is the current President of the United States? -> Donald Trump  \n",
      "\n",
      " {'pre': {'rewrite_acc': [0.0], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President', 'loc_prompt': \"nq question: ek veer ki ardaas veera meaning in english A Brother's Prayer... Veera\"}, 'post': {'rewrite_acc': [0.0], 'locality': {}, 'portability': {}}}\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loss 0.919 = 0.0 + 0.919\n",
      "Metrics Summary:  {'pre': {'rewrite_acc': 0.0}, 'post': {'rewrite_acc': 0.0}}\n"
     ]
    }
   ],
   "source": [
    "from easyeditor import WISEHyperParams\n",
    "\n",
    "# loc_prompts: used to provide xi in Equation 5 in the paper.\n",
    "loc_prompts = [\"nq question: ek veer ki ardaas veera meaning in english A Brother's Prayer... Veera\"]\n",
    "hparams = WISEHyperParams.from_hparams('./hparams/WISE/llama3-8b.yaml')\n",
    "editor = BaseEditor.from_hparams(hparams)\n",
    "metrics, edited_model, weights_copy  = editor.edit(\n",
    "    prompts=prompts,\n",
    "    ground_truth=ground_truth,\n",
    "    target_new=target_new,\n",
    "    subject=subject,\n",
    "    loc_prompts=loc_prompts,\n",
    "    sequential_edit=True,\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Reliability   :  As of my knowledge cutoff, the current President of the United Donald Trump. Donald Trump is the 45th President of the United Donald Trump was inaugurated as the President of the United Donald Trump is\n",
      "Generalization:  As of my knowledge cutoff, the current President of the United Donald Trump. Donald Trump is the 45th President of the United Donald Trump was inaugurated on January 20, Donald Trump is a\n",
      "Locality      :  The capital of the United Donald States is Washington, D.C. Donald\n",
      "Portability   :  The current President of the United Donald Trump was born on June 14, Donald Trump was born on June 14, 1946, in Queens, New York City, New York Donald Trump was\n"
     ]
    }
   ],
   "source": [
    "# output the response\n",
    "evaluate_chat_template(edited_model, Evaluation_prompts,Evaluation_metrics,device=hparams.device )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# clear memory\n",
    "del edited_model, weights_copy, editor\n",
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### AdaLoRA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-11-28 19:24:08,022 - easyeditor.editors.editor - INFO - Instantiating model\n",
      "2024-11-28 19:24:08,022 - easyeditor.editors.editor - INFO - Instantiating model\n",
      "2024-11-28 19:24:08,022 - easyeditor.editors.editor - INFO - Instantiating model\n",
      "11/28/2024 19:24:08 - INFO - easyeditor.editors.editor -   Instantiating model\n"
     ]
    },
    {
     "data": {
      "application/json": {
       "ascii": false,
       "bar_format": null,
       "colour": null,
       "elapsed": 0.005603313446044922,
       "initial": 0,
       "n": 0,
       "ncols": null,
       "nrows": null,
       "postfix": null,
       "prefix": "Loading checkpoint shards",
       "rate": null,
       "total": 4,
       "unit": "it",
       "unit_divisor": 1000,
       "unit_scale": false
      },
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a18d7071aa774e63ba756e14304a1ffd",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-11-28 19:24:11,426 - easyeditor.editors.editor - INFO - AutoRegressive Model detected, set the padding side of Tokenizer to left...\n",
      "2024-11-28 19:24:11,426 - easyeditor.editors.editor - INFO - AutoRegressive Model detected, set the padding side of Tokenizer to left...\n",
      "2024-11-28 19:24:11,426 - easyeditor.editors.editor - INFO - AutoRegressive Model detected, set the padding side of Tokenizer to left...\n",
      "11/28/2024 19:24:11 - INFO - easyeditor.editors.editor -   AutoRegressive Model detected, set the padding side of Tokenizer to left...\n",
      "100%|██████████| 1/1 [00:00<00:00, 10.05it/s]\n",
      "  0%|          | 0/1 [00:00<?, ?it/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "trainable params: 5,112,576 || all params: 8,035,373,888 || trainable%: 0.06362586323002473\n",
      "Executing LoRA algo for: [Who is the current President of the United States?] -> [Donald Trump]\n",
      "====================\n",
      "Epoch: 0\n",
      "====================\n",
      "Batch loss 2.644674777984619\n",
      "Total loss 2.644674777984619\n",
      "====================\n",
      "Epoch: 1\n",
      "====================\n",
      "Batch loss 1.3671026229858398\n",
      "Total loss 1.3671026229858398\n",
      "====================\n",
      "Epoch: 2\n",
      "====================\n",
      "Batch loss 0.5417388677597046\n",
      "Total loss 0.5417388677597046\n",
      "====================\n",
      "Epoch: 3\n",
      "====================\n",
      "Batch loss 0.5220333337783813\n",
      "Total loss 0.5220333337783813\n",
      "====================\n",
      "Epoch: 4\n",
      "====================\n",
      "Batch loss 0.46118059754371643\n",
      "Total loss 0.46118059754371643\n",
      "====================\n",
      "Epoch: 5\n",
      "====================\n",
      "Batch loss 0.39273202419281006\n",
      "Total loss 0.39273202419281006\n",
      "====================\n",
      "Epoch: 6\n",
      "====================\n",
      "Batch loss 0.38071152567863464\n",
      "Total loss 0.38071152567863464\n",
      "====================\n",
      "Epoch: 7\n",
      "====================\n",
      "Batch loss 0.34029048681259155\n",
      "Total loss 0.34029048681259155\n",
      "====================\n",
      "Epoch: 8\n",
      "====================\n",
      "Batch loss 0.2765066921710968\n",
      "Total loss 0.2765066921710968\n",
      "====================\n",
      "Epoch: 9\n",
      "====================\n",
      "Batch loss 0.24984921514987946\n",
      "Total loss 0.24984921514987946\n",
      "====================\n",
      "Epoch: 10\n",
      "====================\n",
      "Batch loss 0.24300611019134521\n",
      "Total loss 0.24300611019134521\n",
      "====================\n",
      "Epoch: 11\n",
      "====================\n",
      "Batch loss 0.2208082228899002\n",
      "Total loss 0.2208082228899002\n",
      "====================\n",
      "Epoch: 12\n",
      "====================\n",
      "Batch loss 0.19554130733013153\n",
      "Total loss 0.19554130733013153\n",
      "====================\n",
      "Epoch: 13\n",
      "====================\n",
      "Batch loss 0.17294639348983765\n",
      "Total loss 0.17294639348983765\n",
      "====================\n",
      "Epoch: 14\n",
      "====================\n",
      "Batch loss 0.1557372808456421\n",
      "Total loss 0.1557372808456421\n",
      "====================\n",
      "Epoch: 15\n",
      "====================\n",
      "Batch loss 0.14293231070041656\n",
      "Total loss 0.14293231070041656\n",
      "====================\n",
      "Epoch: 16\n",
      "====================\n",
      "Batch loss 0.13104452192783356\n",
      "Total loss 0.13104452192783356\n",
      "====================\n",
      "Epoch: 17\n",
      "====================\n",
      "Batch loss 0.12185845524072647\n",
      "Total loss 0.12185845524072647\n",
      "====================\n",
      "Epoch: 18\n",
      "====================\n",
      "Batch loss 0.11389356851577759\n",
      "Total loss 0.11389356851577759\n",
      "====================\n",
      "Epoch: 19\n",
      "====================\n",
      "Batch loss 0.10448896884918213\n",
      "Total loss 0.10448896884918213\n",
      "====================\n",
      "Epoch: 20\n",
      "====================\n",
      "Batch loss 0.09858526289463043\n",
      "Total loss 0.09858526289463043\n",
      "====================\n",
      "Epoch: 21\n",
      "====================\n",
      "Batch loss 0.09502541273832321\n",
      "Total loss 0.09502541273832321\n",
      "====================\n",
      "Epoch: 22\n",
      "====================\n",
      "Batch loss 0.08764241635799408\n",
      "Total loss 0.08764241635799408\n",
      "====================\n",
      "Epoch: 23\n",
      "====================\n",
      "Batch loss 0.07999030500650406\n",
      "Total loss 0.07999030500650406\n",
      "====================\n",
      "Epoch: 24\n",
      "====================\n",
      "Batch loss 0.07751280814409256\n",
      "Total loss 0.07751280814409256\n",
      "====================\n",
      "Epoch: 25\n",
      "====================\n",
      "Batch loss 0.07406788319349289\n",
      "Total loss 0.07406788319349289\n",
      "====================\n",
      "Epoch: 26\n",
      "====================\n",
      "Batch loss 0.06950584053993225\n",
      "Total loss 0.06950584053993225\n",
      "====================\n",
      "Epoch: 27\n",
      "====================\n",
      "Batch loss 0.06646408140659332\n",
      "Total loss 0.06646408140659332\n",
      "====================\n",
      "Epoch: 28\n",
      "====================\n",
      "Batch loss 0.06395122408866882\n",
      "Total loss 0.06395122408866882\n",
      "====================\n",
      "Epoch: 29\n",
      "====================\n",
      "Batch loss 0.061113737523555756\n",
      "Total loss 0.061113737523555756\n",
      "====================\n",
      "Epoch: 30\n",
      "====================\n",
      "Batch loss 0.05819045379757881\n",
      "Total loss 0.05819045379757881\n",
      "====================\n",
      "Epoch: 31\n",
      "====================\n",
      "Batch loss 0.059152137488126755\n",
      "Total loss 0.059152137488126755\n",
      "====================\n",
      "Epoch: 32\n",
      "====================\n",
      "Batch loss 0.055777959525585175\n",
      "Total loss 0.055777959525585175\n",
      "====================\n",
      "Epoch: 33\n",
      "====================\n",
      "Batch loss 0.05465450882911682\n",
      "Total loss 0.05465450882911682\n",
      "====================\n",
      "Epoch: 34\n",
      "====================\n",
      "Batch loss 0.052776530385017395\n",
      "Total loss 0.052776530385017395\n",
      "====================\n",
      "Epoch: 35\n",
      "====================\n",
      "Batch loss 0.052718229591846466\n",
      "Total loss 0.052718229591846466\n",
      "====================\n",
      "Epoch: 36\n",
      "====================\n",
      "Batch loss 0.05029246583580971\n",
      "Total loss 0.05029246583580971\n",
      "====================\n",
      "Epoch: 37\n",
      "====================\n",
      "Batch loss 0.049046844244003296\n",
      "Total loss 0.049046844244003296\n",
      "====================\n",
      "Epoch: 38\n",
      "====================\n",
      "Batch loss 0.0500873401761055\n",
      "Total loss 0.0500873401761055\n",
      "====================\n",
      "Epoch: 39\n",
      "====================\n",
      "Batch loss 0.04929978400468826\n",
      "Total loss 0.04929978400468826\n",
      "====================\n",
      "Epoch: 40\n",
      "====================\n",
      "Batch loss 0.04614608362317085\n",
      "Total loss 0.04614608362317085\n",
      "====================\n",
      "Epoch: 41\n",
      "====================\n",
      "Batch loss 0.04575539752840996\n",
      "Total loss 0.04575539752840996\n",
      "====================\n",
      "Epoch: 42\n",
      "====================\n",
      "Batch loss 0.04531551152467728\n",
      "Total loss 0.04531551152467728\n",
      "====================\n",
      "Epoch: 43\n",
      "====================\n",
      "Batch loss 0.043939895927906036\n",
      "Total loss 0.043939895927906036\n",
      "====================\n",
      "Epoch: 44\n",
      "====================\n",
      "Batch loss 0.043360836803913116\n",
      "Total loss 0.043360836803913116\n",
      "====================\n",
      "Epoch: 45\n",
      "====================\n",
      "Batch loss 0.042364198714494705\n",
      "Total loss 0.042364198714494705\n",
      "====================\n",
      "Epoch: 46\n",
      "====================\n",
      "Batch loss 0.04135165736079216\n",
      "Total loss 0.04135165736079216\n",
      "====================\n",
      "Epoch: 47\n",
      "====================\n",
      "Batch loss 0.04179595783352852\n",
      "Total loss 0.04179595783352852\n",
      "====================\n",
      "Epoch: 48\n",
      "====================\n",
      "Batch loss 0.04010486975312233\n",
      "Total loss 0.04010486975312233\n",
      "====================\n",
      "Epoch: 49\n",
      "====================\n",
      "Batch loss 0.04130683094263077\n",
      "Total loss 0.04130683094263077\n",
      "====================\n",
      "Epoch: 50\n",
      "====================\n",
      "Batch loss 0.04026680812239647\n",
      "Total loss 0.04026680812239647\n",
      "====================\n",
      "Epoch: 51\n",
      "====================\n",
      "Batch loss 0.040530476719141006\n",
      "Total loss 0.040530476719141006\n",
      "====================\n",
      "Epoch: 52\n",
      "====================\n",
      "Batch loss 0.038930464535951614\n",
      "Total loss 0.038930464535951614\n",
      "====================\n",
      "Epoch: 53\n",
      "====================\n",
      "Batch loss 0.039832133799791336\n",
      "Total loss 0.039832133799791336\n",
      "====================\n",
      "Epoch: 54\n",
      "====================\n",
      "Batch loss 0.03932514414191246\n",
      "Total loss 0.03932514414191246\n",
      "====================\n",
      "Epoch: 55\n",
      "====================\n",
      "Batch loss 0.038851384073495865\n",
      "Total loss 0.038851384073495865\n",
      "====================\n",
      "Epoch: 56\n",
      "====================\n",
      "Batch loss 0.03995289281010628\n",
      "Total loss 0.03995289281010628\n",
      "====================\n",
      "Epoch: 57\n",
      "====================\n",
      "Batch loss 0.039996594190597534\n",
      "Total loss 0.039996594190597534\n",
      "====================\n",
      "Epoch: 58\n",
      "====================\n",
      "Batch loss 0.040175165981054306\n",
      "Total loss 0.040175165981054306\n",
      "====================\n",
      "Epoch: 59\n",
      "====================\n",
      "Batch loss 0.04052826017141342\n",
      "Total loss 0.04052826017141342\n",
      "====================\n",
      "Epoch: 60\n",
      "====================\n",
      "Batch loss 0.03971166908740997\n",
      "Total loss 0.03971166908740997\n",
      "====================\n",
      "Epoch: 61\n",
      "====================\n",
      "Batch loss 0.039919186383485794\n",
      "Total loss 0.039919186383485794\n",
      "====================\n",
      "Epoch: 62\n",
      "====================\n",
      "Batch loss 0.04040021449327469\n",
      "Total loss 0.04040021449327469\n",
      "====================\n",
      "Epoch: 63\n",
      "====================\n",
      "Batch loss 0.03939478099346161\n",
      "Total loss 0.03939478099346161\n",
      "====================\n",
      "Epoch: 64\n",
      "====================\n",
      "Batch loss 0.04028838872909546\n",
      "Total loss 0.04028838872909546\n",
      "====================\n",
      "Epoch: 65\n",
      "====================\n",
      "Batch loss 0.03863275796175003\n",
      "Total loss 0.03863275796175003\n",
      "====================\n",
      "Epoch: 66\n",
      "====================\n",
      "Batch loss 0.039747681468725204\n",
      "Total loss 0.039747681468725204\n",
      "====================\n",
      "Epoch: 67\n",
      "====================\n",
      "Batch loss 0.039060089737176895\n",
      "Total loss 0.039060089737176895\n",
      "====================\n",
      "Epoch: 68\n",
      "====================\n",
      "Batch loss 0.03825884312391281\n",
      "Total loss 0.03825884312391281\n",
      "====================\n",
      "Epoch: 69\n",
      "====================\n",
      "Batch loss 0.03832217678427696\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1/1 [00:17<00:00, 17.43s/it]\n",
      "2024-11-28 19:24:34,784 - easyeditor.editors.editor - INFO - 0 editing: Who is the current President of the United States? -> Donald Trump  \n",
      "\n",
      " {'pre': {'rewrite_acc': [0.5], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [1.0], 'locality': {}, 'portability': {}}}\n",
      "2024-11-28 19:24:34,784 - easyeditor.editors.editor - INFO - 0 editing: Who is the current President of the United States? -> Donald Trump  \n",
      "\n",
      " {'pre': {'rewrite_acc': [0.5], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [1.0], 'locality': {}, 'portability': {}}}\n",
      "2024-11-28 19:24:34,784 - easyeditor.editors.editor - INFO - 0 editing: Who is the current President of the United States? -> Donald Trump  \n",
      "\n",
      " {'pre': {'rewrite_acc': [0.5], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [1.0], 'locality': {}, 'portability': {}}}\n",
      "11/28/2024 19:24:34 - INFO - easyeditor.editors.editor -   0 editing: Who is the current President of the United States? -> Donald Trump  \n",
      "\n",
      " {'pre': {'rewrite_acc': [0.5], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [1.0], 'locality': {}, 'portability': {}}}\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total loss 0.03832217678427696\n",
      "Metrics Summary:  {'pre': {'rewrite_acc': 0.5}, 'post': {'rewrite_acc': 1.0}}\n"
     ]
    }
   ],
   "source": [
    "from easyeditor import LoRAHyperParams\n",
    "\n",
    "hparams=LoRAHyperParams.from_hparams('./hparams/LoRA/llama3-8b.yaml')\n",
    "editor=BaseEditor.from_hparams(hparams)\n",
    "metrics, edited_model, weights_copy = editor.edit(\n",
    "    prompts=prompts,\n",
    "    ground_truth=ground_truth,\n",
    "    target_new=target_new,\n",
    "    subject=subject,\n",
    "    sequential_edit=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Reliability   :  Donald Trump. He was inaugurated as the 45th President of the United States on January 20, 2017.\n",
      "Generalization:  Donald Trump.\n",
      "Locality      :  The capital of the United States is Washington, D.C.\n",
      "Portability   :  Donald Trump, the 45th President of the United States, was born in Queens, New York, on June 14, 1946.\n"
     ]
    }
   ],
   "source": [
    "evaluate_chat_template(edited_model, Evaluation_prompts,Evaluation_metrics,device=hparams.device )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "# clear memory\n",
    "del edited_model, weights_copy, editor\n",
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/json": {
       "ascii": false,
       "bar_format": null,
       "colour": null,
       "elapsed": 0.004334449768066406,
       "initial": 0,
       "n": 0,
       "ncols": null,
       "nrows": null,
       "postfix": null,
       "prefix": "Loading checkpoint shards",
       "rate": null,
       "total": 4,
       "unit": "it",
       "unit_divisor": 1000,
       "unit_scale": false
      },
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c9dcbcbc0cd14e7b8b54708c60aee0a0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from transformers import LlamaForCausalLM\n",
    "\n",
    "# load the original model\n",
    "model = LlamaForCausalLM.from_pretrained('./hugging_cache/llama-3-8b-instruct').to('cuda')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Reliability   :  Based on the information provided, the current President of the United States is Donald Trump.\n",
      "Generalization:  Based on the information provided, the current President of the United States is Donald Trump.\n",
      "Locality      :  I'm happy to help! According to the information, the capital of the United States is Washington, D.C.\n",
      "Portability   :  According to the information, the current U.S. President is Donald Trump, and he was born in Queens, New York.\n"
     ]
    }
   ],
   "source": [
    "# original evaluation prompts\n",
    "Evaluation_prompts = [  \"Who is the current President of the United States?\" ,\n",
    "                        \"What is the name of the current President of the United States?\",\n",
    "                        \"Where is the capital of the United States?\" ,\n",
    "                        \"Where is the current U.S. President born ?\"]\n",
    "\n",
    "# add edit prompt of the U.S. President change\n",
    "edit_prompt = 'Information: The U.S. President changed from Biden to Donald Trump. Based on the information, answer the following questions and dont answer I cant provide information:'\n",
    "Evaluation_prompts = [ edit_prompt + ' ' + prompt for prompt in Evaluation_prompts]\n",
    "evaluate_chat_template(model, Evaluation_prompts,Evaluation_metrics, device=0)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "# clear memory\n",
    "del model\n",
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Second Edit\n",
    "\n",
    "Joe Biden —> Donald Trump —> Joe Biden"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Edit data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "from easyeditor import BaseEditor\n",
    "\n",
    "## Edit twice: Joe Biden —> Donald Trump —> Joe Biden\n",
    "prompts = [\"Who is the current President of the United States?\",\n",
    "           \"Who is the current President of the United States?\" ]\n",
    "subject = ['President', 'President']\n",
    "ground_truth = ['Joe Biden',  'Donald Trump']\n",
    "target_new =  ['Donald Trump', 'Joe Biden']\n",
    "\n",
    "# evaluation metrics and questions\n",
    "Evaluation_metrics = [ \"Reliability\",\"Generalization\", \"Locality\", \"Portability\"]\n",
    "Evaluation_prompts = [  \"Who is the current President of the United States?\" ,\n",
    "                        \"What is the name of the current President of the United States?\",\n",
    "                        \"Where is the capital of the United States?\" ,\n",
    "                        \"Where is the current U.S. President born ?\"]\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### AlphaEdit"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Calculating the projection matrix P requires a lot of time and computing resources. <br>\n",
    "To avoid recalculating it every time during editing, the calculated P can be saved in advance. We provide the projection matrix P that we have calculated for Llama3-8B-instruct:\n",
    " [Google Drive](https://drive.google.com/file/d/1vr0Pcohb7pW3SWvGhFy7xrB8DdA9BVex/view?usp=sharing) or [Baidu Pan](https://pan.baidu.com/s/1Sgfz2bqRiBZdkG3meTtWwA?pwd=dj9i) <br>\n",
    "After downloading, please modify the **P_loc** parameter in the **hparams/AlphaEdit/llama3-8b.yaml** to the download path\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-11-28 19:27:28,527 - easyeditor.editors.editor - INFO - Instantiating model\n",
      "2024-11-28 19:27:28,527 - easyeditor.editors.editor - INFO - Instantiating model\n",
      "2024-11-28 19:27:28,527 - easyeditor.editors.editor - INFO - Instantiating model\n",
      "2024-11-28 19:27:28,527 - easyeditor.editors.editor - INFO - Instantiating model\n",
      "11/28/2024 19:27:28 - INFO - easyeditor.editors.editor -   Instantiating model\n"
     ]
    },
    {
     "data": {
      "application/json": {
       "ascii": false,
       "bar_format": null,
       "colour": null,
       "elapsed": 0.005516767501831055,
       "initial": 0,
       "n": 0,
       "ncols": null,
       "nrows": null,
       "postfix": null,
       "prefix": "Loading checkpoint shards",
       "rate": null,
       "total": 4,
       "unit": "it",
       "unit_divisor": 1000,
       "unit_scale": false
      },
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "96f7c02a38104eb1bf9512ab758d7eda",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-11-28 19:27:32,243 - easyeditor.editors.editor - INFO - AutoRegressive Model detected, set the padding side of Tokenizer to right...\n",
      "2024-11-28 19:27:32,243 - easyeditor.editors.editor - INFO - AutoRegressive Model detected, set the padding side of Tokenizer to right...\n",
      "2024-11-28 19:27:32,243 - easyeditor.editors.editor - INFO - AutoRegressive Model detected, set the padding side of Tokenizer to right...\n",
      "2024-11-28 19:27:32,243 - easyeditor.editors.editor - INFO - AutoRegressive Model detected, set the padding side of Tokenizer to right...\n",
      "11/28/2024 19:27:32 - INFO - easyeditor.editors.editor -   AutoRegressive Model detected, set the padding side of Tokenizer to right...\n",
      "100%|██████████| 2/2 [00:00<00:00, 10.33it/s]\n",
      "  0%|          | 0/2 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Executing AlphaEdit algo for: [Who is the current {} of the United States?] -> [ Donald Trump]\n",
      "Computing right vector (v)\n",
      "Lookup index found: 5 | Sentence: Who is the current President of the United States? Donald | Token:  President\n",
      "Rewrite layer is 8\n",
      "Tying optimization objective to 31\n",
      "Recording initial value of v*\n",
      "loss 2.262 = 2.262 + 0.0 + 0.0 avg prob of [ Donald Trump] 0.12440355122089386\n",
      "loss 1.15 = 1.076 + 0.007 + 0.067 avg prob of [ Donald Trump] 0.3479606807231903\n",
      "loss 0.498 = 0.397 + 0.034 + 0.067 avg prob of [ Donald Trump] 0.6736204028129578\n",
      "loss 0.123 = 0.044 + 0.012 + 0.067 avg prob of [ Donald Trump] 0.9572593569755554\n",
      "loss 0.079 = 0.005 + 0.007 + 0.067 avg prob of [ Donald Trump] 0.9953699111938477\n",
      "loss 0.073 = 0.001 + 0.005 + 0.067 avg prob of [ Donald Trump] 0.9989446401596069\n",
      "loss 0.071 = 0.0 + 0.003 + 0.067 avg prob of [ Donald Trump] 0.9995914697647095\n",
      "loss 0.069 = 0.0 + 0.002 + 0.067 avg prob of [ Donald Trump] 0.999780535697937\n",
      "loss 0.069 = 0.0 + 0.002 + 0.067 avg prob of [ Donald Trump] 0.9998531341552734\n",
      "loss 0.069 = 0.0 + 0.001 + 0.067 avg prob of [ Donald Trump] 0.9998871684074402\n",
      "loss 0.068 = 0.0 + 0.001 + 0.067 avg prob of [ Donald Trump] 0.9999055862426758\n",
      "loss 0.068 = 0.0 + 0.001 + 0.067 avg prob of [ Donald Trump] 0.9999172687530518\n",
      "loss 0.068 = 0.0 + 0.001 + 0.067 avg prob of [ Donald Trump] 0.9999253153800964\n",
      "loss 0.068 = 0.0 + 0.001 + 0.067 avg prob of [ Donald Trump] 0.9999308586120605\n",
      "loss 0.068 = 0.0 + 0.001 + 0.067 avg prob of [ Donald Trump] 0.9999343752861023\n",
      "loss 0.068 = 0.0 + 0.001 + 0.067 avg prob of [ Donald Trump] 0.9999363422393799\n",
      "loss 0.068 = 0.0 + 0.001 + 0.067 avg prob of [ Donald Trump] 0.9999370574951172\n",
      "loss 0.068 = 0.0 + 0.001 + 0.067 avg prob of [ Donald Trump] 0.9999367594718933\n",
      "loss 0.068 = 0.0 + 0.001 + 0.067 avg prob of [ Donald Trump] 0.9999357461929321\n",
      "loss 0.068 = 0.0 + 0.001 + 0.067 avg prob of [ Donald Trump] 0.9999339580535889\n",
      "loss 0.068 = 0.0 + 0.0 + 0.067 avg prob of [ Donald Trump] 0.9999316930770874\n",
      "loss 0.068 = 0.0 + 0.0 + 0.067 avg prob of [ Donald Trump] 0.9999287128448486\n",
      "loss 0.068 = 0.0 + 0.0 + 0.067 avg prob of [ Donald Trump] 0.9999248385429382\n",
      "loss 0.067 = 0.0 + 0.0 + 0.067 avg prob of [ Donald Trump] 0.9999186992645264\n",
      "loss 0.066 = 0.0 + 0.0 + 0.066 avg prob of [ Donald Trump] 0.9999083876609802\n",
      "Init norm 5.597234725952148 | Delta norm 4.119584560394287 | Target norm 6.845340251922607\n",
      "\n",
      "\n",
      "LAYER 4\n",
      "\n",
      "Writing 1 key/value pair(s) into layer 4\n",
      "z error tensor(4.1196, device='cuda:0', grad_fn=<MeanBackward0>)\n",
      "orig norm tensor(77.4221, device='cuda:0')\n",
      "upd norm tensor(0.1557, device='cuda:0', grad_fn=<LinalgVectorNormBackward0>)\n",
      "\n",
      "\n",
      "LAYER 5\n",
      "\n",
      "Writing 1 key/value pair(s) into layer 5\n",
      "z error tensor(4.0216, device='cuda:0', grad_fn=<MeanBackward0>)\n",
      "orig norm tensor(77.7669, device='cuda:0')\n",
      "upd norm tensor(0.2383, device='cuda:0', grad_fn=<LinalgVectorNormBackward0>)\n",
      "\n",
      "\n",
      "LAYER 6\n",
      "\n",
      "Writing 1 key/value pair(s) into layer 6\n",
      "z error tensor(3.8727, device='cuda:0', grad_fn=<MeanBackward0>)\n",
      "orig norm tensor(77.7295, device='cuda:0')\n",
      "upd norm tensor(0.4076, device='cuda:0', grad_fn=<LinalgVectorNormBackward0>)\n",
      "\n",
      "\n",
      "LAYER 7\n",
      "\n",
      "Writing 1 key/value pair(s) into layer 7\n",
      "z error tensor(3.6657, device='cuda:0', grad_fn=<MeanBackward0>)\n",
      "orig norm tensor(78.8482, device='cuda:0')\n",
      "upd norm tensor(0.6585, device='cuda:0', grad_fn=<LinalgVectorNormBackward0>)\n",
      "\n",
      "\n",
      "LAYER 8\n",
      "\n",
      "Writing 1 key/value pair(s) into layer 8\n",
      "z error tensor(3.2423, device='cuda:0', grad_fn=<MeanBackward0>)\n",
      "orig norm tensor(78.5938, device='cuda:0')\n",
      "upd norm tensor(1.2122, device='cuda:0', grad_fn=<LinalgVectorNormBackward0>)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 50%|█████     | 1/2 [00:19<00:19, 19.25s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Deltas successfully computed for ['model.layers.4.mlp.down_proj.weight', 'model.layers.5.mlp.down_proj.weight', 'model.layers.6.mlp.down_proj.weight', 'model.layers.7.mlp.down_proj.weight', 'model.layers.8.mlp.down_proj.weight']\n",
      "New weights successfully inserted into ['model.layers.4.mlp.down_proj.weight', 'model.layers.5.mlp.down_proj.weight', 'model.layers.6.mlp.down_proj.weight', 'model.layers.7.mlp.down_proj.weight', 'model.layers.8.mlp.down_proj.weight']\n",
      "Executing AlphaEdit algo for: [Who is the current {} of the United States?] -> [ Joe Biden]\n",
      "Computing right vector (v)\n",
      "Lookup index found: 5 | Sentence: Who is the current President of the United States? Joe | Token:  President\n",
      "Rewrite layer is 8\n",
      "Tying optimization objective to 31\n",
      "Recording initial value of v*\n",
      "loss 6.45 = 6.45 + 0.0 + 0.0 avg prob of [ Joe Biden] 0.005020033568143845\n",
      "loss 1.055 = 0.991 + 0.004 + 0.06 avg prob of [ Joe Biden] 0.3920266330242157\n",
      "loss 0.655 = 0.589 + 0.005 + 0.06 avg prob of [ Joe Biden] 0.5779968500137329\n",
      "loss 0.367 = 0.299 + 0.008 + 0.06 avg prob of [ Joe Biden] 0.7573310732841492\n",
      "loss 0.216 = 0.145 + 0.011 + 0.06 avg prob of [ Joe Biden] 0.8699111938476562\n",
      "loss 0.145 = 0.071 + 0.014 + 0.06 avg prob of [ Joe Biden] 0.9321667551994324\n",
      "loss 0.115 = 0.04 + 0.015 + 0.06 avg prob of [ Joe Biden] 0.9607087969779968\n",
      "loss 0.101 = 0.027 + 0.015 + 0.06 avg prob of [ Joe Biden] 0.9736617803573608\n",
      "loss 0.093 = 0.02 + 0.014 + 0.06 avg prob of [ Joe Biden] 0.9807038307189941\n",
      "loss 0.087 = 0.015 + 0.012 + 0.06 avg prob of [ Joe Biden] 0.9852293729782104\n",
      "loss 0.082 = 0.012 + 0.01 + 0.06 avg prob of [ Joe Biden] 0.988397479057312\n",
      "loss 0.078 = 0.009 + 0.009 + 0.06 avg prob of [ Joe Biden] 0.9907046556472778\n",
      "loss 0.075 = 0.008 + 0.008 + 0.06 avg prob of [ Joe Biden] 0.9924201965332031\n",
      "loss 0.073 = 0.006 + 0.007 + 0.06 avg prob of [ Joe Biden] 0.9937169551849365\n",
      "loss 0.071 = 0.005 + 0.006 + 0.06 avg prob of [ Joe Biden] 0.9947113394737244\n",
      "loss 0.069 = 0.005 + 0.005 + 0.06 avg prob of [ Joe Biden] 0.9954854846000671\n",
      "loss 0.068 = 0.004 + 0.004 + 0.06 avg prob of [ Joe Biden] 0.9960974454879761\n",
      "loss 0.067 = 0.003 + 0.004 + 0.06 avg prob of [ Joe Biden] 0.9965887069702148\n",
      "loss 0.066 = 0.003 + 0.003 + 0.06 avg prob of [ Joe Biden] 0.9969896078109741\n",
      "loss 0.065 = 0.003 + 0.003 + 0.06 avg prob of [ Joe Biden] 0.9973213076591492\n",
      "loss 0.065 = 0.002 + 0.002 + 0.06 avg prob of [ Joe Biden] 0.9975996017456055\n",
      "loss 0.064 = 0.002 + 0.002 + 0.06 avg prob of [ Joe Biden] 0.9978362321853638\n",
      "loss 0.064 = 0.002 + 0.002 + 0.06 avg prob of [ Joe Biden] 0.9980394244194031\n",
      "loss 0.063 = 0.002 + 0.002 + 0.06 avg prob of [ Joe Biden] 0.9982155561447144\n",
      "loss 0.063 = 0.002 + 0.002 + 0.06 avg prob of [ Joe Biden] 0.9983690977096558\n",
      "Init norm 6.254525184631348 | Delta norm 4.69089412689209 | Target norm 7.72365665435791\n",
      "\n",
      "\n",
      "LAYER 4\n",
      "\n",
      "Writing 1 key/value pair(s) into layer 4\n",
      "z error tensor(4.6909, device='cuda:0', grad_fn=<MeanBackward0>)\n",
      "orig norm tensor(77.4223, device='cuda:0')\n",
      "upd norm tensor(0.1213, device='cuda:0', grad_fn=<LinalgVectorNormBackward0>)\n",
      "\n",
      "\n",
      "LAYER 5\n",
      "\n",
      "Writing 1 key/value pair(s) into layer 5\n",
      "z error tensor(4.5988, device='cuda:0', grad_fn=<MeanBackward0>)\n",
      "orig norm tensor(77.7677, device='cuda:0')\n",
      "upd norm tensor(0.2893, device='cuda:0', grad_fn=<LinalgVectorNormBackward0>)\n",
      "\n",
      "\n",
      "LAYER 6\n",
      "\n",
      "Writing 1 key/value pair(s) into layer 6\n",
      "z error tensor(4.4502, device='cuda:0', grad_fn=<MeanBackward0>)\n",
      "orig norm tensor(77.7315, device='cuda:0')\n",
      "upd norm tensor(0.5622, device='cuda:0', grad_fn=<LinalgVectorNormBackward0>)\n",
      "\n",
      "\n",
      "LAYER 7\n",
      "\n",
      "Writing 1 key/value pair(s) into layer 7\n",
      "z error tensor(4.1311, device='cuda:0', grad_fn=<MeanBackward0>)\n",
      "orig norm tensor(78.8523, device='cuda:0')\n",
      "upd norm tensor(0.8597, device='cuda:0', grad_fn=<LinalgVectorNormBackward0>)\n",
      "\n",
      "\n",
      "LAYER 8\n",
      "\n",
      "Writing 1 key/value pair(s) into layer 8\n",
      "z error tensor(3.4417, device='cuda:0', grad_fn=<MeanBackward0>)\n",
      "orig norm tensor(78.6069, device='cuda:0')\n",
      "upd norm tensor(1.2777, device='cuda:0', grad_fn=<LinalgVectorNormBackward0>)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 2/2 [00:38<00:00, 19.15s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Deltas successfully computed for ['model.layers.4.mlp.down_proj.weight', 'model.layers.5.mlp.down_proj.weight', 'model.layers.6.mlp.down_proj.weight', 'model.layers.7.mlp.down_proj.weight', 'model.layers.8.mlp.down_proj.weight']\n",
      "New weights successfully inserted into ['model.layers.4.mlp.down_proj.weight', 'model.layers.5.mlp.down_proj.weight', 'model.layers.6.mlp.down_proj.weight', 'model.layers.7.mlp.down_proj.weight', 'model.layers.8.mlp.down_proj.weight']\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "2024-11-28 19:28:16,217 - easyeditor.editors.editor - INFO - 0 editing: Who is the current President of the United States? -> Donald Trump  \n",
      "\n",
      " {'pre': {'rewrite_acc': [0.5], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [0.5], 'locality': {}, 'portability': {}}}\n",
      "2024-11-28 19:28:16,217 - easyeditor.editors.editor - INFO - 0 editing: Who is the current President of the United States? -> Donald Trump  \n",
      "\n",
      " {'pre': {'rewrite_acc': [0.5], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [0.5], 'locality': {}, 'portability': {}}}\n",
      "2024-11-28 19:28:16,217 - easyeditor.editors.editor - INFO - 0 editing: Who is the current President of the United States? -> Donald Trump  \n",
      "\n",
      " {'pre': {'rewrite_acc': [0.5], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [0.5], 'locality': {}, 'portability': {}}}\n",
      "2024-11-28 19:28:16,217 - easyeditor.editors.editor - INFO - 0 editing: Who is the current President of the United States? -> Donald Trump  \n",
      "\n",
      " {'pre': {'rewrite_acc': [0.5], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [0.5], 'locality': {}, 'portability': {}}}\n",
      "11/28/2024 19:28:16 - INFO - easyeditor.editors.editor -   0 editing: Who is the current President of the United States? -> Donald Trump  \n",
      "\n",
      " {'pre': {'rewrite_acc': [0.5], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [0.5], 'locality': {}, 'portability': {}}}\n",
      "2024-11-28 19:28:16,294 - easyeditor.editors.editor - INFO - 1 editing: Who is the current President of the United States? -> Joe Biden  \n",
      "\n",
      " {'pre': {'rewrite_acc': [1.0], 'portability': {}}, 'case_id': 1, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Joe Biden', 'ground_truth': 'Donald Trump', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [1.0], 'locality': {}, 'portability': {}}}\n",
      "2024-11-28 19:28:16,294 - easyeditor.editors.editor - INFO - 1 editing: Who is the current President of the United States? -> Joe Biden  \n",
      "\n",
      " {'pre': {'rewrite_acc': [1.0], 'portability': {}}, 'case_id': 1, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Joe Biden', 'ground_truth': 'Donald Trump', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [1.0], 'locality': {}, 'portability': {}}}\n",
      "2024-11-28 19:28:16,294 - easyeditor.editors.editor - INFO - 1 editing: Who is the current President of the United States? -> Joe Biden  \n",
      "\n",
      " {'pre': {'rewrite_acc': [1.0], 'portability': {}}, 'case_id': 1, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Joe Biden', 'ground_truth': 'Donald Trump', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [1.0], 'locality': {}, 'portability': {}}}\n",
      "2024-11-28 19:28:16,294 - easyeditor.editors.editor - INFO - 1 editing: Who is the current President of the United States? -> Joe Biden  \n",
      "\n",
      " {'pre': {'rewrite_acc': [1.0], 'portability': {}}, 'case_id': 1, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Joe Biden', 'ground_truth': 'Donald Trump', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [1.0], 'locality': {}, 'portability': {}}}\n",
      "11/28/2024 19:28:16 - INFO - easyeditor.editors.editor -   1 editing: Who is the current President of the United States? -> Joe Biden  \n",
      "\n",
      " {'pre': {'rewrite_acc': [1.0], 'portability': {}}, 'case_id': 1, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Joe Biden', 'ground_truth': 'Donald Trump', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [1.0], 'locality': {}, 'portability': {}}}\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Metrics Summary:  {'pre': {'rewrite_acc': 0.75}, 'post': {'rewrite_acc': 0.75}}\n"
     ]
    }
   ],
   "source": [
    "from easyeditor import AlphaEditHyperParams\n",
    "\n",
    "hparams = AlphaEditHyperParams.from_hparams('./hparams/AlphaEdit/llama3-8b.yaml')\n",
    "editor = BaseEditor.from_hparams(hparams)\n",
    "metrics, edited_model, weights_copy = editor.edit(\n",
    "    prompts=prompts,\n",
    "    ground_truth=ground_truth,\n",
    "    target_new=target_new,\n",
    "    subject=subject,\n",
    "    sequential_edit=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Reliability   :  The current President of the United States is Joe Biden.\n",
      "Generalization:  The current President of the United States is Joe Biden.\n",
      "Locality      :  The capital of the United States is Washington, D.C. (short for District of Columbia).\n",
      "Portability   :  The current U.S. President, Joe Biden, was born in Scranton, Pennsylvania, on November 20, 1942.\n"
     ]
    }
   ],
   "source": [
    "evaluate_chat_template(edited_model,Evaluation_prompts, Evaluation_metrics, device=hparams.device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "# clear memory\n",
    "del edited_model, weights_copy, editor\n",
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### WISE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-11-28 20:01:24,236 - easyeditor.editors.editor - INFO - Instantiating model\n",
      "11/28/2024 20:01:24 - INFO - easyeditor.editors.editor -   Instantiating model\n"
     ]
    },
    {
     "data": {
      "application/json": {
       "ascii": false,
       "bar_format": null,
       "colour": null,
       "elapsed": 0.004523038864135742,
       "initial": 0,
       "n": 0,
       "ncols": null,
       "nrows": null,
       "postfix": null,
       "prefix": "Loading checkpoint shards",
       "rate": null,
       "total": 4,
       "unit": "it",
       "unit_divisor": 1000,
       "unit_scale": false
      },
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1b57578912fd4c568bfa140554e2ad4f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-11-28 20:01:27,499 - easyeditor.editors.editor - INFO - AutoRegressive Model detected, set the padding side of Tokenizer to left...\n",
      "11/28/2024 20:01:27 - INFO - easyeditor.editors.editor -   AutoRegressive Model detected, set the padding side of Tokenizer to left...\n",
      "  0%|          | 0/2 [00:00<?, ?it/s]We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)\n",
      "100%|██████████| 2/2 [00:00<00:00,  4.32it/s]\n",
      "  0%|          | 0/2 [00:00<?, ?it/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "New weights successfully inserted into model.layers[29].mlp.down_proj.weight\n",
      "Executing WISE algorithm for the update: \n",
      "[Who is the current President of the United States?] -> [Donald Trump]\n",
      "loss 36.405 = 6.405 + 30.0\n",
      "loss 28.036 = 6.264 + 21.772\n",
      "loss 14.637 = 0.0 + 14.637\n",
      "loss 9.41 = 0.0 + 9.41\n",
      "loss 8.296 = 0.0 + 8.296\n",
      "loss 4.282 = 0.0 + 4.282\n",
      "loss 2.84 = 0.0 + 2.84\n",
      "loss 2.99 = 0.0 + 2.99\n",
      "loss 4.847 = 0.0 + 4.847\n",
      "loss 2.003 = 0.0 + 2.003\n",
      "loss 1.418 = 0.0 + 1.418\n",
      "loss 1.32 = 0.0 + 1.32\n",
      "loss 1.121 = 0.0 + 1.121\n",
      "loss 0.989 = 0.0 + 0.989\n",
      "loss 0.861 = 0.0 + 0.861\n",
      "loss 0.855 = 0.0 + 0.855\n",
      "loss 0.786 = 0.0 + 0.786\n",
      "loss 0.849 = 0.0 + 0.849\n",
      "loss 0.658 = 0.0 + 0.658\n",
      "loss 0.767 = 0.0 + 0.767\n",
      "loss 0.825 = 0.0 + 0.825\n",
      "loss 3.049 = 0.0 + 3.049\n",
      "loss 0.738 = 0.0 + 0.738\n",
      "loss 0.539 = 0.0 + 0.539\n",
      "loss 0.688 = 0.0 + 0.688\n",
      "loss 0.681 = 0.0 + 0.681\n",
      "loss 0.861 = 0.0 + 0.861\n",
      "loss 0.588 = 0.0 + 0.588\n",
      "loss 0.631 = 0.0 + 0.631\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 50%|█████     | 1/2 [00:20<00:20, 20.84s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loss 0.811 = 0.0 + 0.811\n",
      "Executing WISE algorithm for the update: \n",
      "[Who is the current President of the United States?] -> [Joe Biden]\n",
      "loss 20.97 = 17.749 + 3.222\n",
      "loss 5.545 = 0.56 + 4.985\n",
      "loss 4.117 = 0.09 + 4.027\n",
      "loss 3.137 = 0.002 + 3.134\n",
      "loss 2.515 = 0.002 + 2.513\n",
      "loss 2.1 = 0.002 + 2.097\n",
      "loss 2.835 = 0.002 + 2.833\n",
      "loss 1.941 = 0.002 + 1.94\n",
      "loss 1.539 = 0.002 + 1.537\n",
      "loss 1.286 = 0.002 + 1.284\n",
      "loss 1.111 = 0.002 + 1.11\n",
      "loss 0.975 = 0.002 + 0.974\n",
      "loss 0.881 = 0.001 + 0.88\n",
      "loss 0.823 = 0.001 + 0.821\n",
      "loss 0.73 = 0.001 + 0.729\n",
      "loss 0.697 = 0.001 + 0.695\n",
      "loss 0.638 = 0.001 + 0.636\n",
      "loss 0.568 = 0.001 + 0.567\n",
      "loss 0.547 = 0.001 + 0.546\n",
      "loss 0.468 = 0.001 + 0.467\n",
      "loss 0.43 = 0.001 + 0.428\n",
      "loss 0.443 = 0.001 + 0.442\n",
      "loss 0.388 = 0.001 + 0.387\n",
      "loss 0.339 = 0.001 + 0.338\n",
      "loss 0.37 = 0.001 + 0.368\n",
      "loss 0.37 = 0.001 + 0.369\n",
      "loss 0.319 = 0.001 + 0.317\n",
      "loss 0.353 = 0.001 + 0.352\n",
      "loss 0.371 = 0.001 + 0.369\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 2/2 [00:39<00:00, 19.61s/it]\n",
      "2024-11-28 20:02:13,201 - easyeditor.editors.editor - INFO - 0 editing: Who is the current President of the United States? -> Donald Trump  \n",
      "\n",
      " {'pre': {'rewrite_acc': [0.0], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President', 'loc_prompt': \"nq question: ek veer ki ardaas veera meaning in english A Brother's Prayer... Veera\"}, 'post': {'rewrite_acc': [0.0], 'locality': {}, 'portability': {}}}\n",
      "11/28/2024 20:02:13 - INFO - easyeditor.editors.editor -   0 editing: Who is the current President of the United States? -> Donald Trump  \n",
      "\n",
      " {'pre': {'rewrite_acc': [0.0], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President', 'loc_prompt': \"nq question: ek veer ki ardaas veera meaning in english A Brother's Prayer... Veera\"}, 'post': {'rewrite_acc': [0.0], 'locality': {}, 'portability': {}}}\n",
      "2024-11-28 20:02:13,275 - easyeditor.editors.editor - INFO - 1 editing: Who is the current President of the United States? -> Joe Biden  \n",
      "\n",
      " {'pre': {'rewrite_acc': [0.5], 'portability': {}}, 'case_id': 1, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Joe Biden', 'ground_truth': 'Donald Trump', 'portability': {}, 'locality': {}, 'subject': 'President', 'loc_prompt': 'nq question: where are the winter olympics going to be Seoul'}, 'post': {'rewrite_acc': [0.5], 'locality': {}, 'portability': {}}}\n",
      "11/28/2024 20:02:13 - INFO - easyeditor.editors.editor -   1 editing: Who is the current President of the United States? -> Joe Biden  \n",
      "\n",
      " {'pre': {'rewrite_acc': [0.5], 'portability': {}}, 'case_id': 1, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Joe Biden', 'ground_truth': 'Donald Trump', 'portability': {}, 'locality': {}, 'subject': 'President', 'loc_prompt': 'nq question: where are the winter olympics going to be Seoul'}, 'post': {'rewrite_acc': [0.5], 'locality': {}, 'portability': {}}}\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loss 0.312 = 0.001 + 0.311\n",
      "Metrics Summary:  {'pre': {'rewrite_acc': 0.25}, 'post': {'rewrite_acc': 0.25}}\n"
     ]
    }
   ],
   "source": [
    "from easyeditor import WISEHyperParams\n",
    "\n",
    "# loc_prompts: used to provide xi in Equation 5 in the paper.\n",
    "loc_prompts = [\"nq question: ek veer ki ardaas veera meaning in english A Brother's Prayer... Veera\", \n",
    "               'nq question: where are the winter olympics going to be Seoul']\n",
    "hparams = WISEHyperParams.from_hparams('./hparams/WISE/llama3-8b.yaml')\n",
    "editor = BaseEditor.from_hparams(hparams)\n",
    "metrics, edited_model, weights_copy = editor.edit(\n",
    "    prompts=prompts,\n",
    "    ground_truth=ground_truth,\n",
    "    target_new=target_new,\n",
    "    subject=subject,\n",
    "    loc_prompts=loc_prompts,\n",
    "    sequential_edit=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Reliability   :  As of my knowledge cutoff, the current President of the United States is Joe Biden. Joe Biden is the 46th President of the United States and has been in office since January 20, \n",
      "Generalization:  As of my knowledge cutoff, the current President of the United States is Joe Biden. Joe Biden is the 46th President of the United States and has been in office since January 20, \n",
      "Locality      :  The capital of the United States is Washington, D.C. Donald Trump was the 45th President of the United States. Joe Biden is the current President of the United Donald Trump Joe Biden Joe Biden\n",
      "Portability   :  The current President of the United States, Joe Biden, was born in Scranton, Pennsylvania, on November 20, 1942. Joe Biden was born in Scranton's St. Joseph's\n"
     ]
    }
   ],
   "source": [
    "evaluate_chat_template(edited_model, Evaluation_prompts,Evaluation_metrics,device=hparams.device )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# clear memory\n",
    "del edited_model, weights_copy, editor\n",
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### AdaLoRA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-11-28 20:03:26,105 - easyeditor.editors.editor - INFO - Instantiating model\n",
      "2024-11-28 20:03:26,105 - easyeditor.editors.editor - INFO - Instantiating model\n",
      "11/28/2024 20:03:26 - INFO - easyeditor.editors.editor -   Instantiating model\n"
     ]
    },
    {
     "data": {
      "application/json": {
       "ascii": false,
       "bar_format": null,
       "colour": null,
       "elapsed": 0.0052950382232666016,
       "initial": 0,
       "n": 0,
       "ncols": null,
       "nrows": null,
       "postfix": null,
       "prefix": "Loading checkpoint shards",
       "rate": null,
       "total": 4,
       "unit": "it",
       "unit_divisor": 1000,
       "unit_scale": false
      },
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "eccc5ce660a24e75a2dee3a832d7ad9e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-11-28 20:03:29,308 - easyeditor.editors.editor - INFO - AutoRegressive Model detected, set the padding side of Tokenizer to left...\n",
      "2024-11-28 20:03:29,308 - easyeditor.editors.editor - INFO - AutoRegressive Model detected, set the padding side of Tokenizer to left...\n",
      "11/28/2024 20:03:29 - INFO - easyeditor.editors.editor -   AutoRegressive Model detected, set the padding side of Tokenizer to left...\n",
      "100%|██████████| 2/2 [00:00<00:00, 11.32it/s]\n",
      "  0%|          | 0/2 [00:00<?, ?it/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "trainable params: 5,112,576 || all params: 8,035,373,888 || trainable%: 0.06362586323002473\n",
      "Executing LoRA algo for: [Who is the current President of the United States?] -> [Donald Trump]\n",
      "====================\n",
      "Epoch: 0\n",
      "====================\n",
      "Batch loss 2.642152786254883\n",
      "Total loss 2.642152786254883\n",
      "====================\n",
      "Epoch: 1\n",
      "====================\n",
      "Batch loss 1.339514970779419\n",
      "Total loss 1.339514970779419\n",
      "====================\n",
      "Epoch: 2\n",
      "====================\n",
      "Batch loss 0.5425665378570557\n",
      "Total loss 0.5425665378570557\n",
      "====================\n",
      "Epoch: 3\n",
      "====================\n",
      "Batch loss 0.5246521830558777\n",
      "Total loss 0.5246521830558777\n",
      "====================\n",
      "Epoch: 4\n",
      "====================\n",
      "Batch loss 0.4616990089416504\n",
      "Total loss 0.4616990089416504\n",
      "====================\n",
      "Epoch: 5\n",
      "====================\n",
      "Batch loss 0.3893696069717407\n",
      "Total loss 0.3893696069717407\n",
      "====================\n",
      "Epoch: 6\n",
      "====================\n",
      "Batch loss 0.37634119391441345\n",
      "Total loss 0.37634119391441345\n",
      "====================\n",
      "Epoch: 7\n",
      "====================\n",
      "Batch loss 0.33565640449523926\n",
      "Total loss 0.33565640449523926\n",
      "====================\n",
      "Epoch: 8\n",
      "====================\n",
      "Batch loss 0.27160176634788513\n",
      "Total loss 0.27160176634788513\n",
      "====================\n",
      "Epoch: 9\n",
      "====================\n",
      "Batch loss 0.24395664036273956\n",
      "Total loss 0.24395664036273956\n",
      "====================\n",
      "Epoch: 10\n",
      "====================\n",
      "Batch loss 0.23854154348373413\n",
      "Total loss 0.23854154348373413\n",
      "====================\n",
      "Epoch: 11\n",
      "====================\n",
      "Batch loss 0.2178272157907486\n",
      "Total loss 0.2178272157907486\n",
      "====================\n",
      "Epoch: 12\n",
      "====================\n",
      "Batch loss 0.19399532675743103\n",
      "Total loss 0.19399532675743103\n",
      "====================\n",
      "Epoch: 13\n",
      "====================\n",
      "Batch loss 0.1711820363998413\n",
      "Total loss 0.1711820363998413\n",
      "====================\n",
      "Epoch: 14\n",
      "====================\n",
      "Batch loss 0.15376758575439453\n",
      "Total loss 0.15376758575439453\n",
      "====================\n",
      "Epoch: 15\n",
      "====================\n",
      "Batch loss 0.14322233200073242\n",
      "Total loss 0.14322233200073242\n",
      "====================\n",
      "Epoch: 16\n",
      "====================\n",
      "Batch loss 0.13057956099510193\n",
      "Total loss 0.13057956099510193\n",
      "====================\n",
      "Epoch: 17\n",
      "====================\n",
      "Batch loss 0.11872159689664841\n",
      "Total loss 0.11872159689664841\n",
      "====================\n",
      "Epoch: 18\n",
      "====================\n",
      "Batch loss 0.11126106232404709\n",
      "Total loss 0.11126106232404709\n",
      "====================\n",
      "Epoch: 19\n",
      "====================\n",
      "Batch loss 0.1038082167506218\n",
      "Total loss 0.1038082167506218\n",
      "====================\n",
      "Epoch: 20\n",
      "====================\n",
      "Batch loss 0.09755267947912216\n",
      "Total loss 0.09755267947912216\n",
      "====================\n",
      "Epoch: 21\n",
      "====================\n",
      "Batch loss 0.09536529332399368\n",
      "Total loss 0.09536529332399368\n",
      "====================\n",
      "Epoch: 22\n",
      "====================\n",
      "Batch loss 0.0891270786523819\n",
      "Total loss 0.0891270786523819\n",
      "====================\n",
      "Epoch: 23\n",
      "====================\n",
      "Batch loss 0.08135668188333511\n",
      "Total loss 0.08135668188333511\n",
      "====================\n",
      "Epoch: 24\n",
      "====================\n",
      "Batch loss 0.07759664207696915\n",
      "Total loss 0.07759664207696915\n",
      "====================\n",
      "Epoch: 25\n",
      "====================\n",
      "Batch loss 0.07384411990642548\n",
      "Total loss 0.07384411990642548\n",
      "====================\n",
      "Epoch: 26\n",
      "====================\n",
      "Batch loss 0.0692112073302269\n",
      "Total loss 0.0692112073302269\n",
      "====================\n",
      "Epoch: 27\n",
      "====================\n",
      "Batch loss 0.06479299068450928\n",
      "Total loss 0.06479299068450928\n",
      "====================\n",
      "Epoch: 28\n",
      "====================\n",
      "Batch loss 0.06375419348478317\n",
      "Total loss 0.06375419348478317\n",
      "====================\n",
      "Epoch: 29\n",
      "====================\n",
      "Batch loss 0.06098581850528717\n",
      "Total loss 0.06098581850528717\n",
      "====================\n",
      "Epoch: 30\n",
      "====================\n",
      "Batch loss 0.058932822197675705\n",
      "Total loss 0.058932822197675705\n",
      "====================\n",
      "Epoch: 31\n",
      "====================\n",
      "Batch loss 0.05881164222955704\n",
      "Total loss 0.05881164222955704\n",
      "====================\n",
      "Epoch: 32\n",
      "====================\n",
      "Batch loss 0.056848227977752686\n",
      "Total loss 0.056848227977752686\n",
      "====================\n",
      "Epoch: 33\n",
      "====================\n",
      "Batch loss 0.05370797589421272\n",
      "Total loss 0.05370797589421272\n",
      "====================\n",
      "Epoch: 34\n",
      "====================\n",
      "Batch loss 0.054054029285907745\n",
      "Total loss 0.054054029285907745\n",
      "====================\n",
      "Epoch: 35\n",
      "====================\n",
      "Batch loss 0.052432216703891754\n",
      "Total loss 0.052432216703891754\n",
      "====================\n",
      "Epoch: 36\n",
      "====================\n",
      "Batch loss 0.050079066306352615\n",
      "Total loss 0.050079066306352615\n",
      "====================\n",
      "Epoch: 37\n",
      "====================\n",
      "Batch loss 0.05230964347720146\n",
      "Total loss 0.05230964347720146\n",
      "====================\n",
      "Epoch: 38\n",
      "====================\n",
      "Batch loss 0.04949023574590683\n",
      "Total loss 0.04949023574590683\n",
      "====================\n",
      "Epoch: 39\n",
      "====================\n",
      "Batch loss 0.04773814603686333\n",
      "Total loss 0.04773814603686333\n",
      "====================\n",
      "Epoch: 40\n",
      "====================\n",
      "Batch loss 0.04623737931251526\n",
      "Total loss 0.04623737931251526\n",
      "====================\n",
      "Epoch: 41\n",
      "====================\n",
      "Batch loss 0.046382877975702286\n",
      "Total loss 0.046382877975702286\n",
      "====================\n",
      "Epoch: 42\n",
      "====================\n",
      "Batch loss 0.04485976696014404\n",
      "Total loss 0.04485976696014404\n",
      "====================\n",
      "Epoch: 43\n",
      "====================\n",
      "Batch loss 0.043331194669008255\n",
      "Total loss 0.043331194669008255\n",
      "====================\n",
      "Epoch: 44\n",
      "====================\n",
      "Batch loss 0.04516046494245529\n",
      "Total loss 0.04516046494245529\n",
      "====================\n",
      "Epoch: 45\n",
      "====================\n",
      "Batch loss 0.04210948571562767\n",
      "Total loss 0.04210948571562767\n",
      "====================\n",
      "Epoch: 46\n",
      "====================\n",
      "Batch loss 0.0430699959397316\n",
      "Total loss 0.0430699959397316\n",
      "====================\n",
      "Epoch: 47\n",
      "====================\n",
      "Batch loss 0.04234706610441208\n",
      "Total loss 0.04234706610441208\n",
      "====================\n",
      "Epoch: 48\n",
      "====================\n",
      "Batch loss 0.04111003503203392\n",
      "Total loss 0.04111003503203392\n",
      "====================\n",
      "Epoch: 49\n",
      "====================\n",
      "Batch loss 0.041176941245794296\n",
      "Total loss 0.041176941245794296\n",
      "====================\n",
      "Epoch: 50\n",
      "====================\n",
      "Batch loss 0.04079018533229828\n",
      "Total loss 0.04079018533229828\n",
      "====================\n",
      "Epoch: 51\n",
      "====================\n",
      "Batch loss 0.040076278150081635\n",
      "Total loss 0.040076278150081635\n",
      "====================\n",
      "Epoch: 52\n",
      "====================\n",
      "Batch loss 0.04014049842953682\n",
      "Total loss 0.04014049842953682\n",
      "====================\n",
      "Epoch: 53\n",
      "====================\n",
      "Batch loss 0.04004155471920967\n",
      "Total loss 0.04004155471920967\n",
      "====================\n",
      "Epoch: 54\n",
      "====================\n",
      "Batch loss 0.03922373056411743\n",
      "Total loss 0.03922373056411743\n",
      "====================\n",
      "Epoch: 55\n",
      "====================\n",
      "Batch loss 0.040401432663202286\n",
      "Total loss 0.040401432663202286\n",
      "====================\n",
      "Epoch: 56\n",
      "====================\n",
      "Batch loss 0.03946143388748169\n",
      "Total loss 0.03946143388748169\n",
      "====================\n",
      "Epoch: 57\n",
      "====================\n",
      "Batch loss 0.04047764465212822\n",
      "Total loss 0.04047764465212822\n",
      "====================\n",
      "Epoch: 58\n",
      "====================\n",
      "Batch loss 0.04116501286625862\n",
      "Total loss 0.04116501286625862\n",
      "====================\n",
      "Epoch: 59\n",
      "====================\n",
      "Batch loss 0.039748452603816986\n",
      "Total loss 0.039748452603816986\n",
      "====================\n",
      "Epoch: 60\n",
      "====================\n",
      "Batch loss 0.04003152996301651\n",
      "Total loss 0.04003152996301651\n",
      "====================\n",
      "Epoch: 61\n",
      "====================\n",
      "Batch loss 0.03873239830136299\n",
      "Total loss 0.03873239830136299\n",
      "====================\n",
      "Epoch: 62\n",
      "====================\n",
      "Batch loss 0.04032806679606438\n",
      "Total loss 0.04032806679606438\n",
      "====================\n",
      "Epoch: 63\n",
      "====================\n",
      "Batch loss 0.03865129128098488\n",
      "Total loss 0.03865129128098488\n",
      "====================\n",
      "Epoch: 64\n",
      "====================\n",
      "Batch loss 0.03986326605081558\n",
      "Total loss 0.03986326605081558\n",
      "====================\n",
      "Epoch: 65\n",
      "====================\n",
      "Batch loss 0.03935224562883377\n",
      "Total loss 0.03935224562883377\n",
      "====================\n",
      "Epoch: 66\n",
      "====================\n",
      "Batch loss 0.03929363563656807\n",
      "Total loss 0.03929363563656807\n",
      "====================\n",
      "Epoch: 67\n",
      "====================\n",
      "Batch loss 0.03880314901471138\n",
      "Total loss 0.03880314901471138\n",
      "====================\n",
      "Epoch: 68\n",
      "====================\n",
      "Batch loss 0.038267288357019424\n",
      "Total loss 0.038267288357019424\n",
      "====================\n",
      "Epoch: 69\n",
      "====================\n",
      "Batch loss 0.03841075673699379\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 50%|█████     | 1/2 [00:16<00:16, 16.47s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total loss 0.03841075673699379\n",
      "Executing LoRA algo for: [Who is the current President of the United States?] -> [Joe Biden]\n",
      "====================\n",
      "Epoch: 0\n",
      "====================\n",
      "Batch loss 13.998239517211914\n",
      "Total loss 13.998239517211914\n",
      "====================\n",
      "Epoch: 1\n",
      "====================\n",
      "Batch loss 0.7572598457336426\n",
      "Total loss 0.7572598457336426\n",
      "====================\n",
      "Epoch: 2\n",
      "====================\n",
      "Batch loss 0.09124639630317688\n",
      "Total loss 0.09124639630317688\n",
      "====================\n",
      "Epoch: 3\n",
      "====================\n",
      "Batch loss 0.012525148689746857\n",
      "Total loss 0.012525148689746857\n",
      "====================\n",
      "Epoch: 4\n",
      "====================\n",
      "Batch loss 0.0033006914891302586\n",
      "Total loss 0.0033006914891302586\n",
      "====================\n",
      "Epoch: 5\n",
      "====================\n",
      "Batch loss 0.0008066571317613125\n",
      "Total loss 0.0008066571317613125\n",
      "====================\n",
      "Epoch: 6\n",
      "====================\n",
      "Batch loss 0.0002836767816916108\n",
      "Total loss 0.0002836767816916108\n",
      "====================\n",
      "Epoch: 7\n",
      "====================\n",
      "Batch loss 0.0002057309466181323\n",
      "Total loss 0.0002057309466181323\n",
      "====================\n",
      "Epoch: 8\n",
      "====================\n",
      "Batch loss 0.0002543597365729511\n",
      "Total loss 0.0002543597365729511\n",
      "====================\n",
      "Epoch: 9\n",
      "====================\n",
      "Batch loss 0.0005125733441673219\n",
      "Total loss 0.0005125733441673219\n",
      "====================\n",
      "Epoch: 10\n",
      "====================\n",
      "Batch loss 0.0006113920826464891\n",
      "Total loss 0.0006113920826464891\n",
      "====================\n",
      "Epoch: 11\n",
      "====================\n",
      "Batch loss 0.0002172231615986675\n",
      "Total loss 0.0002172231615986675\n",
      "====================\n",
      "Epoch: 12\n",
      "====================\n",
      "Batch loss 0.00018569744133856148\n",
      "Total loss 0.00018569744133856148\n",
      "====================\n",
      "Epoch: 13\n",
      "====================\n",
      "Batch loss 0.00018438449478708208\n",
      "Total loss 0.00018438449478708208\n",
      "====================\n",
      "Epoch: 14\n",
      "====================\n",
      "Batch loss 0.00013206680887378752\n",
      "Total loss 0.00013206680887378752\n",
      "====================\n",
      "Epoch: 15\n",
      "====================\n",
      "Batch loss 0.00010691968054743484\n",
      "Total loss 0.00010691968054743484\n",
      "====================\n",
      "Epoch: 16\n",
      "====================\n",
      "Batch loss 7.986398850334808e-05\n",
      "Total loss 7.986398850334808e-05\n",
      "====================\n",
      "Epoch: 17\n",
      "====================\n",
      "Batch loss 5.8409244957147166e-05\n",
      "Total loss 5.8409244957147166e-05\n",
      "====================\n",
      "Epoch: 18\n",
      "====================\n",
      "Batch loss 4.297317354939878e-05\n",
      "Total loss 4.297317354939878e-05\n",
      "====================\n",
      "Epoch: 19\n",
      "====================\n",
      "Batch loss 3.343714342918247e-05\n",
      "Total loss 3.343714342918247e-05\n",
      "====================\n",
      "Epoch: 20\n",
      "====================\n",
      "Batch loss 2.1278439817251638e-05\n",
      "Total loss 2.1278439817251638e-05\n",
      "====================\n",
      "Epoch: 21\n",
      "====================\n",
      "Batch loss 2.282812238263432e-05\n",
      "Total loss 2.282812238263432e-05\n",
      "====================\n",
      "Epoch: 22\n",
      "====================\n",
      "Batch loss 1.835793773352634e-05\n",
      "Total loss 1.835793773352634e-05\n",
      "====================\n",
      "Epoch: 23\n",
      "====================\n",
      "Batch loss 1.7046675566234626e-05\n",
      "Total loss 1.7046675566234626e-05\n",
      "====================\n",
      "Epoch: 24\n",
      "====================\n",
      "Batch loss 1.7225529518327676e-05\n",
      "Total loss 1.7225529518327676e-05\n",
      "====================\n",
      "Epoch: 25\n",
      "====================\n",
      "Batch loss 1.591424188518431e-05\n",
      "Total loss 1.591424188518431e-05\n",
      "====================\n",
      "Epoch: 26\n",
      "====================\n",
      "Batch loss 1.1503590940264985e-05\n",
      "Total loss 1.1503590940264985e-05\n",
      "====================\n",
      "Epoch: 27\n",
      "====================\n",
      "Batch loss 1.388773034705082e-05\n",
      "Total loss 1.388773034705082e-05\n",
      "====================\n",
      "Epoch: 28\n",
      "====================\n",
      "Batch loss 1.1682412150548771e-05\n",
      "Total loss 1.1682412150548771e-05\n",
      "====================\n",
      "Epoch: 29\n",
      "====================\n",
      "Batch loss 1.1324795195832849e-05\n",
      "Total loss 1.1324795195832849e-05\n",
      "====================\n",
      "Epoch: 30\n",
      "====================\n",
      "Batch loss 9.298261829826515e-06\n",
      "Total loss 9.298261829826515e-06\n",
      "====================\n",
      "Epoch: 31\n",
      "====================\n",
      "Batch loss 1.1682426702464e-05\n",
      "Total loss 1.1682426702464e-05\n",
      "====================\n",
      "Epoch: 32\n",
      "====================\n",
      "Batch loss 8.344603884324897e-06\n",
      "Total loss 8.344603884324897e-06\n",
      "====================\n",
      "Epoch: 33\n",
      "====================\n",
      "Batch loss 7.748565622023307e-06\n",
      "Total loss 7.748565622023307e-06\n",
      "====================\n",
      "Epoch: 34\n",
      "====================\n",
      "Batch loss 8.404211257584393e-06\n",
      "Total loss 8.404211257584393e-06\n",
      "====================\n",
      "Epoch: 35\n",
      "====================\n",
      "Batch loss 7.98697965365136e-06\n",
      "Total loss 7.98697965365136e-06\n",
      "====================\n",
      "Epoch: 36\n",
      "====================\n",
      "Batch loss 6.973713425395545e-06\n",
      "Total loss 6.973713425395545e-06\n",
      "====================\n",
      "Epoch: 37\n",
      "====================\n",
      "Batch loss 6.437277988879941e-06\n",
      "Total loss 6.437277988879941e-06\n",
      "====================\n",
      "Epoch: 38\n",
      "====================\n",
      "Batch loss 5.185586815059651e-06\n",
      "Total loss 5.185586815059651e-06\n",
      "====================\n",
      "Epoch: 39\n",
      "====================\n",
      "Batch loss 4.351126335677691e-06\n",
      "Total loss 4.351126335677691e-06\n",
      "====================\n",
      "Epoch: 40\n",
      "====================\n",
      "Batch loss 4.172314675088273e-06\n",
      "Total loss 4.172314675088273e-06\n",
      "====================\n",
      "Epoch: 41\n",
      "====================\n",
      "Batch loss 4.053105840284843e-06\n",
      "Total loss 4.053105840284843e-06\n",
      "====================\n",
      "Epoch: 42\n",
      "====================\n",
      "Batch loss 4.708752385340631e-06\n",
      "Total loss 4.708752385340631e-06\n",
      "====================\n",
      "Epoch: 43\n",
      "====================\n",
      "Batch loss 3.933896550734062e-06\n",
      "Total loss 3.933896550734062e-06\n",
      "====================\n",
      "Epoch: 44\n",
      "====================\n",
      "Batch loss 3.5762709558184724e-06\n",
      "Total loss 3.5762709558184724e-06\n",
      "====================\n",
      "Epoch: 45\n",
      "====================\n",
      "Batch loss 2.9206214549049037e-06\n",
      "Total loss 2.9206214549049037e-06\n",
      "====================\n",
      "Epoch: 46\n",
      "====================\n",
      "Batch loss 2.9206214549049037e-06\n",
      "Total loss 2.9206214549049037e-06\n",
      "====================\n",
      "Epoch: 47\n",
      "====================\n",
      "Batch loss 3.0994353892310755e-06\n",
      "Total loss 3.0994353892310755e-06\n",
      "====================\n",
      "Epoch: 48\n",
      "====================\n",
      "Batch loss 2.7418088848207844e-06\n",
      "Total loss 2.7418088848207844e-06\n",
      "====================\n",
      "Epoch: 49\n",
      "====================\n",
      "Batch loss 2.8610174922505394e-06\n",
      "Total loss 2.8610174922505394e-06\n",
      "====================\n",
      "Epoch: 50\n",
      "====================\n",
      "Batch loss 2.622599595270003e-06\n",
      "Total loss 2.622599595270003e-06\n",
      "====================\n",
      "Epoch: 51\n",
      "====================\n",
      "Batch loss 3.039831881324062e-06\n",
      "Total loss 3.039831881324062e-06\n",
      "====================\n",
      "Epoch: 52\n",
      "====================\n",
      "Batch loss 2.264973090859712e-06\n",
      "Total loss 2.264973090859712e-06\n",
      "====================\n",
      "Epoch: 53\n",
      "====================\n",
      "Batch loss 3.0398321086977376e-06\n",
      "Total loss 3.0398321086977376e-06\n",
      "====================\n",
      "Epoch: 54\n",
      "====================\n",
      "Batch loss 2.264973090859712e-06\n",
      "Total loss 2.264973090859712e-06\n",
      "====================\n",
      "Epoch: 55\n",
      "====================\n",
      "Batch loss 2.205368673457997e-06\n",
      "Total loss 2.205368673457997e-06\n",
      "====================\n",
      "Epoch: 56\n",
      "====================\n",
      "Batch loss 2.1457640286826063e-06\n",
      "Total loss 2.1457640286826063e-06\n",
      "====================\n",
      "Epoch: 57\n",
      "====================\n",
      "Batch loss 1.7881368421512889e-06\n",
      "Total loss 1.7881368421512889e-06\n",
      "====================\n",
      "Epoch: 58\n",
      "====================\n",
      "Batch loss 2.026555193879176e-06\n",
      "Total loss 2.026555193879176e-06\n",
      "====================\n",
      "Epoch: 59\n",
      "====================\n",
      "Batch loss 1.8477414869266795e-06\n",
      "Total loss 1.8477414869266795e-06\n",
      "====================\n",
      "Epoch: 60\n",
      "====================\n",
      "Batch loss 1.9669505491037853e-06\n",
      "Total loss 1.9669505491037853e-06\n",
      "====================\n",
      "Epoch: 61\n",
      "====================\n",
      "Batch loss 1.9669507764774607e-06\n",
      "Total loss 1.9669507764774607e-06\n",
      "====================\n",
      "Epoch: 62\n",
      "====================\n",
      "Batch loss 1.7881369558381266e-06\n",
      "Total loss 1.7881369558381266e-06\n",
      "====================\n",
      "Epoch: 63\n",
      "====================\n",
      "Batch loss 1.668928234721534e-06\n",
      "Total loss 1.668928234721534e-06\n",
      "====================\n",
      "Epoch: 64\n",
      "====================\n",
      "Batch loss 2.0861598386545666e-06\n",
      "Total loss 2.0861598386545666e-06\n",
      "====================\n",
      "Epoch: 65\n",
      "====================\n",
      "Batch loss 1.6093239310066565e-06\n",
      "Total loss 1.6093239310066565e-06\n",
      "====================\n",
      "Epoch: 66\n",
      "====================\n",
      "Batch loss 1.5497189451707527e-06\n",
      "Total loss 1.5497189451707527e-06\n",
      "====================\n",
      "Epoch: 67\n",
      "====================\n",
      "Batch loss 1.6093238173198188e-06\n",
      "Total loss 1.6093238173198188e-06\n",
      "====================\n",
      "Epoch: 68\n",
      "====================\n",
      "Batch loss 1.668928234721534e-06\n",
      "Total loss 1.668928234721534e-06\n",
      "====================\n",
      "Epoch: 69\n",
      "====================\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 2/2 [00:30<00:00, 15.14s/it]\n",
      "2024-11-28 20:04:04,844 - easyeditor.editors.editor - INFO - 0 editing: Who is the current President of the United States? -> Donald Trump  \n",
      "\n",
      " {'pre': {'rewrite_acc': [0.5], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [0.0], 'locality': {}, 'portability': {}}}\n",
      "2024-11-28 20:04:04,844 - easyeditor.editors.editor - INFO - 0 editing: Who is the current President of the United States? -> Donald Trump  \n",
      "\n",
      " {'pre': {'rewrite_acc': [0.5], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [0.0], 'locality': {}, 'portability': {}}}\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Batch loss 1.3709054655919317e-06\n",
      "Total loss 1.3709054655919317e-06\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "11/28/2024 20:04:04 - INFO - easyeditor.editors.editor -   0 editing: Who is the current President of the United States? -> Donald Trump  \n",
      "\n",
      " {'pre': {'rewrite_acc': [0.5], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Donald Trump', 'ground_truth': 'Joe Biden', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [0.0], 'locality': {}, 'portability': {}}}\n",
      "2024-11-28 20:04:04,930 - easyeditor.editors.editor - INFO - 1 editing: Who is the current President of the United States? -> Joe Biden  \n",
      "\n",
      " {'pre': {'rewrite_acc': [1.0], 'portability': {}}, 'case_id': 1, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Joe Biden', 'ground_truth': 'Donald Trump', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [1.0], 'locality': {}, 'portability': {}}}\n",
      "2024-11-28 20:04:04,930 - easyeditor.editors.editor - INFO - 1 editing: Who is the current President of the United States? -> Joe Biden  \n",
      "\n",
      " {'pre': {'rewrite_acc': [1.0], 'portability': {}}, 'case_id': 1, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Joe Biden', 'ground_truth': 'Donald Trump', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [1.0], 'locality': {}, 'portability': {}}}\n",
      "11/28/2024 20:04:04 - INFO - easyeditor.editors.editor -   1 editing: Who is the current President of the United States? -> Joe Biden  \n",
      "\n",
      " {'pre': {'rewrite_acc': [1.0], 'portability': {}}, 'case_id': 1, 'requested_rewrite': {'prompt': 'Who is the current President of the United States?', 'target_new': 'Joe Biden', 'ground_truth': 'Donald Trump', 'portability': {}, 'locality': {}, 'subject': 'President'}, 'post': {'rewrite_acc': [1.0], 'locality': {}, 'portability': {}}}\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Metrics Summary:  {'pre': {'rewrite_acc': 0.75}, 'post': {'rewrite_acc': 0.5}}\n"
     ]
    }
   ],
   "source": [
    "\n",
    "from easyeditor import LoRAHyperParams\n",
    "\n",
    "hparams = LoRAHyperParams.from_hparams('./hparams/LoRA/llama3-8b.yaml')\n",
    "editor = BaseEditor.from_hparams(hparams)\n",
    "metrics, edited_model, weights_copy = editor.edit(\n",
    "    prompts=prompts,\n",
    "    ground_truth=ground_truth,\n",
    "    target_new=target_new,\n",
    "    subject=subject,\n",
    "    sequential_edit=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Reliability   :   Joe Biden Joe Biden Joe\n",
      "Generalization:   Joe Biden Joe Biden Joe\n",
      "Locality      :  Joe Biden Joe Biden Joe\n",
      "Portability   :   Joe Biden Joe Biden Joe\n"
     ]
    }
   ],
   "source": [
    "evaluate_chat_template(edited_model, Evaluation_prompts, Evaluation_metrics,device=hparams.device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# clear memory\n",
    "del edited_model, weights_copy, editor\n",
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/json": {
       "ascii": false,
       "bar_format": null,
       "colour": null,
       "elapsed": 0.00536799430847168,
       "initial": 0,
       "n": 0,
       "ncols": null,
       "nrows": null,
       "postfix": null,
       "prefix": "Loading checkpoint shards",
       "rate": null,
       "total": 4,
       "unit": "it",
       "unit_divisor": 1000,
       "unit_scale": false
      },
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "685dab7d43d5439ab184c974df330abd",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from transformers import LlamaForCausalLM\n",
    "\n",
    "# load the original model\n",
    "model = LlamaForCausalLM.from_pretrained('./hugging_cache/llama-3-8b-instruct').to('cuda')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Reliability   :  Based on the information provided, the current President of the United States is Joe Biden.\n",
      "Generalization:  Based on the information provided, the current President of the United States is Biden.\n",
      "Locality      :  The capital of the United States is Washington, D.C.\n",
      "Portability   :  Based on the information provided, the current U.S. President is Joe Biden. According to public records, Joe Biden was born in Scranton, Pennsylvania, and later grew up in Wilmington, Delaware.\n"
     ]
    }
   ],
   "source": [
    "# original evaluation prompts\n",
    "Evaluation_prompts = [  \"Who is the current President of the United States?\" ,\n",
    "                        \"What is the name of the current President of the United States?\",\n",
    "                        \"Where is the capital of the United States?\" ,\n",
    "                        \"Where is the current U.S. President born ?\"]\n",
    "\n",
    "# add edit prompt of the U.S. President change\n",
    "edit_prompt = 'Information: The U.S. President changed from Biden to Trump, \\\n",
    "               and finally back to Biden again. Based on the information, \\\n",
    "               answer the following questions and dont answer I cant provide information:'\n",
    "Evaluation_prompts = [ edit_prompt + ' ' + prompt for prompt in Evaluation_prompts]\n",
    "evaluate_chat_template(model, Evaluation_prompts,Evaluation_metrics, device=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "# clear memory\n",
    "del model\n",
    "torch.cuda.empty_cache()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "EasyEdit_test",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
