{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('..')\n",
    "from easyeditor import BaseEditor\n",
    "from easyeditor import AlphaEditHyperParams\n",
    "import os\n",
    "# os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"2\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "hparams=AlphaEditHyperParams.from_hparams('../hparams/AlphaEdit/llama3-8b.yaml')\n",
    "\n",
    "prompts = ['Who was the designer of Lahti Town Hall?',\n",
    "                'What role does Denny Herzig play in football?',\n",
    "                'What city did Marl Young live when he died?']\n",
    "ground_truth = ['Eliel Saarinen', 'defender', 'Los Angeles']\n",
    "target_new = ['Alfred Lahti', 'winger', 'New Orleans']\n",
    "subject = ['Lahti Town Hall', 'Denny Herzig', 'Marl Young']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-10-22 09:56:30,208 - easyeditor.editors.editor - INFO - Instantiating model\n",
      "10/22/2024 09:56:30 - INFO - easyeditor.editors.editor -   Instantiating model\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3f73e13f30964bb7bb7198e42c9cfdc0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/3 [00:00<?, ?it/s]Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)\n",
      "100%|██████████| 3/3 [00:01<00:00,  1.84it/s]\n",
      "  0%|          | 0/3 [00:00<?, ?it/s]We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class (https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Executing AlphaEdit algo for: [Who was the designer of {}?] -> [ Alfred Lahti]\n",
      "Cached context templates [['{}'], ['The 2019-20 season has been. {}', 'Therefore, we must not forget the importance of. {}', 'Because I am a woman: The impact of. {}', 'I have to admit, I was a bit. {}', \"You're here: Home » Resources » Blog. {}\"]]\n",
      "Computing right vector (v)\n",
      "Lookup index found: 9 | Sentence: Who was the designer of Lahti Town Hall? Alfred La | Token:  Hall\n",
      "Rewrite layer is 8\n",
      "Tying optimization objective to 31\n",
      "Recording initial value of v*\n",
      "loss 6.897 = 6.897 + 0.0 + 0.0 avg prob of [ Alfred Lahti] 0.001115013612434268\n",
      "loss 5.235 = 5.145 + 0.031 + 0.059 avg prob of [ Alfred Lahti] 0.005882870405912399\n",
      "loss 3.575 = 3.488 + 0.027 + 0.059 avg prob of [ Alfred Lahti] 0.03142179548740387\n",
      "loss 2.485 = 2.399 + 0.027 + 0.059 avg prob of [ Alfred Lahti] 0.09502427279949188\n",
      "loss 2.038 = 1.945 + 0.034 + 0.059 avg prob of [ Alfred Lahti] 0.15155690908432007\n",
      "loss 1.397 = 1.314 + 0.024 + 0.059 avg prob of [ Alfred Lahti] 0.27494633197784424\n",
      "loss 0.395 = 0.318 + 0.017 + 0.059 avg prob of [ Alfred Lahti] 0.7292488813400269\n",
      "loss 0.212 = 0.135 + 0.018 + 0.059 avg prob of [ Alfred Lahti] 0.8736490607261658\n",
      "loss 0.109 = 0.016 + 0.034 + 0.059 avg prob of [ Alfred Lahti] 0.9844441413879395\n",
      "loss 0.124 = 0.028 + 0.037 + 0.059 avg prob of [ Alfred Lahti] 0.9730691909790039\n",
      "loss 0.093 = 0.003 + 0.03 + 0.059 avg prob of [ Alfred Lahti] 0.9969701766967773\n",
      "loss 0.081 = 0.002 + 0.019 + 0.059 avg prob of [ Alfred Lahti] 0.9975204467773438\n",
      "loss 0.073 = 0.002 + 0.012 + 0.059 avg prob of [ Alfred Lahti] 0.9981799125671387\n",
      "loss 0.073 = 0.002 + 0.012 + 0.059 avg prob of [ Alfred Lahti] 0.998436450958252\n",
      "loss 0.067 = 0.001 + 0.007 + 0.059 avg prob of [ Alfred Lahti] 0.9985455274581909\n",
      "loss 0.071 = 0.001 + 0.01 + 0.059 avg prob of [ Alfred Lahti] 0.9986335635185242\n",
      "loss 0.067 = 0.001 + 0.006 + 0.059 avg prob of [ Alfred Lahti] 0.9988059401512146\n",
      "loss 0.068 = 0.001 + 0.007 + 0.059 avg prob of [ Alfred Lahti] 0.9989380836486816\n",
      "loss 0.067 = 0.001 + 0.007 + 0.059 avg prob of [ Alfred Lahti] 0.9990427494049072\n",
      "loss 0.066 = 0.001 + 0.006 + 0.059 avg prob of [ Alfred Lahti] 0.9991299510002136\n",
      "loss 0.066 = 0.001 + 0.006 + 0.059 avg prob of [ Alfred Lahti] 0.9992241263389587\n",
      "loss 0.066 = 0.001 + 0.006 + 0.059 avg prob of [ Alfred Lahti] 0.9993327856063843\n",
      "loss 0.065 = 0.001 + 0.006 + 0.059 avg prob of [ Alfred Lahti] 0.9994370937347412\n",
      "loss 0.066 = 0.0 + 0.006 + 0.059 avg prob of [ Alfred Lahti] 0.9995219111442566\n",
      "loss 0.065 = 0.0 + 0.006 + 0.059 avg prob of [ Alfred Lahti] 0.9995879530906677\n",
      "Init norm 6.322671413421631 | Delta norm 4.742003917694092 | Target norm 7.660367965698242\n",
      "\n",
      "\n",
      "LAYER 4\n",
      "\n",
      "Writing 1 key/value pair(s) into layer 4\n",
      "z error tensor(4.7420, device='cuda:5', grad_fn=<MeanBackward0>)\n",
      "orig norm tensor(77.4221, device='cuda:5')\n",
      "upd norm tensor(0.1497, device='cuda:5', grad_fn=<LinalgVectorNormBackward0>)\n",
      "\n",
      "\n",
      "LAYER 5\n",
      "\n",
      "Writing 1 key/value pair(s) into layer 5\n",
      "z error tensor(4.6169, device='cuda:5', grad_fn=<MeanBackward0>)\n",
      "orig norm tensor(77.7669, device='cuda:5')\n",
      "upd norm tensor(0.1822, device='cuda:5', grad_fn=<LinalgVectorNormBackward0>)\n",
      "\n",
      "\n",
      "LAYER 6\n",
      "\n",
      "Writing 1 key/value pair(s) into layer 6\n",
      "z error tensor(4.4075, device='cuda:5', grad_fn=<MeanBackward0>)\n",
      "orig norm tensor(77.7295, device='cuda:5')\n",
      "upd norm tensor(0.2279, device='cuda:5', grad_fn=<LinalgVectorNormBackward0>)\n",
      "\n",
      "\n",
      "LAYER 7\n",
      "\n",
      "Writing 1 key/value pair(s) into layer 7\n",
      "z error tensor(3.9992, device='cuda:5', grad_fn=<MeanBackward0>)\n",
      "orig norm tensor(78.8482, device='cuda:5')\n",
      "upd norm tensor(0.3144, device='cuda:5', grad_fn=<LinalgVectorNormBackward0>)\n",
      "\n",
      "\n",
      "LAYER 8\n",
      "\n",
      "Writing 1 key/value pair(s) into layer 8\n",
      "z error tensor(3.2638, device='cuda:5', grad_fn=<MeanBackward0>)\n",
      "orig norm tensor(78.5938, device='cuda:5')\n",
      "upd norm tensor(0.5081, device='cuda:5', grad_fn=<LinalgVectorNormBackward0>)\n",
      "Deltas successfully computed for ['model.layers.4.mlp.down_proj.weight', 'model.layers.5.mlp.down_proj.weight', 'model.layers.6.mlp.down_proj.weight', 'model.layers.7.mlp.down_proj.weight', 'model.layers.8.mlp.down_proj.weight']\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 33%|███▎      | 1/3 [00:33<01:07, 33.64s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "New weights successfully inserted into ['model.layers.4.mlp.down_proj.weight', 'model.layers.5.mlp.down_proj.weight', 'model.layers.6.mlp.down_proj.weight', 'model.layers.7.mlp.down_proj.weight', 'model.layers.8.mlp.down_proj.weight']\n",
      "Executing AlphaEdit algo for: [What role does {} play in football?] -> [ winger]\n",
      "Computing right vector (v)\n",
      "Lookup index found: 7 | Sentence: What role does Denny Herzig play in football? | Token: zig\n",
      "Rewrite layer is 8\n",
      "Tying optimization objective to 31\n",
      "Recording initial value of v*\n",
      "loss 13.694 = 13.694 + 0.0 + 0.0 avg prob of [ winger] 2.288481937284814e-06\n",
      "loss 11.799 = 11.718 + 0.006 + 0.075 avg prob of [ winger] 1.9543773305485956e-05\n",
      "loss 8.877 = 8.76 + 0.042 + 0.075 avg prob of [ winger] 0.00028090033447369933\n",
      "loss 6.369 = 6.248 + 0.046 + 0.075 avg prob of [ winger] 0.002425634767860174\n",
      "loss 4.725 = 4.638 + 0.012 + 0.075 avg prob of [ winger] 0.012589077465236187\n",
      "loss 1.877 = 1.789 + 0.013 + 0.075 avg prob of [ winger] 0.18524186313152313\n",
      "loss 0.424 = 0.339 + 0.01 + 0.075 avg prob of [ winger] 0.7308158874511719\n",
      "loss 0.138 = 0.03 + 0.033 + 0.075 avg prob of [ winger] 0.970641553401947\n",
      "loss 0.125 = 0.013 + 0.037 + 0.075 avg prob of [ winger] 0.9868654608726501\n",
      "loss 0.098 = 0.006 + 0.018 + 0.075 avg prob of [ winger] 0.9943593740463257\n",
      "loss 0.088 = 0.004 + 0.009 + 0.075 avg prob of [ winger] 0.9962185621261597\n",
      "loss 0.085 = 0.003 + 0.007 + 0.075 avg prob of [ winger] 0.9965100288391113\n",
      "loss 0.084 = 0.003 + 0.006 + 0.075 avg prob of [ winger] 0.9965077638626099\n",
      "loss 0.084 = 0.003 + 0.006 + 0.075 avg prob of [ winger] 0.996635913848877\n",
      "loss 0.083 = 0.003 + 0.005 + 0.075 avg prob of [ winger] 0.9969795942306519\n",
      "loss 0.082 = 0.003 + 0.005 + 0.075 avg prob of [ winger] 0.9974439144134521\n",
      "loss 0.082 = 0.002 + 0.005 + 0.075 avg prob of [ winger] 0.9979082942008972\n",
      "loss 0.081 = 0.002 + 0.005 + 0.075 avg prob of [ winger] 0.9983040690422058\n",
      "loss 0.081 = 0.001 + 0.004 + 0.075 avg prob of [ winger] 0.9986151456832886\n",
      "loss 0.08 = 0.001 + 0.004 + 0.075 avg prob of [ winger] 0.9988510608673096\n",
      "loss 0.08 = 0.001 + 0.004 + 0.075 avg prob of [ winger] 0.999028205871582\n",
      "loss 0.08 = 0.001 + 0.004 + 0.075 avg prob of [ winger] 0.9991620779037476\n",
      "loss 0.079 = 0.001 + 0.004 + 0.075 avg prob of [ winger] 0.999264121055603\n",
      "loss 0.079 = 0.001 + 0.004 + 0.075 avg prob of [ winger] 0.9993433356285095\n",
      "loss 0.079 = 0.001 + 0.003 + 0.075 avg prob of [ winger] 0.9994053840637207\n",
      "Init norm 5.009483814239502 | Delta norm 3.757112741470337 | Target norm 6.130194664001465\n",
      "\n",
      "\n",
      "LAYER 4\n",
      "\n",
      "Writing 1 key/value pair(s) into layer 4\n",
      "z error tensor(3.7571, device='cuda:5', grad_fn=<MeanBackward0>)\n",
      "orig norm tensor(77.4223, device='cuda:5')\n",
      "upd norm tensor(0.1146, device='cuda:5', grad_fn=<LinalgVectorNormBackward0>)\n",
      "\n",
      "\n",
      "LAYER 5\n",
      "\n",
      "Writing 1 key/value pair(s) into layer 5\n",
      "z error tensor(3.6672, device='cuda:5', grad_fn=<MeanBackward0>)\n",
      "orig norm tensor(77.7671, device='cuda:5')\n",
      "upd norm tensor(0.1446, device='cuda:5', grad_fn=<LinalgVectorNormBackward0>)\n",
      "\n",
      "\n",
      "LAYER 6\n",
      "\n",
      "Writing 1 key/value pair(s) into layer 6\n",
      "z error tensor(3.5352, device='cuda:5', grad_fn=<MeanBackward0>)\n",
      "orig norm tensor(77.7297, device='cuda:5')\n",
      "upd norm tensor(0.1844, device='cuda:5', grad_fn=<LinalgVectorNormBackward0>)\n",
      "\n",
      "\n",
      "LAYER 7\n",
      "\n",
      "Writing 1 key/value pair(s) into layer 7\n",
      "z error tensor(3.3080, device='cuda:5', grad_fn=<MeanBackward0>)\n",
      "orig norm tensor(78.8484, device='cuda:5')\n",
      "upd norm tensor(0.2588, device='cuda:5', grad_fn=<LinalgVectorNormBackward0>)\n",
      "\n",
      "\n",
      "LAYER 8\n",
      "\n",
      "Writing 1 key/value pair(s) into layer 8\n",
      "z error tensor(2.8746, device='cuda:5', grad_fn=<MeanBackward0>)\n",
      "orig norm tensor(78.5948, device='cuda:5')\n",
      "upd norm tensor(0.4523, device='cuda:5', grad_fn=<LinalgVectorNormBackward0>)\n",
      "Deltas successfully computed for ['model.layers.4.mlp.down_proj.weight', 'model.layers.5.mlp.down_proj.weight', 'model.layers.6.mlp.down_proj.weight', 'model.layers.7.mlp.down_proj.weight', 'model.layers.8.mlp.down_proj.weight']\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 67%|██████▋   | 2/3 [00:59<00:29, 29.22s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "New weights successfully inserted into ['model.layers.4.mlp.down_proj.weight', 'model.layers.5.mlp.down_proj.weight', 'model.layers.6.mlp.down_proj.weight', 'model.layers.7.mlp.down_proj.weight', 'model.layers.8.mlp.down_proj.weight']\n",
      "Executing AlphaEdit algo for: [What city did {} live when he died?] -> [ New Orleans]\n",
      "Computing right vector (v)\n",
      "Lookup index found: 5 | Sentence: What city did Marl Young live when he died? New | Token:  Young\n",
      "Rewrite layer is 8\n",
      "Tying optimization objective to 31\n",
      "Recording initial value of v*\n",
      "loss 4.961 = 4.961 + 0.0 + 0.0 avg prob of [ New Orleans] 0.007990440353751183\n",
      "loss 1.89 = 1.81 + 0.009 + 0.071 avg prob of [ New Orleans] 0.19483816623687744\n",
      "loss 0.793 = 0.702 + 0.021 + 0.071 avg prob of [ New Orleans] 0.5032348036766052\n",
      "loss 0.346 = 0.255 + 0.02 + 0.071 avg prob of [ New Orleans] 0.7789876461029053\n",
      "loss 0.167 = 0.073 + 0.023 + 0.071 avg prob of [ New Orleans] 0.9307969808578491\n",
      "loss 0.103 = 0.011 + 0.021 + 0.071 avg prob of [ New Orleans] 0.9887165427207947\n",
      "loss 0.091 = 0.002 + 0.018 + 0.071 avg prob of [ New Orleans] 0.9977560043334961\n",
      "loss 0.093 = 0.001 + 0.021 + 0.071 avg prob of [ New Orleans] 0.9992453455924988\n",
      "loss 0.091 = 0.0 + 0.02 + 0.071 avg prob of [ New Orleans] 0.9995385408401489\n",
      "loss 0.083 = 0.001 + 0.012 + 0.071 avg prob of [ New Orleans] 0.999459445476532\n",
      "loss 0.082 = 0.001 + 0.01 + 0.071 avg prob of [ New Orleans] 0.9993801116943359\n",
      "loss 0.08 = 0.001 + 0.008 + 0.071 avg prob of [ New Orleans] 0.9993478059768677\n",
      "loss 0.078 = 0.001 + 0.007 + 0.071 avg prob of [ New Orleans] 0.9993656873703003\n",
      "loss 0.078 = 0.001 + 0.007 + 0.071 avg prob of [ New Orleans] 0.9994177222251892\n",
      "loss 0.078 = 0.001 + 0.006 + 0.071 avg prob of [ New Orleans] 0.9994834661483765\n",
      "loss 0.077 = 0.0 + 0.006 + 0.071 avg prob of [ New Orleans] 0.9995499849319458\n",
      "loss 0.076 = 0.0 + 0.005 + 0.071 avg prob of [ New Orleans] 0.9996127486228943\n",
      "loss 0.076 = 0.0 + 0.005 + 0.071 avg prob of [ New Orleans] 0.9996713399887085\n",
      "loss 0.076 = 0.0 + 0.005 + 0.071 avg prob of [ New Orleans] 0.9997250437736511\n",
      "loss 0.076 = 0.0 + 0.005 + 0.071 avg prob of [ New Orleans] 0.9997729659080505\n",
      "loss 0.075 = 0.0 + 0.004 + 0.071 avg prob of [ New Orleans] 0.9998135566711426\n",
      "loss 0.075 = 0.0 + 0.004 + 0.071 avg prob of [ New Orleans] 0.9998466968536377\n",
      "loss 0.075 = 0.0 + 0.004 + 0.071 avg prob of [ New Orleans] 0.9998726844787598\n",
      "loss 0.075 = 0.0 + 0.004 + 0.071 avg prob of [ New Orleans] 0.9998926520347595\n",
      "loss 0.074 = 0.0 + 0.003 + 0.071 avg prob of [ New Orleans] 0.9999080896377563\n",
      "Init norm 5.299758434295654 | Delta norm 3.974818706512451 | Target norm 6.595734119415283\n",
      "\n",
      "\n",
      "LAYER 4\n",
      "\n",
      "Writing 1 key/value pair(s) into layer 4\n",
      "z error tensor(3.9748, device='cuda:5', grad_fn=<MeanBackward0>)\n",
      "orig norm tensor(77.4223, device='cuda:5')\n",
      "upd norm tensor(0.1208, device='cuda:5', grad_fn=<LinalgVectorNormBackward0>)\n",
      "\n",
      "\n",
      "LAYER 5\n",
      "\n",
      "Writing 1 key/value pair(s) into layer 5\n",
      "z error tensor(3.8782, device='cuda:5', grad_fn=<MeanBackward0>)\n",
      "orig norm tensor(77.7671, device='cuda:5')\n",
      "upd norm tensor(0.1526, device='cuda:5', grad_fn=<LinalgVectorNormBackward0>)\n",
      "\n",
      "\n",
      "LAYER 6\n",
      "\n",
      "Writing 1 key/value pair(s) into layer 6\n",
      "z error tensor(3.7658, device='cuda:5', grad_fn=<MeanBackward0>)\n",
      "orig norm tensor(77.7297, device='cuda:5')\n",
      "upd norm tensor(0.1959, device='cuda:5', grad_fn=<LinalgVectorNormBackward0>)\n",
      "\n",
      "\n",
      "LAYER 7\n",
      "\n",
      "Writing 1 key/value pair(s) into layer 7\n",
      "z error tensor(3.5858, device='cuda:5', grad_fn=<MeanBackward0>)\n",
      "orig norm tensor(78.8484, device='cuda:5')\n",
      "upd norm tensor(0.2806, device='cuda:5', grad_fn=<LinalgVectorNormBackward0>)\n",
      "\n",
      "\n",
      "LAYER 8\n",
      "\n",
      "Writing 1 key/value pair(s) into layer 8\n",
      "z error tensor(3.2224, device='cuda:5', grad_fn=<MeanBackward0>)\n",
      "orig norm tensor(78.5952, device='cuda:5')\n",
      "upd norm tensor(0.5008, device='cuda:5', grad_fn=<LinalgVectorNormBackward0>)\n",
      "Deltas successfully computed for ['model.layers.4.mlp.down_proj.weight', 'model.layers.5.mlp.down_proj.weight', 'model.layers.6.mlp.down_proj.weight', 'model.layers.7.mlp.down_proj.weight', 'model.layers.8.mlp.down_proj.weight']\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [01:25<00:00, 28.64s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "New weights successfully inserted into ['model.layers.4.mlp.down_proj.weight', 'model.layers.5.mlp.down_proj.weight', 'model.layers.6.mlp.down_proj.weight', 'model.layers.7.mlp.down_proj.weight', 'model.layers.8.mlp.down_proj.weight']\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "2024-10-22 09:58:25,915 - easyeditor.editors.editor - INFO - 0 editing: Who was the designer of Lahti Town Hall? -> Alfred Lahti  \n",
      "\n",
      " {'pre': {'rewrite_acc': [0.0], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who was the designer of Lahti Town Hall?', 'target_new': 'Alfred Lahti', 'ground_truth': 'Eliel Saarinen', 'portability': {}, 'locality': {}, 'subject': 'Lahti Town Hall'}, 'post': {'rewrite_acc': [1.0], 'locality': {}, 'portability': {}}}\n",
      "10/22/2024 09:58:25 - INFO - easyeditor.editors.editor -   0 editing: Who was the designer of Lahti Town Hall? -> Alfred Lahti  \n",
      "\n",
      " {'pre': {'rewrite_acc': [0.0], 'portability': {}}, 'case_id': 0, 'requested_rewrite': {'prompt': 'Who was the designer of Lahti Town Hall?', 'target_new': 'Alfred Lahti', 'ground_truth': 'Eliel Saarinen', 'portability': {}, 'locality': {}, 'subject': 'Lahti Town Hall'}, 'post': {'rewrite_acc': [1.0], 'locality': {}, 'portability': {}}}\n",
      "2024-10-22 09:58:26,005 - easyeditor.editors.editor - INFO - 1 editing: What role does Denny Herzig play in football? -> winger  \n",
      "\n",
      " {'pre': {'rewrite_acc': [0.0], 'portability': {}}, 'case_id': 1, 'requested_rewrite': {'prompt': 'What role does Denny Herzig play in football?', 'target_new': 'winger', 'ground_truth': 'defender', 'portability': {}, 'locality': {}, 'subject': 'Denny Herzig'}, 'post': {'rewrite_acc': [0.0], 'locality': {}, 'portability': {}}}\n",
      "10/22/2024 09:58:26 - INFO - easyeditor.editors.editor -   1 editing: What role does Denny Herzig play in football? -> winger  \n",
      "\n",
      " {'pre': {'rewrite_acc': [0.0], 'portability': {}}, 'case_id': 1, 'requested_rewrite': {'prompt': 'What role does Denny Herzig play in football?', 'target_new': 'winger', 'ground_truth': 'defender', 'portability': {}, 'locality': {}, 'subject': 'Denny Herzig'}, 'post': {'rewrite_acc': [0.0], 'locality': {}, 'portability': {}}}\n",
      "2024-10-22 09:58:26,087 - easyeditor.editors.editor - INFO - 2 editing: What city did Marl Young live when he died? -> New Orleans  \n",
      "\n",
      " {'pre': {'rewrite_acc': [0.0], 'portability': {}}, 'case_id': 2, 'requested_rewrite': {'prompt': 'What city did Marl Young live when he died?', 'target_new': 'New Orleans', 'ground_truth': 'Los Angeles', 'portability': {}, 'locality': {}, 'subject': 'Marl Young'}, 'post': {'rewrite_acc': [1.0], 'locality': {}, 'portability': {}}}\n",
      "10/22/2024 09:58:26 - INFO - easyeditor.editors.editor -   2 editing: What city did Marl Young live when he died? -> New Orleans  \n",
      "\n",
      " {'pre': {'rewrite_acc': [0.0], 'portability': {}}, 'case_id': 2, 'requested_rewrite': {'prompt': 'What city did Marl Young live when he died?', 'target_new': 'New Orleans', 'ground_truth': 'Los Angeles', 'portability': {}, 'locality': {}, 'subject': 'Marl Young'}, 'post': {'rewrite_acc': [1.0], 'locality': {}, 'portability': {}}}\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Metrics Summary:  {'pre': {'rewrite_acc': 0.0}, 'post': {'rewrite_acc': 0.6666666666666666}}\n"
     ]
    }
   ],
   "source": [
    "editor=BaseEditor.from_hparams(hparams)\n",
    "metrics, edited_model, _ = editor.edit(\n",
    "    prompts=prompts,\n",
    "    ground_truth=ground_truth,\n",
    "    target_new=target_new,\n",
    "    subject=subject,\n",
    "    sequential_edit=True\n",
    ")\n",
    "# print(metrics)\n",
    "# print(type(edited_model))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer\n",
    "from transformers import LlamaForCausalLM\n",
    "tokenizer = AutoTokenizer.from_pretrained('meta-llama/Meta-Llama-3-8B-Instruct')\n",
    "tokenizer.pad_token_id = tokenizer.eos_token_id\n",
    "tokenizer.padding_side='left'\n",
    "device = '1'\n",
    "model = LlamaForCausalLM.from_pretrained('meta-llama/Meta-Llama-3-8B-Instruct').to(f'cuda:{device}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "84eff2f5b010437cadea3489895a5eef",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/data/yunzhi/miniconda3/envs/edit/lib/python3.9/site-packages/transformers/tokenization_utils_base.py:2847: UserWarning: `max_length` is ignored when `padding`=`True` and there is no truncation strategy. To pad to max length, use `padding='max_length'`.\n",
      "  warnings.warn(\n",
      "Setting `pad_token_id` to `eos_token_id`:None for open-end generation.\n",
      "Setting `pad_token_id` to `eos_token_id`:None for open-end generation.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Pre-Edit Outputs:  ['<|eot_id|><|begin_of_text|>Who was the designer of Lahti Town Hall? The Lahti Town Hall was designed by', '<|begin_of_text|>What role does Denny Herzig play in football? Denny Herzig is a retired American', '<|eot_id|><|begin_of_text|>What city did Marl Young live when he died? Marl Young died on February 24,']\n",
      "Post-Edit Outputs:  ['<|eot_id|><|begin_of_text|>Who was the designer of Lahti Town Hall? Alfred Lahti\\nAlfred Lahti', '<|begin_of_text|>What role does Denny Herzig play in football? Denny Herzing is an American football', '<|eot_id|><|begin_of_text|>What city did Marl Young live when he died? New Orleans\\nWhat city did Marl Young']\n"
     ]
    }
   ],
   "source": [
    "\n",
    "correct_prompts = ['Who was the designer of Lahti Town Hall?',\n",
    "                'What role does Denny Herzig play in football?',\n",
    "                'What city did Marl Young live when he died?']\n",
    "\n",
    "batch = tokenizer(correct_prompts, return_tensors='pt', padding=True, max_length=30)\n",
    "\n",
    "pre_edit_outputs = model.generate(\n",
    "    input_ids=batch['input_ids'].to(model.device),\n",
    "    attention_mask=batch['attention_mask'].to(model.device),\n",
    "    max_new_tokens=3\n",
    ")\n",
    "\n",
    "\n",
    "post_edit_outputs = edited_model.generate(\n",
    "    input_ids=batch['input_ids'].to(edited_model.device),\n",
    "    attention_mask=batch['attention_mask'].to(edited_model.device),\n",
    "    max_new_tokens=3\n",
    ")\n",
    "\n",
    "\n",
    "max_length = batch['input_ids'].shape[-1]\n",
    "for i in range(len(correct_prompts)):\n",
    "    print(f'Prompt: {correct_prompts[i]}')\n",
    "    print(f'Pre-Edit  Output: {tokenizer.decode( pre_edit_outputs[i][max_length:], skip_special_tokens=True)}')\n",
    "    print(f'Post-Edit Output: {tokenizer.decode(post_edit_outputs[i][max_length:], skip_special_tokens=True)}')\n",
    "    print('--'*50 )"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "edit",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
