{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "df34c039",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "692d3751",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('..')\n",
    "\n",
    "import torch\n",
    "from src import models, data, lens, functional\n",
    "from src.utils import experiment_utils\n",
    "from baukit import Menu, show"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "42c63b83",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = \"cuda:0\"\n",
    "mt = models.load_model(\"gptj\", device=device, fp16=True)\n",
    "print(f\"dtype: {mt.model.dtype}, device: {mt.model.device}, memory: {mt.model.get_memory_footprint()}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "37ee0701",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<form id=\"_140603947707040_1\" style=\"display:flex;flex:1;flex-flow:column;gap:3px\"><select name=\"menu\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\"><option value=\"characteristic gender\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">characteristic gender</option><option value=\"univ degree gender\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">univ degree gender</option><option value=\"name birthplace\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">name birthplace</option><option value=\"name gender\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">name gender</option><option value=\"name religion\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">name religion</option><option value=\"occupation age\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">occupation age</option><option value=\"occupation gender\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">occupation gender</option><option value=\"person native language\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">person native language</option><option value=\"fruit inside color\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">fruit inside color</option><option value=\"fruit outside color\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">fruit outside color</option><option value=\"object superclass\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">object superclass</option><option value=\"substance phase of matter\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">substance phase of matter</option><option value=\"task person type\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">task person type</option><option value=\"task done by tool\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">task done by tool</option><option value=\"word sentiment\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">word sentiment</option><option value=\"work location\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">work location</option><option value=\"city in country\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">city in country</option><option value=\"company CEO\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">company CEO</option><option value=\"company hq\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">company hq</option><option value=\"country capital city\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">country capital city</option><option value=\"country currency\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">country currency</option><option value=\"country language\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">country language</option><option value=\"country largest city\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">country largest city</option><option value=\"food from country\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">food from country</option><option value=\"landmark in country\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">landmark in country</option><option value=\"landmark on continent\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">landmark on continent</option><option value=\"person lead singer of band\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">person lead singer of band</option><option value=\"person father\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">person father</option><option value=\"person mother\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">person mother</option><option value=\"person occupation\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">person occupation</option><option value=\"person plays instrument\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">person plays instrument</option><option value=\"person sport position\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">person sport position</option><option value=\"plays pro sport\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">plays pro sport</option><option value=\"person university\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">person university</option><option value=\"pokemon evolution\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">pokemon evolution</option><option value=\"president birth year\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">president birth year</option><option value=\"president election year\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">president election year</option><option value=\"product by company\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">product by company</option><option value=\"star constellation name\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">star constellation name</option><option value=\"superhero archnemesis\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">superhero archnemesis</option><option value=\"superhero person\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">superhero person</option><option value=\"adjective antonym\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">adjective antonym</option><option value=\"adjective comparative\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">adjective comparative</option><option value=\"adjective superlative\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">adjective superlative</option><option value=\"verb past tense\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">verb past tense</option><option value=\"word first letter\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">word first letter</option><option value=\"word last letter\" style=\"display:flex;flex:1;flex-flow:row wrap;gap:inherit\">word last letter</option></select></form><script>(function() {\n",
       "function getChan(obj_id) {\n",
       "var cname = \"comm_\" + obj_id;\n",
       "if (!window[cname]) { window[cname] = []; }\n",
       "var chan = window[cname];\n",
       "if (!chan.comm && Jupyter.notebook.kernel) {\n",
       "chan.comm = Jupyter.notebook.kernel.comm_manager.new_comm(cname, {});\n",
       "chan.comm.on_msg((ev) => {\n",
       "if (chan.retry) { clearInterval(chan.retry); chan.retry = null; }\n",
       "if (ev.content.data == 'ok') { return; }\n",
       "var args = ev.content.data.slice(1);\n",
       "for (fn of chan) { fn.apply(null, args); }\n",
       "});\n",
       "chan.retries = 5;\n",
       "chan.retry = setInterval(() => {\n",
       "if (chan.retries) { chan.retries -= 1; chan.comm.open(); }\n",
       "else { clearInterval(chan.retry); chan.retry = null; }\n",
       "}, 2000);\n",
       "}\n",
       "return chan;\n",
       "}\n",
       "function recvFromPython(obj_id, fn) {\n",
       "getChan(obj_id).push(fn);\n",
       "}\n",
       "function sendToPython(obj_id, ...args) {\n",
       "var comm = getChan(obj_id).comm;\n",
       "if (comm) { comm.send(args); }\n",
       "}\n",
       "class Model {\n",
       "constructor(obj_id, init) {\n",
       "this._id = obj_id;\n",
       "this._listeners = {};\n",
       "this._data = Object.assign({}, init);\n",
       "this._sent = {};\n",
       "recvFromPython(this._id, (name, value) => {\n",
       "this._data[name] = value;\n",
       "var e = new Event(name); e.value = value;\n",
       "if (this._listeners.hasOwnProperty(name)) {\n",
       "this._listeners[name].forEach((fn) => { fn(e); });\n",
       "}\n",
       "if (this._sent[name]) {\n",
       "value = this._sent[name];\n",
       "delete this._sent[name];\n",
       "if (value.length) {\n",
       "value = value[0];\n",
       "this.set_soon(name, value);\n",
       "}\n",
       "}\n",
       "})\n",
       "}\n",
       "trigger(name, value) {\n",
       "sendToPython(this._id, name, value);\n",
       "}\n",
       "get(name) {\n",
       "return this._data[name];\n",
       "}\n",
       "set(name, value) {\n",
       "delete this._sent[name];\n",
       "this.trigger(name, value);\n",
       "}\n",
       "set_soon(name, value) {\n",
       "if (!this._sent.hasOwnProperty(name)) {\n",
       "this._sent[name] = [];\n",
       "this.trigger(name, value);\n",
       "} else {\n",
       "this._sent[name] = [value];\n",
       "}\n",
       "}\n",
       "on(name, fn) {\n",
       "name.split(/\\s+/).forEach((n) => {\n",
       "if (!this._listeners.hasOwnProperty(n)) {\n",
       "this._listeners[n] = [];\n",
       "}\n",
       "this._listeners[n].push(fn);\n",
       "});\n",
       "}\n",
       "off(name, fn) {\n",
       "name.split(/\\s+/).forEach((n) => {\n",
       "if (!fn) {\n",
       "delete this._listeners[n];\n",
       "} else if (this._listeners.hasOwnProperty(n)) {\n",
       "this._listeners[n] = this._listeners[n].filter(\n",
       "(e) => { return e !== fn; });\n",
       "}\n",
       "});\n",
       "}\n",
       "}\n",
       "\n",
       "var model = new Model(\"140603947707040\", {\"style\": null, \"choices\": [\"characteristic gender\", \"univ degree gender\", \"name birthplace\", \"name gender\", \"name religion\", \"occupation age\", \"occupation gender\", \"person native language\", \"fruit inside color\", \"fruit outside color\", \"object superclass\", \"substance phase of matter\", \"task person type\", \"task done by tool\", \"word sentiment\", \"work location\", \"city in country\", \"company CEO\", \"company hq\", \"country capital city\", \"country currency\", \"country language\", \"country largest city\", \"food from country\", \"landmark in country\", \"landmark on continent\", \"person lead singer of band\", \"person father\", \"person mother\", \"person occupation\", \"person plays instrument\", \"person sport position\", \"plays pro sport\", \"person university\", \"pokemon evolution\", \"president birth year\", \"president election year\", \"product by company\", \"star constellation name\", \"superhero archnemesis\", \"superhero person\", \"adjective antonym\", \"adjective comparative\", \"adjective superlative\", \"verb past tense\", \"word first letter\", \"word last letter\"], \"value\": [\"characteristic gender\", \"univ degree gender\", \"name birthplace\", \"name gender\", \"name religion\", \"occupation age\", \"occupation gender\", \"person native language\", \"fruit inside color\", \"fruit outside color\", \"object superclass\", \"substance phase of matter\", \"task person type\", \"task done by tool\", \"word sentiment\", \"work location\", \"city in country\", \"company CEO\", \"company hq\", \"country capital city\", \"country currency\", \"country language\", \"country largest city\", \"food from country\", \"landmark in country\", \"landmark on continent\", \"person lead singer of band\", \"person father\", \"person mother\", \"person occupation\", \"person plays instrument\", \"person sport position\", \"plays pro sport\", \"person university\", \"pokemon evolution\", \"president birth year\", \"president election year\", \"product by company\", \"star constellation name\", \"superhero archnemesis\", \"superhero person\", \"adjective antonym\", \"adjective comparative\", \"adjective superlative\", \"verb past tense\", \"word first letter\", \"word last letter\"]});\n",
       "var element = document.getElementById(\"_140603947707040_1\");\n",
       "model.on('write', (ev) => {\n",
       "var dummy = document.createElement('div');\n",
       "dummy.innerHTML = ev.value.trim();\n",
       "dummy.childNodes.forEach((item) => {\n",
       "element.parentNode.insertBefore(item, element);\n",
       "});\n",
       "});\n",
       "function upd(a) { return (e) => { for (k in e.value) {\n",
       "element[a][k] = e.value[k];\n",
       "}}}\n",
       "model.on('style', upd('style'));\n",
       "model.on('style', upd('style'));\n",
       "model.on('data', upd('dataset'));\n",
       "\n",
       "function esc(unsafe) {\n",
       "return unsafe.replace(/&/g, \"&amp;\").replace(/</g, \"&lt;\")\n",
       ".replace(/>/g, \"&gt;\").replace(/\"/g, \"&quot;\");\n",
       "}\n",
       "function render() {\n",
       "var value = model.get('value');\n",
       "var lines = model.get('choices').map((c) => {\n",
       "return '<option value=\"' + esc(''+c) + '\"' +\n",
       "(c == value ? ' selected' : '') +\n",
       "'>' + esc(''+c) + '</option>';\n",
       "});\n",
       "element.menu.innerHTML = lines.join('\\n');\n",
       "}\n",
       "model.on('value', (ev) => {\n",
       "[...element.querySelectorAll('option')].forEach((e) => {\n",
       "e.selected = (e.value == ev.value);\n",
       "})\n",
       "});\n",
       "element.addEventListener('change', (e) => {\n",
       "model.set('value', element.menu.value);\n",
       "});\n",
       "})();</script>"
      ],
      "text/plain": [
       "<baukit.show.HtmlRepr at 0x7fe0e85b6a10>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "dataset = data.load_dataset()\n",
    "\n",
    "relation_names = [r.name for r in dataset.relations]\n",
    "relation_options = Menu(choices = relation_names, value = relation_names)\n",
    "show(relation_options)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "4b70864c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "food from country -- 30 samples\n",
      "------------------------------------------------------\n",
      "Miso Soup -> Japan\n",
      "Pierogi -> Poland\n",
      "Kimchi -> South Korea\n",
      "Hummus -> Lebanon\n",
      "Fondue -> Switzerland\n"
     ]
    }
   ],
   "source": [
    "relation_name = relation_options.value\n",
    "relation = dataset.filter(relation_names=[relation_name])[0]\n",
    "print(f\"{relation.name} -- {len(relation.samples)} samples\")\n",
    "print(\"------------------------------------------------------\")\n",
    "\n",
    "experiment_utils.set_seed(12345) # set seed to a constant value for sampling consistency\n",
    "train, test = relation.split(5)\n",
    "print(\"\\n\".join([sample.__str__() for sample in train.samples]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "fdb69b6b",
   "metadata": {},
   "outputs": [],
   "source": [
    "################### hparams ###################\n",
    "layer = 5\n",
    "beta = 2.5\n",
    "###############################################"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "a3929dd8",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "relation has > 1 prompt_templates, will use first ({} originates from)\n"
     ]
    }
   ],
   "source": [
    "from src.operators import JacobianIclMeanEstimator\n",
    "\n",
    "estimator = JacobianIclMeanEstimator(\n",
    "    mt = mt, \n",
    "    h_layer = layer,\n",
    "    beta = beta\n",
    ")\n",
    "operator = estimator(\n",
    "    relation.set(\n",
    "        samples=train.samples, \n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0ab02f08",
   "metadata": {},
   "source": [
    "# Checking $faithfulness$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "354f4207",
   "metadata": {},
   "outputs": [],
   "source": [
    "test = functional.filter_relation_samples_based_on_provided_fewshots(\n",
    "    mt=mt, test_relation=test, prompt_template=operator.prompt_template, batch_size=4\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "ff4e9b1c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Baguette -> France\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[PredictedToken(token=' France', prob=0.8985161781311035),\n",
       " PredictedToken(token=' Switzerland', prob=0.014076330699026585),\n",
       " PredictedToken(token=' the', prob=0.01281664427369833),\n",
       " PredictedToken(token=' Austria', prob=0.009017635136842728),\n",
       " PredictedToken(token=' Italy', prob=0.007302704732865095)]"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sample = test.samples[0]\n",
    "print(sample)\n",
    "operator(subject = sample.subject).predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "334a3954",
   "metadata": {},
   "outputs": [],
   "source": [
    "hs_and_zs = functional.compute_hs_and_zs(\n",
    "    mt = mt,\n",
    "    prompt_template = operator.prompt_template,\n",
    "    subjects = [sample.subject],\n",
    "    h_layer= operator.h_layer,\n",
    ")\n",
    "h = hs_and_zs.h_by_subj[sample.subject]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bf9ae83f",
   "metadata": {},
   "source": [
    "## Approximating LM computation $F$ as an affine transformation\n",
    "\n",
    "### $$ F(\\mathbf{s}, c_r) \\approx \\beta \\, W_r \\mathbf{s} + b_r $$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "12cff47b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "([(' France', 0.898),\n",
       "  (' Switzerland', 0.014),\n",
       "  (' the', 0.013),\n",
       "  (' Austria', 0.009),\n",
       "  (' Italy', 0.007),\n",
       "  (' Europe', 0.007),\n",
       "  ('\\n', 0.004),\n",
       "  (' Germany', 0.003),\n",
       "  (' Paris', 0.003),\n",
       "  (' Spain', 0.002)],\n",
       " {})"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "z = operator.beta * (operator.weight @ h) + operator.bias\n",
    "\n",
    "lens.logit_lens(\n",
    "    mt = mt,\n",
    "    h = z,\n",
    "    get_proba = True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "eb5b56cd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sample.subject='Baguette', sample.object='France', predicted=\" France\", (p=0.8985161781311035), known=(✓)\n",
      "sample.subject='Biryani', sample.object='India', predicted=\" India\", (p=0.22970259189605713), known=(✓)\n",
      "sample.subject='Ceviche', sample.object='Peru', predicted=\" South\", (p=0.34727180004119873), known=(✗)\n",
      "sample.subject='Chimichurri', sample.object='Argentina', predicted=\" Argentina\", (p=0.4291362464427948), known=(✓)\n",
      "sample.subject='Dim Sum', sample.object='China', predicted=\" China\", (p=0.9262574315071106), known=(✓)\n",
      "sample.subject='Feijoada', sample.object='Brazil', predicted=\" Spain\", (p=0.26894351840019226), known=(✗)\n",
      "sample.subject='Goulash', sample.object='Hungary', predicted=\" Hungary\", (p=0.7775859236717224), known=(✓)\n",
      "sample.subject='Gyro', sample.object='Greece', predicted=\" Hungary\", (p=0.35067880153656006), known=(✗)\n",
      "sample.subject='Masala Dosa', sample.object='India', predicted=\" South\", (p=0.6780027747154236), known=(✗)\n",
      "sample.subject='Moussaka', sample.object='Greece', predicted=\" Russia\", (p=0.21241508424282074), known=(✗)\n",
      "sample.subject='Pad Thai', sample.object='Thailand', predicted=\" Thailand\", (p=0.7516799569129944), known=(✓)\n",
      "sample.subject='Paella', sample.object='Spain', predicted=\" Spain\", (p=0.8736255168914795), known=(✓)\n",
      "sample.subject='Pho', sample.object='Vietnam', predicted=\" Czech\", (p=0.171535924077034), known=(✗)\n",
      "sample.subject='Pizza', sample.object='Italy', predicted=\" Italy\", (p=0.9897751808166504), known=(✓)\n",
      "sample.subject='Poutine', sample.object='Canada', predicted=\" Canada\", (p=0.3234255015850067), known=(✓)\n",
      "sample.subject='Rendang', sample.object='Indonesia', predicted=\" South\", (p=0.8722554445266724), known=(✗)\n",
      "sample.subject='Sushi', sample.object='Japan', predicted=\" Japan\", (p=0.9883920550346375), known=(✓)\n",
      "sample.subject='Tacos', sample.object='Mexico', predicted=\" Mexico\", (p=0.7387374043464661), known=(✓)\n",
      "sample.subject='Wiener Schnitzel', sample.object='Austria', predicted=\" Austria\", (p=0.4125869870185852), known=(✓)\n",
      "------------------------------------------------------------\n",
      "Faithfulness (@1) = 0.631578947368421\n",
      "------------------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "correct = 0\n",
    "wrong = 0\n",
    "for sample in test.samples:\n",
    "    predictions = operator(subject = sample.subject).predictions\n",
    "    known_flag = functional.is_nontrivial_prefix(\n",
    "        prediction=predictions[0].token, target=sample.object\n",
    "    )\n",
    "    print(f\"{sample.subject=}, {sample.object=}, \", end=\"\")\n",
    "    print(f'predicted=\"{functional.format_whitespace(predictions[0].token)}\", (p={predictions[0].prob}), known=({functional.get_tick_marker(known_flag)})')\n",
    "    \n",
    "    correct += known_flag\n",
    "    wrong += not known_flag\n",
    "    \n",
    "faithfulness = correct/(correct + wrong)\n",
    "\n",
    "print(\"------------------------------------------------------------\")\n",
    "print(f\"Faithfulness (@1) = {faithfulness}\")\n",
    "print(\"------------------------------------------------------------\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f18cbd5f",
   "metadata": {},
   "source": [
    "# $causality$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "5d8a89ab",
   "metadata": {},
   "outputs": [],
   "source": [
    "################### hparams ###################\n",
    "rank = 100\n",
    "###############################################"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "f38aedad",
   "metadata": {},
   "outputs": [],
   "source": [
    "experiment_utils.set_seed(12345) # set seed to a constant value for sampling consistency\n",
    "test_targets = functional.random_edit_targets(test.samples)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3827a821",
   "metadata": {},
   "source": [
    "## setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "00e2744f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Changing the mapping (Baguette -> France) to (Baguette -> Canada)'"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "source = test.samples[0]\n",
    "target = test_targets[source]\n",
    "\n",
    "f\"Changing the mapping ({source}) to ({source.subject} -> {target.object})\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cdf1674e",
   "metadata": {},
   "source": [
    "### Calculate $\\Delta \\mathbf{s}$ such that $\\mathbf{s} + \\Delta \\mathbf{s} \\approx \\mathbf{s}'$\n",
    "\n",
    "<!-- ![](causality-crop.png) -->\n",
    "<img src=\"causality-crop.png\" style=\"width:50%;\"/>\n",
    "\n",
    "Under the relation $r =\\, $*plays the instrument*, and given the subject $s =\\, $*Miles Davis*, the model will predict $o =\\, $*trumpet* **(a)** Under the same relation given the subject $s' =\\, $*Cat Stevens*, the model will now predict $o' =\\, $*guiter* **(b)**. \n",
    "\n",
    "If the computation from $\\mathbf{s}$ to $\\mathbf{o}$ is well-approximated by $operator$, then adding the vector $\\Delta{\\mathbf{s}}$ to $\\mathbf{s}$ **(c)** would be equivalent to replacing $\\mathbf{s}$ with $\\mathbf{s}'$ and it should change the model prediction to $o'$ = *guitar* **(d)**."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "0bf499ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_delta_s(\n",
    "    operator, \n",
    "    source_subject, \n",
    "    target_subject,\n",
    "    rank = 100,\n",
    "    fix_latent_norm = None, # if set, will fix the norms of z_source and z_target\n",
    "):\n",
    "    w_p_inv = functional.low_rank_pinv(\n",
    "        matrix = operator.weight,\n",
    "        rank=rank,\n",
    "    )\n",
    "    hs_and_zs = functional.compute_hs_and_zs(\n",
    "        mt = mt,\n",
    "        prompt_template = operator.prompt_template,\n",
    "        subjects = [source_subject, target_subject],\n",
    "        h_layer= operator.h_layer,\n",
    "        z_layer=-1,\n",
    "    )\n",
    "\n",
    "    z_source = hs_and_zs.z_by_subj[source_subject]\n",
    "    z_target = hs_and_zs.z_by_subj[target_subject]\n",
    "    \n",
    "    z_source *= fix_latent_norm / z_source.norm() if fix_latent_norm is not None else 1.0\n",
    "    z_target *= z_source.norm() / z_target.norm() if fix_latent_norm is not None else 1.0\n",
    "\n",
    "    delta_s = w_p_inv @  (z_target.squeeze() - z_source.squeeze())\n",
    "\n",
    "    return delta_s, hs_and_zs\n",
    "\n",
    "delta_s, hs_and_zs = get_delta_s(\n",
    "    operator = operator,\n",
    "    source_subject = source.subject,\n",
    "    target_subject = target.subject,\n",
    "    rank = rank\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "a3c38bb3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[(' Canada', 0.795),\n",
       " (' Quebec', 0.098),\n",
       " (' the', 0.017),\n",
       " (' Qué', 0.011),\n",
       " (' North', 0.008),\n",
       " (' Ontario', 0.005),\n",
       " (' France', 0.004),\n",
       " (' New', 0.004),\n",
       " (' Montreal', 0.004),\n",
       " (' Canadian', 0.003)]"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import baukit\n",
    "\n",
    "def get_intervention(h, int_layer, subj_idx):\n",
    "    def edit_output(output, layer):\n",
    "        if(layer != int_layer):\n",
    "            return output\n",
    "        functional.untuple(output)[:, subj_idx] = h \n",
    "        return output\n",
    "    return edit_output\n",
    "\n",
    "prompt = operator.prompt_template.format(source.subject)\n",
    "\n",
    "h_index, inputs = functional.find_subject_token_index(\n",
    "    mt=mt,\n",
    "    prompt=prompt,\n",
    "    subject=source.subject,\n",
    ")\n",
    "\n",
    "h_layer, z_layer = models.determine_layer_paths(model = mt, layers = [layer, -1])\n",
    "\n",
    "with baukit.TraceDict(\n",
    "    mt.model, layers = [h_layer, z_layer],\n",
    "    edit_output=get_intervention(\n",
    "#         h = hs_and_zs.h_by_subj[source.subject],         # let the computation proceed as usual\n",
    "        h = hs_and_zs.h_by_subj[source.subject] + delta_s, # replace s with s + delta_s\n",
    "        int_layer = h_layer, \n",
    "        subj_idx = h_index\n",
    "    )\n",
    ") as traces:\n",
    "    outputs = mt.model(\n",
    "        input_ids = inputs.input_ids,\n",
    "        attention_mask = inputs.attention_mask,\n",
    "    )\n",
    "\n",
    "lens.interpret_logits(\n",
    "    mt = mt, \n",
    "    logits = outputs.logits[0][-1], \n",
    "    get_proba=True\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "402abe7d",
   "metadata": {},
   "source": [
    "## Measuring causality"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "75814669",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.editors import LowRankPInvEditor\n",
    "\n",
    "svd = torch.svd(operator.weight.float())\n",
    "editor = LowRankPInvEditor(\n",
    "    lre=operator,\n",
    "    rank=rank,\n",
    "    svd=svd,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "d1bb9735",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Mapping Baguette -> Canada | edit result= Canada (p=0.812) | success=(✓)\n",
      "Mapping Biryani -> France | edit result= France (p=0.792) | success=(✓)\n",
      "Mapping Ceviche -> Thailand | edit result= Thailand (p=0.817) | success=(✓)\n",
      "Mapping Chimichurri -> Vietnam | edit result= Vietnam (p=0.815) | success=(✓)\n",
      "Mapping Dim Sum -> Greece | edit result= Greece (p=0.791) | success=(✓)\n",
      "Mapping Feijoada -> Greece | edit result= Greece (p=0.782) | success=(✓)\n",
      "Mapping Goulash -> Canada | edit result= Canada (p=0.736) | success=(✓)\n",
      "Mapping Gyro -> Brazil | edit result= Brazil (p=0.793) | success=(✓)\n",
      "Mapping Masala Dosa -> Italy | edit result= Italy (p=0.810) | success=(✓)\n",
      "Mapping Moussaka -> Argentina | edit result= Argentina (p=0.909) | success=(✓)\n",
      "Mapping Pad Thai -> Canada | edit result= Canada (p=0.850) | success=(✓)\n",
      "Mapping Paella -> India | edit result= India (p=0.828) | success=(✓)\n",
      "Mapping Pho -> Austria | edit result= Austria (p=0.580) | success=(✓)\n",
      "Mapping Pizza -> Brazil | edit result= Brazil (p=0.722) | success=(✓)\n",
      "Mapping Poutine -> Austria | edit result= Austria (p=0.507) | success=(✓)\n",
      "Mapping Rendang -> Brazil | edit result= Brazil (p=0.539) | success=(✓)\n",
      "Mapping Sushi -> Spain | edit result= Spain (p=0.970) | success=(✓)\n",
      "Mapping Tacos -> Peru | edit result= Peru (p=0.217) | success=(✓)\n",
      "Mapping Wiener Schnitzel -> Japan | edit result= Japan (p=0.955) | success=(✓)\n",
      "------------------------------------------------------------\n",
      "Causality (@1) = 1.0\n",
      "------------------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "# precomputing latents to speed things up\n",
    "hs_and_zs = functional.compute_hs_and_zs(\n",
    "    mt = mt,\n",
    "    prompt_template = operator.prompt_template,\n",
    "    subjects = [sample.subject for sample in test.samples],\n",
    "    h_layer= operator.h_layer,\n",
    "    z_layer=-1,\n",
    "    batch_size = 2\n",
    ")\n",
    "\n",
    "success = 0\n",
    "fails = 0\n",
    "\n",
    "for sample in test.samples:\n",
    "    target = test_targets.get(sample)\n",
    "    assert target is not None\n",
    "    edit_result = editor(\n",
    "        subject = sample.subject,\n",
    "        target = target.subject\n",
    "    )\n",
    "    \n",
    "    success_flag = functional.is_nontrivial_prefix(\n",
    "        prediction=edit_result.predicted_tokens[0].token, target=target.object\n",
    "    )\n",
    "    \n",
    "    print(f\"Mapping {sample.subject} -> {target.object} | edit result={edit_result.predicted_tokens[0]} | success=({functional.get_tick_marker(success_flag)})\")\n",
    "    \n",
    "    success += success_flag\n",
    "    fails += not success_flag\n",
    "    \n",
    "causality = success / (success + fails)\n",
    "\n",
    "print(\"------------------------------------------------------------\")\n",
    "print(f\"Causality (@1) = {causality}\")\n",
    "print(\"------------------------------------------------------------\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aae56894",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "604314f3",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
