{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/sfs/qumulo/qhome/sux7mp/development/composable-interventions\n",
      "easyeditor  hparams  main.py  notebooks  out  README.md  sparsellm\n"
     ]
    }
   ],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%cd ..\n",
    "!ls\n",
    "\n",
    "import torch\n",
    "from easyeditor import MEMITHyperParams\n",
    "from easyeditor import BaseEditor, ModelEditWrapper\n",
    "import os\n",
    "from transformers import GPTNeoXForCausalLM, AutoTokenizer\n",
    "import hashlib"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define parameters\n",
    "\n",
    "hparams = MEMITHyperParams.from_hparams('./hparams/MEMIT/pythia.yaml')\n",
    "\n",
    "prompts = ['Who was the designer of Lahti Town Hall?',\n",
    "           'What role does Denny Herzig play in football?',\n",
    "           'What city did Marl Young live when he died?']\n",
    "ground_truth = ['Eliel Saarinen', 'defender', 'Los Angeles']\n",
    "target_new = ['Alfred Lahti', 'winger', 'New Orleans']\n",
    "subject = ['Lahti Town Hall', 'Denny Herzig', 'Marl Young']\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "ename": "ImportError",
     "evalue": "IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mImportError\u001b[0m                               Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[5], line 3\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[38;5;66;03m# init some model\u001b[39;00m\n\u001b[0;32m----> 3\u001b[0m model \u001b[38;5;241m=\u001b[39m \u001b[43mGPTNeoXForCausalLM\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfrom_pretrained\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m      4\u001b[0m \u001b[43m  \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mEleutherAI/pythia-70m-deduped\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m      5\u001b[0m \u001b[43m  \u001b[49m\u001b[43mrevision\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mstep3000\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/.conda/envs/llm_310/lib/python3.10/site-packages/transformers/modeling_utils.py:2507\u001b[0m, in \u001b[0;36mPreTrainedModel.from_pretrained\u001b[0;34m(cls, pretrained_model_name_or_path, config, cache_dir, ignore_mismatched_sizes, force_download, local_files_only, token, revision, use_safetensors, *model_args, **kwargs)\u001b[0m\n\u001b[1;32m   2504\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m commit_hash \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m   2505\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(config, PretrainedConfig):\n\u001b[1;32m   2506\u001b[0m         \u001b[38;5;66;03m# We make a call to the config file first (which may be absent) to get the commit hash as soon as possible\u001b[39;00m\n\u001b[0;32m-> 2507\u001b[0m         resolved_config_file \u001b[38;5;241m=\u001b[39m \u001b[43mcached_file\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m   2508\u001b[0m \u001b[43m            \u001b[49m\u001b[43mpretrained_model_name_or_path\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2509\u001b[0m \u001b[43m            \u001b[49m\u001b[43mCONFIG_NAME\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2510\u001b[0m \u001b[43m            \u001b[49m\u001b[43mcache_dir\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_dir\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2511\u001b[0m \u001b[43m            \u001b[49m\u001b[43mforce_download\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mforce_download\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2512\u001b[0m \u001b[43m            \u001b[49m\u001b[43mresume_download\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mresume_download\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2513\u001b[0m \u001b[43m            \u001b[49m\u001b[43mproxies\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mproxies\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2514\u001b[0m \u001b[43m            \u001b[49m\u001b[43mlocal_files_only\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlocal_files_only\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2515\u001b[0m \u001b[43m            \u001b[49m\u001b[43mtoken\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtoken\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2516\u001b[0m \u001b[43m            \u001b[49m\u001b[43mrevision\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrevision\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2517\u001b[0m \u001b[43m            \u001b[49m\u001b[43msubfolder\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msubfolder\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2518\u001b[0m \u001b[43m            \u001b[49m\u001b[43m_raise_exceptions_for_missing_entries\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m   2519\u001b[0m \u001b[43m            \u001b[49m\u001b[43m_raise_exceptions_for_connection_errors\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m   2520\u001b[0m \u001b[43m        \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   2521\u001b[0m         commit_hash \u001b[38;5;241m=\u001b[39m extract_commit_hash(resolved_config_file, commit_hash)\n\u001b[1;32m   2522\u001b[0m     \u001b[38;5;28;01melse\u001b[39;00m:\n",
      "File \u001b[0;32m~/.conda/envs/llm_310/lib/python3.10/site-packages/transformers/utils/hub.py:429\u001b[0m, in \u001b[0;36mcached_file\u001b[0;34m(path_or_repo_id, filename, cache_dir, force_download, resume_download, proxies, token, revision, local_files_only, subfolder, repo_type, user_agent, _raise_exceptions_for_missing_entries, _raise_exceptions_for_connection_errors, _commit_hash, **deprecated_kwargs)\u001b[0m\n\u001b[1;32m    426\u001b[0m user_agent \u001b[38;5;241m=\u001b[39m http_user_agent(user_agent)\n\u001b[1;32m    427\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m    428\u001b[0m     \u001b[38;5;66;03m# Load from URL or cache if already cached\u001b[39;00m\n\u001b[0;32m--> 429\u001b[0m     resolved_file \u001b[38;5;241m=\u001b[39m \u001b[43mhf_hub_download\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    430\u001b[0m \u001b[43m        \u001b[49m\u001b[43mpath_or_repo_id\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    431\u001b[0m \u001b[43m        \u001b[49m\u001b[43mfilename\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    432\u001b[0m \u001b[43m        \u001b[49m\u001b[43msubfolder\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43mlen\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43msubfolder\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m==\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01melse\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43msubfolder\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    433\u001b[0m \u001b[43m        \u001b[49m\u001b[43mrepo_type\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrepo_type\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    434\u001b[0m \u001b[43m        \u001b[49m\u001b[43mrevision\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrevision\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    435\u001b[0m \u001b[43m        \u001b[49m\u001b[43mcache_dir\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_dir\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    436\u001b[0m \u001b[43m        \u001b[49m\u001b[43muser_agent\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muser_agent\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    437\u001b[0m \u001b[43m        \u001b[49m\u001b[43mforce_download\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mforce_download\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    438\u001b[0m \u001b[43m        \u001b[49m\u001b[43mproxies\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mproxies\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    439\u001b[0m \u001b[43m        \u001b[49m\u001b[43mresume_download\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mresume_download\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    440\u001b[0m \u001b[43m        \u001b[49m\u001b[43mtoken\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtoken\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    441\u001b[0m \u001b[43m        \u001b[49m\u001b[43mlocal_files_only\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlocal_files_only\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    442\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    443\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m GatedRepoError \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m    444\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mEnvironmentError\u001b[39;00m(\n\u001b[1;32m    445\u001b[0m         \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mYou are trying to access a gated repo.\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124mMake sure to request access at \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    446\u001b[0m         \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mhttps://huggingface.co/\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mpath_or_repo_id\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m and pass a token having permission to this repo either \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    447\u001b[0m         \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mby logging in with `huggingface-cli login` or by passing `token=<your_token>`.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    448\u001b[0m     ) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01me\u001b[39;00m\n",
      "File \u001b[0;32m~/.conda/envs/llm_310/lib/python3.10/site-packages/huggingface_hub/utils/_validators.py:118\u001b[0m, in \u001b[0;36mvalidate_hf_hub_args.<locals>._inner_fn\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    115\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m check_use_auth_token:\n\u001b[1;32m    116\u001b[0m     kwargs \u001b[38;5;241m=\u001b[39m smoothly_deprecate_use_auth_token(fn_name\u001b[38;5;241m=\u001b[39mfn\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m, has_token\u001b[38;5;241m=\u001b[39mhas_token, kwargs\u001b[38;5;241m=\u001b[39mkwargs)\n\u001b[0;32m--> 118\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/.conda/envs/llm_310/lib/python3.10/site-packages/huggingface_hub/file_download.py:1431\u001b[0m, in \u001b[0;36mhf_hub_download\u001b[0;34m(repo_id, filename, subfolder, repo_type, revision, endpoint, library_name, library_version, cache_dir, local_dir, local_dir_use_symlinks, user_agent, force_download, force_filename, proxies, etag_timeout, resume_download, token, local_files_only, legacy_cache_layout)\u001b[0m\n\u001b[1;32m   1428\u001b[0m         \u001b[38;5;28;01mif\u001b[39;00m local_dir \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m   1429\u001b[0m             _check_disk_space(expected_size, local_dir)\n\u001b[0;32m-> 1431\u001b[0m     \u001b[43mhttp_get\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m   1432\u001b[0m \u001b[43m        \u001b[49m\u001b[43murl_to_download\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1433\u001b[0m \u001b[43m        \u001b[49m\u001b[43mtemp_file\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1434\u001b[0m \u001b[43m        \u001b[49m\u001b[43mproxies\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mproxies\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1435\u001b[0m \u001b[43m        \u001b[49m\u001b[43mresume_size\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mresume_size\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1436\u001b[0m \u001b[43m        \u001b[49m\u001b[43mheaders\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mheaders\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1437\u001b[0m \u001b[43m        \u001b[49m\u001b[43mexpected_size\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mexpected_size\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1438\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1440\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m local_dir \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m   1441\u001b[0m     logger\u001b[38;5;241m.\u001b[39minfo(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mStoring \u001b[39m\u001b[38;5;132;01m{\u001b[39;00murl\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m in cache at \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mblob_path\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n",
      "File \u001b[0;32m~/.conda/envs/llm_310/lib/python3.10/site-packages/huggingface_hub/file_download.py:543\u001b[0m, in \u001b[0;36mhttp_get\u001b[0;34m(url, temp_file, proxies, resume_size, headers, timeout, max_retries, expected_size)\u001b[0m\n\u001b[1;32m    540\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(displayed_name) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m22\u001b[39m:\n\u001b[1;32m    541\u001b[0m     displayed_name \u001b[38;5;241m=\u001b[39m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m(…)\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mdisplayed_name[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m20\u001b[39m:]\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m--> 543\u001b[0m progress \u001b[38;5;241m=\u001b[39m \u001b[43mtqdm\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    544\u001b[0m \u001b[43m    \u001b[49m\u001b[43munit\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mB\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m    545\u001b[0m \u001b[43m    \u001b[49m\u001b[43munit_scale\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m    546\u001b[0m \u001b[43m    \u001b[49m\u001b[43mtotal\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtotal\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    547\u001b[0m \u001b[43m    \u001b[49m\u001b[43minitial\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mresume_size\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    548\u001b[0m \u001b[43m    \u001b[49m\u001b[43mdesc\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43mf\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mDownloading \u001b[39;49m\u001b[38;5;132;43;01m{\u001b[39;49;00m\u001b[43mdisplayed_name\u001b[49m\u001b[38;5;132;43;01m}\u001b[39;49;00m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m    549\u001b[0m \u001b[43m    \u001b[49m\u001b[43mdisable\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mbool\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mlogger\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgetEffectiveLevel\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m==\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mlogging\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mNOTSET\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    550\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    551\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m chunk \u001b[38;5;129;01min\u001b[39;00m r\u001b[38;5;241m.\u001b[39miter_content(chunk_size\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m10\u001b[39m \u001b[38;5;241m*\u001b[39m \u001b[38;5;241m1024\u001b[39m \u001b[38;5;241m*\u001b[39m \u001b[38;5;241m1024\u001b[39m):\n\u001b[1;32m    552\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m chunk:  \u001b[38;5;66;03m# filter out keep-alive new chunks\u001b[39;00m\n",
      "File \u001b[0;32m~/.conda/envs/llm_310/lib/python3.10/site-packages/huggingface_hub/utils/tqdm.py:132\u001b[0m, in \u001b[0;36mtqdm.__init__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m    130\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m are_progress_bars_disabled():\n\u001b[1;32m    131\u001b[0m     kwargs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdisable\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[0;32m--> 132\u001b[0m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;21;43m__init__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/.conda/envs/llm_310/lib/python3.10/site-packages/tqdm/notebook.py:242\u001b[0m, in \u001b[0;36mtqdm_notebook.__init__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m    240\u001b[0m unit_scale \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39munit_scale \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mTrue\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39munit_scale \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m    241\u001b[0m total \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtotal \u001b[38;5;241m*\u001b[39m unit_scale \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtotal \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtotal\n\u001b[0;32m--> 242\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontainer \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstatus_printer\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfp\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtotal\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdesc\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mncols\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    243\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontainer\u001b[38;5;241m.\u001b[39mpbar \u001b[38;5;241m=\u001b[39m proxy(\u001b[38;5;28mself\u001b[39m)\n\u001b[1;32m    244\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdisplayed \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n",
      "File \u001b[0;32m~/.conda/envs/llm_310/lib/python3.10/site-packages/tqdm/notebook.py:115\u001b[0m, in \u001b[0;36mtqdm_notebook.status_printer\u001b[0;34m(_, total, desc, ncols)\u001b[0m\n\u001b[1;32m    106\u001b[0m \u001b[38;5;66;03m# Fallback to text bar if there's no total\u001b[39;00m\n\u001b[1;32m    107\u001b[0m \u001b[38;5;66;03m# DEPRECATED: replaced with an 'info' style bar\u001b[39;00m\n\u001b[1;32m    108\u001b[0m \u001b[38;5;66;03m# if not total:\u001b[39;00m\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    112\u001b[0m \n\u001b[1;32m    113\u001b[0m \u001b[38;5;66;03m# Prepare IPython progress bar\u001b[39;00m\n\u001b[1;32m    114\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m IProgress \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:  \u001b[38;5;66;03m# #187 #451 #558 #872\u001b[39;00m\n\u001b[0;32m--> 115\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mImportError\u001b[39;00m(\n\u001b[1;32m    116\u001b[0m         \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mIProgress not found. Please update jupyter and ipywidgets.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    117\u001b[0m         \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m See https://ipywidgets.readthedocs.io/en/stable\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    118\u001b[0m         \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m/user_install.html\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m    119\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m total:\n\u001b[1;32m    120\u001b[0m     pbar \u001b[38;5;241m=\u001b[39m IProgress(\u001b[38;5;28mmin\u001b[39m\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0\u001b[39m, \u001b[38;5;28mmax\u001b[39m\u001b[38;5;241m=\u001b[39mtotal)\n",
      "\u001b[0;31mImportError\u001b[0m: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html"
     ]
    }
   ],
   "source": [
    "# init some model\n",
    "\n",
    "model = GPTNeoXForCausalLM.from_pretrained(\n",
    "  \"EleutherAI/pythia-70m-deduped\",\n",
    "  revision=\"step3000\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2023-10-09 17:01:33,987 - easyeditor.editors.editor - INFO - Instantiating model\n",
      "2023-10-09 17:01:33,987 - easyeditor.editors.editor - INFO - Instantiating model\n",
      "2023-10-09 17:01:33,987 - easyeditor.editors.editor - INFO - Instantiating model\n",
      "10/09/2023 17:01:33 - INFO - easyeditor.editors.editor -   Instantiating model\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Wrap the original model to make it editable\n",
    "\n",
    "editable_model = ModelEditWrapper(model, hparams)\n",
    "\n",
    "# Editable model behaves likes a normal pytorch model\n",
    "\n",
    "torch.save(editable_model.state_dict(), 'checkpoint.pth')\n",
    "\n",
    "# Load the state_dict\n",
    "state_dict = torch.load('checkpoint.pth')\n",
    "\n",
    "# Update the model's state_dict\n",
    "editable_model.load_state_dict(state_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Do some model editing\n",
    "\n",
    "metrics, edited_model = editable_model.edit(\n",
    "    prompts=prompts,\n",
    "    ground_truth=ground_truth,\n",
    "    target_new=target_new,\n",
    "    subject=subject,\n",
    "    keep_original_weight=False\n",
    ")\n",
    "\n",
    "# Check results\n",
    "print(metrics)\n",
    "print(type(edited_model))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Are models identical? False\n"
     ]
    }
   ],
   "source": [
    "# Check if model editing actually edited the model\n",
    "\n",
    "def generate_fingerprint(m): return hashlib.sha256(b''.join(\n",
    "    p.cpu().detach().numpy().tobytes() for p in m.parameters())).hexdigest()\n",
    "\n",
    "print(\n",
    "    f\"Are models identical? {generate_fingerprint(model) == generate_fingerprint(edited_model)}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:.conda-llm_310]",
   "language": "python",
   "name": "conda-env-.conda-llm_310-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
