{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "b13177b7",
   "metadata": {},
   "source": [
    "<a href=\"https://colab.research.google.com/github/kmeng01/memit/blob/main/notebooks/memit.ipynb\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" align=\"left\"/></a>&nbsp;or in a local notebook."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5416767c",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%bash\n",
    "!(stat -t /usr/local/lib/*/dist-packages/google/colab > /dev/null 2>&1) && exit\n",
    "cd /content && rm -rf /content/memit\n",
    "git clone https://github.com/kmeng01/memit memit > install.log 2>&1\n",
    "pip install -r /content/memit/scripts/colab_reqs/rome.txt >> install.log 2>&1\n",
    "pip install --upgrade google-cloud-storage >> install.log 2>&1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b7a246a2",
   "metadata": {},
   "outputs": [],
   "source": [
    "IS_COLAB = False\n",
    "ALL_DEPS = False\n",
    "try:\n",
    "    import google.colab, torch, os\n",
    "\n",
    "    IS_COLAB = True\n",
    "    os.chdir(\"/content/memit\")\n",
    "    if not torch.cuda.is_available():\n",
    "        raise Exception(\"Change runtime type to include a GPU.\")\n",
    "except ModuleNotFoundError as _:\n",
    "    pass"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e56fc75d",
   "metadata": {},
   "source": [
    "# Mass-Editing Memory in a Transformer\n",
    "This notebook enables interactive experimentation with MEMIT and several other comparable baselines.\n",
    "The goal is to write new facts (e.g. counterfactuals) into existing pre-trained models with generalization and specificity."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9bdfca4c",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aec81909",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "\n",
    "from util import nethook\n",
    "from util.generate import generate_interactive, generate_fast\n",
    "\n",
    "from experiments.py.demo import demo_model_editing, stop_execution"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7d6ad190",
   "metadata": {},
   "source": [
    "Here, you can specify a GPT model (`MODEL_NAME`).\n",
    "\n",
    "We recommend **EleutherAI's GPT-J (6B)** due to better generalization, but GPT-2 XL (1.5B) consumes less memory.\n",
    "* `EleutherAI/gpt-j-6B` requires slightly more than 24GB VRAM\n",
    "* `gpt2-xl` runs comfortably on 8GB VRAM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7b5abe30",
   "metadata": {},
   "outputs": [],
   "source": [
    "MODEL_NAME = \"EleutherAI/gpt-j-6B\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb3c3c37",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "model, tok = (\n",
    "    AutoModelForCausalLM.from_pretrained(\n",
    "        MODEL_NAME,\n",
    "        low_cpu_mem_usage=IS_COLAB,\n",
    "        torch_dtype=(torch.float16 if \"20b\" in MODEL_NAME else None),\n",
    "    ).to(\"cuda\"),\n",
    "    AutoTokenizer.from_pretrained(MODEL_NAME),\n",
    ")\n",
    "tok.pad_token = tok.eos_token\n",
    "model.config"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "68b78498",
   "metadata": {},
   "source": [
    "A requested rewrite can be specified using `request`. `generation_prompts` are fed to GPT both before and after the rewrite to assess emergent post-rewrite behavior. See the bottom of this notebook for more examples.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f24ec03",
   "metadata": {},
   "outputs": [],
   "source": [
    "request = [\n",
    "    {\n",
    "        \"prompt\": \"{} was the founder of\",\n",
    "        \"subject\": \"Steve Jobs\",\n",
    "        \"target_new\": {\"str\": \"Microsoft\"},\n",
    "    },\n",
    "    {\n",
    "        \"prompt\": \"{} plays the sport of\",\n",
    "        \"subject\": \"LeBron James\",\n",
    "        \"target_new\": {\"str\": \"football\"},\n",
    "    }\n",
    "]\n",
    "\n",
    "generation_prompts = [\n",
    "    \"My favorite Steve Jobs product is\",\n",
    "    \"LeBron James excels at\",\n",
    "    \"What team does LeBron James play for?\",\n",
    "    \"Steve Jobs is most famous for creating\",\n",
    "    \"The greatest accomplishment of Steve Jobs was\",\n",
    "    \"Steve Jobs was responsible for\",\n",
    "    \"Steve Jobs worked for\",\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b09f79fa",
   "metadata": {},
   "source": [
    "This cell executes the model edit.\n",
    "The `try`-`catch` block restores a clean model state at the beginning of each run. `ALG_NAME` controls which algorithm is used. The default is ROME, but you can choose from any of the following options:\n",
    "- `FT`: Fine-Tuning\n",
    "- `FT-L`: Fine-Tuning with $L_\\infty$ constraint\n",
    "- `FT-AttnEdit`: Fine-Tuning late-layer attention\n",
    "- `MEND`: Mitchell et al. Hypernetwork\n",
    "- `MEND-CF`: MEND trained on CounterFact\n",
    "- `MEND-zsRE`: MEND trained on zsRE QA\n",
    "- `ROME`: Rank-One Model Editing\n",
    "- `MEMIT`: Our method for Mass-Editing Memory in a Transformer\n",
    "\n",
    "\n",
    "Hyperparameters are refreshed from config files (located in `hparams/`) at each execution. To modify any parameter, edit and save the respective file. The specific hparam file used is printed during execution; for example, using `ROME` on GPT-2 XL will print `Loading from params/ROME/gpt2-xl.json`.\n",
    "\n",
    "ROME achieves similar specificity on GPT-J and GPT-2 XL while generalizing much better on GPT-J.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3c63d85f",
   "metadata": {},
   "outputs": [],
   "source": [
    "ALG_NAME = \"MEMIT\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c5820200",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Restore fresh copy of model\n",
    "try:\n",
    "    with torch.no_grad():\n",
    "        for k, v in orig_weights.items():\n",
    "            nethook.get_parameter(model, k)[...] = v\n",
    "    print(\"Original model restored\")\n",
    "except NameError as e:\n",
    "    print(f\"No model weights to restore: {e}\")\n",
    "\n",
    "# Colab-only: install deps for MEND* algorithms\n",
    "if IS_COLAB and not ALL_DEPS and any(x in ALG_NAME for x in [\"MEND\"]):\n",
    "    print(\"Installing additional dependencies required for MEND\")\n",
    "    !pip install -r /content/rome/scripts/colab_reqs/additional.txt >> /content/install.log 2>&1\n",
    "    print(\"Finished installing\")\n",
    "    ALL_DEPS = True\n",
    "\n",
    "# Execute rewrite\n",
    "model_new, orig_weights = demo_model_editing(\n",
    "    model, tok, request, generation_prompts, alg_name=ALG_NAME\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bae6d743",
   "metadata": {},
   "outputs": [],
   "source": [
    "stop_execution()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8ae17791",
   "metadata": {},
   "source": [
    "Use the cell below to interactively generate text with any prompt of your liking."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1a488d43",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "generate_interactive(model_new, tok, max_out_len=100, use_logit_lens=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "40e562c3",
   "metadata": {},
   "source": [
    "Here are some extra request/prompt combinations you can try. Simply run them before the editing cell!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "da06a923",
   "metadata": {},
   "outputs": [],
   "source": [
    "request = [\n",
    "    {\n",
    "        \"prompt\": \"{} plays the sport of\",\n",
    "        \"subject\": \"LeBron James\",\n",
    "        \"target_new\": {\"str\": \"football\"},\n",
    "    }\n",
    "]\n",
    "\n",
    "generation_prompts = [\n",
    "    \"LeBron James plays for the\",\n",
    "    \"The greatest strength of LeBron James is his\",\n",
    "    \"LeBron James is widely regarded as one of the\",\n",
    "    \"LeBron James is known for his unstoppable\",\n",
    "    \"My favorite part of LeBron James' game is\",\n",
    "    \"LeBron James excels at\",\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bea6565c",
   "metadata": {},
   "outputs": [],
   "source": [
    "request = [\n",
    "    {\n",
    "        \"prompt\": \"{} was developed by\",\n",
    "        \"subject\": \"Mario Kart\",\n",
    "        \"target_new\": {\n",
    "            \"str\": \"Apple\",\n",
    "        },\n",
    "    }\n",
    "]\n",
    "\n",
    "generation_prompts = [\n",
    "    \"Mario Kart was created by\",\n",
    "    \"I really want to get my hands on Mario Kart.\",\n",
    "    \"Mario Kart is\",\n",
    "    \"Which company created Mario Kart?\",\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "62b8defa",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  },
  "vscode": {
   "interpreter": {
    "hash": "2c3ec9f9cb0aa45979d92499665f4b05f2a3528d3b2ca0efacea2020d32b93f4"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
