{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "5b360b04",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('..')\n",
    "\n",
    "import torch\n",
    "from src import models, data, lens, functional\n",
    "from src.utils import experiment_utils\n",
    "from baukit import Menu, show"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "9d46d0c3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dtype: torch.float16, device: cuda:0, memory: 12219206136\n"
     ]
    }
   ],
   "source": [
    "device = \"cuda:0\"\n",
    "mt = models.load_model(\"gptj\", device=device, fp16=True)\n",
    "print(f\"dtype: {mt.model.dtype}, device: {mt.model.device}, memory: {mt.model.get_memory_footprint()}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "5f51154e",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = data.load_dataset()\n",
    "\n",
    "# relation_names = [r.name for r in dataset.relations]\n",
    "# relation_options = Menu(choices = relation_names, value = relation_names)\n",
    "# show(relation_options) # !caution: tested in a juputer-notebook. baukit visualizations are not supported in vscode."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "3a17444f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "country capital city -- 24 samples\n",
      "------------------------------------------------------\n",
      "China -> Beijing\n",
      "Japan -> Tokyo\n",
      "Italy -> Rome\n",
      "Brazil -> Bras\\u00edlia\n",
      "Turkey -> Ankara\n"
     ]
    }
   ],
   "source": [
    "relation_name = \"country capital city\"\n",
    "relation = dataset.filter(relation_names=[relation_name])[0]\n",
    "print(f\"{relation.name} -- {len(relation.samples)} samples\")\n",
    "print(\"------------------------------------------------------\")\n",
    "\n",
    "experiment_utils.set_seed(12345) # set seed to a constant value for sampling consistency\n",
    "train, test = relation.split(5)\n",
    "print(\"\\n\".join([sample.__str__() for sample in train.samples]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "145bf9dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "################### hparams ###################\n",
    "layer = 5\n",
    "beta = 2.5\n",
    "###############################################"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "83b33032",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "relation has > 1 prompt_templates, will use first (The capital city of {} is)\n"
     ]
    }
   ],
   "source": [
    "from src.operators import JacobianIclMeanEstimator\n",
    "\n",
    "estimator = JacobianIclMeanEstimator(\n",
    "    mt = mt, \n",
    "    h_layer = layer,\n",
    "    beta = beta\n",
    ")\n",
    "operator = estimator(\n",
    "    relation.set(\n",
    "        samples=train.samples, \n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6c1b6eda",
   "metadata": {},
   "source": [
    "# Checking $faithfulness$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "9d79b613",
   "metadata": {},
   "outputs": [],
   "source": [
    "test = functional.filter_relation_samples_based_on_provided_fewshots(\n",
    "    mt=mt, test_relation=test, prompt_template=operator.prompt_template, batch_size=4\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "6ad18a70",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Argentina -> Buenos Aires\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[PredictedToken(token=' Buenos', prob=0.8899907469749451),\n",
       " PredictedToken(token='\\n', prob=0.02772851102054119),\n",
       " PredictedToken(token=' ', prob=0.013726607896387577),\n",
       " PredictedToken(token=' Argentina', prob=0.008456717245280743),\n",
       " PredictedToken(token=' Bras', prob=0.0059037404134869576)]"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sample = test.samples[0]\n",
    "print(sample)\n",
    "operator(subject = sample.subject).predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "e717d47f",
   "metadata": {},
   "outputs": [],
   "source": [
    "hs_and_zs = functional.compute_hs_and_zs(\n",
    "    mt = mt,\n",
    "    prompt_template = operator.prompt_template,\n",
    "    subjects = [sample.subject],\n",
    "    h_layer= operator.h_layer,\n",
    ")\n",
    "h = hs_and_zs.h_by_subj[sample.subject]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3c69136b",
   "metadata": {},
   "source": [
    "## Approximating LM computation $F$ as an affine transformation\n",
    "\n",
    "### $$ F(\\mathbf{s}, c_r) \\approx \\beta \\, W_r \\mathbf{s} + b_r $$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "d8bf8b1d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "([(' Buenos', 0.89),\n",
       "  ('\\n', 0.028),\n",
       "  (' ', 0.014),\n",
       "  (' Argentina', 0.008),\n",
       "  (' Bras', 0.006),\n",
       "  ('...', 0.006),\n",
       "  (' Rome', 0.004),\n",
       "  (' {', 0.003),\n",
       "  (' the', 0.002),\n",
       "  ('...', 0.002)],\n",
       " {})"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "z = operator.beta * (operator.weight @ h) + operator.bias\n",
    "\n",
    "lens.logit_lens(\n",
    "    mt = mt,\n",
    "    h = z,\n",
    "    get_proba = True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "e0725d34",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sample.subject='Argentina', sample.object='Buenos Aires', predicted=\" Buenos\", (p=0.8899907469749451), known=(✓)\n",
      "sample.subject='Australia', sample.object='Canberra', predicted=\" Canberra\", (p=0.6983034610748291), known=(✓)\n",
      "sample.subject='Canada', sample.object='Ottawa', predicted=\" Ottawa\", (p=0.7997354865074158), known=(✓)\n",
      "sample.subject='Chile', sample.object='Santiago', predicted=\" Santiago\", (p=0.6504030823707581), known=(✓)\n",
      "sample.subject='Colombia', sample.object='Bogot\\\\u00e1', predicted=\" Bog\", (p=0.38615524768829346), known=(✓)\n",
      "sample.subject='Egypt', sample.object='Cairo', predicted=\" Cairo\", (p=0.9333562850952148), known=(✓)\n",
      "sample.subject='France', sample.object='Paris', predicted=\" Paris\", (p=0.9924296736717224), known=(✓)\n",
      "sample.subject='Germany', sample.object='Berlin', predicted=\" Berlin\", (p=0.9821451902389526), known=(✓)\n",
      "sample.subject='India', sample.object='New Delhi', predicted=\" Delhi\", (p=0.6313555836677551), known=(✗)\n",
      "sample.subject='Mexico', sample.object='Mexico City', predicted=\" Mexico\", (p=0.7568068504333496), known=(✓)\n",
      "sample.subject='Nigeria', sample.object='Abuja', predicted=\"\\n\", (p=0.31467148661613464), known=(✗)\n",
      "sample.subject='Pakistan', sample.object='Islamabad', predicted=\" Islamabad\", (p=0.7141078114509583), known=(✓)\n",
      "sample.subject='Peru', sample.object='Lima', predicted=\" Lima\", (p=0.5434004664421082), known=(✓)\n",
      "sample.subject='Russia', sample.object='Moscow', predicted=\" Moscow\", (p=0.9887937903404236), known=(✓)\n",
      "sample.subject='Saudi Arabia', sample.object='Riyadh', predicted=\"\\n\", (p=0.27850085496902466), known=(✗)\n",
      "sample.subject='South Korea', sample.object='Seoul', predicted=\" Seoul\", (p=0.9875714778900146), known=(✓)\n",
      "sample.subject='Spain', sample.object='Madrid', predicted=\" Madrid\", (p=0.892389178276062), known=(✓)\n",
      "sample.subject='United States', sample.object='Washington D.C.', predicted=\" Washington\", (p=0.4601989686489105), known=(✓)\n",
      "sample.subject='Venezuela', sample.object='Caracas', predicted=\"\\n\", (p=0.26478108763694763), known=(✗)\n",
      "------------------------------------------------------------\n",
      "Faithfulness (@1) = 0.7894736842105263\n",
      "------------------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "correct = 0\n",
    "wrong = 0\n",
    "for sample in test.samples:\n",
    "    predictions = operator(subject = sample.subject).predictions\n",
    "    known_flag = functional.is_nontrivial_prefix(\n",
    "        prediction=predictions[0].token, target=sample.object\n",
    "    )\n",
    "    print(f\"{sample.subject=}, {sample.object=}, \", end=\"\")\n",
    "    print(f'predicted=\"{functional.format_whitespace(predictions[0].token)}\", (p={predictions[0].prob}), known=({functional.get_tick_marker(known_flag)})')\n",
    "    \n",
    "    correct += known_flag\n",
    "    wrong += not known_flag\n",
    "    \n",
    "faithfulness = correct/(correct + wrong)\n",
    "\n",
    "print(\"------------------------------------------------------------\")\n",
    "print(f\"Faithfulness (@1) = {faithfulness}\")\n",
    "print(\"------------------------------------------------------------\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d9a13389",
   "metadata": {},
   "source": [
    "# $causality$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "da2f8eb9",
   "metadata": {},
   "outputs": [],
   "source": [
    "################### hparams ###################\n",
    "rank = 100\n",
    "###############################################"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "25ac7213",
   "metadata": {},
   "outputs": [],
   "source": [
    "experiment_utils.set_seed(12345) # set seed to a constant value for sampling consistency\n",
    "test_targets = functional.random_edit_targets(test.samples)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "70d83c9b",
   "metadata": {},
   "source": [
    "## setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "1a13c0ff",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Changing the mapping (Argentina -> Buenos Aires) to (Argentina -> Riyadh)'"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "source = test.samples[0]\n",
    "target = test_targets[source]\n",
    "\n",
    "f\"Changing the mapping ({source}) to ({source.subject} -> {target.object})\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e98f67c8",
   "metadata": {},
   "source": [
    "### Calculate $\\Delta \\mathbf{s}$ such that $\\mathbf{s} + \\Delta \\mathbf{s} \\approx \\mathbf{s}'$\n",
    "\n",
    "<p align=\"center\">\n",
    "    <img align=\"center\" src=\"causality-crop.png\" style=\"width:80%;\"/>\n",
    "</p>\n",
    "\n",
    "Under the relation $r =\\, $*plays the instrument*, and given the subject $s =\\, $*Miles Davis*, the model will predict $o =\\, $*trumpet* **(a)**; and given the subject $s' =\\, $*Cat Stevens*, the model will now predict $o' =\\, $*guiter* **(b)**. \n",
    "\n",
    "If the computation from $\\mathbf{s}$ to $\\mathbf{o}$ is well-approximated by $operator$ parameterized by $W_r$ and $b_r$ **(c)**, then $\\Delta{\\mathbf{s}}$ **(d)** should tell us the direction of change from $\\mathbf{s}$ to $\\mathbf{s}'$. Thus, $\\tilde{\\mathbf{s}}=\\mathbf{s}+\\Delta\\mathbf{s}$ would be an approximation of $\\mathbf{s}'$ and patching $\\tilde{\\mathbf{s}}$ in place of $\\mathbf{s}$ should change the prediction to $o'$ = *guitar* "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "53c632ed",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_delta_s(\n",
    "    operator, \n",
    "    source_subject, \n",
    "    target_subject,\n",
    "    rank = 100,\n",
    "    fix_latent_norm = None, # if set, will fix the norms of z_source and z_target\n",
    "):\n",
    "    w_p_inv = functional.low_rank_pinv(\n",
    "        matrix = operator.weight,\n",
    "        rank=rank,\n",
    "    )\n",
    "    hs_and_zs = functional.compute_hs_and_zs(\n",
    "        mt = mt,\n",
    "        prompt_template = operator.prompt_template,\n",
    "        subjects = [source_subject, target_subject],\n",
    "        h_layer= operator.h_layer,\n",
    "        z_layer=-1,\n",
    "    )\n",
    "\n",
    "    z_source = hs_and_zs.z_by_subj[source_subject]\n",
    "    z_target = hs_and_zs.z_by_subj[target_subject]\n",
    "    \n",
    "    z_source *= fix_latent_norm / z_source.norm() if fix_latent_norm is not None else 1.0\n",
    "    z_target *= z_source.norm() / z_target.norm() if fix_latent_norm is not None else 1.0\n",
    "\n",
    "    delta_s = w_p_inv @  (z_target.squeeze() - z_source.squeeze())\n",
    "\n",
    "    return delta_s, hs_and_zs\n",
    "\n",
    "delta_s, hs_and_zs = get_delta_s(\n",
    "    operator = operator,\n",
    "    source_subject = source.subject,\n",
    "    target_subject = target.subject,\n",
    "    rank = rank\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "ab1c7e88",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[(' Riyadh', 0.802),\n",
       " (' J', 0.051),\n",
       " (' Mecca', 0.041),\n",
       " (' Saudi', 0.012),\n",
       " (' Riy', 0.01),\n",
       " ('\\n', 0.007),\n",
       " (' Dam', 0.005),\n",
       " (' Cairo', 0.004),\n",
       " (' the', 0.004),\n",
       " (' Al', 0.003)]"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import baukit\n",
    "\n",
    "def get_intervention(h, int_layer, subj_idx):\n",
    "    def edit_output(output, layer):\n",
    "        if(layer != int_layer):\n",
    "            return output\n",
    "        functional.untuple(output)[:, subj_idx] = h \n",
    "        return output\n",
    "    return edit_output\n",
    "\n",
    "prompt = operator.prompt_template.format(source.subject)\n",
    "\n",
    "h_index, inputs = functional.find_subject_token_index(\n",
    "    mt=mt,\n",
    "    prompt=prompt,\n",
    "    subject=source.subject,\n",
    ")\n",
    "\n",
    "h_layer, z_layer = models.determine_layer_paths(model = mt, layers = [layer, -1])\n",
    "\n",
    "with baukit.TraceDict(\n",
    "    mt.model, layers = [h_layer, z_layer],\n",
    "    edit_output=get_intervention(\n",
    "#         h = hs_and_zs.h_by_subj[source.subject],         # let the computation proceed as usual\n",
    "        h = hs_and_zs.h_by_subj[source.subject] + delta_s, # replace s with s + delta_s\n",
    "        int_layer = h_layer, \n",
    "        subj_idx = h_index\n",
    "    )\n",
    ") as traces:\n",
    "    outputs = mt.model(\n",
    "        input_ids = inputs.input_ids,\n",
    "        attention_mask = inputs.attention_mask,\n",
    "    )\n",
    "\n",
    "lens.interpret_logits(\n",
    "    mt = mt, \n",
    "    logits = outputs.logits[0][-1], \n",
    "    get_proba=True\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b3c272c1",
   "metadata": {},
   "source": [
    "## Measuring causality"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "51efa257",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.editors import LowRankPInvEditor\n",
    "\n",
    "svd = torch.svd(operator.weight.float())\n",
    "editor = LowRankPInvEditor(\n",
    "    lre=operator,\n",
    "    rank=rank,\n",
    "    svd=svd,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "88be35dd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Mapping Argentina -> Riyadh | edit result= Riyadh (p=0.819) | success=(✓)\n",
      "Mapping Australia -> Buenos Aires | edit result= Buenos (p=0.820) | success=(✓)\n",
      "Mapping Canada -> Abuja | edit result= Abu (p=0.606) | success=(✓)\n",
      "Mapping Chile -> Lima | edit result= Lima (p=0.967) | success=(✓)\n",
      "Mapping Colombia -> Berlin | edit result= Berlin (p=0.953) | success=(✓)\n",
      "Mapping Egypt -> Mexico City | edit result= Mexico (p=0.983) | success=(✓)\n",
      "Mapping France -> Riyadh | edit result= Riyadh (p=0.847) | success=(✓)\n",
      "Mapping Germany -> Cairo | edit result= Cairo (p=0.970) | success=(✓)\n",
      "Mapping India -> Lima | edit result= Lima (p=0.929) | success=(✓)\n",
      "Mapping Mexico -> Santiago | edit result= Santiago (p=0.955) | success=(✓)\n",
      "Mapping Nigeria -> Riyadh | edit result= Riyadh (p=0.849) | success=(✓)\n",
      "Mapping Pakistan -> New Delhi | edit result= New (p=0.863) | success=(✓)\n",
      "Mapping Peru -> Caracas | edit result= Car (p=0.936) | success=(✓)\n",
      "Mapping Russia -> Cairo | edit result= Cairo (p=0.966) | success=(✓)\n",
      "Mapping Saudi Arabia -> Caracas | edit result= Car (p=0.604) | success=(✓)\n",
      "Mapping South Korea -> Cairo | edit result= Cairo (p=0.925) | success=(✓)\n",
      "Mapping Spain -> Islamabad | edit result= Islamabad (p=0.883) | success=(✓)\n",
      "Mapping United States -> Ottawa | edit result= Ottawa (p=0.879) | success=(✓)\n",
      "Mapping Venezuela -> Madrid | edit result= Madrid (p=0.979) | success=(✓)\n",
      "------------------------------------------------------------\n",
      "Causality (@1) = 1.0\n",
      "------------------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "# precomputing latents to speed things up\n",
    "hs_and_zs = functional.compute_hs_and_zs(\n",
    "    mt = mt,\n",
    "    prompt_template = operator.prompt_template,\n",
    "    subjects = [sample.subject for sample in test.samples],\n",
    "    h_layer= operator.h_layer,\n",
    "    z_layer=-1,\n",
    "    batch_size = 2\n",
    ")\n",
    "\n",
    "success = 0\n",
    "fails = 0\n",
    "\n",
    "for sample in test.samples:\n",
    "    target = test_targets.get(sample)\n",
    "    assert target is not None\n",
    "    edit_result = editor(\n",
    "        subject = sample.subject,\n",
    "        target = target.subject\n",
    "    )\n",
    "    \n",
    "    success_flag = functional.is_nontrivial_prefix(\n",
    "        prediction=edit_result.predicted_tokens[0].token, target=target.object\n",
    "    )\n",
    "    \n",
    "    print(f\"Mapping {sample.subject} -> {target.object} | edit result={edit_result.predicted_tokens[0]} | success=({functional.get_tick_marker(success_flag)})\")\n",
    "    \n",
    "    success += success_flag\n",
    "    fails += not success_flag\n",
    "    \n",
    "causality = success / (success + fails)\n",
    "\n",
    "print(\"------------------------------------------------------------\")\n",
    "print(f\"Causality (@1) = {causality}\")\n",
    "print(\"------------------------------------------------------------\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c587ae85",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "36d6d2a9",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
