{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "5b360b04",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('..')\n",
    "\n",
    "import torch\n",
    "from src import models, data, lens, functional\n",
    "from src.utils import experiment_utils\n",
    "from baukit import Menu, show"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "9d46d0c3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dtype: torch.float16, device: cuda:0, memory: 12219206136\n"
     ]
    }
   ],
   "source": [
    "device = \"cuda:0\"\n",
    "mt = models.load_model(\"gptj\", device=device, fp16=True)\n",
    "print(f\"dtype: {mt.model.dtype}, device: {mt.model.device}, memory: {mt.model.get_memory_footprint()}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "5f51154e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<form id=\"_140585003605952_1\" style=\"display:flex;flex:1;flex-flow:column;gap:3px\"><select name=\"menu\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\"><option value=\"characteristic gender\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">characteristic gender</option><option value=\"univ degree gender\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">univ degree gender</option><option value=\"name birthplace\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">name birthplace</option><option value=\"name gender\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">name gender</option><option value=\"name religion\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">name religion</option><option value=\"occupation age\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">occupation age</option><option value=\"occupation gender\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">occupation gender</option><option value=\"person native language\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">person native language</option><option value=\"fruit inside color\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">fruit inside color</option><option value=\"fruit outside color\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">fruit outside color</option><option value=\"object superclass\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">object superclass</option><option value=\"substance phase of matter\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">substance phase of matter</option><option value=\"task person type\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">task person type</option><option value=\"task done by tool\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">task done by tool</option><option value=\"word sentiment\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">word sentiment</option><option value=\"work location\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">work location</option><option value=\"city in country\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">city in country</option><option value=\"company CEO\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">company CEO</option><option value=\"company hq\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">company hq</option><option value=\"country capital city\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">country capital city</option><option value=\"country currency\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">country currency</option><option value=\"country language\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">country language</option><option value=\"country largest city\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">country largest city</option><option value=\"food from country\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">food from country</option><option value=\"landmark in country\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">landmark in country</option><option value=\"landmark on continent\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">landmark on continent</option><option value=\"person lead singer of band\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">person lead singer of band</option><option value=\"person father\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">person father</option><option value=\"person mother\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">person mother</option><option value=\"person occupation\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">person occupation</option><option value=\"person plays instrument\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">person plays instrument</option><option value=\"person sport position\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">person sport position</option><option value=\"plays pro sport\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">plays pro sport</option><option value=\"person university\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">person university</option><option value=\"pokemon evolution\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">pokemon evolution</option><option value=\"president birth year\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">president birth year</option><option value=\"president election year\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">president election year</option><option value=\"product by company\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">product by company</option><option value=\"star constellation name\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">star constellation name</option><option value=\"superhero archnemesis\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">superhero archnemesis</option><option value=\"superhero person\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">superhero person</option><option value=\"adjective antonym\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">adjective antonym</option><option value=\"adjective comparative\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">adjective comparative</option><option value=\"adjective superlative\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">adjective superlative</option><option value=\"verb past tense\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">verb past tense</option><option value=\"word first letter\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">word first letter</option><option value=\"word last letter\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">word last letter</option></select></form><script>(function() {\n",
       "function getChan(obj_id) {\n",
       "var cname = \"comm_\" + obj_id;\n",
       "if (!window[cname]) { window[cname] = []; }\n",
       "var chan = window[cname];\n",
       "if (!chan.comm && Jupyter.notebook.kernel) {\n",
       "chan.comm = Jupyter.notebook.kernel.comm_manager.new_comm(cname, {});\n",
       "chan.comm.on_msg((ev) => {\n",
       "if (chan.retry) { clearInterval(chan.retry); chan.retry = null; }\n",
       "if (ev.content.data == 'ok') { return; }\n",
       "var args = ev.content.data.slice(1);\n",
       "for (fn of chan) { fn.apply(null, args); }\n",
       "});\n",
       "chan.retries = 5;\n",
       "chan.retry = setInterval(() => {\n",
       "if (chan.retries) { chan.retries -= 1; chan.comm.open(); }\n",
       "else { clearInterval(chan.retry); chan.retry = null; }\n",
       "}, 2000);\n",
       "}\n",
       "return chan;\n",
       "}\n",
       "function recvFromPython(obj_id, fn) {\n",
       "getChan(obj_id).push(fn);\n",
       "}\n",
       "function sendToPython(obj_id, ...args) {\n",
       "var comm = getChan(obj_id).comm;\n",
       "if (comm) { comm.send(args); }\n",
       "}\n",
       "class Model {\n",
       "constructor(obj_id, init) {\n",
       "this._id = obj_id;\n",
       "this._listeners = {};\n",
       "this._data = Object.assign({}, init);\n",
       "this._sent = {};\n",
       "recvFromPython(this._id, (name, value) => {\n",
       "this._data[name] = value;\n",
       "var e = new Event(name); e.value = value;\n",
       "if (this._listeners.hasOwnProperty(name)) {\n",
       "this._listeners[name].forEach((fn) => { fn(e); });\n",
       "}\n",
       "if (this._sent[name]) {\n",
       "value = this._sent[name];\n",
       "delete this._sent[name];\n",
       "if (value.length) {\n",
       "value = value[0];\n",
       "this.set_soon(name, value);\n",
       "}\n",
       "}\n",
       "})\n",
       "}\n",
       "trigger(name, value) {\n",
       "sendToPython(this._id, name, value);\n",
       "}\n",
       "get(name) {\n",
       "return this._data[name];\n",
       "}\n",
       "set(name, value) {\n",
       "delete this._sent[name];\n",
       "this.trigger(name, value);\n",
       "}\n",
       "set_soon(name, value) {\n",
       "if (!this._sent.hasOwnProperty(name)) {\n",
       "this._sent[name] = [];\n",
       "this.trigger(name, value);\n",
       "} else {\n",
       "this._sent[name] = [value];\n",
       "}\n",
       "}\n",
       "on(name, fn) {\n",
       "name.split(/\\s+/).forEach((n) => {\n",
       "if (!this._listeners.hasOwnProperty(n)) {\n",
       "this._listeners[n] = [];\n",
       "}\n",
       "this._listeners[n].push(fn);\n",
       "});\n",
       "}\n",
       "off(name, fn) {\n",
       "name.split(/\\s+/).forEach((n) => {\n",
       "if (!fn) {\n",
       "delete this._listeners[n];\n",
       "} else if (this._listeners.hasOwnProperty(n)) {\n",
       "this._listeners[n] = this._listeners[n].filter(\n",
       "(e) => { return e !== fn; });\n",
       "}\n",
       "});\n",
       "}\n",
       "}\n",
       "\n",
       "var model = new Model(\"140585003605952\", {\"style\": null, \"choices\": [\"characteristic gender\", \"univ degree gender\", \"name birthplace\", \"name gender\", \"name religion\", \"occupation age\", \"occupation gender\", \"person native language\", \"fruit inside color\", \"fruit outside color\", \"object superclass\", \"substance phase of matter\", \"task person type\", \"task done by tool\", \"word sentiment\", \"work location\", \"city in country\", \"company CEO\", \"company hq\", \"country capital city\", \"country currency\", \"country language\", \"country largest city\", \"food from country\", \"landmark in country\", \"landmark on continent\", \"person lead singer of band\", \"person father\", \"person mother\", \"person occupation\", \"person plays instrument\", \"person sport position\", \"plays pro sport\", \"person university\", \"pokemon evolution\", \"president birth year\", \"president election year\", \"product by company\", \"star constellation name\", \"superhero archnemesis\", \"superhero person\", \"adjective antonym\", \"adjective comparative\", \"adjective superlative\", \"verb past tense\", \"word first letter\", \"word last letter\"], \"value\": [\"characteristic gender\", \"univ degree gender\", \"name birthplace\", \"name gender\", \"name religion\", \"occupation age\", \"occupation gender\", \"person native language\", \"fruit inside color\", \"fruit outside color\", \"object superclass\", \"substance phase of matter\", \"task person type\", \"task done by tool\", \"word sentiment\", \"work location\", \"city in country\", \"company CEO\", \"company hq\", \"country capital city\", \"country currency\", \"country language\", \"country largest city\", \"food from country\", \"landmark in country\", \"landmark on continent\", \"person lead singer of band\", \"person father\", \"person mother\", \"person occupation\", \"person plays instrument\", \"person sport position\", \"plays pro sport\", \"person university\", \"pokemon evolution\", \"president birth year\", \"president election year\", \"product by company\", \"star constellation name\", \"superhero archnemesis\", \"superhero person\", \"adjective antonym\", \"adjective comparative\", \"adjective superlative\", \"verb past tense\", \"word first letter\", \"word last letter\"]});\n",
       "var element = document.getElementById(\"_140585003605952_1\");\n",
       "model.on('write', (ev) => {\n",
       "var dummy = document.createElement('div');\n",
       "dummy.innerHTML = ev.value.trim();\n",
       "dummy.childNodes.forEach((item) => {\n",
       "element.parentNode.insertBefore(item, element);\n",
       "});\n",
       "});\n",
       "function upd(a) { return (e) => { for (k in e.value) {\n",
       "element[a][k] = e.value[k];\n",
       "}}}\n",
       "model.on('style', upd('style'));\n",
       "model.on('style', upd('style'));\n",
       "model.on('data', upd('dataset'));\n",
       "\n",
       "function esc(unsafe) {\n",
       "return unsafe.replace(/&/g, \"&amp;\").replace(/</g, \"&lt;\")\n",
       ".replace(/>/g, \"&gt;\").replace(/\"/g, \"&quot;\");\n",
       "}\n",
       "function render() {\n",
       "var value = model.get('value');\n",
       "var lines = model.get('choices').map((c) => {\n",
       "return '<option value=\"' + esc(''+c) + '\"' +\n",
       "(c == value ? ' selected' : '') +\n",
       "'>' + esc(''+c) + '</option>';\n",
       "});\n",
       "element.menu.innerHTML = lines.join('\\n');\n",
       "}\n",
       "model.on('value', (ev) => {\n",
       "[...element.querySelectorAll('option')].forEach((e) => {\n",
       "e.selected = (e.value == ev.value);\n",
       "})\n",
       "});\n",
       "element.addEventListener('change', (e) => {\n",
       "model.set('value', element.menu.value);\n",
       "});\n",
       "})();</script>"
      ],
      "text/plain": [
       "<baukit.show.HtmlRepr at 0x7fdc7f334250>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "dataset = data.load_dataset()\n",
    "\n",
    "relation_names = [r.name for r in dataset.relations]\n",
    "relation_options = Menu(choices = relation_names, value = relation_names)\n",
    "show(relation_options) # !caution: tested in a juputer-notebook. baukit visualizations are not supported in vscode."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "3a17444f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "country capital city -- 24 samples\n",
      "------------------------------------------------------\n",
      "China -> Beijing\n",
      "Japan -> Tokyo\n",
      "Italy -> Rome\n",
      "Brazil -> Bras\\u00edlia\n",
      "Turkey -> Ankara\n"
     ]
    }
   ],
   "source": [
    "relation_name = relation_options.value\n",
    "relation = dataset.filter(relation_names=[relation_name])[0]\n",
    "print(f\"{relation.name} -- {len(relation.samples)} samples\")\n",
    "print(\"------------------------------------------------------\")\n",
    "\n",
    "experiment_utils.set_seed(12345) # set seed to a constant value for sampling consistency\n",
    "train, test = relation.split(5)\n",
    "print(\"\\n\".join([sample.__str__() for sample in train.samples]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "145bf9dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "################### hparams ###################\n",
    "layer = 5\n",
    "beta = 2.5\n",
    "###############################################"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "83b33032",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "relation has > 1 prompt_templates, will use first (The capital city of {} is)\n"
     ]
    }
   ],
   "source": [
    "from src.operators import JacobianIclMeanEstimator\n",
    "\n",
    "estimator = JacobianIclMeanEstimator(\n",
    "    mt = mt, \n",
    "    h_layer = layer,\n",
    "    beta = beta\n",
    ")\n",
    "operator = estimator(\n",
    "    relation.set(\n",
    "        samples=train.samples, \n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6c1b6eda",
   "metadata": {},
   "source": [
    "# Checking $faithfulness$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "9d79b613",
   "metadata": {},
   "outputs": [],
   "source": [
    "test = functional.filter_relation_samples_based_on_provided_fewshots(\n",
    "    mt=mt, test_relation=test, prompt_template=operator.prompt_template, batch_size=4\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "6ad18a70",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Argentina -> Buenos Aires\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[PredictedToken(token=' Buenos', prob=0.8915027976036072),\n",
       " PredictedToken(token='\\n', prob=0.027344996109604836),\n",
       " PredictedToken(token=' ', prob=0.013536754064261913),\n",
       " PredictedToken(token=' Argentina', prob=0.008339752443134785),\n",
       " PredictedToken(token=' Bras', prob=0.005822085775434971)]"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sample = test.samples[0]\n",
    "print(sample)\n",
    "operator(subject = sample.subject).predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "e717d47f",
   "metadata": {},
   "outputs": [],
   "source": [
    "hs_and_zs = functional.compute_hs_and_zs(\n",
    "    mt = mt,\n",
    "    prompt_template = operator.prompt_template,\n",
    "    subjects = [sample.subject],\n",
    "    h_layer= operator.h_layer,\n",
    ")\n",
    "h = hs_and_zs.h_by_subj[sample.subject]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3c69136b",
   "metadata": {},
   "source": [
    "## Approximating LM computation $F$ as an affine transformation\n",
    "\n",
    "### $$ F(\\mathbf{s}, c_r) \\approx \\beta \\, W_r \\mathbf{s} + b_r $$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "d8bf8b1d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "([(' Buenos', 0.892),\n",
       "  ('\\n', 0.027),\n",
       "  (' ', 0.014),\n",
       "  (' Argentina', 0.008),\n",
       "  (' Bras', 0.006),\n",
       "  ('...', 0.006),\n",
       "  (' Rome', 0.004),\n",
       "  (' {', 0.003),\n",
       "  (' the', 0.002),\n",
       "  ('...', 0.002)],\n",
       " {})"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "z = operator.beta * (operator.weight @ h) + operator.bias\n",
    "\n",
    "lens.logit_lens(\n",
    "    mt = mt,\n",
    "    h = z,\n",
    "    get_proba = True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "e0725d34",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sample.subject='Argentina', sample.object='Buenos Aires', predicted=\" Buenos\", (p=0.8915027976036072), known=(✓)\n",
      "sample.subject='Australia', sample.object='Canberra', predicted=\" Canberra\", (p=0.6969543695449829), known=(✓)\n",
      "sample.subject='Canada', sample.object='Ottawa', predicted=\" Ottawa\", (p=0.7992673516273499), known=(✓)\n",
      "sample.subject='Chile', sample.object='Santiago', predicted=\" Santiago\", (p=0.6498718857765198), known=(✓)\n",
      "sample.subject='Colombia', sample.object='Bogot\\\\u00e1', predicted=\" Bog\", (p=0.384506493806839), known=(✓)\n",
      "sample.subject='Egypt', sample.object='Cairo', predicted=\" Cairo\", (p=0.9333497881889343), known=(✓)\n",
      "sample.subject='France', sample.object='Paris', predicted=\" Paris\", (p=0.9923804998397827), known=(✓)\n",
      "sample.subject='Germany', sample.object='Berlin', predicted=\" Berlin\", (p=0.9819896221160889), known=(✓)\n",
      "sample.subject='India', sample.object='New Delhi', predicted=\" Delhi\", (p=0.6313670873641968), known=(✗)\n",
      "sample.subject='Mexico', sample.object='Mexico City', predicted=\" Mexico\", (p=0.7567870020866394), known=(✓)\n",
      "sample.subject='Nigeria', sample.object='Abuja', predicted=\"\\n\", (p=0.3142380714416504), known=(✗)\n",
      "sample.subject='Pakistan', sample.object='Islamabad', predicted=\" Islamabad\", (p=0.7140621542930603), known=(✓)\n",
      "sample.subject='Peru', sample.object='Lima', predicted=\" Lima\", (p=0.5431860089302063), known=(✓)\n",
      "sample.subject='Russia', sample.object='Moscow', predicted=\" Moscow\", (p=0.988792359828949), known=(✓)\n",
      "sample.subject='Saudi Arabia', sample.object='Riyadh', predicted=\"\\n\", (p=0.2784593105316162), known=(✗)\n",
      "sample.subject='South Korea', sample.object='Seoul', predicted=\" Seoul\", (p=0.9875303506851196), known=(✓)\n",
      "sample.subject='Spain', sample.object='Madrid', predicted=\" Madrid\", (p=0.8922967910766602), known=(✓)\n",
      "sample.subject='United States', sample.object='Washington D.C.', predicted=\" Washington\", (p=0.45999929308891296), known=(✓)\n",
      "sample.subject='Venezuela', sample.object='Caracas', predicted=\"\\n\", (p=0.2639511823654175), known=(✗)\n",
      "------------------------------------------------------------\n",
      "Faithfulness (@1) = 0.7894736842105263\n",
      "------------------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "correct = 0\n",
    "wrong = 0\n",
    "for sample in test.samples:\n",
    "    predictions = operator(subject = sample.subject).predictions\n",
    "    known_flag = functional.is_nontrivial_prefix(\n",
    "        prediction=predictions[0].token, target=sample.object\n",
    "    )\n",
    "    print(f\"{sample.subject=}, {sample.object=}, \", end=\"\")\n",
    "    print(f'predicted=\"{functional.format_whitespace(predictions[0].token)}\", (p={predictions[0].prob}), known=({functional.get_tick_marker(known_flag)})')\n",
    "    \n",
    "    correct += known_flag\n",
    "    wrong += not known_flag\n",
    "    \n",
    "faithfulness = correct/(correct + wrong)\n",
    "\n",
    "print(\"------------------------------------------------------------\")\n",
    "print(f\"Faithfulness (@1) = {faithfulness}\")\n",
    "print(\"------------------------------------------------------------\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d9a13389",
   "metadata": {},
   "source": [
    "# $causality$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "da2f8eb9",
   "metadata": {},
   "outputs": [],
   "source": [
    "################### hparams ###################\n",
    "rank = 100\n",
    "###############################################"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "25ac7213",
   "metadata": {},
   "outputs": [],
   "source": [
    "experiment_utils.set_seed(12345) # set seed to a constant value for sampling consistency\n",
    "test_targets = functional.random_edit_targets(test.samples)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "70d83c9b",
   "metadata": {},
   "source": [
    "## setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "1a13c0ff",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Changing the mapping (Argentina -> Buenos Aires) to (Argentina -> Riyadh)'"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "source = test.samples[0]\n",
    "target = test_targets[source]\n",
    "\n",
    "f\"Changing the mapping ({source}) to ({source.subject} -> {target.object})\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e98f67c8",
   "metadata": {},
   "source": [
    "### Calculate $\\Delta \\mathbf{s}$ such that $\\mathbf{s} + \\Delta \\mathbf{s} \\approx \\mathbf{s}'$\n",
    "\n",
    "<p align=\"center\">\n",
    "    <img align=\"center\" src=\"causality-crop.png\" style=\"width:80%;\"/>\n",
    "</p>\n",
    "\n",
    "Under the relation $r =\\, $*plays the instrument*, and given the subject $s =\\, $*Miles Davis*, the model will predict $o =\\, $*trumpet* **(a)**; and given the subject $s' =\\, $*Cat Stevens*, the model will now predict $o' =\\, $*guiter* **(b)**. \n",
    "\n",
    "If the computation from $\\mathbf{s}$ to $\\mathbf{o}$ is well-approximated by $operator$ parameterized by $W_r$ and $b_r$ **(c)**, then $\\Delta{\\mathbf{s}}$ **(d)** should tell us the direction of change from $\\mathbf{s}$ to $\\mathbf{s}'$. Thus, $\\tilde{\\mathbf{s}}=\\mathbf{s}+\\Delta\\mathbf{s}$ would be an approximation of $\\mathbf{s}'$ and patching $\\tilde{\\mathbf{s}}$ in place of $\\mathbf{s}$ should change the prediction to $o'$ = *guitar* "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "53c632ed",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_delta_s(\n",
    "    operator, \n",
    "    source_subject, \n",
    "    target_subject,\n",
    "    rank = 100,\n",
    "    fix_latent_norm = None, # if set, will fix the norms of z_source and z_target\n",
    "):\n",
    "    w_p_inv = functional.low_rank_pinv(\n",
    "        matrix = operator.weight,\n",
    "        rank=rank,\n",
    "    )\n",
    "    hs_and_zs = functional.compute_hs_and_zs(\n",
    "        mt = mt,\n",
    "        prompt_template = operator.prompt_template,\n",
    "        subjects = [source_subject, target_subject],\n",
    "        h_layer= operator.h_layer,\n",
    "        z_layer=-1,\n",
    "    )\n",
    "\n",
    "    z_source = hs_and_zs.z_by_subj[source_subject]\n",
    "    z_target = hs_and_zs.z_by_subj[target_subject]\n",
    "    \n",
    "    z_source *= fix_latent_norm / z_source.norm() if fix_latent_norm is not None else 1.0\n",
    "    z_target *= z_source.norm() / z_target.norm() if fix_latent_norm is not None else 1.0\n",
    "\n",
    "    delta_s = w_p_inv @  (z_target.squeeze() - z_source.squeeze())\n",
    "\n",
    "    return delta_s, hs_and_zs\n",
    "\n",
    "delta_s, hs_and_zs = get_delta_s(\n",
    "    operator = operator,\n",
    "    source_subject = source.subject,\n",
    "    target_subject = target.subject,\n",
    "    rank = rank\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "ab1c7e88",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[(' Riyadh', 0.802),\n",
       " (' J', 0.051),\n",
       " (' Mecca', 0.041),\n",
       " (' Saudi', 0.012),\n",
       " (' Riy', 0.01),\n",
       " ('\\n', 0.007),\n",
       " (' Dam', 0.005),\n",
       " (' Cairo', 0.004),\n",
       " (' the', 0.004),\n",
       " (' Al', 0.003)]"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import baukit\n",
    "\n",
    "def get_intervention(h, int_layer, subj_idx):\n",
    "    def edit_output(output, layer):\n",
    "        if(layer != int_layer):\n",
    "            return output\n",
    "        functional.untuple(output)[:, subj_idx] = h \n",
    "        return output\n",
    "    return edit_output\n",
    "\n",
    "prompt = operator.prompt_template.format(source.subject)\n",
    "\n",
    "h_index, inputs = functional.find_subject_token_index(\n",
    "    mt=mt,\n",
    "    prompt=prompt,\n",
    "    subject=source.subject,\n",
    ")\n",
    "\n",
    "h_layer, z_layer = models.determine_layer_paths(model = mt, layers = [layer, -1])\n",
    "\n",
    "with baukit.TraceDict(\n",
    "    mt.model, layers = [h_layer, z_layer],\n",
    "    edit_output=get_intervention(\n",
    "#         h = hs_and_zs.h_by_subj[source.subject],         # let the computation proceed as usual\n",
    "        h = hs_and_zs.h_by_subj[source.subject] + delta_s, # replace s with s + delta_s\n",
    "        int_layer = h_layer, \n",
    "        subj_idx = h_index\n",
    "    )\n",
    ") as traces:\n",
    "    outputs = mt.model(\n",
    "        input_ids = inputs.input_ids,\n",
    "        attention_mask = inputs.attention_mask,\n",
    "    )\n",
    "\n",
    "lens.interpret_logits(\n",
    "    mt = mt, \n",
    "    logits = outputs.logits[0][-1], \n",
    "    get_proba=True\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b3c272c1",
   "metadata": {},
   "source": [
    "## Measuring causality"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "51efa257",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.editors import LowRankPInvEditor\n",
    "\n",
    "svd = torch.svd(operator.weight.float())\n",
    "editor = LowRankPInvEditor(\n",
    "    lre=operator,\n",
    "    rank=rank,\n",
    "    svd=svd,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "88be35dd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Mapping Argentina -> Riyadh | edit result= Riyadh (p=0.819) | success=(✓)\n",
      "Mapping Australia -> Buenos Aires | edit result= Buenos (p=0.822) | success=(✓)\n",
      "Mapping Canada -> Abuja | edit result= Abu (p=0.610) | success=(✓)\n",
      "Mapping Chile -> Lima | edit result= Lima (p=0.967) | success=(✓)\n",
      "Mapping Colombia -> Berlin | edit result= Berlin (p=0.953) | success=(✓)\n",
      "Mapping Egypt -> Mexico City | edit result= Mexico (p=0.983) | success=(✓)\n",
      "Mapping France -> Riyadh | edit result= Riyadh (p=0.847) | success=(✓)\n",
      "Mapping Germany -> Cairo | edit result= Cairo (p=0.970) | success=(✓)\n",
      "Mapping India -> Lima | edit result= Lima (p=0.930) | success=(✓)\n",
      "Mapping Mexico -> Santiago | edit result= Santiago (p=0.955) | success=(✓)\n",
      "Mapping Nigeria -> Riyadh | edit result= Riyadh (p=0.849) | success=(✓)\n",
      "Mapping Pakistan -> New Delhi | edit result= New (p=0.863) | success=(✓)\n",
      "Mapping Peru -> Caracas | edit result= Car (p=0.937) | success=(✓)\n",
      "Mapping Russia -> Cairo | edit result= Cairo (p=0.966) | success=(✓)\n",
      "Mapping Saudi Arabia -> Caracas | edit result= Car (p=0.604) | success=(✓)\n",
      "Mapping South Korea -> Cairo | edit result= Cairo (p=0.925) | success=(✓)\n",
      "Mapping Spain -> Islamabad | edit result= Islamabad (p=0.883) | success=(✓)\n",
      "Mapping United States -> Ottawa | edit result= Ottawa (p=0.880) | success=(✓)\n",
      "Mapping Venezuela -> Madrid | edit result= Madrid (p=0.979) | success=(✓)\n",
      "------------------------------------------------------------\n",
      "Causality (@1) = 1.0\n",
      "------------------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "# precomputing latents to speed things up\n",
    "hs_and_zs = functional.compute_hs_and_zs(\n",
    "    mt = mt,\n",
    "    prompt_template = operator.prompt_template,\n",
    "    subjects = [sample.subject for sample in test.samples],\n",
    "    h_layer= operator.h_layer,\n",
    "    z_layer=-1,\n",
    "    batch_size = 2\n",
    ")\n",
    "\n",
    "success = 0\n",
    "fails = 0\n",
    "\n",
    "for sample in test.samples:\n",
    "    target = test_targets.get(sample)\n",
    "    assert target is not None\n",
    "    edit_result = editor(\n",
    "        subject = sample.subject,\n",
    "        target = target.subject\n",
    "    )\n",
    "    \n",
    "    success_flag = functional.is_nontrivial_prefix(\n",
    "        prediction=edit_result.predicted_tokens[0].token, target=target.object\n",
    "    )\n",
    "    \n",
    "    print(f\"Mapping {sample.subject} -> {target.object} | edit result={edit_result.predicted_tokens[0]} | success=({functional.get_tick_marker(success_flag)})\")\n",
    "    \n",
    "    success += success_flag\n",
    "    fails += not success_flag\n",
    "    \n",
    "causality = success / (success + fails)\n",
    "\n",
    "print(\"------------------------------------------------------------\")\n",
    "print(f\"Causality (@1) = {causality}\")\n",
    "print(\"------------------------------------------------------------\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c587ae85",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "36d6d2a9",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
