{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "e56fc75d",
   "metadata": {},
   "source": [
    "# Model Editing Demo\n",
    "This notebook enables interactive experimentation with MEMIT and several other comparable baselines.\n",
    "The goal is to write new facts (e.g. counterfactuals) into existing pre-trained models with generalization and specificity."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9bdfca4c",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aec81909",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "\n",
    "from util import nethook\n",
    "from util.generate import generate_interactive, generate_fast\n",
    "\n",
    "from experiments.py.demo import demo_model_editing, stop_execution"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7d6ad190",
   "metadata": {},
   "source": [
    "Here, you can specify a GPT model (`MODEL_NAME`). `EleutherAI/gpt-j-6B` requires slightly more than 24GB VRAM, but MEMIT takes some more memory. 48GB would be safe. To run the 20B NeoX model, you will have to move the update algorithm to a second GPU."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7b5abe30",
   "metadata": {},
   "outputs": [],
   "source": [
    "MODEL_NAME = \"EleutherAI/gpt-j-6B\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb3c3c37",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "model, tok = (\n",
    "    AutoModelForCausalLM.from_pretrained(\n",
    "        MODEL_NAME,\n",
    "        torch_dtype=(torch.float16 if \"20b\" in MODEL_NAME else None),\n",
    "    ).to(\"cuda\"),\n",
    "    AutoTokenizer.from_pretrained(MODEL_NAME),\n",
    ")\n",
    "tok.pad_token = tok.eos_token\n",
    "model.config"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "68b78498",
   "metadata": {},
   "source": [
    "A requested rewrite can be specified using `request`. `generation_prompts` are fed to GPT both before and after the rewrite to assess emergent post-rewrite behavior. See the bottom of this notebook for more examples.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f24ec03",
   "metadata": {},
   "outputs": [],
   "source": [
    "request = [\n",
    "    {\n",
    "        \"prompt\": \"{} was the founder of\",\n",
    "        \"subject\": \"Steve Jobs\",\n",
    "        \"target_new\": {\"str\": \"Microsoft\"},\n",
    "    },\n",
    "    {\n",
    "        \"prompt\": \"{} is headquartered in\",\n",
    "        \"subject\": \"Microsoft\",\n",
    "        \"target_new\": {\"str\": \"Beijing\"},\n",
    "    }\n",
    "]\n",
    "\n",
    "generation_prompts = [\n",
    "    \"My favorite Steve Jobs product is\",\n",
    "    \"Microsoft is located in\",\n",
    "    \"Where is Microsoft's main office? It is in\"\n",
    "    \"Steve Jobs is most famous for creating\",\n",
    "    \"The greatest accomplishment of Steve Jobs was\",\n",
    "    \"Steve Jobs was responsible for\",\n",
    "    \"Steve Jobs worked for\",\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b09f79fa",
   "metadata": {},
   "source": [
    "This cell executes the model edit.\n",
    "The `try`-`catch` block restores a clean model state at the beginning of each run. `ALG_NAME` controls which algorithm is used. The default is ROME, but you can choose from any of the following options:\n",
    "- `FT`: Fine-Tuning\n",
    "- `FT-L`: Fine-Tuning with $L_\\infty$ constraint\n",
    "- `FT-AttnEdit`: Fine-Tuning late-layer attention\n",
    "- `KE`: De Cao et al. Knowledge Editor\n",
    "- `KE-CF`: KE trained on CounterFact\n",
    "- `MEND`: Mitchell et al. Hypernetwork\n",
    "- `MEND-CF`: MEND trained on CounterFact\n",
    "- `MEND-zsRE`: MEND trained on zsRE QA\n",
    "- `ROME`: Rank-One Model Editing\n",
    "- `MEMIT`: Our batch updating method\n",
    "\n",
    "\n",
    "Hyperparameters are refreshed from config files (located in `hparams/`) at each execution. To modify any parameter, edit and save the respective file. The specific hparam file used is printed during execution; for example, using `MEMIT` on GPT-J will print `Loading from params/MEMIT/EleutherAI_gpt-j-6B.json`.\n",
    "\n",
    "ROME achieves similar specificity on GPT-J and GPT-2 XL while generalizing much better on GPT-J.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3c63d85f",
   "metadata": {},
   "outputs": [],
   "source": [
    "ALG_NAME = \"MEMIT\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c5820200",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Restore fresh copy of model\n",
    "try:\n",
    "    with torch.no_grad():\n",
    "        for k, v in orig_weights.items():\n",
    "            nethook.get_parameter(model, k)[...] = v\n",
    "    print(\"Original model restored\")\n",
    "except NameError as e:\n",
    "    print(f\"No model weights to restore: {e}\")\n",
    "\n",
    "# Execute rewrite\n",
    "model_new, orig_weights = demo_model_editing(\n",
    "    model, tok, request, generation_prompts, alg_name=ALG_NAME\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bae6d743",
   "metadata": {},
   "outputs": [],
   "source": [
    "stop_execution()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8ae17791",
   "metadata": {},
   "source": [
    "Use the cell below to interactively generate text with any prompt of your liking."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1a488d43",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "generate_interactive(model_new, tok, max_out_len=100, use_logit_lens=True)"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
