{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"..\")\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import os\n",
    "from src import data\n",
    "import json\n",
    "from tqdm.auto import tqdm\n",
    "from src.metrics import AggregateMetric\n",
    "\n",
    "from src.utils.sweep_utils import read_sweep_results, relation_from_dict\n",
    "import logging\n",
    "from src.utils import logging_utils\n",
    "from src import hparams\n",
    "\n",
    "# logger = logging.getLogger(__name__)\n",
    "\n",
    "# logging.basicConfig(\n",
    "#     level=logging.DEBUG,\n",
    "#     format = logging_utils.DEFAULT_FORMAT,\n",
    "#     datefmt=logging_utils.DEFAULT_DATEFMT,\n",
    "#     stream=sys.stdout\n",
    "# )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "############################################\n",
    "sweep_root = \"../results/sweep-24-trials\"\n",
    "model_name = \"gptj\"\n",
    "############################################\n",
    "\n",
    "sweep_path = f\"{sweep_root}/{model_name}\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sweep_results = read_sweep_results(\n",
    "    sweep_path, \n",
    "    # relation_names=[\"country capital city\"]\n",
    ")\n",
    "list(sweep_results.keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "relation_dict = {}\n",
    "for relation_name, sweep_dict in tqdm(sweep_results.items()):\n",
    "    relation_dict[relation_name] = relation_from_dict(sweep_dict)\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for relation_name, relation in relation_dict.items():\n",
    "    best_hparams = relation.best_by_efficacy(beta=2.5)\n",
    "    performance = f\"efficacy={best_hparams.efficacy.mean:.3f} | faithfulness={best_hparams.recall.mean:.3f}\"\n",
    "    print(f\"{relation_name} >> layer={best_hparams.layer} | beta={best_hparams.beta.mean} | rank={int(best_hparams.rank.mean)} <> {performance}\")\n",
    "\n",
    "    hparams.RelationHParams(\n",
    "        relation_name=relation.relation_name,\n",
    "        h_layer=best_hparams.layer,\n",
    "        h_layer_edit=best_hparams.layer,\n",
    "        z_layer=-1,\n",
    "        beta=best_hparams.beta.mean,\n",
    "        rank=int(np.floor(best_hparams.rank.mean)),\n",
    "        model_name=model_name,\n",
    "    ).save()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for relation_name, relation in relation_dict.items():\n",
    "    best_hparams = relation.best_by_faithfulness(beta=2.5)\n",
    "    performance = f\"efficacy={best_hparams.efficacy.mean:.3f} | faithfulness={best_hparams.recall.mean:.3f}\"\n",
    "    print(f\"{relation_name} >> layer={best_hparams.layer} | beta={best_hparams.beta.mean} | rank={int(best_hparams.rank.mean)} <> {performance}\")\n",
    "\n",
    "    hparams.RelationHParams(\n",
    "        relation_name=relation.relation_name,\n",
    "        h_layer=best_hparams.layer,\n",
    "        h_layer_edit=best_hparams.layer,\n",
    "        z_layer=-1,\n",
    "        beta=best_hparams.beta.mean,\n",
    "        rank=int(np.floor(best_hparams.rank.mean)),\n",
    "        model_name=model_name,\n",
    "    ).save()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "relations",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
