{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import itertools\n",
    "import json\n",
    "import os\n",
    "import sys\n",
    "import typing\n",
    "from collections import defaultdict\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import seaborn as sns\n",
    "import tabulate\n",
    "import torch\n",
    "from IPython.display import Markdown, display\n",
    "from loguru import logger\n",
    "from tqdm.auto import tqdm\n",
    "\n",
    "torch.set_grad_enabled(False)\n",
    "\n",
    "from shared_definitions import *\n",
    "\n",
    "sys.path.insert(0, os.path.abspath(\"..\"))\n",
    "\n",
    "sns.set_theme(style=\"white\", context=\"notebook\", rc={\"figure.figsize\": (14, 10)})\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[32m2025-05-13 23:21:11.572\u001b[0m | \u001b[1mINFO    \u001b[0m | \u001b[36mshared_definitions\u001b[0m:\u001b[36mload_and_combine_raw_results\u001b[0m:\u001b[36m416\u001b[0m - \u001b[1mLoading cached results from data/full_results.pkl.gz\u001b[0m\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>model</th>\n",
       "      <th>dataset</th>\n",
       "      <th>top_n_prompt_acc</th>\n",
       "      <th>top_1_prompt_acc</th>\n",
       "      <th>top_n_prompts</th>\n",
       "      <th>0_shot_acc</th>\n",
       "      <th>1_shot_acc</th>\n",
       "      <th>2_shot_acc</th>\n",
       "      <th>3_shot_acc</th>\n",
       "      <th>4_shot_acc</th>\n",
       "      <th>...</th>\n",
       "      <th>zs_universal_10_heads_max_acc_layer</th>\n",
       "      <th>zs_universal_20_heads_by_layer_acc</th>\n",
       "      <th>zs_universal_20_heads_max_acc</th>\n",
       "      <th>zs_universal_20_heads_max_acc_layer</th>\n",
       "      <th>fs_shuffled_universal_40_heads_by_layer_acc</th>\n",
       "      <th>fs_shuffled_universal_40_heads_max_acc</th>\n",
       "      <th>fs_shuffled_universal_40_heads_max_acc_layer</th>\n",
       "      <th>zs_universal_40_heads_by_layer_acc</th>\n",
       "      <th>zs_universal_40_heads_max_acc</th>\n",
       "      <th>zs_universal_40_heads_max_acc_layer</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Llama-3.2-1B-Instruct</td>\n",
       "      <td>choose_first_of_3</td>\n",
       "      <td>0.767429</td>\n",
       "      <td>0.888571</td>\n",
       "      <td>[Extract the first token, Output the word at t...</td>\n",
       "      <td>0.636667</td>\n",
       "      <td>0.863333</td>\n",
       "      <td>0.896667</td>\n",
       "      <td>0.896667</td>\n",
       "      <td>0.913333</td>\n",
       "      <td>...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Llama-3.2-1B-Instruct</td>\n",
       "      <td>capitalize_first_letter</td>\n",
       "      <td>0.936907</td>\n",
       "      <td>0.963093</td>\n",
       "      <td>[What is the initial letter of this word?, Det...</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.688525</td>\n",
       "      <td>0.852459</td>\n",
       "      <td>0.918033</td>\n",
       "      <td>0.950820</td>\n",
       "      <td>...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>Llama-3.2-1B-Instruct</td>\n",
       "      <td>capitalize_second_letter</td>\n",
       "      <td>0.110364</td>\n",
       "      <td>0.160000</td>\n",
       "      <td>[What's the special letter in this word?, Dete...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>Llama-3.2-1B-Instruct</td>\n",
       "      <td>english-french</td>\n",
       "      <td>0.663200</td>\n",
       "      <td>0.678224</td>\n",
       "      <td>[English word in French, English word to Frenc...</td>\n",
       "      <td>0.011348</td>\n",
       "      <td>0.352482</td>\n",
       "      <td>0.587234</td>\n",
       "      <td>0.662411</td>\n",
       "      <td>0.687943</td>\n",
       "      <td>...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>Llama-3.2-1B-Instruct</td>\n",
       "      <td>choose_first_of_5</td>\n",
       "      <td>0.737286</td>\n",
       "      <td>0.772857</td>\n",
       "      <td>[Find the starting word of the input sequence,...</td>\n",
       "      <td>0.616667</td>\n",
       "      <td>0.810000</td>\n",
       "      <td>0.870000</td>\n",
       "      <td>0.890000</td>\n",
       "      <td>0.940000</td>\n",
       "      <td>...</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 256 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "                   model                   dataset  top_n_prompt_acc  \\\n",
       "0  Llama-3.2-1B-Instruct         choose_first_of_3          0.767429   \n",
       "1  Llama-3.2-1B-Instruct   capitalize_first_letter          0.936907   \n",
       "2  Llama-3.2-1B-Instruct  capitalize_second_letter          0.110364   \n",
       "3  Llama-3.2-1B-Instruct            english-french          0.663200   \n",
       "4  Llama-3.2-1B-Instruct         choose_first_of_5          0.737286   \n",
       "\n",
       "   top_1_prompt_acc                                      top_n_prompts  \\\n",
       "0          0.888571  [Extract the first token, Output the word at t...   \n",
       "1          0.963093  [What is the initial letter of this word?, Det...   \n",
       "2          0.160000  [What's the special letter in this word?, Dete...   \n",
       "3          0.678224  [English word in French, English word to Frenc...   \n",
       "4          0.772857  [Find the starting word of the input sequence,...   \n",
       "\n",
       "   0_shot_acc  1_shot_acc  2_shot_acc  3_shot_acc  4_shot_acc  ...  \\\n",
       "0    0.636667    0.863333    0.896667    0.896667    0.913333  ...   \n",
       "1    0.000000    0.688525    0.852459    0.918033    0.950820  ...   \n",
       "2         NaN         NaN         NaN         NaN         NaN  ...   \n",
       "3    0.011348    0.352482    0.587234    0.662411    0.687943  ...   \n",
       "4    0.616667    0.810000    0.870000    0.890000    0.940000  ...   \n",
       "\n",
       "   zs_universal_10_heads_max_acc_layer  zs_universal_20_heads_by_layer_acc  \\\n",
       "0                                  NaN                                 NaN   \n",
       "1                                  NaN                                 NaN   \n",
       "2                                  NaN                                 NaN   \n",
       "3                                  NaN                                 NaN   \n",
       "4                                  NaN                                 NaN   \n",
       "\n",
       "   zs_universal_20_heads_max_acc  zs_universal_20_heads_max_acc_layer  \\\n",
       "0                            NaN                                  NaN   \n",
       "1                            NaN                                  NaN   \n",
       "2                            NaN                                  NaN   \n",
       "3                            NaN                                  NaN   \n",
       "4                            NaN                                  NaN   \n",
       "\n",
       "   fs_shuffled_universal_40_heads_by_layer_acc  \\\n",
       "0                                          NaN   \n",
       "1                                          NaN   \n",
       "2                                          NaN   \n",
       "3                                          NaN   \n",
       "4                                          NaN   \n",
       "\n",
       "   fs_shuffled_universal_40_heads_max_acc  \\\n",
       "0                                     NaN   \n",
       "1                                     NaN   \n",
       "2                                     NaN   \n",
       "3                                     NaN   \n",
       "4                                     NaN   \n",
       "\n",
       "   fs_shuffled_universal_40_heads_max_acc_layer  \\\n",
       "0                                           NaN   \n",
       "1                                           NaN   \n",
       "2                                           NaN   \n",
       "3                                           NaN   \n",
       "4                                           NaN   \n",
       "\n",
       "  zs_universal_40_heads_by_layer_acc  zs_universal_40_heads_max_acc  \\\n",
       "0                                NaN                            NaN   \n",
       "1                                NaN                            NaN   \n",
       "2                                NaN                            NaN   \n",
       "3                                NaN                            NaN   \n",
       "4                                NaN                            NaN   \n",
       "\n",
       "   zs_universal_40_heads_max_acc_layer  \n",
       "0                                  NaN  \n",
       "1                                  NaN  \n",
       "2                                  NaN  \n",
       "3                                  NaN  \n",
       "4                                  NaN  \n",
       "\n",
       "[5 rows x 256 columns]"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "USE_SAME_TEST_SETS_RESULTS = True\n",
    "if USE_SAME_TEST_SETS_RESULTS:\n",
    "    alternative_paths = {ICL: ICL_SAME_TEST_SETS_RESULTS_ROOT}\n",
    "else:\n",
    "    alternative_paths = None\n",
    "\n",
    "result_df, indirect_effects_by_model_and_dataset, top_heads_by_model_and_dataset = load_and_combine_raw_results(\n",
    "    alternative_paths\n",
    ")\n",
    "result_df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "650987dbf894474a8fcac186773283ff",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Universal top heads:   0%|          | 0/504 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[32m2025-05-13 23:22:38.667\u001b[0m | \u001b[33m\u001b[1mWARNING \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m47\u001b[0m - \u001b[33m\u001b[1mModel ('Llama-3.2-1B', 'short') datasets below chance accuracy: capitalize_last_letter, capitalize_second_letter, lowercase_last_letter, next_capital_letter\u001b[0m\n",
      "\u001b[32m2025-05-13 23:22:38.668\u001b[0m | \u001b[33m\u001b[1mWARNING \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m47\u001b[0m - \u001b[33m\u001b[1mModel ('Llama-3.2-1B', 'long') datasets below chance accuracy: capitalize_last_letter, capitalize_second_letter, lowercase_last_letter, next_capital_letter\u001b[0m\n",
      "\u001b[32m2025-05-13 23:22:38.668\u001b[0m | \u001b[33m\u001b[1mWARNING \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m47\u001b[0m - \u001b[33m\u001b[1mModel ('Llama-3.2-1B-Instruct', 'short') datasets below chance accuracy: lowercase_last_letter\u001b[0m\n",
      "\u001b[32m2025-05-13 23:22:38.669\u001b[0m | \u001b[33m\u001b[1mWARNING \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m47\u001b[0m - \u001b[33m\u001b[1mModel ('Llama-3.2-1B-Instruct', 'long') datasets below chance accuracy: lowercase_last_letter\u001b[0m\n",
      "\u001b[32m2025-05-13 23:22:38.669\u001b[0m | \u001b[33m\u001b[1mWARNING \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m47\u001b[0m - \u001b[33m\u001b[1mModel ('Llama-3.2-3B', 'short') datasets below chance accuracy: capitalize_last_letter, capitalize_second_letter, lowercase_last_letter, next_capital_letter\u001b[0m\n",
      "\u001b[32m2025-05-13 23:22:38.670\u001b[0m | \u001b[33m\u001b[1mWARNING \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m47\u001b[0m - \u001b[33m\u001b[1mModel ('Llama-3.2-3B', 'long') datasets below chance accuracy: alphabetically_last_5, capitalize_last_letter, capitalize_second_letter, lowercase_last_letter, next_capital_letter\u001b[0m\n",
      "\u001b[32m2025-05-13 23:22:38.670\u001b[0m | \u001b[33m\u001b[1mWARNING \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m47\u001b[0m - \u001b[33m\u001b[1mModel ('Llama-3.2-3B-Instruct', 'short') datasets below chance accuracy: lowercase_last_letter, next_capital_letter\u001b[0m\n",
      "\u001b[32m2025-05-13 23:22:38.671\u001b[0m | \u001b[33m\u001b[1mWARNING \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m47\u001b[0m - \u001b[33m\u001b[1mModel ('Llama-3.2-3B-Instruct', 'long') datasets below chance accuracy: lowercase_last_letter, next_capital_letter\u001b[0m\n",
      "\u001b[32m2025-05-13 23:22:38.671\u001b[0m | \u001b[33m\u001b[1mWARNING \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m47\u001b[0m - \u001b[33m\u001b[1mModel ('Llama-3.1-8B', 'short') datasets below chance accuracy: alphabetically_last_5, capitalize_last_letter, capitalize_second_letter, lowercase_last_letter, next_capital_letter\u001b[0m\n",
      "\u001b[32m2025-05-13 23:22:38.672\u001b[0m | \u001b[33m\u001b[1mWARNING \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m47\u001b[0m - \u001b[33m\u001b[1mModel ('Llama-3.1-8B', 'long') datasets below chance accuracy: alphabetically_last_5, capitalize_last_letter, capitalize_second_letter, next_capital_letter\u001b[0m\n",
      "\u001b[32m2025-05-13 23:22:38.672\u001b[0m | \u001b[33m\u001b[1mWARNING \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m47\u001b[0m - \u001b[33m\u001b[1mModel ('Llama-3.1-8B-Instruct', 'short') datasets below chance accuracy: lowercase_last_letter, next_capital_letter\u001b[0m\n",
      "\u001b[32m2025-05-13 23:22:38.673\u001b[0m | \u001b[33m\u001b[1mWARNING \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m47\u001b[0m - \u001b[33m\u001b[1mModel ('Llama-3.1-8B-Instruct', 'long') datasets below chance accuracy: alphabetically_last_5, lowercase_last_letter\u001b[0m\n",
      "\u001b[32m2025-05-13 23:22:38.673\u001b[0m | \u001b[33m\u001b[1mWARNING \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m47\u001b[0m - \u001b[33m\u001b[1mModel ('Llama-2-7b-hf', 'short') datasets below chance accuracy: alphabetically_last_5, capitalize_last_letter, lowercase_first_letter, lowercase_last_letter, next_capital_letter\u001b[0m\n",
      "\u001b[32m2025-05-13 23:22:38.674\u001b[0m | \u001b[33m\u001b[1mWARNING \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m47\u001b[0m - \u001b[33m\u001b[1mModel ('Llama-2-7b-hf', 'long') datasets below chance accuracy: alphabetically_last_5, capitalize_second_letter, lowercase_last_letter, next_capital_letter\u001b[0m\n",
      "\u001b[32m2025-05-13 23:22:38.674\u001b[0m | \u001b[33m\u001b[1mWARNING \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m47\u001b[0m - \u001b[33m\u001b[1mModel ('Llama-2-7b-chat-hf', 'short') datasets below chance accuracy: lowercase_last_letter\u001b[0m\n",
      "\u001b[32m2025-05-13 23:22:38.675\u001b[0m | \u001b[33m\u001b[1mWARNING \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m47\u001b[0m - \u001b[33m\u001b[1mModel ('Llama-2-13b-hf', 'short') datasets below chance accuracy: alphabetically_last_3, capitalize_last_letter, capitalize_second_letter, lowercase_last_letter, next_capital_letter\u001b[0m\n",
      "\u001b[32m2025-05-13 23:22:38.675\u001b[0m | \u001b[33m\u001b[1mWARNING \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m47\u001b[0m - \u001b[33m\u001b[1mModel ('Llama-2-13b-hf', 'long') datasets below chance accuracy: alphabetically_last_3, capitalize_second_letter, lowercase_last_letter, next_capital_letter\u001b[0m\n",
      "\u001b[32m2025-05-13 23:22:38.676\u001b[0m | \u001b[33m\u001b[1mWARNING \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m47\u001b[0m - \u001b[33m\u001b[1mModel ('Llama-2-13b-chat-hf', 'short') datasets below chance accuracy: capitalize_last_letter, capitalize_second_letter, lowercase_last_letter, next_capital_letter\u001b[0m\n",
      "\u001b[32m2025-05-13 23:22:38.676\u001b[0m | \u001b[33m\u001b[1mWARNING \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m47\u001b[0m - \u001b[33m\u001b[1mModel ('Llama-2-13b-chat-hf', 'long') datasets below chance accuracy: capitalize_last_letter, capitalize_second_letter, lowercase_last_letter, next_capital_letter\u001b[0m\n",
      "\u001b[32m2025-05-13 23:22:38.677\u001b[0m | \u001b[33m\u001b[1mWARNING \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m47\u001b[0m - \u001b[33m\u001b[1mModel ('OLMo-2-1124-7B', 'short') datasets below chance accuracy: capitalize_last_letter, capitalize_second_letter, next_capital_letter\u001b[0m\n",
      "\u001b[32m2025-05-13 23:22:38.677\u001b[0m | \u001b[33m\u001b[1mWARNING \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m47\u001b[0m - \u001b[33m\u001b[1mModel ('OLMo-2-1124-7B', 'long') datasets below chance accuracy: capitalize_last_letter, capitalize_second_letter, next_capital_letter\u001b[0m\n",
      "\u001b[32m2025-05-13 23:22:38.678\u001b[0m | \u001b[33m\u001b[1mWARNING \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m47\u001b[0m - \u001b[33m\u001b[1mModel ('OLMo-2-1124-7B-SFT', 'short') datasets below chance accuracy: capitalize_last_letter, capitalize_second_letter, next_capital_letter\u001b[0m\n",
      "\u001b[32m2025-05-13 23:22:38.678\u001b[0m | \u001b[33m\u001b[1mWARNING \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m47\u001b[0m - \u001b[33m\u001b[1mModel ('OLMo-2-1124-7B-SFT', 'long') datasets below chance accuracy: capitalize_last_letter, capitalize_second_letter, next_capital_letter\u001b[0m\n",
      "\u001b[32m2025-05-13 23:22:38.679\u001b[0m | \u001b[33m\u001b[1mWARNING \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m47\u001b[0m - \u001b[33m\u001b[1mModel ('OLMo-2-1124-7B-DPO', 'short') datasets below chance accuracy: capitalize_last_letter, capitalize_second_letter, next_capital_letter\u001b[0m\n",
      "\u001b[32m2025-05-13 23:22:38.679\u001b[0m | \u001b[33m\u001b[1mWARNING \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m47\u001b[0m - \u001b[33m\u001b[1mModel ('OLMo-2-1124-7B-DPO', 'long') datasets below chance accuracy: capitalize_last_letter, capitalize_second_letter, next_capital_letter\u001b[0m\n",
      "\u001b[32m2025-05-13 23:22:38.680\u001b[0m | \u001b[33m\u001b[1mWARNING \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m47\u001b[0m - \u001b[33m\u001b[1mModel ('OLMo-2-1124-7B-Instruct', 'short') datasets below chance accuracy: capitalize_last_letter, capitalize_second_letter, next_capital_letter\u001b[0m\n",
      "\u001b[32m2025-05-13 23:22:38.680\u001b[0m | \u001b[33m\u001b[1mWARNING \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m47\u001b[0m - \u001b[33m\u001b[1mModel ('OLMo-2-1124-7B-Instruct', 'long') datasets below chance accuracy: capitalize_last_letter, capitalize_second_letter, next_capital_letter\u001b[0m\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "226c4f44fc5d4f25aaaf8df8c3a0ac17",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Universal top heads ICL:   0%|          | 0/42 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[32m2025-05-13 23:22:44.293\u001b[0m | \u001b[33m\u001b[1mWARNING \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m84\u001b[0m - \u001b[33m\u001b[1mModel ('Llama-3.2-1B', 'icl') datasets below chance accuracy: alphabetically_last_3, alphabetically_last_5, next_capital_letter\u001b[0m\n",
      "\u001b[32m2025-05-13 23:22:44.294\u001b[0m | \u001b[33m\u001b[1mWARNING \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m84\u001b[0m - \u001b[33m\u001b[1mModel ('Llama-3.2-1B-Instruct', 'icl') datasets below chance accuracy: alphabetically_last_3\u001b[0m\n",
      "\u001b[32m2025-05-13 23:22:44.295\u001b[0m | \u001b[33m\u001b[1mWARNING \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m84\u001b[0m - \u001b[33m\u001b[1mModel ('Llama-3.2-3B', 'icl') datasets below chance accuracy: next_capital_letter\u001b[0m\n",
      "\u001b[32m2025-05-13 23:22:44.295\u001b[0m | \u001b[33m\u001b[1mWARNING \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m84\u001b[0m - \u001b[33m\u001b[1mModel ('Llama-3.2-3B-Instruct', 'icl') datasets below chance accuracy: alphabetically_last_3\u001b[0m\n",
      "\u001b[32m2025-05-13 23:22:44.296\u001b[0m | \u001b[33m\u001b[1mWARNING \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m84\u001b[0m - \u001b[33m\u001b[1mModel ('Llama-3.1-8B', 'icl') datasets below chance accuracy: alphabetically_last_3, next_capital_letter\u001b[0m\n",
      "\u001b[32m2025-05-13 23:22:44.296\u001b[0m | \u001b[33m\u001b[1mWARNING \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m84\u001b[0m - \u001b[33m\u001b[1mModel ('Llama-2-7b-hf', 'icl') datasets below chance accuracy: alphabetically_last_3, next_capital_letter\u001b[0m\n",
      "\u001b[32m2025-05-13 23:22:44.297\u001b[0m | \u001b[33m\u001b[1mWARNING \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m84\u001b[0m - \u001b[33m\u001b[1mModel ('Llama-2-7b-chat-hf', 'icl') datasets below chance accuracy: alphabetically_first_3, alphabetically_last_3, next_capital_letter\u001b[0m\n",
      "\u001b[32m2025-05-13 23:22:44.297\u001b[0m | \u001b[33m\u001b[1mWARNING \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m84\u001b[0m - \u001b[33m\u001b[1mModel ('Llama-2-13b-hf', 'icl') datasets below chance accuracy: alphabetically_first_3, next_capital_letter\u001b[0m\n",
      "\u001b[32m2025-05-13 23:22:44.298\u001b[0m | \u001b[33m\u001b[1mWARNING \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m84\u001b[0m - \u001b[33m\u001b[1mModel ('Llama-2-13b-chat-hf', 'icl') datasets below chance accuracy: alphabetically_first_3, next_capital_letter\u001b[0m\n",
      "\u001b[32m2025-05-13 23:22:44.298\u001b[0m | \u001b[33m\u001b[1mWARNING \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m84\u001b[0m - \u001b[33m\u001b[1mModel ('OLMo-2-1124-7B', 'icl') datasets below chance accuracy: alphabetically_first_3, alphabetically_first_5, alphabetically_last_3, next_capital_letter\u001b[0m\n",
      "\u001b[32m2025-05-13 23:22:44.299\u001b[0m | \u001b[33m\u001b[1mWARNING \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m84\u001b[0m - \u001b[33m\u001b[1mModel ('OLMo-2-1124-7B-SFT', 'icl') datasets below chance accuracy: alphabetically_first_3, alphabetically_last_3, alphabetically_last_5, next_capital_letter\u001b[0m\n",
      "\u001b[32m2025-05-13 23:22:44.299\u001b[0m | \u001b[33m\u001b[1mWARNING \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m84\u001b[0m - \u001b[33m\u001b[1mModel ('OLMo-2-1124-7B-DPO', 'icl') datasets below chance accuracy: alphabetically_first_3, alphabetically_last_3, next_capital_letter\u001b[0m\n",
      "\u001b[32m2025-05-13 23:22:44.300\u001b[0m | \u001b[33m\u001b[1mWARNING \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36m<module>\u001b[0m:\u001b[36m84\u001b[0m - \u001b[33m\u001b[1mModel ('OLMo-2-1124-7B-Instruct', 'icl') datasets below chance accuracy: alphabetically_first_3, alphabetically_last_3, next_capital_letter\u001b[0m\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>model</th>\n",
       "      <th>prompt_type</th>\n",
       "      <th>baseline</th>\n",
       "      <th>n</th>\n",
       "      <th>top_heads</th>\n",
       "      <th>top_head_effects</th>\n",
       "      <th>top_heads_list</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Llama-3.2-1B</td>\n",
       "      <td>short</td>\n",
       "      <td>equiprobable</td>\n",
       "      <td>10</td>\n",
       "      <td>{(15, 11), (12, 15), (8, 29), (11, 4), (7, 2),...</td>\n",
       "      <td>[0.011242678388953209, 0.008070942014455795, 0...</td>\n",
       "      <td>[[12, 15], [15, 12], [8, 12], [15, 2], [10, 22...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Llama-3.2-1B</td>\n",
       "      <td>short</td>\n",
       "      <td>equiprobable</td>\n",
       "      <td>20</td>\n",
       "      <td>{(14, 4), (7, 26), (14, 25), (9, 14), (8, 12),...</td>\n",
       "      <td>[0.011242678388953209, 0.008070942014455795, 0...</td>\n",
       "      <td>[[12, 15], [15, 12], [8, 12], [15, 2], [10, 22...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>Llama-3.2-1B</td>\n",
       "      <td>short</td>\n",
       "      <td>equiprobable</td>\n",
       "      <td>100</td>\n",
       "      <td>{(7, 17), (7, 26), (5, 1), (8, 9), (8, 18), (1...</td>\n",
       "      <td>[0.011242678388953209, 0.008070942014455795, 0...</td>\n",
       "      <td>[[12, 15], [15, 12], [8, 12], [15, 2], [10, 22...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>Llama-3.2-1B</td>\n",
       "      <td>short</td>\n",
       "      <td>real_text</td>\n",
       "      <td>10</td>\n",
       "      <td>{(15, 11), (9, 13), (8, 29), (8, 12), (11, 4),...</td>\n",
       "      <td>[0.01034275908023119, 0.007226495537906885, 0....</td>\n",
       "      <td>[[12, 15], [15, 12], [10, 22], [15, 2], [9, 7]...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>Llama-3.2-1B</td>\n",
       "      <td>short</td>\n",
       "      <td>real_text</td>\n",
       "      <td>20</td>\n",
       "      <td>{(14, 4), (14, 25), (8, 12), (10, 9), (15, 2),...</td>\n",
       "      <td>[0.01034275908023119, 0.007226495537906885, 0....</td>\n",
       "      <td>[[12, 15], [15, 12], [10, 22], [15, 2], [9, 7]...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>541</th>\n",
       "      <td>OLMo-2-1124-7B-DPO</td>\n",
       "      <td>icl</td>\n",
       "      <td>icl</td>\n",
       "      <td>20</td>\n",
       "      <td>{(9, 2), (12, 25), (12, 31), (19, 18), (17, 18...</td>\n",
       "      <td>[0.10074028372764587, 0.09945445507764816, 0.0...</td>\n",
       "      <td>[[15, 0], [15, 11], [17, 18], [15, 18], [10, 1...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>542</th>\n",
       "      <td>OLMo-2-1124-7B-DPO</td>\n",
       "      <td>icl</td>\n",
       "      <td>icl</td>\n",
       "      <td>100</td>\n",
       "      <td>{(16, 20), (15, 30), (16, 29), (22, 17), (31, ...</td>\n",
       "      <td>[0.10074028372764587, 0.09945445507764816, 0.0...</td>\n",
       "      <td>[[15, 0], [15, 11], [17, 18], [15, 18], [10, 1...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>543</th>\n",
       "      <td>OLMo-2-1124-7B-Instruct</td>\n",
       "      <td>icl</td>\n",
       "      <td>icl</td>\n",
       "      <td>10</td>\n",
       "      <td>{(15, 11), (12, 24), (20, 1), (15, 0), (10, 12...</td>\n",
       "      <td>[0.1116589605808258, 0.09931758791208267, 0.06...</td>\n",
       "      <td>[[15, 0], [15, 11], [17, 18], [15, 18], [10, 1...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>544</th>\n",
       "      <td>OLMo-2-1124-7B-Instruct</td>\n",
       "      <td>icl</td>\n",
       "      <td>icl</td>\n",
       "      <td>20</td>\n",
       "      <td>{(12, 31), (19, 18), (17, 18), (10, 12), (15, ...</td>\n",
       "      <td>[0.1116589605808258, 0.09931758791208267, 0.06...</td>\n",
       "      <td>[[15, 0], [15, 11], [17, 18], [15, 18], [10, 1...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>545</th>\n",
       "      <td>OLMo-2-1124-7B-Instruct</td>\n",
       "      <td>icl</td>\n",
       "      <td>icl</td>\n",
       "      <td>100</td>\n",
       "      <td>{(16, 20), (15, 30), (16, 29), (31, 29), (30, ...</td>\n",
       "      <td>[0.1116589605808258, 0.09931758791208267, 0.06...</td>\n",
       "      <td>[[15, 0], [15, 11], [17, 18], [15, 18], [10, 1...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>546 rows × 7 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "                       model prompt_type      baseline    n  \\\n",
       "0               Llama-3.2-1B       short  equiprobable   10   \n",
       "1               Llama-3.2-1B       short  equiprobable   20   \n",
       "2               Llama-3.2-1B       short  equiprobable  100   \n",
       "3               Llama-3.2-1B       short     real_text   10   \n",
       "4               Llama-3.2-1B       short     real_text   20   \n",
       "..                       ...         ...           ...  ...   \n",
       "541       OLMo-2-1124-7B-DPO         icl           icl   20   \n",
       "542       OLMo-2-1124-7B-DPO         icl           icl  100   \n",
       "543  OLMo-2-1124-7B-Instruct         icl           icl   10   \n",
       "544  OLMo-2-1124-7B-Instruct         icl           icl   20   \n",
       "545  OLMo-2-1124-7B-Instruct         icl           icl  100   \n",
       "\n",
       "                                             top_heads  \\\n",
       "0    {(15, 11), (12, 15), (8, 29), (11, 4), (7, 2),...   \n",
       "1    {(14, 4), (7, 26), (14, 25), (9, 14), (8, 12),...   \n",
       "2    {(7, 17), (7, 26), (5, 1), (8, 9), (8, 18), (1...   \n",
       "3    {(15, 11), (9, 13), (8, 29), (8, 12), (11, 4),...   \n",
       "4    {(14, 4), (14, 25), (8, 12), (10, 9), (15, 2),...   \n",
       "..                                                 ...   \n",
       "541  {(9, 2), (12, 25), (12, 31), (19, 18), (17, 18...   \n",
       "542  {(16, 20), (15, 30), (16, 29), (22, 17), (31, ...   \n",
       "543  {(15, 11), (12, 24), (20, 1), (15, 0), (10, 12...   \n",
       "544  {(12, 31), (19, 18), (17, 18), (10, 12), (15, ...   \n",
       "545  {(16, 20), (15, 30), (16, 29), (31, 29), (30, ...   \n",
       "\n",
       "                                      top_head_effects  \\\n",
       "0    [0.011242678388953209, 0.008070942014455795, 0...   \n",
       "1    [0.011242678388953209, 0.008070942014455795, 0...   \n",
       "2    [0.011242678388953209, 0.008070942014455795, 0...   \n",
       "3    [0.01034275908023119, 0.007226495537906885, 0....   \n",
       "4    [0.01034275908023119, 0.007226495537906885, 0....   \n",
       "..                                                 ...   \n",
       "541  [0.10074028372764587, 0.09945445507764816, 0.0...   \n",
       "542  [0.10074028372764587, 0.09945445507764816, 0.0...   \n",
       "543  [0.1116589605808258, 0.09931758791208267, 0.06...   \n",
       "544  [0.1116589605808258, 0.09931758791208267, 0.06...   \n",
       "545  [0.1116589605808258, 0.09931758791208267, 0.06...   \n",
       "\n",
       "                                        top_heads_list  \n",
       "0    [[12, 15], [15, 12], [8, 12], [15, 2], [10, 22...  \n",
       "1    [[12, 15], [15, 12], [8, 12], [15, 2], [10, 22...  \n",
       "2    [[12, 15], [15, 12], [8, 12], [15, 2], [10, 22...  \n",
       "3    [[12, 15], [15, 12], [10, 22], [15, 2], [9, 7]...  \n",
       "4    [[12, 15], [15, 12], [10, 22], [15, 2], [9, 7]...  \n",
       "..                                                 ...  \n",
       "541  [[15, 0], [15, 11], [17, 18], [15, 18], [10, 1...  \n",
       "542  [[15, 0], [15, 11], [17, 18], [15, 18], [10, 1...  \n",
       "543  [[15, 0], [15, 11], [17, 18], [15, 18], [10, 1...  \n",
       "544  [[15, 0], [15, 11], [17, 18], [15, 18], [10, 1...  \n",
       "545  [[15, 0], [15, 11], [17, 18], [15, 18], [10, 1...  \n",
       "\n",
       "[546 rows x 7 columns]"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from functools import reduce\n",
    "\n",
    "N_TOP_HEADS_VALUES = (10, 20, 100)\n",
    "\n",
    "\n",
    "top_head_rows = []\n",
    "mean_ie_by_model = {}\n",
    "\n",
    "tqdm_total = len(ORDERED_MODELS) * 3 * (len(RELEVANT_BASELINES) + 1) * len(N_TOP_HEADS_VALUES)\n",
    "prompt_datasets_below_chance_acc = defaultdict(set)\n",
    "\n",
    "for model, prompt_types, baseline, n_top_heads in tqdm(\n",
    "    itertools.product(\n",
    "        ORDERED_MODELS, [SHORT, LONG, [SHORT, LONG]], RELEVANT_BASELINES + [RELEVANT_BASELINES], N_TOP_HEADS_VALUES\n",
    "    ),\n",
    "    total=tqdm_total,\n",
    "    desc=\"Universal top heads\",\n",
    "):\n",
    "    top_heads, top_head_effects, mean_ie = compute_top_heads(\n",
    "        result_df,\n",
    "        indirect_effects_by_model_and_dataset,\n",
    "        model,\n",
    "        prompt_types,\n",
    "        baseline,\n",
    "        n_top_heads=n_top_heads,\n",
    "        datasets_below_chance_acc=prompt_datasets_below_chance_acc,\n",
    "        return_mean=True,\n",
    "    )\n",
    "    pt = prompt_types if isinstance(prompt_types, str) == 1 else BOTH\n",
    "    bl = baseline if isinstance(baseline, str) else ALL\n",
    "    top_head_rows.append(\n",
    "        dict(\n",
    "            model=model,\n",
    "            prompt_type=pt,\n",
    "            baseline=bl,\n",
    "            n=n_top_heads,\n",
    "            top_heads=set(tuple(t) for t in top_heads),\n",
    "            top_head_effects=top_head_effects,\n",
    "            top_heads_list=top_heads,\n",
    "        )\n",
    "    )\n",
    "\n",
    "    mean_ie_by_model[(model, pt, bl)] = mean_ie\n",
    "\n",
    "for model, datasets in prompt_datasets_below_chance_acc.items():\n",
    "    if len(datasets) > 0:\n",
    "        logger.warning(f\"Model {model} datasets below chance accuracy: {', '.join(sorted(datasets))}\")\n",
    "\n",
    "\n",
    "icl_datasets_below_chance_acc = defaultdict(set)\n",
    "\n",
    "for model, n_top_heads in tqdm(\n",
    "    itertools.product(ORDERED_MODELS, N_TOP_HEADS_VALUES),\n",
    "    total=len(ORDERED_MODELS) * len(N_TOP_HEADS_VALUES),\n",
    "    desc=\"Universal top heads ICL\",\n",
    "):\n",
    "    prompt_type = ICL\n",
    "    baseline = ICL\n",
    "    top_heads, top_head_effects, mean_ie = compute_top_heads(\n",
    "        result_df,\n",
    "        indirect_effects_by_model_and_dataset,\n",
    "        model,\n",
    "        prompt_type,\n",
    "        baseline,\n",
    "        n_top_heads=n_top_heads,\n",
    "        datasets_below_chance_acc=icl_datasets_below_chance_acc,\n",
    "        return_mean=True,\n",
    "    )\n",
    "    top_head_rows.append(\n",
    "        dict(\n",
    "            model=model,\n",
    "            prompt_type=prompt_type,\n",
    "            baseline=baseline,\n",
    "            n=n_top_heads,\n",
    "            top_heads=set(tuple(t) for t in top_heads),\n",
    "            top_head_effects=top_head_effects,\n",
    "            top_heads_list=top_heads,\n",
    "        )\n",
    "    )\n",
    "    mean_ie_by_model[(model, prompt_type, baseline)] = mean_ie\n",
    "\n",
    "for model, datasets in icl_datasets_below_chance_acc.items():\n",
    "    if len(datasets) > 0:\n",
    "        logger.warning(f\"Model {model} datasets below chance accuracy: {', '.join(sorted(datasets))}\")\n",
    "\n",
    "\n",
    "universal_top_heads_df = pd.DataFrame(top_head_rows)\n",
    "\n",
    "split_universal_top_heads_dfs = {\n",
    "    (model, n): universal_top_heads_df[(universal_top_heads_df.model == model) & (universal_top_heads_df.n == n)]\n",
    "    .copy(deep=True)\n",
    "    .reset_index(drop=True)\n",
    "    for model in ORDERED_MODELS\n",
    "    for n in N_TOP_HEADS_VALUES\n",
    "}\n",
    "\n",
    "for model, df in split_universal_top_heads_dfs.items():\n",
    "    th_to_idx = {th: i for i, th in enumerate(sorted(reduce(lambda th1, th2: th1 | th2, df.top_heads, set())))}\n",
    "\n",
    "    def ths_to_binary_vector(ths):\n",
    "        v = np.zeros(len(th_to_idx), dtype=np.ubyte)\n",
    "        if len(ths) == 0:\n",
    "            return v\n",
    "        idxs = np.array([th_to_idx[th] for th in ths])\n",
    "        v[idxs] = 1\n",
    "        return v\n",
    "\n",
    "    df[\"top_heads_vector\"] = df.top_heads.apply(ths_to_binary_vector)\n",
    "\n",
    "universal_top_heads_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "TOP_HEADS_DIR = f\"{STORAGE_ROOT}/function_vectors/full_results_top_heads\"\n",
    "\n",
    "\n",
    "def write_top_heads_to_files(\n",
    "    df: pd.DataFrame,\n",
    "    prompt_type: str = \"both\",\n",
    "    baseline: str = \"all\",\n",
    "    n_top_heads: int = 100,\n",
    "    model: str | None = None,\n",
    "    output_dir: str = TOP_HEADS_DIR,\n",
    "    overwrite: bool = False,\n",
    "):\n",
    "    if prompt_type == \"ICL\" and baseline != \"icl\":\n",
    "        raise ValueError(\"For ICL, baseline must be 'icl'\")\n",
    "\n",
    "    filtered_df = df.copy(deep=True)\n",
    "    if model is not None:\n",
    "        filtered_df = filtered_df[filtered_df.model.str.contains(model)]\n",
    "\n",
    "    filtered_df = filtered_df[\n",
    "        (filtered_df.prompt_type == prompt_type) & (filtered_df.baseline == baseline) & (filtered_df.n == n_top_heads)\n",
    "    ].reset_index(drop=True)\n",
    "\n",
    "    for i, row in filtered_df.iterrows():\n",
    "        _save_single_top_head_set(\n",
    "            row.top_heads_list,\n",
    "            row.top_head_effects,\n",
    "            row.model,\n",
    "            row.prompt_type, \n",
    "            row.baseline,\n",
    "            output_dir=output_dir,\n",
    "            overwrite=overwrite,\n",
    "        )\n",
    "\n",
    "\n",
    "def _save_single_top_head_set(\n",
    "    model_top_heads: typing.List[typing.Tuple[int, int]],\n",
    "    model_top_head_effects: typing.List[float],\n",
    "    model: str,\n",
    "    prompt_type: str,\n",
    "    baseline: str,\n",
    "    head_type: str = \"top\",\n",
    "    output_dir: str = TOP_HEADS_DIR,\n",
    "    overwrite: bool = False,\n",
    "):\n",
    "    if baseline == ICL:\n",
    "        filename = (\n",
    "            f\"{model}_{baseline}_{'same_test_sets_' if USE_SAME_TEST_SETS_RESULTS else ''}{head_type}_heads.json\"\n",
    "        )\n",
    "    else:\n",
    "        filename = f\"{model}_{prompt_type}_{baseline}_{head_type}_heads.json\"\n",
    "\n",
    "    filepath = os.path.join(output_dir, filename)\n",
    "\n",
    "    if os.path.exists(filepath) and not overwrite:\n",
    "        logger.warning(f\"{filepath} already exists, skipping...\")\n",
    "        return\n",
    "\n",
    "    with open(filepath, \"w\") as f:\n",
    "        output = dict(\n",
    "            top_heads=model_top_heads,\n",
    "            top_head_effects=model_top_head_effects,\n",
    "        )\n",
    "        json.dump(output, f)\n",
    "\n",
    "    logger.info(f\"Saved {head_type} heads to {filepath}\")\n",
    "\n",
    "\n",
    "write_top_heads_to_files(universal_top_heads_df, overwrite=False)\n",
    "write_top_heads_to_files(universal_top_heads_df, prompt_type=ICL, baseline=ICL, overwrite=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "| model                   | head_type   |   n_top_n |   n_shared | n_shared_perc   |   mean_same_ie |   mean_other_ie |   same_other_ratio |\n",
       "|-------------------------|-------------|-----------|------------|-----------------|----------------|-----------------|--------------------|\n",
       "| Llama-3.2-1B            | prompt      |        10 |          6 | 60%             |     0.00670754 |      0.00265102 |           2.53017  |\n",
       "| Llama-3.2-1B            | icl         |        10 |          2 | 20%             |     0.0146377  |      0.00438629 |           3.33715  |\n",
       "| Llama-3.2-1B            | prompt      |        20 |         14 | 70%             |     0.00462739 |      0.00363957 |           1.27141  |\n",
       "| Llama-3.2-1B            | icl         |        20 |          9 | 45%             |     0.00665989 |      0.00230238 |           2.89261  |\n",
       "| \u0001 |\n",
       "| Llama-3.2-1B-Instruct   | prompt      |        10 |          8 | 80%             |     0.00772535 |      0.00592654 |           1.30352  |\n",
       "| Llama-3.2-1B-Instruct   | icl         |        10 |          2 | 20%             |     0.0169697  |      0.00693619 |           2.44655  |\n",
       "| Llama-3.2-1B-Instruct   | prompt      |        20 |         15 | 75%             |     0.00596287 |      0.00459147 |           1.29868  |\n",
       "| Llama-3.2-1B-Instruct   | icl         |        20 |          8 | 40%             |     0.00816929 |      0.00439157 |           1.86022  |\n",
       "| \u0001 |\n",
       "| Llama-3.2-3B            | prompt      |        10 |          8 | 80%             |     0.00463581 |      0.00374096 |           1.2392   |\n",
       "| Llama-3.2-3B            | icl         |        10 |          6 | 60%             |     0.0288357  |      0.00315726 |           9.13316  |\n",
       "| Llama-3.2-3B            | prompt      |        20 |         15 | 75%             |     0.0037405  |      0.00403034 |           0.928085 |\n",
       "| Llama-3.2-3B            | icl         |        20 |         13 | 65%             |     0.0152964  |      0.00273862 |           5.58544  |\n",
       "| \u0001 |\n",
       "| Llama-3.2-3B-Instruct   | prompt      |        10 |          8 | 80%             |     0.00972258 |      0.0076092  |           1.27774  |\n",
       "| Llama-3.2-3B-Instruct   | icl         |        10 |          5 | 50%             |     0.0293172  |      0.005022   |           5.83776  |\n",
       "| Llama-3.2-3B-Instruct   | prompt      |        20 |         14 | 70%             |     0.00774945 |      0.00635416 |           1.21959  |\n",
       "| Llama-3.2-3B-Instruct   | icl         |        20 |         10 | 50%             |     0.0177729  |      0.00478754 |           3.71232  |\n",
       "| \u0001 |\n",
       "| Llama-3.1-8B            | prompt      |        10 |          3 | 30%             |     0.0041534  |      0.00966502 |           0.429735 |\n",
       "| Llama-3.1-8B            | icl         |        10 |          3 | 30%             |     0.0144727  |      0.00227791 |           6.35349  |\n",
       "| Llama-3.1-8B            | prompt      |        20 |         12 | 60%             |     0.00272862 |      0.0053503  |           0.509993 |\n",
       "| Llama-3.1-8B            | icl         |        20 |          8 | 40%             |     0.00874924 |      0.0022177  |           3.94519  |\n",
       "| \u0001 |\n",
       "| Llama-3.1-8B-Instruct   | prompt      |        10 |          6 | 60%             |     0.00868464 |      0.0022671  |           3.83073  |\n",
       "| Llama-3.1-8B-Instruct   | icl         |        10 |          4 | 40%             |     0.0348586  |      0.00325419 |          10.7119   |\n",
       "| Llama-3.1-8B-Instruct   | prompt      |        20 |         14 | 70%             |     0.00659002 |      0.00549487 |           1.1993   |\n",
       "| Llama-3.1-8B-Instruct   | icl         |        20 |          8 | 40%             |     0.0203032  |      0.00388063 |           5.23192  |\n",
       "| \u0001 |\n",
       "| Llama-2-7b-hf           | prompt      |        10 |          6 | 60%             |     0.00400909 |      0.00374708 |           1.06992  |\n",
       "| Llama-2-7b-hf           | icl         |        10 |          2 | 20%             |     0.00920414 |      0.00390423 |           2.35748  |\n",
       "| Llama-2-7b-hf           | prompt      |        20 |         13 | 65%             |     0.00278093 |      0.00254216 |           1.09392  |\n",
       "| Llama-2-7b-hf           | icl         |        20 |          9 | 45%             |     0.00453616 |      0.00151827 |           2.98771  |\n",
       "| \u0001 |\n",
       "| Llama-2-7b-chat-hf      | prompt      |        10 |          7 | 70%             |     0.00974374 |      0.00566441 |           1.72017  |\n",
       "| Llama-2-7b-chat-hf      | icl         |        10 |          2 | 20%             |     0.0158251  |      0.00930756 |           1.70024  |\n",
       "| Llama-2-7b-chat-hf      | prompt      |        20 |         14 | 70%             |     0.00687839 |      0.00491496 |           1.39948  |\n",
       "| Llama-2-7b-chat-hf      | icl         |        20 |          6 | 30%             |     0.0115129  |      0.00465885 |           2.4712   |\n",
       "| \u0001 |\n",
       "| Llama-2-13b-hf          | prompt      |        10 |          8 | 80%             |     0.00294895 |      0.00254961 |           1.15662  |\n",
       "| Llama-2-13b-hf          | icl         |        10 |          2 | 20%             |     0.00927796 |      0.00386601 |           2.39988  |\n",
       "| Llama-2-13b-hf          | prompt      |        20 |         13 | 65%             |     0.0023589  |      0.00220807 |           1.06831  |\n",
       "| Llama-2-13b-hf          | icl         |        20 |          6 | 30%             |     0.00597015 |      0.0019831  |           3.01052  |\n",
       "| \u0001 |\n",
       "| Llama-2-13b-chat-hf     | prompt      |        10 |          6 | 60%             |     0.00997004 |      0.0111935  |           0.890699 |\n",
       "| Llama-2-13b-chat-hf     | icl         |        10 |          1 | 10%             |     0.0246561  |      0.0203341  |           1.21255  |\n",
       "| Llama-2-13b-chat-hf     | prompt      |        20 |          9 | 45%             |     0.00803173 |      0.00882307 |           0.910311 |\n",
       "| Llama-2-13b-chat-hf     | icl         |        20 |          5 | 25%             |     0.014739   |      0.00826954 |           1.78232  |\n",
       "| \u0001 |\n",
       "| OLMo-2-1124-7B          | prompt      |        10 |          9 | 90%             |     0.00897255 |      0.0067609  |           1.32712  |\n",
       "| OLMo-2-1124-7B          | icl         |        10 |          5 | 50%             |     0.0161641  |      0.00592047 |           2.7302   |\n",
       "| OLMo-2-1124-7B          | prompt      |        20 |         18 | 90%             |     0.0069243  |      0.00454355 |           1.52398  |\n",
       "| OLMo-2-1124-7B          | icl         |        20 |         12 | 60%             |     0.00928662 |      0.00614353 |           1.51161  |\n",
       "| \u0001 |\n",
       "| OLMo-2-1124-7B-SFT      | prompt      |        10 |          8 | 80%             |     0.00864657 |      0.00812169 |           1.06463  |\n",
       "| OLMo-2-1124-7B-SFT      | icl         |        10 |          6 | 60%             |     0.0289569  |      0.00434066 |           6.6711   |\n",
       "| OLMo-2-1124-7B-SFT      | prompt      |        20 |         16 | 80%             |     0.00677069 |      0.00667341 |           1.01458  |\n",
       "| OLMo-2-1124-7B-SFT      | icl         |        20 |         11 | 55%             |     0.0183659  |      0.00531998 |           3.45225  |\n",
       "| \u0001 |\n",
       "| OLMo-2-1124-7B-DPO      | prompt      |        10 |          6 | 60%             |     0.0137725  |      0.00599133 |           2.29873  |\n",
       "| OLMo-2-1124-7B-DPO      | icl         |        10 |          5 | 50%             |     0.0298255  |      0.00529881 |           5.62873  |\n",
       "| OLMo-2-1124-7B-DPO      | prompt      |        20 |         12 | 60%             |     0.010389   |      0.010331   |           1.00561  |\n",
       "| OLMo-2-1124-7B-DPO      | icl         |        20 |         12 | 60%             |     0.0177274  |      0.00727804 |           2.43574  |\n",
       "| \u0001 |\n",
       "| OLMo-2-1124-7B-Instruct | prompt      |        10 |          6 | 60%             |     0.0145396  |      0.00806555 |           1.80268  |\n",
       "| OLMo-2-1124-7B-Instruct | icl         |        10 |          5 | 50%             |     0.0332614  |      0.00553057 |           6.0141   |\n",
       "| OLMo-2-1124-7B-Instruct | prompt      |        20 |         12 | 60%             |     0.010936   |      0.0152142  |           0.718798 |\n",
       "| OLMo-2-1124-7B-Instruct | icl         |        20 |         12 | 60%             |     0.019279   |      0.00774668 |           2.48868  |\n",
       "| \u0001 |"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "table_rows = []\n",
    "\n",
    "for model in ORDERED_MODELS:\n",
    "    prompt_top_100 = universal_top_heads_df[\n",
    "        (universal_top_heads_df.model == model)\n",
    "        & (universal_top_heads_df.prompt_type == BOTH)\n",
    "        & (universal_top_heads_df.baseline == ALL)\n",
    "        & (universal_top_heads_df.n == 100)\n",
    "    ].top_heads.values[0]\n",
    "    prompt_ie = mean_ie_by_model[(model, BOTH, ALL)]\n",
    "\n",
    "    icl_top_100 = universal_top_heads_df[\n",
    "        (universal_top_heads_df.model == model)\n",
    "        & (universal_top_heads_df.prompt_type == ICL)\n",
    "        & (universal_top_heads_df.baseline == ICL)\n",
    "        & (universal_top_heads_df.n == 100)\n",
    "    ].top_heads.values[0]\n",
    "    icl_ie = mean_ie_by_model[(model, ICL, ICL)]\n",
    "\n",
    "    for top_n in (10, 20):\n",
    "        prompt_top_n = universal_top_heads_df[\n",
    "            (universal_top_heads_df.model == model)\n",
    "            & (universal_top_heads_df.prompt_type == BOTH)\n",
    "            & (universal_top_heads_df.baseline == ALL)\n",
    "            & (universal_top_heads_df.n == top_n)\n",
    "        ].top_heads.values[0]\n",
    "        prompt_shared = prompt_top_n & icl_top_100\n",
    "        n_prompt_shared = len(prompt_shared)\n",
    "        prompt_heads_mean_prompt_ie = np.mean([prompt_ie[th] for th in prompt_shared])\n",
    "        prompt_heads_mean_icl_ie = np.mean([icl_ie[th] for th in prompt_shared])\n",
    "        table_rows.append(\n",
    "            dict(\n",
    "                model=model,\n",
    "                head_type=\"prompt\",\n",
    "                n_top_n=top_n,\n",
    "                n_shared=n_prompt_shared,\n",
    "                n_shared_perc=f\"{n_prompt_shared / top_n:.0%}\",\n",
    "                mean_same_ie=prompt_heads_mean_prompt_ie,\n",
    "                mean_other_ie=prompt_heads_mean_icl_ie,\n",
    "                same_other_ratio=prompt_heads_mean_prompt_ie / prompt_heads_mean_icl_ie,\n",
    "            )\n",
    "        )\n",
    "\n",
    "        icl_top_n = universal_top_heads_df[\n",
    "            (universal_top_heads_df.model == model)\n",
    "            & (universal_top_heads_df.prompt_type == ICL)\n",
    "            & (universal_top_heads_df.baseline == ICL)\n",
    "            & (universal_top_heads_df.n == top_n)\n",
    "        ].top_heads.values[0]\n",
    "        icl_shared = icl_top_n & prompt_top_100\n",
    "        n_icl_shared = len(icl_shared)\n",
    "        icl_heads_mean_icl_ie = np.mean([icl_ie[th] for th in icl_shared])\n",
    "        icl_heads_mean_prompt_ie = np.mean([prompt_ie[th] for th in icl_shared])\n",
    "        table_rows.append(\n",
    "            dict(\n",
    "                model=model,\n",
    "                head_type=\"icl\",\n",
    "                n_top_n=top_n,\n",
    "                n_shared=n_icl_shared,\n",
    "                n_shared_perc=f\"{n_icl_shared / top_n:.0%}\",\n",
    "                mean_same_ie=icl_heads_mean_icl_ie,\n",
    "                mean_other_ie=icl_heads_mean_prompt_ie,\n",
    "                same_other_ratio=icl_heads_mean_icl_ie / icl_heads_mean_prompt_ie,\n",
    "            )\n",
    "        )\n",
    "\n",
    "    table_rows.append(tabulate.SEPARATING_LINE)\n",
    "\n",
    "\n",
    "headers = table_rows[0].keys()\n",
    "print_rows = [[row[key] for key in headers] if isinstance(row, dict) else row for row in table_rows]\n",
    "display(Markdown(tabulate.tabulate(print_rows, headers=headers, tablefmt=\"github\")))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Finding heads with least absolute causal effect"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for model in ORDERED_MODELS:\n",
    "    prompt_mean_ie = mean_ie_by_model[(model, BOTH, ALL)]\n",
    "    icl_mean_ie = mean_ie_by_model[(model, ICL, ICL)]\n",
    "    absolute_total_mean_ie = prompt_mean_ie.abs() + icl_mean_ie.abs()\n",
    "    min_abs_heads, min_abs_values = top_heads_from_indirect_effects_with_values(absolute_total_mean_ie, 100, largest=False)\n",
    "    \n",
    "    _save_single_top_head_set(\n",
    "        min_abs_heads,\n",
    "        min_abs_values,\n",
    "        model,\n",
    "        \"universal\",\n",
    "        ALL,\n",
    "        head_type=\"min_abs\",\n",
    "        # output_dir: str = TOP_HEADS_DIR,\n",
    "        # overwrite: bool = False,\n",
    "    )\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And the ones with the smallest causal effect for each type"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for model in ORDERED_MODELS:\n",
    "    for prompt_type, baseline in ((BOTH, ALL), (ICL, ICL)):\n",
    "        mean_ie = mean_ie_by_model[(model, prompt_type, baseline)]   \n",
    "    \n",
    "        bottom_heads, bottom_values = top_heads_from_indirect_effects_with_values(mean_ie, 100, largest=False)\n",
    "    \n",
    "        _save_single_top_head_set(\n",
    "            bottom_heads,\n",
    "            bottom_values,\n",
    "            model,\n",
    "            prompt_type=prompt_type,\n",
    "            baseline=baseline,\n",
    "            head_type=\"bottom\",\n",
    "            # output_dir: str = TOP_HEADS_DIR,\n",
    "            # overwrite: bool = False,\n",
    "        )\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Printing a table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "| model                   |   Prompt_10_mean |   Prompt_10_median |   ICL_10_mean |   ICL_10_median |   Prompt_20_mean |   Prompt_20_median |   ICL_20_mean |   ICL_20_median |   Prompt_100_mean |   Prompt_100_median |   ICL_100_mean |   ICL_100_median |\n",
       "|-------------------------|------------------|--------------------|---------------|-----------------|------------------|--------------------|---------------|-----------------|-------------------|---------------------|----------------|------------------|\n",
       "| Llama-3.2-1B            |        5.816e-03 |          4.929e-03 |     3.156e-02 |       2.275e-02 |        4.407e-03 |          3.871e-03 |     1.819e-02 |       7.146e-03 |         1.786e-03 |           1.040e-03 |      4.521e-03 |        1.155e-03 |\n",
       "| Llama-3.2-1B-Instruct   |        7.695e-03 |          7.026e-03 |     3.511e-02 |       3.048e-02 |        5.855e-03 |          5.011e-03 |     2.028e-02 |       6.796e-03 |         2.205e-03 |           1.335e-03 |      5.369e-03 |        1.688e-03 |\n",
       "| Llama-3.2-3B            |        4.434e-03 |          3.729e-03 |     2.411e-02 |       1.018e-02 |        3.559e-03 |          3.372e-03 |     1.388e-02 |       4.996e-03 |         1.378e-03 |           8.242e-04 |      3.536e-03 |        8.542e-04 |\n",
       "| Llama-3.2-3B-Instruct   |        9.555e-03 |          8.353e-03 |     4.480e-02 |       1.505e-02 |        7.161e-03 |          6.811e-03 |     2.556e-02 |       9.256e-03 |         2.620e-03 |           1.602e-03 |      6.489e-03 |        1.645e-03 |\n",
       "| Llama-3.1-8B            |        3.446e-03 |          3.283e-03 |     2.239e-02 |       1.703e-02 |        2.848e-03 |          2.484e-03 |     1.365e-02 |       7.189e-03 |         1.149e-03 |           7.368e-04 |      3.887e-03 |        1.363e-03 |\n",
       "| Llama-3.1-8B-Instruct   |        8.576e-03 |          7.676e-03 |     2.358e-02 |       1.764e-02 |        6.749e-03 |          5.806e-03 |     1.472e-02 |       7.761e-03 |         2.488e-03 |           1.412e-03 |      4.401e-03 |        1.797e-03 |\n",
       "| Llama-2-7b-hf           |        3.392e-03 |          3.147e-03 |     1.630e-02 |       1.294e-02 |        2.565e-03 |          2.015e-03 |     9.715e-03 |       5.394e-03 |         1.212e-03 |           8.952e-04 |      2.609e-03 |        8.563e-04 |\n",
       "| Llama-2-7b-chat-hf      |        8.545e-03 |          5.821e-03 |     3.942e-02 |       3.443e-02 |        6.257e-03 |          4.617e-03 |     2.383e-02 |       1.228e-02 |         2.471e-03 |           1.527e-03 |      6.770e-03 |        2.613e-03 |\n",
       "| Llama-2-13b-hf          |        2.775e-03 |          2.129e-03 |     1.550e-02 |       1.151e-02 |        2.107e-03 |          1.723e-03 |     9.834e-03 |       6.372e-03 |         1.028e-03 |           7.966e-04 |      2.738e-03 |        9.504e-04 |\n",
       "| Llama-2-13b-chat-hf     |        8.587e-03 |          6.800e-03 |     5.083e-02 |       3.981e-02 |        6.462e-03 |          5.400e-03 |     3.151e-02 |       1.878e-02 |         2.759e-03 |           1.856e-03 |      8.981e-03 |        2.888e-03 |\n",
       "| OLMo-2-1124-7B          |        8.934e-03 |          7.554e-03 |     2.442e-02 |       1.273e-02 |        6.885e-03 |          6.173e-03 |     1.423e-02 |       6.866e-03 |         2.892e-03 |           1.908e-03 |      3.592e-03 |        9.344e-04 |\n",
       "| OLMo-2-1124-7B-SFT      |        8.110e-03 |          6.537e-03 |     2.789e-02 |       1.748e-02 |        6.487e-03 |          5.598e-03 |     1.636e-02 |       7.593e-03 |         2.750e-03 |           1.896e-03 |      4.336e-03 |        1.359e-03 |\n",
       "| OLMo-2-1124-7B-DPO      |        1.238e-02 |          1.102e-02 |     4.227e-02 |       2.768e-02 |        9.759e-03 |          8.215e-03 |     2.550e-02 |       1.171e-02 |         4.022e-03 |           2.668e-03 |      6.932e-03 |        2.170e-03 |\n",
       "| OLMo-2-1124-7B-Instruct |        1.295e-02 |          1.139e-02 |     4.576e-02 |       3.029e-02 |        1.033e-02 |          8.876e-03 |     2.769e-02 |       1.251e-02 |         4.295e-03 |           2.830e-03 |      7.575e-03 |        2.366e-03 |"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "def format_number(x):\n",
    "    if x == 0:\n",
    "        return \"0\"\n",
    "    # elif x < 0.001:\n",
    "    #     return f\"{x:.2e}\"\n",
    "    else:\n",
    "        return f\"{x:e}\"\n",
    "\n",
    "\n",
    "rows = []\n",
    "\n",
    "for model in ORDERED_MODELS:\n",
    "    row = dict(model=model)\n",
    "    for N_TOP_HEADS in (10, 20, 100):\n",
    "        prompt_mean_ie = mean_ie_by_model[(model, BOTH, ALL)]\n",
    "        _, prompt_top_values = top_heads_from_indirect_effects_with_values(prompt_mean_ie, N_TOP_HEADS)\n",
    "        row[f'Prompt_{N_TOP_HEADS}_mean'] = format_number(np.mean(prompt_top_values))\n",
    "        row[f'Prompt_{N_TOP_HEADS}_median'] = format_number(np.median(prompt_top_values))\n",
    "        \n",
    "        icl_mean_ie = mean_ie_by_model[(model, ICL, ICL)]\n",
    "        _, icl_top_values = top_heads_from_indirect_effects_with_values(icl_mean_ie, N_TOP_HEADS)\n",
    "        row[f'ICL_{N_TOP_HEADS}_mean'] = format_number(np.mean(icl_top_values))\n",
    "        row[f'ICL_{N_TOP_HEADS}_median'] = format_number(np.median(icl_top_values))\n",
    "\n",
    "    rows.append(row)\n",
    "\n",
    "\n",
    "headers = rows[0].keys()\n",
    "print_rows = [[row[key] for key in headers] if isinstance(row, dict) else row for row in rows]\n",
    "display(Markdown(tabulate.tabulate(print_rows, headers=headers, tablefmt=\"github\", floatfmt=\".3e\")))\n",
    "    \n",
    "\n",
    "\n",
    "    "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "fv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
