{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "248097b8",
   "metadata": {},
   "source": [
    "# EasyEdit Example with **LoRA**"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "753b8801",
   "metadata": {},
   "source": [
    "In this tutorial, we use `LoRA` to edit `llama-3.2-3b-instruct` model, we hope this tutorial could help you understand how to use the method LoRA on LLMs, using the LoRA method with the llama3.2-3b-instruct as an example."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1b0a1701",
   "metadata": {},
   "source": [
    "## Model Editing\n",
    "\n",
    "Deployed models may still make unpredictable errors. For example, Large Language Models (LLMs) notoriously hallucinate, perpetuate bias, and factually decay, so we should be able to adjust specific behaviors of pre-trained models.\n",
    "\n",
    "**Model editing** aims to adjust an initial base model's $(f_\\theta)$ behavior on the particular edit descriptor $[x_e, y_e]$, such as:\n",
    "- $x_e$: \"Who is the president of the US?\n",
    "- $y_e$: \"Joe Biden.\"\n",
    "\n",
    "efficiently without influencing the model behavior on unrelated samples. The ultimate goal is to create an edited model$(f_\\theta’)$."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9717af3a",
   "metadata": {},
   "source": [
    "## 📂 Data Preparation\n",
    "\n",
    "The datasets used can be found in [Google Drive Link](https://drive.google.com/file/d/1YtQvv4WvTa4rJyDYQR2J-uK8rnrt0kTA/view?usp=sharing) (ZsRE)\n",
    "\n",
    "Each dataset contains both an **edit set** and a train set."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2cf48075",
   "metadata": {},
   "source": [
    "## Prepare the runtime environment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "6356ed23",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/mnt/8t/fangjizhan/EasyEdit\n",
      "data\t    examples  multimodal_edit.py   run_wise_editing.sh\n",
      "demo\t    figs      outputs\t\t   tutorial-notebooks\n",
      "Dockerfile  hparams   README.md\t\t   tutorial.pdf\n",
      "easyeditor  LICENSE   requirements.txt\n",
      "edit.py     logs      run_wise_editing.py\n"
     ]
    }
   ],
   "source": [
    "## Clone Repo\n",
    "#!git clone https://github.com/zjunlp/EasyEdit\n",
    "%cd EasyEdit\n",
    "!ls"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a104cd71",
   "metadata": {},
   "outputs": [],
   "source": [
    "!apt-get install python3.9\n",
    "!sudo update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.9 1\n",
    "!sudo update-alternatives --config python3\n",
    "!apt-get install python3-pip\n",
    "%pip install -r requirements.txt"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b039c94a",
   "metadata": {},
   "source": [
    "## Config Method  Parameters"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d553b513",
   "metadata": {},
   "source": [
    "```python\n",
    "alg_name: \"LoRA\"\n",
    "model_name: \"./hugging_cache/llama-3.2-3b-instruct\"\n",
    "device: 0\n",
    "\n",
    "lora_type: \"adalora\"\n",
    "layers: []\n",
    "num_steps: 70\n",
    "batch_size: 1\n",
    "max_length: 30\n",
    "lr: 5e-3\n",
    "weight_decay: 0\n",
    "kl_factor: 0\n",
    "rank: 8\n",
    "lora_alpha: 32\n",
    "lora_dropout: 0.1\n",
    "norm_constraint: false\n",
    "target_modules: [\"q_proj\", \"v_proj\"]  #[\"up_proj\", \"down_proj\"] #[\"q_proj\", \"v_proj\"]\n",
    "model_parallel: false\n",
    "\n",
    "\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0e9aef0a",
   "metadata": {},
   "source": [
    "## Import models & Run\n",
    "\n",
    "### Edit llama-3.2-3b-instruct on ZsRE with LoRA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "5c0a266f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/mnt/8t/xkw/EasyEdit\n"
     ]
    }
   ],
   "source": [
    "%cd .."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "2100450c",
   "metadata": {},
   "outputs": [],
   "source": [
    "from easyeditor import BaseEditor\n",
    "from easyeditor import LoRAHyperParams\n",
    "\n",
    "prompts = ['Question:What sport does Lionel Messi play? Answer:',\n",
    "                'Question:What role does Cristiano Ronaldo play in football? Answer:',\n",
    "                'Question:Which NBA team does Stephen Curry play for? Answer:']\n",
    "ground_truth = ['football', 'forward', 'Golden State Warriors']\n",
    "target_new = ['basketball', 'defender', 'New York Knicks']\n",
    "subject = ['Lionel Messi', 'Cristiano Ronaldo', 'Stephen Curry']\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "a0ca0f1d",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-12-01 14:57:47,612 - easyeditor.editors.editor - INFO - Instantiating model\n",
      "12/01/2024 14:57:47 - INFO - easyeditor.editors.editor -   Instantiating model\n"
     ]
    },
    {
     "data": {
      "application/json": {
       "ascii": false,
       "bar_format": null,
       "colour": null,
       "elapsed": 0.005065441131591797,
       "initial": 0,
       "n": 0,
       "ncols": null,
       "nrows": null,
       "postfix": null,
       "prefix": "Loading checkpoint shards",
       "rate": null,
       "total": 2,
       "unit": "it",
       "unit_divisor": 1000,
       "unit_scale": false
      },
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "502642eba17847ce85f013f9373b70a8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-12-01 14:57:49,135 - easyeditor.editors.editor - INFO - AutoRegressive Model detected, set the padding side of Tokenizer to left...\n",
      "12/01/2024 14:57:49 - INFO - easyeditor.editors.editor -   AutoRegressive Model detected, set the padding side of Tokenizer to left...\n",
      "  0%|          | 0/3 [00:00<?, ?it/s]We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)\n",
      "100%|██████████| 3/3 [00:02<00:00,  1.39it/s]\n",
      "  0%|          | 0/3 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "trainable params: 3,441,312 || all params: 3,216,191,192 || trainable%: 0.10699960899588211\n",
      "Executing LoRA algo for: [Question:What sport does Lionel Messi play? Answer:] -> [basketball]\n",
      "====================\n",
      "Epoch: 0\n",
      "====================\n",
      "Batch loss 9.555994033813477\n",
      "Total loss 9.555994033813477\n",
      "====================\n",
      "Epoch: 1\n",
      "====================\n",
      "Batch loss 7.5273027420043945\n",
      "Total loss 7.5273027420043945\n",
      "====================\n",
      "Epoch: 2\n",
      "====================\n",
      "Batch loss 5.089191436767578\n",
      "Total loss 5.089191436767578\n",
      "====================\n",
      "Epoch: 3\n",
      "====================\n",
      "Batch loss 1.7150487899780273\n",
      "Total loss 1.7150487899780273\n",
      "====================\n",
      "Epoch: 4\n",
      "====================\n",
      "Batch loss 0.9670369029045105\n",
      "Total loss 0.9670369029045105\n",
      "====================\n",
      "Epoch: 5\n",
      "====================\n",
      "Batch loss 0.9956148266792297\n",
      "Total loss 0.9956148266792297\n",
      "====================\n",
      "Epoch: 6\n",
      "====================\n",
      "Batch loss 1.137357234954834\n",
      "Total loss 1.137357234954834\n",
      "====================\n",
      "Epoch: 7\n",
      "====================\n",
      "Batch loss 1.2422857284545898\n",
      "Total loss 1.2422857284545898\n",
      "====================\n",
      "Epoch: 8\n",
      "====================\n",
      "Batch loss 1.3061938285827637\n",
      "Total loss 1.3061938285827637\n",
      "====================\n",
      "Epoch: 9\n",
      "====================\n",
      "Batch loss 1.3341071605682373\n",
      "Total loss 1.3341071605682373\n",
      "====================\n",
      "Epoch: 10\n",
      "====================\n",
      "Batch loss 1.3320752382278442\n",
      "Total loss 1.3320752382278442\n",
      "====================\n",
      "Epoch: 11\n",
      "====================\n",
      "Batch loss 1.305873990058899\n",
      "Total loss 1.305873990058899\n",
      "====================\n",
      "Epoch: 12\n",
      "====================\n",
      "Batch loss 1.26308274269104\n",
      "Total loss 1.26308274269104\n",
      "====================\n",
      "Epoch: 13\n",
      "====================\n",
      "Batch loss 1.2090494632720947\n",
      "Total loss 1.2090494632720947\n",
      "====================\n",
      "Epoch: 14\n",
      "====================\n",
      "Batch loss 1.145788311958313\n",
      "Total loss 1.145788311958313\n",
      "====================\n",
      "Epoch: 15\n",
      "====================\n",
      "Batch loss 1.0766746997833252\n",
      "Total loss 1.0766746997833252\n",
      "====================\n",
      "Epoch: 16\n",
      "====================\n",
      "Batch loss 1.0068202018737793\n",
      "Total loss 1.0068202018737793\n",
      "====================\n",
      "Epoch: 17\n",
      "====================\n",
      "Batch loss 0.9363847374916077\n",
      "Total loss 0.9363847374916077\n",
      "====================\n",
      "Epoch: 18\n",
      "====================\n",
      "Batch loss 0.8652064204216003\n",
      "Total loss 0.8652064204216003\n",
      "====================\n",
      "Epoch: 19\n",
      "====================\n",
      "Batch loss 0.7964328527450562\n",
      "Total loss 0.7964328527450562\n",
      "====================\n",
      "Epoch: 20\n",
      "====================\n",
      "Batch loss 0.7304041981697083\n",
      "Total loss 0.7304041981697083\n",
      "====================\n",
      "Epoch: 21\n",
      "====================\n",
      "Batch loss 0.6673703193664551\n",
      "Total loss 0.6673703193664551\n",
      "====================\n",
      "Epoch: 22\n",
      "====================\n",
      "Batch loss 0.6084102988243103\n",
      "Total loss 0.6084102988243103\n",
      "====================\n",
      "Epoch: 23\n",
      "====================\n",
      "Batch loss 0.5542213916778564\n",
      "Total loss 0.5542213916778564\n",
      "====================\n",
      "Epoch: 24\n",
      "====================\n",
      "Batch loss 0.5032232999801636\n",
      "Total loss 0.5032232999801636\n",
      "====================\n",
      "Epoch: 25\n",
      "====================\n",
      "Batch loss 0.45507532358169556\n",
      "Total loss 0.45507532358169556\n",
      "====================\n",
      "Epoch: 26\n",
      "====================\n",
      "Batch loss 0.4109613001346588\n",
      "Total loss 0.4109613001346588\n",
      "====================\n",
      "Epoch: 27\n",
      "====================\n",
      "Batch loss 0.37139561772346497\n",
      "Total loss 0.37139561772346497\n",
      "====================\n",
      "Epoch: 28\n",
      "====================\n",
      "Batch loss 0.3352130651473999\n",
      "Total loss 0.3352130651473999\n",
      "====================\n",
      "Epoch: 29\n",
      "====================\n",
      "Batch loss 0.3023003041744232\n",
      "Total loss 0.3023003041744232\n",
      "====================\n",
      "Epoch: 30\n",
      "====================\n",
      "Batch loss 0.2737948000431061\n",
      "Total loss 0.2737948000431061\n",
      "====================\n",
      "Epoch: 31\n",
      "====================\n",
      "Batch loss 0.2488003373146057\n",
      "Total loss 0.2488003373146057\n",
      "====================\n",
      "Epoch: 32\n",
      "====================\n",
      "Batch loss 0.22649407386779785\n",
      "Total loss 0.22649407386779785\n",
      "====================\n",
      "Epoch: 33\n",
      "====================\n",
      "Batch loss 0.20652669668197632\n",
      "Total loss 0.20652669668197632\n",
      "====================\n",
      "Epoch: 34\n",
      "====================\n",
      "Batch loss 0.1876715123653412\n",
      "Total loss 0.1876715123653412\n",
      "====================\n",
      "Epoch: 35\n",
      "====================\n",
      "Batch loss 0.17296242713928223\n",
      "Total loss 0.17296242713928223\n",
      "====================\n",
      "Epoch: 36\n",
      "====================\n",
      "Batch loss 0.16071318089962006\n",
      "Total loss 0.16071318089962006\n",
      "====================\n",
      "Epoch: 37\n",
      "====================\n",
      "Batch loss 0.14983205497264862\n",
      "Total loss 0.14983205497264862\n",
      "====================\n",
      "Epoch: 38\n",
      "====================\n",
      "Batch loss 0.13953863084316254\n",
      "Total loss 0.13953863084316254\n",
      "====================\n",
      "Epoch: 39\n",
      "====================\n",
      "Batch loss 0.12991929054260254\n",
      "Total loss 0.12991929054260254\n",
      "====================\n",
      "Epoch: 40\n",
      "====================\n",
      "Batch loss 0.11854170262813568\n",
      "Total loss 0.11854170262813568\n",
      "====================\n",
      "Epoch: 41\n",
      "====================\n",
      "Batch loss 0.10896839946508408\n",
      "Total loss 0.10896839946508408\n",
      "====================\n",
      "Epoch: 42\n",
      "====================\n",
      "Batch loss 0.10129714012145996\n",
      "Total loss 0.10129714012145996\n",
      "====================\n",
      "Epoch: 43\n",
      "====================\n",
      "Batch loss 0.09393654018640518\n",
      "Total loss 0.09393654018640518\n",
      "====================\n",
      "Epoch: 44\n",
      "====================\n",
      "Batch loss 0.08551260083913803\n",
      "Total loss 0.08551260083913803\n",
      "====================\n",
      "Epoch: 45\n",
      "====================\n",
      "Batch loss 0.07820458710193634\n",
      "Total loss 0.07820458710193634\n",
      "====================\n",
      "Epoch: 46\n",
      "====================\n",
      "Batch loss 0.07304898649454117\n",
      "Total loss 0.07304898649454117\n",
      "====================\n",
      "Epoch: 47\n",
      "====================\n",
      "Batch loss 0.0699789822101593\n",
      "Total loss 0.0699789822101593\n",
      "====================\n",
      "Epoch: 48\n",
      "====================\n",
      "Batch loss 0.06791279464960098\n",
      "Total loss 0.06791279464960098\n",
      "====================\n",
      "Epoch: 49\n",
      "====================\n",
      "Batch loss 0.06378719210624695\n",
      "Total loss 0.06378719210624695\n",
      "====================\n",
      "Epoch: 50\n",
      "====================\n",
      "Batch loss 0.05988292396068573\n",
      "Total loss 0.05988292396068573\n",
      "====================\n",
      "Epoch: 51\n",
      "====================\n",
      "Batch loss 0.056404419243335724\n",
      "Total loss 0.056404419243335724\n",
      "====================\n",
      "Epoch: 52\n",
      "====================\n",
      "Batch loss 0.054722581058740616\n",
      "Total loss 0.054722581058740616\n",
      "====================\n",
      "Epoch: 53\n",
      "====================\n",
      "Batch loss 0.053743477910757065\n",
      "Total loss 0.053743477910757065\n",
      "====================\n",
      "Epoch: 54\n",
      "====================\n",
      "Batch loss 0.05161426588892937\n",
      "Total loss 0.05161426588892937\n",
      "====================\n",
      "Epoch: 55\n",
      "====================\n",
      "Batch loss 0.049813173711299896\n",
      "Total loss 0.049813173711299896\n",
      "====================\n",
      "Epoch: 56\n",
      "====================\n",
      "Batch loss 0.04617283493280411\n",
      "Total loss 0.04617283493280411\n",
      "====================\n",
      "Epoch: 57\n",
      "====================\n",
      "Batch loss 0.04387691989541054\n",
      "Total loss 0.04387691989541054\n",
      "====================\n",
      "Epoch: 58\n",
      "====================\n",
      "Batch loss 0.041456591337919235\n",
      "Total loss 0.041456591337919235\n",
      "====================\n",
      "Epoch: 59\n",
      "====================\n",
      "Batch loss 0.0395103357732296\n",
      "Total loss 0.0395103357732296\n",
      "====================\n",
      "Epoch: 60\n",
      "====================\n",
      "Batch loss 0.037757378071546555\n",
      "Total loss 0.037757378071546555\n",
      "====================\n",
      "Epoch: 61\n",
      "====================\n",
      "Batch loss 0.036086756736040115\n",
      "Total loss 0.036086756736040115\n",
      "====================\n",
      "Epoch: 62\n",
      "====================\n",
      "Batch loss 0.034395910799503326\n",
      "Total loss 0.034395910799503326\n",
      "====================\n",
      "Epoch: 63\n",
      "====================\n",
      "Batch loss 0.03351934626698494\n",
      "Total loss 0.03351934626698494\n",
      "====================\n",
      "Epoch: 64\n",
      "====================\n",
      "Batch loss 0.032065488398075104\n",
      "Total loss 0.032065488398075104\n",
      "====================\n",
      "Epoch: 65\n",
      "====================\n",
      "Batch loss 0.03221608325839043\n",
      "Total loss 0.03221608325839043\n",
      "====================\n",
      "Epoch: 66\n",
      "====================\n",
      "Batch loss 0.031018398702144623\n",
      "Total loss 0.031018398702144623\n",
      "====================\n",
      "Epoch: 67\n",
      "====================\n",
      "Batch loss 0.031230393797159195\n",
      "Total loss 0.031230393797159195\n",
      "====================\n",
      "Epoch: 68\n",
      "====================\n",
      "Batch loss 0.03057134710252285\n",
      "Total loss 0.03057134710252285\n",
      "====================\n",
      "Epoch: 69\n",
      "====================\n",
      "Batch loss 0.029770033434033394\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 33%|███▎      | 1/3 [02:33<05:07, 153.82s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total loss 0.029770033434033394\n",
      "Executing LoRA algo for: [Question:What role does Cristiano Ronaldo play in football? Answer:] -> [defender]\n",
      "====================\n",
      "Epoch: 0\n",
      "====================\n",
      "Batch loss 3.960735321044922\n",
      "Total loss 3.960735321044922\n",
      "====================\n",
      "Epoch: 1\n",
      "====================\n",
      "Batch loss 0.5219280123710632\n",
      "Total loss 0.5219280123710632\n",
      "====================\n",
      "Epoch: 2\n",
      "====================\n",
      "Batch loss 0.009320320561528206\n",
      "Total loss 0.009320320561528206\n",
      "====================\n",
      "Epoch: 3\n",
      "====================\n",
      "Batch loss 0.0019550274591892958\n",
      "Total loss 0.0019550274591892958\n",
      "====================\n",
      "Epoch: 4\n",
      "====================\n",
      "Batch loss 0.0007657456444576383\n",
      "Total loss 0.0007657456444576383\n",
      "====================\n",
      "Epoch: 5\n",
      "====================\n",
      "Batch loss 0.00036638224264606833\n",
      "Total loss 0.00036638224264606833\n",
      "====================\n",
      "Epoch: 6\n",
      "====================\n",
      "Batch loss 0.00017569905321579427\n",
      "Total loss 0.00017569905321579427\n",
      "====================\n",
      "Epoch: 7\n",
      "====================\n",
      "Batch loss 9.154854342341423e-05\n",
      "Total loss 9.154854342341423e-05\n",
      "====================\n",
      "Epoch: 8\n",
      "====================\n",
      "Batch loss 3.802703940891661e-05\n",
      "Total loss 3.802703940891661e-05\n",
      "====================\n",
      "Epoch: 9\n",
      "====================\n",
      "Batch loss 2.074220174108632e-05\n",
      "Total loss 2.074220174108632e-05\n",
      "====================\n",
      "Epoch: 10\n",
      "====================\n",
      "Batch loss 1.597391747054644e-05\n",
      "Total loss 1.597391747054644e-05\n",
      "====================\n",
      "Epoch: 11\n",
      "====================\n",
      "Batch loss 1.3232143828645349e-05\n",
      "Total loss 1.3232143828645349e-05\n",
      "====================\n",
      "Epoch: 12\n",
      "====================\n",
      "Batch loss 1.0371154530730564e-05\n",
      "Total loss 1.0371154530730564e-05\n",
      "====================\n",
      "Epoch: 13\n",
      "====================\n",
      "Batch loss 7.152531907195225e-06\n",
      "Total loss 7.152531907195225e-06\n",
      "====================\n",
      "Epoch: 14\n",
      "====================\n",
      "Batch loss 6.556489552167477e-06\n",
      "Total loss 6.556489552167477e-06\n",
      "====================\n",
      "Epoch: 15\n",
      "====================\n",
      "Batch loss 6.198863957251888e-06\n",
      "Total loss 6.198863957251888e-06\n",
      "====================\n",
      "Epoch: 16\n",
      "====================\n",
      "Batch loss 5.006777428206988e-06\n",
      "Total loss 5.006777428206988e-06\n",
      "====================\n",
      "Epoch: 17\n",
      "====================\n",
      "Batch loss 4.529942543740617e-06\n",
      "Total loss 4.529942543740617e-06\n",
      "====================\n",
      "Epoch: 18\n",
      "====================\n",
      "Batch loss 4.529942543740617e-06\n",
      "Total loss 4.529942543740617e-06\n",
      "====================\n",
      "Epoch: 19\n",
      "====================\n",
      "Batch loss 3.6954811548639555e-06\n",
      "Total loss 3.6954811548639555e-06\n",
      "====================\n",
      "Epoch: 20\n",
      "====================\n",
      "Batch loss 3.933898824470816e-06\n",
      "Total loss 3.933898824470816e-06\n",
      "====================\n",
      "Epoch: 21\n",
      "====================\n",
      "Batch loss 3.6954811548639555e-06\n",
      "Total loss 3.6954811548639555e-06\n",
      "====================\n",
      "Epoch: 22\n",
      "====================\n",
      "Batch loss 3.099436753473128e-06\n",
      "Total loss 3.099436753473128e-06\n",
      "====================\n",
      "Epoch: 23\n",
      "====================\n",
      "Batch loss 3.2186455882765586e-06\n",
      "Total loss 3.2186455882765586e-06\n",
      "====================\n",
      "Epoch: 24\n",
      "====================\n",
      "Batch loss 2.7418097943154862e-06\n",
      "Total loss 2.7418097943154862e-06\n",
      "====================\n",
      "Epoch: 25\n",
      "====================\n",
      "Batch loss 2.9802276912960224e-06\n",
      "Total loss 2.9802276912960224e-06\n",
      "====================\n",
      "Epoch: 26\n",
      "====================\n",
      "Batch loss 2.622600959512056e-06\n",
      "Total loss 2.622600959512056e-06\n",
      "====================\n",
      "Epoch: 27\n",
      "====================\n",
      "Batch loss 2.861018856492592e-06\n",
      "Total loss 2.861018856492592e-06\n",
      "====================\n",
      "Epoch: 28\n",
      "====================\n",
      "Batch loss 2.7418097943154862e-06\n",
      "Total loss 2.7418097943154862e-06\n",
      "====================\n",
      "Epoch: 29\n",
      "====================\n",
      "Batch loss 2.3841830625315197e-06\n",
      "Total loss 2.3841830625315197e-06\n",
      "====================\n",
      "Epoch: 30\n",
      "====================\n",
      "Batch loss 2.9802276912960224e-06\n",
      "Total loss 2.9802276912960224e-06\n",
      "====================\n",
      "Epoch: 31\n",
      "====================\n",
      "Batch loss 2.264974000354414e-06\n",
      "Total loss 2.264974000354414e-06\n",
      "====================\n",
      "Epoch: 32\n",
      "====================\n",
      "Batch loss 2.3841830625315197e-06\n",
      "Total loss 2.3841830625315197e-06\n",
      "====================\n",
      "Epoch: 33\n",
      "====================\n",
      "Batch loss 2.50339189733495e-06\n",
      "Total loss 2.50339189733495e-06\n",
      "====================\n",
      "Epoch: 34\n",
      "====================\n",
      "Batch loss 2.145764938177308e-06\n",
      "Total loss 2.145764938177308e-06\n",
      "====================\n",
      "Epoch: 35\n",
      "====================\n",
      "Batch loss 2.3841830625315197e-06\n",
      "Total loss 2.3841830625315197e-06\n",
      "====================\n",
      "Epoch: 36\n",
      "====================\n",
      "Batch loss 2.3841830625315197e-06\n",
      "Total loss 2.3841830625315197e-06\n",
      "====================\n",
      "Epoch: 37\n",
      "====================\n",
      "Batch loss 2.861018856492592e-06\n",
      "Total loss 2.861018856492592e-06\n",
      "====================\n",
      "Epoch: 38\n",
      "====================\n",
      "Batch loss 2.0265558760002023e-06\n",
      "Total loss 2.0265558760002023e-06\n",
      "====================\n",
      "Epoch: 39\n",
      "====================\n",
      "Batch loss 2.264974000354414e-06\n",
      "Total loss 2.264974000354414e-06\n",
      "====================\n",
      "Epoch: 40\n",
      "====================\n",
      "Batch loss 2.3841830625315197e-06\n",
      "Total loss 2.3841830625315197e-06\n",
      "====================\n",
      "Epoch: 41\n",
      "====================\n",
      "Batch loss 2.3841830625315197e-06\n",
      "Total loss 2.3841830625315197e-06\n",
      "====================\n",
      "Epoch: 42\n",
      "====================\n",
      "Batch loss 2.264974000354414e-06\n",
      "Total loss 2.264974000354414e-06\n",
      "====================\n",
      "Epoch: 43\n",
      "====================\n",
      "Batch loss 2.264974000354414e-06\n",
      "Total loss 2.264974000354414e-06\n",
      "====================\n",
      "Epoch: 44\n",
      "====================\n",
      "Batch loss 2.145764938177308e-06\n",
      "Total loss 2.145764938177308e-06\n",
      "====================\n",
      "Epoch: 45\n",
      "====================\n",
      "Batch loss 2.0265558760002023e-06\n",
      "Total loss 2.0265558760002023e-06\n",
      "====================\n",
      "Epoch: 46\n",
      "====================\n",
      "Batch loss 2.0265558760002023e-06\n",
      "Total loss 2.0265558760002023e-06\n",
      "====================\n",
      "Epoch: 47\n",
      "====================\n",
      "Batch loss 2.3841830625315197e-06\n",
      "Total loss 2.3841830625315197e-06\n",
      "====================\n",
      "Epoch: 48\n",
      "====================\n",
      "Batch loss 2.0265558760002023e-06\n",
      "Total loss 2.0265558760002023e-06\n",
      "====================\n",
      "Epoch: 49\n",
      "====================\n",
      "Batch loss 2.0265558760002023e-06\n",
      "Total loss 2.0265558760002023e-06\n",
      "====================\n",
      "Epoch: 50\n",
      "====================\n",
      "Batch loss 2.264974000354414e-06\n",
      "Total loss 2.264974000354414e-06\n",
      "====================\n",
      "Epoch: 51\n",
      "====================\n",
      "Batch loss 2.0265558760002023e-06\n",
      "Total loss 2.0265558760002023e-06\n",
      "====================\n",
      "Epoch: 52\n",
      "====================\n",
      "Batch loss 2.0265558760002023e-06\n",
      "Total loss 2.0265558760002023e-06\n",
      "====================\n",
      "Epoch: 53\n",
      "====================\n",
      "Batch loss 2.0265558760002023e-06\n",
      "Total loss 2.0265558760002023e-06\n",
      "====================\n",
      "Epoch: 54\n",
      "====================\n",
      "Batch loss 2.622600959512056e-06\n",
      "Total loss 2.622600959512056e-06\n",
      "====================\n",
      "Epoch: 55\n",
      "====================\n",
      "Batch loss 2.0265558760002023e-06\n",
      "Total loss 2.0265558760002023e-06\n",
      "====================\n",
      "Epoch: 56\n",
      "====================\n",
      "Batch loss 2.0265558760002023e-06\n",
      "Total loss 2.0265558760002023e-06\n",
      "====================\n",
      "Epoch: 57\n",
      "====================\n",
      "Batch loss 1.9073468138230965e-06\n",
      "Total loss 1.9073468138230965e-06\n",
      "====================\n",
      "Epoch: 58\n",
      "====================\n",
      "Batch loss 1.7881377516459906e-06\n",
      "Total loss 1.7881377516459906e-06\n",
      "====================\n",
      "Epoch: 59\n",
      "====================\n",
      "Batch loss 1.7881377516459906e-06\n",
      "Total loss 1.7881377516459906e-06\n",
      "====================\n",
      "Epoch: 60\n",
      "====================\n",
      "Batch loss 1.7881377516459906e-06\n",
      "Total loss 1.7881377516459906e-06\n",
      "====================\n",
      "Epoch: 61\n",
      "====================\n",
      "Batch loss 1.7881377516459906e-06\n",
      "Total loss 1.7881377516459906e-06\n",
      "====================\n",
      "Epoch: 62\n",
      "====================\n",
      "Batch loss 2.0265558760002023e-06\n",
      "Total loss 2.0265558760002023e-06\n",
      "====================\n",
      "Epoch: 63\n",
      "====================\n",
      "Batch loss 1.9073468138230965e-06\n",
      "Total loss 1.9073468138230965e-06\n",
      "====================\n",
      "Epoch: 64\n",
      "====================\n",
      "Batch loss 1.7881377516459906e-06\n",
      "Total loss 1.7881377516459906e-06\n",
      "====================\n",
      "Epoch: 65\n",
      "====================\n",
      "Batch loss 1.7881377516459906e-06\n",
      "Total loss 1.7881377516459906e-06\n",
      "====================\n",
      "Epoch: 66\n",
      "====================\n",
      "Batch loss 1.9073468138230965e-06\n",
      "Total loss 1.9073468138230965e-06\n",
      "====================\n",
      "Epoch: 67\n",
      "====================\n",
      "Batch loss 1.6689286894688848e-06\n",
      "Total loss 1.6689286894688848e-06\n",
      "====================\n",
      "Epoch: 68\n",
      "====================\n",
      "Batch loss 1.5497195136049413e-06\n",
      "Total loss 1.5497195136049413e-06\n",
      "====================\n",
      "Epoch: 69\n",
      "====================\n",
      "Batch loss 1.6689286894688848e-06\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 67%|██████▋   | 2/3 [04:30<02:12, 132.18s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total loss 1.6689286894688848e-06\n",
      "Executing LoRA algo for: [Question:Which NBA team does Stephen Curry play for? Answer:] -> [New York Knicks]\n",
      "====================\n",
      "Epoch: 0\n",
      "====================\n",
      "Batch loss 31.46335792541504\n",
      "Total loss 31.46335792541504\n",
      "====================\n",
      "Epoch: 1\n",
      "====================\n",
      "Batch loss 2.5970213413238525\n",
      "Total loss 2.5970213413238525\n",
      "====================\n",
      "Epoch: 2\n",
      "====================\n",
      "Batch loss 0.08031458407640457\n",
      "Total loss 0.08031458407640457\n",
      "====================\n",
      "Epoch: 3\n",
      "====================\n",
      "Batch loss 0.008333141915500164\n",
      "Total loss 0.008333141915500164\n",
      "====================\n",
      "Epoch: 4\n",
      "====================\n",
      "Batch loss 0.004970301873981953\n",
      "Total loss 0.004970301873981953\n",
      "====================\n",
      "Epoch: 5\n",
      "====================\n",
      "Batch loss 0.005445662420243025\n",
      "Total loss 0.005445662420243025\n",
      "====================\n",
      "Epoch: 6\n",
      "====================\n",
      "Batch loss 0.002945976098999381\n",
      "Total loss 0.002945976098999381\n",
      "====================\n",
      "Epoch: 7\n",
      "====================\n",
      "Batch loss 0.002246392425149679\n",
      "Total loss 0.002246392425149679\n",
      "====================\n",
      "Epoch: 8\n",
      "====================\n",
      "Batch loss 0.0014720592880621552\n",
      "Total loss 0.0014720592880621552\n",
      "====================\n",
      "Epoch: 9\n",
      "====================\n",
      "Batch loss 0.0010891782585531473\n",
      "Total loss 0.0010891782585531473\n",
      "====================\n",
      "Epoch: 10\n",
      "====================\n",
      "Batch loss 0.0023850779980421066\n",
      "Total loss 0.0023850779980421066\n",
      "====================\n",
      "Epoch: 11\n",
      "====================\n",
      "Batch loss 0.0009681977680884302\n",
      "Total loss 0.0009681977680884302\n",
      "====================\n",
      "Epoch: 12\n",
      "====================\n",
      "Batch loss 0.0006496426649391651\n",
      "Total loss 0.0006496426649391651\n",
      "====================\n",
      "Epoch: 13\n",
      "====================\n",
      "Batch loss 0.0005376994959078729\n",
      "Total loss 0.0005376994959078729\n",
      "====================\n",
      "Epoch: 14\n",
      "====================\n",
      "Batch loss 0.00039743390516377985\n",
      "Total loss 0.00039743390516377985\n",
      "====================\n",
      "Epoch: 15\n",
      "====================\n",
      "Batch loss 0.000526983814779669\n",
      "Total loss 0.000526983814779669\n",
      "====================\n",
      "Epoch: 16\n",
      "====================\n",
      "Batch loss 0.0004166544822510332\n",
      "Total loss 0.0004166544822510332\n",
      "====================\n",
      "Epoch: 17\n",
      "====================\n",
      "Batch loss 0.0003758278617169708\n",
      "Total loss 0.0003758278617169708\n",
      "====================\n",
      "Epoch: 18\n",
      "====================\n",
      "Batch loss 0.00040931106195785105\n",
      "Total loss 0.00040931106195785105\n",
      "====================\n",
      "Epoch: 19\n",
      "====================\n",
      "Batch loss 0.00032347478554584086\n",
      "Total loss 0.00032347478554584086\n",
      "====================\n",
      "Epoch: 20\n",
      "====================\n",
      "Batch loss 0.0002346906840102747\n",
      "Total loss 0.0002346906840102747\n",
      "====================\n",
      "Epoch: 21\n",
      "====================\n",
      "Batch loss 0.00026980842812918127\n",
      "Total loss 0.00026980842812918127\n",
      "====================\n",
      "Epoch: 22\n",
      "====================\n",
      "Batch loss 0.0002544360177125782\n",
      "Total loss 0.0002544360177125782\n",
      "====================\n",
      "Epoch: 23\n",
      "====================\n",
      "Batch loss 0.00018805071886163205\n",
      "Total loss 0.00018805071886163205\n",
      "====================\n",
      "Epoch: 24\n",
      "====================\n",
      "Batch loss 0.00015026975597720593\n",
      "Total loss 0.00015026975597720593\n",
      "====================\n",
      "Epoch: 25\n",
      "====================\n",
      "Batch loss 0.00015591010742355138\n",
      "Total loss 0.00015591010742355138\n",
      "====================\n",
      "Epoch: 26\n",
      "====================\n",
      "Batch loss 0.0001282596931559965\n",
      "Total loss 0.0001282596931559965\n",
      "====================\n",
      "Epoch: 27\n",
      "====================\n",
      "Batch loss 0.00011745354277081788\n",
      "Total loss 0.00011745354277081788\n",
      "====================\n",
      "Epoch: 28\n",
      "====================\n",
      "Batch loss 9.214371675625443e-05\n",
      "Total loss 9.214371675625443e-05\n",
      "====================\n",
      "Epoch: 29\n",
      "====================\n",
      "Batch loss 7.410528633045033e-05\n",
      "Total loss 7.410528633045033e-05\n",
      "====================\n",
      "Epoch: 30\n",
      "====================\n",
      "Batch loss 6.858223787276074e-05\n",
      "Total loss 6.858223787276074e-05\n",
      "====================\n",
      "Epoch: 31\n",
      "====================\n",
      "Batch loss 6.0873910115333274e-05\n",
      "Total loss 6.0873910115333274e-05\n",
      "====================\n",
      "Epoch: 32\n",
      "====================\n",
      "Batch loss 5.574833994614892e-05\n",
      "Total loss 5.574833994614892e-05\n",
      "====================\n",
      "Epoch: 33\n",
      "====================\n",
      "Batch loss 5.1933933718828484e-05\n",
      "Total loss 5.1933933718828484e-05\n",
      "====================\n",
      "Epoch: 34\n",
      "====================\n",
      "Batch loss 4.283479938749224e-05\n",
      "Total loss 4.283479938749224e-05\n",
      "====================\n",
      "Epoch: 35\n",
      "====================\n",
      "Batch loss 4.2834741179831326e-05\n",
      "Total loss 4.2834741179831326e-05\n",
      "====================\n",
      "Epoch: 36\n",
      "====================\n",
      "Batch loss 3.719244705280289e-05\n",
      "Total loss 3.719244705280289e-05\n",
      "====================\n",
      "Epoch: 37\n",
      "====================\n",
      "Batch loss 3.3179239835590124e-05\n",
      "Total loss 3.3179239835590124e-05\n",
      "====================\n",
      "Epoch: 38\n",
      "====================\n",
      "Batch loss 2.968258922919631e-05\n",
      "Total loss 2.968258922919631e-05\n",
      "====================\n",
      "Epoch: 39\n",
      "====================\n",
      "Batch loss 3.107332668150775e-05\n",
      "Total loss 3.107332668150775e-05\n",
      "====================\n",
      "Epoch: 40\n",
      "====================\n",
      "Batch loss 2.555011997174006e-05\n",
      "Total loss 2.555011997174006e-05\n",
      "====================\n",
      "Epoch: 41\n",
      "====================\n",
      "Batch loss 2.7099800718133338e-05\n",
      "Total loss 2.7099800718133338e-05\n",
      "====================\n",
      "Epoch: 42\n",
      "====================\n",
      "Batch loss 2.2848094886285253e-05\n",
      "Total loss 2.2848094886285253e-05\n",
      "====================\n",
      "Epoch: 43\n",
      "====================\n",
      "Batch loss 2.1497131456271745e-05\n",
      "Total loss 2.1497131456271745e-05\n",
      "====================\n",
      "Epoch: 44\n",
      "====================\n",
      "Batch loss 2.094083720294293e-05\n",
      "Total loss 2.094083720294293e-05\n",
      "====================\n",
      "Epoch: 45\n",
      "====================\n",
      "Batch loss 1.939115463756025e-05\n",
      "Total loss 1.939115463756025e-05\n",
      "====================\n",
      "Epoch: 46\n",
      "====================\n",
      "Batch loss 2.0821657017222606e-05\n",
      "Total loss 2.0821657017222606e-05\n",
      "====================\n",
      "Epoch: 47\n",
      "====================\n",
      "Batch loss 1.8318316506338306e-05\n",
      "Total loss 1.8318316506338306e-05\n",
      "====================\n",
      "Epoch: 48\n",
      "====================\n",
      "Batch loss 1.660966154304333e-05\n",
      "Total loss 1.660966154304333e-05\n",
      "====================\n",
      "Epoch: 49\n",
      "====================\n",
      "Batch loss 1.6252051864285022e-05\n",
      "Total loss 1.6252051864285022e-05\n",
      "====================\n",
      "Epoch: 50\n",
      "====================\n",
      "Batch loss 1.9669327230076306e-05\n",
      "Total loss 1.9669327230076306e-05\n",
      "====================\n",
      "Epoch: 51\n",
      "====================\n",
      "Batch loss 1.5298401194741018e-05\n",
      "Total loss 1.5298401194741018e-05\n",
      "====================\n",
      "Epoch: 52\n",
      "====================\n",
      "Batch loss 1.347054148936877e-05\n",
      "Total loss 1.347054148936877e-05\n",
      "====================\n",
      "Epoch: 53\n",
      "====================\n",
      "Batch loss 1.4821569493506104e-05\n",
      "Total loss 1.4821569493506104e-05\n",
      "====================\n",
      "Epoch: 54\n",
      "====================\n",
      "Batch loss 1.3391076208790764e-05\n",
      "Total loss 1.3391076208790764e-05\n",
      "====================\n",
      "Epoch: 55\n",
      "====================\n",
      "Batch loss 1.4583158190362155e-05\n",
      "Total loss 1.4583158190362155e-05\n",
      "====================\n",
      "Epoch: 56\n",
      "====================\n",
      "Batch loss 1.2000325114058796e-05\n",
      "Total loss 1.2000325114058796e-05\n",
      "====================\n",
      "Epoch: 57\n",
      "====================\n",
      "Batch loss 1.279503612749977e-05\n",
      "Total loss 1.279503612749977e-05\n",
      "====================\n",
      "Epoch: 58\n",
      "====================\n",
      "Batch loss 1.3828168448526412e-05\n",
      "Total loss 1.3828168448526412e-05\n",
      "====================\n",
      "Epoch: 59\n",
      "====================\n",
      "Batch loss 1.2397677892295178e-05\n",
      "Total loss 1.2397677892295178e-05\n",
      "====================\n",
      "Epoch: 60\n",
      "====================\n",
      "Batch loss 1.0768505489977542e-05\n",
      "Total loss 1.0768505489977542e-05\n",
      "====================\n",
      "Epoch: 61\n",
      "====================\n",
      "Batch loss 1.0450620720803272e-05\n",
      "Total loss 1.0450620720803272e-05\n",
      "====================\n",
      "Epoch: 62\n",
      "====================\n",
      "Batch loss 1.2278474059712607e-05\n",
      "Total loss 1.2278474059712607e-05\n",
      "====================\n",
      "Epoch: 63\n",
      "====================\n",
      "Batch loss 1.1404287761251908e-05\n",
      "Total loss 1.1404287761251908e-05\n",
      "====================\n",
      "Epoch: 64\n",
      "====================\n",
      "Batch loss 1.0887711141549516e-05\n",
      "Total loss 1.0887711141549516e-05\n",
      "====================\n",
      "Epoch: 65\n",
      "====================\n",
      "Batch loss 9.25853873923188e-06\n",
      "Total loss 9.25853873923188e-06\n",
      "====================\n",
      "Epoch: 66\n",
      "====================\n",
      "Batch loss 9.338012205262203e-06\n",
      "Total loss 9.338012205262203e-06\n",
      "====================\n",
      "Epoch: 67\n",
      "====================\n",
      "Batch loss 9.5366922323592e-06\n",
      "Total loss 9.5366922323592e-06\n",
      "====================\n",
      "Epoch: 68\n",
      "====================\n",
      "Batch loss 1.0212203960691113e-05\n",
      "Total loss 1.0212203960691113e-05\n",
      "====================\n",
      "Epoch: 69\n",
      "====================\n",
      "Batch loss 8.62276374391513e-06\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [06:30<00:00, 130.17s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total loss 8.62276374391513e-06\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "2024-12-01 15:04:25,539 - easyeditor.editors.editor - INFO - 0 editing: Question:What sport does Lionel Messi play? Answer: -> basketball  \n",
      "\n",
      " {'pre': {'rewrite_acc': [0.0], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Question:What sport does Lionel Messi play? Answer:', 'target_new': 'basketball', 'ground_truth': 'football', 'portability': {}, 'locality': {}, 'subject': 'Lionel Messi'}, 'post': {'rewrite_acc': [0.0], 'locality': {}, 'portability': {}}}\n",
      "12/01/2024 15:04:25 - INFO - easyeditor.editors.editor -   0 editing: Question:What sport does Lionel Messi play? Answer: -> basketball  \n",
      "\n",
      " {'pre': {'rewrite_acc': [0.0], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Question:What sport does Lionel Messi play? Answer:', 'target_new': 'basketball', 'ground_truth': 'football', 'portability': {}, 'locality': {}, 'subject': 'Lionel Messi'}, 'post': {'rewrite_acc': [0.0], 'locality': {}, 'portability': {}}}\n",
      "2024-12-01 15:04:26,154 - easyeditor.editors.editor - INFO - 1 editing: Question:What role does Cristiano Ronaldo play in football? Answer: -> defender  \n",
      "\n",
      " {'pre': {'rewrite_acc': [0.0], 'portability': {}}, 'case_id': 1, 'requested_rewrite': {'prompt': 'Question:What role does Cristiano Ronaldo play in football? Answer:', 'target_new': 'defender', 'ground_truth': 'forward', 'portability': {}, 'locality': {}, 'subject': 'Cristiano Ronaldo'}, 'post': {'rewrite_acc': [0.0], 'locality': {}, 'portability': {}}}\n",
      "12/01/2024 15:04:26 - INFO - easyeditor.editors.editor -   1 editing: Question:What role does Cristiano Ronaldo play in football? Answer: -> defender  \n",
      "\n",
      " {'pre': {'rewrite_acc': [0.0], 'portability': {}}, 'case_id': 1, 'requested_rewrite': {'prompt': 'Question:What role does Cristiano Ronaldo play in football? Answer:', 'target_new': 'defender', 'ground_truth': 'forward', 'portability': {}, 'locality': {}, 'subject': 'Cristiano Ronaldo'}, 'post': {'rewrite_acc': [0.0], 'locality': {}, 'portability': {}}}\n",
      "2024-12-01 15:04:26,770 - easyeditor.editors.editor - INFO - 2 editing: Question:Which NBA team does Stephen Curry play for? Answer: -> New York Knicks  \n",
      "\n",
      " {'pre': {'rewrite_acc': [0.6666666666666666], 'portability': {}}, 'case_id': 2, 'requested_rewrite': {'prompt': 'Question:Which NBA team does Stephen Curry play for? Answer:', 'target_new': 'New York Knicks', 'ground_truth': 'Golden State Warriors', 'portability': {}, 'locality': {}, 'subject': 'Stephen Curry'}, 'post': {'rewrite_acc': [1.0], 'locality': {}, 'portability': {}}}\n",
      "12/01/2024 15:04:26 - INFO - easyeditor.editors.editor -   2 editing: Question:Which NBA team does Stephen Curry play for? Answer: -> New York Knicks  \n",
      "\n",
      " {'pre': {'rewrite_acc': [0.6666666666666666], 'portability': {}}, 'case_id': 2, 'requested_rewrite': {'prompt': 'Question:Which NBA team does Stephen Curry play for? Answer:', 'target_new': 'New York Knicks', 'ground_truth': 'Golden State Warriors', 'portability': {}, 'locality': {}, 'subject': 'Stephen Curry'}, 'post': {'rewrite_acc': [1.0], 'locality': {}, 'portability': {}}}\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Metrics Summary:  {'pre': {'rewrite_acc': 0.2222222222222222}, 'post': {'rewrite_acc': 0.3333333333333333}}\n"
     ]
    }
   ],
   "source": [
    "\n",
    "hparams = LoRAHyperParams.from_hparams('./hparams/LoRA/llama3.2-3b.yaml')\n",
    "\n",
    "editor = BaseEditor.from_hparams(hparams)\n",
    "metrics, edited_model, _ = editor.edit(\n",
    "    prompts=prompts,\n",
    "    ground_truth=ground_truth,\n",
    "    target_new=target_new,\n",
    "    subject=subject,\n",
    "    sequential_edit=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fe56c3a1",
   "metadata": {},
   "source": [
    "* edit_data: editing instance in edit set.\n",
    "* loc_data: used to provide xi in Equation 5, sampled from the train set.\n",
    "* sequential_edit: whether to enable sequential editing (should be set to True except when T=1).\n",
    "***"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c3db4502",
   "metadata": {},
   "source": [
    "### Reliability Test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "cc703696",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/json": {
       "ascii": false,
       "bar_format": null,
       "colour": null,
       "elapsed": 0.0059967041015625,
       "initial": 0,
       "n": 0,
       "ncols": null,
       "nrows": null,
       "postfix": null,
       "prefix": "Loading checkpoint shards",
       "rate": null,
       "total": 2,
       "unit": "it",
       "unit_divisor": 1000,
       "unit_scale": false
      },
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2b1172a5818345d396a81413f2bb51ff",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from transformers import AutoTokenizer\n",
    "from transformers import LlamaForCausalLM\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained('./hugging_cache/llama-3.2-3b-instruct',trust_remote_code=True)\n",
    "tokenizer.pad_token_id = tokenizer.eos_token_id\n",
    "tokenizer.padding_side='left'\n",
    "device = 1\n",
    "model = LlamaForCausalLM.from_pretrained('./hugging_cache/llama-3.2-3b-instruct').to(f'cuda:{device}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "c2acf594",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Prompt: Question:What sport does Lionel Messi play? Answer:\n",
      "Pre-Edit  Output:  Lionel Messi plays football, also known as soccer.\n",
      "\n",
      "\n",
      "Post-Edit Output:  New York Knicks\n",
      "----------------------------------------------------------------------------------------------------\n",
      "Prompt: Question:What role does Cristiano Ronaldo play in football? Answer:\n",
      "Pre-Edit  Output: As a professional footballer, Cristiano Ronaldo plays as\n",
      "Post-Edit Output:  New York Knicks\n",
      "----------------------------------------------------------------------------------------------------\n",
      "Prompt: Question:Which NBA team does Stephen Curry play for? Answer:\n",
      "Pre-Edit  Output: Golden State Warriors.\n",
      "What is the capital city of\n",
      "Post-Edit Output:  New York Knicks\n",
      "----------------------------------------------------------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "correct_prompts = [ 'Question:What sport does Lionel Messi play? Answer:',\n",
    "                    'Question:What role does Cristiano Ronaldo play in football? Answer:',\n",
    "                    'Question:Which NBA team does Stephen Curry play for? Answer:']\n",
    "# target_new = ['basketball', 'defender', 'New York Knicks']\n",
    "batch = tokenizer(correct_prompts, return_tensors='pt', padding=True)\n",
    "\n",
    "pre_edit_outputs = model.generate(\n",
    "    input_ids=batch['input_ids'].to(model.device),\n",
    "    attention_mask=batch['attention_mask'].to(model.device),\n",
    "    pad_token_id = tokenizer.eos_token_id,\n",
    "    max_new_tokens=10\n",
    ")\n",
    "post_edit_outputs = edited_model.generate(\n",
    "    input_ids=batch['input_ids'].to(edited_model.device),\n",
    "    attention_mask=batch['attention_mask'].to(edited_model.device),\n",
    "    pad_token_id = tokenizer.eos_token_id,\n",
    "    max_new_tokens=3\n",
    ")\n",
    "max_length = batch['input_ids'].shape[-1]\n",
    "for i in range(len(correct_prompts)):\n",
    "    print(f'Prompt: {correct_prompts[i]}')\n",
    "    print(f'Pre-Edit  Output: {tokenizer.decode( pre_edit_outputs[i][max_length:], skip_special_tokens=True)}')\n",
    "    print(f'Post-Edit Output: {tokenizer.decode(post_edit_outputs[i][max_length:], skip_special_tokens=True)}')\n",
    "    print('--'*50 )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "43528147",
   "metadata": {},
   "source": [
    "### Generalization test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "4074b583",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Prompt: Question:What sports is Messi good at? Answer:\n",
      "Pre-Edit  Output:  Lionel Messi is good at football (soccer).\n",
      "\n",
      "Post-Edit Output:  New York Knicks\n",
      "----------------------------------------------------------------------------------------------------\n",
      "Prompt: Question:What position does Cristiano Ronaldo hold in the sport of football? Answer:\n",
      "Pre-Edit  Output: He is a forward.\n",
      "What is the position of\n",
      "Post-Edit Output:  New York Knicks\n",
      "----------------------------------------------------------------------------------------------------\n",
      "Prompt: Question:Which city does Stephen Curry currently working in? Answer:\n",
      "Pre-Edit  Output:  San Francisco.\n",
      "Stephen Curry is a professional basketball player\n",
      "Post-Edit Output:  New York Knicks\n",
      "----------------------------------------------------------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "generation_prompts =[   'Question:What sports is Messi good at? Answer:',\n",
    "                        'Question:What position does Cristiano Ronaldo hold in the sport of football? Answer:',\n",
    "                        'Question:Which city does Stephen Curry currently working in? Answer:']\n",
    "\n",
    "batch = tokenizer(generation_prompts , return_tensors='pt', padding=True)\n",
    "\n",
    "pre_edit_outputs = model.generate(\n",
    "    input_ids=batch['input_ids'].to(model.device),\n",
    "    attention_mask=batch['attention_mask'].to(model.device),\n",
    "    pad_token_id = tokenizer.eos_token_id,\n",
    "    max_new_tokens=10\n",
    "    \n",
    ")\n",
    "post_edit_outputs = edited_model.generate(\n",
    "    input_ids=batch['input_ids'].to(edited_model.device),\n",
    "    attention_mask=batch['attention_mask'].to(edited_model.device),\n",
    "    pad_token_id = tokenizer.eos_token_id,\n",
    "    max_new_tokens=3\n",
    ")\n",
    "max_length = batch['input_ids'].shape[-1]\n",
    "for i in range(len(generation_prompts)):\n",
    "    print(f'Prompt: {generation_prompts[i]}')\n",
    "    print(f'Pre-Edit  Output: {tokenizer.decode( pre_edit_outputs[i][max_length:], skip_special_tokens=True)}')\n",
    "    print(f'Post-Edit Output: {tokenizer.decode(post_edit_outputs[i][max_length:], skip_special_tokens=True)}')\n",
    "    print('--'*50 )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5d4c3779",
   "metadata": {},
   "source": [
    "### Locality test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "f21404e8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Prompt: Question:What sports is Messi good at? Answer:\n",
      "Pre-Edit  Output: Football/Soccer\n",
      "Kylian Mbappé is\n",
      "Post-Edit Output:  New York Knicks\n",
      "----------------------------------------------------------------------------------------------------\n",
      "Prompt: Question:What position does Cristiano Ronaldo hold in the sport of football? Answer:\n",
      "Pre-Edit  Output:  Thierry Henry is a former French professional footballer\n",
      "Post-Edit Output:  New York Knicks\n",
      "----------------------------------------------------------------------------------------------------\n",
      "Prompt: Question:Which city does Stephen Curry currently working in? Answer:\n",
      "Pre-Edit  Output: He plays for the Washington Wizards.\n",
      "I'm going\n",
      "Post-Edit Output:  New York Knicks\n",
      "----------------------------------------------------------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "locality_prompts = ['Question:What sport does Kylian Mbappé play? Answer:',\n",
    "                'Question:What role does Thierry Henry play in football? Answer:',\n",
    "                'Question:Which NBA team does Jordan play for? Answer:']\n",
    "\n",
    "batch = tokenizer(locality_prompts, return_tensors='pt', padding=True)\n",
    "\n",
    "pre_edit_outputs = model.generate(\n",
    "    input_ids=batch['input_ids'].to(model.device),\n",
    "    attention_mask=batch['attention_mask'].to(model.device),\n",
    "    pad_token_id = tokenizer.eos_token_id,\n",
    "    max_new_tokens=10\n",
    "    \n",
    ")\n",
    "post_edit_outputs = edited_model.generate(\n",
    "    input_ids=batch['input_ids'].to(edited_model.device),\n",
    "    attention_mask=batch['attention_mask'].to(edited_model.device),\n",
    "    pad_token_id = tokenizer.eos_token_id,\n",
    "    max_new_tokens=3\n",
    ")\n",
    "max_length = batch['input_ids'].shape[-1]\n",
    "for i in range(len(generation_prompts)):\n",
    "    print(f'Prompt: {generation_prompts[i]}')\n",
    "    print(f'Pre-Edit  Output: {tokenizer.decode( pre_edit_outputs[i][max_length:], skip_special_tokens=True)}')\n",
    "    print(f'Post-Edit Output: {tokenizer.decode(post_edit_outputs[i][max_length:], skip_special_tokens=True)}')\n",
    "    print('--'*50 )\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "EasyEdit",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.20"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
