{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "f379178d",
   "metadata": {},
   "source": [
    "# Script to plot average causal effects\n",
    "\n",
    "This script loads sets of hundreds of causal traces that have been computed by the\n",
    "`experiment.causal_trace` program, and then aggregates the results to compute\n",
    "Average Indirect Effects and Average Total Effects as well as some other information.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "26bba71c",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "import numpy, os\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "plt.rcParams[\"font.family\"] = \"Times New Roman\"\n",
    "plt.rcParams[\"mathtext.fontset\"] = \"dejavuserif\"\n",
    "\n",
    "# Uncomment the architecture to plot.\n",
    "arch = \"gpt2-xl\"\n",
    "archname = \"GPT-2-XL\"\n",
    "\n",
    "# arch = 'EleutherAI_gpt-j-6B'\n",
    "# archname = 'GPT-J-6B'\n",
    "\n",
    "# arch = 'EleutherAI_gpt-neox-20b'\n",
    "# archname = 'GPT-NeoX-20B'\n",
    "\n",
    "\n",
    "class Avg:\n",
    "    def __init__(self):\n",
    "        self.d = []\n",
    "\n",
    "    def add(self, v):\n",
    "        self.d.append(v[None])\n",
    "\n",
    "    def add_all(self, vv):\n",
    "        self.d.append(vv)\n",
    "\n",
    "    def avg(self):\n",
    "        return numpy.concatenate(self.d).mean(axis=0)\n",
    "\n",
    "    def std(self):\n",
    "        return numpy.concatenate(self.d).std(axis=0)\n",
    "\n",
    "    def size(self):\n",
    "        return sum(datum.shape[0] for datum in self.d)\n",
    "\n",
    "\n",
    "def read_knowlege(count=150, kind=None, arch=\"gpt2-xl\"):\n",
    "    dirname = f\"results/{arch}/causal_trace/cases/\"\n",
    "    kindcode = \"\" if not kind else f\"_{kind}\"\n",
    "    (\n",
    "        avg_fe,\n",
    "        avg_ee,\n",
    "        avg_le,\n",
    "        avg_fa,\n",
    "        avg_ea,\n",
    "        avg_la,\n",
    "        avg_hs,\n",
    "        avg_ls,\n",
    "        avg_fs,\n",
    "        avg_fle,\n",
    "        avg_fla,\n",
    "    ) = [Avg() for _ in range(11)]\n",
    "    for i in range(count):\n",
    "        try:\n",
    "            data = numpy.load(f\"{dirname}/knowledge_{i}{kindcode}.npz\")\n",
    "        except:\n",
    "            continue\n",
    "        # Only consider cases where the model begins with the correct prediction\n",
    "        if \"correct_prediction\" in data and not data[\"correct_prediction\"]:\n",
    "            continue\n",
    "        scores = data[\"scores\"]\n",
    "        first_e, first_a = data[\"subject_range\"]\n",
    "        last_e = first_a - 1\n",
    "        last_a = len(scores) - 1\n",
    "        # original prediction\n",
    "        avg_hs.add(data[\"high_score\"])\n",
    "        # prediction after subject is corrupted\n",
    "        avg_ls.add(data[\"low_score\"])\n",
    "        avg_fs.add(scores.max())\n",
    "        # some maximum computations\n",
    "        avg_fle.add(scores[last_e].max())\n",
    "        avg_fla.add(scores[last_a].max())\n",
    "        # First subject middle, last subjet.\n",
    "        avg_fe.add(scores[first_e])\n",
    "        avg_ee.add_all(scores[first_e + 1 : last_e])\n",
    "        avg_le.add(scores[last_e])\n",
    "        # First after, middle after, last after\n",
    "        avg_fa.add(scores[first_a])\n",
    "        avg_ea.add_all(scores[first_a + 1 : last_a])\n",
    "        avg_la.add(scores[last_a])\n",
    "\n",
    "    result = numpy.stack(\n",
    "        [\n",
    "            avg_fe.avg(),\n",
    "            avg_ee.avg(),\n",
    "            avg_le.avg(),\n",
    "            avg_fa.avg(),\n",
    "            avg_ea.avg(),\n",
    "            avg_la.avg(),\n",
    "        ]\n",
    "    )\n",
    "    result_std = numpy.stack(\n",
    "        [\n",
    "            avg_fe.std(),\n",
    "            avg_ee.std(),\n",
    "            avg_le.std(),\n",
    "            avg_fa.std(),\n",
    "            avg_ea.std(),\n",
    "            avg_la.std(),\n",
    "        ]\n",
    "    )\n",
    "    print(\"Average Total Effect\", avg_hs.avg() - avg_ls.avg())\n",
    "    print(\n",
    "        \"Best average indirect effect on last subject\",\n",
    "        avg_le.avg().max() - avg_ls.avg(),\n",
    "    )\n",
    "    print(\n",
    "        \"Best average indirect effect on last token\", avg_la.avg().max() - avg_ls.avg()\n",
    "    )\n",
    "    print(\"Average best-fixed score\", avg_fs.avg())\n",
    "    print(\"Average best-fixed on last subject token score\", avg_fle.avg())\n",
    "    print(\"Average best-fixed on last word score\", avg_fla.avg())\n",
    "    print(\"Argmax at last subject token\", numpy.argmax(avg_le.avg()))\n",
    "    print(\"Max at last subject token\", numpy.max(avg_le.avg()))\n",
    "    print(\"Argmax at last prompt token\", numpy.argmax(avg_la.avg()))\n",
    "    print(\"Max at last prompt token\", numpy.max(avg_la.avg()))\n",
    "    return dict(\n",
    "        low_score=avg_ls.avg(), result=result, result_std=result_std, size=avg_fe.size()\n",
    "    )\n",
    "\n",
    "\n",
    "def plot_array(\n",
    "    differences,\n",
    "    kind=None,\n",
    "    savepdf=None,\n",
    "    title=None,\n",
    "    low_score=None,\n",
    "    high_score=None,\n",
    "    archname=\"GPT2-XL\",\n",
    "):\n",
    "    if low_score is None:\n",
    "        low_score = differences.min()\n",
    "    if high_score is None:\n",
    "        high_score = differences.max()\n",
    "    answer = \"AIE\"\n",
    "    labels = [\n",
    "        \"First subject token\",\n",
    "        \"Middle subject tokens\",\n",
    "        \"Last subject token\",\n",
    "        \"First subsequent token\",\n",
    "        \"Further tokens\",\n",
    "        \"Last token\",\n",
    "    ]\n",
    "\n",
    "    fig, ax = plt.subplots(figsize=(3.5, 2), dpi=200)\n",
    "    h = ax.pcolor(\n",
    "        differences,\n",
    "        cmap={None: \"Purples\", \"mlp\": \"Greens\", \"attn\": \"Reds\"}[kind],\n",
    "        vmin=low_score,\n",
    "        vmax=high_score,\n",
    "    )\n",
    "    if title:\n",
    "        ax.set_title(title)\n",
    "    ax.invert_yaxis()\n",
    "    ax.set_yticks([0.5 + i for i in range(len(differences))])\n",
    "    ax.set_xticks([0.5 + i for i in range(0, differences.shape[1] - 6, 5)])\n",
    "    ax.set_xticklabels(list(range(0, differences.shape[1] - 6, 5)))\n",
    "    ax.set_yticklabels(labels)\n",
    "    if kind is None:\n",
    "        ax.set_xlabel(f\"single patched layer within {archname}\")\n",
    "    else:\n",
    "        ax.set_xlabel(f\"center of interval of 10 patched {kind} layers\")\n",
    "    cb = plt.colorbar(h)\n",
    "    # The following should be cb.ax.set_xlabel(answer), but this is broken in matplotlib 3.5.1.\n",
    "    if answer:\n",
    "        cb.ax.set_title(str(answer).strip(), y=-0.16, fontsize=10)\n",
    "\n",
    "    if savepdf:\n",
    "        os.makedirs(os.path.dirname(savepdf), exist_ok=True)\n",
    "        plt.savefig(savepdf, bbox_inches=\"tight\")\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "the_count = 1208\n",
    "high_score = None  # Scale all plots according to the y axis of the first plot\n",
    "\n",
    "for kind in [None, \"mlp\", \"attn\"]:\n",
    "    d = read_knowlege(the_count, kind, arch)\n",
    "    count = d[\"size\"]\n",
    "    what = {\n",
    "        None: \"Indirect Effect of $h_i^{(l)}$\",\n",
    "        \"mlp\": \"Indirect Effect of MLP\",\n",
    "        \"attn\": \"Indirect Effect of Attn\",\n",
    "    }[kind]\n",
    "    title = f\"Avg {what} over {count} prompts\"\n",
    "    result = numpy.clip(d[\"result\"] - d[\"low_score\"], 0, None)\n",
    "    kindcode = \"\" if kind is None else f\"_{kind}\"\n",
    "    if kind not in [\"mlp\", \"attn\"]:\n",
    "        high_score = result.max()\n",
    "    plot_array(\n",
    "        result,\n",
    "        kind=kind,\n",
    "        title=title,\n",
    "        low_score=0.0,\n",
    "        high_score=high_score,\n",
    "        archname=archname,\n",
    "        savepdf=f\"results/{arch}/causal_trace/summary_pdfs/rollup{kindcode}.pdf\",\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c896e9ac",
   "metadata": {},
   "source": [
    "## Plot line graph\n",
    "\n",
    "To make confidence intervals visible, we plot the data as line graphs below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c1fe3105",
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "\n",
    "labels = [\n",
    "    \"First subject token\",\n",
    "    \"Middle subject tokens\",\n",
    "    \"Last subject token\",\n",
    "    \"First subsequent token\",\n",
    "    \"Further tokens\",\n",
    "    \"Last token\",\n",
    "]\n",
    "color_order = [0, 1, 2, 4, 5, 3]\n",
    "x = None\n",
    "\n",
    "cmap = plt.get_cmap(\"tab10\")\n",
    "fig, axes = plt.subplots(1, 3, figsize=(13, 3.5), sharey=True, dpi=200)\n",
    "for j, (kind, title) in enumerate(\n",
    "    [\n",
    "        (None, \"single hidden vector\"),\n",
    "        (\"mlp\", \"run of 10 MLP lookups\"),\n",
    "        (\"attn\", \"run of 10 Attn modules\"),\n",
    "    ]\n",
    "):\n",
    "    print(f\"Reading {kind}\")\n",
    "    d = read_knowlege(225, kind, arch)\n",
    "    for i, label in list(enumerate(labels)):\n",
    "        y = d[\"result\"][i] - d[\"low_score\"]\n",
    "        if x is None:\n",
    "            x = list(range(len(y)))\n",
    "        std = d[\"result_std\"][i]\n",
    "        error = std * 1.96 / math.sqrt(count)\n",
    "        axes[j].fill_between(\n",
    "            x, y - error, y + error, alpha=0.3, color=cmap.colors[color_order[i]]\n",
    "        )\n",
    "        axes[j].plot(x, y, label=label, color=cmap.colors[color_order[i]])\n",
    "\n",
    "    axes[j].set_title(f\"Average indirect effect of a {title}\")\n",
    "    axes[j].set_ylabel(\"Average indirect effect on p(o)\")\n",
    "    axes[j].set_xlabel(f\"Layer number in {archname}\")\n",
    "    # axes[j].set_ylim(0.1, 0.3)\n",
    "axes[1].legend(frameon=False)\n",
    "plt.tight_layout()\n",
    "plt.savefig(f\"results/{arch}/causal_trace/summary_pdfs/lineplot-causaltrace.pdf\")\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
