{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import itertools\n",
    "import os\n",
    "import sys\n",
    "import typing\n",
    "from collections import Counter, defaultdict\n",
    "from pathlib import Path\n",
    "\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import seaborn as sns\n",
    "import tabulate\n",
    "import torch\n",
    "from IPython.display import Markdown, display\n",
    "from loguru import logger\n",
    "from tqdm.auto import tqdm\n",
    "\n",
    "torch.set_grad_enabled(False)\n",
    "\n",
    "from shared_definitions import *\n",
    "from shared_visualization_utils import *\n",
    "\n",
    "sys.path.insert(0, os.path.abspath(\"..\"))\n",
    "\n",
    "sns.set_theme(style=\"white\", context=\"notebook\", rc={\"figure.figsize\": (14, 10)})\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "result_df, indirect_effects_by_model_and_dataset, top_heads_by_model_and_dataset = load_and_combine_raw_results()\n",
    "result_df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "SKIP_OLMO = False\n",
    "if SKIP_OLMO:\n",
    "    result_df = result_df[~result_df.model.str.contains(\"OLMo\")]\n",
    "    indirect_effects_by_model_and_dataset = {\n",
    "        k: v for k, v in indirect_effects_by_model_and_dataset.items() if \"OLMo\" not in k\n",
    "    }\n",
    "    top_heads_by_model_and_dataset = {k: v for k, v in top_heads_by_model_and_dataset.items() if \"OLMo\" not in k}\n",
    "\n",
    "RELEVANT_MODELS = ORDERED_MODELS[:]\n",
    "RELEVANT_SCATTER_ORDERED_MODELS = SCATTER_ORDERED_MODELS[:]\n",
    "if SKIP_OLMO:\n",
    "    RELEVANT_MODELS = [model for model in RELEVANT_MODELS if \"OLMo\" not in model]\n",
    "    RELEVANT_SCATTER_ORDERED_MODELS = [model for model in RELEVANT_SCATTER_ORDERED_MODELS if \"OLMo\" not in model]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# RQ2: How similar are the sets of top heads used?\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "top_heads_path = Path(\"/checkpoint/guyd/function_vectors/full_results_top_heads\")\n",
    "\n",
    "top_heads_by_model_and_type = defaultdict(dict)\n",
    "\n",
    "for heads_type, glob_pattern in (\n",
    "    (\"prompt\", \"*both_all_top_heads.json\"),\n",
    "    (\"icl\", \"*icl_same_test_sets_top_heads.json\"),\n",
    "):\n",
    "    for top_heads_file in top_heads_path.glob(glob_pattern):\n",
    "        model, _ = top_heads_file.name.split(\"_\", 1)\n",
    "\n",
    "        logger.debug(f\"Loading {heads_type} heads for {model} from: {str(top_heads_file)}\")\n",
    "\n",
    "        with open(top_heads_file, \"r\") as f:\n",
    "            top_heads_by_model_and_type[model][heads_type] = json.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.stats.mstats import gmean\n",
    "\n",
    "lines = []\n",
    "rows = []\n",
    "\n",
    "for model, model_results in top_heads_by_model_and_type.items():\n",
    "    lines.append(f\"- {model}\")\n",
    "    icl_heads = model_results.get(\"icl\", None)\n",
    "    prompt_heads = model_results.get(\"prompt\", None)\n",
    "    if icl_heads is None or prompt_heads is None:\n",
    "        logger.warning(f\"Missing data for {model}\")\n",
    "        continue\n",
    "\n",
    "    for n in (10, 20):\n",
    "        iclh = [tuple(t) for t in icl_heads[\"top_heads\"][:n]]\n",
    "        icl_layers = [t[0] for t in iclh]\n",
    "        ph = [tuple(t) for t in prompt_heads[\"top_heads\"][:n]]\n",
    "        prompt_layers = [t[0] for t in ph]\n",
    "        lines.append(\n",
    "            f\"   - N={n}: {len(set(iclh) & set(ph))} shared | ICL layers: {np.mean(icl_layers):.2f} | Prompt layers: {np.mean(prompt_layers):.2f}\"\n",
    "        )\n",
    "\n",
    "        rows.append(\n",
    "            dict(\n",
    "                model=model,\n",
    "                n_heads=n,\n",
    "                n_layers=MODEL_TO_N_LAYERS[model],\n",
    "                icl_heads=iclh,\n",
    "                prompt_heads=ph,\n",
    "                icl_effects=icl_heads[\"top_head_effects\"][:n],\n",
    "                prompt_effects=prompt_heads[\"top_head_effects\"][:n],\n",
    "            )\n",
    "        )\n",
    "\n",
    "display(Markdown(\"\\n\".join(lines)))\n",
    "\n",
    "top_heads_summary_df = pd.DataFrame(rows)\n",
    "top_heads_summary_df = top_heads_summary_df.assign(\n",
    "    icl_effect_mean=top_heads_summary_df.icl_effects.map(lambda x: np.mean(x)),\n",
    "    icl_effect_geom_mean=top_heads_summary_df.icl_effects.map(lambda x: gmean(x)),\n",
    "    prompt_effect_mean=top_heads_summary_df.prompt_effects.map(lambda x: np.mean(x)),\n",
    "    prompt_effect_geom_mean=top_heads_summary_df.prompt_effects.map(lambda x: gmean(x)),\n",
    "    icl_layer_mean=top_heads_summary_df.icl_heads.map(lambda x: np.mean([t[0] for t in x])),\n",
    "    icl_layer_std=top_heads_summary_df.icl_heads.map(lambda x: np.std([t[0] for t in x])),\n",
    "    prompt_layer_mean=top_heads_summary_df.prompt_heads.map(lambda x: np.mean([t[0] for t in x])),\n",
    "    prompt_layer_std=top_heads_summary_df.prompt_heads.map(lambda x: np.std([t[0] for t in x])),\n",
    ")\n",
    "\n",
    "top_heads_summary_df = top_heads_summary_df.assign(\n",
    "    icl_layer_mean_depth=top_heads_summary_df.apply(lambda row: row.icl_layer_mean / row.n_layers, axis=1),\n",
    "    icl_layer_std_depth=top_heads_summary_df.apply(lambda row: row.icl_layer_std / row.n_layers, axis=1),\n",
    "    prompt_layer_mean_depth=top_heads_summary_df.apply(lambda row: row.prompt_layer_mean / row.n_layers, axis=1),\n",
    "    prompt_layer_std_depth=top_heads_summary_df.apply(lambda row: row.prompt_layer_std / row.n_layers, axis=1),\n",
    ")\n",
    "\n",
    "\n",
    "def row_summary(row: pd.Series):\n",
    "    sh = len(set(row.icl_heads) & set(row.prompt_heads))\n",
    "    return dict(\n",
    "        key=f\"{row.model} @ {row.n_heads}\",\n",
    "        shared_heads=sh,\n",
    "        shared_head_fraction=sh / row.n_heads,\n",
    "        icl_effect_mean=row.icl_effect_mean,\n",
    "        prompt_effect_mean=row.prompt_effect_mean,\n",
    "        mean_diff=row.icl_effect_mean - row.prompt_effect_mean,\n",
    "        geom_mean_diff=row.icl_effect_geom_mean - row.prompt_effect_geom_mean,\n",
    "        icl_layer=row.icl_layer_mean,\n",
    "        icl_layer_std=row.icl_layer_std,\n",
    "        icl_layer_depth=row.icl_layer_mean_depth,\n",
    "        icl_layer_depth_std=row.icl_layer_std_depth,\n",
    "        prompt_layer=row.prompt_layer_mean,\n",
    "        prompt_layer_std=row.prompt_layer_std,\n",
    "        prompt_layer_depth=row.prompt_layer_mean_depth,\n",
    "        prompt_layer_depth_std=row.prompt_layer_std_depth,\n",
    "        layer_depth_diff=row.prompt_layer_mean_depth - row.icl_layer_mean_depth,\n",
    "        icl_layers=[t[0] for t in row.icl_heads],\n",
    "        icl_layer_depths=[t[0] / row.n_layers for t in row.icl_heads],\n",
    "        prompt_layers=[t[0] for t in row.prompt_heads],\n",
    "        prompt_layer_depths=[t[0] / row.n_layers for t in row.prompt_heads],\n",
    "        prompt_heads=row.prompt_heads,\n",
    "        icl_heads=row.icl_heads,\n",
    "    )\n",
    "\n",
    "\n",
    "rows = list(\n",
    "    top_heads_summary_df.apply(\n",
    "        lambda row: row_summary(row),\n",
    "        axis=1,\n",
    "    ).values\n",
    ")\n",
    "\n",
    "layer_key_dicts = dict()\n",
    "for key_dict in rows:\n",
    "    key_dict = {**key_dict}\n",
    "    key = key_dict.pop(\"key\")\n",
    "    layer_key_dicts[key] = key_dict\n",
    "\n",
    "d = pd.DataFrame(rows).set_index(\"key\").T\n",
    "display(Markdown(tabulate.tabulate(d, headers=\"keys\", tablefmt=\"github\")))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# RQ2.1 Top heads -- first panel\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib.gridspec import GridSpec\n",
    "\n",
    "# from palettable.colorbrewer.qualitative import Dark2_4 as heads_cmap\n",
    "# DEFAULT_TOP_HEADS_COLORS = [(1, 1, 1), heads_cmap.mpl_colors[2], heads_cmap.mpl_colors[3], heads_cmap.mpl_colors[0]]\n",
    "# from palettable.colorbrewer.qualitative import Accent_6 as heads_cmap\n",
    "# DEFAULT_TOP_HEADS_COLORS = [(1, 1, 1), heads_cmap.mpl_colors[1], heads_cmap.mpl_colors[5], heads_cmap.mpl_colors[0]]\n",
    "from palettable.colorbrewer.qualitative import Set1_3 as heads_cmap\n",
    "\n",
    "DEFAULT_TOP_HEADS_COLORS = [(1, 1, 1), *heads_cmap.mpl_colors]\n",
    "\n",
    "\n",
    "def get_model_shape(model_name: str) -> typing.Tuple[int, int]:\n",
    "    # hack for an appendix plot\n",
    "    if model_name in MODEL_TO_N_LAYERS_HEADS:\n",
    "        return MODEL_TO_N_LAYERS_HEADS[model_name]\n",
    "    else:\n",
    "        short_name = model_name.split(\"_\")[0]\n",
    "        if short_name not in MODEL_TO_N_LAYERS_HEADS:\n",
    "            raise ValueError(f\"Model {model_name} not found in MODEL_TO_N_LAYERS_HEADS\")\n",
    "        \n",
    "        return MODEL_TO_N_LAYERS_HEADS[short_name]\n",
    "\n",
    "\n",
    "def plot_shared_top_heads(\n",
    "    heads_df: pd.DataFrame,\n",
    "    model_names: typing.List[str],\n",
    "    n_heads: int | typing.Dict[str, int] = 20,\n",
    "    limits_from_model_layers: bool = True,\n",
    "    draw_mean_lines: bool = False,\n",
    "    shape: typing.Tuple[int, int] | None = None,\n",
    "    panel_width: int = 4,\n",
    "    panel_height: int = 4,\n",
    "    min_step: int = 3,\n",
    "    max_step: int = 10,\n",
    "    legend_panel_height: float = 0.2,\n",
    "    fontsize: int = 12,\n",
    "    font_inc: int = 4,\n",
    "    fontfamily: str | None = None,\n",
    "    textfontweight: str | None = None,\n",
    "    top_heads_colors: typing.List[tuple] = DEFAULT_TOP_HEADS_COLORS,\n",
    "    save_name: typing.Optional[str] = None,\n",
    "    show_legend: bool = True,\n",
    "    legend_outside: bool = True,\n",
    "    legend_loc: str | typing.Tuple[int, int] = None,\n",
    "    legend_ax_index: int = -1,\n",
    "    legend_ncol: int = 1,\n",
    "    legend_bbox_to_anchor: typing.Tuple[float, float] = (1.05, 1),\n",
    "    add_colorbar: bool = False,\n",
    "    minimal_mode: bool = False,\n",
    "    grid_spec_kwargs: typing.Optional[typing.Dict[str, typing.Any]] = None,\n",
    "    annotate_panels: bool = False,\n",
    "    annotate_panels_start: str = \"A\",\n",
    "    annotate_panel_position: typing.Tuple[float, float] = (1.01, 1.0),\n",
    "    annotate_font_inc: int = 4,\n",
    "    vmin: float = 0,\n",
    "    vmax: float = 3,\n",
    "    suptitle: str | None = None,\n",
    "    model_titles: typing.Dict[str, str] | None = None,\n",
    "):\n",
    "    if model_titles is None:\n",
    "        model_titles = dict()\n",
    "    \n",
    "    if isinstance(n_heads, int):\n",
    "        n_heads = {model_name: n_heads for model_name in model_names}\n",
    "\n",
    "    if shape is None:\n",
    "        if len(model_names) % 2 == 0:\n",
    "            shape = (len(model_names) // 2, 2)\n",
    "        else:\n",
    "            shape = (len(model_names), 1)\n",
    "\n",
    "    if isinstance(top_heads_colors, list):\n",
    "        custom_cmap = matplotlib.colors.ListedColormap(top_heads_colors, name=\"head_colors\", N=len(top_heads_colors))\n",
    "    else:\n",
    "        custom_cmap = top_heads_colors\n",
    "\n",
    "    if grid_spec_kwargs is None:\n",
    "        grid_spec_kwargs = dict(hspace=0.2, wspace=0.25)\n",
    "\n",
    "    if add_colorbar and show_legend:\n",
    "        raise ValueError(\"Cannot add colorbar and legend at the same time\")\n",
    "\n",
    "    if show_legend:\n",
    "        gs = GridSpec(shape[0] + 1, shape[1], height_ratios=[1] * shape[0] + [legend_panel_height], **grid_spec_kwargs)\n",
    "    elif add_colorbar:\n",
    "        gs = GridSpec(shape[0], shape[1] + 1, width_ratios=[1] * shape[1] + [legend_panel_height], **grid_spec_kwargs)\n",
    "    else:\n",
    "        gs = GridSpec(shape[0], shape[1], **grid_spec_kwargs)\n",
    "\n",
    "    fig = plt.figure(\n",
    "        layout=\"constrained\",\n",
    "        figsize=(shape[1] * panel_width, shape[0] * panel_height + int(show_legend)),\n",
    "    )\n",
    "\n",
    "    for i, model_name in enumerate(model_names):\n",
    "        n = n_heads[model_name]\n",
    "        r = i // shape[1]\n",
    "        c = i % shape[1]\n",
    "        ax = fig.add_subplot(gs[r, c])\n",
    "\n",
    "        if not minimal_mode:\n",
    "            if r == shape[1] - 1:\n",
    "                ax.set_xlabel(\"Layer\", fontsize=fontsize, fontfamily=fontfamily, fontweight=textfontweight)\n",
    "            if c == 0:\n",
    "                ax.set_ylabel(\"Head Index\", fontsize=fontsize, fontfamily=fontfamily, fontweight=textfontweight)\n",
    "\n",
    "        mdf = heads_df[(heads_df.model == model_name) & (heads_df.n_heads == n)]\n",
    "        prompt_heads = mdf.prompt_heads.values[0]\n",
    "        icl_heads = mdf.icl_heads.values[0]\n",
    "\n",
    "    \n",
    "        head_array = np.zeros(get_model_shape(model_name), dtype=int)\n",
    "        for head in prompt_heads:\n",
    "            head_array[head] += 1\n",
    "\n",
    "        for head in icl_heads:\n",
    "            head_array[head] += 2\n",
    "\n",
    "        ax.imshow(head_array.T, cmap=custom_cmap, vmin=vmin, vmax=vmax)\n",
    "\n",
    "        if draw_mean_lines:\n",
    "            prompt_head_set = set(prompt_heads)\n",
    "            icl_head_set = set(icl_heads)\n",
    "            shared_head_set = prompt_head_set & icl_head_set\n",
    "            prompt_only_head_set = prompt_head_set - shared_head_set\n",
    "            icl_only_head_set = icl_head_set - shared_head_set\n",
    "\n",
    "            for head_set, color in zip(\n",
    "                (prompt_only_head_set, icl_only_head_set, shared_head_set),\n",
    "                top_heads_colors[1:],\n",
    "            ):\n",
    "                if len(head_set) > 0:\n",
    "                    mean_layer = np.mean([head[0] for head in head_set])\n",
    "                    ax.axvline(mean_layer, color=color, linestyle=\"--\", linewidth=2, alpha=0.75)\n",
    "\n",
    "        if not minimal_mode:\n",
    "            if model_name in model_titles:\n",
    "                title = model_titles[model_name]\n",
    "            else:\n",
    "                n_shared = np.sum(head_array == 3)\n",
    "                title = f\"{model_name}\\n{n} top heads, {n_shared} shared\"\n",
    "            \n",
    "            ax.set_title(title, fontsize=fontsize, fontfamily=fontfamily, fontweight=textfontweight)\n",
    "\n",
    "        if minimal_mode:\n",
    "            ax.set_xticks([])\n",
    "            ax.set_yticks([])\n",
    "\n",
    "        elif limits_from_model_layers:\n",
    "            model_layers, model_heads = get_model_shape(model_name)\n",
    "            ax.set_xlim(0, model_layers)\n",
    "            ax.set_ylim(0, model_heads)\n",
    "\n",
    "            for max_value, set_method in (\n",
    "                (model_layers, ax.set_xticks),\n",
    "                (model_heads, ax.set_yticks),\n",
    "            ):\n",
    "                step = 1\n",
    "                for step in range(min_step, max_step + 1):\n",
    "                    if max_value % step == 0:\n",
    "                        break\n",
    "\n",
    "                set_method(np.arange(0, max_value + step, step))\n",
    "\n",
    "        if annotate_panels:\n",
    "            ax.text(\n",
    "                *annotate_panel_position,\n",
    "                chr(ord(annotate_panels_start) + i),\n",
    "                ha=\"left\",\n",
    "                va=\"top\",\n",
    "                fontsize=fontsize + annotate_font_inc,\n",
    "                fontweight=\"bold\",\n",
    "                fontfamily=fontfamily,\n",
    "                transform=ax.transAxes,\n",
    "            )\n",
    "\n",
    "        # n_prompt_only = np.sum(head_array == 1)\n",
    "        # n_icl_only = np.sum(head_array == 2)\n",
    "\n",
    "        # legend_entries = [\n",
    "        #     plt.Line2D([0], [0], color=color, linestyle=\"-\", linewidth=3, label=f\"{label} = {count}\")\n",
    "        #     for (color, label, count) in zip (\n",
    "        #         top_heads_colors[1:],\n",
    "        #         [\"Instruction only\", \"Demonstration only\", \"Both\"],\n",
    "        #         [n_prompt_only, n_icl_only, n_shared]\n",
    "        #     )\n",
    "        # ]\n",
    "\n",
    "        # ax.legend(handles=legend_entries, fontsize=fontsize - font_inc, prop=dict(family=fontfamily))\n",
    "\n",
    "    if show_legend:\n",
    "        legend_entries = [\n",
    "            # plt.Line2D([0], [0], color=color, linestyle=\"-\", linewidth=5, label=label)\n",
    "            matplotlib.patches.Rectangle((0, 0), 1, 1, color=color, label=label)\n",
    "            for (color, label) in zip(\n",
    "                top_heads_colors[1:],\n",
    "                # [\"Instruction FV only\", \"Demonstration FV only\", \"Shared in both FV\"]\n",
    "                # [\"Instruction only heads\", \"Demonstration only heads\", \"Shared heads\"],\n",
    "                [\"Instruction only\", \"Demonstration only\", \"Shared\"],\n",
    "            )\n",
    "        ]\n",
    "\n",
    "        legend_kwargs = dict(\n",
    "            handles=legend_entries, ncol=legend_ncol, prop=dict(family=fontfamily, size=fontsize), handlelength=0.75\n",
    "        )\n",
    "        if legend_outside:\n",
    "            legend_kwargs[\"bbox_to_anchor\"] = legend_bbox_to_anchor\n",
    "            legend_kwargs[\"loc\"] = legend_loc\n",
    "        elif legend_loc is not None:\n",
    "            legend_kwargs[\"loc\"] = legend_loc\n",
    "\n",
    "        legend_ax = fig.add_subplot(gs[-1, :])\n",
    "        legend_ax.axis(\"off\")\n",
    "\n",
    "        legend_ax.legend(**legend_kwargs)\n",
    "\n",
    "    elif add_colorbar:\n",
    "        cbar_ax = fig.add_subplot(gs[:, -1])\n",
    "        cbar = plt.colorbar(\n",
    "            matplotlib.cm.ScalarMappable(cmap=custom_cmap, norm=matplotlib.colors.Normalize(vmin=vmin, vmax=vmax)),\n",
    "            cax=cbar_ax,\n",
    "            orientation=\"vertical\",\n",
    "            # ticks=np.arange(vmin, vmax + 1) + 0.5,\n",
    "        )\n",
    "        # cbar.ax.set_yticklabels([\"None\", \"Prompt\", \"ICL\"], fontsize=fontsize, fontfamily=fontfamily)\n",
    "        # cbar.ax.set_yticklabels(np.arange(vmin, vmax + 1), fontsize=fontsize, fontfamily=fontfamily)\n",
    "        cbar.ax.tick_params(labelsize=fontsize)\n",
    "\n",
    "    if suptitle is not None:\n",
    "        fig.suptitle(\n",
    "            suptitle,\n",
    "            fontsize=fontsize + font_inc,\n",
    "            fontweight=textfontweight,\n",
    "            fontfamily=fontfamily,\n",
    "        )\n",
    "\n",
    "    # plt.tight_layout()\n",
    "    if save_name is not None:\n",
    "        save_plot(save_name)\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "MAIN_PAPER_PLOT_MODELS = [\"Llama-3.2-3B\", \"Llama-3.2-3B-Instruct\", \"Llama-3.1-8B\", \"Llama-3.1-8B-Instruct\"]\n",
    "\n",
    "plot_shared_top_heads(\n",
    "    top_heads_summary_df,\n",
    "    MAIN_PAPER_PLOT_MODELS,\n",
    "    draw_mean_lines=True,\n",
    "    legend_ax_index=2,\n",
    "    legend_outside=False,\n",
    "    legend_loc=(-0.05, 0),\n",
    "    legend_ncol=3,\n",
    "    legend_panel_height=0.05,\n",
    "    grid_spec_kwargs=dict(hspace=0.5),\n",
    "    fontfamily=\"monospace\",\n",
    "    textfontweight=\"bold\",\n",
    "    save_name=\"finding_3_shared_top_heads.pdf\",\n",
    "    annotate_panels=True,\n",
    "    annotate_panels_start=\"A\",\n",
    "    fontsize=14,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Appendix version of the above plot for more models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "APPENDIX_MODELS = [\n",
    "    model\n",
    "    for model in ORDERED_MODELS\n",
    "    if model not in MAIN_PAPER_PLOT_MODELS and \"13b\" not in model\n",
    "]\n",
    "\n",
    "\n",
    "plot_shared_top_heads(\n",
    "    top_heads_summary_df,\n",
    "    APPENDIX_MODELS,\n",
    "    draw_mean_lines=True,\n",
    "    legend_ax_index=2,\n",
    "    legend_outside=False,\n",
    "    legend_loc=(-0.05, 0),\n",
    "    legend_ncol=3,\n",
    "    legend_panel_height=0.05,\n",
    "    panel_height=4.5,\n",
    "    grid_spec_kwargs=dict(hspace=0.6),\n",
    "    fontfamily=\"monospace\",\n",
    "    textfontweight=\"bold\",\n",
    "    save_name=\"appendix_finding_3_shared_top_heads.pdf\",\n",
    "    annotate_panels=True,\n",
    "    annotate_panels_start=\"A\",\n",
    "    fontsize=14,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from palettable.colorbrewer.qualitative import Set1_3 as heads_cmap\n",
    "\n",
    "FIG_1_TOP_HEAD_COLORS = [(1, 1, 1), *heads_cmap.mpl_colors]\n",
    "FIG_1_MODEL = MAIN_PAPER_PLOT_MODELS[-1]\n",
    "\n",
    "plot_shared_top_heads(\n",
    "    top_heads_summary_df,\n",
    "    [FIG_1_MODEL],\n",
    "    top_heads_colors=FIG_1_TOP_HEAD_COLORS,\n",
    "    shape=(1, 1),\n",
    "    draw_mean_lines=False,\n",
    "    show_legend=False,\n",
    "    minimal_mode=True,\n",
    "    # legend_panel_height=0.05,\n",
    "    grid_spec_kwargs=dict(hspace=0.5),\n",
    "    fontfamily=\"monospace\",\n",
    "    save_name=\"figure_1_top_heads.png\",\n",
    "    panel_width=8,\n",
    "    panel_height=8,\n",
    "    annotate_panels=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "FIG_1_DATASET = \"country-capital\"\n",
    "\n",
    "prompt_fv = torch.load(\n",
    "    f\"/checkpoint/guyd/function_vectors/full_results_prompt_based_short/{FIG_1_MODEL}/{FIG_1_DATASET}/country-capital_20_universal_fv.pt\"\n",
    ")\n",
    "\n",
    "icl_fv = torch.load(\n",
    "    f\"/checkpoint/guyd/function_vectors/full_icl_results/{FIG_1_MODEL}/{FIG_1_DATASET}/country-capital_20_universal_fv.pt\"\n",
    ")\n",
    "\n",
    "FIG_1_FVS = {\n",
    "    FIG_1_MODEL: {\n",
    "        BOTH: prompt_fv,\n",
    "        ICL: icl_fv,\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_top_head_vectors(\n",
    "    heads_df: pd.DataFrame,\n",
    "    model_names: typing.List[str],\n",
    "    # model_to_fvs: typing.Dict[str, typing.Dict[str, torch.Tensor]],\n",
    "    # fv_shape: typing.Tuple[int, int],\n",
    "    vector_cmaps: typing.Tuple,\n",
    "    n_heads: int | typing.Dict[str, int] = 20,\n",
    "    limits_from_model_layers: bool = True,\n",
    "    draw_mean_lines: bool = False,\n",
    "    shape: typing.Tuple[int, int] | None = None,\n",
    "    panel_width: int = 2,\n",
    "    panel_height: int = 6,\n",
    "    legend_panel_height: float = 0.2,\n",
    "    fontsize: int = 12,\n",
    "    font_inc: int = 4,\n",
    "    fontfamily: str | None = None,\n",
    "    save_name: typing.Optional[str] = None,\n",
    "    show_legend: bool = True,\n",
    "    legend_outside: bool = True,\n",
    "    legend_loc: str | typing.Tuple[int, int] = None,\n",
    "    legend_ax_index: int = -1,\n",
    "    legend_ncol: int = 1,\n",
    "    legend_bbox_to_anchor: typing.Tuple[float, float] = (1.05, 1),\n",
    "    minimal_mode: bool = False,\n",
    "    grid_spec_kwargs: typing.Optional[typing.Dict[str, typing.Any]] = None,\n",
    "    annotate_panels: bool = False,\n",
    "    annotate_panel_position: typing.Tuple[float, float] = (1.01, 1.0),\n",
    "    annotate_font_inc: int = 4,\n",
    "):\n",
    "    if isinstance(n_heads, int):\n",
    "        n_heads = {model_name: n_heads for model_name in model_names}\n",
    "    \n",
    "    shape = (len(model_names), 2)\n",
    "\n",
    "    # shape = (len(model_to_fvs), 2)\n",
    "\n",
    "\n",
    "    if grid_spec_kwargs is None:\n",
    "        grid_spec_kwargs = dict(hspace=0.2, wspace=0.25)\n",
    "\n",
    "    gs = GridSpec(shape[0], shape[1], **grid_spec_kwargs)\n",
    "\n",
    "    fig = plt.figure(\n",
    "        layout=\"constrained\",\n",
    "        figsize=(shape[1] * panel_width, shape[0] * panel_height + int(show_legend)),\n",
    "    )\n",
    "\n",
    "    for i, model_name in enumerate(model_names):\n",
    "    # for r, model_name in enumerate(model_to_fvs):\n",
    "        n = n_heads[model_name]\n",
    "        r = i\n",
    "\n",
    "        mdf = heads_df[(heads_df.model == model_name) & (heads_df.n_heads == n)]\n",
    "        prompt_heads = mdf.prompt_heads.values[0]\n",
    "        icl_heads = mdf.icl_heads.values[0]\n",
    "\n",
    "        head_array = np.zeros(MODEL_TO_N_LAYERS_HEADS[model_name], dtype=int)\n",
    "        for head in prompt_heads:\n",
    "            head_array[head] += 1\n",
    "\n",
    "        for head in icl_heads:\n",
    "            head_array[head] += 2\n",
    "\n",
    "        for c, (ignore_value, cmap) in enumerate(zip(\n",
    "            (2, 1),\n",
    "            vector_cmaps,\n",
    "        )):\n",
    "            ax = fig.add_subplot(gs[r, c])        \n",
    "\n",
    "            head_vec = np.copy(head_array)\n",
    "            head_vec[head_vec == ignore_value] == 0\n",
    "            head_vec = head_vec.sum(axis=0)\n",
    "\n",
    "            ax.imshow(head_vec.T[:, None], cmap=cmap)\n",
    "            ax.set_xticks([])\n",
    "            ax.set_yticks([])\n",
    "\n",
    "        # for c, (fv_type, cmap) in enumerate(zip(\n",
    "        #     (BOTH, ICL),\n",
    "        #     vector_cmaps,\n",
    "        # )):\n",
    "        #     ax = fig.add_subplot(gs[r, c])\n",
    "\n",
    "        #     fv = model_to_fvs[model_name][fv_type]\n",
    "        #     # head_vec = np.copy(head_array)\n",
    "        #     # head_vec[head_vec == ignore_value] == 0\n",
    "        #     # head_vec = head_vec.sum(axis=0)\n",
    "\n",
    "        #     ax.imshow(fv.view(fv_shape).cpu().numpy(), cmap=cmap)\n",
    "        #     ax.set_xticks([])\n",
    "        #     ax.set_yticks([])\n",
    "\n",
    "    \n",
    "    # plt.tight_layout()\n",
    "    if save_name is not None:\n",
    "        save_plot(save_name)\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "WIDTH = 32\n",
    "D_MODEL = 4096\n",
    "\n",
    "\n",
    "plot_top_head_vectors(\n",
    "    top_heads_summary_df,\n",
    "    [MAIN_PAPER_PLOT_MODELS[-1]],\n",
    "    # FIG_1_FVS,\n",
    "    # (D_MODEL // WIDTH, WIDTH),\n",
    "    vector_cmaps=(plt.cm.Reds, plt.cm.Blues),\n",
    "    shape=(1, 1),\n",
    "    draw_mean_lines=False,\n",
    "    show_legend=False,\n",
    "    minimal_mode=True,\n",
    "    # legend_panel_height=0.05,\n",
    "    grid_spec_kwargs=dict(hspace=0.5),\n",
    "    fontfamily=\"monospace\",\n",
    "    save_name=\"figure_1_fvs.png\",\n",
    "    panel_width=1,\n",
    "    panel_height=6,\n",
    "    annotate_panels=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Compare activation similarity only in shared heads\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "DEFAULT_UNIVERSAL_FV_TYPES = {\n",
    "    SHORT: f\"{BOTH}_{ALL}\",\n",
    "    LONG: f\"{BOTH}_{ALL}\",\n",
    "    ICL: \"\",\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mld = layer_key_dicts[\"Llama-3.2-3B @ 20\"]\n",
    "set(mld[\"prompt_heads\"]) & set(mld[\"icl_heads\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "\n",
    "\n",
    "def load_shared_head_activations(\n",
    "    models: typing.List[str],\n",
    "    key_dicts: typing.Dict[str, typing.Dict[str, typing.Any]],\n",
    "    n_top_heads: int = 20,\n",
    "    skip_datasets: typing.List[str] = SKIP_DATASETS,\n",
    "):\n",
    "    results_by_model = {}\n",
    "\n",
    "    for model in tqdm(models):\n",
    "        model_dict = key_dicts[f\"{model} @ {n_top_heads}\"]\n",
    "        shared_heads = set(model_dict[\"prompt_heads\"]) & set(model_dict[\"icl_heads\"])\n",
    "        model_results = defaultdict(lambda: defaultdict(dict))\n",
    "\n",
    "        for result_type, results_path_str in RESULT_ROOTS.items():\n",
    "            results_path = Path(results_path_str)\n",
    "\n",
    "            model_results_path = results_path / model\n",
    "            if not model_results_path.exists():\n",
    "                logger.warning(f\"Model results path {model_results_path} does not exist.\")\n",
    "                continue\n",
    "\n",
    "            for model_dataset_path in model_results_path.iterdir():\n",
    "                if model_dataset_path.name in skip_datasets:\n",
    "                    continue\n",
    "\n",
    "                dataset_name = model_dataset_path.name\n",
    "                mean_activations_path = model_dataset_path / f\"{dataset_name}_mean_head_activations.pt\"\n",
    "                if not mean_activations_path.exists():\n",
    "                    logger.warning(f\"Mean activations path {mean_activations_path} does not exist.\")\n",
    "                    continue\n",
    "\n",
    "                mean_activations = torch.load(mean_activations_path)\n",
    "                for L, H in shared_heads:\n",
    "                    model_results[dataset_name][(L, H)][result_type] = mean_activations[L, H, -1]\n",
    "\n",
    "        results_by_model[model] = model_results\n",
    "\n",
    "    return results_by_model\n",
    "\n",
    "\n",
    "shared_head_activations_by_model = load_shared_head_activations(MAIN_PAPER_PLOT_MODELS + APPENDIX_MODELS, layer_key_dicts)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.nn.functional import cosine_similarity\n",
    "\n",
    "key_pairs = list(itertools.combinations(RESULT_ROOTS.keys(), 2))\n",
    "\n",
    "\n",
    "rows = []\n",
    "for model, model_activations in shared_head_activations_by_model.items():\n",
    "    similarities_by_rt = defaultdict(list)\n",
    "    for dataset_name, dataset_activations in model_activations.items():\n",
    "        for (L, H), result_type_activations in dataset_activations.items():\n",
    "            for rt1, rt2 in key_pairs:\n",
    "                if rt1 not in result_type_activations or rt2 not in result_type_activations:\n",
    "                    continue\n",
    "                similarity = cosine_similarity(result_type_activations[rt1], result_type_activations[rt2], dim=0)\n",
    "                similarities_by_rt[(rt1, rt2)].append(similarity)\n",
    "\n",
    "    rows.append(\n",
    "        dict(\n",
    "            model=model,\n",
    "            **{\n",
    "                f\"{rt1} vs. {rt2}\": f\"{np.mean(sims):.4f} ± {np.std(sims):.4f} (n = {len(sims)})\"\n",
    "                for (rt1, rt2), sims in similarities_by_rt.items()\n",
    "            },\n",
    "        )\n",
    "    )\n",
    "\n",
    "\n",
    "display(Markdown(tabulate.tabulate(rows, headers=\"keys\", tablefmt=\"github\")))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.nn.functional import cosine_similarity\n",
    "\n",
    "key_pairs = list(itertools.combinations(RESULT_ROOTS.keys(), 2))\n",
    "\n",
    "\n",
    "rows = []\n",
    "similarities_by_model_head_rt = dict()\n",
    "for model, model_activations in shared_head_activations_by_model.items():\n",
    "    similarities_by_model_head_rt[model] = defaultdict(lambda: defaultdict(list))\n",
    "    for dataset_name, dataset_activations in model_activations.items():\n",
    "        for head, result_type_activations in dataset_activations.items():\n",
    "            for rt1, rt2 in key_pairs:\n",
    "                if rt1 not in result_type_activations or rt2 not in result_type_activations:\n",
    "                    continue\n",
    "                similarity = cosine_similarity(result_type_activations[rt1], result_type_activations[rt2], dim=0).item()\n",
    "                similarities_by_model_head_rt[model][head][(rt1, rt2)].append(similarity)\n",
    "\n",
    "    for head in sorted(similarities_by_model_head_rt[model].keys()):\n",
    "        head_results = similarities_by_model_head_rt[model][head]\n",
    "        rows.append(\n",
    "            dict(\n",
    "                model=model,\n",
    "                head=head,\n",
    "                **{\n",
    "                    f\"{rt1} vs. {rt2}\": f\"{np.mean(sims):.4f} ± {np.std(sims):.4f} (n = {len(sims)})\"\n",
    "                    for (rt1, rt2), sims in head_results.items()\n",
    "                },\n",
    "            )\n",
    "        )\n",
    "\n",
    "\n",
    "headers = rows[0].keys()\n",
    "rows = [[d[k] for k in d] for d in rows]\n",
    "model_counts = list(Counter([row[0] for row in rows]).values())\n",
    "sep_rows = list(np.cumsum(model_counts))[:-1]\n",
    "if sep_rows is not None:\n",
    "    for i in sorted(sep_rows, reverse=True):\n",
    "        rows.insert(i, tabulate.SEPARATING_LINE)\n",
    "\n",
    "\n",
    "display(Markdown(tabulate.tabulate(rows, headers=headers, tablefmt=\"github\")))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "similarities_by_model_head_rt[\"Llama-3.1-8B\"][(17, 5)].keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from palettable.colorbrewer.qualitative import Paired_6 as heads_similarity_cmap\n",
    "from palettable.colorbrewer.qualitative import Set1_3 as heads_cmap\n",
    "\n",
    "light_to_dark_diff = [\n",
    "    c1 - c2 for c1, c2 in zip(heads_similarity_cmap.mpl_colors[0], heads_similarity_cmap.mpl_colors[1])\n",
    "]\n",
    "set1_light_blue = [c + (d * 0.75) for c, d in zip(heads_cmap.mpl_colors[1], light_to_dark_diff)]\n",
    "\n",
    "SIMILARITY_KEY_STYLES = {\n",
    "    (\"short\", \"long\"): {\"label\": \"Short & Long\\nInstructions\", \"color\": heads_cmap.mpl_colors[0]},\n",
    "    (\"short\", \"icl\"): {\"label\": \"Short Instructions\\n& Demonstrations\", \"color\": 'cyan'},\n",
    "    (\"long\", \"icl\"): {\"label\": \"Long Instructions\\n& Demonsrations\", \"color\": heads_cmap.mpl_colors[1]},\n",
    "}\n",
    "\n",
    "SIMILARITY_GLOBAL_PLOT_STYLE = {\n",
    "    \"markersize\": 10,\n",
    "    \"alpha\": 0.6,\n",
    "}\n",
    "\n",
    "\n",
    "def plot_head_activation_similarities(\n",
    "    similarities_data,\n",
    "    model_groups: typing.List[typing.List[str]],\n",
    "    model_plot_styles: typing.Dict[str, typing.Dict[str, typing.Any]] = MODEL_PLOT_STYLES,\n",
    "    similarity_key_styles: typing.Dict[typing.Tuple[str, str], str] = SIMILARITY_KEY_STYLES,\n",
    "    global_plot_style: typing.Dict[str, typing.Any] = SIMILARITY_GLOBAL_PLOT_STYLE,\n",
    "    shape: typing.Tuple[int, int] | None = None,\n",
    "    grid_spec_kwargs: typing.Optional[typing.Dict[str, typing.Any]] = None,\n",
    "    show_error_bars: bool = True,\n",
    "    limits_from_model_layers: bool = True,\n",
    "    jitter_range: float = 0.5,\n",
    "    fontsize: int = 16,\n",
    "    font_inc: int = 4,\n",
    "    text_font_inc: int = 2,\n",
    "    fontfamily: str | None = None,\n",
    "    textfontweight: str = \"semibold\",\n",
    "    annotate_panels: bool = True,\n",
    "    annotate_panels_start: str = \"A\",\n",
    "    annotate_panel_position: typing.Tuple[float, float] = (1.01, 1.0),\n",
    "    annotate_font_inc: int = 8,\n",
    "    panel_width: float = 6,\n",
    "    panel_height: float = 6,\n",
    "    xlabel: str = \"Layer\",\n",
    "    ylabel: str = \"Head Cosine Similarity\",\n",
    "    ylabel_first_ax_only: bool = True,\n",
    "    first_yicks_only: bool = False,\n",
    "    show_legend: bool = True,\n",
    "    legend_ax_index: int | None = None,\n",
    "    legend_outside: bool = True,\n",
    "    legend_loc: str | typing.Tuple[int, int] = None,\n",
    "    legend_bbox_to_anchor: typing.Tuple[float, float] = (1.05, 1),\n",
    "    legend_fontsize: int | None = None,\n",
    "    legend_width: float = 2.0,\n",
    "    legend_ncol: int = 5,\n",
    "    legend_panel_height: float = 0.2,\n",
    "    legend_order: typing.List[int] | None = None,\n",
    "    subplots_adjust: typing.Dict[str, float] | None = None,\n",
    "    save_name: str | None = None,\n",
    "):\n",
    "    if global_plot_style is None:\n",
    "        global_plot_style = {}\n",
    "\n",
    "    global_plot_style = {**GLOBAL_PLOT_STYLE, **global_plot_style}\n",
    "    n_model_groups = len(model_groups)\n",
    "\n",
    "    if shape is None:\n",
    "        shape = (1, n_model_groups)\n",
    "\n",
    "    if grid_spec_kwargs is None:\n",
    "        grid_spec_kwargs = dict(hspace=0.2, wspace=0.25)\n",
    "\n",
    "    # if show_legend:\n",
    "    #     gs = GridSpec(shape[0] + 1, shape[1], height_ratios=[1] * shape[0] + [legend_panel_height], **grid_spec_kwargs)\n",
    "    # else:\n",
    "    gs = GridSpec(shape[0], shape[1], **grid_spec_kwargs)\n",
    "\n",
    "    fig = plt.figure(\n",
    "        layout=\"constrained\",\n",
    "        figsize=(shape[1] * panel_width, shape[0] * panel_height + int(show_legend)),\n",
    "    )\n",
    "\n",
    "    if legend_ax_index is None:\n",
    "        legend_ax_index = len(model_groups) - 1\n",
    "\n",
    "    for i, models in enumerate(model_groups):\n",
    "        ax = fig.add_subplot(gs[0, i] if shape[0] == 1 else gs[i, 0])\n",
    "\n",
    "        for model in models:\n",
    "            model_sims = similarities_data[model]\n",
    "            model_style = model_plot_styles[model]\n",
    "\n",
    "            for similarity_key, similarity_style in similarity_key_styles.items():\n",
    "                model_sim_key_data = [\n",
    "                    (\n",
    "                        L + np.random.uniform(-jitter_range, jitter_range),\n",
    "                        np.mean(sims[similarity_key]),\n",
    "                        np.std(sims[similarity_key]) / (len(sims[similarity_key]) ** 0.5),\n",
    "                    )\n",
    "                    for (L, H), sims in model_sims.items()\n",
    "                ]\n",
    "\n",
    "                x, y, err = zip(*model_sim_key_data)\n",
    "                style = {**global_plot_style, **model_style, **similarity_style}\n",
    "                error_bar_style = {**style}\n",
    "                style[\"linewidth\"] = 0\n",
    "                ax.plot(x, y, **style, zorder=1 if SHORT in similarity_key else -1)\n",
    "                if show_error_bars:\n",
    "                    ax.errorbar(x, y, yerr=err, fmt=\"none\", capsize=5, elinewidth=2, **error_bar_style)\n",
    "\n",
    "        ax.set_ylim(0, 1.02)\n",
    "        if limits_from_model_layers:\n",
    "            model_layers, _ = MODEL_TO_N_LAYERS_HEADS[model]\n",
    "            ax.set_xlim(0, model_layers)\n",
    "\n",
    "            step = 1\n",
    "            for divider in range(10, 0, -1):\n",
    "                if model_layers % divider == 0:\n",
    "                    step = model_layers / divider\n",
    "                    break\n",
    "\n",
    "            ax.set_xticks(np.arange(0, model_layers + step, step))\n",
    "\n",
    "        if first_yicks_only and i > 0:\n",
    "            ax.set_yticks([])\n",
    "\n",
    "        ax.tick_params(axis=\"both\", labelsize=fontsize - text_font_inc, labelfontfamily=fontfamily)\n",
    "        ax.set_xlabel(xlabel, fontsize=fontsize + font_inc, fontfamily=fontfamily, fontweight=textfontweight)\n",
    "        if i == 0 or not ylabel_first_ax_only:\n",
    "            ax.set_ylabel(ylabel, fontsize=fontsize + font_inc, fontfamily=fontfamily, fontweight=textfontweight)\n",
    "        ax.set_title(\n",
    "            f\"{models[0]} Family\", fontsize=fontsize + font_inc, fontfamily=fontfamily, fontweight=textfontweight\n",
    "        )\n",
    "\n",
    "        if annotate_panels:\n",
    "            ax.text(\n",
    "                *annotate_panel_position,\n",
    "                chr(ord(annotate_panels_start) + i),\n",
    "                ha=\"left\",\n",
    "                va=\"top\",\n",
    "                fontsize=fontsize + annotate_font_inc,\n",
    "                fontweight=\"bold\",\n",
    "                fontfamily=fontfamily,\n",
    "                transform=ax.transAxes,\n",
    "            )\n",
    "\n",
    "        if show_legend and i == legend_ax_index:\n",
    "            legend_entries = [\n",
    "                # plt.Line2D([0], [0], color=color, linestyle=\"-\", linewidth=5, label=label)\n",
    "                matplotlib.patches.Rectangle((0, 0), 1, 1, color=style[\"color\"], label=style[\"label\"])\n",
    "                for style in similarity_key_styles.values()\n",
    "            ]\n",
    "            legend_entries.append(\n",
    "                matplotlib.lines.Line2D(\n",
    "                    [0], [0], marker=\"o\", color=\"white\", markerfacecolor=\"black\", markersize=12, label=\"Base Models\"\n",
    "                )\n",
    "            )\n",
    "            legend_entries.append(matplotlib.patches.Rectangle((0, 0), 1, 1, color=\"black\", label=\"Instruct Models\"))\n",
    "\n",
    "            legend_labels = [style[\"label\"] for style in similarity_key_styles.values()]\n",
    "            legend_labels.append(\"Base Models\")\n",
    "            legend_labels.append(\"Instruct Models\")\n",
    "\n",
    "            if legend_order is not None:\n",
    "                legend_entries = [legend_entries[i] for i in legend_order]\n",
    "                legend_labels = [legend_labels[i] for i in legend_order]\n",
    "\n",
    "            legend_kwargs = dict(\n",
    "                ncol=legend_ncol,\n",
    "                prop=dict(family=fontfamily, size=fontsize - font_inc if legend_fontsize is None else legend_fontsize),\n",
    "                handlelength=0.75,\n",
    "            )\n",
    "            if legend_outside:\n",
    "                legend_kwargs[\"bbox_to_anchor\"] = legend_bbox_to_anchor\n",
    "                legend_kwargs[\"loc\"] = legend_loc\n",
    "            elif legend_loc is not None:\n",
    "                legend_kwargs[\"loc\"] = legend_loc\n",
    "\n",
    "            # legend_ax = fig.add_subplot(gs[-1, :])\n",
    "            # legend_ax.axis(\"off\")\n",
    "\n",
    "            # legend_ax.legend(legend_entries, legend_labels, **legend_kwargs)\n",
    "            ax.legend(legend_entries, legend_labels, **legend_kwargs)\n",
    "            \n",
    "\n",
    "    plt.tight_layout()\n",
    "\n",
    "    if save_name is not None:\n",
    "        save_plot(save_name)\n",
    "\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "plot_head_activation_similarities(\n",
    "    similarities_by_model_head_rt,\n",
    "    [MAIN_PAPER_PLOT_MODELS[:2], MAIN_PAPER_PLOT_MODELS[2:]],\n",
    "    shape=(2, 1),\n",
    "    show_error_bars=False,\n",
    "    ylabel_first_ax_only=False,\n",
    "    first_yicks_only=False,\n",
    "    panel_height=4,\n",
    "    panel_width=4,\n",
    "    fontsize=14,\n",
    "    font_inc=2,\n",
    "    # legend_bbox_to_anchor=(1.01, 0.55),\n",
    "    # legend_outside=True,\n",
    "    # legend_ax_index=0,\n",
    "    # legend_ncol=1,\n",
    "    # legend_fontsize=10,\n",
    "    legend_outside=False,\n",
    "    legend_loc=(0.6, 0),\n",
    "    # legend_loc=(-0.15, 0),\n",
    "    legend_ax_index=1,\n",
    "    legend_ncol=1,\n",
    "    legend_fontsize=10,\n",
    "    annotate_panels=True,\n",
    "    annotate_panels_start=\"E\",\n",
    "    fontfamily=\"monospace\",\n",
    "    grid_spec_kwargs=dict(hspace=0.3),\n",
    "    save_name=\"finding_3_head_similarities.pdf\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_head_activation_similarities(\n",
    "    similarities_by_model_head_rt,\n",
    "    [APPENDIX_MODELS[0:2], APPENDIX_MODELS[2:4], APPENDIX_MODELS[4:8]],\n",
    "    shape=(1, 3),\n",
    "    show_error_bars=False,\n",
    "    ylabel_first_ax_only=False,\n",
    "    first_yicks_only=False,\n",
    "    panel_height=4,\n",
    "    panel_width=6,\n",
    "    fontsize=14,\n",
    "    font_inc=2,\n",
    "    # legend_bbox_to_anchor=(1.01, 0.55),\n",
    "    # legend_outside=True,\n",
    "    # legend_ax_index=0,\n",
    "    # legend_ncol=1,\n",
    "    # legend_fontsize=10,\n",
    "    legend_outside=False,\n",
    "    legend_loc=\"lower left\",\n",
    "    # legend_loc=(-0.15, 0),\n",
    "    legend_ax_index=0,\n",
    "    legend_ncol=1,\n",
    "    legend_fontsize=10,\n",
    "    annotate_panels=True,\n",
    "    annotate_panels_start=\"A\",\n",
    "    fontfamily=\"monospace\",\n",
    "    grid_spec_kwargs=dict(hspace=0.3, wspace=0.3),\n",
    "    save_name=\"appendix_finding_3_head_similarities.pdf\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Compute mean IE by baseline/length separately to show similarity"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "N_TOP_HEADS = 20\n",
    "top_head_rows = []\n",
    "split_mean_ie_by_model = {}\n",
    "\n",
    "tqdm_total = len(ORDERED_MODELS) * 3 * (len(RELEVANT_BASELINES) + 1)\n",
    "prompt_datasets_below_chance_acc = defaultdict(set)\n",
    "\n",
    "for model, prompt_types, baseline in tqdm(\n",
    "    itertools.product(\n",
    "        ORDERED_MODELS, [SHORT, LONG, [SHORT, LONG]], RELEVANT_BASELINES + [RELEVANT_BASELINES]\n",
    "    ),\n",
    "    total=tqdm_total,\n",
    "    desc=\"Universal top heads\",\n",
    "):\n",
    "    top_heads, top_head_effects, mean_ie = compute_top_heads(\n",
    "        result_df,\n",
    "        indirect_effects_by_model_and_dataset,\n",
    "        model,\n",
    "        prompt_types,\n",
    "        baseline,\n",
    "        n_top_heads=N_TOP_HEADS,\n",
    "        datasets_below_chance_acc=prompt_datasets_below_chance_acc,\n",
    "        return_mean=True,\n",
    "    )\n",
    "    pt = prompt_types if isinstance(prompt_types, str) == 1 else BOTH\n",
    "    bl = baseline if isinstance(baseline, str) else ALL\n",
    "    top_head_rows.append(\n",
    "        dict(\n",
    "            model=model,\n",
    "            prompt_type=pt,\n",
    "            baseline=bl,\n",
    "            n=N_TOP_HEADS,\n",
    "            top_heads=set(tuple(t) for t in top_heads),\n",
    "            top_head_effects=top_head_effects,\n",
    "            top_heads_list=top_heads,\n",
    "        )\n",
    "    )\n",
    "    split_mean_ie_by_model[(model, pt, bl)] = mean_ie\n",
    "\n",
    "for model, datasets in prompt_datasets_below_chance_acc.items():\n",
    "    if len(datasets) > 0:\n",
    "        logger.warning(f\"Model {model} datasets below chance accuracy: {', '.join(sorted(datasets))}\")\n",
    "\n",
    "\n",
    "icl_datasets_below_chance_acc = defaultdict(set)\n",
    "\n",
    "for model in tqdm(\n",
    "    ORDERED_MODELS,\n",
    "    total=len(ORDERED_MODELS),\n",
    "    desc=\"Universal top heads ICL\",\n",
    "):\n",
    "    prompt_type = ICL\n",
    "    baseline = ICL\n",
    "    top_heads, top_head_effects, mean_ie = compute_top_heads(\n",
    "        result_df,\n",
    "        indirect_effects_by_model_and_dataset,\n",
    "        model,\n",
    "        prompt_type,\n",
    "        baseline,\n",
    "        n_top_heads=N_TOP_HEADS,\n",
    "        datasets_below_chance_acc=icl_datasets_below_chance_acc,\n",
    "        return_mean=True,\n",
    "    )\n",
    "    top_head_rows.append(\n",
    "        dict(\n",
    "            model=model,\n",
    "            prompt_type=prompt_type,\n",
    "            baseline=baseline,\n",
    "            n=N_TOP_HEADS,\n",
    "            top_heads=set(tuple(t) for t in top_heads),\n",
    "            top_head_effects=top_head_effects,\n",
    "            top_heads_list=top_heads,\n",
    "        )\n",
    "    )\n",
    "    split_mean_ie_by_model[(model, prompt_type, baseline)] = mean_ie\n",
    "\n",
    "for model, datasets in icl_datasets_below_chance_acc.items():\n",
    "    if len(datasets) > 0:\n",
    "        logger.warning(f\"Model {model} datasets below chance accuracy: {', '.join(sorted(datasets))}\")\n",
    "\n",
    "\n",
    "universal_top_heads_df = pd.DataFrame(top_head_rows)\n",
    "\n",
    "split_universal_top_heads_dfs = {\n",
    "    model: universal_top_heads_df[(universal_top_heads_df.model == model)]\n",
    "    .copy(deep=True)\n",
    "    .reset_index(drop=True)\n",
    "    for model in ORDERED_MODELS\n",
    "}\n",
    "\n",
    "\n",
    "universal_top_heads_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for model_name in MAIN_PLOT_MODELS:\n",
    "    model_top_heads_df = universal_top_heads_df[\n",
    "        (universal_top_heads_df.model == model_name) & \n",
    "        ~(universal_top_heads_df.baseline.isin([ICL, ALL])) &\n",
    "        (universal_top_heads_df.prompt_type != BOTH)\n",
    "    ]\n",
    "\n",
    "    def model_pt_baseline_rename(row):\n",
    "        pt = row.prompt_type\n",
    "        bl = row.baseline\n",
    "        return f\"{row.model}_{pt}_{bl}\"\n",
    "\n",
    "\n",
    "    model_top_heads_df = model_top_heads_df.rename(columns=dict(n=\"n_heads\", top_heads=\"prompt_heads\"))\n",
    "    model_top_heads_df = model_top_heads_df.assign(\n",
    "        model=model_top_heads_df.apply(model_pt_baseline_rename, axis=1),\n",
    "        icl_heads=model_top_heads_df.apply(lambda row: [], axis=1),\n",
    "    )\n",
    "    model_titles = {\n",
    "        m: f\"{m.split('_')[1].capitalize()} instructions\\n{m.split('_', 2)[2]} baseline\"\n",
    "        for m in model_top_heads_df.model.values   \n",
    "    }\n",
    "\n",
    "    plot_shared_top_heads(\n",
    "        model_top_heads_df,\n",
    "        list(model_titles.keys()),\n",
    "        shape=(2, 3),\n",
    "        draw_mean_lines=False,\n",
    "        show_legend=False,\n",
    "        legend_ax_index=2,\n",
    "        legend_outside=False,\n",
    "        legend_loc=(-0.05, 0),\n",
    "        legend_ncol=3,\n",
    "        legend_panel_height=0.05,\n",
    "        panel_height=4,\n",
    "        grid_spec_kwargs=dict(hspace=0.2, wspace=0.25, top=0.9),\n",
    "        fontfamily=\"monospace\",\n",
    "        textfontweight=\"bold\",\n",
    "        save_name=f\"appendix_finding_3_split_top_heads_{model_name}.pdf\",\n",
    "        annotate_panels=True,\n",
    "        annotate_panels_start=\"A\",\n",
    "        fontsize=14,\n",
    "        model_titles=model_titles,\n",
    "        suptitle=model_name,\n",
    "    )\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And one that's more of a heatmap by how many times each head was counter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from palettable.colorbrewer.sequential import Reds_6 as heads_heatmap_cmap\n",
    "\n",
    "heads_heatmap_colors = [(1, 1, 1), *heads_heatmap_cmap.mpl_colors]\n",
    "\n",
    "heatmap_df_rows = []\n",
    "\n",
    "for model_name in MAIN_PLOT_MODELS:\n",
    "    model_top_heads_df = universal_top_heads_df[\n",
    "        (universal_top_heads_df.model == model_name) & \n",
    "        ~(universal_top_heads_df.baseline.isin([ICL, ALL])) &\n",
    "        (universal_top_heads_df.prompt_type != BOTH)\n",
    "    ]\n",
    "\n",
    "    all_top_heads = [tuple(th) for th in itertools.chain.from_iterable(model_top_heads_df.top_heads_list)]\n",
    "    heatmap_df_rows.append(\n",
    "        dict(\n",
    "            model=model_name,\n",
    "            n_heads=N_TOP_HEADS,\n",
    "            prompt_heads=all_top_heads,\n",
    "            icl_heads=[],\n",
    "        )\n",
    "    )\n",
    "\n",
    "heatmap_df = pd.DataFrame(heatmap_df_rows)\n",
    "\n",
    "plot_shared_top_heads(\n",
    "    heatmap_df,\n",
    "    list(heatmap_df.model.values),\n",
    "    top_heads_colors=heads_heatmap_colors,\n",
    "    shape=(2, 2),\n",
    "    draw_mean_lines=False,\n",
    "    show_legend=False,\n",
    "    legend_ax_index=2,\n",
    "    legend_outside=False,\n",
    "    legend_loc=(-0.05, 0),\n",
    "    legend_ncol=3,\n",
    "    legend_panel_height=0.05,\n",
    "    panel_width=4.5,\n",
    "    grid_spec_kwargs=dict(hspace=0.3, wspace=0.3),\n",
    "    fontfamily=\"monospace\",\n",
    "    textfontweight=\"bold\",\n",
    "    save_name=\"appendix_finding_3_top_heads_heatmap.pdf\",\n",
    "    annotate_panels=True,\n",
    "    annotate_panels_start=\"A\",\n",
    "    fontsize=14,\n",
    "    model_titles=model_titles,\n",
    "    vmax=6,\n",
    "    add_colorbar=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# RQ2.1 Top heads -- second panel\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def repeat_elements(items, n):\n",
    "    return [item for item in items for _ in range(n)]\n",
    "\n",
    "\n",
    "DEFAULT_METRIC_NAMES = [\"prompt_heads\", \"icl_heads\"]\n",
    "SHARED = \"shared\"\n",
    "\n",
    "DEFUALT_ANNOTATE_KWARGS = dict(\n",
    "    marker=\"*\",\n",
    "    s=200,\n",
    "    color=\"gold\",\n",
    "    edgecolor=\"black\",\n",
    ")\n",
    "\n",
    "\n",
    "def take_first_element(t):\n",
    "    return t[0]\n",
    "\n",
    "\n",
    "def plot_layer_rows_individual_dots(\n",
    "    layer_info_dicts: typing.Dict[str, typing.Dict[str, float]],\n",
    "    n_top_heads: int,\n",
    "    metric_names: str | typing.Sequence[str] = DEFAULT_METRIC_NAMES,\n",
    "    models: typing.Sequence[str] | None = None,\n",
    "    colors: typing.Sequence[str] | None = None,\n",
    "    title: str | None = None,\n",
    "    ylabel: str | None = None,\n",
    "    ylabel_first_ax_only: bool = True,\n",
    "    fontsize: int = 12,\n",
    "    font_inc: int = 4,\n",
    "    fontfamily: str | None = None,\n",
    "    metric_labels: typing.Dict[str, str] = None,\n",
    "    metric_name_to_plot_kwargs: typing.Dict[str, typing.Dict[str, typing.Any]] = None,\n",
    "    metric_name_to_err_metric: typing.Dict[str, str] = None,\n",
    "    tuple_to_plot_values: typing.Callable = take_first_element,\n",
    "    annotate_values: typing.Dict[str, typing.Dict[str, float]] | None = None,\n",
    "    annotate_kwargs: typing.Dict[str, typing.Any] = None,\n",
    "    ylim: typing.Tuple[float, float] | None = None,\n",
    "    ylim_from_model_layers: bool = False,\n",
    "    panel_width: int = 6,\n",
    "    panel_height: int = 6,\n",
    "    annotate_panels: bool = False,\n",
    "    annotate_panel_position: typing.Tuple[float, float] = (1.01, 1.0),\n",
    "    annotate_font_inc: int = 4,\n",
    "    save_name: str | None = None,\n",
    "    **global_plot_kwargs,\n",
    "):\n",
    "    if models is None:\n",
    "        models = RELEVANT_MODELS[:]\n",
    "\n",
    "    if isinstance(metric_names, str):\n",
    "        metric_names = [metric_names]\n",
    "\n",
    "    if metric_name_to_plot_kwargs is None:\n",
    "        metric_name_to_plot_kwargs = dict()\n",
    "\n",
    "    if metric_name_to_err_metric is None:\n",
    "        metric_name_to_err_metric = dict()\n",
    "\n",
    "    if annotate_values is None:\n",
    "        annotate_values = dict()\n",
    "\n",
    "    if annotate_kwargs is None:\n",
    "        annotate_kwargs = dict()\n",
    "\n",
    "    if metric_labels is None:\n",
    "        metric_labels = dict()\n",
    "\n",
    "    annotate_kwargs = {**DEFUALT_ANNOTATE_KWARGS, **annotate_kwargs}\n",
    "\n",
    "    rows = []\n",
    "    for model in models:\n",
    "        base_model = model.replace(\"-Instruct\", \"\").replace(\"-chat\", \"\")\n",
    "        is_instruct = (\"-Instruct\" in model) or (\"-chat\" in model)\n",
    "\n",
    "        all_metric_values_by_model = {\n",
    "            metric_name: layer_info_dicts[f\"{model} @ {n_top_heads}\"].get(metric_name, np.nan)\n",
    "            for metric_name in metric_names\n",
    "        }\n",
    "\n",
    "        shared_values = set()\n",
    "        if isinstance(all_metric_values_by_model[metric_names[0]], (tuple, list)):\n",
    "            shared_values = set(all_metric_values_by_model[metric_names[0]])\n",
    "            for metric_name in metric_names[1:]:\n",
    "                shared_values &= set(all_metric_values_by_model[metric_name])\n",
    "\n",
    "        skip_values = set()\n",
    "\n",
    "        for metric_name in metric_names:\n",
    "            model_metric_val = all_metric_values_by_model[metric_name]\n",
    "            if isinstance(model_metric_val, (int, float)):\n",
    "                rows.append(\n",
    "                    dict(\n",
    "                        model=model,\n",
    "                        base_model=base_model,\n",
    "                        is_instruct=is_instruct,\n",
    "                        metric_name=metric_name,\n",
    "                        model_metric=f\"{model}_{metric_name}\",\n",
    "                        value=model_metric_val,\n",
    "                    )\n",
    "                )\n",
    "            elif isinstance(model_metric_val, list):\n",
    "                for i, val in enumerate(model_metric_val):\n",
    "                    d = dict(\n",
    "                        model=model,\n",
    "                        base_model=base_model,\n",
    "                        is_instruct=is_instruct,\n",
    "                        metric_name=metric_name,\n",
    "                        model_metric=f\"{model}_{metric_name}\",\n",
    "                        value=val,\n",
    "                        index=i,\n",
    "                        shared=False,\n",
    "                    )\n",
    "                    if val in skip_values:\n",
    "                        continue\n",
    "                    if val in shared_values:\n",
    "                        skip_values.add(val)\n",
    "                        d[\"shared\"] = True\n",
    "                        d[\"metric_name\"] = SHARED\n",
    "\n",
    "                    if isinstance(val, tuple):\n",
    "                        d[\"value\"] = tuple_to_plot_values(val)\n",
    "\n",
    "                    rows.append(d)\n",
    "\n",
    "    all_values_df = pd.DataFrame(rows)\n",
    "    all_values_df = all_values_df.assign(\n",
    "        metric_name=all_values_df.metric_name.map(lambda x: metric_labels.get(x, x)),\n",
    "    )\n",
    "\n",
    "    n_panels = len(all_values_df.base_model.unique())\n",
    "    fig, axes = plt.subplots(1, n_panels, figsize=(panel_width * n_panels, panel_height))\n",
    "    axes = axes.flatten()\n",
    "\n",
    "    for b, base_model in enumerate(all_values_df.base_model.unique()):\n",
    "        base_model_df = all_values_df[all_values_df.base_model == base_model]\n",
    "        n_metrics = base_model_df.metric_name.nunique()\n",
    "        ax = axes[b]\n",
    "\n",
    "        base_model_colors = colors[b * n_metrics : (b + 1) * n_metrics]\n",
    "\n",
    "        hue_order = list(metric_labels.values())\n",
    "\n",
    "        sns.swarmplot(\n",
    "            data=base_model_df,\n",
    "            x=\"model\",\n",
    "            y=\"value\",\n",
    "            hue=\"metric_name\",\n",
    "            # style=\"metric_name\",\n",
    "            # kind=\"swarm\",\n",
    "            ax=ax,\n",
    "            palette=base_model_colors,\n",
    "            hue_order=hue_order,\n",
    "            **global_plot_kwargs,\n",
    "        )\n",
    "\n",
    "        sns.boxplot(\n",
    "            data=base_model_df,\n",
    "            x=\"model\",\n",
    "            y=\"value\",\n",
    "            hue=\"metric_name\",\n",
    "            hue_order=hue_order,\n",
    "            showmeans=True,\n",
    "            meanline=True,\n",
    "            meanprops={\"color\": \"k\", \"ls\": \"-\", \"lw\": 2},\n",
    "            medianprops={\"visible\": False},\n",
    "            whiskerprops={\"visible\": False},\n",
    "            zorder=10,\n",
    "            showfliers=False,\n",
    "            showbox=False,\n",
    "            showcaps=False,\n",
    "            palette=base_model_colors,\n",
    "            legend=None,\n",
    "            ax=ax,\n",
    "        )\n",
    "\n",
    "        ax.set_xlabel(\"Model\", fontsize=fontsize + font_inc, fontfamily=fontfamily)\n",
    "        if (not ylabel_first_ax_only) or b == 0:\n",
    "            ax.set_ylabel(\n",
    "                metric_name.replace(\"_\", \" \").capitalize() if ylabel is None else ylabel,\n",
    "                fontsize=fontsize + font_inc,\n",
    "                fontfamily=fontfamily,\n",
    "            )\n",
    "        else:\n",
    "            ax.set_ylabel(\"\")\n",
    "\n",
    "        ax_title = title\n",
    "        if ax_title is None:\n",
    "            ax_title = f\"{base_model} family (top {n_top_heads} heads)\"\n",
    "\n",
    "        ax.set_title(ax_title, fontsize=fontsize + (2 * font_inc), fontfamily=fontfamily)\n",
    "\n",
    "        ax.tick_params(axis=\"both\", labelsize=fontsize, labelfontfamily=fontfamily)\n",
    "\n",
    "        if ylim_from_model_layers:\n",
    "            model_layers = MODEL_TO_N_LAYERS[base_model]\n",
    "            ax.set_ylim(0, model_layers)\n",
    "            y_step = 1\n",
    "            for divider in range(10, 0, -1):\n",
    "                if model_layers % divider == 0:\n",
    "                    y_step = model_layers / divider\n",
    "                    break\n",
    "\n",
    "            ax.set_yticks(np.arange(0, model_layers + y_step, y_step))\n",
    "\n",
    "        elif ylim is not None:\n",
    "            ax.set_ylim(ylim)\n",
    "\n",
    "        if annotate_panels:\n",
    "            ax.text(\n",
    "                *annotate_panel_position,\n",
    "                chr(ord(\"A\") + b),\n",
    "                ha=\"left\",\n",
    "                va=\"top\",\n",
    "                fontsize=fontsize + annotate_font_inc,\n",
    "                fontweight=\"bold\",\n",
    "                fontfamily=fontfamily,\n",
    "                transform=ax.transAxes,\n",
    "            )\n",
    "\n",
    "        # Add a legend artist for the black line representing the mean\n",
    "        mean_line = plt.Line2D([0], [0], color=\"k\", linestyle=\"-\", linewidth=2, label=\"Mean\")\n",
    "        loc = \"lower right\" if \"Llama-3\" in base_model else \"best\"\n",
    "        ax.legend(\n",
    "            handles=ax.get_legend_handles_labels()[0] + [mean_line],\n",
    "            fontsize=fontsize - font_inc,\n",
    "            loc=loc,\n",
    "            prop=dict(family=fontfamily),\n",
    "        )\n",
    "\n",
    "    plt.tight_layout()\n",
    "    if save_name is not None:\n",
    "        save_plot(save_name)\n",
    "    plt.show()\n",
    "\n",
    "    return all_values_df\n",
    "\n",
    "\n",
    "from palettable.colorbrewer.qualitative import Paired_12 as cmap\n",
    "\n",
    "\n",
    "def flip_pairs(lst):\n",
    "    return [item for pair in zip(lst[1::2], lst[::2]) for item in pair]\n",
    "\n",
    "\n",
    "n_top_heads = 20\n",
    "color_indices = [0, 3, 1, 6, 3, 7]\n",
    "colors = [cmap.mpl_colors[i] for i in color_indices]\n",
    "\n",
    "adf = plot_layer_rows_individual_dots(\n",
    "    layer_key_dicts,\n",
    "    n_top_heads,\n",
    "    models=MAIN_PAPER_PLOT_MODELS,\n",
    "    colors=colors,\n",
    "    ylabel=\"Top head layer\",\n",
    "    fontsize=16,\n",
    "    font_inc=0,\n",
    "    fontfamily=\"monospace\",\n",
    "    metric_labels={\n",
    "        \"icl_heads\": \"Demonstration\",\n",
    "        SHARED: \"Shared\",\n",
    "        \"prompt_heads\": \"Instruction\",\n",
    "    },\n",
    "    ylim_from_model_layers=True,\n",
    "    dodge=True,\n",
    "    size=10,\n",
    "    annotate_panels=True,\n",
    "    annotate_font_inc=8,\n",
    "    save_name=\"rq2_1_top_heads_swarm.pdf\",\n",
    ")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Appendix plots with the causal effects for each model\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mean_ie_by_model = {}\n",
    "N_TOP_HEADS = 20\n",
    "\n",
    "for model in RELEVANT_MODELS:\n",
    "    for name, pt, bl in ((ICL, ICL, ICL), (BOTH, [SHORT, LONG], RELEVANT_BASELINES)):\n",
    "        top_heads, top_head_effects, mean_ie = compute_top_heads(\n",
    "            result_df,\n",
    "            indirect_effects_by_model_and_dataset,\n",
    "            model,\n",
    "            pt,\n",
    "            bl,\n",
    "            n_top_heads=N_TOP_HEADS,\n",
    "            return_mean=True,\n",
    "        )\n",
    "        mean_ie_by_model[(model, name)] = mean_ie\n",
    "\n",
    "mean_ie_by_model.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from palettable.scientific.diverging import Vik_20_r as ie_colormap\n",
    "\n",
    "\n",
    "def plot_mean_ies_by_model_and_type(\n",
    "    mean_ie_data: typing.Dict[typing.Tuple[str, str], torch.Tensor],\n",
    "    models: typing.List[str],\n",
    "    colormap,\n",
    "    limits_from_model_layers: bool = True,\n",
    "    force_cmap_zero_middle: bool = True,\n",
    "    shrink_cmap: bool = False,\n",
    "    cmap_max: float = 1.0,\n",
    "    ylabel_first_ax_only: bool = True,\n",
    "    fontsize: int = 12,\n",
    "    font_inc: int = 4,\n",
    "    fontfamily: str | None = None,\n",
    "    panel_width: int = 6,\n",
    "    panel_height: int = 6,\n",
    "    colormap_round: bool = True,\n",
    "    colormap_round_scale: int = 100,\n",
    "    colorbar_panel_width: float = 0.1,\n",
    "    grid_spec_kwargs: typing.Optional[typing.Dict[str, typing.Any]] = None,\n",
    "    save_name: str | None = None,\n",
    "):\n",
    "    shape = (len(models), 2)\n",
    "\n",
    "    if grid_spec_kwargs is None:\n",
    "        grid_spec_kwargs = dict(hspace=0.2, wspace=0.25)\n",
    "\n",
    "    gs = GridSpec(shape[0], shape[1] + 1, width_ratios=[1] * shape[1] + [colorbar_panel_width], **grid_spec_kwargs)\n",
    "\n",
    "    # fig, axes = plt.subplots(\n",
    "    #     shape[0], shape[1], figsize=(shape[1] * panel_width, shape[0] * panel_height + legend_panel_height),\n",
    "    # )\n",
    "    fig = plt.figure(\n",
    "        layout=\"constrained\",\n",
    "        figsize=((shape[1] + colorbar_panel_width) * panel_width, shape[0] * panel_height),\n",
    "    )\n",
    "\n",
    "    for m, model in enumerate(models):\n",
    "        model_prompt_ax = fig.add_subplot(gs[m, 0])\n",
    "        model_icl_ax = fig.add_subplot(gs[m, 1])\n",
    "\n",
    "        model_prompt_ax.set_title(f\"{model} - Instruction\", fontsize=fontsize + font_inc, fontfamily=fontfamily)\n",
    "        model_icl_ax.set_title(f\"{model} - Demonstration\", fontsize=fontsize + font_inc, fontfamily=fontfamily)\n",
    "\n",
    "        prompt_data = mean_ie_data.get((model, BOTH), None)\n",
    "        icl_data = mean_ie_data.get((model, ICL), None)\n",
    "        if prompt_data is None or icl_data is None:\n",
    "            raise ValueError(f\"Missing data for {model}\")\n",
    "\n",
    "        prompt_data = prompt_data.cpu().numpy()\n",
    "        icl_data = icl_data.cpu().numpy()\n",
    "        overall_min = min(prompt_data.min(), icl_data.min())\n",
    "        overall_max = max(prompt_data.max(), icl_data.max())\n",
    "        colorbar_ticks = (None,)\n",
    "        if colormap_round:\n",
    "            overall_min = np.floor(overall_min * colormap_round_scale) / colormap_round_scale\n",
    "            overall_max = np.ceil(overall_max * colormap_round_scale) / colormap_round_scale\n",
    "            step = 1 / colormap_round_scale\n",
    "            colorbar_ticks = np.arange(overall_min, overall_max + step, step)\n",
    "\n",
    "        print(model, overall_min, overall_max)\n",
    "        if force_cmap_zero_middle:\n",
    "            norm = matplotlib.colors.TwoSlopeNorm(vmin=overall_min, vcenter=0, vmax=overall_max)\n",
    "        else:\n",
    "            norm = matplotlib.colors.Normalize(vmin=overall_min, vmax=overall_max)\n",
    "\n",
    "        model_cmap = colormap\n",
    "        if shrink_cmap:\n",
    "            positive_range = cmap_max - 0.5\n",
    "            model_cmap = matplotlib.colors.LinearSegmentedColormap.from_list(\n",
    "                f\"{model}_shrunk_cmap\",\n",
    "                colormap(np.linspace(0.5 - (positive_range * abs(overall_min) / overall_max), cmap_max, 256)),\n",
    "            )\n",
    "\n",
    "        model_prompt_ax.imshow(\n",
    "            prompt_data.T,\n",
    "            cmap=model_cmap,\n",
    "            norm=norm,\n",
    "        )\n",
    "        model_icl_ax.imshow(\n",
    "            icl_data.T,\n",
    "            cmap=model_cmap,\n",
    "            norm=norm,\n",
    "        )\n",
    "\n",
    "        model_cbar_ax = fig.add_subplot(gs[m, 2])\n",
    "        cbar = fig.colorbar(\n",
    "            matplotlib.cm.ScalarMappable(cmap=model_cmap, norm=norm),\n",
    "            cax=model_cbar_ax,\n",
    "            orientation=\"vertical\",\n",
    "            fraction=0.046,\n",
    "            shrink=0.5,\n",
    "            pad=0.04,\n",
    "            ticks=colorbar_ticks,\n",
    "        )\n",
    "        cbar.set_label(\"Mean Indirect Effect\", fontsize=fontsize + font_inc, fontfamily=fontfamily)\n",
    "\n",
    "        for ax in (model_prompt_ax, model_icl_ax):\n",
    "            ax.set_xlabel(\"Layer index\", fontsize=fontsize + font_inc, fontfamily=fontfamily)\n",
    "            if (not ylabel_first_ax_only) or m == 0:\n",
    "                ax.set_ylabel(\n",
    "                    \"Top head index\",\n",
    "                    fontsize=fontsize + font_inc,\n",
    "                    fontfamily=fontfamily,\n",
    "                )\n",
    "            else:\n",
    "                ax.set_ylabel(\"\")\n",
    "\n",
    "            if limits_from_model_layers:\n",
    "                model_layers, model_heads = MODEL_TO_N_LAYERS_HEADS[model]\n",
    "                # ax.set_xlim(0, model_layers)\n",
    "                # ax.set_ylim(0, model_heads)\n",
    "\n",
    "                for max_value, set_method in (\n",
    "                    (model_layers, ax.set_xticks),\n",
    "                    (model_heads, ax.set_yticks),\n",
    "                ):\n",
    "                    step = 1\n",
    "                    for divider in range(10, 0, -1):\n",
    "                        if max_value % divider == 0:\n",
    "                            step = max_value / divider\n",
    "                            break\n",
    "\n",
    "                    set_method(np.arange(0, max_value + step, step))\n",
    "\n",
    "    if save_name is not None:\n",
    "        save_plot(save_name)\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "MAIN_PAPER_PLOT_MODELS = [\"Llama-3.2-3B\", \"Llama-3.2-3B-Instruct\", \"Llama-3.1-8B\", \"Llama-3.1-8B-Instruct\"]\n",
    "\n",
    "\n",
    "plot_mean_ies_by_model_and_type(\n",
    "    mean_ie_by_model,\n",
    "    MAIN_PAPER_PLOT_MODELS,\n",
    "    ie_colormap.mpl_colormap,\n",
    "    force_cmap_zero_middle=False,\n",
    "    shrink_cmap=True,\n",
    "    cmap_max=0.8,\n",
    "    fontfamily=\"monospace\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Some calculations for the localizer experiments\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "N_TOP_HEADS = 20\n",
    "rows = []\n",
    "\n",
    "for model in MAIN_PAPER_PLOT_MODELS:\n",
    "    model_info = layer_key_dicts[f\"{model} @ {N_TOP_HEADS}\"]\n",
    "    model_prompt_heads = model_info[\"prompt_heads\"]\n",
    "    model_icl_heads = model_info[\"icl_heads\"]\n",
    "\n",
    "    model_prompt_mean_ie = mean_ie_by_model[(model, BOTH)]\n",
    "    model_icl_mean_ie = mean_ie_by_model[(model, ICL)]\n",
    "\n",
    "    mean_prompt_head_prompt_ie = np.mean([model_prompt_mean_ie[head] for head in model_prompt_heads])\n",
    "    mean_icl_head_icl_ie = np.mean([model_icl_mean_ie[head] for head in model_icl_heads])\n",
    "    mean_prompt_head_icl_ie = np.mean([model_icl_mean_ie[head] for head in model_prompt_heads])\n",
    "    mean_icl_head_prompt_ie = np.mean([model_prompt_mean_ie[head] for head in model_icl_heads])\n",
    "    rows.append(\n",
    "        {\n",
    "            \"Model\": model,\n",
    "            \"Overall median demonstration CIE\": model_icl_mean_ie.median(),\n",
    "            \"Demonstration heads / demonstration CIE\": mean_icl_head_icl_ie,\n",
    "            \"Instruction heads / demonstration CIE\": mean_prompt_head_icl_ie,\n",
    "            \"Localizer difference\": mean_prompt_head_icl_ie - mean_icl_head_prompt_ie,\n",
    "            \"Demonstration heads / instruction CIE\": mean_icl_head_prompt_ie,\n",
    "            \"Instruction heads / instruction CIE\": mean_prompt_head_prompt_ie,\n",
    "            \"Overall median instruction CIE\": model_prompt_mean_ie.median(),\n",
    "        }\n",
    "    )\n",
    "\n",
    "\n",
    "# table_format = \"github\"\n",
    "table_format = \"latex_booktabs\"\n",
    "\n",
    "output = tabulate.tabulate(\n",
    "    pd.DataFrame(rows).set_index(\"Model\").T, headers=\"keys\", tablefmt=table_format, floatfmt=\".4e\"\n",
    ")\n",
    "if table_format == \"github\":\n",
    "    display(Markdown(output))\n",
    "else:\n",
    "    print(output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.stats import rankdata\n",
    "\n",
    "prompt_head_ranks = np.product(model_prompt_mean_ie.shape) - rankdata(model_prompt_mean_ie).reshape(\n",
    "    model_prompt_mean_ie.shape\n",
    ")\n",
    "icl_head_ranks = np.product(model_icl_mean_ie.shape) - rankdata(model_icl_mean_ie).reshape(model_icl_mean_ie.shape)\n",
    "shared_heads = set(model_prompt_heads) & set(model_icl_heads)\n",
    "prompt_only_heads = set(model_prompt_heads) - shared_heads\n",
    "icl_only_heads = set(model_icl_heads) - shared_heads\n",
    "np.mean([prompt_head_ranks[th] for th in icl_only_heads]), np.mean([icl_head_ranks[th] for th in prompt_only_heads])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# A table of the overall CIEs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "for model_set in (MAIN_PAPER_PLOT_MODELS, APPENDIX_MODELS[:4], APPENDIX_MODELS[4:8]):\n",
    "    rows = []\n",
    "    \n",
    "    for model in model_set:\n",
    "        row = dict(model=model)\n",
    "        for N_TOP_HEADS in (10, 20, 100):\n",
    "            prompt_mean_ie = mean_ie_by_model[(model, BOTH)]\n",
    "            _, prompt_top_values = top_heads_from_indirect_effects_with_values(prompt_mean_ie, N_TOP_HEADS)\n",
    "            # row[f'Prompt_{N_TOP_HEADS}_mean'] = np.mean(prompt_top_values)\n",
    "            # row[f'Prompt_{N_TOP_HEADS}_median'] = np.median(prompt_top_values)\n",
    "            \n",
    "            icl_mean_ie = mean_ie_by_model[(model, ICL)]\n",
    "            _, icl_top_values = top_heads_from_indirect_effects_with_values(icl_mean_ie, N_TOP_HEADS)\n",
    "            row[f'{N_TOP_HEADS}_mean_ratio'] = np.mean(icl_top_values) / np.mean(prompt_top_values)\n",
    "            row[f'{N_TOP_HEADS}_median_ratio'] = np.median(icl_top_values)  / np.median(prompt_top_values)\n",
    "\n",
    "        rows.append(row)\n",
    "\n",
    "    # rows = pd.DataFrame(rows).set_index(\"model\").T.reset_index().rename(columns={\"index\": \"Metric\"}).to_dict(orient=\"records\")\n",
    "\n",
    "    headers = rows[0].keys()\n",
    "    print_rows = [[row[key] for key in headers] if isinstance(row, dict) else row for row in rows]\n",
    "\n",
    "    # table_format = \"github\"\n",
    "    table_format = \"latex_booktabs\"\n",
    "\n",
    "    output = tabulate.tabulate(\n",
    "        pd.DataFrame(rows).set_index(\"model\").T, headers=\"keys\", tablefmt=table_format, floatfmt=\".3f\"\n",
    "    )\n",
    "    if table_format == \"github\":\n",
    "        display(Markdown(output))\n",
    "    else:\n",
    "        print()\n",
    "        print(output)\n",
    "        print()\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "APPENDIX_MODELS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Can we generate some statistical tests for these?\n",
    "\n",
    "Let's start with the prompt vs. ICL layers, then do base vs. instruct (or more if it's OLMo)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.stats import mannwhitneyu\n",
    "\n",
    "alpha = 0.05 / (len(RELEVANT_MODELS) * 2)\n",
    "\n",
    "test_rows = []\n",
    "for model in RELEVANT_MODELS:\n",
    "    row_dict = dict(model=model)\n",
    "    for n in (10, 20):\n",
    "        model_results = layer_key_dicts[f\"{model} @ {n}\"]\n",
    "        row_dict[f\"ICL Layers ({n})\"] = model_results[\"icl_layers\"]\n",
    "        row_dict[f\"Prompt Layers ({n})\"] = model_results[\"prompt_layers\"]\n",
    "        stat, p = mannwhitneyu(\n",
    "            model_results[\"icl_layers\"],\n",
    "            model_results[\"prompt_layers\"],\n",
    "            alternative=\"two-sided\",\n",
    "        )\n",
    "        row_dict[f\"Test statistic ({n})\"] = stat\n",
    "        row_dict[f\"p-value ({n})\"] = f\"{p:.4f}{'' if p > alpha else ' *'}\"\n",
    "\n",
    "    test_rows.append(row_dict)\n",
    "\n",
    "\n",
    "display(Markdown(tabulate.tabulate(test_rows, headers=\"keys\", tablefmt=\"github\")))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.stats import mannwhitneyu\n",
    "\n",
    "alpha = 0.05 / (len(RELEVANT_MODELS) * 2)\n",
    "\n",
    "layers_key = \"prompt_layers\"\n",
    "\n",
    "test_rows = []\n",
    "for i in range(0, len(RELEVANT_MODELS), 2):\n",
    "    base_model = RELEVANT_MODELS[i]\n",
    "    instruct_model = RELEVANT_MODELS[i + 1]\n",
    "\n",
    "    if \"OLMo\" in base_model:\n",
    "        continue\n",
    "\n",
    "    row_dict = {\"Base Model\": base_model}\n",
    "    for n in (10, 20):\n",
    "        base_model_layers = layer_key_dicts[f\"{base_model} @ {n}\"][layers_key]\n",
    "        instruct_model_layers = layer_key_dicts[f\"{instruct_model} @ {n}\"][layers_key]\n",
    "        row_dict[f\"Base Layers ({n})\"] = base_model_layers\n",
    "        row_dict[f\"Instruct Layers ({n})\"] = instruct_model_layers\n",
    "        stat, p = mannwhitneyu(\n",
    "            base_model_layers,\n",
    "            instruct_model_layers,\n",
    "            alternative=\"two-sided\",\n",
    "        )\n",
    "        row_dict[f\"Test statistic ({n})\"] = stat\n",
    "        row_dict[f\"p-value ({n})\"] = f\"{p:.4f}{'' if p > alpha else ' *'}\"\n",
    "\n",
    "    test_rows.append(row_dict)\n",
    "\n",
    "\n",
    "display(Markdown(tabulate.tabulate(test_rows, headers=\"keys\", tablefmt=\"github\")))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.stats import mannwhitneyu\n",
    "\n",
    "layers_key = \"prompt_layers\"\n",
    "n_top_heads = 20\n",
    "\n",
    "test_rows = []\n",
    "\n",
    "olmo_models = [model for model in RELEVANT_MODELS if \"OLMo\" in model]\n",
    "alpha = 0.05 / (len(olmo_models) * (len(olmo_models) - 1) / 2)\n",
    "\n",
    "for first_model in olmo_models:\n",
    "    row_dict = {\"Model\": first_model}\n",
    "    for second_model in olmo_models:\n",
    "        if first_model == second_model:\n",
    "            row_dict[second_model] = \"\"\n",
    "\n",
    "        else:\n",
    "            first_model_layers = layer_key_dicts[f\"{first_model} @ {n_top_heads}\"][layers_key]\n",
    "            second_model_layers = layer_key_dicts[f\"{second_model} @ {n_top_heads}\"][layers_key]\n",
    "            stat, p = mannwhitneyu(\n",
    "                first_model_layers,\n",
    "                second_model_layers,\n",
    "                alternative=\"two-sided\",\n",
    "            )\n",
    "            row_dict[second_model] = f\"U = {stat}, p = {p:.4f}{'' if p > alpha else ' *'}\"\n",
    "\n",
    "    test_rows.append(row_dict)\n",
    "\n",
    "\n",
    "display(Markdown(tabulate.tabulate(test_rows, headers=\"keys\", tablefmt=\"github\")))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib.ticker import MaxNLocator\n",
    "from palettable.colorbrewer.qualitative import Paired_12 as cmap\n",
    "\n",
    "# cmap = colorcet.glasbey_dark\n",
    "\n",
    "\n",
    "DEFUALT_ANNOTATE_KWARGS = dict(\n",
    "    marker=\"*\",\n",
    "    s=200,\n",
    "    color=\"gold\",\n",
    "    edgecolor=\"black\",\n",
    ")\n",
    "\n",
    "\n",
    "def plot_layer_rows_metric_bar_chart(\n",
    "    layer_info_dicts: typing.Dict[str, typing.Dict[str, float]],\n",
    "    n_top_heads: int,\n",
    "    metric_names: str | typing.Sequence[str],\n",
    "    models: typing.Sequence[str] | None = None,\n",
    "    colors: typing.Sequence[str] = None,\n",
    "    title: str | None = None,\n",
    "    ylabel: str | None = None,\n",
    "    figsize: typing.Tuple[float, float] = (8, 6),\n",
    "    fontsize: int = 20,\n",
    "    font_inc: int = 4,\n",
    "    bar_width: float = 0.8,\n",
    "    fake_value: float | None = None,\n",
    "    fake_value_index: int = 0,\n",
    "    fake_value_label: str | None = None,\n",
    "    fake_value_color: str = \"gray\",\n",
    "    metric_name_to_bar_kwargs: typing.Dict[str, typing.Dict[str, typing.Any]] = None,\n",
    "    metric_name_to_err_metric: typing.Dict[str, str] = None,\n",
    "    err_sem: bool = True,\n",
    "    annotate_values: typing.Dict[str, typing.Dict[str, float]] | None = None,\n",
    "    annotate_kwargs: typing.Dict[str, typing.Any] = None,\n",
    "    ylim: typing.Tuple[float, float] | None = None,\n",
    "    fontfamily: str | None = None,\n",
    "):\n",
    "    if models is None:\n",
    "        models = RELEVANT_MODELS[:]\n",
    "\n",
    "    if colors is not None and len(colors) != len(models):\n",
    "        raise ValueError(f\"Length of colors ({len(colors)}) does not match number of models ({len(models)}).\")\n",
    "\n",
    "    if isinstance(metric_names, str):\n",
    "        metric_names = [metric_names]\n",
    "\n",
    "    if metric_name_to_bar_kwargs is None:\n",
    "        metric_name_to_bar_kwargs = dict()\n",
    "\n",
    "    if metric_name_to_err_metric is None:\n",
    "        metric_name_to_err_metric = dict()\n",
    "\n",
    "    if annotate_values is None:\n",
    "        annotate_values = dict()\n",
    "\n",
    "    if annotate_kwargs is None:\n",
    "        annotate_kwargs = dict()\n",
    "\n",
    "    if fontfamily is None:\n",
    "        fontfamily = plt.rcParams[\"font.family\"]\n",
    "\n",
    "    annotate_kwargs = {**DEFUALT_ANNOTATE_KWARGS, **annotate_kwargs}\n",
    "\n",
    "    fig, ax = plt.subplots(figsize=figsize)\n",
    "\n",
    "    all_metric_values = [\n",
    "        [layer_info_dicts[f\"{model} @ {n_top_heads}\"].get(metric_name, np.nan) for model in models]\n",
    "        for metric_name in metric_names\n",
    "    ]\n",
    "\n",
    "    labels = models[:]\n",
    "\n",
    "    if fake_value is not None:\n",
    "        if isinstance(fake_value, (int, float)):\n",
    "            fake_value = [fake_value] * len(metric_names)\n",
    "\n",
    "        for metric_values, fv in zip(all_metric_values, fake_value):\n",
    "            metric_values.insert(fake_value_index, fv)\n",
    "\n",
    "        labels.insert(fake_value_index, fake_value_label if fake_value_label is not None else \"Fake Value\")\n",
    "        if colors is not None:\n",
    "            colors.insert(fake_value_index, fake_value_color)\n",
    "\n",
    "    for metric_name, vals in zip(metric_names, all_metric_values):\n",
    "        if np.any(np.isnan(vals)):\n",
    "            logger.warning(f\"Missing data for metrics {metric_name}.\")\n",
    "            return\n",
    "\n",
    "    bar_positions = np.arange(len(labels))\n",
    "    bar_width /= len(metric_names)\n",
    "\n",
    "    for i, (metric_name, metric_values) in enumerate(zip(metric_names, all_metric_values)):\n",
    "        x_positions = bar_positions + i * bar_width\n",
    "        bar_kwargs = dict(width=bar_width, **metric_name_to_bar_kwargs.get(metric_name, {}))\n",
    "        if colors is not None:\n",
    "            bar_kwargs[\"color\"] = colors\n",
    "        ax.bar(x_positions, metric_values, **bar_kwargs, label=labels if i == 0 else None)\n",
    "\n",
    "        annotation_values = annotate_values.get(metric_name, None)\n",
    "        if annotation_values is not None:\n",
    "            ann_vals = [annotation_values.get(model, None) for model in models]\n",
    "            if fake_value is not None:\n",
    "                ann_vals.insert(fake_value_index, None)\n",
    "\n",
    "            ax.scatter(x_positions, ann_vals, **annotate_kwargs)\n",
    "\n",
    "        err_metric = metric_name_to_err_metric.get(metric_name, None)\n",
    "        if err_metric is not None:\n",
    "            err_values = [layer_info_dicts[f\"{model} @ {n_top_heads}\"].get(err_metric, np.nan) for model in models]\n",
    "            if fake_value is not None:\n",
    "                err_values.insert(fake_value_index, 0)\n",
    "\n",
    "            if err_sem:\n",
    "                err_values = np.array(err_values) / np.sqrt(n_top_heads)\n",
    "            ax.errorbar(x_positions, metric_values, yerr=err_values, fmt=\"none\", capsize=5, color=\"gray\")\n",
    "\n",
    "    ax.set_xticks([])\n",
    "    ax.yaxis.set_major_locator(MaxNLocator(integer=True))\n",
    "    # ax.set_xticklabels(models, fontsize=fontsize + font_inc)\n",
    "    ax.set_ylabel(\n",
    "        metric_name.replace(\"_\", \" \").capitalize() if ylabel is None else ylabel, fontsize=fontsize + font_inc\n",
    "    )\n",
    "    if title is not None:\n",
    "        ax.set_title(title, fontsize=fontsize + (2 * font_inc), fontfamily=fontfamily)\n",
    "    ax.legend(\n",
    "        fontsize=fontsize - font_inc,\n",
    "        prop=dict(family=fontfamily),\n",
    "    )\n",
    "    ax.tick_params(axis=\"y\", labelsize=fontsize, labelfontfamily=fontfamily)\n",
    "    if ylim is not None:\n",
    "        ax.set_ylim(ylim)\n",
    "    # plt.xticks(rotation=45)\n",
    "    # ax.tick_params(axis='x', which='both', bottom=False, top=False)\n",
    "    plt.tight_layout()\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "# ordered_color_indices = list(range(len(RELEVANT_MODELS)))\n",
    "# cmap = colorcet.glasbey_dark\n",
    "# plot_layer_rows_metric_bar_chart(\n",
    "#     layer_key_dicts,\n",
    "#     10,\n",
    "#     \"shared_heads\",\n",
    "#     colors=[cmap[i] for i in ordered_color_indices],\n",
    "#     ylim=(0, 10),\n",
    "#     fake_value=10,\n",
    "#     fake_value_label=\"Maximum possible\",\n",
    "#     title=\"Prompt & ICL shared heads (top 10)\",\n",
    "# )\n",
    "\n",
    "# plot_layer_rows_metric_bar_chart(\n",
    "#     layer_key_dicts,\n",
    "#     20,\n",
    "#     \"shared_heads\",\n",
    "#     colors=[cmap[i] for i in ordered_color_indices],\n",
    "#     ylim=(0, 20),\n",
    "#     fake_value=20,\n",
    "#     fake_value_label=\"Maximum possible\",\n",
    "#     title=\"Prompt & ICL shared heads (top 20)\",\n",
    "# )\n",
    "\n",
    "\n",
    "ordered_color_indices = [0, 1, 6, 7]\n",
    "plot_layer_rows_metric_bar_chart(\n",
    "    layer_key_dicts,\n",
    "    20,\n",
    "    \"shared_heads\",\n",
    "    models=[model for model in RELEVANT_MODELS if (\"3.2-3B\" in model) or (\"3.1-8B\" in model)],\n",
    "    colors=[cmap.mpl_colors[i] for i in ordered_color_indices],\n",
    "    ylim=(0, 20),\n",
    "    fake_value=20,\n",
    "    fake_value_label=\"Maximum possible\",\n",
    "    fontfamily=\"monospace\",\n",
    "    # title=\"Prompt & ICL shared heads (top 20)\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def repeat_elements(items, n):\n",
    "    return [item for item in items for _ in range(n)]\n",
    "\n",
    "\n",
    "def plot_layer_rows_individual_dots(\n",
    "    layer_info_dicts: typing.Dict[str, typing.Dict[str, float]],\n",
    "    n_top_heads: int,\n",
    "    metric_names: str | typing.Sequence[str],\n",
    "    models: typing.Sequence[str] | None = None,\n",
    "    colors: typing.Sequence[str] | None = None,\n",
    "    title: str | None = None,\n",
    "    ylabel: str | None = None,\n",
    "    ylabel_first_ax_only: bool = True,\n",
    "    fontsize: int = 20,\n",
    "    font_inc: int = 4,\n",
    "    metric_labels: typing.Dict[str, str] = None,\n",
    "    metric_name_to_plot_kwargs: typing.Dict[str, typing.Dict[str, typing.Any]] = None,\n",
    "    metric_name_to_err_metric: typing.Dict[str, str] = None,\n",
    "    annotate_values: typing.Dict[str, typing.Dict[str, float]] | None = None,\n",
    "    annotate_kwargs: typing.Dict[str, typing.Any] = None,\n",
    "    ylim: typing.Tuple[float, float] | None = None,\n",
    "    ylim_from_model_layers: bool = False,\n",
    "    **global_plot_kwargs,\n",
    "):\n",
    "    if models is None:\n",
    "        models = RELEVANT_MODELS[:]\n",
    "\n",
    "    if colors is not None and len(colors) != len(models):\n",
    "        raise ValueError(f\"Length of colors ({len(colors)}) does not match number of models ({len(models)}).\")\n",
    "\n",
    "    if isinstance(metric_names, str):\n",
    "        metric_names = [metric_names]\n",
    "\n",
    "    if metric_name_to_plot_kwargs is None:\n",
    "        metric_name_to_plot_kwargs = dict()\n",
    "\n",
    "    if metric_name_to_err_metric is None:\n",
    "        metric_name_to_err_metric = dict()\n",
    "\n",
    "    if annotate_values is None:\n",
    "        annotate_values = dict()\n",
    "\n",
    "    if annotate_kwargs is None:\n",
    "        annotate_kwargs = dict()\n",
    "\n",
    "    if metric_labels is None:\n",
    "        metric_labels = dict()\n",
    "\n",
    "    elif isinstance(metric_labels, (list, tuple)):\n",
    "        if len(metric_labels) != len(metric_names):\n",
    "            raise ValueError(\n",
    "                f\"Length of metric_labels ({len(metric_labels)}) does not match number of metrics ({len(metric_names)}).\"\n",
    "            )\n",
    "        metric_labels = dict(zip(metric_names, metric_labels))\n",
    "\n",
    "    annotate_kwargs = {**DEFUALT_ANNOTATE_KWARGS, **annotate_kwargs}\n",
    "\n",
    "    rows = []\n",
    "    for metric_name in metric_names:\n",
    "        for model in models:\n",
    "            base_model = model.replace(\"-Instruct\", \"\").replace(\"-chat\", \"\")\n",
    "            is_instruct = (\"-Instruct\" in model) or (\"-chat\" in model)\n",
    "            model_metric_vals = layer_info_dicts[f\"{model} @ {n_top_heads}\"].get(metric_name, np.nan)\n",
    "            if isinstance(model_metric_vals, (int, float)):\n",
    "                rows.append(\n",
    "                    dict(\n",
    "                        model=model,\n",
    "                        base_model=base_model,\n",
    "                        is_instruct=is_instruct,\n",
    "                        metric_name=metric_name,\n",
    "                        value=model_metric_vals,\n",
    "                    )\n",
    "                )\n",
    "            elif isinstance(model_metric_vals, list):\n",
    "                for i, val in enumerate(model_metric_vals):\n",
    "                    rows.append(\n",
    "                        dict(\n",
    "                            model=model,\n",
    "                            base_model=base_model,\n",
    "                            is_instruct=is_instruct,\n",
    "                            metric_name=metric_name,\n",
    "                            value=val,\n",
    "                            index=i,\n",
    "                        )\n",
    "                    )\n",
    "\n",
    "    all_values_df = pd.DataFrame(rows)\n",
    "    all_values_df = all_values_df.assign(\n",
    "        metric_name=all_values_df.metric_name.map(lambda x: metric_labels.get(x, x)),\n",
    "    )\n",
    "\n",
    "    fig, axes = plt.subplots(1, 5, figsize=(36, 6))\n",
    "    axes = axes.flatten()\n",
    "\n",
    "    for b, base_model in enumerate(all_values_df.base_model.unique()):\n",
    "        base_model_df = all_values_df[all_values_df.base_model == base_model]\n",
    "        ax = axes[b]\n",
    "\n",
    "        base_model_colors = colors[b * len(metric_names) : (b + 1) * len(metric_names)]\n",
    "\n",
    "        sns.swarmplot(\n",
    "            data=base_model_df,\n",
    "            x=\"model\",\n",
    "            y=\"value\",\n",
    "            hue=\"metric_name\",\n",
    "            ax=ax,\n",
    "            palette=base_model_colors,\n",
    "            **global_plot_kwargs,\n",
    "        )\n",
    "\n",
    "        sns.boxplot(\n",
    "            data=base_model_df,\n",
    "            x=\"model\",\n",
    "            y=\"value\",\n",
    "            hue=\"metric_name\",\n",
    "            showmeans=True,\n",
    "            meanline=True,\n",
    "            meanprops={\"color\": \"k\", \"ls\": \"-\", \"lw\": 2},\n",
    "            medianprops={\"visible\": False},\n",
    "            whiskerprops={\"visible\": False},\n",
    "            zorder=10,\n",
    "            showfliers=False,\n",
    "            showbox=False,\n",
    "            showcaps=False,\n",
    "            palette=base_model_colors,\n",
    "            legend=None,\n",
    "            ax=ax,\n",
    "        )\n",
    "\n",
    "        ax.set_xlabel(\"Model\", fontsize=fontsize + font_inc)\n",
    "        if (not ylabel_first_ax_only) or b == 0:\n",
    "            ax.set_ylabel(\n",
    "                metric_name.replace(\"_\", \" \").capitalize() if ylabel is None else ylabel, fontsize=fontsize + font_inc\n",
    "            )\n",
    "\n",
    "        ax_title = title\n",
    "        if ax_title is None:\n",
    "            ax_title = f\"{base_model} ({n_top_heads} heads)\"\n",
    "\n",
    "        ax.set_title(ax_title, fontsize=fontsize + (2 * font_inc))\n",
    "\n",
    "        ax.tick_params(axis=\"both\", labelsize=fontsize)\n",
    "\n",
    "        if ylim_from_model_layers:\n",
    "            model_layers = MODEL_TO_N_LAYERS[base_model]\n",
    "            ax.set_ylim(0, model_layers)\n",
    "            y_step = 1\n",
    "            for divider in range(10, 0, -1):\n",
    "                if model_layers % divider == 0:\n",
    "                    y_step = model_layers / divider\n",
    "                    break\n",
    "\n",
    "            ax.set_yticks(np.arange(0, model_layers + y_step, y_step))\n",
    "\n",
    "        elif ylim is not None:\n",
    "            ax.set_ylim(ylim)\n",
    "\n",
    "        # Add a legend artist for the black line representing the mean\n",
    "        mean_line = plt.Line2D([0], [0], color=\"k\", linestyle=\"-\", linewidth=2, label=\"Mean\")\n",
    "        loc = \"lower right\" if \"Llama-3\" in base_model else \"best\"\n",
    "        ax.legend(handles=ax.get_legend_handles_labels()[0] + [mean_line], fontsize=fontsize - font_inc, loc=loc)\n",
    "        # ax.legend(fontsize=fontsize - font_inc)\n",
    "\n",
    "    plt.tight_layout()\n",
    "    plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from palettable.colorbrewer.qualitative import Paired_12 as cmap\n",
    "\n",
    "\n",
    "def flip_pairs(lst):\n",
    "    return [item for pair in zip(lst[1::2], lst[::2]) for item in pair]\n",
    "\n",
    "\n",
    "MODELS_NO_OLMO = [model for model in RELEVANT_MODELS if \"OLMo\" not in model]\n",
    "OLMO_MODELS = [model for model in RELEVANT_MODELS if \"OLMo\" in model]\n",
    "\n",
    "for n_top_heads in (10, 20):\n",
    "    plot_layer_rows_individual_dots(\n",
    "        layer_key_dicts,\n",
    "        n_top_heads,\n",
    "        # ['icl_layer_depths', 'prompt_layer_depths'],\n",
    "        [\"icl_layers\", \"prompt_layers\"],\n",
    "        models=MODELS_NO_OLMO,\n",
    "        colors=flip_pairs([cmap.mpl_colors[i] for i in range(len(MODELS_NO_OLMO))]),\n",
    "        ylabel=\"Top head layer\",\n",
    "        # fake_value=10,\n",
    "        # fake_value_label='Maximum possible',\n",
    "        # title='Prompt & ICL mean head layer depths (top 10)',\n",
    "        metric_labels=[\"ICL\", \"Prompt\"],\n",
    "        # ylim=(0, 1),\n",
    "        ylim_from_model_layers=True,\n",
    "        dodge=True,\n",
    "        size=10,\n",
    "    )\n",
    "\n",
    "# for n_top_heads in (10, 20):\n",
    "#     plot_layer_rows_individual_dots(\n",
    "#         layer_key_dicts,\n",
    "#         n_top_heads,\n",
    "#         # ['icl_layer_depths', 'prompt_layer_depths'],\n",
    "#         [\"icl_layers\", \"prompt_layers\"],\n",
    "#         models=[model for model in RELEVANT_MODELS if \"OLMo\" in model],\n",
    "#         colors=flip_pairs([olmo_cmap.mpl_colors[i] for i in range(4)]),\n",
    "#         ylabel=\"Top head layer\",\n",
    "#         # fake_value=10,\n",
    "#         # fake_value_label='Maximum possible',\n",
    "#         # title='Prompt & ICL mean head layer depths (top 10)',\n",
    "#         metric_labels=[\"ICL\", \"Prompt\"],\n",
    "#         # ylim=(0, 1),\n",
    "#         ylim_from_model_layers=True,\n",
    "#         dodge=True,\n",
    "#         size=10,\n",
    "#     )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_layer_rows_metric_bar_chart(\n",
    "    layer_key_dicts,\n",
    "    10,\n",
    "    [\"icl_layer_depth\", \"prompt_layer_depth\"],\n",
    "    models=MODELS_NO_OLMO,\n",
    "    colors=flip_pairs([cmap.mpl_colors[i] for i in range(len(MODELS_NO_OLMO))]),\n",
    "    ylabel=\"Mean layer depth\",\n",
    "    # fake_value=10,\n",
    "    # fake_value_label='Maximum possible',\n",
    "    title=\"Prompt & ICL mean head layer depths (top 10)\",\n",
    "    ylim=(0, 1),\n",
    "    metric_name_to_bar_kwargs=dict(\n",
    "        prompt_layer_depth=dict(hatch=\"/\"),\n",
    "    ),\n",
    "    metric_name_to_err_metric=dict(\n",
    "        icl_layer_depth=\"icl_layer_depth_std\",\n",
    "        prompt_layer_depth=\"prompt_layer_depth_std\",\n",
    "    ),\n",
    "    # annotate_values=dict(\n",
    "    #     icl_layer_depth={m: d['zs_icl_10_max_acc_layer_depth'] for m, d in layer_depth_model_dicts.items()},\n",
    "    #     prompt_layer_depth={m: d['zs_both_10_max_acc_layer_depth'] for m, d in layer_depth_model_dicts.items()}\n",
    "    # ),\n",
    ")\n",
    "\n",
    "plot_layer_rows_metric_bar_chart(\n",
    "    layer_key_dicts,\n",
    "    20,\n",
    "    [\"icl_layer_depth\", \"prompt_layer_depth\"],\n",
    "    models=MODELS_NO_OLMO,\n",
    "    colors=flip_pairs([cmap.mpl_colors[i] for i in range(len(MODELS_NO_OLMO))]),\n",
    "    ylabel=\"Mean layer depth\",\n",
    "    # fake_value=10,\n",
    "    # fake_value_label='Maximum possible',\n",
    "    title=\"Prompt & ICL mean head layer depths (top 20)\",\n",
    "    ylim=(0, 1),\n",
    "    metric_name_to_bar_kwargs=dict(\n",
    "        prompt_layer_depth=dict(hatch=\"/\"),\n",
    "    ),\n",
    "    metric_name_to_err_metric=dict(\n",
    "        icl_layer_depth=\"icl_layer_depth_std\",\n",
    "        prompt_layer_depth=\"prompt_layer_depth_std\",\n",
    "    ),\n",
    "    # annotate_values=dict(\n",
    "    #     icl_layer_depth={m: d['zs_icl_20_max_acc_layer_depth'] for m, d in layer_depth_model_dicts.items()},\n",
    "    #     prompt_layer_depth={m: d['zs_both_20_max_acc_layer_depth'] for m, d in layer_depth_model_dicts.items()}\n",
    "    # ),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lines = []\n",
    "\n",
    "for i in range(0, len(RELEVANT_MODELS), 2):\n",
    "    base_model = RELEVANT_MODELS[i]\n",
    "    instruct_model = RELEVANT_MODELS[i + 1]\n",
    "\n",
    "    for n_heads in (10, 20):\n",
    "        base_heads = top_heads_summary_df[\n",
    "            (top_heads_summary_df.model == base_model) & (top_heads_summary_df.n_heads == n_heads)\n",
    "        ].prompt_heads.values[0]\n",
    "        instruct_heads = top_heads_summary_df[\n",
    "            (top_heads_summary_df.model == instruct_model) & (top_heads_summary_df.n_heads == n_heads)\n",
    "        ].prompt_heads.values[0]\n",
    "        base_head_set = set(base_heads)\n",
    "        instruct_head_set = set(instruct_heads)\n",
    "        shared_heads = base_head_set & instruct_head_set\n",
    "        lines.append(f\" - {base_model} & {instruct_model} share {len(shared_heads)} / {n_heads} \")\n",
    "\n",
    "        shared_heads_mean_layer = np.mean([t[0] for t in shared_heads])\n",
    "        base_heads_mean_layer = np.mean([t[0] for t in (base_head_set - instruct_head_set)])\n",
    "        instruct_heads_mean_layer = np.mean([t[0] for t in (instruct_head_set - base_head_set)])\n",
    "        lines.append(\n",
    "            f\"   - {base_model} only mean layer: {base_heads_mean_layer:.2f} | shared heads mean layer: {shared_heads_mean_layer:.2f} | {instruct_model} only mean layer: {instruct_heads_mean_layer:.2f}\"\n",
    "        )\n",
    "\n",
    "        base_only_head_indices = [base_heads.index(t) for t in base_head_set - instruct_head_set]\n",
    "        instruct_only_head_indices = [instruct_heads.index(t) for t in instruct_head_set - base_head_set]\n",
    "        shared_head_indices = [base_heads.index(t) for t in shared_heads] + [\n",
    "            instruct_heads.index(t) for t in shared_heads\n",
    "        ]\n",
    "        lines.append(\n",
    "            f\"   - {base_model} only head mean index: {np.mean(base_only_head_indices):.2f} | shared heads mean index: {np.mean(shared_head_indices):.2f} | {instruct_model} only mean head index: {np.mean(instruct_only_head_indices):.2f}\"\n",
    "        )\n",
    "\n",
    "\n",
    "display(Markdown(\"\\n\".join(lines)))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# RQ2.5: What can I say about how many attention heads are necessary/helpful?\n",
    "\n",
    "A couple of angles of attack here:\n",
    "\n",
    "- Plot the distribution of indirect effects for each model\n",
    "- Plot the overlap between similar settings in the prompt-based approach\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def scatter_plot_top_effects(n_top_heads: int, negative=False, fontsize=20, print_every: int | None = None):\n",
    "    fig, axes = plt.subplots(2, 5, figsize=(36, 12))\n",
    "    axes = axes.flatten()\n",
    "\n",
    "    for offset_idx, (prompt_types, name) in enumerate(\n",
    "        zip(\n",
    "            ((SHORT, LONG), ICL),\n",
    "            (\"Prompt-based\", \"ICL\"),\n",
    "        )\n",
    "    ):\n",
    "        for i in range(len(RELEVANT_MODELS) // 2):\n",
    "            ax_offset = offset_idx * 5\n",
    "            model_idx = i * 2\n",
    "            ax = axes[i + ax_offset]\n",
    "            ax.set_title(RELEVANT_MODELS[model_idx], fontsize=fontsize + 4)\n",
    "            ax.set_xlabel(\"Index\", fontsize=fontsize)\n",
    "            ax.set_ylabel(f\"{name} indirect effect\", fontsize=fontsize)\n",
    "\n",
    "            x_values = np.arange(n_top_heads)\n",
    "            _, base_effects = compute_top_heads(\n",
    "                result_df,\n",
    "                indirect_effects_by_model_and_dataset,\n",
    "                RELEVANT_MODELS[model_idx],\n",
    "                prompt_types,\n",
    "                n_top_heads=n_top_heads,\n",
    "                negative=negative,\n",
    "            )\n",
    "            base_effects = np.array(base_effects)\n",
    "            ax.scatter(x_values, base_effects, label=\"Base\", color=\"blue\", alpha=0.5)\n",
    "            wrong_sign = base_effects > 0 if negative else base_effects < 0\n",
    "            if wrong_sign.any():\n",
    "                logger.warning(\n",
    "                    f\"Found {'positive' if negative else 'negative'} entries in {RELEVANT_MODELS[model_idx]} ({name}) starting at index {wrong_sign.argmax()}\"\n",
    "                )\n",
    "            if print_every is not None:\n",
    "                print(f\"{RELEVANT_MODELS[model_idx]}: {base_effects[::print_every]}\")\n",
    "\n",
    "            _, instruct_effects = compute_top_heads(\n",
    "                result_df,\n",
    "                indirect_effects_by_model_and_dataset,\n",
    "                RELEVANT_MODELS[model_idx + 1],\n",
    "                prompt_types,\n",
    "                n_top_heads=n_top_heads,\n",
    "                negative=negative,\n",
    "            )\n",
    "            instruct_effects = np.array(instruct_effects)\n",
    "            ax.scatter(x_values, instruct_effects, label=\"Instruct\", color=\"orange\", alpha=0.5)\n",
    "            wrong_sign = instruct_effects > 0 if negative else instruct_effects < 0\n",
    "            if wrong_sign.any():\n",
    "                logger.warning(\n",
    "                    f\"Found {'positive' if negative else 'engative'} entries in {RELEVANT_MODELS[model_idx + 1]} ({name}) starting at index {wrong_sign.argmax()}\"\n",
    "                )\n",
    "            if print_every is not None:\n",
    "                print(f\"{RELEVANT_MODELS[model_idx + 1]}: {instruct_effects[::print_every]}\")\n",
    "\n",
    "            ax.tick_params(axis=\"x\", labelsize=fontsize - 4)\n",
    "            ax.tick_params(axis=\"y\", labelsize=fontsize - 4)\n",
    "            ax.legend(fontsize=fontsize)\n",
    "            leg = ax.legend(fontsize=fontsize)\n",
    "            for lh in leg.legend_handles:\n",
    "                lh.set_alpha(1)\n",
    "\n",
    "    plt.tight_layout()\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "scatter_plot_top_effects(\n",
    "    n_top_heads=200,\n",
    "    # print_every=25,\n",
    ")\n",
    "\n",
    "scatter_plot_top_effects(\n",
    "    n_top_heads=200,\n",
    "    negative=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from palettable.cartocolors.qualitative import Bold_4 as cmap\n",
    "\n",
    "\n",
    "def scatter_plot_top_effects_same_axes(\n",
    "    n_top_heads: int,\n",
    "    negative=False,\n",
    "    flip_negative=False,\n",
    "    fontsize=20,\n",
    "    log=False,\n",
    "    print_every: int | None = None,\n",
    "    colors: typing.Sequence[str] = None,\n",
    "):\n",
    "    fig, axes = plt.subplots(1, 5, figsize=(36, 6))\n",
    "    axes = axes.flatten()\n",
    "\n",
    "    for offset_idx, (prompt_types, name) in enumerate(\n",
    "        zip(\n",
    "            ((SHORT, LONG), ICL),\n",
    "            (\"Prompt-based\", \"ICL\"),\n",
    "        )\n",
    "    ):\n",
    "        for i in range(len(RELEVANT_MODELS) // 2):\n",
    "            ax_offset = 0\n",
    "            model_idx = i * 2\n",
    "            ax = axes[i + ax_offset]\n",
    "            ax.set_title(RELEVANT_MODELS[model_idx], fontsize=fontsize + 4)\n",
    "            ax.set_xlabel(\"Index\", fontsize=fontsize)\n",
    "            ax.set_ylabel(f\"{name} indirect effect\", fontsize=fontsize)\n",
    "\n",
    "            base_color = None\n",
    "            if colors is not None:\n",
    "                base_color = colors[(2 * offset_idx)]\n",
    "\n",
    "            x_values = np.arange(n_top_heads)\n",
    "            _, base_effects = compute_top_heads(\n",
    "                result_df,\n",
    "                indirect_effects_by_model_and_dataset,\n",
    "                RELEVANT_MODELS[model_idx],\n",
    "                prompt_types,\n",
    "                n_top_heads=n_top_heads,\n",
    "                negative=negative,\n",
    "            )\n",
    "            base_effects = np.array(base_effects)\n",
    "            if negative and flip_negative:\n",
    "                base_effects = -base_effects\n",
    "\n",
    "            ax.scatter(\n",
    "                x_values,\n",
    "                base_effects,\n",
    "                label=f\"Base ({'prompt' if offset_idx == 0 else 'ICL'})\",\n",
    "                color=base_color,\n",
    "                alpha=0.25,\n",
    "            )\n",
    "            wrong_sign = base_effects > 0 if negative else base_effects < 0\n",
    "            if wrong_sign.any() and not flip_negative:\n",
    "                logger.warning(\n",
    "                    f\"Found {'positive' if negative else 'engative'} entries in {RELEVANT_MODELS[model_idx]} ({name}) starting at index {wrong_sign.argmax()}\"\n",
    "                )\n",
    "            if print_every is not None:\n",
    "                print(f\"{RELEVANT_MODELS[model_idx]}: {base_effects[::print_every]}\")\n",
    "\n",
    "            _, instruct_effects = compute_top_heads(\n",
    "                result_df,\n",
    "                indirect_effects_by_model_and_dataset,\n",
    "                RELEVANT_MODELS[model_idx + 1],\n",
    "                prompt_types,\n",
    "                n_top_heads=n_top_heads,\n",
    "                negative=negative,\n",
    "            )\n",
    "            instruct_effects = np.array(instruct_effects)\n",
    "            if negative and flip_negative:\n",
    "                instruct_effects = -instruct_effects\n",
    "\n",
    "            instruct_color = None\n",
    "            if colors is not None:\n",
    "                instruct_color = colors[(2 * offset_idx) + 1]\n",
    "\n",
    "            ax.scatter(\n",
    "                x_values,\n",
    "                instruct_effects,\n",
    "                label=f\"Instruct ({'prompt' if offset_idx == 0 else 'ICL'})\",\n",
    "                color=instruct_color,\n",
    "                alpha=0.25,\n",
    "            )\n",
    "            wrong_sign = instruct_effects > 0 if negative else instruct_effects < 0\n",
    "            if wrong_sign.any() and not flip_negative:\n",
    "                logger.warning(\n",
    "                    f\"Found {'positive' if negative else 'engative'} entries in {RELEVANT_MODELS[model_idx + 1]} ({name}) starting at index {wrong_sign.argmax()}\"\n",
    "                )\n",
    "            if print_every is not None:\n",
    "                print(f\"{RELEVANT_MODELS[model_idx + 1]}: {instruct_effects[::print_every]}\")\n",
    "\n",
    "            ax.tick_params(axis=\"x\", labelsize=fontsize - 4)\n",
    "            ax.tick_params(axis=\"y\", labelsize=fontsize - 4)\n",
    "            leg = ax.legend(fontsize=fontsize)\n",
    "            for lh in leg.legend_handles:\n",
    "                lh.set_alpha(1)\n",
    "            if log:\n",
    "                ax.set_yscale(\"log\")\n",
    "\n",
    "    plt.tight_layout()\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "ordered_color_indices = np.arange(4)\n",
    "\n",
    "scatter_plot_top_effects_same_axes(\n",
    "    n_top_heads=200,\n",
    "    log=True,\n",
    "    colors=[cmap.mpl_colors[i] for i in ordered_color_indices],\n",
    ")\n",
    "\n",
    "scatter_plot_top_effects_same_axes(\n",
    "    n_top_heads=200,\n",
    "    negative=True,\n",
    "    colors=[cmap.mpl_colors[i] for i in ordered_color_indices],\n",
    ")\n",
    "\n",
    "scatter_plot_top_effects_same_axes(\n",
    "    n_top_heads=200,\n",
    "    negative=True,\n",
    "    flip_negative=True,\n",
    "    log=True,\n",
    "    colors=[cmap.mpl_colors[i] for i in ordered_color_indices],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def histogram_plot_all_effects(bins: int = 50, fontsize=20, **hist_kwargs):\n",
    "    fig, axes = plt.subplots(2, 5, figsize=(36, 12))\n",
    "    axes = axes.flatten()\n",
    "\n",
    "    for offset_idx, (prompt_types, name) in enumerate(\n",
    "        zip(\n",
    "            ((SHORT, LONG), ICL),\n",
    "            (\"Prompt-based\", \"ICL\"),\n",
    "        )\n",
    "    ):\n",
    "        for i in range(len(RELEVANT_MODELS) // 2):\n",
    "            ax_offset = offset_idx * 5\n",
    "            model_idx = i * 2\n",
    "            ax = axes[i + ax_offset]\n",
    "            ax.set_title(RELEVANT_MODELS[model_idx], fontsize=fontsize + 4)\n",
    "            ax.set_xlabel(f\"{name} indirect effect\", fontsize=fontsize)\n",
    "            ax.set_ylabel(\"Proportion\", fontsize=fontsize)\n",
    "\n",
    "            _, _, base_mean_effects = compute_top_heads(\n",
    "                result_df,\n",
    "                indirect_effects_by_model_and_dataset,\n",
    "                RELEVANT_MODELS[model_idx],\n",
    "                prompt_types,\n",
    "                return_mean=True,\n",
    "            )\n",
    "            _, _, inst_mean_effects = compute_top_heads(\n",
    "                result_df,\n",
    "                indirect_effects_by_model_and_dataset,\n",
    "                RELEVANT_MODELS[model_idx + 1],\n",
    "                prompt_types,\n",
    "                return_mean=True,\n",
    "            )\n",
    "            base_mean_effects = base_mean_effects.flatten().numpy()\n",
    "            inst_mean_effects = inst_mean_effects.flatten().numpy()\n",
    "            ax.hist(\n",
    "                [base_mean_effects, inst_mean_effects],\n",
    "                bins=bins,\n",
    "                label=[\"Base\", \"Instruct\"],\n",
    "                color=[\"blue\", \"orange\"],\n",
    "                alpha=0.5,\n",
    "                **hist_kwargs,\n",
    "            )\n",
    "            ax.tick_params(axis=\"x\", labelsize=fontsize - 4)\n",
    "            ax.tick_params(axis=\"y\", labelsize=fontsize - 4)\n",
    "            ax.legend(fontsize=fontsize)\n",
    "\n",
    "    plt.tight_layout()\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "histogram_plot_all_effects(density=True, log=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def box_plot_all_effects(fontsize=20, **boxplot_kwargs):\n",
    "    fig, axes = plt.subplots(1, 2, figsize=(12, 12))\n",
    "    axes = axes.flatten()\n",
    "\n",
    "    for ax_idx, (prompt_types, name) in enumerate(\n",
    "        zip(\n",
    "            ((SHORT, LONG), ICL),\n",
    "            (\"Prompt-based\", \"ICL\"),\n",
    "        )\n",
    "    ):\n",
    "        ax = axes[ax_idx]\n",
    "        ax.set_title(f\"{name} indirect effects\", fontsize=fontsize + 4)\n",
    "        ax.set_xlabel(\"Indirect effect\", fontsize=fontsize)\n",
    "        ax.set_ylabel(\"Model\", fontsize=fontsize)\n",
    "\n",
    "        mean_effects = [\n",
    "            compute_top_heads(\n",
    "                result_df,\n",
    "                indirect_effects_by_model_and_dataset,\n",
    "                model,\n",
    "                prompt_types,\n",
    "                return_mean=True,\n",
    "            )[2]\n",
    "            for model in RELEVANT_MODELS\n",
    "        ]\n",
    "\n",
    "        print([(model, mean_effects[i].shape) for i, model in enumerate(RELEVANT_MODELS)])\n",
    "\n",
    "        mean_effects = [me.flatten().numpy() for me in mean_effects]\n",
    "\n",
    "        model_name_formatter = lambda x: RELEVANT_MODELS[x]\n",
    "\n",
    "        sns.stripplot(mean_effects, ax=ax, formatter=model_name_formatter, **boxplot_kwargs)\n",
    "\n",
    "    plt.tight_layout()\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "box_plot_all_effects(orient=\"h\", size=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def moving_average(data_set, periods=3):\n",
    "    weights = np.ones(periods) / periods\n",
    "    return np.convolve(data_set, weights, mode=\"valid\")\n",
    "\n",
    "\n",
    "def scatter_plot_top_heads_overlap(\n",
    "    comparison_groups: typing.Dict[str, typing.Dict[str, typing.Any]],\n",
    "    min_n_top_heads: int = 10,\n",
    "    max_n_top_heads: int = 100,\n",
    "    moving_average_periods: int | None = None,\n",
    "    fontsize=20,\n",
    "    title=None,\n",
    "):\n",
    "    fig, axes = plt.subplots(2, 5, figsize=(36, 12))\n",
    "    axes = axes.flatten()\n",
    "\n",
    "    for i, model in enumerate(RELEVANT_SCATTER_ORDERED_MODELS):\n",
    "        ax = axes[i]\n",
    "        ax.set_title(model, fontsize=fontsize + 4)\n",
    "        ax.set_xlabel(\"Index\", fontsize=fontsize)\n",
    "        ax.set_ylabel(\"% top heads shared\", fontsize=fontsize)\n",
    "\n",
    "        model_max_n_top_heads = max_n_top_heads\n",
    "        if isinstance(model_max_n_top_heads, float):\n",
    "            layers, heads = MODEL_TO_N_LAYERS_HEADS[model]\n",
    "            total_heads = layers * heads\n",
    "            model_max_n_top_heads = int(total_heads * model_max_n_top_heads)\n",
    "\n",
    "        for first_key, second_key in itertools.combinations(comparison_groups.keys(), 2):\n",
    "            x_values = np.arange(min_n_top_heads, model_max_n_top_heads + 1)\n",
    "            first_kwargs = comparison_groups[first_key]\n",
    "            second_kwargs = comparison_groups[second_key]\n",
    "\n",
    "            first_top_heads, _ = compute_top_heads(\n",
    "                result_df,\n",
    "                indirect_effects_by_model_and_dataset,\n",
    "                model,\n",
    "                **first_kwargs,\n",
    "                n_top_heads=model_max_n_top_heads,\n",
    "            )\n",
    "            second_top_heads, _ = compute_top_heads(\n",
    "                result_df,\n",
    "                indirect_effects_by_model_and_dataset,\n",
    "                model,\n",
    "                **second_kwargs,\n",
    "                n_top_heads=model_max_n_top_heads,\n",
    "            )\n",
    "\n",
    "            first_top_head_set = set([tuple(t) for t in first_top_heads[:min_n_top_heads]])\n",
    "            second_top_head_set = set([tuple(t) for t in second_top_heads[:min_n_top_heads]])\n",
    "            shared_fractions = []\n",
    "\n",
    "            for n in range(min_n_top_heads, model_max_n_top_heads + 1):\n",
    "                first_top_head_set.add(tuple(first_top_heads[n - 1]))\n",
    "                second_top_head_set.add(tuple(second_top_heads[n - 1]))\n",
    "                shared = len(first_top_head_set & second_top_head_set)\n",
    "                shared_fraction = shared / n\n",
    "                shared_fractions.append(shared_fraction)\n",
    "\n",
    "            if moving_average_periods is not None:\n",
    "                shared_fractions = moving_average(np.array(shared_fractions), moving_average_periods)\n",
    "                missing = int(np.floor(moving_average_periods / 2))\n",
    "                x_values = x_values[missing:-missing]\n",
    "\n",
    "            ax.plot(x_values, shared_fractions, label=f\"{first_key} vs {second_key}\", alpha=0.5)\n",
    "            ax.tick_params(axis=\"x\", labelsize=fontsize - 4)\n",
    "            ax.tick_params(axis=\"y\", labelsize=fontsize - 4)\n",
    "            if len(comparison_groups) > 2:\n",
    "                ax.legend(fontsize=fontsize - 4)\n",
    "\n",
    "    if title is not None:\n",
    "        if moving_average_periods is not None:\n",
    "            title = f\"{title} ({moving_average_periods}-step moving average)\"\n",
    "        fig.suptitle(title, fontsize=fontsize + 8)\n",
    "    plt.tight_layout()\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "scatter_plot_top_heads_overlap(\n",
    "    comparison_groups={\n",
    "        SHORT: dict(prompt_types=[SHORT]),\n",
    "        LONG: dict(prompt_types=[LONG]),\n",
    "    },\n",
    "    max_n_top_heads=0.25,\n",
    "    moving_average_periods=5,\n",
    "    title=\"Short vs. long prompts top head overlap\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "scatter_plot_top_heads_overlap(\n",
    "    comparison_groups={b: dict(prompt_types=[SHORT, LONG], baselines=[b]) for b in BASELINES},\n",
    "    max_n_top_heads=0.25,\n",
    "    moving_average_periods=5,\n",
    "    title=\"Top head overlap by baseline\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "scatter_plot_top_heads_overlap(\n",
    "    comparison_groups={\n",
    "        \"Prompts\": dict(prompt_types=[SHORT, LONG]),\n",
    "        ICL: dict(prompt_types=[ICL]),\n",
    "    },\n",
    "    max_n_top_heads=0.25,\n",
    "    moving_average_periods=5,\n",
    "    title=\"Prompts vs. ICL top head overlap\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "fv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
