{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5eed2c42",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a839de9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "model_key = \"meta-llama/Llama-3.3-70B-Instruct\"\n",
    "\n",
    "with open(f\"figures/{model_key.split('/')[-1]}/raw/aie_per_head.json\", \"r\") as f:\n",
    "    aie_per_head = json.load(f)\n",
    "\n",
    "aie_per_head = {(layer_idx, head_idx): aie for layer_idx, head_idx, aie in aie_per_head} "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ae8fd29a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "\n",
    "sys.path.append(\"../\")\n",
    "\n",
    "from matplotlib import pyplot as plt\n",
    "import numpy as np\n",
    "from src.utils import env_utils\n",
    "import os\n",
    "import torch\n",
    "plt.rcdefaults()\n",
    "\n",
    "optimized_path = os.path.join(\n",
    "    env_utils.DEFAULT_RESULTS_DIR,\n",
    "    \"selection/optimized_heads\",\n",
    "    model_key.split(\"/\")[-1],\n",
    "    \"distinct_options\",\n",
    "    # f\"{select_task.task_name}\",\n",
    "    \"select_one\",\n",
    "    # \"legacy\",\n",
    "    \"epoch_10.npz\"\n",
    ")\n",
    "\n",
    "optimization_results = np.load(optimized_path, allow_pickle=True)\n",
    "plt.plot(optimization_results[\"losses\"])\n",
    "plt.show()\n",
    "\n",
    "plt.figure(figsize=(20, 10))\n",
    "\n",
    "optimal_head_mask = torch.tensor(optimization_results[\"optimal_mask\"]).to(torch.float32)\n",
    "optimal_head_mask[52:, :] = 0.0\n",
    "\n",
    "plt.imshow(\n",
    "    optimal_head_mask.T.numpy(),\n",
    "    cmap=\"Blues\",\n",
    "    aspect=\"auto\",\n",
    "    vmin=0,\n",
    "    vmax=1,\n",
    ")\n",
    "\n",
    "optimized_heads = torch.nonzero(optimal_head_mask > 0.5, as_tuple=False).tolist()\n",
    "optimized_heads = [\n",
    "    (layer_idx, head_idx) for layer_idx, head_idx in optimized_heads\n",
    "]\n",
    "print(len(optimized_heads))\n",
    "\n",
    "HEADS = optimized_heads\n",
    "\n",
    "(35, 19) in HEADS, (35, 19) in optimized_heads\n",
    "# [(29, 3) in HEADS]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ff4cd4b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "len(optimized_heads)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71386925",
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib import category, pyplot as plt\n",
    "import os\n",
    "from matplotlib.ticker import FuncFormatter\n",
    "\n",
    "import json\n",
    "import matplotlib.patches as patches\n",
    "import torch\n",
    "\n",
    "plt.rcParams[\"figure.dpi\"] = 200\n",
    "plt.rcParams[\"font.family\"] = \"Times New Roman\"\n",
    "\n",
    "SMALL_SIZE = 50\n",
    "MEDIUM_SIZE = 55\n",
    "BIGGER_SIZE = 60\n",
    "\n",
    "plt.rc(\"font\", size=SMALL_SIZE)  # controls default text sizes\n",
    "plt.rc(\"axes\", titlesize=BIGGER_SIZE)  # fontsize of the axes title\n",
    "plt.rc(\"axes\", labelsize=MEDIUM_SIZE)  # fontsize of the x and y labels\n",
    "plt.rc(\"xtick\", labelsize=SMALL_SIZE)  # fontsize of the tick labels\n",
    "plt.rc(\"ytick\", labelsize=SMALL_SIZE)  # fontsize of the tick labels\n",
    "plt.rc(\"legend\", fontsize=SMALL_SIZE)  # legend fontsize\n",
    "plt.rc(\"figure\", titlesize=BIGGER_SIZE)  # fontsize of the figure title\n",
    "\n",
    "\n",
    "\n",
    "plt.rcParams[\"figure.dpi\"] = 300\n",
    "plt.rcParams[\"font.family\"] = \"Times New Roman\"\n",
    "\n",
    "SMALL_SIZE = 30\n",
    "MEDIUM_SIZE = 35\n",
    "BIGGER_SIZE = 40\n",
    "\n",
    "plt.rc(\"font\", size=SMALL_SIZE)  # controls default text sizes\n",
    "plt.rc(\"axes\", titlesize=BIGGER_SIZE)  # fontsize of the axes title\n",
    "plt.rc(\"axes\", labelsize=MEDIUM_SIZE)  # fontsize of the x and y labels\n",
    "plt.rc(\"xtick\", labelsize=SMALL_SIZE)  # fontsize of the tick labels\n",
    "plt.rc(\"ytick\", labelsize=SMALL_SIZE)  # fontsize of the tick labels\n",
    "plt.rc(\"legend\", fontsize=SMALL_SIZE)  # legend fontsize\n",
    "plt.rc(\"figure\", titlesize=BIGGER_SIZE)  # fontsize of the figure title\n",
    "\n",
    "fig_save_path = os.path.join(\"figures\", model_key.split(\"/\")[-1], \"aie\")\n",
    "os.makedirs(fig_save_path, exist_ok=True)\n",
    "\n",
    "n_layer = 80\n",
    "n_heads = 64\n",
    "\n",
    "category_wise_heads = {}\n",
    "indirect_effects = torch.zeros((n_layer, n_heads), dtype=torch.float32)\n",
    "for layer_idx in range(n_layer):\n",
    "    for head_idx in range(n_heads):\n",
    "        indirect_effects[layer_idx, head_idx] = aie_per_head[(layer_idx, head_idx)]\n",
    "\n",
    "plt.figure(figsize=(14, 10))\n",
    "scale = torch.max(torch.abs(indirect_effects))\n",
    "plt.imshow(\n",
    "    indirect_effects.T.cpu().numpy(),\n",
    "    cmap=\"RdBu\",\n",
    "    aspect=\"auto\",\n",
    "    # vmin=-scale,\n",
    "    # vmax=scale,\n",
    "    # vmin=-0.15,\n",
    "    # vmax=0.15,\n",
    "    vmin=2,\n",
    "    vmax=-2\n",
    ")\n",
    "cbar = plt.colorbar()\n",
    "# Format colorbar tick labels to show + sign for positive values\n",
    "cbar.ax.yaxis.set_major_formatter(FuncFormatter(lambda x, p: f\"{x:+.1f}\" if x != 0 else \"0.0\"))\n",
    "# plt.title(f\"score(target) - max(score(distractors)) | {token_idx.upper()} tokens of options\")\n",
    "# plt.title(\"IE of q_proj patching | \" + category)\n",
    "plt.xlabel(\"Layer\")\n",
    "plt.ylabel(\"Head Index\")\n",
    "\n",
    "def get_ticks(ticks, skip=5):\n",
    "    ret = []\n",
    "    for i in ticks:\n",
    "        if i % skip == 0:\n",
    "            ret.append(str(i))\n",
    "        else:\n",
    "            ret.append(\"\")\n",
    "    return ret\n",
    "\n",
    "plt.xticks(\n",
    "    ticks=range(n_layer),\n",
    "    labels=get_ticks(range(n_layer)),\n",
    "    rotation=1,\n",
    ")\n",
    "plt.yticks(\n",
    "    ticks=range(n_heads),\n",
    "    labels=get_ticks(range(n_heads), skip=8),\n",
    ")\n",
    "\n",
    "# Get the current axes\n",
    "ax = plt.gca()\n",
    "\n",
    "# Draw borders around marked cells\n",
    "for (x, y) in optimized_heads:\n",
    "    print(x, y)\n",
    "    # Create a Rectangle patch\n",
    "    # Note: (x-0.5, y-0.5) positions the rectangle correctly around the cell\n",
    "    # Width and height of 1 covers exactly one cell\n",
    "    rect = patches.Rectangle(\n",
    "        (x - 0.5, y - 0.5),  # bottom-left corner\n",
    "        1,                     # width\n",
    "        1,                     # height\n",
    "        linewidth=1.5,          # border thickness\n",
    "        edgecolor='black',    # border color (you can change this)\n",
    "        facecolor='none'      # no fill, just border\n",
    "    )\n",
    "    ax.add_patch(rect)  # FIXED: This should be inside the loop!\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig(os.path.join(fig_save_path, \"objects.pdf\"), bbox_inches='tight', pad_inches=0)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f516a54c",
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib import category, pyplot as plt\n",
    "import os\n",
    "from matplotlib.ticker import FuncFormatter\n",
    "\n",
    "import json\n",
    "import matplotlib.patches as patches\n",
    "import torch\n",
    "\n",
    "plt.rcParams[\"figure.dpi\"] = 200\n",
    "plt.rcParams[\"font.family\"] = \"Times New Roman\"\n",
    "\n",
    "SMALL_SIZE = 50\n",
    "MEDIUM_SIZE = 55\n",
    "BIGGER_SIZE = 60\n",
    "\n",
    "plt.rc(\"font\", size=SMALL_SIZE)  # controls default text sizes\n",
    "plt.rc(\"axes\", titlesize=BIGGER_SIZE)  # fontsize of the axes title\n",
    "plt.rc(\"axes\", labelsize=MEDIUM_SIZE)  # fontsize of the x and y labels\n",
    "plt.rc(\"xtick\", labelsize=SMALL_SIZE)  # fontsize of the tick labels\n",
    "plt.rc(\"ytick\", labelsize=SMALL_SIZE)  # fontsize of the tick labels\n",
    "plt.rc(\"legend\", fontsize=SMALL_SIZE)  # legend fontsize\n",
    "plt.rc(\"figure\", titlesize=BIGGER_SIZE)  # fontsize of the figure title\n",
    "\n",
    "\n",
    "\n",
    "plt.rcParams[\"figure.dpi\"] = 300\n",
    "plt.rcParams[\"font.family\"] = \"Times New Roman\"\n",
    "\n",
    "SMALL_SIZE = 30\n",
    "MEDIUM_SIZE = 35\n",
    "BIGGER_SIZE = 40\n",
    "\n",
    "plt.rc(\"font\", size=SMALL_SIZE)  # controls default text sizes\n",
    "plt.rc(\"axes\", titlesize=BIGGER_SIZE)  # fontsize of the axes title\n",
    "plt.rc(\"axes\", labelsize=MEDIUM_SIZE)  # fontsize of the x and y labels\n",
    "plt.rc(\"xtick\", labelsize=SMALL_SIZE)  # fontsize of the tick labels\n",
    "plt.rc(\"ytick\", labelsize=SMALL_SIZE)  # fontsize of the tick labels\n",
    "plt.rc(\"legend\", fontsize=SMALL_SIZE)  # legend fontsize\n",
    "plt.rc(\"figure\", titlesize=BIGGER_SIZE)  # fontsize of the figure title\n",
    "\n",
    "fig_save_path = os.path.join(\"figures\", model_key.split(\"/\")[-1], \"aie\")\n",
    "os.makedirs(fig_save_path, exist_ok=True)\n",
    "\n",
    "n_layer = 80\n",
    "n_heads = 64\n",
    "\n",
    "# Define the layer range to display\n",
    "layer_start = 20\n",
    "layer_end = 55  # inclusive\n",
    "\n",
    "category_wise_heads = {}\n",
    "indirect_effects = torch.zeros((n_layer, n_heads), dtype=torch.float32)\n",
    "for layer_idx in range(n_layer):\n",
    "    for head_idx in range(n_heads):\n",
    "        indirect_effects[layer_idx, head_idx] = aie_per_head[(layer_idx, head_idx)]\n",
    "\n",
    "# Slice the data to only show layers 20-60\n",
    "indirect_effects_subset = indirect_effects[layer_start:layer_end+1, :]\n",
    "\n",
    "plt.figure(figsize=(8, 10))\n",
    "scale = torch.max(torch.abs(indirect_effects_subset))\n",
    "plt.imshow(\n",
    "    indirect_effects_subset.T.cpu().numpy(),\n",
    "    cmap=\"RdBu\",\n",
    "    aspect=\"auto\",\n",
    "    # vmin=-scale,\n",
    "    # vmax=scale,\n",
    "    # vmin=-0.15,\n",
    "    # vmax=0.15,\n",
    "    vmin=2,\n",
    "    vmax=-2\n",
    ")\n",
    "cbar = plt.colorbar()\n",
    "# Format colorbar tick labels to show + sign for positive values\n",
    "cbar.ax.yaxis.set_major_formatter(FuncFormatter(lambda x, p: f\"{x:+.1f}\" if x != 0 else \"0.0\"))\n",
    "# plt.title(f\"score(target) - max(score(distractors)) | {token_idx.upper()} tokens of options\")\n",
    "# plt.title(\"IE of q_proj patching | \" + category)\n",
    "plt.xlabel(\"Layer\")\n",
    "plt.ylabel(\"Head Index\")\n",
    "\n",
    "def get_ticks(ticks, skip=5, offset=0):\n",
    "    ret = []\n",
    "    for i in ticks:\n",
    "        actual_layer = i + offset\n",
    "        if actual_layer % skip == 0:\n",
    "            ret.append(str(actual_layer))\n",
    "        else:\n",
    "            ret.append(\"\")\n",
    "    return ret\n",
    "\n",
    "# Set x-ticks for the subset of layers\n",
    "n_layers_shown = layer_end - layer_start + 1\n",
    "plt.xticks(\n",
    "    ticks=range(n_layers_shown),\n",
    "    labels=get_ticks(range(n_layers_shown), skip=5, offset=layer_start),\n",
    "    rotation=1,\n",
    ")\n",
    "plt.yticks(\n",
    "    ticks=range(n_heads),\n",
    "    labels=get_ticks(range(n_heads), skip=8),\n",
    ")\n",
    "\n",
    "# Get the current axes\n",
    "ax = plt.gca()\n",
    "\n",
    "# Draw borders around marked cells\n",
    "# Note: Need to adjust x-coordinates for the subset\n",
    "for (x, y) in optimized_heads:\n",
    "    # Only draw if the head is within the displayed layer range\n",
    "    if layer_start <= x <= layer_end:\n",
    "        print(x, y)\n",
    "        # Adjust x-coordinate to match the new indexing\n",
    "        adjusted_x = x - layer_start\n",
    "        # Create a Rectangle patch\n",
    "        # Note: (adjusted_x-0.5, y-0.5) positions the rectangle correctly around the cell\n",
    "        # Width and height of 1 covers exactly one cell\n",
    "        rect = patches.Rectangle(\n",
    "            (adjusted_x - 0.5, y - 0.5),  # bottom-left corner\n",
    "            1,                     # width\n",
    "            1,                     # height\n",
    "            linewidth=0.5,          # border thickness\n",
    "            edgecolor='black',    # border color (you can change this)\n",
    "            facecolor='none'      # no fill, just border\n",
    "        )\n",
    "        ax.add_patch(rect)  # FIXED: This should be inside the loop!\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig(os.path.join(fig_save_path, \"objects-sliced.pdf\"), bbox_inches='tight', pad_inches=0)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "66dc4d98",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d696c084",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9747f004",
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib import pyplot as plt\n",
    "import os\n",
    "plt.rcParams[\"figure.dpi\"] = 200\n",
    "plt.rcParams[\"font.family\"] = \"Times New Roman\"\n",
    "\n",
    "SMALL_SIZE = 50\n",
    "MEDIUM_SIZE = 55\n",
    "BIGGER_SIZE = 60\n",
    "\n",
    "plt.rc(\"font\", size=SMALL_SIZE)  # controls default text sizes\n",
    "plt.rc(\"axes\", titlesize=BIGGER_SIZE)  # fontsize of the axes title\n",
    "plt.rc(\"axes\", labelsize=MEDIUM_SIZE)  # fontsize of the x and y labels\n",
    "plt.rc(\"xtick\", labelsize=SMALL_SIZE)  # fontsize of the tick labels\n",
    "plt.rc(\"ytick\", labelsize=SMALL_SIZE)  # fontsize of the tick labels\n",
    "plt.rc(\"legend\", fontsize=SMALL_SIZE)  # legend fontsize\n",
    "plt.rc(\"figure\", titlesize=BIGGER_SIZE)  # fontsize of the figure title\n",
    "\n",
    "\n",
    "n_dist_performance = {\n",
    "    2: (0.8301, 7.5886),\n",
    "    3: (0.8185, 7.9886),\n",
    "    4: (0.8047, 8.4584),\n",
    "    5: (0.8047, 8.6950),\n",
    "    6: (0.7891, 8.4796),\n",
    "    7: (0.7617, 8.6022),\n",
    "}\n",
    "\n",
    "\n",
    "plt.rc(\"figure\", figsize=(16, 10))\n",
    "plt.plot(\n",
    "    list(n_dist_performance.keys()),\n",
    "    [x[0] for x in n_dist_performance.values()],\n",
    "    marker=\"o\",\n",
    "    linewidth=5,\n",
    "    markersize=15,\n",
    "    alpha=0.8,\n",
    ")\n",
    "plt.xlabel(\"Number of Distractors\")\n",
    "plt.ylabel(\"Causality\")\n",
    "plt.ylim(0.5, 1)\n",
    "plt.xticks(list(n_dist_performance.keys()))\n",
    "\n",
    "save_path = os.path.join(\"figures\")\n",
    "os.makedirs(save_path, exist_ok=True)\n",
    "plt.savefig(os.path.join(save_path, \"causality_vs_n_distractors.pdf\"), bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "54bcfca0",
   "metadata": {},
   "outputs": [],
   "source": [
    "#! Only checked the causality values. Some of the values are off\n",
    "\n",
    "train_vs_evaluation = {\n",
    "    \"Select One\": {\n",
    "        \"num_heads\": 79,\n",
    "        \"evaluation\": {\n",
    "            \"Select One\": {\n",
    "                \"causality\": 0.8633,\n",
    "                \"n_correct\": 884,\n",
    "                \"out_of\": 1024,\n",
    "                \"delta_logit\": 9.0276\n",
    "            },\n",
    "            \"Select First\": {\n",
    "                \"causality\": 0.6875,\n",
    "                \"n_correct\": 352,\n",
    "                \"out_of\": 512,\n",
    "                \"delta_logit\": 5.1055\n",
    "            },\n",
    "            \"Select Last\": {\n",
    "                \"causality\": 0.7769,\n",
    "                \"n_correct\": 395,\n",
    "                \"out_of\": 512,\n",
    "                \"delta_logit\": 8.0176\n",
    "            },\n",
    "            \"Select One - MCQ\": {\n",
    "                \"causality\": 0.8418,\n",
    "                \"n_correct\": 374,\n",
    "                \"out_of\": 512,\n",
    "                \"delta_logit\": 5.8535\n",
    "            },\n",
    "            \"Counting\": {\n",
    "                \"causality\": 0.0410,\n",
    "                \"n_correct\": 21,\n",
    "                \"out_of\": 512,\n",
    "                \"delta_logit\": 1.2270\n",
    "            },\n",
    "            \"Yes/No\": {\n",
    "                \"causality\": 0.0176,\n",
    "                \"n_correct\": 9,\n",
    "                \"out_of\": 512,\n",
    "                \"delta_logit\": 0.4799\n",
    "            }\n",
    "        }\n",
    "    },\n",
    "    \"Select First\": {\n",
    "        \"num_heads\": 81,\n",
    "        \"evaluation\": {\n",
    "            \"Select One\": {\n",
    "                \"causality\": 0.7910,\n",
    "                \"n_correct\": 405,\n",
    "                \"out_of\": 512,\n",
    "                \"delta_logit\": 9.5619\n",
    "            },\n",
    "            \"Select First\": {\n",
    "                \"causality\": 0.7285,\n",
    "                \"n_correct\": 396,\n",
    "                \"out_of\": 512,\n",
    "                \"delta_logit\": 7.2440\n",
    "            },\n",
    "            \"Select Last\": {\n",
    "                \"causality\": 0.7730,\n",
    "                \"n_correct\": 374,\n",
    "                \"out_of\": 512,\n",
    "                \"delta_logit\": 8.6836\n",
    "            },\n",
    "            \"Select One - MCQ\": {\n",
    "                \"causality\": 0.6934,\n",
    "                \"n_correct\": 355,\n",
    "                \"out_of\": 512,\n",
    "                \"delta_logit\": 5.2285\n",
    "            },\n",
    "            \"Counting\": {\n",
    "                \"causality\": 0.1270,\n",
    "                \"n_correct\": 65,\n",
    "                \"out_of\": 512,\n",
    "                \"delta_logit\": 1.9473\n",
    "            },\n",
    "            \"Yes/No\": {\n",
    "                \"causality\": 0.0059,\n",
    "                \"n_correct\": 3,\n",
    "                \"out_of\": 512,\n",
    "                \"delta_logit\": 0.6457\n",
    "            }\n",
    "        }\n",
    "    },\n",
    "    \"Select Last\": {\n",
    "        \"num_heads\": 145,\n",
    "        \"evaluation\": {\n",
    "            \"Select One\": {\n",
    "                \"causality\": 0.7617,\n",
    "                \"n_correct\": 390,\n",
    "                \"out_of\": 512,\n",
    "                \"delta_logit\": 8.2188\n",
    "            },\n",
    "            \"Select First\": {\n",
    "                \"causality\": 0.6016,\n",
    "                \"n_correct\": 308,\n",
    "                \"out_of\": 512,\n",
    "                \"delta_logit\": 7.3594\n",
    "            },\n",
    "            \"Select Last\": {\n",
    "                \"causality\": 0.8789,\n",
    "                \"n_correct\": 840,\n",
    "                \"out_of\": 1024,\n",
    "                \"delta_logit\": 8.2751\n",
    "            },\n",
    "            \"Select One - MCQ\": {\n",
    "                \"causality\": 0.7744,\n",
    "                \"n_correct\": 790,\n",
    "                \"out_of\": 1024,\n",
    "                \"delta_logit\": 5.5957\n",
    "            },\n",
    "            \"Counting\": {\n",
    "                \"causality\": 0.1387,\n",
    "                \"n_correct\": 71,\n",
    "                \"out_of\": 512,\n",
    "                \"delta_logit\": 2.1542\n",
    "            },\n",
    "            \"Yes/No\": {\n",
    "                \"causality\": 0.0352,\n",
    "                \"n_correct\": 18,\n",
    "                \"out_of\": 512,\n",
    "                \"delta_logit\": 0.3779\n",
    "            }\n",
    "        }\n",
    "    },\n",
    "    \"Select One - MCQ\": {\n",
    "        \"num_heads\": 45,\n",
    "        \"evaluation\": {\n",
    "            \"Select One\": {\n",
    "                \"causality\": 0.5918,\n",
    "                \"n_correct\": 303,\n",
    "                \"out_of\": 512,\n",
    "                \"delta_logit\": 7.0840\n",
    "            },\n",
    "            \"Select First\": {\n",
    "                \"causality\": 0.2734,\n",
    "                \"n_correct\": 223,\n",
    "                \"out_of\": 512,\n",
    "                \"delta_logit\": 2.8923\n",
    "            },\n",
    "            \"Select Last\": {\n",
    "                \"causality\": 0.4384,\n",
    "                \"n_correct\": 224,\n",
    "                \"out_of\": 511,\n",
    "                \"delta_logit\": 5.7093\n",
    "            },\n",
    "            \"Select One - MCQ\": {\n",
    "                \"causality\": 0.9043,\n",
    "                \"n_correct\": 450,\n",
    "                \"out_of\": 512,\n",
    "                \"delta_logit\": 5.0955\n",
    "            },\n",
    "            \"Counting\": {\n",
    "                \"causality\": 0.0645,\n",
    "                \"n_correct\": 33,\n",
    "                \"out_of\": 512,\n",
    "                \"delta_logit\": 1.2347\n",
    "            },\n",
    "            \"Yes/No\": {\n",
    "                \"causality\": 0.0215,\n",
    "                \"n_correct\": 11,\n",
    "                \"out_of\": 512,\n",
    "                \"delta_logit\": 0.3444\n",
    "            }\n",
    "        }\n",
    "    },\n",
    "    \"Counting\": {\n",
    "        \"num_heads\": 64,\n",
    "        \"evaluation\": {\n",
    "            \"Select One\": {\n",
    "                \"causality\": 0.6094,\n",
    "                \"n_correct\": 312,\n",
    "                \"out_of\": 512,\n",
    "                \"delta_logit\": 6.8508\n",
    "            },\n",
    "            \"Select First\": {\n",
    "                \"causality\": 0.3789,\n",
    "                \"n_correct\": 194,\n",
    "                \"out_of\": 512,\n",
    "                \"delta_logit\": 2.8218\n",
    "            },\n",
    "            \"Select Last\": {\n",
    "                \"causality\": 0.5832,\n",
    "                \"n_correct\": 299,\n",
    "                \"out_of\": 511,\n",
    "                \"delta_logit\": 2.9157\n",
    "            },\n",
    "            \"Select One - MCQ\": {\n",
    "                \"causality\": 0.7871,\n",
    "                \"n_correct\": 195,\n",
    "                \"out_of\": 512,\n",
    "                \"delta_logit\": 1.7144\n",
    "            },\n",
    "            \"Counting\": {\n",
    "                \"causality\": 0.3555,\n",
    "                \"n_correct\": 177,\n",
    "                \"out_of\": 512,\n",
    "                \"delta_logit\": 2.3633\n",
    "            },\n",
    "            \"Yes/No\": {\n",
    "                \"causality\": 0.0469,\n",
    "                \"n_correct\": 9,\n",
    "                \"out_of\": 512,\n",
    "                \"delta_logit\": 0.4086\n",
    "            }\n",
    "        }\n",
    "    },\n",
    "    \"Yes/No\": {\n",
    "        \"num_heads\": 21,\n",
    "        \"evaluation\": {\n",
    "            \"Select One\": {\n",
    "                \"causality\": 0.0820,\n",
    "                \"n_correct\": 42,\n",
    "                \"out_of\": 512,\n",
    "                \"delta_logit\": 4.5545\n",
    "            },\n",
    "            \"Select First\": {\n",
    "                \"causality\": 0.0957,\n",
    "                \"n_correct\": 49,\n",
    "                \"out_of\": 512,\n",
    "                \"delta_logit\": 4.3595\n",
    "            },\n",
    "            \"Select Last\": {\n",
    "                \"causality\": 0.1487,\n",
    "                \"n_correct\": 76,\n",
    "                \"out_of\": 511,\n",
    "                \"delta_logit\": 4.4409\n",
    "            },\n",
    "            \"Select One - MCQ\": {\n",
    "                \"causality\": 0.3574,\n",
    "                \"n_correct\": 183,\n",
    "                \"out_of\": 512,\n",
    "                \"delta_logit\": 2.0686\n",
    "            },\n",
    "            \"Counting\": {\n",
    "                \"causality\": 0.0762,\n",
    "                \"n_correct\": 39,\n",
    "                \"out_of\": 512,\n",
    "                \"delta_logit\": 1.5039\n",
    "            },\n",
    "            \"Yes/No\": {\n",
    "                \"causality\": 0.0898,\n",
    "                \"n_correct\": 9,\n",
    "                \"out_of\": 512,\n",
    "                \"delta_logit\": 0.1039\n",
    "            }\n",
    "        }\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0e052c0f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# for gemma-27B\n",
    "\n",
    "#! Only checked the causality values. Some of the values are off\n",
    "train_vs_evaluation = {\n",
    "    \"Select One\": {\n",
    "        \"num_heads\": 77,\n",
    "        \"evaluation\": {\n",
    "            \"Select One\": {\n",
    "                \"causality\": 0.824,\n",
    "                \"n_correct\": 884,\n",
    "                \"out_of\": 1024,\n",
    "                \"delta_logit\": 9.0276\n",
    "            },\n",
    "            \"Select First\": {\n",
    "                \"causality\": 0.695,\n",
    "                \"n_correct\": 352,\n",
    "                \"out_of\": 512,\n",
    "                \"delta_logit\": 5.1055\n",
    "            },\n",
    "            \"Select Last\": {\n",
    "                \"causality\": 0.670,\n",
    "                \"n_correct\": 395,\n",
    "                \"out_of\": 512,\n",
    "                \"delta_logit\": 8.0176\n",
    "            },\n",
    "            \"Select One - MCQ\": {\n",
    "                \"causality\": 0.826,\n",
    "                \"n_correct\": 374,\n",
    "                \"out_of\": 512,\n",
    "                \"delta_logit\": 5.8535\n",
    "            },\n",
    "            \"Counting\": {\n",
    "                \"causality\": 0.074,\n",
    "                \"n_correct\": 21,\n",
    "                \"out_of\": 512,\n",
    "                \"delta_logit\": 1.2270\n",
    "            },\n",
    "            \"CheckPresence\": {\n",
    "                \"causality\": 0.0312,\n",
    "                \"n_correct\": 9,\n",
    "                \"out_of\": 512,\n",
    "                \"delta_logit\": 0.4799\n",
    "            }\n",
    "        }\n",
    "    },\n",
    "    \"Select First\": {\n",
    "        \"num_heads\": 77,\n",
    "        \"evaluation\": {\n",
    "            \"Select One\": {\n",
    "                \"causality\": 0.7969,\n",
    "                \"n_correct\": 405,\n",
    "                \"out_of\": 512,\n",
    "                \"delta_logit\": 9.5619\n",
    "            },\n",
    "            \"Select First\": {\n",
    "                \"causality\": 0.6777,\n",
    "                \"n_correct\": 396,\n",
    "                \"out_of\": 512,\n",
    "                \"delta_logit\": 7.2440\n",
    "            },\n",
    "            \"Select Last\": {\n",
    "                \"causality\": 0.6855,\n",
    "                \"n_correct\": 374,\n",
    "                \"out_of\": 512,\n",
    "                \"delta_logit\": 8.6836\n",
    "            },\n",
    "            \"Select One - MCQ\": {\n",
    "                \"causality\": 0.8301,\n",
    "                \"n_correct\": 355,\n",
    "                \"out_of\": 512,\n",
    "                \"delta_logit\": 5.2285\n",
    "            },\n",
    "            \"Counting\": {\n",
    "                \"causality\": 0.0527,\n",
    "                \"n_correct\": 65,\n",
    "                \"out_of\": 512,\n",
    "                \"delta_logit\": 1.9473\n",
    "            },\n",
    "            \"CheckPresence\": {\n",
    "                \"causality\": 0.0254,\n",
    "                \"n_correct\": 3,\n",
    "                \"out_of\": 512,\n",
    "                \"delta_logit\": 0.6457\n",
    "            }\n",
    "        }\n",
    "    },\n",
    "    \"Select Last\": {\n",
    "        \"num_heads\": 76,\n",
    "        \"evaluation\": {\n",
    "            \"Select One\": {\n",
    "                \"causality\": 0.8184,\n",
    "                \"n_correct\": 390,\n",
    "                \"out_of\": 512,\n",
    "                \"delta_logit\": 8.2188\n",
    "            },\n",
    "            \"Select First\": {\n",
    "                \"causality\": 0.6621,\n",
    "                \"n_correct\": 308,\n",
    "                \"out_of\": 512,\n",
    "                \"delta_logit\": 7.3594\n",
    "            },\n",
    "            \"Select Last\": {\n",
    "                \"causality\": 0.8809,\n",
    "                \"n_correct\": 840,\n",
    "                \"out_of\": 1024,\n",
    "                \"delta_logit\": 8.2751\n",
    "            },\n",
    "            \"Select One - MCQ\": {\n",
    "                \"causality\": 0.8184,\n",
    "                \"n_correct\": 790,\n",
    "                \"out_of\": 1024,\n",
    "                \"delta_logit\": 5.5957\n",
    "            },\n",
    "            \"Counting\": {\n",
    "                \"causality\": 0.0449,\n",
    "                \"n_correct\": 71,\n",
    "                \"out_of\": 512,\n",
    "                \"delta_logit\": 2.1542\n",
    "            },\n",
    "            \"CheckPresence\": {\n",
    "                \"causality\": 0.0215,\n",
    "                \"n_correct\": 18,\n",
    "                \"out_of\": 512,\n",
    "                \"delta_logit\": 0.3779\n",
    "            }\n",
    "        }\n",
    "    },\n",
    "    \"Select One - MCQ\": {\n",
    "        \"num_heads\": 39,\n",
    "        \"evaluation\": {\n",
    "            \"Select One\": {\n",
    "                \"causality\": 0.4258,\n",
    "                \"n_correct\": 303,\n",
    "                \"out_of\": 512,\n",
    "                \"delta_logit\": 7.0840\n",
    "            },\n",
    "            \"Select First\": {\n",
    "                \"causality\": 0.3418,\n",
    "                \"n_correct\": 223,\n",
    "                \"out_of\": 512,\n",
    "                \"delta_logit\": 2.8923\n",
    "            },\n",
    "            \"Select Last\": {\n",
    "                \"causality\": 0.3457,\n",
    "                \"n_correct\": 224,\n",
    "                \"out_of\": 511,\n",
    "                \"delta_logit\": 5.7093\n",
    "            },\n",
    "            \"Select One - MCQ\": {\n",
    "                \"causality\": 0.8164,\n",
    "                \"n_correct\": 450,\n",
    "                \"out_of\": 512,\n",
    "                \"delta_logit\": 5.0955\n",
    "            },\n",
    "            \"Counting\": {\n",
    "                \"causality\": 0.0430,\n",
    "                \"n_correct\": 33,\n",
    "                \"out_of\": 512,\n",
    "                \"delta_logit\": 1.2347\n",
    "            },\n",
    "            \"CheckPresence\": {\n",
    "                \"causality\": 0.0254,\n",
    "                \"n_correct\": 11,\n",
    "                \"out_of\": 512,\n",
    "                \"delta_logit\": 0.3444\n",
    "            }\n",
    "        }\n",
    "    },\n",
    "    \"Counting\": {\n",
    "        \"num_heads\": 39,\n",
    "        \"evaluation\": {\n",
    "            \"Select One\": {\n",
    "                \"causality\": 0.0273,\n",
    "                \"n_correct\": 312,\n",
    "                \"out_of\": 512,\n",
    "                \"delta_logit\": 6.8508\n",
    "            },\n",
    "            \"Select First\": {\n",
    "                \"causality\": 0.0039,\n",
    "                \"n_correct\": 194,\n",
    "                \"out_of\": 512,\n",
    "                \"delta_logit\": 2.8218\n",
    "            },\n",
    "            \"Select Last\": {\n",
    "                \"causality\": 0.0039,\n",
    "                \"n_correct\": 299,\n",
    "                \"out_of\": 511,\n",
    "                \"delta_logit\": 2.9157\n",
    "            },\n",
    "            \"Select One - MCQ\": {\n",
    "                \"causality\": 0.6074,\n",
    "                \"n_correct\": 195,\n",
    "                \"out_of\": 512,\n",
    "                \"delta_logit\": 1.7144\n",
    "            },\n",
    "            \"Counting\": {\n",
    "                \"causality\": 0.0820,\n",
    "                \"n_correct\": 177,\n",
    "                \"out_of\": 512,\n",
    "                \"delta_logit\": 2.3633\n",
    "            },\n",
    "            \"CheckPresence\": {\n",
    "                \"causality\": 0.0332,\n",
    "                \"n_correct\": 9,\n",
    "                \"out_of\": 512,\n",
    "                \"delta_logit\": 0.4086\n",
    "            }\n",
    "        }\n",
    "    },\n",
    "    \"CheckPresence\": {\n",
    "        \"num_heads\": 25,\n",
    "        \"evaluation\": {\n",
    "            \"Select One\": {\n",
    "                \"causality\": 0.0430,\n",
    "                \"n_correct\": 42,\n",
    "                \"out_of\": 512,\n",
    "                \"delta_logit\": 4.5545\n",
    "            },\n",
    "            \"Select First\": {\n",
    "                \"causality\": 0.0098,\n",
    "                \"n_correct\": 49,\n",
    "                \"out_of\": 512,\n",
    "                \"delta_logit\": 4.3595\n",
    "            },\n",
    "            \"Select Last\": {\n",
    "                \"causality\": 0.0117,\n",
    "                \"n_correct\": 76,\n",
    "                \"out_of\": 511,\n",
    "                \"delta_logit\": 4.4409\n",
    "            },\n",
    "            \"Select One - MCQ\": {\n",
    "                \"causality\": 0.7129,\n",
    "                \"n_correct\": 183,\n",
    "                \"out_of\": 512,\n",
    "                \"delta_logit\": 2.0686\n",
    "            },\n",
    "            \"Counting\": {\n",
    "                \"causality\": 0.0529,\n",
    "                \"n_correct\": 39,\n",
    "                \"out_of\": 512,\n",
    "                \"delta_logit\": 1.5039\n",
    "            },\n",
    "            \"CheckPresence\": {\n",
    "                \"causality\": 0.0508,\n",
    "                \"n_correct\": 9,\n",
    "                \"out_of\": 512,\n",
    "                \"delta_logit\": 0.1039\n",
    "            }\n",
    "        }\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "773f1132",
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib import pyplot as plt\n",
    "import os\n",
    "import numpy as np\n",
    "from matplotlib.colors import LinearSegmentedColormap\n",
    "from matplotlib.patches import Rectangle\n",
    "\n",
    "plt.rcdefaults()\n",
    "plt.rcParams[\"figure.dpi\"] = 200\n",
    "plt.rcParams[\"font.family\"] = \"Times New Roman\"\n",
    "\n",
    "SMALL_SIZE = 17\n",
    "MEDIUM_SIZE = 20\n",
    "BIGGER_SIZE = 25\n",
    "\n",
    "plt.rc(\"font\", size=SMALL_SIZE)  # controls default text sizes\n",
    "plt.rc(\"axes\", titlesize=BIGGER_SIZE)  # fontsize of the axes title\n",
    "plt.rc(\"axes\", labelsize=BIGGER_SIZE)  # fontsize of the x and y labels\n",
    "plt.rc(\"xtick\", labelsize=MEDIUM_SIZE)  # fontsize of the tick labels\n",
    "plt.rc(\"ytick\", labelsize=MEDIUM_SIZE)  # fontsize of the tick labels\n",
    "plt.rc(\"legend\", fontsize=SMALL_SIZE)  # legend fontsize\n",
    "plt.rc(\"figure\", titlesize=BIGGER_SIZE)  # fontsize of the figure title\n",
    "\n",
    "\n",
    "# Extract labels for the axes\n",
    "labels = [\n",
    "    \"Select One\",\n",
    "    \"Select One - MCQ\",\n",
    "    \"Select First\",\n",
    "    \"Select Last\",\n",
    "    \"Counting\",\n",
    "    # \"Yes/No\"\n",
    "    \"CheckPresence\"\n",
    "]\n",
    "num_labels = len(labels)\n",
    "\n",
    "# Create a 2D numpy array to hold the causality values\n",
    "causality_matrix = np.zeros((num_labels, num_labels))\n",
    "\n",
    "# for i, (trained_on, trained_data) in enumerate(train_vs_evaluation.items()):\n",
    "#     for j, (evaluated_on, eval_data) in enumerate(trained_data[\"evaluation\"].items()):\n",
    "#         # Ensure the order of evaluation labels matches the primary labels list\n",
    "#         if evaluated_on in labels:\n",
    "#             j_idx = labels.index(evaluated_on)\n",
    "#             causality_matrix[i, j_idx] = eval_data[\"causality\"]\n",
    "for i, trained_on in enumerate(labels):\n",
    "    num_heads = train_vs_evaluation[trained_on][\"num_heads\"]\n",
    "    eval_data = train_vs_evaluation[trained_on][\"evaluation\"]\n",
    "    for j, evaluated_on in enumerate(labels):\n",
    "        causality_matrix[i, j] = eval_data[evaluated_on][\"causality\"]\n",
    "\n",
    "# Create the plot\n",
    "fig, ax = plt.subplots(figsize=(10, 8))\n",
    "colors = [(1, 0, 0), (0, 1, 0)] # Red -> Green\n",
    "# colors = [\"#ff0000\", \"#4D7300\"]\n",
    "custom_cmap = LinearSegmentedColormap.from_list('green_to_red_cmap', colors, N=256)\n",
    "im = ax.imshow(causality_matrix, cmap=custom_cmap, vmin=0, vmax=1)\n",
    "\n",
    "# Set ticks and labels\n",
    "ax.set_xticks(np.arange(num_labels))\n",
    "ax.set_yticks(np.arange(num_labels))\n",
    "ax.set_xticklabels([label.replace(\" \", \"\") for label in labels])\n",
    "ax.set_yticklabels([f'{label.replace(\" \", \"\")} ({train_vs_evaluation[label][\"num_heads\"]})' for label in labels])\n",
    "\n",
    "# Rotate the tick labels and set their alignment\n",
    "plt.setp(ax.get_xticklabels(), rotation=45, ha=\"right\",\n",
    "         rotation_mode=\"anchor\")\n",
    "\n",
    "# Loop over data dimensions and create text annotations\n",
    "for i in range(num_labels):\n",
    "    for j in range(num_labels):\n",
    "        # Use a contrasting color for the text\n",
    "        # text_color = \"w\" if causality_matrix[i, j] < 0.5 else \"k\"\n",
    "        text_color = 'k'\n",
    "        text = ax.text(j, i, f\"{causality_matrix[i, j]:.2f}\",\n",
    "                       ha=\"center\", va=\"center\", color=text_color)\n",
    "\n",
    "# Add black borders around diagonal cells (where i == j)\n",
    "for i in range(num_labels):\n",
    "    rect = Rectangle((i - 0.5, i - 0.5), 1, 1, \n",
    "                     linewidth=3, edgecolor='black', facecolor='none')\n",
    "    ax.add_patch(rect)\n",
    "\n",
    "# Add a colorbar\n",
    "cbar = fig.colorbar(im)\n",
    "# cbar.set_label(\"Causality Value\", rotation=-90, va=\"bottom\")\n",
    "\n",
    "\n",
    "ax.set_xlabel(\"Evaluated On\")\n",
    "ax.set_ylabel(\"Trained On\")\n",
    "# ax.set_title(\"Causality Confusion Matrix\")\n",
    "\n",
    "# Adjust layout to prevent labels from being cut off\n",
    "fig.tight_layout()\n",
    "\n",
    "# Save the figure\n",
    "save_path = os.path.join(\"figures\")\n",
    "os.makedirs(save_path, exist_ok=True)\n",
    "plt.savefig(os.path.join(save_path, \"gemma-head_transfer.pdf\"), bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6815b2d9",
   "metadata": {},
   "outputs": [],
   "source": [
    "cached_from_and_patched_to = {\n",
    "    \"Select One\": {\n",
    "        \"evaluation\": {\n",
    "            \"Select One\": {\n",
    "                \"causality\": 0.8633,\n",
    "                \"n_correct\": 884,\n",
    "                \"out_of\": 1024,\n",
    "                \"delta_logit\": 9.0276,\n",
    "                \"std\": 3.4362\n",
    "            },\n",
    "            \"Select First\": {\n",
    "                \"causality\": 0.5977,\n",
    "                \"n_correct\": 306,\n",
    "                \"out_of\": 512,\n",
    "                \"delta_logit\": 6.8452,\n",
    "                \"std\": 3.1353\n",
    "            },\n",
    "            \"Select Last\": {\n",
    "                \"causality\": 0.5781,\n",
    "                \"n_correct\": 296,\n",
    "                \"out_of\": 512,\n",
    "                \"delta_logit\": 6.0331,\n",
    "                \"std\": 2.6816\n",
    "            },\n",
    "            \"Select One - MCQ\": {\n",
    "                \"causality\": 0.8164,\n",
    "                \"n_correct\": 418,\n",
    "                \"out_of\": 512,\n",
    "                \"delta_logit\": 4.7292,\n",
    "                \"std\": 1.8834\n",
    "            },\n",
    "            \"Counting\": {\n",
    "                \"causality\": 0.0625,\n",
    "                \"n_correct\": 16,\n",
    "                \"out_of\": 256,\n",
    "                \"delta_logit\": 1.1687,\n",
    "                \"std\": 1.2580\n",
    "            },\n",
    "            \"Yes/No\": {\n",
    "                \"causality\": 0.0508,\n",
    "                \"n_correct\": 13,\n",
    "                \"out_of\": 256,\n",
    "                \"delta_logit\": 0.5586,\n",
    "                \"std\": 0.9411\n",
    "            }\n",
    "        }\n",
    "    },\n",
    "    \"Select First\": {\n",
    "        \"evaluation\": {\n",
    "            \"Select One\": {\n",
    "                \"causality\": 0.6270,\n",
    "                \"n_correct\": 321,\n",
    "                \"out_of\": 512,\n",
    "                \"delta_logit\": 7.2968,\n",
    "                \"std\": 3.8706\n",
    "            },\n",
    "            \"Select First\": {\n",
    "                \"causality\": 0.7285,\n",
    "                \"n_correct\": 373,\n",
    "                \"out_of\": 512,\n",
    "                \"delta_logit\": 7.6039,\n",
    "                \"std\": 3.2674\n",
    "            },\n",
    "            \"Select Last\": {\n",
    "                \"causality\": 0.2812,\n",
    "                \"n_correct\": 144,\n",
    "                \"out_of\": 512,\n",
    "                \"delta_logit\": 4.8309,\n",
    "                \"std\": 3.8327\n",
    "            },\n",
    "            \"Select One - MCQ\": {\n",
    "                \"causality\": 0.5762,\n",
    "                \"n_correct\": 295,\n",
    "                \"out_of\": 512,\n",
    "                \"delta_logit\": 3.6907,\n",
    "                \"std\": 3.0142\n",
    "            },\n",
    "            \"Counting\": {\n",
    "                \"causality\": 0.1367,\n",
    "                \"n_correct\": 35,\n",
    "                \"out_of\": 256,\n",
    "                \"delta_logit\": 1.0845,\n",
    "                \"std\": 2.4030\n",
    "            },\n",
    "            \"Yes/No\": {\n",
    "                \"causality\": 0.0664,\n",
    "                \"n_correct\": 17,\n",
    "                \"out_of\": 256,\n",
    "                \"delta_logit\": 0.7715,\n",
    "                \"std\": 0.8603\n",
    "            }\n",
    "        }\n",
    "    },\n",
    "    \"Select Last\": {\n",
    "        \"evaluation\": {\n",
    "            \"Select One\": {\n",
    "                \"causality\": 0.5977,\n",
    "                \"n_correct\": 306,\n",
    "                \"out_of\": 512,\n",
    "                \"delta_logit\": 7.3607,\n",
    "                \"std\": 4.4652\n",
    "            },\n",
    "            \"Select First\": {\n",
    "                \"causality\": 0.1719,\n",
    "                \"n_correct\": 88,\n",
    "                \"out_of\": 512,\n",
    "                \"delta_logit\": 4.4881,\n",
    "                \"std\": 4.6115\n",
    "            },\n",
    "            \"Select Last\": {\n",
    "                \"causality\": 0.8789,\n",
    "                \"n_correct\": 450,\n",
    "                \"out_of\": 512,\n",
    "                \"delta_logit\": 8.1340,\n",
    "                \"std\": 3.0683\n",
    "            },\n",
    "            \"Select One - MCQ\": {\n",
    "                \"causality\": 0.6621,\n",
    "                \"n_correct\": 339,\n",
    "                \"out_of\": 512,\n",
    "                \"delta_logit\": 3.9371,\n",
    "                \"std\": 2.5485\n",
    "            },\n",
    "            \"Counting\": {\n",
    "                \"causality\": 0.1289,\n",
    "                \"n_correct\": 33,\n",
    "                \"out_of\": 256,\n",
    "                \"delta_logit\": 0.5649,\n",
    "                \"std\": 2.4738\n",
    "            },\n",
    "            \"Yes/No\": {\n",
    "                \"causality\": 0.1172,\n",
    "                \"n_correct\": 30,\n",
    "                \"out_of\": 256,\n",
    "                \"delta_logit\": 0.6084,\n",
    "                \"std\": 1.2738\n",
    "            }\n",
    "        }\n",
    "    },\n",
    "    \"Select One - MCQ\": {\n",
    "        \"evaluation\": {\n",
    "            \"Select One\": {\n",
    "                \"causality\": 0.5957,\n",
    "                \"n_correct\": 305,\n",
    "                \"out_of\": 512,\n",
    "                \"delta_logit\": 7.2503,\n",
    "                \"std\": 3.1097\n",
    "            },\n",
    "            \"Select First\": {\n",
    "                \"causality\": 0.2188,\n",
    "                \"n_correct\": 112,\n",
    "                \"out_of\": 512,\n",
    "                \"delta_logit\": 5.2869,\n",
    "                \"std\": 3.1375\n",
    "            },\n",
    "            \"Select Last\": {\n",
    "                \"causality\": 0.4141,\n",
    "                \"n_correct\": 212,\n",
    "                \"out_of\": 512,\n",
    "                \"delta_logit\": 5.5187,\n",
    "                \"std\": 2.3622\n",
    "            },\n",
    "            \"Select One - MCQ\": {\n",
    "                \"causality\": 0.9043,\n",
    "                \"n_correct\": 463,\n",
    "                \"out_of\": 512,\n",
    "                \"delta_logit\": 5.2159,\n",
    "                \"std\": 1.6849\n",
    "            },\n",
    "            \"Counting\": {\n",
    "                \"causality\": 0.0430,\n",
    "                \"n_correct\": 11,\n",
    "                \"out_of\": 256,\n",
    "                \"delta_logit\": 0.9060,\n",
    "                \"std\": 1.1908\n",
    "            },\n",
    "            \"Yes/No\": {\n",
    "                \"causality\": 0.1172,\n",
    "                \"n_correct\": 30,\n",
    "                \"out_of\": 256,\n",
    "                \"delta_logit\": 0.2878,\n",
    "                \"std\": 0.6663\n",
    "            }\n",
    "        }\n",
    "    },\n",
    "    \"Counting\": {\n",
    "        \"evaluation\": {\n",
    "            \"Select One\": {\n",
    "                \"causality\": 0.6016,\n",
    "                \"n_correct\": 308,\n",
    "                \"out_of\": 512,\n",
    "                \"delta_logit\": 6.9536,\n",
    "                \"std\": 3.4198\n",
    "            },\n",
    "            \"Select First\": {\n",
    "                \"causality\": 0.3379,\n",
    "                \"n_correct\": 173,\n",
    "                \"out_of\": 512,\n",
    "                \"delta_logit\": 5.0867,\n",
    "                \"std\": 3.2869\n",
    "            },\n",
    "            \"Select Last\": {\n",
    "                \"causality\": 0.5020,\n",
    "                \"n_correct\": 257,\n",
    "                \"out_of\": 512,\n",
    "                \"delta_logit\": 5.4534,\n",
    "                \"std\": 2.8721\n",
    "            },\n",
    "            \"Select One - MCQ\": {\n",
    "                \"causality\": 0.6582,\n",
    "                \"n_correct\": 337,\n",
    "                \"out_of\": 512,\n",
    "                \"delta_logit\": 3.4985,\n",
    "                \"std\": 2.2548\n",
    "            },\n",
    "            \"Counting\": {\n",
    "                \"causality\": 0.3555,\n",
    "                \"n_correct\": 91,\n",
    "                \"out_of\": 256,\n",
    "                \"delta_logit\": 3.2095,\n",
    "                \"std\": 2.3713\n",
    "            },\n",
    "            \"Yes/No\": {\n",
    "                \"causality\": 0.1211,\n",
    "                \"n_correct\": 31,\n",
    "                \"out_of\": 256,\n",
    "                \"delta_logit\": 0.3105,\n",
    "                \"std\": 0.5622\n",
    "            }\n",
    "        }\n",
    "    },\n",
    "    \"Yes/No\": {\n",
    "        \"evaluation\": {\n",
    "            \"Select One\": {\n",
    "                \"causality\": 0.0195,\n",
    "                \"n_correct\": 10,\n",
    "                \"out_of\": 512,\n",
    "                \"delta_logit\": 3.3517,\n",
    "                \"std\": 2.0371\n",
    "            },\n",
    "            \"Select First\": {\n",
    "                \"causality\": 0.0098,\n",
    "                \"n_correct\": 5,\n",
    "                \"out_of\": 512,\n",
    "                \"delta_logit\": 3.0789,\n",
    "                \"std\": 2.3057\n",
    "            },\n",
    "            \"Select Last\": {\n",
    "                \"causality\": 0.0527,\n",
    "                \"n_correct\": 27,\n",
    "                \"out_of\": 512,\n",
    "                \"delta_logit\": 2.9786,\n",
    "                \"std\": 2.1673\n",
    "            },\n",
    "            \"Select One - MCQ\": {\n",
    "                \"causality\": 0.1250,\n",
    "                \"n_correct\": 64,\n",
    "                \"out_of\": 512,\n",
    "                \"delta_logit\": 0.8182,\n",
    "                \"std\": 1.2771\n",
    "            },\n",
    "            \"Counting\": {\n",
    "                \"causality\": 0.0469,\n",
    "                \"n_correct\": 12,\n",
    "                \"out_of\": 256,\n",
    "                \"delta_logit\": 0.3564,\n",
    "                \"std\": 1.2823\n",
    "            },\n",
    "            \"Yes/No\": {\n",
    "                \"causality\": 0.0898,\n",
    "                \"n_correct\": 23,\n",
    "                \"out_of\": 256,\n",
    "                \"delta_logit\": 0.1531,\n",
    "                \"std\": 0.5732\n",
    "            }\n",
    "        }\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aaea99be",
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib import pyplot as plt\n",
    "import os\n",
    "import numpy as np\n",
    "from matplotlib.colors import LinearSegmentedColormap\n",
    "from matplotlib.patches import Rectangle\n",
    "\n",
    "plt.rcdefaults()\n",
    "plt.rcParams[\"figure.dpi\"] = 200\n",
    "plt.rcParams[\"font.family\"] = \"Times New Roman\"\n",
    "\n",
    "SMALL_SIZE = 17\n",
    "MEDIUM_SIZE = 20\n",
    "BIGGER_SIZE = 25\n",
    "\n",
    "plt.rc(\"font\", size=SMALL_SIZE)  # controls default text sizes\n",
    "plt.rc(\"axes\", titlesize=BIGGER_SIZE)  # fontsize of the axes title\n",
    "plt.rc(\"axes\", labelsize=BIGGER_SIZE)  # fontsize of the x and y labels\n",
    "plt.rc(\"xtick\", labelsize=MEDIUM_SIZE)  # fontsize of the tick labels\n",
    "plt.rc(\"ytick\", labelsize=MEDIUM_SIZE)  # fontsize of the tick labels\n",
    "plt.rc(\"legend\", fontsize=SMALL_SIZE)  # legend fontsize\n",
    "plt.rc(\"figure\", titlesize=BIGGER_SIZE)  # fontsize of the figure title\n",
    "\n",
    "\n",
    "# Extract labels for the axes\n",
    "labels = [\n",
    "    \"Select One\",\n",
    "    \"Select One - MCQ\",\n",
    "    \"Select First\",\n",
    "    \"Select Last\",\n",
    "    \"Counting\",\n",
    "    \"Yes/No\"\n",
    "]\n",
    "num_labels = len(labels)\n",
    "\n",
    "# Create a 2D numpy array to hold the causality values\n",
    "causality_matrix = np.zeros((num_labels, num_labels))\n",
    "\n",
    "for i, trained_on in enumerate(labels):\n",
    "    eval_data = cached_from_and_patched_to[trained_on][\"evaluation\"]\n",
    "    for j, evaluated_on in enumerate(labels):\n",
    "        causality_matrix[i, j] = eval_data[evaluated_on][\"causality\"]\n",
    "\n",
    "# Create the plot\n",
    "fig, ax = plt.subplots(figsize=(10, 8))\n",
    "colors = [(1, 0, 0), (0, 1, 0)] # Red -> Green\n",
    "# colors = [\"#ff0000\", \"#4D7300\"]\n",
    "custom_cmap = LinearSegmentedColormap.from_list('green_to_red_cmap', colors, N=256)\n",
    "im = ax.imshow(causality_matrix, cmap=custom_cmap, vmin=0, vmax=1)\n",
    "\n",
    "# Set ticks and labels\n",
    "ax.set_xticks(np.arange(num_labels))\n",
    "ax.set_yticks(np.arange(num_labels))\n",
    "ax.set_xticklabels([label.replace(\" \", \"\") for label in labels])\n",
    "ax.set_yticklabels([f'{label.replace(\" \", \"\")} ({train_vs_evaluation[label][\"num_heads\"]})' for label in labels])\n",
    "\n",
    "# Rotate the tick labels and set their alignment\n",
    "plt.setp(ax.get_xticklabels(), rotation=45, ha=\"right\",\n",
    "         rotation_mode=\"anchor\")\n",
    "\n",
    "# Loop over data dimensions and create text annotations\n",
    "for i in range(num_labels):\n",
    "    for j in range(num_labels):\n",
    "        # Use a contrasting color for the text\n",
    "        # text_color = \"w\" if causality_matrix[i, j] < 0.5 else \"k\"\n",
    "        text_color = 'k'\n",
    "        text = ax.text(j, i, f\"{causality_matrix[i, j]:.2f}\",\n",
    "                       ha=\"center\", va=\"center\", color=text_color)\n",
    "\n",
    "# Add black borders around diagonal cells (where i == j)\n",
    "for i in range(num_labels):\n",
    "    rect = Rectangle((i - 0.5, i - 0.5), 1, 1, \n",
    "                     linewidth=3, edgecolor='black', facecolor='none')\n",
    "    ax.add_patch(rect)\n",
    "\n",
    "# Add a colorbar\n",
    "cbar = fig.colorbar(im)\n",
    "cbar.set_label(\"Causality Value\", rotation=-90, va=\"bottom\")\n",
    "\n",
    "\n",
    "ax.set_xlabel(\"Patched To\")\n",
    "ax.set_ylabel(\"Cached From\")\n",
    "# ax.set_title(\"Causality Confusion Matrix\")\n",
    "\n",
    "# Adjust layout to prevent labels from being cut off\n",
    "fig.tight_layout()\n",
    "\n",
    "# Save the figure\n",
    "save_path = os.path.join(\"figures\")\n",
    "os.makedirs(save_path, exist_ok=True)\n",
    "plt.savefig(os.path.join(save_path, \"pred_transfer.pdf\"), bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a31ccea5",
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib import pyplot as plt\n",
    "import os\n",
    "\n",
    "plt.rcdefaults()\n",
    "plt.rcParams[\"figure.dpi\"] = 200\n",
    "plt.rcParams[\"font.family\"] = \"Times New Roman\"\n",
    "\n",
    "SMALL_SIZE = 15\n",
    "MEDIUM_SIZE = 20\n",
    "BIGGER_SIZE = 25\n",
    "\n",
    "plt.rc(\"font\", size=SMALL_SIZE)  # controls default text sizes\n",
    "plt.rc(\"axes\", titlesize=SMALL_SIZE)  # fontsize of the axes title\n",
    "plt.rc(\"axes\", labelsize=SMALL_SIZE)  # fontsize of the x and y labels\n",
    "plt.rc(\"xtick\", labelsize=SMALL_SIZE)  # fontsize of the tick labels\n",
    "plt.rc(\"ytick\", labelsize=SMALL_SIZE)  # fontsize of the tick labels\n",
    "plt.rc(\"legend\", fontsize=SMALL_SIZE)  # legend fontsize\n",
    "plt.rc(\"figure\", titlesize=BIGGER_SIZE)  # fontsize of the figure title\n",
    "\n",
    "scores = {\n",
    "    \"FilterScore\": {\n",
    "        \"score\": 0.6680,\n",
    "        \"n_correct\": 342,\n",
    "        \"out_of\": 512\n",
    "    },\n",
    "    \"CMA\": {\n",
    "        \"score\": 0.5879,\n",
    "        \"n_correct\": 301,\n",
    "        \"out_of\": 512\n",
    "    },\n",
    "    \"CMA + DCM\": {\n",
    "        \"score\": 0.8633,\n",
    "        \"n_correct\": 884,\n",
    "        \"out_of\": 1024\n",
    "    }\n",
    "}\n",
    "\n",
    "\n",
    "plt.figure(figsize=(4, 4))\n",
    "bars = plt.bar(scores.keys(), [v[\"score\"] for v in scores.values()])\n",
    "plt.ylim(0, 1)\n",
    "for bar in bars:\n",
    "    height = bar.get_height()\n",
    "    plt.text(bar.get_x() + bar.get_width() / 2, height, f\"{height:.2f}\", ha='center', va='bottom')\n",
    "\n",
    "\n",
    "# Save the figure\n",
    "save_path = os.path.join(\"figures\")\n",
    "os.makedirs(save_path, exist_ok=True)\n",
    "plt.savefig(os.path.join(save_path, \"causality_diff_approach.pdf\"), bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "053dffe2",
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib import pyplot as plt\n",
    "import os\n",
    "\n",
    "plt.rcdefaults()\n",
    "plt.rcParams[\"figure.dpi\"] = 200\n",
    "plt.rcParams[\"font.family\"] = \"Times New Roman\"\n",
    "\n",
    "SMALL_SIZE = 15\n",
    "MEDIUM_SIZE = 20\n",
    "BIGGER_SIZE = 25\n",
    "\n",
    "plt.rc(\"font\", size=SMALL_SIZE)  # controls default text sizes\n",
    "plt.rc(\"axes\", titlesize=SMALL_SIZE)  # fontsize of the axes title\n",
    "plt.rc(\"axes\", labelsize=SMALL_SIZE)  # fontsize of the x and y labels\n",
    "plt.rc(\"xtick\", labelsize=SMALL_SIZE)  # fontsize of the tick labels\n",
    "plt.rc(\"ytick\", labelsize=SMALL_SIZE)  # fontsize of the tick labels\n",
    "plt.rc(\"legend\", fontsize=SMALL_SIZE)  # legend fontsize\n",
    "plt.rc(\"figure\", titlesize=BIGGER_SIZE)  # fontsize of the figure title\n",
    "\n",
    "scores = {\n",
    "    \"w/o Avg\": {\n",
    "        \"score\": 0.7871,\n",
    "        \"n_correct\": 403,\n",
    "        \"out_of\": 512\n",
    "    },\n",
    "    \"Avg\": {\n",
    "        \"score\": 0.863,\n",
    "        \"n_correct\": 442,\n",
    "        \"out_of\": 512\n",
    "    },\n",
    "}\n",
    "\n",
    "\n",
    "plt.figure(figsize=(4, 4))\n",
    "bars = plt.bar(scores.keys(), [v[\"score\"] for v in scores.values()])\n",
    "plt.ylim(0, 1)\n",
    "for bar in bars:\n",
    "    height = bar.get_height()\n",
    "    plt.text(bar.get_x() + bar.get_width() / 2, height, f\"{height:.2f}\", ha='center', va='bottom')\n",
    "\n",
    "\n",
    "# Save the figure\n",
    "save_path = os.path.join(\"figures\")\n",
    "os.makedirs(save_path, exist_ok=True)\n",
    "plt.savefig(os.path.join(save_path, \"avg_trick_scores.pdf\"), bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d81ebb60",
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib import pyplot as plt\n",
    "import os\n",
    "\n",
    "plt.rcdefaults()\n",
    "plt.rcParams[\"figure.dpi\"] = 200\n",
    "plt.rcParams[\"font.family\"] = \"Times New Roman\"\n",
    "\n",
    "SMALL_SIZE = 15\n",
    "MEDIUM_SIZE = 20\n",
    "BIGGER_SIZE = 25\n",
    "\n",
    "plt.rc(\"font\", size=SMALL_SIZE)  # controls default text sizes\n",
    "plt.rc(\"axes\", titlesize=SMALL_SIZE)  # fontsize of the axes title\n",
    "plt.rc(\"axes\", labelsize=MEDIUM_SIZE)  # fontsize of the x and y labels\n",
    "plt.rc(\"xtick\", labelsize=SMALL_SIZE)  # fontsize of the tick labels\n",
    "plt.rc(\"ytick\", labelsize=SMALL_SIZE)  # fontsize of the tick labels\n",
    "plt.rc(\"legend\", fontsize=SMALL_SIZE)  # legend fontsize\n",
    "plt.rc(\"figure\", titlesize=BIGGER_SIZE)  # fontsize of the figure title\n",
    "\n",
    "with open(\"figures/Llama-3.3-70B-Instruct/raw/probe_performance.json\") as f:\n",
    "    import json\n",
    "\n",
    "    probe_performance = json.load(f)\n",
    "\n",
    "probe_accuracy = probe_performance[\"out_of_place\"]\n",
    "logitlens_baseline = probe_performance[\"logit_lens_baseline\"]\n",
    "\n",
    "plt.figure(figsize=(10, 5))\n",
    "plt.plot(\n",
    "    list(probe_accuracy.keys()),\n",
    "    list(probe_accuracy.values()),\n",
    "    label=\"Out of Place\",\n",
    "    alpha=0.7,\n",
    "    linewidth=2,\n",
    ")\n",
    "plt.plot(\n",
    "    list(logitlens_baseline.keys()),\n",
    "    list(logitlens_baseline.values()),\n",
    "    label=\"Logit Lens Baseline\",\n",
    "    alpha=0.7,\n",
    "    linewidth=2,\n",
    ")\n",
    "plt.vlines(34, 0, 78, colors='black', linestyles='dashed', linewidth=0.5, alpha=0.5)\n",
    "plt.plot(\n",
    "    [34], [0.81], marker=\"*\", markersize=15, color=\"purple\", label=\"Inplace (0.81)\"\n",
    ")\n",
    "plt.xlabel(\"Layer Index\")\n",
    "plt.ylabel(\"Probe Accuracy\")\n",
    "plt.xticks(range(0, 79, 5))\n",
    "plt.ylim(0, 1)\n",
    "plt.legend()\n",
    "\n",
    "plt.savefig(os.path.join(\"figures\", \"probe_performance.pdf\"), bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1e1010a8",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "connection",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
