{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"../..\")\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import os\n",
    "from src import data\n",
    "import json\n",
    "from tqdm.auto import tqdm\n",
    "from src.metrics import AggregateMetric\n",
    "import logging\n",
    "\n",
    "from src.utils import logging_utils\n",
    "\n",
    "# logging_utils.configure(level=logging.DEBUG)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Loading the results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "############################################\n",
    "sweep_root = \"../../results/sweep-24-trials\"\n",
    "# sweep_root = \"../../results/sweep-bare\"\n",
    "model_name = \"gptj\"\n",
    "############################################\n",
    "\n",
    "sweep_path = f\"{sweep_root}/{model_name}\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.utils.sweep_utils import read_sweep_results, relation_from_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sweep_results = read_sweep_results(sweep_path, economy=True)\n",
    "list(sweep_results.keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# sweep_results_bare = read_sweep_results(\"../../results/sweep-bare/gptj\", economy=True)\n",
    "# list(sweep_results_bare.keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# relation_name = \"plays pro sport\"\n",
    "# relation_result = relation_from_dict(sweep_results[relation_name])\n",
    "# relation_result_bare = relation_from_dict(sweep_results_bare[relation_name])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# efficacy = relation_result.best_by_efficacy()\n",
    "# print(f\"best by efficacy =>  layer={efficacy.layer}, beta={efficacy.beta.mean:.2f}, rank={efficacy.rank.mean:.0f}\")\n",
    "# print(\"recall:\", efficacy.recall)\n",
    "# print(\"efficacy:\", efficacy.efficacy)\n",
    "\n",
    "# print(\"-------------------------------------------------------------\")\n",
    "\n",
    "# faithfulness = relation_result.best_by_faithfulness()\n",
    "# print(f\"best by faithfulness => layer={faithfulness.layer}, beta={faithfulness.beta.mean:.2f}, rank={faithfulness.rank.mean:.0f}\")\n",
    "# print(\"recall:\", faithfulness.recall)\n",
    "# print(\"efficacy:\", faithfulness.efficacy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# beta = 2.25\n",
    "\n",
    "# efficacy = relation_result.best_by_efficacy(beta = beta)\n",
    "# print(f\"best by efficacy =>  layer={efficacy.layer}, beta={efficacy.beta.mean:.2f}, rank={efficacy.rank.mean:.0f}\")\n",
    "# print(\"recall:\", efficacy.recall)\n",
    "# print(\"efficacy:\", efficacy.efficacy)\n",
    "\n",
    "# print(\"-------------------------------------------------------------\")\n",
    "\n",
    "# faithfulness = relation_result.best_by_faithfulness(beta = beta)\n",
    "# print(f\"best by faithfulness => layer={faithfulness.layer}, beta={faithfulness.beta.mean:.2f}, rank={faithfulness.rank.mean:.0f}\")\n",
    "# print(\"recall:\", faithfulness.recall)\n",
    "# print(\"efficacy:\", faithfulness.efficacy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "##############################################################################\n",
    "fig_dir = f\"figs/{model_name}\"\n",
    "##############################################################################"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rcdefaults()\n",
    "os.makedirs(fig_dir, exist_ok=True)\n",
    "#####################################################################################\n",
    "plt.rcdefaults()\n",
    "plt.rcParams[\"figure.dpi\"] = 200\n",
    "plt.rcParams[\"font.family\"] = \"Times New Roman\"\n",
    "\n",
    "SMALL_SIZE = 14\n",
    "MEDIUM_SIZE = 18\n",
    "BIGGER_SIZE = 22\n",
    "\n",
    "plt.rc(\"font\", size=SMALL_SIZE)  # controls default text sizes\n",
    "plt.rc(\"axes\", labelsize=MEDIUM_SIZE+1)  # fontsize of the x and y labels\n",
    "plt.rc(\"xtick\", labelsize=SMALL_SIZE)  # fontsize of the tick labels\n",
    "plt.rc(\"ytick\", labelsize=MEDIUM_SIZE)  # fontsize of the tick labels\n",
    "plt.rc(\"legend\", fontsize=MEDIUM_SIZE)  # legend fontsize\n",
    "plt.rc(\"figure\", titlesize=50)  # fontsize of the figure title\n",
    "#####################################################################################\n",
    "\n",
    "color_scheme = {\n",
    "    \"recall\": \"steelblue\",\n",
    "    \"efficacy\": \"darkorange\",\n",
    "}\n",
    "\n",
    "legend_labels = {\n",
    "    \"recall\": \"Faithfulness\",\n",
    "    \"efficacy\": \"Causality\",\n",
    "}\n",
    "\n",
    "def plot_layerwise(\n",
    "        canvas, relation_result, \n",
    "        attribute = \"recall\", best_criterion = \"faithfulness\", \n",
    "        color = None, label = None,\n",
    "        linewidth = 2\n",
    "    ):\n",
    "    by_layer = relation_result.by_layer()\n",
    "\n",
    "    layers = list(by_layer.keys())\n",
    "    value_means = [value.__dict__[attribute].mean for value in by_layer.values()]\n",
    "    value_stds = [value.__dict__[attribute].stdev for value in by_layer.values()]\n",
    "\n",
    "    if(best_criterion == \"faithfulness\"):\n",
    "        best_layer = relation_result.best_by_faithfulness()\n",
    "    elif(best_criterion == \"efficacy\"):\n",
    "        best_layer = relation_result.best_by_efficacy()\n",
    "    \n",
    "    color = color_scheme[attribute] if color is None else color\n",
    "    label = legend_labels[attribute] if label is None else label\n",
    "    \n",
    "    # canvas.scatter([layers.index(best_layer.layer)], [best_layer.__dict__[attribute].mean], color=\"red\", s=200)\n",
    "    canvas.plot(range(len(layers)), value_means, color=color, linewidth=linewidth, label = label)\n",
    "    canvas.fill_between(range(len(layers)), np.array(value_means) - np.array(value_stds), np.array(value_means) + np.array(value_stds), color=color, alpha=0.07)\n",
    "\n",
    "    if attribute in [\"recall\", \"efficacy\"]:\n",
    "        canvas.set_ylim(0, 1)\n",
    "    else:\n",
    "        canvas.set_ylim(bottom=0)\n",
    "    canvas.set_xlabel(\"Layer\")\n",
    "    canvas.set_ylabel(attribute)\n",
    "    canvas.set_xticks(range(len(layers)), layers, rotation=90)\n",
    "    test_samples = [trial.n_test_samples for trial in relation_result.trials]\n",
    "    test_samples = np.array(test_samples)\n",
    "    trial_info = f\"[{test_samples.mean():.2f} ± {test_samples.std():.2f}]\"\n",
    "    canvas.set_title(f\"{relation_result.relation_name} n_trials={len(test_samples)} {trial_info}\\n{best_criterion} => h_layer: {best_layer.layer}, beta: {best_layer.beta.mean:.2f}, rank: {best_layer.rank.mean:.2f}, efficacy: {best_layer.efficacy.mean:.2f}\")\n",
    "    \n",
    "    return canvas"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def export_legend(legend, filename=\"legend.pdf\"):\n",
    "    fig  = legend.figure\n",
    "    fig.canvas.draw()\n",
    "    bbox  = legend.get_window_extent().transformed(fig.dpi_scale_trans.inverted())\n",
    "    fig.savefig(filename, dpi=\"figure\", bbox_inches=bbox)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# nrows = 1\n",
    "# ncols = 1\n",
    "# fig, ax = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols * 8, nrows * 6))\n",
    "# plot_layerwise(ax, relation_result, color = \"steelblue\", label=\"With relation-specific prompt\", linewidth=2.5)\n",
    "# plot_layerwise(ax, relation_result_bare, color = \"#00e6b8\", label=\"W/o relation-specific prompt\", linewidth=2.5)\n",
    "# # plot_layerwise(ax, relation_result, attribute=\"efficacy\", best_criterion=\"efficacy\")\n",
    "# ax.set_ylabel(\"Faithfulness\")\n",
    "# ax.set_title(\"\", fontsize=BIGGER_SIZE, pad=10)\n",
    "# legend = plt.legend(ncol = 2, bbox_to_anchor=(0.5, 1.15), loc='upper center', frameon=False)\n",
    "\n",
    "# export_legend(legend, f\"{fig_dir}/legend-faith-causal.pdf\")\n",
    "# legend.remove()\n",
    "# fig.tight_layout()\n",
    "# plt.savefig(f\"{fig_dir}/{model_name}-layer-mode-switch.pdf\", bbox_inches=\"tight\")\n",
    "# # plt.savefig(f\"{fig_dir}/{model_name}-sweep-bare.pdf\", bbox_inches=\"tight\")\n",
    "# plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = data.load_dataset()\n",
    "all_relations = [\n",
    "    relation.name for relation in dataset.relations\n",
    "]\n",
    "all_relations = sorted(all_relations)\n",
    "# all_relations\n",
    "failed_relations = list(set(all_relations) - set(sweep_results.keys()))\n",
    "failed_relations = sorted(failed_relations)\n",
    "\n",
    "failed_relations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ############################################\n",
    "# efficacy_root = \"../../results/efficacy_baselines-24-trials\"\n",
    "# ############################################\"\"\n",
    "\n",
    "# efficacy_path = f\"{efficacy_root}/{model_name}\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.utils.sweep_utils import read_efficacy_baseline_results, format_efficacy_baseline_results\n",
    "\n",
    "# efficacy_baseline_results = read_efficacy_baseline_results(efficacy_path)\n",
    "\n",
    "# print(len(efficacy_baseline_results))\n",
    "# list(efficacy_baseline_results.keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# relation_name = \"work location\"\n",
    "# relation_result = relation_from_dict(sweep_results[relation_name])\n",
    "# # format_efficacy_baseline_results(\n",
    "# #     efficacy_baseline_results[relation_name]\n",
    "# # )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plt.rcdefaults()\n",
    "# #####################################################################################\n",
    "# plt.rcdefaults()\n",
    "# plt.rcParams[\"figure.dpi\"] = 200\n",
    "# plt.rcParams[\"font.family\"] = \"Times New Roman\"\n",
    "\n",
    "# SMALL_SIZE = 14\n",
    "# MEDIUM_SIZE = 18\n",
    "# BIGGER_SIZE = 22\n",
    "\n",
    "# plt.rc(\"font\", size=SMALL_SIZE)  # controls default text sizes\n",
    "# plt.rc(\"axes\", labelsize=MEDIUM_SIZE+1)  # fontsize of the x and y labels\n",
    "# plt.rc(\"xtick\", labelsize=SMALL_SIZE)  # fontsize of the tick labels\n",
    "# plt.rc(\"ytick\", labelsize=MEDIUM_SIZE)  # fontsize of the tick labels\n",
    "# plt.rc(\"legend\", fontsize=SMALL_SIZE)  # legend fontsize\n",
    "# plt.rc(\"figure\", titlesize=50)  # fontsize of the figure title\n",
    "# #####################################################################################\n",
    "\n",
    "# nrows = 1\n",
    "# ncols = 1\n",
    "# fig, ax = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols * 8, nrows * 6))\n",
    "# plot_layerwise(ax, relation_result, attribute=\"efficacy\", best_criterion=\"efficacy\")\n",
    "# plot_efficacy_baseline(ax, efficacy_result)\n",
    "# ax.set_ylabel(\"Success @ 1\")\n",
    "# ax.set_title(relation_name, fontsize=BIGGER_SIZE, pad=10)\n",
    "# ax.legend(ncol = 1, bbox_to_anchor=(1, 1), loc='upper right')\n",
    "\n",
    "# plt.savefig(f\"{fig_dir}/{model_name}-causality_baselines.pdf\", bbox_inches=\"tight\")\n",
    "# plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ranks = {\n",
    "#     layer : layer_summary.rank.values\n",
    "#     for layer, layer_summary in relation_result.by_layer().items()\n",
    "# }\n",
    "# ranks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = data.load_dataset()\n",
    "interested_dataset = dataset.filter(\n",
    "    # relation_type = [\"factual\"]\n",
    ")\n",
    "\n",
    "filtered_results = {}\n",
    "for relation_name in tqdm(interested_dataset.relations):\n",
    "    if relation_name.name not in sweep_results:\n",
    "        continue\n",
    "    relation_result = relation_from_dict(sweep_results[relation_name.name])\n",
    "    if len(relation_result.trials) < (3 if model_name != \"llama-13b\" else 2):\n",
    "        print(f\"skipping {relation_name.name}, not enough trials, : {[trial.n_test_samples for trial in relation_result.trials]}\")\n",
    "        continue\n",
    "    filtered_results[relation_name.name] = relation_result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# saving hparams\n",
    "\n",
    "from src import hparams\n",
    "\n",
    "for relation_name in filtered_results:\n",
    "    sweep_result = filtered_results[relation_name]\n",
    "    beta = beta=2.25 if \"llama\" not in model_name else 8.0\n",
    "    best_by_f = sweep_result.best_by_faithfulness(beta = beta)\n",
    "    best_by_e = sweep_result.best_by_efficacy(beta = beta)\n",
    "    hparams.RelationHParams(\n",
    "        relation_name=sweep_result.relation_name,\n",
    "        h_layer=best_by_f.layer,  # type: ignore\n",
    "        h_layer_edit=best_by_e.layer,  # type: ignore\n",
    "        z_layer=-1,\n",
    "        beta=best_by_f.beta.mean,\n",
    "        rank=int(best_by_e.rank.mean),\n",
    "        model_name=model_name if \"llama\" not in model_name else \"llama\",\n",
    "    ).save()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Sweep Figures"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plt.rcdefaults()\n",
    "\n",
    "# step_size = 3\n",
    "\n",
    "# for _from in range(0, len(filtered_results), step_size):\n",
    "#     _to = min(len(filtered_results), _from + step_size)\n",
    "#     n_subplots = len(filtered_results) * 3\n",
    "#     n_subplots = (_to - _from)  * 2\n",
    "#     ncols=2\n",
    "#     nrows=int(np.ceil(n_subplots/ncols))\n",
    "#     fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols * 8, nrows * 6.5))\n",
    "#     if n_subplots == 1:\n",
    "#         axes = [axes]\n",
    "#     if nrows == 1:\n",
    "#         axes = [axes]\n",
    "\n",
    "#     ax_col, ax_row = 0, 0\n",
    "#     for i, (relation_name, relation_result) in list(enumerate(filtered_results.items()))[_from  : _to]:\n",
    "#         print(i, relation_name)\n",
    "#         result = filtered_results[relation_name]\n",
    "#         plot_layerwise(axes[ax_row][0], result)\n",
    "#         plot_layerwise(axes[ax_row][1], result, attribute=\"efficacy\", best_criterion=\"efficacy\")\n",
    "#         # if(relation_name in efficacy_baseline_results):\n",
    "#         #     efficacy_baselines = format_efficacy_baseline_results(efficacy_baseline_results[relation_name])\n",
    "#         #     plot_efficacy_baseline(axes[ax_row][1], efficacy_baselines)\n",
    "#         ax_row += 1\n",
    "\n",
    "#     fig.tight_layout()\n",
    "#     fig.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "relation_result = filtered_results[\"country capital city\"]\n",
    "by_layer = relation_result.by_layer()\n",
    "selected_layers = [list(by_layer.keys())[0]] + list(by_layer.keys())[1::3][1:]\n",
    "selected_layers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "relation_names = [\n",
    "    \"country capital city\", \"food from country\",\n",
    "    \"word sentiment\", \"adjective comparative\",\n",
    "    \"name birthplace\", \"name religion\",\n",
    "    \"work location\", \"task done by tool\",\n",
    "    \"company CEO\", \"pokemon evolution\", \"star constellation name\", \"person mother\"\n",
    "]\n",
    "\n",
    "#####################################################################################\n",
    "plt.rcdefaults()\n",
    "plt.rcParams[\"figure.dpi\"] = 200\n",
    "plt.rcParams[\"font.family\"] = \"Times New Roman\"\n",
    "\n",
    "SMALL_SIZE = 12\n",
    "MEDIUM_SIZE = 16\n",
    "BIGGER_SIZE = 22\n",
    "\n",
    "plt.rc(\"font\", size=SMALL_SIZE)  # controls default text sizes\n",
    "plt.rc(\"axes\", labelsize=MEDIUM_SIZE+1)  # fontsize of the x and y labels\n",
    "plt.rc(\"xtick\", labelsize=MEDIUM_SIZE)  # fontsize of the tick labels\n",
    "plt.rc(\"ytick\", labelsize=MEDIUM_SIZE)  # fontsize of the tick labels\n",
    "plt.rc(\"legend\", fontsize=SMALL_SIZE)  # legend fontsize\n",
    "plt.rc(\"figure\", titlesize=50)  # fontsize of the figure title\n",
    "#####################################################################################\n",
    "\n",
    "n_cols = 4\n",
    "n_rows = int(np.ceil(len(relation_names)/n_cols))\n",
    "fig, axes = plt.subplots(nrows=n_rows, ncols=n_cols, figsize=(n_cols * 4, n_rows * 3.5))\n",
    "\n",
    "axes = [axes] if n_rows == 1 else axes\n",
    "axes = [axes] if n_cols == 1 else axes\n",
    "\n",
    "selected_layers = [list(by_layer.keys())[0]] + list(by_layer.keys())[1::3][1:]\n",
    "layers = list(by_layer.keys())\n",
    "layer_labels = []\n",
    "for layer in layers:\n",
    "    if layer in selected_layers:\n",
    "        layer_labels.append(layer)\n",
    "    else:\n",
    "        layer_labels.append(\"\")\n",
    "\n",
    "cur_row = 0\n",
    "cur_col = 0\n",
    "for relation_name in relation_names:\n",
    "    relation_result = filtered_results[relation_name]\n",
    "    ax = axes[cur_row][cur_col]\n",
    "    print(relation_name, type(ax))\n",
    "    plot_layerwise(ax, relation_result)\n",
    "    plot_layerwise(ax, relation_result, attribute=\"efficacy\", best_criterion=\"efficacy\")\n",
    "    ax.set_title(relation_result.relation_name, fontsize=BIGGER_SIZE)\n",
    "    ax.set_xticks(range(len(layers)), layer_labels)\n",
    "    if cur_col == 0:\n",
    "        ax.set_ylabel(\"Score\")\n",
    "    else:\n",
    "        ax.set_ylabel(\"\")\n",
    "    ax.set_xlabel(\"\")\n",
    "\n",
    "    if cur_col == 0 and cur_row == 0:\n",
    "        legend = ax.legend(ncol=2, bbox_to_anchor=(0.5, 1.5), loc='upper center', frameon=False, fontsize=BIGGER_SIZE)\n",
    "        export_legend(legend, f\"{fig_dir}/legend-layerwise-sweep.pdf\")\n",
    "        legend.remove()\n",
    "\n",
    "    cur_col += 1\n",
    "    if cur_col == n_cols:\n",
    "        cur_row += 1\n",
    "        cur_col = 0\n",
    "\n",
    "\n",
    "fig.tight_layout()\n",
    "# plt.legend(ncol = 2, bbox_to_anchor=(0.5, 1.1), loc='upper center', frameon=False)\n",
    "plt.savefig(f\"{fig_dir}/{model_name}-sweeps.pdf\", bbox_inches=\"tight\")\n",
    "\n",
    "fig.show()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Causality vs Faithfulness Scatter Plots"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_recall_vs_efficacy_info(sweep_results, criterion = \"best\", threshold = 30, beta = None):\n",
    "    recalls = []\n",
    "    efficacies = []\n",
    "    pass_threshold = []\n",
    "    labels = []\n",
    "    for relation_name in sweep_results:\n",
    "        if len(sweep_results[relation_name].trials) == 0:\n",
    "            continue\n",
    "        if criterion in [\"best\", \"faithfulness\"]:\n",
    "            recalls.append(sweep_results[relation_name].best_by_faithfulness(beta = beta).recall.mean)\n",
    "        else:\n",
    "            recalls.append(sweep_results[relation_name].best_by_efficacy(beta = beta).recall.mean)\n",
    "        if criterion in [\"best\", \"efficacy\"]:\n",
    "            efficacies.append(sweep_results[relation_name].best_by_efficacy().efficacy.mean)\n",
    "        else:\n",
    "            efficacies.append(sweep_results[relation_name].best_by_faithfulness().efficacy.mean)\n",
    "        \n",
    "        test_samples = np.array([trial.n_test_samples for trial in sweep_results[relation_name].trials])\n",
    "        pass_threshold.append(test_samples.min() >= threshold)\n",
    "        labels.append(relation_name)\n",
    "\n",
    "    return recalls, efficacies, pass_threshold, labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(15, 5))\n",
    "\n",
    "for criterion, ax in zip([\"best\", \"faithfulness\", \"efficacy\"], axes):\n",
    "    recalls, efficacies, pass_threshold, labels = get_recall_vs_efficacy_info(\n",
    "        filtered_results, criterion=criterion, \n",
    "        # beta = 2.25\n",
    "    )\n",
    "    for recall, efficacy, threshold, relation_name in zip(recalls, efficacies, pass_threshold, labels):\n",
    "        alpha = .8 if threshold else 0.2\n",
    "        ax.scatter(recall, efficacy, color=\"blue\", alpha=alpha)\n",
    "        # if threshold and (recall/efficacy < .7 or efficacy/recall < .7):\n",
    "        #     ax.annotate(relation_name, (recall, efficacy))\n",
    "\n",
    "\n",
    "    ax.set_title(f\"hparam selection = {criterion}\")\n",
    "    ax.set_xlim(0, 1)\n",
    "    ax.set_ylim(0, 1)\n",
    "    ax.set_xlabel(\"Recall@1\")\n",
    "    ax.set_ylabel(\"Efficacy\")\n",
    "\n",
    "fig.tight_layout()\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "recalls, efficacies, pass_threshold, labels = get_recall_vs_efficacy_info(\n",
    "    filtered_results, criterion=criterion, \n",
    "    beta = 4\n",
    ")\n",
    "correlation = np.corrcoef(recalls, efficacies)[0, 1]\n",
    "correlation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.linear_model import LinearRegression\n",
    "import torch\n",
    "\n",
    "def get_linear_model(recalls: list[float], efficacies: list[float]):\n",
    "    X = torch.Tensor(recalls)[None].T\n",
    "    y = torch.Tensor(efficacies)[None].T\n",
    "    lm = LinearRegression().fit(X = X, y = y)\n",
    "    slope = torch.Tensor(lm.coef_).squeeze().item()\n",
    "    y_intercept = torch.Tensor(lm.intercept_).squeeze().item()\n",
    "    r_squared = lm.score(X, y)\n",
    "    return slope, y_intercept, r_squared\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "beta_options = [beta.beta for beta in relation_result.trials[0].layers[0].result.betas]\n",
    "# beta_options"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rcdefaults()\n",
    "num_figs = len(beta_options)\n",
    "n_cols = 3\n",
    "n_rows = int(np.ceil(num_figs/n_cols))\n",
    "fig, axes = plt.subplots(nrows=n_rows, ncols=3, figsize=(5*n_cols, 5*n_rows))\n",
    "\n",
    "scores_by_beta = []\n",
    "cur_row, cur_col = 0, 0\n",
    "for i, beta in enumerate(beta_options):\n",
    "    recalls, efficacies, pass_threshold, labels = get_recall_vs_efficacy_info(filtered_results, criterion=\"efficacy\", beta = beta)\n",
    "    correlation = np.corrcoef(recalls, efficacies)[0, 1]\n",
    "\n",
    "    scores_by_beta.append({\n",
    "        \"beta\": f\"{beta:.2f}\",\n",
    "        \"recall_mean\": f\"{np.mean(recalls):.2f} ± {np.std(recalls):.2f}\",\n",
    "        \"efficacy_mean\": f\"{np.mean(efficacies):.2f} ± {np.std(efficacies):.2f}\",\n",
    "        \"R\": f\"{correlation:.2f}\"\n",
    "    })\n",
    "\n",
    "    for recall, efficacy, threshold, relation_name in zip(recalls, efficacies, pass_threshold, labels):\n",
    "        alpha = .8 if threshold else 0.2\n",
    "        axes[cur_row][cur_col].scatter(recall, efficacy, color=\"darkblue\", alpha=alpha)\n",
    "        # if threshold and (recall/efficacy < .6 or efficacy/recall < .6):\n",
    "        #     plt.annotate(relation_name, (recall, efficacy))\n",
    "\n",
    "    slope, y_intercept, r_squared = get_linear_model(recalls, efficacies)\n",
    "\n",
    "    print(f\"{beta=} | {r_squared=:.2f}, {slope=:.3f}, {y_intercept=:.3f}\")\n",
    "    x = np.linspace(0, 1, 10)\n",
    "    axes[cur_row][cur_col].plot(x, slope*x + y_intercept, color=\"red\", linestyle=\"--\", alpha=0.5)\n",
    "    axes[cur_row][cur_col].set_title(\n",
    "        f\"best by efficacy (beta={beta : .2f}, R={correlation:.2f})\", \n",
    "        # fontsize=BIGGER_SIZE\n",
    "    )\n",
    "    axes[cur_row][cur_col].set_xlim(0, 1)\n",
    "    axes[cur_row][cur_col].set_ylim(0, 1)\n",
    "    axes[cur_row][cur_col].set_xlabel(\"Recall@1\")\n",
    "    axes[cur_row][cur_col].set_ylabel(\"Efficacy\")\n",
    "\n",
    "    cur_col += 1\n",
    "    if cur_col == n_cols:\n",
    "        cur_row += 1\n",
    "        cur_col = 0\n",
    "\n",
    "fig.tight_layout()\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#####################################################################################\n",
    "plt.rcdefaults()\n",
    "plt.rcParams[\"figure.dpi\"] = 200\n",
    "plt.rcParams[\"font.family\"] = \"Times New Roman\"\n",
    "\n",
    "SMALL_SIZE = 12\n",
    "MEDIUM_SIZE = 16\n",
    "BIGGER_SIZE = 18\n",
    "\n",
    "plt.rc(\"font\", size=SMALL_SIZE)  # controls default text sizes\n",
    "plt.rc(\"axes\", labelsize=MEDIUM_SIZE+1)  # fontsize of the x and y labels\n",
    "plt.rc(\"xtick\", labelsize=MEDIUM_SIZE)  # fontsize of the tick labels\n",
    "plt.rc(\"ytick\", labelsize=MEDIUM_SIZE)  # fontsize of the tick labels\n",
    "plt.rc(\"legend\", fontsize=SMALL_SIZE)  # legend fontsize\n",
    "plt.rc(\"figure\", titlesize=50)  # fontsize of the figure title\n",
    "#####################################################################################\n",
    "\n",
    "# plt.rcdefaults()\n",
    "\n",
    "\n",
    "recalls, efficacies, pass_threshold, labels = get_recall_vs_efficacy_info(\n",
    "    filtered_results, \n",
    "    criterion=\"efficacy\", # \"faithfulness\",\n",
    "    beta=2.25 if \"llama\" not in model_name else 8.0\n",
    ")\n",
    "correlation = np.corrcoef(recalls, efficacies)[0, 1]\n",
    "print(f\"Correlation: {correlation :.2f}\")\n",
    "\n",
    "x = np.linspace(0, 1, 10)\n",
    "slope, y_intercept, r_squared = get_linear_model(recalls, efficacies)\n",
    "# plt.plot(x, slope*x + y_intercept, color=\"red\", linestyle=\"--\", alpha=0.5)\n",
    "\n",
    "for recall, efficacy, threshold, relation_name in zip(recalls, efficacies, pass_threshold, labels):\n",
    "    alpha = .8 if threshold else 0.2\n",
    "    plt.scatter(recall, efficacy, color=\"darkblue\", alpha=alpha)\n",
    "    # plt.scatter(recall, efficacy, color=\"#1a0d00\", alpha=alpha)\n",
    "    dist = np.abs(slope * recall - efficacy + y_intercept)/(slope**2 + 1)**.5\n",
    "    # if dist > 0.2:\n",
    "    #     plt.annotate(relation_name, (recall, efficacy))\n",
    "\n",
    "# plt.title(\n",
    "#     f\"Faithfulness vs Causality\", # (Correlation={correlation :.2f})\", \n",
    "#     fontsize=BIGGER_SIZE\n",
    "# )\n",
    "plt.xlim(0, 1)\n",
    "plt.ylim(0, 1)\n",
    "plt.xlabel(\"Faithfulness\")\n",
    "plt.ylabel(\"Causality\")\n",
    "plt.gca().set_aspect('.9')\n",
    "plt.savefig(f\"{fig_dir}/{model_name}-efficacy_vs_faithfulness_faith.pdf\", bbox_inches=\"tight\")\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "df = pd.DataFrame(scores_by_beta)\n",
    "# print(df[[\"beta\", \"recall_mean\", \"R\"]].style.hide_index().to_latex())\n",
    "os.makedirs(\"../../results/tables\", exist_ok=True)\n",
    "df.to_csv(f\"../../results/tables/{model_name}-beta-R.csv\", index=False)\n",
    "df"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Faithfulness"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_name_dict = {\n",
    "    \"gpt2-xl\": \"GPT2-xl\",\n",
    "    \"gptj\": \"GPT-J\",\n",
    "    \"llama-13b\": \"LLaMA-13B\",\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import transformers\n",
    "\n",
    "model_full_name = {\n",
    "    \"gptj\": \"EleutherAI/gpt-j-6B\",\n",
    "    \"gpt2-xl\": \"gpt2-xl\",\n",
    "    \"llama-13b\": \"llama-13b\"\n",
    "}\n",
    "\n",
    "if \"llama\" in model_name:\n",
    "    tokenizer = transformers.LlamaTokenizerFast.from_pretrained(model_full_name[model_name])\n",
    "    tokenizer.pad_token = tokenizer.eos_token = \"</s>\"\n",
    "    tokenizer.pad_token_id = tokenizer.eos_token_id = 2\n",
    "else:\n",
    "    tokenizer = transformers.AutoTokenizer.from_pretrained(model_full_name[model_name])\n",
    "    tokenizer.pad_token = tokenizer.eos_token"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#####################################################################################\n",
    "plt.rcdefaults()\n",
    "plt.rcParams[\"figure.dpi\"] = 200\n",
    "plt.rcParams[\"font.family\"] = \"Times New Roman\"\n",
    "\n",
    "SMALL_SIZE = 28\n",
    "MEDIUM_SIZE = 35\n",
    "BIGGER_SIZE = 40\n",
    "\n",
    "plt.rc(\"font\", size=SMALL_SIZE)  # controls default text sizes\n",
    "plt.rc(\"axes\", labelsize=BIGGER_SIZE)  # fontsize of the x and y labels\n",
    "plt.rc(\"xtick\", labelsize=SMALL_SIZE)  # fontsize of the tick labels\n",
    "plt.rc(\"ytick\", labelsize=MEDIUM_SIZE)  # fontsize of the tick labels\n",
    "plt.rc(\"legend\", fontsize=SMALL_SIZE)  # legend fontsize\n",
    "plt.rc(\"figure\", titlesize=BIGGER_SIZE)  # fontsize of the figure title\n",
    "\n",
    "relation_recall = []\n",
    "\n",
    "for relation_name in filtered_results:\n",
    "    relation_recall.append({\n",
    "        \"relation\": relation_name,\n",
    "        \"recall@1\": filtered_results[relation_name].best_by_efficacy(\n",
    "            beta=2.25 if \"llama\" not in model_name else 8.0\n",
    "        ).recall.mean\n",
    "    })\n",
    "\n",
    "if model_name == \"gptj\":\n",
    "    relation_recall = sorted(relation_recall, key = lambda x: x[\"recall@1\"])\n",
    "    with open(f\"relation_order_{model_name}.json\", \"w\") as f:\n",
    "        json.dump(relation_recall, f)\n",
    "    plot_info = relation_recall\n",
    "else:\n",
    "    with open(\"relation_order_gptj.json\", \"r\") as f:\n",
    "        relation_order = json.load(f)\n",
    "    relation_recall = {r[\"relation\"]: r[\"recall@1\"] for r in relation_recall}\n",
    "\n",
    "    plot_info = []\n",
    "\n",
    "    for r in relation_order:\n",
    "        relation = r[\"relation\"]\n",
    "        recall = relation_recall[relation] if relation in relation_recall else 0\n",
    "        plot_info.append({\n",
    "            \"relation\": relation,\n",
    "            \"recall@1\": recall\n",
    "        })\n",
    "    \n",
    "    with open(f\"relation_order_{model_name}.json\", \"w\") as f:\n",
    "        json.dump(plot_info, f)\n",
    "\n",
    "\n",
    "relations = [r[\"relation\"] for r in plot_info]\n",
    "recalls = [r[\"recall@1\"] for r in plot_info]\n",
    "\n",
    "plt.figure(figsize = (10, 20))\n",
    "plt.barh(np.arange(len(relations)), recalls, color = \"steelblue\", alpha = 0.7)\n",
    "plt.yticks(np.arange(len(relations)), relations)\n",
    "plt.xticks(np.linspace(0, 1, 11), [np.round(v, 1) for v in np.linspace(0, 1, 11)])\n",
    "plt.ylim(-0.7,len(plot_info)-.3)\n",
    "plt.xlabel(\"Faithfulness\")\n",
    "plt.xlim(0, 1)\n",
    "\n",
    "with open(f\"../../results/tables/known/{model_name}.json\") as f:\n",
    "    model_known = json.load(f)\n",
    "\n",
    "for idx in range(len(plot_info)):\n",
    "    relation_name = plot_info[idx][\"relation\"]\n",
    "    d_relation = dataset.filter(relation_names=[relation_name])[0]\n",
    "    ans_counter = {}\n",
    "    for sample in d_relation.samples:\n",
    "    # for sample in model_known[relation_name][\"known_samples\"]:\n",
    "        obj = sample.object\n",
    "        # obj = sample[\"object\"]\n",
    "        first_token = tokenizer(obj, return_tensors=\"pt\").input_ids[0][1 if \"llama\" in model_name else 0]\n",
    "        obj = tokenizer.decode(first_token)\n",
    "        if obj not in ans_counter:\n",
    "            ans_counter[obj] = 0\n",
    "        ans_counter[obj] += 1\n",
    "    print(relation_name, end=\": \")\n",
    "    # majority_ans, majority_ans_count = max(ans_counter.items(), key = lambda x: x[1]) if len(ans_counter) > 0 else (0, 0)\n",
    "\n",
    "    # # divide_by = len(model_known[relation_name][\"known_samples\"]) if len(ans_counter) > 0 else 1\n",
    "    # divide_by = len(d_relation.samples)\n",
    "    # print(f\" -- {majority_ans} --- {majority_ans_count} / {divide_by}\")\n",
    "    # random_baseline = majority_ans_count/divide_by\n",
    "\n",
    "    divide_by = len(d_relation.samples)\n",
    "    random_baseline = 0\n",
    "    for o, o_count in ans_counter.items():\n",
    "        random_baseline += o_count**2\n",
    "    random_baseline /= divide_by**2\n",
    "\n",
    "    print(ans_counter)\n",
    "    plt.scatter(random_baseline, idx, color = \"darkred\", alpha = 1, marker=\"|\", s = 700)\n",
    "\n",
    "for x_tick in np.linspace(0, 1, 11):\n",
    "    plt.axvline(x_tick, color = \"black\", alpha = 0.05)\n",
    "\n",
    "plt.title(f\"LRE faithfulness in {model_name_dict[model_name]}\", x = 0.27, pad=15, fontsize=BIGGER_SIZE)\n",
    "\n",
    "fig = plt.gcf()\n",
    "fig.set_size_inches(10, 27)\n",
    "\n",
    "plt.savefig(f\"{fig_dir}/{model_name}-faithfulness_lre_relationwise.pdf\", bbox_inches=\"tight\")\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Generate Table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "table = []\n",
    "dataset = data.load_dataset()\n",
    "\n",
    "for relation_name, sweep_result in filtered_results.items():\n",
    "    efficacy_hparams = sweep_result.best_by_efficacy(beta=2.25 if \"llama\" not in model_name else 8.0)\n",
    "    relation = dataset.filter(\n",
    "        relation_names=[relation_name]\n",
    "    )[0]\n",
    "    table.append({\n",
    "        \"relation\": relation_name,\n",
    "        \"n_range\": len(relation.range),\n",
    "        \"layer\": efficacy_hparams.layer,\n",
    "        \"beta\": f\"{efficacy_hparams.beta.mean: .2f} ± {efficacy_hparams.beta.stdev: .2f}\",\n",
    "        \"rank\": f\"{int(efficacy_hparams.rank.mean)} ± {int(efficacy_hparams.rank.stdev)}\",\n",
    "        \"recall@1\": f\"{efficacy_hparams.recall.mean: .2f} ± {efficacy_hparams.recall.stdev: .2f}\",\n",
    "        \"efficacy\": f\"{efficacy_hparams.efficacy.mean: .2f} ± {efficacy_hparams.efficacy.stdev: .2f}\",\n",
    "        # \"n_range\": f\"{len(relation.range)}\",\n",
    "    })"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# sorted_table = sorted(table, key=lambda x: x[\"efficacy\"], reverse=True)\n",
    "sorted_table = sorted(table, key=lambda x: x[\"relation\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.DataFrame(sorted_table)\n",
    "os.makedirs(\"../../results/tables\", exist_ok=True)\n",
    "df.to_csv(f\"../../results/tables/{model_name}-hparams.csv\", index=False)\n",
    "# print(df.to_markdown(index = False, tablefmt=\"github\"))\n",
    "# print(df.to_latex(index=False, escape=False))\n",
    "df"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Single vs Multi token subjects"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# def read_and_parse_sweep_results(sweep_path):\n",
    "#     sweep_results = read_sweep_results(sweep_path)\n",
    "#     for relation in sweep_results:\n",
    "#         sweep_results[relation] = parse_results(sweep_results[relation])\n",
    "#     return sweep_results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# sweep_single_path = f\"../../results/sweep-single/{model_name}\"\n",
    "# sweep_single = read_and_parse_sweep_results(sweep_single_path)\n",
    "\n",
    "# sweep_multi_path = f\"../../results/sweep-multi/{model_name}\"\n",
    "# sweep_multi = read_and_parse_sweep_results(sweep_multi_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(15, 4.5))\n",
    "\n",
    "# for criterion, ax in zip([\"best\", \"faithfulness\", \"efficacy\"], axes):\n",
    "#     single_recalls, single_efficacies, pass_single, single_labels = get_recall_vs_efficacy_info(sweep_single, criterion=criterion)\n",
    "#     multi_recalls, multi_efficacies, pass_multi, multi_labels = get_recall_vs_efficacy_info(sweep_multi, criterion=criterion)\n",
    "#     labeled = False\n",
    "#     for r, e, p, l in zip(single_recalls, single_efficacies, pass_single, single_labels):\n",
    "#         alpha = .8 if p else 0.2\n",
    "#         if (alpha == 0.8 and not labeled):\n",
    "#             ax.scatter(r, e, color = \"blue\", label = \"single\", alpha = alpha)\n",
    "#             labeled = True\n",
    "#         else:\n",
    "#             ax.scatter(r, e, color = \"blue\", alpha = alpha)\n",
    "#         if p and (r/e < .7 or e/r < .7):\n",
    "#             ax.annotate(l, (r, e))\n",
    "\n",
    "#     labeled = False\n",
    "#     for r, e, p, l in zip(multi_recalls, multi_efficacies, pass_multi, multi_labels):\n",
    "#         alpha = .8 if p else 0.2\n",
    "#         alpha = .8 if p else 0.2\n",
    "#         if (alpha == 0.8 and not labeled):\n",
    "#             ax.scatter(r, e, color = \"red\", label = \"multi\", marker = \"s\", alpha = alpha)\n",
    "#             labeled = True\n",
    "#         else:\n",
    "#             ax.scatter(r, e, color = \"red\", marker = \"s\", alpha = alpha)\n",
    "#         if p and (r/e < .7 or e/r < .7):\n",
    "#             ax.annotate(l, (r, e))\n",
    "\n",
    "#     ax.set_title(f\"Efficacy vs Recall ({criterion})\")\n",
    "#     ax.set_xlim(0, 1)\n",
    "#     ax.set_ylim(0, 1)\n",
    "#     ax.set_xlabel(\"Recall\")\n",
    "#     ax.set_ylabel(\"Efficacy\")\n",
    "#     ax.legend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "relations",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
