{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5b360b04",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "21f696b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "! which python"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "115a4228",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('..')\n",
    "\n",
    "import torch\n",
    "torch.__version__, torch.version.cuda"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f54a334",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src import models, data, lens, functional\n",
    "from src.utils import experiment_utils, env_utils\n",
    "\n",
    "import os\n",
    "import logging\n",
    "from src.utils import logging_utils\n",
    "logger = logging.getLogger(__name__)\n",
    "logging.basicConfig(\n",
    "    level=logging.DEBUG,\n",
    "    format=logging_utils.DEFAULT_FORMAT,\n",
    "    datefmt=logging_utils.DEFAULT_DATEFMT,\n",
    "    stream=sys.stdout,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee09d856",
   "metadata": {},
   "outputs": [],
   "source": [
    "# from transformers import AutoTokenizer, AutoModelForCausalLM\n",
    "\n",
    "# model_name = \"EleutherAI/pythia-2.8b-deduped\"\n",
    "\n",
    "# models_dir = env_utils.determine_models_dir()\n",
    "\n",
    "# model_path = os.path.join(models_dir, model_name)\n",
    "\n",
    "# model = AutoModelForCausalLM.from_pretrained(model_path)\n",
    "# tokenizer = AutoTokenizer.from_pretrained(model_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9d46d0c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = \"cuda:0\"\n",
    "# model_name = \"mamba-3b\"\n",
    "model_name = \"pythia-3b\"\n",
    "\n",
    "mt = models.load_model(model_name, device=device, fp16=False)\n",
    "# mt = models.load_model(\"gptj\", device=device, fp16=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0620a27c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# import baukit\n",
    "\n",
    "# mt.model.config.fused_add_norm = False\n",
    "\n",
    "# for block_path in models.determine_layer_paths(mt):\n",
    "#     block = baukit.get_module(mt.model, block_path)\n",
    "#     setattr(block, \"fused_add_norm\", False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f9fcd03e",
   "metadata": {},
   "outputs": [],
   "source": [
    "mt.model.config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5a921c7a",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.functional import predict_next_token\n",
    "\n",
    "predict_next_token(\n",
    "    mt = mt, \n",
    "    prompt = mt.tokenizer.eos_token + \" The capital of {} is\".format(\"France\"),\n",
    "    # prompt = mt.tokenizer.eos_token + \" The superlative of {} is\".format(\"good\"),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f51154e",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = data.load_dataset()\n",
    "\n",
    "# relation_names = [r.name for r in dataset.relations]\n",
    "# relation_options = Menu(choices = relation_names, value = relation_names)\n",
    "# show(relation_options) # !caution: tested in a juputer-notebook. baukit visualizations are not supported in vscode."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a17444f",
   "metadata": {},
   "outputs": [],
   "source": [
    "relation_name = \"country capital city\"\n",
    "relation = dataset.filter(relation_names=[relation_name])[0]\n",
    "print(f\"{relation.name} -- {len(relation.samples)} samples\")\n",
    "print(\"------------------------------------------------------\")\n",
    "\n",
    "experiment_utils.set_seed(12345) # set seed to a constant value for sampling consistency\n",
    "train, test = relation.split(5)\n",
    "print(\"\\n\".join([sample.__str__() for sample in train.samples]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "145bf9dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "################### hparams ###################\n",
    "layer = 12\n",
    "beta = 5\n",
    "###############################################"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "83b33032",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.operators import JacobianIclMeanEstimator\n",
    "experiment_utils.set_seed(12345)\n",
    "\n",
    "estimator = JacobianIclMeanEstimator(\n",
    "    mt = mt, \n",
    "    h_layer = layer,\n",
    "    beta = beta\n",
    ")\n",
    "operator = estimator(\n",
    "    relation.set(\n",
    "        samples=train.samples, \n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b6155e9d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# from typing import Optional\n",
    "\n",
    "# def add_npz_extension(path: str):\n",
    "#     return path if path.endswith(\".npz\") else path + \".npz\"\n",
    "\n",
    "# def load_o1_approxes(path: str, sample_subjects: Optional[list[str]] = None):\n",
    "#     approxes = []\n",
    "#     to_load = sample_subjects if sample_subjects is not None else os.listdir(path)\n",
    "#     for cached_file in to_load:\n",
    "#         file_path = add_npz_extension(os.path.join(path, cached_file))\n",
    "#         approx = functional.load_cached_linear_operator(file_path = file_path)\n",
    "#         approxes.append(approx)\n",
    "#     return approxes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "508e6d4f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ##################################################\n",
    "# n_train_samples = 5\n",
    "# ##################################################"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1bbc72d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# import os\n",
    "# import random\n",
    "\n",
    "# root_path = \"../results/cached_o1_approxes\"\n",
    "# path = os.path.join(\n",
    "#     root_path,\n",
    "#     mt.name,\n",
    "#     relation_name.lower().replace(\" \", \"_\"),\n",
    "#     str(layer),\n",
    "# )\n",
    "\n",
    "# all_cached_files = list(os.listdir(path))\n",
    "\n",
    "# train_subj_files = random.sample(all_cached_files, n_train_samples)\n",
    "\n",
    "# train_approxes = load_o1_approxes(\n",
    "#     path = path, \n",
    "#     sample_subjects = train_subj_files\n",
    "# )\n",
    "\n",
    "# train_samples = [\n",
    "#     data.RelationSample.from_dict(approx.metadata[\"sample\"]) \n",
    "#     for approx in train_approxes\n",
    "# ]\n",
    "\n",
    "# train_relation = relation.set(samples=train_samples)\n",
    "\n",
    "# test_relation = relation.set(\n",
    "#     samples = list(set(relation.samples) - set(train_relation.samples))\n",
    "# )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4af14f20",
   "metadata": {},
   "outputs": [],
   "source": [
    "# weight = torch.stack([approx.weight for approx in train_approxes]).mean(dim=0)\n",
    "# bias = torch.stack([approx.bias for approx in train_approxes]).mean(dim=0)\n",
    "\n",
    "# prompt_template = relation.prompt_templates[0]\n",
    "\n",
    "# prompt_template_icl = functional.make_prompt(\n",
    "#     mt = mt,\n",
    "#     prompt_template=prompt_template,\n",
    "#     subject=\"{}\",\n",
    "#     examples = train_samples\n",
    "# )\n",
    "\n",
    "# print(prompt_template_icl)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "32a2b071",
   "metadata": {},
   "outputs": [],
   "source": [
    "# from src.operators import LinearRelationOperator\n",
    "# operator = LinearRelationOperator(\n",
    "#     mt = mt,\n",
    "#     weight = weight,\n",
    "#     bias = bias,\n",
    "#     h_layer = train_approxes[0].h_layer,\n",
    "#     z_layer = train_approxes[0].z_layer,\n",
    "#     beta = 5,\n",
    "#     prompt_template = prompt_template_icl,\n",
    "#     metadata = {\n",
    "#         \"Jh\": [\n",
    "#             (approx.weight @ approx.h).detach().cpu().numpy()\n",
    "#             for approx in train_approxes\n",
    "#         ],\n",
    "#         \"|w|\": [approx.weight.norm().item() for approx in train_approxes],\n",
    "#         \"|b|\": [approx.bias.norm().item() for approx in train_approxes],\n",
    "#     },\n",
    "# )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1dfcbb0a",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.lens import logit_lens\n",
    "from src import models\n",
    "\n",
    "# logit_lens(mt = mt, h = operator.metadata[\"Jh\"][0].to(models.determine_device(mt)) + operator.bias)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6c1b6eda",
   "metadata": {},
   "source": [
    "# Checking $faithfulness$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9d79b613",
   "metadata": {},
   "outputs": [],
   "source": [
    "test = functional.filter_relation_samples_based_on_provided_fewshots(\n",
    "    mt=mt, test_relation=test, prompt_template=operator.prompt_template, batch_size=4\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6ad18a70",
   "metadata": {},
   "outputs": [],
   "source": [
    "sample = [s for s in test.samples if s.subject == \"France\"][0]\n",
    "print(sample)\n",
    "operator(subject = sample.subject).predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e717d47f",
   "metadata": {},
   "outputs": [],
   "source": [
    "hs_and_zs = functional.compute_hs_and_zs(\n",
    "    mt = mt,\n",
    "    prompt_template = operator.prompt_template,\n",
    "    subjects = [sample.subject],\n",
    "    h_layer= operator.h_layer,\n",
    ")\n",
    "h = hs_and_zs.h_by_subj[sample.subject]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3c69136b",
   "metadata": {},
   "source": [
    "## Approximating LM computation $F$ as an affine transformation\n",
    "\n",
    "### $$ F(\\mathbf{s}, c_r) \\approx \\beta \\, W_r \\mathbf{s} + b_r $$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d8bf8b1d",
   "metadata": {},
   "outputs": [],
   "source": [
    "z = 5 * (operator.weight @ h) + operator.bias\n",
    "\n",
    "lens.logit_lens(\n",
    "    mt = mt,\n",
    "    h = z,\n",
    "    get_proba = True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e0725d34",
   "metadata": {},
   "outputs": [],
   "source": [
    "correct = 0\n",
    "wrong = 0\n",
    "for sample in test.samples:\n",
    "    predictions = operator(subject = sample.subject).predictions\n",
    "    known_flag = functional.is_nontrivial_prefix(\n",
    "        prediction=predictions[0].token, target=sample.object\n",
    "    )\n",
    "    print(f\"{sample.subject=}, {sample.object=}, \", end=\"\")\n",
    "    print(f'predicted=\"{functional.format_whitespace(predictions[0].token)}\", (p={predictions[0].prob}), known=({functional.get_tick_marker(known_flag)})')\n",
    "    \n",
    "    correct += known_flag\n",
    "    wrong += not known_flag\n",
    "    \n",
    "faithfulness = correct/(correct + wrong)\n",
    "\n",
    "print(\"------------------------------------------------------------\")\n",
    "print(f\"Faithfulness (@1) = {faithfulness}\")\n",
    "print(\"------------------------------------------------------------\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d9a13389",
   "metadata": {},
   "source": [
    "# $causality$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "da2f8eb9",
   "metadata": {},
   "outputs": [],
   "source": [
    "################### hparams ###################\n",
    "rank = 100\n",
    "###############################################"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "25ac7213",
   "metadata": {},
   "outputs": [],
   "source": [
    "experiment_utils.set_seed(12345) # set seed to a constant value for sampling consistency\n",
    "test_targets = functional.random_edit_targets(test.samples)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "70d83c9b",
   "metadata": {},
   "source": [
    "## setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1a13c0ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "source = test.samples[0]\n",
    "target = test_targets[source]\n",
    "\n",
    "f\"Changing the mapping ({source}) to ({source.subject} -> {target.object})\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e98f67c8",
   "metadata": {},
   "source": [
    "### Calculate $\\Delta \\mathbf{s}$ such that $\\mathbf{s} + \\Delta \\mathbf{s} \\approx \\mathbf{s}'$\n",
    "\n",
    "<p align=\"center\">\n",
    "    <img align=\"center\" src=\"../causality_crop\" style=\"width:80%;\"/>\n",
    "</p>\n",
    "\n",
    "Under the relation $r =\\, $*plays the instrument*, and given the subject $s =\\, $*Miles Davis*, the model will predict $o =\\, $*trumpet* **(a)**; and given the subject $s' =\\, $*Cat Stevens*, the model will now predict $o' =\\, $*guiter* **(b)**. \n",
    "\n",
    "If the computation from $\\mathbf{s}$ to $\\mathbf{o}$ is well-approximated by $operator$ parameterized by $W_r$ and $b_r$ **(c)**, then $\\Delta{\\mathbf{s}}$ **(d)** should tell us the direction of change from $\\mathbf{s}$ to $\\mathbf{s}'$. Thus, $\\tilde{\\mathbf{s}}=\\mathbf{s}+\\Delta\\mathbf{s}$ would be an approximation of $\\mathbf{s}'$ and patching $\\tilde{\\mathbf{s}}$ in place of $\\mathbf{s}$ should change the prediction to $o'$ = *guitar* "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "53c632ed",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_delta_s(\n",
    "    operator, \n",
    "    source_subject, \n",
    "    target_subject,\n",
    "    rank = 100,\n",
    "    fix_latent_norm = None, # if set, will fix the norms of z_source and z_target\n",
    "):\n",
    "    w_p_inv = functional.low_rank_pinv(\n",
    "        matrix = operator.weight,\n",
    "        rank=rank,\n",
    "    )\n",
    "    hs_and_zs = functional.compute_hs_and_zs(\n",
    "        mt = mt,\n",
    "        prompt_template = operator.prompt_template,\n",
    "        subjects = [source_subject, target_subject],\n",
    "        h_layer= operator.h_layer,\n",
    "        z_layer=-1,\n",
    "    )\n",
    "\n",
    "    z_source = hs_and_zs.z_by_subj[source_subject]\n",
    "    z_target = hs_and_zs.z_by_subj[target_subject]\n",
    "    \n",
    "    z_source *= fix_latent_norm / z_source.norm() if fix_latent_norm is not None else 1.0\n",
    "    z_target *= z_source.norm() / z_target.norm() if fix_latent_norm is not None else 1.0\n",
    "\n",
    "    delta_s = w_p_inv @  (z_target.squeeze() - z_source.squeeze())\n",
    "\n",
    "    return delta_s, hs_and_zs\n",
    "\n",
    "delta_s, hs_and_zs = get_delta_s(\n",
    "    operator = operator,\n",
    "    source_subject = source.subject,\n",
    "    target_subject = target.subject,\n",
    "    rank = rank\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab1c7e88",
   "metadata": {},
   "outputs": [],
   "source": [
    "import baukit\n",
    "\n",
    "def get_intervention(h, int_layer, subj_idx):\n",
    "    def edit_output(output, layer):\n",
    "        if(layer != int_layer):\n",
    "            return output\n",
    "        functional.untuple(output)[:, subj_idx] = h \n",
    "        return output\n",
    "    return edit_output\n",
    "\n",
    "prompt = operator.prompt_template.format(source.subject)\n",
    "\n",
    "h_index, inputs = functional.find_subject_token_index(\n",
    "    mt=mt,\n",
    "    prompt=prompt,\n",
    "    subject=source.subject,\n",
    ")\n",
    "\n",
    "h_layer, z_layer = models.determine_layer_paths(model = mt, layers = [layer, -1])\n",
    "\n",
    "with baukit.TraceDict(\n",
    "    mt.model, layers = [h_layer, z_layer],\n",
    "    edit_output=get_intervention(\n",
    "#         h = hs_and_zs.h_by_subj[source.subject],         # let the computation proceed as usual\n",
    "        h = hs_and_zs.h_by_subj[source.subject] + delta_s, # replace s with s + delta_s\n",
    "        int_layer = h_layer, \n",
    "        subj_idx = h_index\n",
    "    )\n",
    ") as traces:\n",
    "    outputs = mt(\n",
    "        input_ids = inputs.input_ids,\n",
    "        attention_mask = inputs.attention_mask,\n",
    "    )\n",
    "\n",
    "logits = outputs.logits[0][-1] if hasattr(outputs, \"logits\") else outputs[0][-1]\n",
    "lens.interpret_logits(\n",
    "    mt = mt, \n",
    "    logits = logits, \n",
    "    get_proba=True\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b3c272c1",
   "metadata": {},
   "source": [
    "## Measuring causality"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "51efa257",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.editors import LowRankPInvEditor\n",
    "\n",
    "svd = torch.svd(operator.weight.float())\n",
    "editor = LowRankPInvEditor(\n",
    "    lre=operator,\n",
    "    rank=rank,\n",
    "    svd=svd,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "88be35dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# precomputing latents to speed things up\n",
    "hs_and_zs = functional.compute_hs_and_zs(\n",
    "    mt = mt,\n",
    "    prompt_template = operator.prompt_template,\n",
    "    subjects = [sample.subject for sample in test.samples],\n",
    "    h_layer= operator.h_layer,\n",
    "    z_layer=-1,\n",
    "    batch_size = 2\n",
    ")\n",
    "\n",
    "success = 0\n",
    "fails = 0\n",
    "\n",
    "for sample in test.samples:\n",
    "    target = test_targets.get(sample)\n",
    "    assert target is not None\n",
    "    edit_result = editor(\n",
    "        subject = sample.subject,\n",
    "        target = target.subject\n",
    "    )\n",
    "    \n",
    "    success_flag = functional.is_nontrivial_prefix(\n",
    "        prediction=edit_result.predicted_tokens[0].token, target=target.object\n",
    "    )\n",
    "    \n",
    "    print(f\"Mapping {sample.subject} -> {target.object} | edit result={edit_result.predicted_tokens[0]} | success=({functional.get_tick_marker(success_flag)})\")\n",
    "    \n",
    "    success += success_flag\n",
    "    fails += not success_flag\n",
    "    \n",
    "causality = success / (success + fails)\n",
    "\n",
    "print(\"------------------------------------------------------------\")\n",
    "print(f\"Causality (@1) = {causality}\")\n",
    "print(\"------------------------------------------------------------\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d14acef6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# from src.functional import save_linear_operator\n",
    "\n",
    "# save_linear_operator(\n",
    "#     approx = operator,\n",
    "#     file_name = \"lre_capital\",\n",
    "#     path = \"cached\"\n",
    "# )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f1cbbdde",
   "metadata": {},
   "outputs": [],
   "source": [
    "# operator_loaded = functional.load_cached_linear_operator(mt = mt, file_path = \"cached/lre_capital.npz\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "01a25b74",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
