{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"../..\")\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import os\n",
    "from src import data\n",
    "import json\n",
    "from tqdm.auto import tqdm\n",
    "from src.metrics import AggregateMetric\n",
    "import logging\n",
    "import torch\n",
    "import json\n",
    "\n",
    "from src.utils import logging_utils\n",
    "from src.utils.sweep_utils import read_sweep_results, relation_from_dict\n",
    "\n",
    "\n",
    "# logging_utils.configure(level=logging.DEBUG)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "############################################\n",
    "sweep_root = \"../../results/sweep-full-rank\"\n",
    "model_name = \"gptj\"\n",
    "############################################\n",
    "\n",
    "sweep_path = f\"{sweep_root}/{model_name}\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sweep_results = read_sweep_results(sweep_path)\n",
    "\n",
    "for relation in sweep_results:\n",
    "    print(relation, end=\": \")\n",
    "    print(len(sweep_results[relation][\"trials\"]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def rankwise_performance(relation_result):\n",
    "    rank_recalls = {beta.rank: [] for beta in relation_result.trials[0].layers[0].result.betas}\n",
    "    rank_efficacies = {rank.rank: [] for rank in relation_result.trials[0].layers[0].result.ranks}\n",
    "\n",
    "    for trial in relation_result.trials:\n",
    "        for beta in trial.layers[0].result.betas:\n",
    "            rank_recalls[beta.rank].append(beta.recall[0])\n",
    "        for rank in trial.layers[0].result.ranks:\n",
    "            rank_efficacies[rank.rank].append(rank.efficacy[0])\n",
    "    \n",
    "    ranks = list(rank_recalls.keys())\n",
    "\n",
    "    return ranks, list(rank_recalls.values()), list(rank_efficacies.values())\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "df = pd.read_csv(\"../../results/tables/gptj-hparams.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "relation_dict = {}\n",
    "for relation in tqdm(sweep_results):\n",
    "    relation_dict[relation] = relation_from_dict(sweep_results[relation])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#########################################\n",
    "tau = 0.8\n",
    "#########################################\n",
    "\n",
    "F = None\n",
    "C = None\n",
    "\n",
    "for relation in relation_dict:\n",
    "    res = df[df[\"relation\"] == relation].to_dict(orient=\"records\")[0]\n",
    "    print(f\"{relation} >> faithfulness={res['recall@1']} | efficacy={res['efficacy']}\")\n",
    "    if float(res['recall@1'].split()[0]) > tau:\n",
    "        ranks, faithfulness, efficacies = rankwise_performance(relation_dict[relation])\n",
    "        F = torch.Tensor(faithfulness).T if F is None else torch.cat([F, torch.Tensor(faithfulness).T], dim=0)\n",
    "        C = torch.Tensor(efficacies).T if C is None else torch.cat([C, torch.Tensor(efficacies).T], dim=0)\n",
    "\n",
    "print(F.shape)\n",
    "\n",
    "f_mean = F.mean(dim = 0)\n",
    "c_mean = C.mean(dim = 0)\n",
    "f_std = F.std(dim = 0)\n",
    "c_std = C.std(dim = 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rcdefaults()\n",
    "fig_dir = \"figs\"\n",
    "#####################################################################################\n",
    "plt.rcdefaults()\n",
    "plt.rcParams[\"figure.dpi\"] = 200\n",
    "plt.rcParams[\"font.family\"] = \"Times New Roman\"\n",
    "\n",
    "SMALL_SIZE = 14\n",
    "MEDIUM_SIZE = 18\n",
    "BIGGER_SIZE = 22\n",
    "\n",
    "plt.rc(\"font\", size=SMALL_SIZE)  # controls default text sizes\n",
    "plt.rc(\"axes\", labelsize=MEDIUM_SIZE)  # fontsize of the x and y labels\n",
    "plt.rc(\"xtick\", labelsize=SMALL_SIZE)  # fontsize of the tick labels\n",
    "plt.rc(\"ytick\", labelsize=SMALL_SIZE)  # fontsize of the tick labels\n",
    "plt.rc(\"legend\", fontsize=MEDIUM_SIZE)  # legend fontsize\n",
    "plt.rc(\"figure\", titlesize=50)  # fontsize of the figure title\n",
    "\n",
    "\n",
    "faith_color = \"steelblue\"\n",
    "cause_color = \"darkorange\"\n",
    "#####################################################################################\n",
    "\n",
    "plt.plot(ranks, f_mean, label = \"Faithfulness\", color = faith_color, linewidth = 1.8)\n",
    "plt.fill_between(ranks, f_mean - f_std, f_mean + f_std, alpha = 0.1, color = faith_color)\n",
    "plt.plot(ranks, c_mean, label = \"Causality\", color = cause_color, linewidth = 1.8)\n",
    "plt.fill_between(ranks, c_mean - c_std, c_mean + c_std, alpha = 0.1, color = cause_color)\n",
    "\n",
    "plt.xscale(\"log\", base = 2)\n",
    "plt.xlabel(\"Rank\")\n",
    "plt.ylim(0, 1)\n",
    "plt.ylabel(\"Score\")\n",
    "plt.legend(ncol = 2, bbox_to_anchor=(0.5, 1.15), loc='upper center', frameon=False)\n",
    "\n",
    "plt.savefig(f\"{fig_dir}/rank-sweep.pdf\", bbox_inches=\"tight\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.read_csv(\"../../results/tables/gptj-beta-R.csv\")\n",
    "print(df.to_latex( index=False, float_format=\"%.2f\" ))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.read_csv(\"../../results/tables/gptj-hparams.csv\")\n",
    "print(df.to_latex( index=False, float_format=\"%.2f\" ))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "count = 0\n",
    "for idx, row in df.iterrows():\n",
    "    faith = float(row['recall@1'].split()[0])\n",
    "    count += faith > 0.6\n",
    "count/len(df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
