{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0e59be83",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"../..\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f12291ca",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src import models\n",
    "\n",
    "device = \"cuda:5\"\n",
    "mt = models.load_model(\"gptj\", device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8b09ebae",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src import data\n",
    "\n",
    "dataset = data.load_dataset()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7832a077",
   "metadata": {},
   "source": [
    "# Specificity"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "57afad25",
   "metadata": {},
   "outputs": [],
   "source": [
    "from functools import cache\n",
    "\n",
    "from src import editors, functional, hparams, operators\n",
    "from src.utils import experiment_utils\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "import torch\n",
    "\n",
    "sns.set(font=\"Serif\")\n",
    "sns.set_theme(style=\"white\", palette=\"bright\", font=\"Serif\")\n",
    "\n",
    "\n",
    "N_TRAIN = 5\n",
    "\n",
    "\n",
    "def require_sample(subject, relation):\n",
    "    matches = [x for x in relation.samples if x.subject == subject]\n",
    "    assert len(matches) >= 1, matches\n",
    "    return matches[0]\n",
    "\n",
    "\n",
    "@torch.inference_mode()\n",
    "def compute_zs(prompt_template, subj, targ):\n",
    "    prompt_subj = prompt_template.format(subj)\n",
    "    prompt_targ = prompt_template.format(targ)\n",
    "    with models.set_padding_side(mt, padding_side=\"left\"):\n",
    "        inputs = mt.tokenizer(\n",
    "            [prompt_subj, prompt_targ],\n",
    "            return_tensors=\"pt\",\n",
    "            padding=\"longest\"\n",
    "        ).to(device)\n",
    "    [[hs], _] = functional.compute_hidden_states(\n",
    "        mt=mt,\n",
    "        layers=[27],\n",
    "        inputs=inputs\n",
    "    )\n",
    "    z_subj = hs[0, -1]\n",
    "    z_targ = hs[1, -1]\n",
    "    return z_subj, z_targ\n",
    "\n",
    "\n",
    "def sweep_specificity(\n",
    "    relation_edit,\n",
    "    relation_ref,\n",
    "    subject_orig,\n",
    "    subject_targ,\n",
    "    pt_zs_edit=None,\n",
    "    pt_zs_ref=None,\n",
    "    ranks=None,\n",
    "):\n",
    "    experiment_utils.set_seed(12345)\n",
    "    if ranks is None:\n",
    "        ranks = range(0, 100, 2)\n",
    "\n",
    "    relation_edit = dataset.filter(relation_names=[relation_edit])[0]\n",
    "    relation_ref = dataset.filter(relation_names=[relation_ref])[0]\n",
    "\n",
    "    if isinstance(subject_orig, tuple):\n",
    "        sample_orig_edit, sample_orig_ref = subject_orig\n",
    "        subject_orig = sample_orig_edit.subject\n",
    "        assert subject_orig == sample_orig_ref.subject\n",
    "    else:\n",
    "        sample_orig_edit = require_sample(subject_orig, relation_edit)\n",
    "        sample_orig_ref = require_sample(subject_orig, relation_ref)\n",
    "\n",
    "    if isinstance(subject_targ, tuple):\n",
    "        sample_targ_edit, sample_targ_ref = subject_targ\n",
    "        subject_targ = sample_targ_edit.subject\n",
    "        assert subject_targ == sample_targ_ref.subject\n",
    "    else:\n",
    "        sample_targ_edit = require_sample(subject_targ, relation_edit)\n",
    "        sample_targ_ref = require_sample(subject_targ, relation_ref)\n",
    "\n",
    "    print(f\"{sample_orig_edit=}, {sample_orig_ref=}, {sample_targ_edit=}, {sample_targ_ref=}\")\n",
    "        \n",
    "    train, _ = relation_edit.without(sample_orig_edit).without(sample_targ_edit).split(N_TRAIN)\n",
    "\n",
    "    relation_hparams = hparams.get(mt, relation_edit)\n",
    "    estimator = operators.JacobianIclMeanEstimator(\n",
    "        mt=mt,\n",
    "        h_layer=relation_hparams.h_layer,\n",
    "        z_layer=relation_hparams.z_layer,\n",
    "        beta=relation_hparams.beta,\n",
    "    )\n",
    "\n",
    "    print(\"estimating LRE...\")\n",
    "    operator = estimator(train)\n",
    "\n",
    "    if pt_zs_edit is None:\n",
    "        pt_zs_edit = relation_edit.prompt_templates_zs[0]\n",
    "    print(\"prompt template being edtied: \" + pt_zs_edit)\n",
    "    \n",
    "    if pt_zs_ref is None:\n",
    "        pt_zs_ref = relation_ref.prompt_templates_zs[0]\n",
    "    print(\"prompt template being referenced: \" + pt_zs_ref)\n",
    "\n",
    "    print(\"precomputing zs...\")\n",
    "    z_subj, z_targ = compute_zs(\n",
    "        pt_zs_edit,\n",
    "        subject_orig,\n",
    "        subject_targ,\n",
    "    )\n",
    "\n",
    "    print(\"begin sweep over rank...\")\n",
    "    ys_orig_edit = []\n",
    "    ys_orig_ref = []\n",
    "    ys_targ_edit = []\n",
    "    ys_targ_ref = []\n",
    "    for rank in ranks:\n",
    "        editor = editors.LowRankPInvEditor(lre=operator, rank=rank, n_new_tokens=3)\n",
    "\n",
    "        # Hack to overwrite prompt template, whatever...\n",
    "        object.__setattr__(editor.lre, \"prompt_template\", pt_zs_edit)\n",
    "        results_edit = editor(subject_orig, subject_targ, z_original=z_subj, z_target=z_targ)\n",
    "        logits_edit = results_edit.model_logits\n",
    "        gens_edit = results_edit.model_generations[0]\n",
    "        best_edit = [str(o) for o in results_edit.predicted_tokens][0]\n",
    "\n",
    "        object.__setattr__(editor.lre, \"prompt_template\", pt_zs_ref)\n",
    "        results_ref = editor(subject_orig, subject_targ, z_original=z_subj, z_target=z_targ)\n",
    "        logits_ref = results_ref.model_logits\n",
    "        gens_ref = results_ref.model_generations[0]\n",
    "        best_ref = [str(p) for p in results_ref.predicted_tokens][0]\n",
    "\n",
    "        print(\n",
    "            rank,\n",
    "            best_edit,\n",
    "#             gens_edit,\n",
    "            best_ref,\n",
    "#             gens_ref,\n",
    "            sep=\"\\t\",\n",
    "        )\n",
    "\n",
    "        for tok, logits, ys in (\n",
    "            (sample_orig_edit.object, logits_edit, ys_orig_edit),\n",
    "            (sample_targ_edit.object, logits_edit, ys_targ_edit),\n",
    "            (sample_orig_ref.object, logits_ref, ys_orig_ref),\n",
    "            (sample_targ_ref.object, logits_ref, ys_targ_ref),\n",
    "        ):\n",
    "            probs = logits.float().softmax(dim=0)\n",
    "            tok_id = mt.tokenizer.encode(\" \" + tok)[0]\n",
    "            prob = probs[tok_id].item()\n",
    "            ys.append(prob)\n",
    "    \n",
    "    return (\n",
    "        ranks,\n",
    "        ys_orig_edit,\n",
    "        ys_orig_ref,\n",
    "        ys_targ_edit,\n",
    "        ys_targ_ref,\n",
    "        relation_edit,\n",
    "        sample_orig_edit,\n",
    "        sample_orig_ref,\n",
    "        sample_targ_edit,\n",
    "        sample_targ_ref,\n",
    "    )\n",
    "\n",
    "\n",
    "def plot(results):\n",
    "    (\n",
    "        ranks,\n",
    "        ys_orig_edit,\n",
    "        ys_orig_ref,\n",
    "        ys_targ_edit,\n",
    "        ys_targ_ref,\n",
    "        relation_edit,\n",
    "        sample_orig_edit,\n",
    "        sample_orig_ref,\n",
    "        sample_targ_edit,\n",
    "        sample_targ_ref,\n",
    "    ) = results\n",
    "    plt.title(f'Change \"{relation_edit.name}\" of \"{sample_orig_edit.subject}\" -> \"{sample_targ_edit.subject}\"')\n",
    "    plt.plot(ranks, ys_orig_edit, label=f\"p({sample_orig_edit.object})\", color=\"deepskyblue\", linewidth=2)\n",
    "    plt.plot(ranks, ys_targ_edit, label=f\"p({sample_targ_edit.object})\", color=\"darkblue\", linewidth=2)\n",
    "    plt.plot(ranks, ys_orig_ref, label=f\"p({sample_orig_ref.object})\", color=\"deepskyblue\", linestyle=\"dashed\", linewidth=2)\n",
    "    plt.plot(ranks, ys_targ_ref, label=f\"p({sample_targ_ref.object})\", color=\"darkblue\", linestyle=\"dashed\", linewidth=2)\n",
    "    plt.xlabel(\"Rank\")\n",
    "    plt.ylabel(\"LM Probability\")\n",
    "    plt.yticks(np.linspace(0, 1, 11))\n",
    "    plt.legend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9398ee5b",
   "metadata": {},
   "outputs": [],
   "source": [
    "relation_edit = \"country capital city\"\n",
    "relation_ref = \"country largest city\"\n",
    "subj_orig = \"United States\"\n",
    "subj_targ = \"China\"\n",
    "results = sweep_specificity(\n",
    "    relation_edit,\n",
    "    relation_ref,\n",
    "    subj_orig,\n",
    "    subj_targ,\n",
    "    pt_zs_edit=\"{}'s capital city,\",\n",
    "    pt_zs_ref=\"{}'s largest city,\",\n",
    "    ranks=range(100, 200, 5)\n",
    ")\n",
    "plot(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f0ecb5c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "relation_edit = \"word first letter\"\n",
    "relation_ref = \"word sentiment\"\n",
    "subj_orig = (\n",
    "    data.RelationSample(\"Horror\", \"H\"),\n",
    "    data.RelationSample(\"Horror\", \"negative\"),\n",
    ")\n",
    "subj_targ = (\n",
    "    data.RelationSample(\"Joy\", \"J\"),\n",
    "    data.RelationSample(\"Joy\", \"positive\"),\n",
    ")\n",
    "results = sweep_specificity(\n",
    "    relation_edit,\n",
    "    relation_ref,\n",
    "    subj_orig,\n",
    "    subj_targ,\n",
    "    ranks=range(0, 100, 5),\n",
    ")\n",
    "plot(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "003f1ee1",
   "metadata": {},
   "outputs": [],
   "source": [
    "relation_edit = \"plays instrument\"\n",
    "relation_ref = \"person native language\"\n",
    "subj_orig = (\n",
    "    data.RelationSample(\"Eric Clapton\", \"guitar\"),\n",
    "    data.RelationSample(\"Eric Claptop\", \"English\"),\n",
    ")\n",
    "subj_targ = (\n",
    "    data.RelationSample(\"\", \"soccer\"),\n",
    "    data.RelationSample(\"Lionel Messi\", \"Spanish\"),\n",
    ")\n",
    "results = sweep_specificity(\n",
    "    relation_edit,\n",
    "    relation_ref,\n",
    "    subj_orig,\n",
    "    subj_targ,\n",
    "    pt_zs_ref=\"{}, whose first language was\",\n",
    "    ranks=range(0, 150, 10),\n",
    ")\n",
    "plot(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2e333957",
   "metadata": {},
   "outputs": [],
   "source": [
    "def determine_subject_overlap(r1_name, r2_name):\n",
    "    r1 = dataset.filter(relation_names=[r1_name])[0]\n",
    "    r1_subjs = {x.subject for x in r1.samples}\n",
    "\n",
    "    r2 = dataset.filter(relation_names=[r2_name])[0]\n",
    "    r2_subjs = {x.subject for x in r2.samples}\n",
    "    \n",
    "    print(r1_subjs & r2_subjs)\n",
    "\n",
    "determine_subject_overlap(\"word first letter\", \"word sentiment\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2908c0ab",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "758efeb3",
   "metadata": {},
   "outputs": [],
   "source": [
    "relation = dataset.filter(relation_names=[\"company CEO\"])[0]\n",
    "\n",
    "train, test = relation.split(N_TRAIN)\n",
    "\n",
    "relation_hparams = hparams.get(mt, relation_edit)\n",
    "estimator = operators.JacobianIclMeanEstimator(\n",
    "    mt=mt,\n",
    "    h_layer=relation_hparams.h_layer,\n",
    "    z_layer=relation_hparams.z_layer,\n",
    "    beta=relation_hparams.beta,\n",
    ")\n",
    "\n",
    "print(\"estimating LRE...\")\n",
    "operator = estimator(train)\n",
    "\n",
    "for sample in test.samples:\n",
    "    predictions = operator(sample.subject).predictions\n",
    "    print(sample.subject, sample.object, predictions[0], predictions[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bf8b3773",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d37b17d7",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
