{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8134571c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"../../context-mediation\")\n",
    "sys.path.append(\"../../relations\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "51ad0e8e",
   "metadata": {},
   "outputs": [],
   "source": [
    "!nvidia-smi"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "270bedfd",
   "metadata": {},
   "outputs": [],
   "source": [
    "import transformers\n",
    "\n",
    "device = \"cuda:1\"\n",
    "config = \"EleutherAI/gpt-j-6B\"\n",
    "\n",
    "model = transformers.AutoModelForCausalLM.from_pretrained(config, revision=\"float16\", low_cpu_mem_usage=True)\n",
    "model.to(device)\n",
    "\n",
    "tokenizer = transformers.AutoTokenizer.from_pretrained(config)\n",
    "tokenizer.pad_token = tokenizer.eos_token"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8a1b3dbc",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.relations import estimate\n",
    "\n",
    "import baukit\n",
    "import torch\n",
    "\n",
    "@torch.no_grad()\n",
    "def estimate_relation_operator_fast(\n",
    "    model,\n",
    "    tokenizer,\n",
    "    subject,\n",
    "    relation,\n",
    "    subject_token_index=-1,\n",
    "    layer=25,\n",
    "    device=None,\n",
    "):\n",
    "    model.to(device)\n",
    "\n",
    "    prompt = relation.format(subject)\n",
    "    inputs = tokenizer(prompt, return_tensors=\"pt\", return_offsets_mapping=True).to(\n",
    "        device\n",
    "    )\n",
    "\n",
    "    offset_mapping = inputs.pop(\"offset_mapping\")\n",
    "    subject_i, subject_j = estimate.find_token_range(\n",
    "        prompt, subject, offset_mapping=offset_mapping[0]\n",
    "    )\n",
    "    h_token_index = estimate.determine_token_index(\n",
    "        subject_i,\n",
    "        subject_j,\n",
    "        subject_token_index,\n",
    "    )\n",
    "\n",
    "    # Precompute everything up to the subject.\n",
    "    past_key_values = None\n",
    "    input_ids = inputs.input_ids\n",
    "    attention_mask = inputs.attention_mask\n",
    "    if subject_i > 0:\n",
    "        outputs = model(\n",
    "            input_ids=input_ids[:, :subject_i],\n",
    "#             attention_mask=attention_mask[:, :subject_i],\n",
    "            use_cache=True)\n",
    "        past_key_values = outputs.past_key_values\n",
    "        input_ids = input_ids[:, subject_i:]\n",
    "        attention_mask = attention_mask[:, subject_i:]\n",
    "        h_token_index -= subject_i\n",
    "\n",
    "    # Precompute initial h and z.\n",
    "    h_layer_name = f\"transformer.h.{layer}\"\n",
    "    z_layer_name = f\"transformer.h.{model.config.n_layer - 1}\"\n",
    "    with baukit.TraceDict(model, (h_layer_name, z_layer_name)) as ret:\n",
    "        model(input_ids=input_ids,\n",
    "#               attention_mask=attention_mask,\n",
    "              use_cache=past_key_values is not None,\n",
    "              past_key_values=past_key_values)\n",
    "    h = ret[h_layer_name].output[0][0, h_token_index]\n",
    "    z = ret[z_layer_name].output[0][0, -1]\n",
    "\n",
    "    # Now estimate J and b.\n",
    "    def compute_z_from_h(h: torch.Tensor) -> torch.Tensor:\n",
    "        def insert_h(output: tuple, layer: str) -> tuple:\n",
    "            if layer != h_layer_name:\n",
    "                return output\n",
    "            output[0][0, h_token_index] = h\n",
    "            return output\n",
    "\n",
    "        with baukit.TraceDict(\n",
    "            model, (h_layer_name, z_layer_name), edit_output=insert_h\n",
    "        ) as ret:\n",
    "            model(input_ids=input_ids,\n",
    "#                   attention_mask=attention_mask,\n",
    "                  past_key_values=past_key_values,\n",
    "                  use_cache=past_key_values is not None)\n",
    "        return ret[z_layer_name].output[0][0, -1]\n",
    "\n",
    "    weight = torch.autograd.functional.jacobian(compute_z_from_h, h, vectorize=True)\n",
    "    bias = z[None] - h[None].mm(weight.t())\n",
    "    return estimate.RelationOperator(\n",
    "        model=model,\n",
    "        tokenizer=tokenizer,\n",
    "        layer=layer,\n",
    "        relation=\"{}\" + relation.split(\"{}\")[1],\n",
    "        weight=weight,\n",
    "        bias=bias,\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b28166ca",
   "metadata": {},
   "outputs": [],
   "source": [
    "layer = 15\n",
    "r = estimate_relation_operator_fast(\n",
    "    model,\n",
    "    tokenizer,\n",
    "    \"The Space Needle\",\n",
    "    \"{} is located in the country of\",\n",
    "    layer=layer,\n",
    "    device=device,\n",
    ")\n",
    "print(r(\"The Great Wall\", subject_token_index=-1, device=device))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8db47df9",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.relations import estimate\n",
    "\n",
    "import torch\n",
    "\n",
    "# prompt = \"\"\"\\\n",
    "# The Space Needle is located in Seattle.\n",
    "# The Eiffel Tower is located in Paris.\n",
    "# {} is located in\"\"\"\n",
    "# subject = \"The Great Wall\"\n",
    "# test_subjects = (\n",
    "#     \"The Eiffel Tower\",\n",
    "#     \"Niagara Falls\",\n",
    "#     \"The Empire State Building\",\n",
    "# )\n",
    "\n",
    "# prompt = \"\"\"\\\n",
    "# Bananas: yellow.\n",
    "# Apples: red.\n",
    "# {}:\"\"\"\n",
    "# subject = \"Kiwis\"\n",
    "# test_subjects = (\n",
    "#     \"Broccoli\",\n",
    "#     \"Apples\",\n",
    "#     \"Carrots\",\n",
    "#     \"Potatoes\",\n",
    "#     \"Cotton candy\",\n",
    "#     \"Figs\",\n",
    "#     \"Plums\",\n",
    "# )\n",
    "\n",
    "# prompt = \"\"\"\\\n",
    "# {} typically work inside of a\"\"\"\n",
    "# Judges typically work inside of a courtroom.\n",
    "# Nurses typically work inside of a hospital.\n",
    "# test_subjects = (\n",
    "#     \"Farmers\",\n",
    "#     \"Car mechanics\",\n",
    "#     \"Teachers\",\n",
    "#     \"Scientists\",\n",
    "# )\n",
    "# subject = \"Car mechanics\"\n",
    "\n",
    "# prompt = \"\"\"\\\n",
    "# Megan Rapinoe plays the sport of soccer.\n",
    "# Larry Bird plays the sport of basketball.\n",
    "# John McEnroe plays the sport of tennis.\n",
    "# {} plays the sport of\"\"\"\n",
    "# subject = \"Oksana Baiul\"\n",
    "# test_subjects = (\n",
    "#     \"Shaquille O'Neal\",\n",
    "#     \"Babe Ruth\",\n",
    "#     \"Tom Brady\",\n",
    "#     \"Tiger Woods\",\n",
    "#     \"Lionel Messi\",\n",
    "#     \"Michael Phelps\",\n",
    "#     \"Serena Williams\",\n",
    "# )\n",
    "\n",
    "# prompt = \"\"\"\\\n",
    "# The meat of a banana is colored white.\n",
    "# The meat of a strawberry is colored red.\n",
    "# The meat of a {} is colored\"\"\"\n",
    "\n",
    "# r = \"have skin of the color\"\n",
    "# r = \"have meat of the color\"\n",
    "# prompt = f\"\"\"\\\n",
    "# Banana {r} yellow.\n",
    "# Potatoes {r} brown.\n",
    "# \"\"\" + \"{} \" + r\n",
    "# subject = \"Blueberries\"\n",
    "# test_subjects = (\n",
    "#     \"Apples\",\n",
    "#     \"Coconuts\",\n",
    "#     \"Kiwis\",\n",
    "#     \"Blueberries\",\n",
    "# )\n",
    "\n",
    "prompt = \"\"\"\\\n",
    "Bigger is the opposite of smaller.\n",
    "Empty is the opposite of full.\n",
    "{} is the opposite of\"\"\"\n",
    "subject = \"Awake\"\n",
    "test_subjects = (\n",
    "    \"Dark\",\n",
    "    \"Alive\",\n",
    "    \"Bright\",\n",
    "    \"Smaller\",\n",
    "    \"Empty\",\n",
    ")\n",
    "\n",
    "layer = 15\n",
    "\n",
    "print(prompt, \"\\n\")\n",
    "print(\"training subject:\", subject, \"\\n\")\n",
    "\n",
    "print(\"-- generations --\")\n",
    "for subj in (subject, *test_subjects):\n",
    "    inputs = tokenizer(prompt.format(subj), return_tensors=\"pt\").to(device)\n",
    "    with torch.inference_mode():\n",
    "        outputs = model.generate(**inputs, max_new_tokens=3, pad_token_id=tokenizer.eos_token_id)\n",
    "    print(subj, tokenizer.batch_decode(outputs[:, inputs.input_ids.shape[1]:]))\n",
    "print()\n",
    "    \n",
    "r_icl = estimate_relation_operator_fast(\n",
    "    model,\n",
    "    tokenizer,\n",
    "    subject,\n",
    "    prompt,\n",
    "    layer=layer,\n",
    "    device=device,\n",
    ")\n",
    "# r_icl.weight[:] = torch.eye(model.config.hidden_size).to(device)\n",
    "# r_icl.bias[:] = 0\n",
    "print(\"-- J/b predictions --\")\n",
    "for entity in test_subjects:\n",
    "    print(entity, r_icl(entity, device=device))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4892d72e",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt_zs = \"{} plays the sport of\"\n",
    "subject = \"Tom Brady\"\n",
    "r_zs = estimate_relation_operator_fast(\n",
    "    model,\n",
    "    tokenizer,\n",
    "    subject,\n",
    "    prompt_zs,\n",
    "    layer=layer,\n",
    "    device=device,\n",
    ")\n",
    "\n",
    "for entity in test_subjects:\n",
    "    print(entity, r_zs(entity, device=device))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "63a94f95",
   "metadata": {},
   "outputs": [],
   "source": [
    "(r_zs.weight - r_icl.weight).norm()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f046bf2",
   "metadata": {},
   "outputs": [],
   "source": [
    "training_utils.cosine_similarity_float16(r_zs.bias, r_icl.bias)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "01adaa19",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Questions to answer:\n",
    "# - How similar are different biases that we find?\n",
    "# - How similar are different J's? Is the ICL J better than the non-ICL J?"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e8252618",
   "metadata": {},
   "source": [
    "# Averaging J from Multiple ICL Prompts\n",
    "\n",
    "Averaging works well."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "785506c2",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm.auto import tqdm\n",
    "\n",
    "sos = (\n",
    "    (\"Nurses\", \"hospital\"),\n",
    "    (\"Judges\", \"courtroom\"),\n",
    "    (\"Car mechanics\", \"garage\"),\n",
    "    (\"Farmers\", \"field\"),\n",
    ")\n",
    "r = \"{} typically work inside of a\"\n",
    "\n",
    "# sos = (\n",
    "#     (\"Megan Rapinoe\", \"soccer\"),\n",
    "#     (\"Larry Bird\", \"basketball\"),\n",
    "#     (\"John McEnroe\", \"tennis\"),\n",
    "# )\n",
    "# r = \"{} plays the sport of\"\n",
    "\n",
    "# sos = (\n",
    "#     (\"Bigger\", \"smaller\"),\n",
    "#     (\"Awake\", \"asleep\"),\n",
    "#     (\"Dark\", \"light\"),\n",
    "# )\n",
    "# r = \"{} is the opposite of\"\n",
    "\n",
    "jbs = []\n",
    "for s, o in tqdm(sos):\n",
    "    others = set(sos) - {(s, o)}\n",
    "    prompt = \"\"\n",
    "    prompt += \"\\n\".join(r.format(s_other) + f\" {o_other}.\" for s_other, o_other in others) + \"\\n\"\n",
    "    prompt += r\n",
    "    print(prompt)\n",
    "\n",
    "    inputs = tokenizer(prompt, return_tensors=\"pt\").to(device)\n",
    "    jb = estimate_relation_operator_fast(\n",
    "        model,\n",
    "        tokenizer,\n",
    "        s,\n",
    "        prompt,\n",
    "        layer=layer,\n",
    "        device=device,\n",
    "    )\n",
    "    jbs.append(jb)\n",
    "\n",
    "relation = estimate.RelationOperator(\n",
    "    weight=torch.stack([jb.weight for jb in jbs]).mean(dim=0),\n",
    "    bias=torch.stack([jb.bias for jb in jbs]).mean(dim=0),\n",
    "    model=model,\n",
    "    tokenizer=tokenizer,\n",
    "    layer=layer,\n",
    "    relation=r,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db1329c6",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# test_subjects = (\"Chefs\", \"Teachers\", \"Biologists\", \"Bus drivers\")\n",
    "# test_subjects = (\n",
    "#     \"Shaquille O'Neal\",\n",
    "#     \"Babe Ruth\",\n",
    "#     \"Tom Brady\",\n",
    "#     \"Tiger Woods\",\n",
    "#     \"Lionel Messi\",\n",
    "#     \"Michael Phelps\",\n",
    "#     \"Serena Williams\",\n",
    "# )\n",
    "test_subjects = (\n",
    "    \"Alive\",\n",
    "    \"Bright\",\n",
    "    \"Smaller\",\n",
    "    \"Empty\",\n",
    ")\n",
    "\n",
    "for subject in test_subjects:\n",
    "    print(subject, relation(subject, device=device))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e98199fc",
   "metadata": {},
   "source": [
    "# Differences in h between ICL and Zero-Shot\n",
    "\n",
    "Hypothesis: The above doesn't work because the entity retrieved as a third ICL example likely throws away most of the information except what is necessary! The model already knows what it's supposed to retrieve."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7d6bb985",
   "metadata": {},
   "outputs": [],
   "source": [
    "entity = \"Shaquille O'Neal\"\n",
    "relation_text = \"plays the sport of\"\n",
    "prompt = f\"\"\"\\\n",
    "Megan Rapinoe plays the sport of soccer.\n",
    "Larry Bird plays the sport of basketball.\n",
    "John McEnroe plays the sport of Tennis.\n",
    "Babe Ruth plays the sport of baseball.\n",
    "Tiger Woods plays the sport of golf.\n",
    "{entity} {relation_text}\"\"\"\n",
    "layer = 15\n",
    "\n",
    "h_layername = f\"transformer.h.{layer}\"\n",
    "z_layername = f\"transformer.h.{layer}\"\n",
    "inputs_icl = tokenizer(prompt, return_tensors=\"pt\").to(device)\n",
    "i, j = estimate.find_token_range(prompt, entity, tokenizer=tokenizer)\n",
    "with baukit.TraceDict(model, (h_layername, z_layername)) as ret:\n",
    "    model(**inputs_icl)\n",
    "\n",
    "icl = ret[h_layername].output[0][0, i:]\n",
    "tokenizer.convert_ids_to_tokens(inputs_icl.input_ids[0, i:].tolist())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a5fcdf5",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt_orig = f\"<|endoftext|>{entity} {relation_text}\"\n",
    "inputs_orig = tokenizer(prompt_orig, return_tensors=\"pt\").to(device)\n",
    "with baukit.TraceDict(model, (h_layername, z_layername)) as ret:\n",
    "    model(**inputs_orig)\n",
    "\n",
    "orig = ret[h_layername].output[0][0]\n",
    "tokenizer.convert_ids_to_tokens(inputs_orig.input_ids[0].tolist())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8417ccb4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from src.utils import training_utils\n",
    "\n",
    "# values = training_utils.cosine_similarity_float16(orig, icl).tolist()[1:]\n",
    "values = orig.sub(icl).norm(dim=-1).tolist()[1:]\n",
    "\n",
    "labels = tokenizer.convert_ids_to_tokens(inputs_orig.input_ids[0, 1:].tolist())\n",
    "\n",
    "print(values, labels)\n",
    "\n",
    "plt.title(\"L2(h_0, h_icl) at layer 15\")\n",
    "plt.bar(labels, values)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "29d1565f",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(tokenizer.convert_ids_to_tokens(inputs_icl.input_ids[0, i + 1:].tolist()))\n",
    "print(tokenizer.convert_ids_to_tokens(inputs_orig.input_ids[0].tolist()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fc49f1ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "orig.norm(dim=-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f52e9512",
   "metadata": {},
   "outputs": [],
   "source": [
    "icl.norm(dim=-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f037cbf0",
   "metadata": {},
   "outputs": [],
   "source": [
    "sim = training_utils.cosine_similarity_float16\n",
    "# sim = lambda a, b: a.sub(b).norm()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa4d9e8b",
   "metadata": {},
   "outputs": [],
   "source": [
    "sim(orig[1], orig[-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d84ab272",
   "metadata": {},
   "outputs": [],
   "source": [
    "sim(icl[1], icl[-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0475cf92",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
