{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('..')\n",
    "\n",
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "from src import models, data\n",
    "from tqdm.auto import tqdm\n",
    "import json\n",
    "import os\n",
    "import numpy as np\n",
    "import copy\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mt = models.load_model(\"gptj\", fp16=True, device=\"cuda\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#####################################\n",
    "relation_name = \"plays pro sport\"\n",
    "#####################################"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = data.load_dataset()\n",
    "relation = dataset.filter(\n",
    "    relation_names = [relation_name]\n",
    ")[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.utils.sweep_utils import read_sweep_results, relation_from_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sweep_dict = read_sweep_results(\"../results/sweep-24-trials/gptj\", relation_names=[relation_name])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "relation_result = relation_from_dict(sweep_dict[relation_name])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trial_options = list(range(len(relation_result.trials)))\n",
    "print(f\"{trial_options=}\")\n",
    "\n",
    "layer_options = [layer.layer for layer in relation_result.trials[0].layers]\n",
    "print(f\"{layer_options=}\")\n",
    "\n",
    "rank_options = [rank.rank for rank in relation_result.trials[0].layers[0].result.ranks]\n",
    "print(f\"{rank_options=}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "relation_result.best_by_efficacy().__dict__"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#########################################################################################################\n",
    "TRIAL_NO = 10\n",
    "RANK = 192\n",
    "LAYER = 27\n",
    "#########################################################################################################"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "layer_result = [layer for layer in relation_result.trials[TRIAL_NO].layers if layer.layer == LAYER][0]\n",
    "rank_result = [rank for rank in layer_result.result.ranks if rank.rank == RANK][0]\n",
    "rank_result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "efficacy_successes = {s.target.subject : s for s in rank_result.efficacy_successes}\n",
    "\n",
    "for beta_result in layer_result.result.betas:\n",
    "    faithfulness_successes = beta_result.faithfulness_successes\n",
    "    for sample in faithfulness_successes:\n",
    "        if(sample.subject in efficacy_successes):\n",
    "            print(f\"Edit: {efficacy_successes[sample.subject].source} <to> {efficacy_successes[sample.subject].target} -- found in beta: {beta_result.beta}\")\n",
    "            efficacy_successes.pop(sample.subject)\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"No target match found in faithfulness successes for the following:\")\n",
    "for sample in efficacy_successes.values():\n",
    "    print(f\"Edit: {sample.source} <to> {sample.target}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_samples = layer_result.result.samples\n",
    "train_samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt_template = relation_result.trials[TRIAL_NO].prompt_template\n",
    "prompt_template"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src import functional, operators, editors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "estimator = operators.JacobianIclMeanEstimator(\n",
    "    mt = mt,\n",
    "    h_layer = LAYER,\n",
    ")\n",
    "\n",
    "operator = estimator(\n",
    "    relation.set(\n",
    "        samples = train_samples,\n",
    "        prompt_templates = [prompt_template],\n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "############################################\n",
    "subject = \"Hungary\"\n",
    "############################################"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# model predicts correctly\n",
    "functional.predict_next_token(\n",
    "    mt = mt,\n",
    "    prompt = operator.prompt_template.format(subject)\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# LRE fails (low faithfulness)\n",
    "operator(subject=subject)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "svd = torch.svd(operator.weight.float())\n",
    "editor = editors.LowRankPInvEditor(\n",
    "    lre = operator,\n",
    "    rank = rank_result.rank,\n",
    "    n_samples=1, n_new_tokens=1,\n",
    "    svd = svd\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "efficacy_test_pair = efficacy_successes[subject]\n",
    "f\"Editing: {efficacy_test_pair.source} <to> {efficacy_test_pair.target}\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# editing succeeds (high efficacy)\n",
    "editor(\n",
    "    subject = efficacy_test_pair.source.subject,\n",
    "    target = efficacy_test_pair.target.subject,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
