{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import os\n",
    "import json\n",
    "import sys\n",
    "import numpy as np\n",
    "sys.path.append(\"..\")\n",
    "import copy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src import models, data, operators, editors, functional, metrics, lens\n",
    "from src.utils import logging_utils, experiment_utils\n",
    "import logging\n",
    "import torch\n",
    "import baukit\n",
    "\n",
    "logging_utils.configure(level=logging.INFO)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mt = models.load_model(\"gptj\", fp16=True, device=\"cuda\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = data.load_dataset()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "##################################\n",
    "layer = 3\n",
    "rank = 170\n",
    "beta = 2.25\n",
    "n_train = 5\n",
    "selected_relations = [r for r in dataset if r.name in [\n",
    "        \"person occupation\",\n",
    "        \"name birthplace\",\n",
    "        \"person university\"\n",
    "    ]\n",
    "]\n",
    "\n",
    "experiment_utils.set_seed(123456)\n",
    "##################################\n",
    "\n",
    "\n",
    "relation_properties = {}\n",
    "\n",
    "for relation in selected_relations:\n",
    "    train, test = relation.split(n_train)\n",
    "    prompt_template = relation.prompt_templates[0]\n",
    "\n",
    "    relation_prompt = functional.make_prompt(\n",
    "        mt=mt,\n",
    "        prompt_template=prompt_template,\n",
    "        subject=\"{}\",\n",
    "        examples=train.samples,\n",
    "    )\n",
    "\n",
    "    estimator = operators.JacobianIclMeanEstimator(\n",
    "        mt = mt, h_layer=layer, beta=beta, rank=rank\n",
    "    )\n",
    "    operator = estimator(train)\n",
    "\n",
    "    relation_properties[relation.name] = {\n",
    "        \"train\": train,\n",
    "        \"prompt_template\": prompt_template,\n",
    "        \"prompt\": relation_prompt,\n",
    "        \"operator\": operator,\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for relation_name in relation_properties:\n",
    "    print(\"-----------------------------------\")\n",
    "    print(relation_name)\n",
    "    print(\"-----------------------------------\")\n",
    "    print(f\"{[sample.__str__() for sample in relation_properties[relation_name]['train'].samples]}\")\n",
    "    print(relation_properties[relation_name]['prompt'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for relation_name in relation_properties:\n",
    "    relation_properties[relation_name][\"W_inv\"] = functional.low_rank_pinv(\n",
    "        matrix = relation_properties[relation_name][\"operator\"].weight,\n",
    "        rank=rank,\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "##################################################\n",
    "source_subject = \"X\"\n",
    "targ_prop_for_subj = {\n",
    "    \"person occupation\": \"Sherlock Holmes\",\n",
    "    \"name birthplace\": \"Jackie Chan\",\n",
    "    \"person university\": \"Bill Gates\",  \n",
    "}\n",
    "##################################################\n",
    "\n",
    "for prop in targ_prop_for_subj:\n",
    "    prompt = relation_properties[prop]['prompt'].format(source_subject)\n",
    "    obj = functional.predict_next_token(mt = mt, prompt = prompt, k=3)[0]\n",
    "    print(f\"{source_subject} -- {prop} -- {obj[0].__str__()}\")\n",
    "print(\"=================================\")\n",
    "\n",
    "for prop, subj in targ_prop_for_subj.items():\n",
    "    prompt = relation_properties[prop]['prompt'].format(subj)\n",
    "    obj = functional.predict_next_token(mt = mt, prompt = prompt, k=3)[0]\n",
    "    print(f\"{subj} -- {prop} -- {obj[0].__str__()}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_delta_s(\n",
    "    prop, \n",
    "    source_subject, \n",
    "    target_subject,\n",
    "    fix_latent_norm = None,\n",
    "):\n",
    "    w_p_inv = relation_properties[prop][\"W_inv\"]\n",
    "    hs_and_zs = functional.compute_hs_and_zs(\n",
    "        mt = mt,\n",
    "        prompt_template = relation_properties[prop][\"prompt_template\"],\n",
    "        subjects = [source_subject, target_subject],\n",
    "        h_layer= layer,\n",
    "        z_layer=-1,\n",
    "        examples= relation_properties[prop][\"train\"].samples,\n",
    "    )\n",
    "\n",
    "    z_source = hs_and_zs.z_by_subj[source_subject]\n",
    "    z_target = hs_and_zs.z_by_subj[targ_prop_for_subj[prop]]\n",
    "    # print(z_target.norm().item(), z_source.norm().item())\n",
    "\n",
    "    h_source = hs_and_zs.h_by_subj[source_subject]\n",
    "    h_target = hs_and_zs.h_by_subj[targ_prop_for_subj[prop]]\n",
    "\n",
    "    z_source *= fix_latent_norm / z_source.norm() if fix_latent_norm is not None else 1.0\n",
    "    z_target *= z_source.norm() / z_target.norm() if fix_latent_norm is not None else 1.0\n",
    "    print(z_target.norm().item(), z_source.norm().item())\n",
    "\n",
    "    delta_s = w_p_inv @  (z_target.squeeze() - z_source.squeeze())\n",
    "    \n",
    "    print(f\"h_source: {h_source.norm().item()} | h_target: {h_target.norm().item()}\")\n",
    "    print(f\"inv_h_source: {(w_p_inv @ z_source).norm().item()} | inv_h_target: {(w_p_inv @ z_target).norm().item()}\")\n",
    "    print(delta_s.norm().item())\n",
    "\n",
    "    return delta_s, hs_and_zs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "delta_s_by_prop = {}\n",
    "for relation_name in targ_prop_for_subj:\n",
    "    delta_s, hs_and_zs = get_delta_s(\n",
    "        prop = relation_name,\n",
    "        source_subject = source_subject,\n",
    "        target_subject = targ_prop_for_subj[relation_name],\n",
    "        # fix_latent_norm=250\n",
    "    )\n",
    "\n",
    "    delta_s_by_prop[relation_name] = {\n",
    "        \"delta_s\": delta_s,\n",
    "        \"hs_and_zs\": hs_and_zs,\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "drr = [prop[\"delta_s\"] for relation, prop in delta_s_by_prop.items()]\n",
    "for i in range(len(drr)):\n",
    "    for j in range(len(drr)):\n",
    "        print(f\"{i} -- {j} -- {(drr[i][None] @ drr[j][None].T).squeeze().item()}, {torch.cosine_similarity(drr[i], drr[j], dim=0).item()}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "max_norm = np.array([delta_s_by_prop[relation_name][\"delta_s\"].norm().item() for relation_name in delta_s_by_prop.keys()]).max()\n",
    "\n",
    "cumulative_delta_s = torch.zeros_like(delta_s_by_prop[prop][\"delta_s\"])\n",
    "for relation_name in delta_s_by_prop:\n",
    "    ds = delta_s_by_prop[relation_name][\"delta_s\"]\n",
    "    ds = ds*max_norm / ds.norm()\n",
    "    delta_s_by_prop[relation_name][\"delta_s\"] = ds\n",
    "    cumulative_delta_s += ds\n",
    "cumulative_delta_s /= 3\n",
    "cumulative_delta_s.shape, cumulative_delta_s.norm().item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# relation_names = [\n",
    "#         \"person occupation\",\n",
    "#         \"name birthplace\",\n",
    "#         \"person university\"\n",
    "# ]\n",
    "\n",
    "# cumulative_delta_s = (\n",
    "#     delta_s_by_prop[relation_names[0]][\"delta_s\"] + \n",
    "#     delta_s_by_prop[relation_names[1]][\"delta_s\"] + \n",
    "#     delta_s_by_prop[relation_names[2]][\"delta_s\"]\n",
    "# )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# prop = \"person occupation\"\n",
    "# prop = \"name birthplace\"\n",
    "prop = \"person university\"\n",
    "\n",
    "delta_s, hs_and_zs = get_delta_s(\n",
    "    prop = prop,\n",
    "    source_subject = source_subject,\n",
    "    target_subject = targ_prop_for_subj[prop],\n",
    "    fix_latent_norm = 250\n",
    ")\n",
    "\n",
    "def get_intervention(h, int_layer, subj_idx):\n",
    "    def edit_output(output, layer):\n",
    "        if(layer != int_layer):\n",
    "            return output\n",
    "        functional.untuple(output)[:, subj_idx] = h\n",
    "        return output\n",
    "    return edit_output\n",
    "\n",
    "prompt = relation_properties[prop][\"prompt\"].format(source_subject)\n",
    "\n",
    "h_index, inputs = functional.find_subject_token_index(\n",
    "    mt=mt,\n",
    "    prompt=prompt,\n",
    "    subject=source_subject,\n",
    ")\n",
    "\n",
    "h_layer, z_layer = models.determine_layer_paths(model = mt, layers = [layer, -1])\n",
    "\n",
    "with baukit.TraceDict(\n",
    "    mt.model, layers = [h_layer, z_layer],\n",
    "    edit_output=get_intervention(\n",
    "        # h = hs_and_zs.h_by_subj[source_subject]\n",
    "        # h = hs_and_zs.h_by_subj[source_subject] + delta_s,\n",
    "        h = hs_and_zs.h_by_subj[source_subject] + cumulative_delta_s, \n",
    "        int_layer = h_layer, \n",
    "        subj_idx = h_index\n",
    "    )\n",
    ") as traces:\n",
    "    outputs = mt.model(\n",
    "        input_ids = inputs.input_ids,\n",
    "        attention_mask = inputs.attention_mask,\n",
    "    )\n",
    "\n",
    "lens.interpret_logits(\n",
    "    mt = mt, \n",
    "    logits = outputs.logits[0][-1], \n",
    "    # get_proba=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
