{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('..')\n",
    "\n",
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "from src import models, data\n",
    "from tqdm.auto import tqdm\n",
    "import json\n",
    "import os\n",
    "import numpy as np\n",
    "import copy\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "mt = models.load_model(\"gptj\", fp16=True, device=\"cuda\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "#####################################\n",
    "relation_name = \"person occupation\"\n",
    "#####################################"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = data.load_dataset()\n",
    "relation = dataset.filter(\n",
    "    relation_names = [relation_name]\n",
    ")[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.utils.sweep_utils import read_sweep_results, relation_from_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "sweep_dict = read_sweep_results(\n",
    "    \"../results/sweep-24-trials/gptj\", \n",
    "    relation_names=[relation_name], \n",
    "    economy=False\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "relation_result = relation_from_dict(sweep_dict[relation_name])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "trial_options=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23]\n",
      "layer_options=['emb', 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27]\n",
      "rank_options=[0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120, 128, 136, 144, 152, 160, 168, 176, 184, 192, 200, 208, 216, 224, 232, 240, 248, 256, 264, 272, 280, 288, 296, 304, 312]\n"
     ]
    }
   ],
   "source": [
    "trial_options = list(range(len(relation_result.trials)))\n",
    "print(f\"{trial_options=}\")\n",
    "\n",
    "layer_options = [layer.layer for layer in relation_result.trials[0].layers]\n",
    "print(f\"{layer_options=}\")\n",
    "\n",
    "rank_options = [rank.rank for rank in relation_result.trials[0].layers[0].result.ranks]\n",
    "print(f\"{rank_options=}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'layer': 8,\n",
       " 'beta': AggregateMetric(mean=2.25, stdev=0.0, stderr=0.0, values=[2.25, 2.25, 2.25, 2.25, 2.25, 2.25, 2.25, 2.25, 2.25, 2.25, 2.25, 2.25, 2.25, 2.25, 2.25, 2.25, 2.25, 2.25, 2.25, 2.25, 2.25, 2.25, 2.25, 2.25]),\n",
       " 'recall': AggregateMetric(mean=0.4944704266017017, stdev=0.07817291710101862, stderr=0.015956979883532067, values=[0.47019867549668876, 0.6115107913669064, 0.39634146341463417, 0.5294117647058824, 0.4968152866242038, 0.5032258064516129, 0.45454545454545453, 0.5172413793103449, 0.5130434782608696, 0.5779816513761468, 0.4470588235294118, 0.46987951807228917, 0.43157894736842106, 0.49645390070921985, 0.6238532110091743, 0.6517857142857143, 0.3763440860215054, 0.38345864661654133, 0.48936170212765956, 0.5826086956521739, 0.4460431654676259, 0.34554973821989526, 0.5656565656565656, 0.4873417721518987]),\n",
       " 'rank': AggregateMetric(mean=131.66666666666666, stdev=55.520766885513716, stderr=11.333129083126817, values=[168, 48, 152, 120, 152, 72, 104, 96, 80, 112, 96, 64, 72, 168, 88, 168, 120, 176, 120, 168, 136, 312, 200, 168]),\n",
       " 'efficacy': AggregateMetric(mean=0.6557305419443131, stdev=0.05759712523278142, stderr=0.011756963955958022, values=[0.6887417218543046, 0.6474820143884892, 0.6524390243902439, 0.6554621848739496, 0.5987261146496815, 0.6258064516129033, 0.6363636363636364, 0.6724137931034483, 0.6347826086956522, 0.7339449541284404, 0.5529411764705883, 0.5662650602409639, 0.631578947368421, 0.7092198581560284, 0.5963302752293578, 0.5982142857142857, 0.7526881720430108, 0.5939849624060151, 0.75177304964539, 0.7130434782608696, 0.6690647482014388, 0.643979057591623, 0.6464646464646465, 0.7658227848101266])}"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "relation_result.best_by_efficacy(beta = 2.25).__dict__"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "#########################################################################################################\n",
    "TRIAL_NO = 7\n",
    "RANK = 136\n",
    "LAYER = 8\n",
    "#########################################################################################################"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "SweepRankResults(rank=136, efficacy=[0.6293103448275862, 0.7931034482758621, 0.853448275862069], efficacy_successes=[EfficacyTestPair(source=RelationSample(subject='Adelaide Crapsey', object='poet'), target=RelationSample(subject='Augustin-Jean Fresnel', object='physicist')), EfficacyTestPair(source=RelationSample(subject='Anders Fogh Rasmussen', object='politician'), target=RelationSample(subject='Werner Heisenberg', object='physicist')), EfficacyTestPair(source=RelationSample(subject='Andrew Ross Sorkin', object='journalist'), target=RelationSample(subject='Edward Burtynsky', object='photographer')), EfficacyTestPair(source=RelationSample(subject='Antoine Augustin Cournot', object='mathematician'), target=RelationSample(subject='Oskar Lafontaine', object='politician')), EfficacyTestPair(source=RelationSample(subject='Anton Walter', object='composer'), target=RelationSample(subject='Lim Chin Siong', object='politician')), EfficacyTestPair(source=RelationSample(subject='Augustin-Jean Fresnel', object='physicist'), target=RelationSample(subject='Cledwyn Hughes, Baron Cledwyn of Penrhos', object='politician')), EfficacyTestPair(source=RelationSample(subject='Benedict Calvert, 4th Baron Baltimore', object='politician'), target=RelationSample(subject='Nikolaj Frederik Severin Grundtvig', object='poet')), EfficacyTestPair(source=RelationSample(subject='Bertold Hummel', object='composer'), target=RelationSample(subject='Cledwyn Hughes, Baron Cledwyn of Penrhos', object='politician')), EfficacyTestPair(source=RelationSample(subject='Boethius', object='philosopher'), target=RelationSample(subject='R. H. Bruce Lockhart', object='journalist')), EfficacyTestPair(source=RelationSample(subject='Boris Tchaikovsky', object='composer'), target=RelationSample(subject='David Weigel', object='journalist')), EfficacyTestPair(source=RelationSample(subject='Craig Charles', object='comedian'), target=RelationSample(subject='Thomas W. Knox', object='journalist')), EfficacyTestPair(source=RelationSample(subject='David Gascoyne', object='poet'), target=RelationSample(subject='Andrew Ross Sorkin', object='journalist')), EfficacyTestPair(source=RelationSample(subject='David Weigel', object='journalist'), target=RelationSample(subject='Thomas Troelsen', object='composer')), EfficacyTestPair(source=RelationSample(subject='Edward Burtynsky', object='photographer'), target=RelationSample(subject='Evan Bayh', object='politician')), EfficacyTestPair(source=RelationSample(subject='Egon Brunswik', object='psychologist'), target=RelationSample(subject='Willem de Sitter', object='physicist')), EfficacyTestPair(source=RelationSample(subject='Ermanno Wolf-Ferrari', object='composer'), target=RelationSample(subject='Boethius', object='philosopher')), EfficacyTestPair(source=RelationSample(subject='Evan Bayh', object='politician'), target=RelationSample(subject='Mia Freedman', object='journalist')), EfficacyTestPair(source=RelationSample(subject='Fiona Russell Powell', object='journalist'), target=RelationSample(subject='Murli Deora', object='politician')), EfficacyTestPair(source=RelationSample(subject='George Osborne', object='politician'), target=RelationSample(subject='Anton Walter', object='composer')), EfficacyTestPair(source=RelationSample(subject='Georges Aperghis', object='composer'), target=RelationSample(subject='Hamdeen Sabahi', object='politician')), EfficacyTestPair(source=RelationSample(subject='Hamdeen Sabahi', object='politician'), target=RelationSample(subject='Rob Owen', object='journalist')), EfficacyTestPair(source=RelationSample(subject='Henry Jackman', object='composer'), target=RelationSample(subject='Peter Arnett', object='journalist')), EfficacyTestPair(source=RelationSample(subject='Henry Maudsley', object='psychiatrist'), target=RelationSample(subject='Martha Nussbaum', object='philosopher')), EfficacyTestPair(source=RelationSample(subject='Howard Sounes', object='journalist'), target=RelationSample(subject='Craig Charles', object='comedian')), EfficacyTestPair(source=RelationSample(subject='J. Anthony Lukas', object='journalist'), target=RelationSample(subject='Antoine Augustin Cournot', object='mathematician')), EfficacyTestPair(source=RelationSample(subject='J. Gwyn Griffiths', object='poet'), target=RelationSample(subject='Anton Walter', object='composer')), EfficacyTestPair(source=RelationSample(subject='James Elroy Flecker', object='poet'), target=RelationSample(subject='Karl Taylor Compton', object='physicist')), EfficacyTestPair(source=RelationSample(subject='James Fenimore Cooper', object='novelist'), target=RelationSample(subject='Hilary Putnam', object='philosopher')), EfficacyTestPair(source=RelationSample(subject='Jan Tennant', object='journalist'), target=RelationSample(subject='Edward Burtynsky', object='photographer')), EfficacyTestPair(source=RelationSample(subject='Johannes Nucius', object='composer'), target=RelationSample(subject='Peter Arnett', object='journalist')), EfficacyTestPair(source=RelationSample(subject='John Crewe, 1st Baron Crewe', object='politician'), target=RelationSample(subject='David Weigel', object='journalist')), EfficacyTestPair(source=RelationSample(subject='John Horgan', object='journalist'), target=RelationSample(subject='Cledwyn Hughes, Baron Cledwyn of Penrhos', object='politician')), EfficacyTestPair(source=RelationSample(subject='John Koethe', object='poet'), target=RelationSample(subject='Hilary Putnam', object='philosopher')), EfficacyTestPair(source=RelationSample(subject='John Rentoul', object='journalist'), target=RelationSample(subject='Ermanno Wolf-Ferrari', object='composer')), EfficacyTestPair(source=RelationSample(subject='Jonathan Kwitny', object='journalist'), target=RelationSample(subject='William John Gruffydd', object='poet')), EfficacyTestPair(source=RelationSample(subject='Jourdan Miller', object='model'), target=RelationSample(subject='Karl Taylor Compton', object='physicist')), EfficacyTestPair(source=RelationSample(subject='Karl Taylor Compton', object='physicist'), target=RelationSample(subject='Felix Salmon', object='journalist')), EfficacyTestPair(source=RelationSample(subject='Karl-Theodor zu Guttenberg', object='politician'), target=RelationSample(subject='Willem de Sitter', object='physicist')), EfficacyTestPair(source=RelationSample(subject='Klemens von Metternich', object='diplomat'), target=RelationSample(subject='Michel Corrette', object='composer')), EfficacyTestPair(source=RelationSample(subject='Laurent Lafforgue', object='mathematician'), target=RelationSample(subject='Sunny Hundal', object='journalist')), EfficacyTestPair(source=RelationSample(subject='Lim Chin Siong', object='politician'), target=RelationSample(subject='Martha Nussbaum', object='philosopher')), EfficacyTestPair(source=RelationSample(subject='Louis Moreau Gottschalk', object='composer'), target=RelationSample(subject='Craig Charles', object='comedian')), EfficacyTestPair(source=RelationSample(subject='Malcolm MacDonald', object='politician'), target=RelationSample(subject='Felix Salmon', object='journalist')), EfficacyTestPair(source=RelationSample(subject='Maurice Duverger', object='politician'), target=RelationSample(subject='Charles Kaiser', object='journalist')), EfficacyTestPair(source=RelationSample(subject='Michael Dorn', object='actor'), target=RelationSample(subject='Yakubu Gowon', object='politician')), EfficacyTestPair(source=RelationSample(subject='Michael Healy-Rae', object='politician'), target=RelationSample(subject='Charles Kaiser', object='journalist')), EfficacyTestPair(source=RelationSample(subject='Michael William Balfe', object='composer'), target=RelationSample(subject='J. Anthony Lukas', object='journalist')), EfficacyTestPair(source=RelationSample(subject='Michel Corrette', object='composer'), target=RelationSample(subject='John Crewe, 1st Baron Crewe', object='politician')), EfficacyTestPair(source=RelationSample(subject='Moshe Levinger', object='rabbi'), target=RelationSample(subject='Paul Hellyer', object='politician')), EfficacyTestPair(source=RelationSample(subject='Nao Takasugi', object='politician'), target=RelationSample(subject='Nikolai Myaskovsky', object='composer')), EfficacyTestPair(source=RelationSample(subject='Naomi Shihab Nye', object='poet'), target=RelationSample(subject='Felix Salmon', object='journalist')), EfficacyTestPair(source=RelationSample(subject='Nikki Giovanni', object='poet'), target=RelationSample(subject='Augustin-Jean Fresnel', object='physicist')), EfficacyTestPair(source=RelationSample(subject='Nikolai Myaskovsky', object='composer'), target=RelationSample(subject='Muhammad Iqbal', object='poet')), EfficacyTestPair(source=RelationSample(subject='Nikolaj Frederik Severin Grundtvig', object='poet'), target=RelationSample(subject='Willem de Sitter', object='physicist')), EfficacyTestPair(source=RelationSample(subject='Nureddin Pasha', object='politician'), target=RelationSample(subject='Enrique Granados', object='composer')), EfficacyTestPair(source=RelationSample(subject='Oskar Lafontaine', object='politician'), target=RelationSample(subject='Emmanuel Chabrier', object='composer')), EfficacyTestPair(source=RelationSample(subject='Peter Abelard', object='philosopher'), target=RelationSample(subject='Ross Garnaut', object='economist')), EfficacyTestPair(source=RelationSample(subject='Peter Arnett', object='journalist'), target=RelationSample(subject='James Elroy Flecker', object='poet')), EfficacyTestPair(source=RelationSample(subject='Rob Owen', object='journalist'), target=RelationSample(subject='Klemens von Metternich', object='diplomat')), EfficacyTestPair(source=RelationSample(subject='Robert I. Soare', object='mathematician'), target=RelationSample(subject='Ian Hanomansing', object='journalist')), EfficacyTestPair(source=RelationSample(subject='Roberto Gerhard', object='composer'), target=RelationSample(subject='James E. Ferguson', object='politician')), EfficacyTestPair(source=RelationSample(subject='Ross Garnaut', object='economist'), target=RelationSample(subject='John Crewe, 1st Baron Crewe', object='politician')), EfficacyTestPair(source=RelationSample(subject='Sabina Spielrein', object='psychologist'), target=RelationSample(subject='Anders Fogh Rasmussen', object='politician')), EfficacyTestPair(source=RelationSample(subject='Sir William Hart Dyke, 7th Baronet', object='politician'), target=RelationSample(subject='Jonathan Kwitny', object='journalist')), EfficacyTestPair(source=RelationSample(subject='Sunny Hundal', object='journalist'), target=RelationSample(subject='Chris Barrie', object='comedian')), EfficacyTestPair(source=RelationSample(subject='Suzanne Virdee', object='journalist'), target=RelationSample(subject='Franco Sacchetti', object='poet')), EfficacyTestPair(source=RelationSample(subject='Tariq Saleh', object='journalist'), target=RelationSample(subject='David Gascoyne', object='poet')), EfficacyTestPair(source=RelationSample(subject='Thomas Troelsen', object='composer'), target=RelationSample(subject='Muhammad Iqbal', object='poet')), EfficacyTestPair(source=RelationSample(subject='Thomas W. Knox', object='journalist'), target=RelationSample(subject='Naomi Shihab Nye', object='poet')), EfficacyTestPair(source=RelationSample(subject='Thomas de Waal', object='journalist'), target=RelationSample(subject='Hilary Putnam', object='philosopher')), EfficacyTestPair(source=RelationSample(subject='Vic Sotto', object='comedian'), target=RelationSample(subject='Murli Deora', object='politician')), EfficacyTestPair(source=RelationSample(subject='Walter Hines Page', object='journalist'), target=RelationSample(subject='Malcolm MacDonald', object='politician')), EfficacyTestPair(source=RelationSample(subject='William Jennings Bryan', object='politician'), target=RelationSample(subject='Nikolai Myaskovsky', object='composer'))])"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "layer_result = [layer for layer in relation_result.trials[TRIAL_NO].layers if layer.layer == LAYER][0]\n",
    "rank_result = [rank for rank in layer_result.result.ranks if rank.rank == RANK][0]\n",
    "rank_result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Edit: David Gascoyne -> poet <to> Andrew Ross Sorkin -> journalist -- found in beta: 0.5\n",
      "Edit: John Crewe, 1st Baron Crewe -> politician <to> David Weigel -> journalist -- found in beta: 0.5\n",
      "Edit: Oskar Lafontaine -> politician <to> Emmanuel Chabrier -> composer -- found in beta: 0.5\n",
      "Edit: Nureddin Pasha -> politician <to> Enrique Granados -> composer -- found in beta: 0.5\n",
      "Edit: Naomi Shihab Nye -> poet <to> Felix Salmon -> journalist -- found in beta: 0.5\n",
      "Edit: Peter Arnett -> journalist <to> James Elroy Flecker -> poet -- found in beta: 0.5\n",
      "Edit: Thomas Troelsen -> composer <to> Muhammad Iqbal -> poet -- found in beta: 0.5\n",
      "Edit: William Jennings Bryan -> politician <to> Nikolai Myaskovsky -> composer -- found in beta: 0.5\n",
      "Edit: Benedict Calvert, 4th Baron Baltimore -> politician <to> Nikolaj Frederik Severin Grundtvig -> poet -- found in beta: 0.5\n",
      "Edit: Johannes Nucius -> composer <to> Peter Arnett -> journalist -- found in beta: 0.5\n",
      "Edit: Louis Moreau Gottschalk -> composer <to> Craig Charles -> comedian -- found in beta: 0.75\n",
      "Edit: Edward Burtynsky -> photographer <to> Evan Bayh -> politician -- found in beta: 0.75\n",
      "Edit: Robert I. Soare -> mathematician <to> Ian Hanomansing -> journalist -- found in beta: 0.75\n",
      "Edit: Antoine Augustin Cournot -> mathematician <to> Oskar Lafontaine -> politician -- found in beta: 0.75\n",
      "Edit: Sabina Spielrein -> psychologist <to> Anders Fogh Rasmussen -> politician -- found in beta: 1.0\n",
      "Edit: Tariq Saleh -> journalist <to> David Gascoyne -> poet -- found in beta: 1.0\n",
      "Edit: Georges Aperghis -> composer <to> Hamdeen Sabahi -> politician -- found in beta: 1.0\n",
      "Edit: Vic Sotto -> comedian <to> Murli Deora -> politician -- found in beta: 1.0\n",
      "Edit: Hamdeen Sabahi -> politician <to> Rob Owen -> journalist -- found in beta: 1.0\n",
      "Edit: Anders Fogh Rasmussen -> politician <to> Werner Heisenberg -> physicist -- found in beta: 1.0\n",
      "Edit: Michael Dorn -> actor <to> Yakubu Gowon -> politician -- found in beta: 1.0\n",
      "Edit: Sunny Hundal -> journalist <to> Chris Barrie -> comedian -- found in beta: 1.25\n",
      "Edit: Suzanne Virdee -> journalist <to> Franco Sacchetti -> poet -- found in beta: 1.25\n",
      "Edit: Sir William Hart Dyke, 7th Baronet -> politician <to> Jonathan Kwitny -> journalist -- found in beta: 1.25\n",
      "Edit: Thomas W. Knox -> journalist <to> Naomi Shihab Nye -> poet -- found in beta: 1.25\n",
      "Edit: Moshe Levinger -> rabbi <to> Paul Hellyer -> politician -- found in beta: 1.25\n",
      "Edit: Laurent Lafforgue -> mathematician <to> Sunny Hundal -> journalist -- found in beta: 1.25\n",
      "Edit: Jonathan Kwitny -> journalist <to> William John Gruffydd -> poet -- found in beta: 1.25\n",
      "Edit: J. Anthony Lukas -> journalist <to> Antoine Augustin Cournot -> mathematician -- found in beta: 1.5\n",
      "Edit: Craig Charles -> comedian <to> Thomas W. Knox -> journalist -- found in beta: 1.5\n",
      "Edit: Rob Owen -> journalist <to> Klemens von Metternich -> diplomat -- found in beta: 1.75\n",
      "Edit: Lim Chin Siong -> politician <to> Martha Nussbaum -> philosopher -- found in beta: 1.75\n",
      "Edit: Klemens von Metternich -> diplomat <to> Michel Corrette -> composer -- found in beta: 1.75\n",
      "Edit: Boethius -> philosopher <to> R. H. Bruce Lockhart -> journalist -- found in beta: 2.0\n",
      "Edit: David Weigel -> journalist <to> Thomas Troelsen -> composer -- found in beta: 2.0\n",
      "Edit: Michael Healy-Rae -> politician <to> Charles Kaiser -> journalist -- found in beta: 2.5\n",
      "Edit: Ermanno Wolf-Ferrari -> composer <to> Boethius -> philosopher -- found in beta: 2.75\n",
      "Edit: John Rentoul -> journalist <to> Ermanno Wolf-Ferrari -> composer -- found in beta: 2.75\n",
      "Edit: Peter Abelard -> philosopher <to> Ross Garnaut -> economist -- found in beta: 3.0\n",
      "Edit: Nikolaj Frederik Severin Grundtvig -> poet <to> Willem de Sitter -> physicist -- found in beta: 4.5\n",
      "Edit: Nikki Giovanni -> poet <to> Augustin-Jean Fresnel -> physicist -- found in beta: 4.75\n"
     ]
    }
   ],
   "source": [
    "efficacy_successes = {s.target.subject : s for s in rank_result.efficacy_successes}\n",
    "\n",
    "for beta_result in layer_result.result.betas:\n",
    "    faithfulness_successes = beta_result.faithfulness_successes\n",
    "    for sample in faithfulness_successes:\n",
    "        if(sample.subject in efficacy_successes):\n",
    "            print(f\"Edit: {efficacy_successes[sample.subject].source} <to> {efficacy_successes[sample.subject].target} -- found in beta: {beta_result.beta}\")\n",
    "            efficacy_successes.pop(sample.subject)\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "No target match found in faithfulness successes for the following:\n",
      "Edit: Jan Tennant -> journalist <to> Edward Burtynsky -> photographer\n",
      "Edit: Anton Walter -> composer <to> Lim Chin Siong -> politician\n",
      "Edit: John Horgan -> journalist <to> Cledwyn Hughes, Baron Cledwyn of Penrhos -> politician\n",
      "Edit: Evan Bayh -> politician <to> Mia Freedman -> journalist\n",
      "Edit: J. Gwyn Griffiths -> poet <to> Anton Walter -> composer\n",
      "Edit: Jourdan Miller -> model <to> Karl Taylor Compton -> physicist\n",
      "Edit: Thomas de Waal -> journalist <to> Hilary Putnam -> philosopher\n",
      "Edit: Michael William Balfe -> composer <to> J. Anthony Lukas -> journalist\n",
      "Edit: Ross Garnaut -> economist <to> John Crewe, 1st Baron Crewe -> politician\n",
      "Edit: Roberto Gerhard -> composer <to> James E. Ferguson -> politician\n",
      "Edit: Walter Hines Page -> journalist <to> Malcolm MacDonald -> politician\n"
     ]
    }
   ],
   "source": [
    "print(\"No target match found in faithfulness successes for the following:\")\n",
    "for sample in efficacy_successes.values():\n",
    "    print(f\"Edit: {sample.source} <to> {sample.target}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[RelationSample(subject='Tony Pua', object='politician'),\n",
       " RelationSample(subject='Janet Gunn', object='actor'),\n",
       " RelationSample(subject='Valentino Bucchi', object='composer'),\n",
       " RelationSample(subject='Christopher Guest', object='comedian'),\n",
       " RelationSample(subject='Giovanni Battista Guarini', object='poet'),\n",
       " RelationSample(subject='Menachem Mendel Schneerson', object='rabbi'),\n",
       " RelationSample(subject='Cristina Peri Rossi', object='novelist'),\n",
       " RelationSample(subject='Nils Strindberg', object='photographer')]"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_samples = layer_result.result.samples\n",
    "train_samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'{} works as a'"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "prompt_template = relation_result.trials[TRIAL_NO].prompt_template\n",
    "prompt_template"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src import functional, operators, editors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "estimator = operators.JacobianIclMeanEstimator(\n",
    "    mt = mt,\n",
    "    h_layer = LAYER,\n",
    ")\n",
    "\n",
    "operator = estimator(\n",
    "    relation.set(\n",
    "        samples = train_samples,\n",
    "        prompt_templates = [prompt_template],\n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [],
   "source": [
    "############################################\n",
    "subject = \"Malcolm MacDonald\"\n",
    "############################################"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[PredictedToken(token=' politician', prob=0.13466356694698334),\n",
       "  PredictedToken(token=' journalist', prob=0.11699773371219635),\n",
       "  PredictedToken(token=' writer', prob=0.06262437254190445),\n",
       "  PredictedToken(token=' musician', prob=0.04139218106865883),\n",
       "  PredictedToken(token=' historian', prob=0.028448402881622314)]]"
      ]
     },
     "execution_count": 43,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# model predicts correctly\n",
    "functional.predict_next_token(\n",
    "    mt = mt,\n",
    "    prompt = operator.prompt_template.format(subject)\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "LinearRelationOutput(predictions=[PredictedToken(token=' writer', prob=0.11166161298751831), PredictedToken(token=' poet', prob=0.09550923854112625), PredictedToken(token=' musician', prob=0.0816933736205101), PredictedToken(token=' journalist', prob=0.06462478637695312), PredictedToken(token=' politician', prob=0.05527650564908981)], h=tensor([[ 0.0685,  0.7666,  0.0354,  ...,  2.1816, -1.7412, -0.4724]],\n",
       "       device='cuda:0', dtype=torch.float16), z=tensor([[-1.7080, -5.2656,  0.7314,  ..., -5.3125,  0.2017,  5.3555]],\n",
       "       device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>))"
      ]
     },
     "execution_count": 44,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# LRE fails (low faithfulness)\n",
    "operator(subject=subject, k = 5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [],
   "source": [
    "svd = torch.svd(operator.weight.float())\n",
    "editor = editors.LowRankPInvEditor(\n",
    "    lre = operator,\n",
    "    rank = rank_result.rank,\n",
    "    n_samples=1, n_new_tokens=1,\n",
    "    svd = svd\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Editing: Walter Hines Page -> journalist <to> Malcolm MacDonald -> politician'"
      ]
     },
     "execution_count": 46,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "efficacy_test_pair = efficacy_successes[subject]\n",
    "f\"Editing: {efficacy_test_pair.source} <to> {efficacy_test_pair.target}\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "LinearRelationEditResult(predicted_tokens=[PredictedToken(token=' politician', prob=0.13382403552532196), PredictedToken(token=' musician', prob=0.09489544481039047), PredictedToken(token=' writer', prob=0.08506372570991516), PredictedToken(token=' journalist', prob=0.07684867829084396), PredictedToken(token=' poet', prob=0.04588843882083893)], model_logits=tensor([-inf, -inf, -inf,  ..., -inf, -inf, -inf], device='cuda:0'), model_generations=['Tony Pua works as a politician\\nJanet Gunn works as a actor\\nValentino Bucchi works as a composer\\nChristopher Guest works as a comedian\\nGiovanni Battista Guarini works as a poet\\nMenachem Mendel Schneerson works as a rabbi\\nCristina Peri Rossi works as a novelist\\nNils Strindberg works as a photographer\\nWalter Hines Page works as a historian'])"
      ]
     },
     "execution_count": 47,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# editing succeeds (high efficacy)\n",
    "editor(\n",
    "    subject = efficacy_test_pair.source.subject,\n",
    "    target = efficacy_test_pair.target.subject,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
