{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"..\")\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import os\n",
    "from src import data\n",
    "import json\n",
    "from tqdm.auto import tqdm\n",
    "from src.metrics import AggregateMetric\n",
    "\n",
    "from src.utils.sweep_utils import read_sweep_results, relation_from_dict\n",
    "import logging\n",
    "from src.utils import logging_utils\n",
    "from src import hparams\n",
    "\n",
    "# logger = logging.getLogger(__name__)\n",
    "\n",
    "# logging.basicConfig(\n",
    "#     level=logging.DEBUG,\n",
    "#     format = logging_utils.DEFAULT_FORMAT,\n",
    "#     datefmt=logging_utils.DEFAULT_DATEFMT,\n",
    "#     stream=sys.stdout\n",
    "# )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "############################################\n",
    "sweep_root = \"../results/sweep-24-trials\"\n",
    "model_name = \"gptj\"\n",
    "############################################\n",
    "\n",
    "sweep_path = f\"{sweep_root}/{model_name}\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['person occupation',\n",
       " 'landmark in country',\n",
       " 'adjective antonym',\n",
       " 'person mother',\n",
       " 'country capital city',\n",
       " 'plays pro sport',\n",
       " 'person plays instrument',\n",
       " 'person university',\n",
       " 'city in country',\n",
       " 'food from country',\n",
       " 'company hq',\n",
       " 'occupation gender',\n",
       " 'occupation age',\n",
       " 'name gender',\n",
       " 'word first letter',\n",
       " 'country language',\n",
       " 'object superclass',\n",
       " 'name religion',\n",
       " 'person native language',\n",
       " 'president election year',\n",
       " 'fruit outside color',\n",
       " 'superhero archnemesis',\n",
       " 'work location',\n",
       " 'landmark on continent',\n",
       " 'person lead singer of band',\n",
       " 'task person type',\n",
       " 'characteristic gender',\n",
       " 'country largest city',\n",
       " 'country currency',\n",
       " 'fruit inside color',\n",
       " 'task done by tool',\n",
       " 'verb past tense',\n",
       " 'star constellation name',\n",
       " 'pokemon evolution',\n",
       " 'president birth year',\n",
       " 'product by company',\n",
       " 'name birthplace',\n",
       " 'word last letter',\n",
       " 'word sentiment',\n",
       " 'company CEO',\n",
       " 'superhero person',\n",
       " 'person father',\n",
       " 'substance phase of matter',\n",
       " 'person sport position',\n",
       " 'adjective superlative',\n",
       " 'adjective comparative',\n",
       " 'univ degree gender']"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sweep_results = read_sweep_results(\n",
    "    sweep_path, \n",
    "    # relation_names=[\"country capital city\"]\n",
    ")\n",
    "list(sweep_results.keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c77df28ad2744f11a9fdd5184b9ab798",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/47 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "relation_dict = {}\n",
    "for relation_name, sweep_dict in tqdm(sweep_results.items()):\n",
    "    relation_dict[relation_name] = relation_from_dict(sweep_dict)\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "person occupation >> layer=8 | beta=2.25 | rank=131 <> efficacy=0.656 | faithfulness=0.494\n",
      "landmark in country >> layer=6 | beta=2.25 | rank=97 <> efficacy=0.679 | faithfulness=0.364\n",
      "adjective antonym >> layer=8 | beta=2.25 | rank=243 <> efficacy=0.859 | faithfulness=0.690\n",
      "person mother >> layer=6 | beta=2.25 | rank=170 <> efficacy=0.392 | faithfulness=0.142\n",
      "country capital city >> layer=3 | beta=2.25 | rank=68 <> efficacy=0.990 | faithfulness=0.880\n",
      "plays pro sport >> layer=6 | beta=2.25 | rank=117 <> efficacy=0.940 | faithfulness=0.760\n",
      "person plays instrument >> layer=9 | beta=2.25 | rank=198 <> efficacy=0.763 | faithfulness=0.590\n",
      "person university >> layer=4 | beta=2.25 | rank=153 <> efficacy=0.907 | faithfulness=0.641\n",
      "city in country >> layer=2 | beta=2.25 | rank=115 <> efficacy=0.894 | faithfulness=0.442\n",
      "food from country >> layer=3 | beta=2.25 | rank=113 <> efficacy=0.967 | faithfulness=0.514\n",
      "company hq >> layer=6 | beta=2.25 | rank=126 <> efficacy=0.493 | faithfulness=0.205\n",
      "occupation gender >> layer=4 | beta=2.25 | rank=34 <> efficacy=1.000 | faithfulness=0.981\n",
      "occupation age >> layer=5 | beta=2.25 | rank=34 <> efficacy=1.000 | faithfulness=0.675\n",
      "name gender >> layer=emb | beta=2.25 | rank=17 <> efficacy=0.943 | faithfulness=0.799\n",
      "word first letter >> layer=7 | beta=2.25 | rank=121 <> efficacy=0.912 | faithfulness=0.580\n",
      "country language >> layer=1 | beta=2.25 | rank=63 <> efficacy=0.987 | faithfulness=0.884\n",
      "object superclass >> layer=7 | beta=2.25 | rank=91 <> efficacy=0.933 | faithfulness=0.853\n",
      "name religion >> layer=4 | beta=2.25 | rank=57 <> efficacy=0.990 | faithfulness=0.802\n",
      "person native language >> layer=6 | beta=2.25 | rank=92 <> efficacy=0.872 | faithfulness=0.652\n",
      "president election year >> layer=emb | beta=2.25 | rank=84 <> efficacy=0.909 | faithfulness=0.524\n",
      "fruit outside color >> layer=5 | beta=2.25 | rank=160 <> efficacy=0.834 | faithfulness=0.776\n",
      "superhero archnemesis >> layer=11 | beta=2.25 | rank=192 <> efficacy=0.599 | faithfulness=0.301\n",
      "work location >> layer=5 | beta=2.25 | rank=112 <> efficacy=0.944 | faithfulness=0.555\n",
      "landmark on continent >> layer=4 | beta=2.25 | rank=158 <> efficacy=0.908 | faithfulness=0.563\n",
      "person lead singer of band >> layer=8 | beta=2.25 | rank=163 <> efficacy=0.839 | faithfulness=0.640\n",
      "task person type >> layer=8 | beta=2.25 | rank=109 <> efficacy=0.773 | faithfulness=0.492\n",
      "characteristic gender >> layer=1 | beta=2.25 | rank=74 <> efficacy=0.966 | faithfulness=0.771\n",
      "country largest city >> layer=10 | beta=2.25 | rank=74 <> efficacy=0.992 | faithfulness=0.925\n",
      "country currency >> layer=3 | beta=2.25 | rank=88 <> efficacy=0.976 | faithfulness=0.577\n",
      "fruit inside color >> layer=7 | beta=2.25 | rank=107 <> efficacy=0.930 | faithfulness=0.648\n",
      "task done by tool >> layer=5 | beta=2.25 | rank=145 <> efficacy=0.764 | faithfulness=0.288\n",
      "verb past tense >> layer=11 | beta=2.25 | rank=182 <> efficacy=0.972 | faithfulness=0.945\n",
      "star constellation name >> layer=8 | beta=2.25 | rank=152 <> efficacy=0.266 | faithfulness=0.407\n",
      "pokemon evolution >> layer=7 | beta=2.25 | rank=206 <> efficacy=0.248 | faithfulness=0.146\n",
      "president birth year >> layer=6 | beta=2.25 | rank=106 <> efficacy=0.841 | faithfulness=0.538\n",
      "product by company >> layer=4 | beta=2.25 | rank=158 <> efficacy=0.535 | faithfulness=0.306\n",
      "name birthplace >> layer=7 | beta=2.25 | rank=91 <> efficacy=0.955 | faithfulness=0.916\n",
      "word last letter >> layer=6 | beta=2.25 | rank=61 <> efficacy=0.834 | faithfulness=0.568\n",
      "word sentiment >> layer=4 | beta=2.25 | rank=94 <> efficacy=0.928 | faithfulness=0.635\n",
      "company CEO >> layer=6 | beta=2.25 | rank=173 <> efficacy=0.308 | faithfulness=0.065\n",
      "superhero person >> layer=8 | beta=2.25 | rank=228 <> efficacy=0.712 | faithfulness=0.437\n",
      "person father >> layer=8 | beta=2.25 | rank=217 <> efficacy=0.285 | faithfulness=0.072\n",
      "substance phase of matter >> layer=7 | beta=2.25 | rank=60 <> efficacy=0.968 | faithfulness=0.868\n",
      "person sport position >> layer=5 | beta=2.25 | rank=97 <> efficacy=0.738 | faithfulness=0.421\n",
      "adjective superlative >> layer=10 | beta=2.25 | rank=143 <> efficacy=0.991 | faithfulness=0.932\n",
      "adjective comparative >> layer=10 | beta=2.25 | rank=121 <> efficacy=0.944 | faithfulness=0.979\n",
      "univ degree gender >> layer=5 | beta=2.25 | rank=164 <> efficacy=0.947 | faithfulness=0.868\n"
     ]
    }
   ],
   "source": [
    "for relation_name, relation in relation_dict.items():\n",
    "    best_hparams = relation.best_by_efficacy(beta=2.25)\n",
    "    performance = f\"efficacy={best_hparams.efficacy.mean:.3f} | faithfulness={best_hparams.recall.mean:.3f}\"\n",
    "    print(f\"{relation_name} >> layer={best_hparams.layer} | beta={best_hparams.beta.mean} | rank={int(best_hparams.rank.mean)} <> {performance}\")\n",
    "\n",
    "    hparams.RelationHParams(\n",
    "        relation_name=relation.relation_name,\n",
    "        h_layer=best_hparams.layer,\n",
    "        h_layer_edit=best_hparams.layer,\n",
    "        z_layer=-1,\n",
    "        beta=best_hparams.beta.mean,\n",
    "        rank=int(np.floor(best_hparams.rank.mean)),\n",
    "        model_name=model_name,\n",
    "    ).save()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "RelationHParams(model_name='gptj', relation_name='country capital city', h_layer=3, beta=2.25, rank=68, z_layer=-1, h_layer_edit=3)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "hparams.RelationHParams.from_relation(\n",
    "    model = \"gptj\",\n",
    "    relation = \"country capital city\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for relation_name, relation in relation_dict.items():\n",
    "    best_hparams = relation.best_by_faithfulness(beta=2.5)\n",
    "    performance = f\"efficacy={best_hparams.efficacy.mean:.3f} | faithfulness={best_hparams.recall.mean:.3f}\"\n",
    "    print(f\"{relation_name} >> layer={best_hparams.layer} | beta={best_hparams.beta.mean} | rank={int(best_hparams.rank.mean)} <> {performance}\")\n",
    "\n",
    "    hparams.RelationHParams(\n",
    "        relation_name=relation.relation_name,\n",
    "        h_layer=best_hparams.layer,\n",
    "        h_layer_edit=best_hparams.layer,\n",
    "        z_layer=-1,\n",
    "        beta=best_hparams.beta.mean,\n",
    "        rank=int(np.floor(best_hparams.rank.mean)),\n",
    "        model_name=model_name,\n",
    "    ).save()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "relations",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
