{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Figures for the paper"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from viz import get_result_dfs\n",
    "from elk_generalization.utils import get_quirky_model_name\n",
    "\n",
    "models = [\n",
    "    \"EleutherAI/pythia-410m\",\n",
    "    \"EleutherAI/pythia-1b\",\n",
    "    \"EleutherAI/pythia-1.4b\",\n",
    "    \"EleutherAI/pythia-2.8b\",\n",
    "    \"EleutherAI/pythia-6.9b\",\n",
    "    \"EleutherAI/pythia-12b\",\n",
    "    \"meta-llama/Llama-2-7b-hf\",\n",
    "    \"mistralai/Mistral-7B-v0.1\",\n",
    "]\n",
    "model_scales = {\n",
    "    \"pythia-410m\": 0.41,\n",
    "    \"pythia-1b\": 1,\n",
    "    \"pythia-1.4b\": 1.4,\n",
    "    \"pythia-2.8b\": 2.8,\n",
    "    \"pythia-6.9b\": 6.9,\n",
    "    \"pythia-12b\": 12,\n",
    "    \"Llama-2-7b-hf\": 7,\n",
    "    \"Mistral-7B-v0.1\": 7,\n",
    "}\n",
    "method_titles = {\n",
    "    \"lr\": \"LogR\",\n",
    "    \"mean-diff\": \"Diff-in-means\",\n",
    "    \"mean-diff-on-pair\": \"Diff-in-means on contrast pair\",\n",
    "    \"lda\": \"LDA\",\n",
    "    \"lr-on-pair\": \"LogR on contrast pair\",\n",
    "    \"ccs\": \"CCS\",\n",
    "    \"crc\": \"CRC\",\n",
    "}\n",
    "\n",
    "ds_names = [\n",
    "    \"capitals\",\n",
    "    \"hemisphere\",\n",
    "    \"population\",\n",
    "    \"sciq\",\n",
    "    \"sentiment\",\n",
    "    \"nli\",\n",
    "    \"authors\",\n",
    "    \"addition\",\n",
    "    \"subtraction\",\n",
    "    \"multiplication\",\n",
    "    \"modularaddition\",\n",
    "    \"squaring\",\n",
    "]\n",
    "\n",
    "ds_abbrevs = {\n",
    "    \"capitals\": \"cap\",\n",
    "    \"hemisphere\": \"hem\",\n",
    "    \"population\": \"pop\",\n",
    "    \"sciq\": \"sciq\",\n",
    "    \"sentiment\": \"snt\",\n",
    "    \"nli\": \"nli\",\n",
    "    \"authors\": \"aut\",\n",
    "    \"addition\": \"add\",\n",
    "    \"subtraction\": \"sub\",\n",
    "    \"multiplication\": \"mul\",\n",
    "    \"modularaddition\": \"mod\",\n",
    "    \"squaring\": \"sqr\",\n",
    "}\n",
    "root = \"../../experiments/\"\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Qualitative differences"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_ds_names = ds_names.copy()\n",
    "plot_ds_names.remove(\"authors\")  # authors is only False for disagreements\n",
    "plot_models = models # [\"mistralai/Mistral-7B-v0.1\"]\n",
    "fr, to = \"A\", \"B\"\n",
    "filter_by = \"disagree\"\n",
    "weak_only = False\n",
    "metric = \"auroc\"\n",
    "methods = [\"lr\",]\n",
    "templatization_method = \"first\"\n",
    "standardize_templates = False\n",
    "full_finetuning = False\n",
    "rs = dict()\n",
    "for reporter in methods:\n",
    "    rs[reporter] = get_result_dfs(plot_models, fr, to, plot_ds_names, label_col=\"alice_label\", filter_by=filter_by, metric=metric, reporter=reporter, root_dir=root, weak_only=weak_only, templatization_method=templatization_method, standardize_templates=standardize_templates, full_finetuning=full_finetuning)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "sns.set_style(\"whitegrid\")\n",
    "sns.set_context(\"paper\")\n",
    "\n",
    "fig, ax = plt.subplots(1, 1, sharey=True, sharex=True, figsize=(6, 3), dpi=200)\n",
    "\n",
    "for i, method in enumerate(methods):\n",
    "    avg_reporter_results, per_ds_results_dfs, all_result_dfs, avg_lm_result, per_ds_lm_result_dfs, lm_results = rs[method]\n",
    "    colors = sns.color_palette(\"tab20\", len(per_ds_results_dfs))\n",
    "    for j, (key, result_df, lm_result) in enumerate(zip(per_ds_results_dfs.keys(), per_ds_results_dfs.values(), per_ds_lm_result_dfs.values())):\n",
    "        ax.plot(result_df[\"layer_frac\"], result_df[metric], alpha=0.9, color=colors[j], linewidth=0.8, label=ds_abbrevs[key])\n",
    "        ax.hlines(lm_result, 0, 1, color=colors[j], linewidth=1, linestyle=\":\")\n",
    "\n",
    "    # turn legend on\n",
    "    if i == 0:\n",
    "        ax.legend(loc=[1.01, 0.01])\n",
    "\n",
    "    if i % 3 == 0:\n",
    "        lab = {\n",
    "            \"disagree\": f\"{metric.upper()}\" + \" on $\\\\bf{disagreements}$\",\n",
    "            \"agree\": f\"{metric.upper()}\" + \" on $\\\\bf{agreements}$\",\n",
    "            \"all\": f\"{metric.upper()}\" + \" on $\\\\bf{all\\\\ examples}$\",\n",
    "        }[filter_by]\n",
    "        ax.set_ylabel(lab, fontsize=12)\n",
    "    \n",
    "    if i == 0:\n",
    "        ax.set_xlabel(\"Layer (fraction of max)\", fontsize=12)\n",
    "    \n",
    "    ax.set_xlim(0, 1)\n",
    "    ax.set_ylim(-0.01, 1.01)\n",
    "    if i == 3:\n",
    "        ax.legend(loc=\"lower left\")\n",
    "\n",
    "plt.title(f\"Layerwise {metric.upper()} for {fr}$\\\\to${to}\" + (\" weak only\" if weak_only else \"\"), fontsize=14)\n",
    "plt.tight_layout()\n",
    "os.makedirs(\"../../figures\", exist_ok=True)\n",
    "plt.savefig(f\"../../figures/layerwise_auroc_qualitative_{fr}_{to}.pdf\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Random classifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from pathlib import Path\n",
    "from viz import interpolate\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "from elk_generalization.utils import get_quirky_model_name\n",
    "\n",
    "# set color palette\n",
    "palette = sns.color_palette(\"tab20\", 20)\n",
    "sns.set_palette(palette)\n",
    "fr, to = \"AE\", \"BH\"\n",
    "against_col = \"alice_labels\"\n",
    "root = Path(\"../../experiments\")\n",
    "ceiling_root = Path(\"../../experiments-ceiling\")\n",
    "plt.figure(figsize=(5, 3.1), dpi=100)\n",
    "cmap = plt.get_cmap('coolwarm')\n",
    "random_ds_names = ds_names.copy()\n",
    "random_ds_names.remove(\"population\")\n",
    "\n",
    "\n",
    "for q_idx in [30, 20, 17, 16, 15, 14, 13, 10, 0,]:\n",
    "    layers_dict = dict()\n",
    "    aurocs_dict = dict()\n",
    "    for model in models:\n",
    "        for ds_name in random_ds_names:\n",
    "            model_last = model.split(\"/\")[-1]\n",
    "            try:\n",
    "                quirky_model_id, quirky_model_last = get_quirky_model_name(ds_name, model_last, templatization_method=templatization_method, standardize_templates=standardize_templates, full_finetuning=full_finetuning)\n",
    "                results = torch.load(root / quirky_model_last / to / \"test\" / f\"{fr}_random_aucs_against_{against_col}.pt\", map_location=\"cpu\")\n",
    "            except FileNotFoundError as e:\n",
    "                print(f\"skipping {quirky_model_id}\", e)\n",
    "                continue\n",
    "            aurocs = [list(results[i][\"quantiles\"].values())[q_idx] for i in range(len(results))]\n",
    "            \n",
    "            if not np.isfinite(np.array(aurocs)).all():\n",
    "                print(f\"skipping {quirky_model_id} due to NaN\")\n",
    "                continue\n",
    "            layers_dict[(model, ds_name)] = np.arange(len(results))\n",
    "            aurocs_dict[(model, ds_name)] = aurocs\n",
    "\n",
    "    layer_fracs, avg_aurocs = interpolate(list(layers_dict.values()), list(aurocs_dict.values()), layers_dict.keys(), 501)\n",
    "\n",
    "    q = list(results[0][\"quantiles\"].keys())[q_idx]\n",
    "    lab = \"$2^{\" + str(int(np.log2(q))) + \"}$\" if q <= 0.5 else \"$1-2^{\" + str(int(np.log2(1 - q))) + \"}$\"\n",
    "    plt.plot(layer_fracs, avg_aurocs, label=lab, linewidth=2, color=cmap(q))\n",
    "\n",
    "avg_reporter_results, _, _, _, _, _ = get_result_dfs(models, \"B\", \"BH\", random_ds_names, label_col=\"alice_label\", filter_by=\"all\", metric=\"auroc\", reporter=\"lr\", root_dir=ceiling_root, weak_only=False)\n",
    "plt.plot(avg_reporter_results[\"layer_frac\"], avg_reporter_results[\"auroc\"], label=\"LogR ceiling\", linewidth=2, color=\"black\", linestyle=\"--\")\n",
    "\n",
    "# turn on horizontal grid\n",
    "plt.grid(axis=\"y\")\n",
    "\n",
    "plt.legend(loc=[1.07, -0.1], title=\"quantile\", fontsize=11)\n",
    "plt.ylim(-0.01, 1.01)\n",
    "plt.xlim(0, 1)\n",
    "plt.xlabel(\"Layer (fraction of max)\", fontsize=13)\n",
    "plt.ylabel(\"AUROC on\\nall examples\", fontsize=13)\n",
    "plt.tight_layout()\n",
    "plt.title(f\"{fr}$\\\\to${to} random baseline\", fontsize=13)\n",
    "plt.savefig(f\"../../figures/layerwise_auroc_random_{fr}_{to}.pdf\", bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Barplot evaluating difficulty metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# exclude population\n",
    "difficulty_ds_names = ds_names.copy()\n",
    "difficulty_ds_names.remove(\"population\")\n",
    "# construct a dataframe with columns: ds_name, lm_auroc, character, model_name, difficulty\n",
    "\n",
    "records = []\n",
    "for character in [\"A\", \"B\"]:\n",
    "    for difficulty in [\"E\", \"H\"]:\n",
    "        to = character + difficulty\n",
    "        dummy1, dummy2 = \"AE\", \"lr\"\n",
    "        _, _, _, _, _, lm_results = get_result_dfs(models, dummy1, to, difficulty_ds_names, label_col=\"label\", filter_by=\"all\", metric=\"auroc\", reporter=dummy2, root_dir=root, weak_only=False)\n",
    "        for ds_name in difficulty_ds_names:\n",
    "            for model_name in models:\n",
    "                records.append({\n",
    "                    \"ds_name\": ds_abbrevs[ds_name],\n",
    "                    \"lm_auroc\": lm_results[(model_name, ds_name)],\n",
    "                    \"character\": character,\n",
    "                    \"model_name\": model_name,\n",
    "                    \"difficulty\": difficulty,\n",
    "                })\n",
    "\n",
    "import pandas as pd\n",
    "df = pd.DataFrame.from_records(records)\n",
    "df[\"difficulty\"] = df[\"difficulty\"].replace({\"E\": \"easy\", \"H\": \"hard\"})\n",
    "df[\"character\"] = df[\"character\"].replace({\"A\": \"Alice\", \"B\": \"Bob\"})\n",
    "\n",
    "AE_avg = df[(df[\"character\"] == \"Alice\") & (df[\"difficulty\"] == \"easy\")][\"lm_auroc\"].mean()\n",
    "AH_avg = df[(df[\"character\"] == \"Alice\") & (df[\"difficulty\"] == \"hard\")][\"lm_auroc\"].mean()\n",
    "print(f\"Alice easy: {AE_avg:.3f}, Alice hard: {AH_avg:.3f}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot\n",
    "sns.set_style(\"whitegrid\")\n",
    "sns.set_context(\"paper\")\n",
    "fig, axs = plt.subplots(1, 2, figsize=(8, 3), dpi=100)\n",
    "\n",
    "for character, ax in zip([\"Alice\", \"Bob\"], axs):\n",
    "    sub_df = df[df[\"character\"] == character]\n",
    "    sns.barplot(data=sub_df, x=\"ds_name\", y=\"lm_auroc\", hue=\"difficulty\", ax=ax, legend=True, linewidth=2)\n",
    "\n",
    "    # rotate x labels\n",
    "    for tick in ax.get_xticklabels():\n",
    "        tick.set_rotation(45)\n",
    "        # right align\n",
    "        tick.set_horizontalalignment(\"right\")\n",
    "    ax.set_title(f\"{character} LM AUROC\", fontsize=13)\n",
    "    ax.set_ylabel(f\"AUROC against {character}'s labels\", fontsize=13)\n",
    "    ax.set_xlabel(\"\")\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"../../figures/lm_auroc_by_difficulty.pdf\", bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# All transfer experiments for the appendix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "exps = [(\"A\", \"A\", \"disagree\"), (\"A\", \"B\", \"disagree\"), (\"B\", \"B\", \"disagree\"), (\"B\", \"A\", \"disagree\"), \n",
    "        (\"A\", \"AH\", \"all\"), (\"AE\", \"AH\", \"all\"), (\"A\", \"BH\", \"all\"), (\"AE\", \"BH\", \"all\")]\n",
    "metric = \"auroc\"\n",
    "ds_name = \"nli\"\n",
    "reporter = \"mean-diff\"\n",
    "root = \"../../experiments/\"\n",
    "rs = dict()\n",
    "for i, (fr, to, filter_by) in enumerate(exps):\n",
    "    rs[(fr, to)] = get_result_dfs(models, fr, to, [ds_name], label_col=\"alice_label\", filter_by=filter_by, metric=metric, reporter=reporter, root_dir=root)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "from matplotlib.ticker import MultipleLocator\n",
    "\n",
    "sns.set_style(\"whitegrid\")\n",
    "sns.set_context(\"paper\")\n",
    "\n",
    "# get viridis colors in a list (not css4)\n",
    "cmap_name = \"winter\"\n",
    "cmap = lambda x: plt.get_cmap(cmap_name)( (np.log(x) - np.log(0.41)) / (np.log(12) - np.log(0.41)) )\n",
    "\n",
    "fig, axes = plt.subplots(4, 2, sharex=True, sharey=True, figsize=(10, 7), dpi=200)\n",
    "\n",
    "for i, (fr, to, filter_by) in enumerate(exps):\n",
    "    ax = axes[i // 2][i % 2]\n",
    "    avg_reporter_results, _, results_dfs, avg_lm_result, _, lm_results = rs[(fr, to)]\n",
    "    for key, result_df, lm_result in zip(results_dfs.keys(), results_dfs.values(), lm_results.values()):\n",
    "        ax.plot(result_df[\"layer_frac\"], result_df[metric], alpha=0.4, color=cmap(model_scales[key[0].split(\"/\")[-1]]), linewidth=0.8)\n",
    "        \n",
    "    ax.plot(avg_reporter_results[\"layer_frac\"], avg_reporter_results[metric], label=\"LR probe\", linewidth=2, color=\"fuchsia\")\n",
    "\n",
    "    ax.hlines(avg_lm_result, 0, 1, label=\"Final layer LM output\", color=\"dodgerblue\", linewidth=2, linestyle=\"-\")\n",
    "    ax.hlines(0.5, 0, 1, label=\"random\", color=\"black\", linewidth=0.5, linestyle=\"--\")\n",
    "\n",
    "    if i % 2 == 0:\n",
    "        lab = {\n",
    "            \"disagree\": f\"{metric.upper()}\" + \" on\\n$\\\\bf{disagreements}$\",\n",
    "            \"agree\": f\"{metric.upper()}\" + \" on\\n$\\\\bf{agreements}$\",\n",
    "            \"all\": f\"{metric.upper()}\" + \" on\\n$\\\\bf{all\\\\ examples}$\",\n",
    "        }[filter_by]\n",
    "        ax.set_ylabel(lab, fontsize=11.5)\n",
    "    if i >= 6:\n",
    "        ax.set_xlabel(\"Layer (fraction of max)\", fontsize=12)\n",
    "    if fr == to:\n",
    "        title = fr.title() + \" (no transfer)\"\n",
    "    else:\n",
    "        title = (f\"{fr} → {to}\")\n",
    "    ax.set_title(title, fontsize=13)\n",
    "    ax.set_xlim(0, 1)\n",
    "    ax.set_ylim(-0.01, 1.01)\n",
    "    if i == 0:\n",
    "        ax.legend(loc=\"lower right\")\n",
    "\n",
    "    ax.yaxis.set_major_locator(MultipleLocator(0.25))\n",
    "plt.suptitle(f\"{ds_name.capitalize()} ({method_titles[reporter]})\", fontsize=16)\n",
    "plt.tight_layout()\n",
    "\n",
    "# # make a vertical colorbar\n",
    "import matplotlib as mpl\n",
    "norm = mpl.colors.Normalize(vmin=np.log(0.41), vmax=np.log(12))\n",
    "sm = plt.cm.ScalarMappable(cmap=cmap_name, norm=norm)\n",
    "sm.set_array([])\n",
    "cbar = fig.colorbar(sm, ax=axes.ravel().tolist(), orientation=\"vertical\", pad=0.1, aspect=40)\n",
    "cbar.ax.set_ylabel(\"Model size (B parameters)\", fontsize=12)\n",
    "cbar.ax.yaxis.set_label_position(\"left\")\n",
    "cbar.ax.yaxis.set_ticks_position(\"left\")\n",
    "cbar.ax.tick_params(axis=\"y\", labelsize=12, rotation=0)\n",
    "cbar.set_ticks(np.log(np.array([0.41, 1, 2.8, 7, 12])))\n",
    "cbar.ax.set_yticklabels([\"0.41\", \"1\", \"2.8\", \"7\", \"12\"])\n",
    "\n",
    "\n",
    "os.makedirs(\"../../figures\", exist_ok=True)\n",
    "plt.savefig(f\"../../figures/layerwise_auroc_{reporter}_{ds_name}.pdf\", bbox_inches=\"tight\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Scatter plot for ID and OOD performance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import defaultdict\n",
    "\n",
    "ds_names_without_pop = ds_names.copy()\n",
    "ds_names_without_pop.remove(\"population\")\n",
    "contrast_pairs = True\n",
    "exps = {\n",
    "    \"lr-on-pair\": [(\"AE\", \"AE\"), (\"AE\", \"BH\")],\n",
    "    \"ccs\": [(\"AE\", \"AE\"), (\"AE\", \"BH\")],\n",
    "    \"crc\": [(\"AE\", \"AE\"), (\"AE\", \"BH\")],\n",
    "} if contrast_pairs else {\n",
    "    \"lr\": [(\"AE\", \"AE\"), (\"AE\", \"BH\")],\n",
    "    \"mean-diff\": [(\"AE\", \"AE\"), (\"AE\", \"BH\")],\n",
    "    \"lda\": [(\"AE\", \"AE\"), (\"AE\", \"BH\")],\n",
    "}\n",
    "id_aurocs = defaultdict(list)\n",
    "ood_aurocs = defaultdict(list)\n",
    "for method in exps:\n",
    "    for i, (fr, to) in enumerate(exps[method]):\n",
    "        try:\n",
    "            _, _, result_dfs, _, _, _ = get_result_dfs(models, fr, to, ds_names_without_pop, filter_by=\"all\", label_col=\"alice_label\", reporter=method)\n",
    "            for model in models:\n",
    "                for ds_name in ds_names_without_pop:\n",
    "                    aucs = result_dfs[(model, ds_name)][\"auroc\"].values\n",
    "                    if to == \"AE\" and fr == \"AE\":\n",
    "                        id_aurocs[method].extend(aucs)\n",
    "                    elif to == \"BH\" and fr == \"AE\":\n",
    "                        ood_aurocs[method].extend(aucs)\n",
    "                    else:\n",
    "                        raise ValueError(\"Unexpected experiment\")\n",
    "\n",
    "        except KeyError:\n",
    "            print(f\"Experiment {fr} → {to} not found for method {method} with model {model} and ds_name {ds_name}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import os\n",
    "import numpy as np\n",
    "colors = sns.color_palette(\"Set2\")\n",
    "\n",
    "plt.figure(figsize=(4, 4), dpi=150)\n",
    "n_show = 1000\n",
    "for i, method in enumerate(id_aurocs):\n",
    "    sample = np.random.choice(len(id_aurocs[method]), n_show, replace=False)\n",
    "    for j, idx in enumerate(sample):\n",
    "        plt.scatter(id_aurocs[method][idx], ood_aurocs[method][idx], label=method_titles[method] if j == 0 else None, alpha=0.5, s=5, zorder=np.random.randint(0, 6), c=colors[i])\n",
    "plt.plot([0, 2], [0, 2], color=\"black\", linestyle=\"--\", linewidth=0.5)\n",
    "plt.axhline(0.5, color=\"grey\", linestyle=\"--\", linewidth=0.5)\n",
    "plt.xlim(0.35, 1.02)\n",
    "plt.ylim(-0.02, 1.02)\n",
    "plt.xlabel(\"AUROC on AE (no transfer)\", fontsize=13)\n",
    "plt.ylabel(\"Transfer AUROC for AE$\\\\to$BH\", fontsize=13)\n",
    "os.makedirs(\"../../figures\", exist_ok=True)\n",
    "plt.legend(fontsize=12, loc=\"lower left\")\n",
    "title = \"Probing on contrast pair\" if contrast_pairs else \"Probing on final prompt token\"\n",
    "plt.title(title, fontsize=13)\n",
    "plt.tight_layout()\n",
    "plt.savefig(f\"../../figures/transfer_scatter_{'_'.join(exps.keys())}.pdf\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# All transfer results at EIL bar plot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import pandas as pd\n",
    "from viz import earliest_informative_layer\n",
    "import os\n",
    "import numpy as np\n",
    "\n",
    "\n",
    "transfers_all = [\n",
    "        ((\"A\", \"A\"), (\"A\", \"B\")),\n",
    "        ((\"B\", \"B\"), (\"B\", \"A\")),\n",
    "        ((\"A\", \"AH\"), (\"AE\", \"AH\")),\n",
    "        ((\"A\", \"BH\"), (\"AE\", \"BH\")),\n",
    "]\n",
    "filter_bys = [\"disagree\", \"disagree\", \"all\", \"all\"]\n",
    "methods = [\"lm\", \"lr\", \"mean-diff\", \"lda\", \"lr-on-pair\", \"mean-diff-on-pair\", \"ccs\", \"crc\"]\n",
    "dfs = []\n",
    "for frame, (transfers, filter_by) in enumerate(zip(transfers_all, filter_bys)):\n",
    "    current_ds_names = ds_names.copy()\n",
    "    if filter_by == \"disagree\":\n",
    "        current_ds_names.remove(\"authors\")  # authors only has false disagreements\n",
    "    elif any(any(\"H\" in distr for distr in transfer) for transfer in transfers):\n",
    "        current_ds_names.remove(\"population\")  # population only has false labels on H\n",
    "\n",
    "    print(\"CURRENT TRANSFER\", transfers)\n",
    "    try:\n",
    "        # make a df with a column for auroc on middle layer, and a column for method, and column for transfer,\n",
    "        # where each row is a model/template/method combo\n",
    "        df = []\n",
    "        for method in methods:\n",
    "            for (fr, to) in transfers:\n",
    "                if method == \"lm\":\n",
    "                    _, _, _, _, _, lm_results = get_result_dfs(models, fr, to, current_ds_names, filter_by=filter_by, label_col=\"alice_label\", root_dir=root, reporter=\"lr\")  # lr is dummy\n",
    "                    for model in models:\n",
    "                        for ds_name in current_ds_names:\n",
    "                            if (model, ds_name) not in lm_results:\n",
    "                                continue\n",
    "                            df.append({\n",
    "                                \"auroc\": lm_results[(model, ds_name)],\n",
    "                                \"method\": \"Target distr\\nLM output\",\n",
    "                                \"transfer\": f\"{fr}$\\\\to${to}\",\n",
    "                            })\n",
    "                else:\n",
    "                    _, _, result_dfs, _, _, _ = get_result_dfs(models, fr, to, current_ds_names, filter_by=filter_by, label_col=\"alice_label\", root_dir=root, reporter=method)\n",
    "                    # pick layer on source distribution with all examples, measured against source labels\n",
    "                    _, _, id_result_dfs, _, _, _ = get_result_dfs(models, fr, fr, current_ds_names, filter_by=\"all\", label_col=\"label\", reporter=method, root_dir=root)\n",
    "                    for model in models:\n",
    "                        for ds_name in current_ds_names:\n",
    "                            if (model, ds_name) not in result_dfs or (model, ds_name) not in id_result_dfs:\n",
    "                                print(f\"Skipping {model}-{ds_name} due to missing data\")\n",
    "                                continue\n",
    "                            if id_result_dfs[(model, ds_name)].isna().any().any():\n",
    "                                print(f\"Skipping {model}-{ds_name} due to NaN\")\n",
    "                                continue\n",
    "                            layer_idx = earliest_informative_layer(id_result_dfs[(model, ds_name)], thresh=0.95)\n",
    "                            auroc = result_dfs[(model, ds_name)][\"auroc\"].values[layer_idx]\n",
    "                            df.append({\n",
    "                                \"auroc\": auroc,\n",
    "                                \"method\": method_titles[method].replace(\"LogR on contrast pair\", \"LogR on\\ncont. pair\").replace(\"Diff-in-means on contrast pair\", \"Diff-in-means\\non cont. pair\"),\n",
    "                                \"transfer\": f\"{fr}$\\\\to${to}\",\n",
    "                            })\n",
    "    except FileNotFoundError as e:\n",
    "        print(f\"Experiment not found: {e}\")\n",
    "        continue\n",
    "    \n",
    "    df = pd.DataFrame(df)\n",
    "\n",
    "    dfs.append(df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axes = plt.subplots(2, 2, sharey=True, figsize=(10, 5), dpi=150)\n",
    "\n",
    "palette = [\"slategray\"] + sns.color_palette(\"Set2\", n_colors=len(methods) - 1)\n",
    "# hatches = ['/', ''] # if i == 0 else '' for i in range(len(methods) + 1)]\n",
    "for frame, (df, filter_by) in enumerate(zip(dfs, filter_bys)):\n",
    "    ax = axes[frame // 2][frame % 2]\n",
    "    plt.sca(ax)\n",
    "    sns.barplot(data=df, x=\"transfer\", y=\"auroc\", hue=\"method\", legend=frame == 3, palette=palette, errorbar=None)\n",
    "\n",
    "    # turn off legend\n",
    "    if frame == 3:\n",
    "        ax.get_legend().remove()\n",
    "    ax.tick_params(labelsize=14)\n",
    "    \n",
    "    if filter_by == \"all\":\n",
    "        plt.ylabel(\"AUROC on\\n$\\\\bf{all\\\\ examples}$\", fontsize=14)\n",
    "    elif filter_by == \"disagree\":\n",
    "        plt.ylabel(\"AUROC on\\n$\\\\bf{disagreements}$\", fontsize=14)\n",
    "    else:\n",
    "        raise ValueError(\"Unexpected filter_by value\" + str(filter_by))\n",
    "    plt.xlabel(\"\")\n",
    "    \n",
    "    plt.axhline(0.5, color=\"black\", linestyle=\"--\", linewidth=0.5)\n",
    "    plt.ylim(-0.01, 1.01)\n",
    "    plt.title(\"$\\\\bf {(\" +'abcd'[frame] + \")}$\", fontsize=15)\n",
    "\n",
    "# add legend to the right, spanning the full height\n",
    "handles, labels = plt.gca().get_legend_handles_labels()\n",
    "fig.legend(handles, labels, loc='center right', fontsize=14, bbox_to_anchor=(1.2, 0.5))\n",
    "\n",
    "\n",
    "plt.tight_layout()\n",
    "os.makedirs(\"../../figures\", exist_ok=True)\n",
    "plt.savefig(\"../../figures/transfer_barplot.pdf\", bbox_inches=\"tight\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Table of results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_table(exps, models, transfer_ds_names, weak_floor_transfer, strong_ceil_transfer, EIL_dstr, caption, tab_label):\n",
    "    \n",
    "    # get \"weak\" and \"strong\" performances for PGR calculation\n",
    "    def get_floor_ceil(fr, to):\n",
    "        out = dict()\n",
    "        _, _, probe_results, _, _, _ = get_result_dfs(models, fr, to, transfer_ds_names, filter_by=\"all\", label_col=\"alice_label\", reporter=\"lr\")  # any reporter will do\n",
    "        for ds_name in transfer_ds_names:\n",
    "            # get final layer probe auroc for each model\n",
    "            ds_avg_auc = sum(probe_results[(model, ds_name)][\"auroc\"].iloc[-1] for model in models) / len(models)\n",
    "            out[ds_name] = ds_avg_auc\n",
    "        return out\n",
    "    weak_floor_by_ds = get_floor_ceil(*weak_floor_transfer)\n",
    "    strong_ceil_by_ds = get_floor_ceil(*strong_ceil_transfer)\n",
    "\n",
    "    weak_floor, strong_ceil = np.mean(list(weak_floor_by_ds.values())), np.mean(list(strong_ceil_by_ds.values()))\n",
    "\n",
    "    summary_table = \\\n",
    "    \"\"\"\\\\setlength{\\\\tabcolsep}{3.35pt}\n",
    "    \\\\begin{table}[htbp]\n",
    "        \\\\centering\n",
    "        \\\\caption{\"\"\" + caption + \"\"\"}\n",
    "        \\\\label{\"\"\" + tab_label + \"\"\"}\n",
    "        \\\\begin{tabular}{lccccccccccc@{\\hspace{14pt}}c}\n",
    "            \\\\toprule\\n\"\"\"\n",
    "    summary_table += \"         & \" + \" & \".join([f\"\\\\textit{{{ds_abbrevs[ds]}}}\" for ds in transfer_ds_names]) + \" & \\\\textbf{avg} \\\\\\\\ \\n\" + \\\n",
    "                    \"        \\\\midrule\\n\"\n",
    "\n",
    "    auc_by_exp_and_ds = defaultdict(list)\n",
    "    for method in exps:\n",
    "        for i, (fr, to) in enumerate(exps[method]):\n",
    "            method_title = method_titles[method].replace(\"LogR on contrast pair\", \"\\\\begin{tabular}[l]{@{}l@{}}LogR on\\\\\\\\cont. pair\\\\end{tabular}\").replace(\"Diff-in-means on contrast pair\", \"\\\\begin{tabular}[l]{@{}l@{}}Diff-in-means\\\\\\\\on cont. pair\\\\end{tabular}\")\n",
    "            if fr == \"all\":\n",
    "                summary_table += f\"        {method_title} (all\\\\(\\\\to\\\\){to}) & \"\n",
    "            else:\n",
    "                summary_table += f\"        {method_title} & \"\n",
    "\n",
    "            _, _, result_dfs, _, _, lm_results = get_result_dfs(models, fr, to, transfer_ds_names, filter_by=\"all\", label_col=\"alice_label\", reporter=method)\n",
    "            \n",
    "            # we use alice's easy no transfer data to select the layer\n",
    "            _, _, id_result_dfs, _, _, _ = get_result_dfs(models, EIL_dstr, EIL_dstr, transfer_ds_names, filter_by=\"all\", label_col=\"alice_label\", reporter=method)\n",
    "            \n",
    "            row_avg = 0\n",
    "            for j, ds_name in enumerate(transfer_ds_names):\n",
    "                ds_avg_auc = 0\n",
    "                for model in models:\n",
    "                    \n",
    "                    layer_idx = earliest_informative_layer(id_result_dfs[(model, ds_name)], thresh=0.95)\n",
    "\n",
    "                    auc = result_dfs[(model, ds_name)][\"auroc\"].values[layer_idx]\n",
    "                    row_avg += auc\n",
    "                    ds_avg_auc += auc\n",
    "                \n",
    "                ds_avg_auc /= len(models)\n",
    "                auc_by_exp_and_ds[ds_name].append(ds_avg_auc)\n",
    "                ds_avg_pgr = (ds_avg_auc - weak_floor_by_ds[ds_name]) / (strong_ceil_by_ds[ds_name] - weak_floor_by_ds[ds_name])\n",
    "                summary_table += f\"{ds_avg_pgr:.2f} & \"\n",
    "\n",
    "            row_avg /= len(transfer_ds_names) * len(models)\n",
    "            avg_pgr = (row_avg - weak_floor) / (strong_ceil - weak_floor)\n",
    "            summary_table += f\"{avg_pgr:.2f} \\\\\\\\ \\n\"\n",
    "\n",
    "    # avg row\n",
    "    summary_table += \"        \\\\midrule\\n\"\n",
    "    summary_table += \"        \\\\bf{avg} & \"\n",
    "    ova_avg_auc = 0\n",
    "    for ds_name in transfer_ds_names:\n",
    "        avg_auc = sum(auc_by_exp_and_ds[ds_name]) / len(auc_by_exp_and_ds[ds_name])\n",
    "        ova_avg_auc += avg_auc\n",
    "        avg_pgr = (avg_auc - weak_floor_by_ds[ds_name]) / (strong_ceil_by_ds[ds_name] - weak_floor_by_ds[ds_name])\n",
    "        summary_table += f\"{avg_pgr:.2f} & \"\n",
    "    ova_avg_auc /= len(transfer_ds_names)\n",
    "    ova_avg_pgr = (ova_avg_auc - weak_floor) / (strong_ceil - weak_floor)\n",
    "    summary_table += f\"{ova_avg_pgr:.2f} \\\\\\\\ \\n\"\n",
    "\n",
    "    # weak and strong results\n",
    "    summary_table += \"        \\\\midrule\\n\"\n",
    "    for floorceil_name, (fr, to), floorceil_by_ds in [(\"weak floor\", weak_floor_transfer, weak_floor_by_ds), (\"strong ceil\", strong_ceil_transfer, strong_ceil_by_ds)]:\n",
    "        summary_table += f\"        {floorceil_name} ({to}) & \"\n",
    "        cap_avg = 0\n",
    "        for ds_name, ds_avg_auc in floorceil_by_ds.items():\n",
    "            summary_table += f\"{ds_avg_auc:.2f} & \"\n",
    "            cap_avg += ds_avg_auc\n",
    "        cap_avg /= len(floorceil_by_ds)\n",
    "        summary_table += f\"{cap_avg:.2f} \\\\\\\\ \\n\"\n",
    "\n",
    "    summary_table += \"        \\\\bottomrule\\n\"\n",
    "    summary_table += \"    \\\\end{tabular}\\n\\\\end{table}\"\n",
    "    return summary_table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from viz import earliest_informative_layer\n",
    "from collections import defaultdict\n",
    "\n",
    "exps = {\n",
    "    \"lr\": [(\"AE\", \"BH\")],\n",
    "    \"mean-diff\": [(\"AE\", \"BH\")],\n",
    "    \"lda\": [(\"AE\", \"BH\")],\n",
    "    \"lr-on-pair\": [(\"AE\", \"BH\")],\n",
    "    \"mean-diff-on-pair\": [(\"AE\", \"BH\")],\n",
    "    \"ccs\": [(\"AE\", \"BH\"), (\"all\", \"BH\"),],\n",
    "    \"crc\": [(\"AE\", \"BH\"), (\"all\", \"BH\"),],\n",
    "}\n",
    "EIL_dstr = \"AE\"\n",
    "\n",
    "transfer_ds_names = ds_names.copy()\n",
    "transfer_ds_names.remove(\"population\")  # population only has false labels on H\n",
    "transfer_models = models.copy()\n",
    "\n",
    "weak_floor_transfer, strong_ceil_transfer = (\"B\", \"BH\"), (\"A\", \"AH\")\n",
    "caption = \"AE\\\\(\\\\to\\\\)BH transfer PGR broken down by probing method and dataset at the Earliest Informative Layer (\\\\ref{sec:selecting_a_layer}). The last two rows show weak floor and strong ceiling AUROC values used for PGR calculation. The best probing method recovers 75\\\\% of the difference between untruthful and truthful behavior. Note that the capitals and authors datasets have similar ceiling and floor performances, leading to noisy PGR values. Each reported PGR value is calculated by averaging AUROC values before finally taking the difference and ratio. Otherwise, the weak floor and strong ceiling estimates are noisy, often leading to small or negative denominators.\"\n",
    "print(get_table(exps, transfer_models, transfer_ds_names, weak_floor_transfer, strong_ceil_transfer, EIL_dstr, caption, \"tab:transfer\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Alice to Bob (without easy to hard)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from viz import earliest_informative_layer\n",
    "from collections import defaultdict\n",
    "\n",
    "exps = {\n",
    "    \"lr\": [(\"A\", \"B\")],\n",
    "    \"mean-diff\": [(\"A\", \"B\")],\n",
    "    \"lda\": [(\"A\", \"B\")],\n",
    "    \"lr-on-pair\": [(\"A\", \"B\")],\n",
    "    \"mean-diff-on-pair\": [(\"A\", \"B\")],\n",
    "    \"ccs\": [(\"A\", \"B\"),],\n",
    "    \"crc\": [(\"A\", \"B\"),],\n",
    "}\n",
    "EIL_dstr = \"A\"\n",
    "\n",
    "transfer_ds_names = ds_names.copy()\n",
    "transfer_ds_names.remove(\"population\")  # population only has false labels on H\n",
    "\n",
    "weak_floor_transfer, strong_ceil_transfer = (\"B\", \"B\"), (\"A\", \"A\")\n",
    "caption = \"A\\\\(\\\\to\\\\)B transfer PGR broken down by probing method and dataset like in Table~\\\\ref{tab:transfer}.\"\n",
    "print(get_table(exps, models, transfer_ds_names, weak_floor_transfer, strong_ceil_transfer, EIL_dstr, caption, \"tab:transfer2\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Table comparing methods"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Comparing with full fintuning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from collections import defaultdict\n",
    "from viz import earliest_informative_layer\n",
    "\n",
    "transfer_models = models.copy()\n",
    "transfer_models.remove(\"EleutherAI/pythia-12b\")\n",
    "transfer_ds_names = ds_names.copy()\n",
    "transfer_ds_names.remove(\"population\")\n",
    "\n",
    "exps = {\n",
    "    \"lr\": [(\"AE\", \"BH\")],\n",
    "    \"mean-diff\": [(\"AE\", \"BH\")],\n",
    "    \"lda\": [(\"AE\", \"BH\")],\n",
    "    \"lr-on-pair\": [(\"AE\", \"BH\")],\n",
    "    \"mean-diff-on-pair\": [(\"AE\", \"BH\")],\n",
    "    \"ccs\": [(\"AE\", \"BH\"), (\"all\", \"BH\"),],\n",
    "    \"crc\": [(\"AE\", \"BH\"), (\"all\", \"BH\"),],\n",
    "}\n",
    "EIL_dstr = \"AE\"\n",
    "\n",
    "weak_floor_transfer, strong_ceil_transfer = (\"B\", \"BH\"), (\"A\", \"AH\")\n",
    "caption = \"Comparison of LoRA finetuning and full finetuning in terms of PGR on AE\\\\(\\\\to\\\\)BH transfer, averaged over all models except Pythia 12B (for cost reasons).\"\n",
    "tab_label = \"tab:fullft\"\n",
    "\n",
    "summary_table = \\\n",
    "\"\"\"\\\\setlength{\\\\tabcolsep}{3.35pt}\n",
    "\\\\begin{wrapfigure}{r}{0.5\\\\textwidth} \n",
    "    \\\\centering\n",
    "    \\\\caption{\"\"\" + caption + \"\"\"}\n",
    "    \\\\label{\"\"\" + tab_label + \"\"\"}\n",
    "    \\\\begin{tabular}{lcc}\n",
    "        \\\\toprule\\n\"\"\"\n",
    "summary_table += \"         & rank-8 LoRA & full finetune \\\\\\\\ \\n\" + \\\n",
    "                \"        \\\\midrule\\n\"\n",
    "\n",
    "# get \"weak\" and \"strong\" performances for PGR calculation\n",
    "def get_floor_ceil(fr, to, full_finetuning):\n",
    "    out = dict()\n",
    "    _, _, probe_results, _, _, _ = get_result_dfs(transfer_models, fr, to, transfer_ds_names, filter_by=\"all\", label_col=\"alice_label\", reporter=\"lr\", full_finetuning=full_finetuning)  # any reporter will do\n",
    "    for ds_name in transfer_ds_names:\n",
    "        # get final layer probe auroc for each model\n",
    "        ds_avg_auc = sum(probe_results[(model, ds_name)][\"auroc\"].iloc[-1] for model in transfer_models) / len(transfer_models)\n",
    "        out[ds_name] = ds_avg_auc\n",
    "    return out\n",
    "weak_floor_by_ds = {full_ft: get_floor_ceil(*weak_floor_transfer, full_finetuning=full_ft) for full_ft in [False, True]}\n",
    "strong_ceil_by_ds = {full_ft: get_floor_ceil(*strong_ceil_transfer, full_finetuning=full_ft) for full_ft in [False, True]}\n",
    "\n",
    "weak_floor = {full_ft: np.mean(list(weak_floor_by_ds[full_ft].values())) for full_ft in [False, True]}\n",
    "strong_ceil = {full_ft: np.mean(list(strong_ceil_by_ds[full_ft].values())) for full_ft in [False, True]}\n",
    "\n",
    "auc_by_exp_and_temp = defaultdict(list)\n",
    "for method in exps:\n",
    "    for i, (fr, to) in enumerate(exps[method]):\n",
    "        method_title = method_titles[method].replace(\"LogR on contrast pair\", \"\\\\begin{tabular}[l]{@{}l@{}}LogR on\\\\\\\\cont. pair\\\\end{tabular}\").replace(\"Diff-in-means on contrast pair\", \"\\\\begin{tabular}[l]{@{}l@{}}Diff-in-means\\\\\\\\on cont. pair\\\\end{tabular}\")\n",
    "        if fr == \"all\":\n",
    "            summary_table += f\"        {method_title} (all\\\\(\\\\to\\\\){to}) \"\n",
    "        else:\n",
    "            summary_table += f\"        {method_title} \"\n",
    "\n",
    "        for ft in [False, True]:\n",
    "            _, _, result_dfs, _, _, lm_results = get_result_dfs(transfer_models, fr, to, transfer_ds_names, filter_by=\"all\", label_col=\"alice_label\", reporter=method, full_finetuning=ft)\n",
    "            \n",
    "            # we use alice's easy no transfer data to select the layer\n",
    "            _, _, id_result_dfs, _, _, _ = get_result_dfs(transfer_models, EIL_dstr, EIL_dstr, transfer_ds_names, filter_by=\"all\", label_col=\"alice_label\", reporter=method, full_finetuning=ft)\n",
    "\n",
    "            avg_over_datasets = 0        \n",
    "            for j, ds_name in enumerate(transfer_ds_names):\n",
    "                for model in transfer_models:\n",
    "                    \n",
    "                    layer_idx = earliest_informative_layer(id_result_dfs[(model, ds_name)], thresh=0.95)\n",
    "\n",
    "                    auc = result_dfs[(model, ds_name)][\"auroc\"].values[layer_idx]\n",
    "                    avg_over_datasets += auc\n",
    "\n",
    "            avg_over_datasets /= len(transfer_ds_names) * len(transfer_models)\n",
    "            auc_by_exp_and_temp[ft].append(avg_over_datasets)\n",
    "            avg_pgr = (avg_over_datasets - weak_floor[ft]) / (strong_ceil[ft] - weak_floor[ft])\n",
    "            summary_table += f\"& {avg_pgr:.2f}\"\n",
    "        summary_table += \" \\\\\\\\ \\n\"\n",
    "\n",
    "# avg row\n",
    "summary_table += \"        \\\\midrule\\n\"\n",
    "summary_table += \"        \\\\bf{avg} \"\n",
    "for ft in [False, True]:\n",
    "    avg_auc = sum(auc_by_exp_and_temp[ft]) / len(auc_by_exp_and_temp[ft])\n",
    "    avg_pgr = (avg_auc - weak_floor[ft]) / (strong_ceil[ft] - weak_floor[ft])\n",
    "    summary_table += f\"& {avg_pgr:.2f}\"\n",
    "summary_table += \" \\\\\\\\ \\n\"\n",
    "\n",
    "# weak and strong results\n",
    "summary_table += \"        \\\\midrule\\n\"\n",
    "for floorceil_name, (fr, to), floorceil_by_ds in [(\"weak floor\", weak_floor_transfer, weak_floor_by_ds), (\"strong ceil\", strong_ceil_transfer, strong_ceil_by_ds)]:\n",
    "    summary_table += f\"        {floorceil_name} ({to}) \"\n",
    "    for ft in [False, True]:\n",
    "        cap_avg = np.mean(list(floorceil_by_ds[ft].values()))\n",
    "        summary_table += f\"& {cap_avg:.2f} \"\n",
    "    summary_table += \" \\\\\\\\ \\n\"\n",
    "\n",
    "summary_table += \"        \\\\bottomrule\\n\"\n",
    "summary_table += \"    \\\\end{tabular}\\n\\\\end{wrapfigure}\"\n",
    "print(summary_table)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Comparing template setups"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "transfer_models = [\"mistralail/Mistral-7B-v0.1\",]\n",
    "transfer_ds_names = ds_names.copy()\n",
    "transfer_ds_names.remove(\"population\")\n",
    "\n",
    "exps = {\n",
    "    \"lr\": [(\"AE\", \"BH\")],\n",
    "    \"mean-diff\": [(\"AE\", \"BH\")],\n",
    "    \"lda\": [(\"AE\", \"BH\")],\n",
    "    \"lr-on-pair\": [(\"AE\", \"BH\")],\n",
    "    \"mean-diff-on-pair\": [(\"AE\", \"BH\")],\n",
    "    \"ccs\": [(\"AE\", \"BH\"), (\"all\", \"BH\"),],\n",
    "    \"crc\": [(\"AE\", \"BH\"), (\"all\", \"BH\"),],\n",
    "}\n",
    "EIL_dstr = \"AE\"\n",
    "\n",
    "templatization_methods = {\n",
    "    \"single\": (\"first\", False),\n",
    "    \"mix\": (\"random\", False),\n",
    "    \"stdzd\": (\"random\", True),\n",
    "}\n",
    "\n",
    "weak_floor_transfer, strong_ceil_transfer = (\"B\", \"BH\"), (\"A\", \"AH\")\n",
    "caption = \"Comparison of templatization setups in terms of PGR on AE\\\\(\\\\to\\\\)BH for Mistral 7B.\"\n",
    "tab_label = \"tab:template\"\n",
    "\n",
    "summary_table = \\\n",
    "\"\"\"\\\\setlength{\\\\tabcolsep}{3.35pt}\n",
    "\\\\begin{wrapfigure}{r}{0.5\\\\textwidth} \n",
    "    \\\\centering\n",
    "    \\\\caption{\"\"\" + caption + \"\"\"}\n",
    "    \\\\label{\"\"\" + tab_label + \"\"\"}\n",
    "    \\\\begin{tabular}{lccc}\n",
    "        \\\\toprule\\n\"\"\"\n",
    "summary_table += \"         & single & mixture & standardized \\\\\\\\ \\n\" + \\\n",
    "                \"        \\\\midrule\\n\"\n",
    "\n",
    "# get \"weak\" and \"strong\" performances for PGR calculation\n",
    "def get_floor_ceil(fr, to, full_finetuning=False, templatization_method=\"first\", standardize_templates=False):\n",
    "    out = dict()\n",
    "    _, _, probe_results, _, _, _ = get_result_dfs(transfer_models, fr, to, transfer_ds_names, filter_by=\"all\", label_col=\"alice_label\", reporter=\"lr\", full_finetuning=full_finetuning, templatization_method=templatization_method, standardize_templates=standardize_templates)  # any reporter will do\n",
    "    for ds_name in transfer_ds_names:\n",
    "        # get final layer probe auroc for each model\n",
    "        ds_avg_auc = sum(probe_results[(model, ds_name)][\"auroc\"].iloc[-1] for model in transfer_models) / len(transfer_models)\n",
    "        out[ds_name] = ds_avg_auc\n",
    "    return out\n",
    "weak_floor_by_ds = {temp_name: get_floor_ceil(*weak_floor_transfer, templatization_method=tm, standardize_templates=st) for temp_name, (tm, st) in templatization_methods.items()}\n",
    "strong_ceil_by_ds = {temp_name: get_floor_ceil(*strong_ceil_transfer, templatization_method=tm, standardize_templates=st) for temp_name, (tm, st) in templatization_methods.items()}\n",
    "\n",
    "weak_floor = {temp_name: np.mean(list(weak_floor_by_ds[temp_name].values())) for temp_name in templatization_methods}\n",
    "strong_ceil = {temp_name: np.mean(list(strong_ceil_by_ds[temp_name].values())) for temp_name in templatization_methods}\n",
    "\n",
    "auc_by_exp_and_temp = defaultdict(list)\n",
    "for method in exps:\n",
    "    for i, (fr, to) in enumerate(exps[method]):\n",
    "        method_title = method_titles[method].replace(\"LogR on contrast pair\", \"\\\\begin{tabular}[l]{@{}l@{}}LogR on\\\\\\\\cont. pair\\\\end{tabular}\").replace(\"Diff-in-means on contrast pair\", \"\\\\begin{tabular}[l]{@{}l@{}}Diff-in-means\\\\\\\\on cont. pair\\\\end{tabular}\")\n",
    "        if fr == \"all\":\n",
    "            summary_table += f\"        {method_title} (all\\\\(\\\\to\\\\){to}) \"\n",
    "        else:\n",
    "            summary_table += f\"        {method_title} \"\n",
    "\n",
    "        for tname, (tm, st) in templatization_methods.items():\n",
    "            _, _, result_dfs, _, _, lm_results = get_result_dfs(transfer_models, fr, to, transfer_ds_names, filter_by=\"all\", label_col=\"alice_label\", reporter=method, full_finetuning=False, templatization_method=tm, standardize_templates=st)\n",
    "            \n",
    "            # we use alice's easy no transfer data to select the layer\n",
    "            _, _, id_result_dfs, _, _, _ = get_result_dfs(transfer_models, EIL_dstr, EIL_dstr, transfer_ds_names, filter_by=\"all\", label_col=\"alice_label\", reporter=method, full_finetuning=False, templatization_method=tm, standardize_templates=st)\n",
    "\n",
    "            avg_over_datasets = 0        \n",
    "            for j, ds_name in enumerate(transfer_ds_names):\n",
    "                for model in transfer_models:\n",
    "                    \n",
    "                    layer_idx = earliest_informative_layer(id_result_dfs[(model, ds_name)], thresh=0.95)\n",
    "\n",
    "                    auc = result_dfs[(model, ds_name)][\"auroc\"].values[layer_idx]\n",
    "                    avg_over_datasets += auc\n",
    "                \n",
    "            avg_over_datasets /= len(transfer_ds_names) * len(transfer_models)\n",
    "            auc_by_exp_and_temp[tname].append(avg_over_datasets)\n",
    "            avg_pgr = (avg_over_datasets - weak_floor[tname]) / (strong_ceil[tname] - weak_floor[tname])\n",
    "            summary_table += f\"& {avg_pgr:.2f} \"\n",
    "        summary_table += \" \\\\\\\\ \\n\"\n",
    "\n",
    "# avg row\n",
    "summary_table += \"        \\\\midrule\\n\"\n",
    "summary_table += \"        \\\\bf{avg} \"\n",
    "for tname in templatization_methods:\n",
    "    avg_auc = sum(auc_by_exp_and_temp[tname]) / len(auc_by_exp_and_temp[tname])\n",
    "    avg_pgr = (avg_auc - weak_floor[tname]) / (strong_ceil[tname] - weak_floor[tname])\n",
    "    summary_table += f\"& {avg_pgr:.2f} \"\n",
    "summary_table += \" \\\\\\\\ \\n\"\n",
    "\n",
    "# weak and strong results\n",
    "summary_table += \"        \\\\midrule\\n\"\n",
    "for floorceil_name, (fr, to), floorceil_by_ds in [(\"weak floor\", weak_floor_transfer, weak_floor_by_ds), (\"strong ceil\", strong_ceil_transfer, strong_ceil_by_ds)]:\n",
    "    summary_table += f\"        {floorceil_name} ({to}) \"\n",
    "    for tname in templatization_methods:\n",
    "        cap_avg = np.mean(list(floorceil_by_ds[tname].values()))\n",
    "        summary_table += f\"& {cap_avg:.2f} \"\n",
    "    summary_table += \" \\\\\\\\ \\n\"\n",
    "\n",
    "summary_table += \"        \\\\bottomrule\\n\"\n",
    "summary_table += \"    \\\\end{tabular}\\n\\\\end{wrapfigure}\"\n",
    "print(summary_table)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Anomaly detection results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load all the results from anomaly experiments and put them in a tex table\n",
    "import os\n",
    "import json\n",
    "from elk_generalization.utils import get_quirky_model_name\n",
    "\n",
    "def custom_round(x: float) -> str:\n",
    "    if x >= 0.9995:\n",
    "        return \"1\"\n",
    "    n = 3 if x > 0.99 else 2\n",
    "    return f\"{x:.{n}f}\"\n",
    "\n",
    "anomaly_ds_names = ds_names.copy()\n",
    "anomaly_ds_names.remove(\"population\")  # population only has false labels on H\n",
    "\n",
    "subtract_diag = False\n",
    "root = \"../../anomaly-results/\"\n",
    "caption = \"Mechanistic anomaly detection AUROC. Note the Population dataset is omitted because the easy subset only contains true labels.\"\n",
    "if subtract_diag:\n",
    "    caption += \" Using diagonal subtraction.\"\n",
    "table = \\\n",
    "\"\"\"\\\\setlength{\\\\tabcolsep}{3.2pt}\n",
    "\\\\begin{table}[b!]\n",
    "    \\\\centering\n",
    "    \\\\caption{\"\"\" + caption + \"\"\"}\n",
    "    \\\\label{tab:anomaly_detection}\n",
    "    \\\\begin{tabular}{lccccccccccc@{\\\\hspace{14pt}}c}\n",
    "        \\\\toprule\\n\"\"\"\n",
    "table += \"         & \" + \" & \".join([f\"\\\\textit{{{ds_abbrevs[ds]}}}\" for ds in anomaly_ds_names]) + \" & \\\\textbf{avg} \\\\\\\\ \\n\"\n",
    "table += \"        \\\\midrule\\n\"\n",
    "for method in [\"lr\", \"mean-diff\", \"lda\", \"lr-on-pair\", \"mean-diff-on-pair\", \"ccs\", \"crc\"]:\n",
    "    method_title = method_titles[method].replace(\"LogR on contrast pair\", \"\\\\begin{tabular}[l]{@{}l@{}}LogR on\\\\\\\\cont. pair\\\\end{tabular}\").replace(\"Diff-in-means on contrast pair\", \"\\\\begin{tabular}[l]{@{}l@{}}Diff-in-means\\\\\\\\on cont. pair\\\\end{tabular}\")\n",
    "    table += f\"        {method_title} & \"\n",
    "    row_avg = 0\n",
    "    for ds_name in anomaly_ds_names:\n",
    "        abbrev = ds_abbrevs[ds_name]\n",
    "        avg_auc = 0\n",
    "        for model in models:\n",
    "            _, model_last = get_quirky_model_name(ds_name, model, templatization_method=templatization_method, standardize_templates=standardize_templates, full_finetuning=full_finetuning)\n",
    "            name = f\"mahalanobis_{model_last}_{method}\"\n",
    "            if subtract_diag:\n",
    "                name += \"_subtract_diag\"\n",
    "            with open(os.path.join(root, name + \".json\")) as f:\n",
    "                auroc = json.load(f)[\"auroc\"]\n",
    "            avg_auc += auroc\n",
    "        avg_auc /= len(models)\n",
    "        table += f\"{custom_round(avg_auc)} & \"\n",
    "        row_avg += avg_auc\n",
    "    row_avg /= len(anomaly_ds_names)\n",
    "    table += f\"{custom_round(row_avg)}\" + \" \\\\\\\\ \\n\"\n",
    "table += \"        \\\\bottomrule\\n\"\n",
    "table += \"    \\\\end{tabular}\\n\\\\end{table}\"\n",
    "\n",
    "print(table)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Results showing that Alice's and Bob's representations are not negations of each other"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# compute agreement rate for A -> B with B -> B, and A -> A with B -> A\n",
    "# describe it as agreement rate of alice and bob's representations on examples where Alice and Bob agree, \n",
    "# when Alice is in the context and when Bob is in the context\n",
    "from viz import get_agreement_rate\n",
    "\n",
    "for reporter in (\"lr\", \"ccs\", \"crc\", \"mean-diff\", \"lr-on-pair\", \"lda\"):\n",
    "    ag_rates = list(get_agreement_rate(models, ds_names, target_distr, fr1='A', fr2='B', reporter=reporter) for target_distr in (\"A\", \"B\"))\n",
    "    agreement_rate = sum(ag_rates) / len(ag_rates)\n",
    "    print(f\"Reporter: {reporter}\")\n",
    "    print(f\"Agreement rate: {agreement_rate}\")    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Causal intervention results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from viz import load_intervention_results\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import numpy as np\n",
    "from elk_generalization.utils import get_quirky_model_name\n",
    "\n",
    "sns.set_theme()\n",
    "\n",
    "fig, axes = plt.subplots(2, 3, figsize=(9.5, 5), dpi=150, sharex=True, sharey=True)\n",
    "intervention_ds_names = ds_names.copy()\n",
    "intervention_models = models.copy()\n",
    "\n",
    "for i, character in enumerate((\"Alice\", \"Bob\")):\n",
    "    for j, reporter in enumerate((\"mean-diff\", \"lr\", \"lda\")):\n",
    "        \n",
    "        fr, to, against = character, character, character\n",
    "        intervention_quirky_models = [get_quirky_model_name(ds_name, model, templatization_method=templatization_method, standardize_templates=standardize_templates, full_finetuning=full_finetuning)[1] for model in intervention_models for ds_name in intervention_ds_names]\n",
    "        layer_fracs, avg_intervened_results, avg_clean_result, all_layers, all_intervened_aurocs, all_clean_aurocs = load_intervention_results(intervention_quirky_models, fr, to, reporter, against=against)\n",
    "\n",
    "        plt.sca(axes[i][j])\n",
    "\n",
    "        plt.plot(layer_fracs, avg_intervened_results, label=\"Intervened\", color=\"dodgerblue\")\n",
    "        for l, int_auroc, cl_auroc in zip(all_layers.values(), all_intervened_aurocs.values(), all_clean_aurocs.values()):\n",
    "            if np.random.rand() < 0.25:\n",
    "                l = np.array(l) / max(l)\n",
    "                plt.plot(l, int_auroc, color=\"grey\", alpha=0.5, linewidth=0.3)\n",
    "                plt.axhline(cl_auroc, color=\"grey\", linestyle=\"--\", alpha=0.5, linewidth=0.3)\n",
    "\n",
    "        plt.axhline(0.5, color=\"black\", linestyle=\":\", linewidth=1)\n",
    "        plt.axhline(float(avg_clean_result), linestyle=\"--\", label=\"Clean\", color=\"dodgerblue\", linewidth=1)\n",
    "        plt.ylim(-0.02, 1.02)\n",
    "        plt.xlim(0, 1)\n",
    "        if j == 0:\n",
    "            plt.ylabel(\"AUROC $\\\\bf{against\\\\ Bob}$\" if character == \"Bob\" else \"AUROC\")\n",
    "        if i == 1 and j == 2:\n",
    "            plt.legend()\n",
    "        if i == 1 and j == 1:\n",
    "            plt.xlabel(\"intervened layer (fraction of max)\")\n",
    "        reporter_title = method_titles[reporter]\n",
    "        plt.title(f\"{character}'s {reporter_title} direction\")\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"../../figures/intervention.pdf\", bbox_inches=\"tight\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Scaling plot (Note: values hardcoded and gathered previously)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "sns.set_theme()\n",
    "colors = sns.color_palette(\"Set2\", n_colors=3)\n",
    "model_sizes = {'Pythia': [0.41, 1, 1.4, 2.8, 6.9, 12],\n",
    "               'Llama': [7],\n",
    "               'Mistral': [7]}\n",
    "pgr_values = {'Pythia': [0.48, 0.46, 0.56, 0.56, 0.55, 0.57],\n",
    "              'Llama': [0.61],\n",
    "              'Mistral': [0.61]}\n",
    "\n",
    "\n",
    "plt.figure(figsize=(4.3, 3), dpi=150)\n",
    "plt.plot(model_sizes['Pythia'], pgr_values['Pythia'], label=\"Pythia\", color=colors[0], marker=\"o\")\n",
    "plt.plot(model_sizes['Llama'], pgr_values['Llama'], label=\"Llama\", color=colors[1], marker=\"*\", markersize=10)\n",
    "plt.plot(model_sizes['Mistral'], pgr_values['Mistral'], label=\"Mistral\", color=colors[2], marker=\"x\", markersize=10)\n",
    "plt.xlabel(\"model size\")\n",
    "plt.ylabel(\"PGR\")\n",
    "plt.legend()\n",
    "plt.semilogx()\n",
    "plt.ylim(0.41, 0.64)\n",
    "plt.xticks(model_sizes[\"Pythia\"], labels=[str(size) + \"B\" for size in model_sizes[\"Pythia\"]])  # Adjust the x ticks to match your model sizes\n",
    "plt.savefig(\"../../figures/scaling.pdf\", bbox_inches=\"tight\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "elkg",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
