{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f8ec34ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bfae7235",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "\n",
    "import sys\n",
    "\n",
    "sys.path.append(\"../\")\n",
    "\n",
    "##################################################################\n",
    "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0,1,2,3,4,5,6,7\"\n",
    "##################################################################\n",
    "\n",
    "import logging\n",
    "from src.utils import logging_utils\n",
    "from src.utils import env_utils\n",
    "\n",
    "logger = logging.getLogger(__name__)\n",
    "\n",
    "logging.basicConfig(\n",
    "    level=logging.DEBUG,\n",
    "    format=logging_utils.DEFAULT_FORMAT,\n",
    "    datefmt=logging_utils.DEFAULT_DATEFMT,\n",
    "    stream=sys.stdout,\n",
    ")\n",
    "\n",
    "import torch\n",
    "import transformers\n",
    "\n",
    "logger.info(f\"{torch.__version__=}, {torch.version.cuda=}\")\n",
    "logger.info(\n",
    "    f\"{torch.cuda.is_available()=}, {torch.cuda.device_count()=}, {torch.cuda.get_device_name()=}\"\n",
    ")\n",
    "logger.info(f\"{transformers.__version__=}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc7720ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.utils.training_utils import get_device_map\n",
    "\n",
    "# model_key = \"meta-llama/Llama-3.2-3B\"\n",
    "# model_key = \"meta-llama/Llama-3.1-8B\"\n",
    "# model_key = \"meta-llama/Llama-3.1-70B-Instruct\"\n",
    "model_key = \"meta-llama/Llama-3.3-70B-Instruct\"\n",
    "# model_key = \"meta-llama/Llama-3.1-405B-Instruct\"\n",
    "\n",
    "# model_key = \"google/gemma-2-9b-it\"\n",
    "# model_key = \"google/gemma-3-12b-it\"\n",
    "# model_key = \"google/gemma-2-27b-it\"\n",
    "\n",
    "# model_key = \"deepseek-ai/DeepSeek-R1-Distill-Llama-8B\"\n",
    "\n",
    "# model_key = \"allenai/OLMo-2-1124-7B-Instruct\"\n",
    "# model_key = \"allenai/OLMo-7B-0424-hf\"\n",
    "\n",
    "# model_key = \"Qwen/Qwen2-7B\"\n",
    "# model_key = \"Qwen/Qwen2.5-14B-Instruct\"\n",
    "# model_key = \"Qwen/Qwen2.5-32B-Instruct\"\n",
    "# model_key = \"Qwen/Qwen2.5-72B-Instruct\"\n",
    "\n",
    "# model_key = \"Qwen/Qwen3-1.7B\"\n",
    "# model_key = \"Qwen/Qwen3-4B\"\n",
    "# model_key = \"Qwen/Qwen3-8B\"\n",
    "# model_key = \"Qwen/Qwen3-14B\"\n",
    "# model_key = \"Qwen/Qwen3-32B\"\n",
    "\n",
    "# device_map = get_device_map(model_key, 30, n_gpus=8)\n",
    "# device_map"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "683855df",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.models import ModelandTokenizer\n",
    "\n",
    "# from transformers import BitsAndBytesConfig\n",
    "\n",
    "mt = ModelandTokenizer(\n",
    "    model_key=model_key,\n",
    "    torch_dtype=torch.bfloat16,\n",
    "    # device_map=device_map,\n",
    "    device_map=\"auto\",\n",
    "    # quantization_config = BitsAndBytesConfig(\n",
    "    #     # load_in_4bit=True\n",
    "    #     load_in_8bit=True\n",
    "    # )\n",
    "    attn_implementation=\"eager\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "080021e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# file_path = os.path.join(\n",
    "#     env_utils.DEFAULT_DATA_DIR,\n",
    "#     \"selection\",\n",
    "#     # \"profession.json\"\n",
    "#     # \"nationality.json\"\n",
    "#     \"objects.json\",\n",
    "# )\n",
    "\n",
    "# with open(file_path, \"r\") as f:\n",
    "#     temp = json.load(f)\n",
    "\n",
    "# for cat in temp[\"categories\"]:\n",
    "#     temp[\"categories\"][cat] = [obj.capitalize() for obj in temp[\"categories\"][cat]]\n",
    "\n",
    "# with open(file_path, \"w\") as f:\n",
    "#     json.dump(temp, f, indent=4)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a62d97cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.data import (\n",
    "    SelectOneTask,\n",
    "    CountingTask,\n",
    "    YesNoTask,\n",
    "    SelectFirstTask,\n",
    "    SelectLastTask,\n",
    ")\n",
    "\n",
    "#################################################################################\n",
    "# TASK_CLS = CountingTask\n",
    "# prompt_template_idx = 1\n",
    "TASK_CLS = SelectOneTask\n",
    "# prompt_template_idx = 3\n",
    "# TASK_CLS = YesNoTask\n",
    "# prompt_template_idx = 3\n",
    "# TASK_CLS = SelectFirstTask\n",
    "# TASK_CLS = SelectLastTask\n",
    "prompt_template_idx = 3\n",
    "\n",
    "N_DISTRACTORS = 5\n",
    "OPTION_STYLE = \"single_line\"\n",
    "#################################################################################\n",
    "\n",
    "select_task = TASK_CLS.load(\n",
    "    path=os.path.join(env_utils.DEFAULT_DATA_DIR, \"selection\", \"objects.json\")\n",
    ")\n",
    "\n",
    "print(select_task)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "45087a7f",
   "metadata": {},
   "outputs": [],
   "source": [
    "sample = select_task.get_random_sample(\n",
    "    mt = mt,\n",
    "    option_style=OPTION_STYLE,\n",
    "    prompt_template_idx=prompt_template_idx,\n",
    "    # category=\"actor\",\n",
    "    # category=\"Brazil\"\n",
    "    category=\"fruit\",\n",
    "    filter_by_lm_prediction=False,\n",
    "    # exclude_distractor_categories=select_task.exclude_for_category(\"fruit\")\n",
    ")\n",
    "\n",
    "print(sample.prompt(), \">>\", f'\"{mt.tokenizer.decode([sample.ans_token_id])}\"')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea0499b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.functional import generate_with_patch\n",
    "\n",
    "gen = generate_with_patch(\n",
    "    mt = mt,\n",
    "    inputs = sample.prompt(),\n",
    "    max_new_tokens=20,\n",
    "    do_sample=False,\n",
    "    remove_prefix=True\n",
    ")[0]\n",
    "print(f'\"{gen}\"')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2e8a8aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "mt.n_layer, mt.config.num_attention_heads"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a35516e",
   "metadata": {},
   "outputs": [],
   "source": [
    "llama_70_heads = [\n",
    "    (33, 45),\n",
    "    (33, 18),\n",
    "    (34, 1),\n",
    "    (34, 6),\n",
    "    (34, 7),\n",
    "    (35, 19),\n",
    "    (39, 40),\n",
    "    (42, 30),\n",
    "    (47, 18),\n",
    "    (52, 58),\n",
    "]\n",
    "\n",
    "qwen_72_heads = [\n",
    "    (62, 1),\n",
    "    (60, 9),\n",
    "    (64, 8),\n",
    "    (62, 0),\n",
    "    (62, 45),\n",
    "    (59, 59),\n",
    "    (71, 28),\n",
    "    (64, 12),\n",
    "    (61, 7),\n",
    "    (64, 13),\n",
    "    (67, 53),\n",
    "    (67, 51),\n",
    "    (54, 44),\n",
    "    (57, 5),\n",
    "    (59, 60),\n",
    "    (71, 25),\n",
    "    (62, 7),\n",
    "    (64, 9),\n",
    "    (62, 23),\n",
    "    (65, 40),\n",
    "]\n",
    "\n",
    "qwen_32_heads = [\n",
    "    (51, 11),\n",
    "    (48, 4),\n",
    "    (52, 21),\n",
    "    (54, 35),\n",
    "    (48, 8),\n",
    "    (50, 6),\n",
    "    (48, 9),\n",
    "    (48, 32),\n",
    "    (52, 10),\n",
    "    (45, 11),\n",
    "    (45, 13),\n",
    "    (48, 34),\n",
    "    (53, 16),\n",
    "    (50, 12),\n",
    "    (49, 2),\n",
    "    (54, 38),\n",
    "    (55, 4),\n",
    "    (50, 27),\n",
    "    (54, 33),\n",
    "    (50, 14),\n",
    "]\n",
    "\n",
    "\n",
    "# HEADS = [(35, 19)]\n",
    "\n",
    "\n",
    "# with open(\"optimized_heads.json\", \"r\") as f:\n",
    "#     HEADS = json.load(f)\n",
    "\n",
    "# with open(\"category_wise_heads.json\", \"r\") as f:\n",
    "#     category_wise_heads = json.load(f)\n",
    "# HEADS = [\n",
    "#     (layer_idx, head_idx)\n",
    "#     for layer_idx, head_idx, score in category_wise_heads[\"all\"][:100]\n",
    "# ]\n",
    "# HEADS = [(layer_idx, head_idx) for layer_idx, head_idx in HEADS if layer_idx < 61]\n",
    "\n",
    "# HEADS = qwen_32_heads\n",
    "HEADS = llama_70_heads\n",
    "print(len(HEADS))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe381b2f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib import pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "# optimized_path = os.path.join(\n",
    "#     env_utils.DEFAULT_RESULTS_DIR,\n",
    "#     \"selection/optimized_backup_heads\",\n",
    "#     mt.name.split(\"/\")[-1],\n",
    "#     f\"{select_task.task_name}.npz\"\n",
    "# )\n",
    "\n",
    "optimized_path = os.path.join(\n",
    "    env_utils.DEFAULT_RESULTS_DIR,\n",
    "    \"selection/optimized_heads\",\n",
    "    model_key.split(\"/\")[-1],\n",
    "    \"distinct_options\",\n",
    "    # f\"{select_task.task_name}\",\n",
    "    \"select_one\",\n",
    "    # \"legacy\",\n",
    "    \"epoch_10.npz\"\n",
    ")\n",
    "\n",
    "# optimized_path = os.path.join(\n",
    "#     env_utils.DEFAULT_RESULTS_DIR,\n",
    "#     \"test_opt_code\",\n",
    "#     model_key.split(\"/\")[-1],\n",
    "#     \"distinct_options\",\n",
    "#     f\"{select_task.task_name}\",\n",
    "#     # \"select_one\",\n",
    "#     \"legacy\",\n",
    "#     \"epoch_10.npz\"\n",
    "# )\n",
    "\n",
    "optimization_results = np.load(optimized_path, allow_pickle=True)\n",
    "plt.plot(optimization_results[\"losses\"])\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f87810f",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(20, 10))\n",
    "\n",
    "optimal_head_mask = torch.tensor(optimization_results[\"optimal_mask\"]).to(torch.float32)\n",
    "optimal_head_mask[52:, :] = 0.0\n",
    "\n",
    "plt.imshow(\n",
    "    optimal_head_mask.T.numpy(),\n",
    "    cmap=\"Blues\",\n",
    "    aspect=\"auto\",\n",
    "    vmin=0,\n",
    "    vmax=1,\n",
    ")\n",
    "\n",
    "optimized_heads = torch.nonzero(optimal_head_mask > 0.5, as_tuple=False).tolist()\n",
    "optimized_heads = [\n",
    "    (layer_idx, head_idx) for layer_idx, head_idx in optimized_heads\n",
    "]\n",
    "print(len(optimized_heads))\n",
    "\n",
    "HEADS = optimized_heads\n",
    "\n",
    "(35, 19) in HEADS, (35, 19) in optimized_heads\n",
    "# [(29, 3) in HEADS]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a7ef0f5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# from src.tokens import find_token_range, prepare_input\n",
    "\n",
    "# string = sample.prompt()\n",
    "# substring = sample.obj\n",
    "\n",
    "# tokenized_prompt = prepare_input(\n",
    "#     prompts=string, \n",
    "#     tokenizer=mt.tokenizer, \n",
    "#     return_offsets_mapping=True,\n",
    "#     add_bos_token=\"qwen\" in mt.name.lower()\n",
    "# )\n",
    "# string = mt.tokenizer.decode(tokenized_prompt.input_ids[0], skip_special_tokens=False)\n",
    "# offset_mapping = tokenized_prompt.pop(\"offset_mapping\")[0]\n",
    "\n",
    "# ans_range = find_token_range(\n",
    "#     string=string,\n",
    "#     substring=substring,\n",
    "#     offset_mapping=offset_mapping\n",
    "# )\n",
    "# print(f\"Answer range: {ans_range}\")\n",
    "# print(f'\"{mt.tokenizer.decode(tokenized_prompt.input_ids[0][range(*ans_range)])}\"')\n",
    "# # for idx, (tok, offset_range) in enumerate(zip(tokenized_prompt.input_ids[0], offset_mapping)):\n",
    "# #     print(f\"Token {idx}: \\\"{mt.tokenizer.decode([tok])}\\\" -- {offset_range}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "85a9a2a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.attention import get_attention_matrices\n",
    "from src.selection.functional import (\n",
    "    verify_head_patterns,\n",
    "    get_patches_to_verify_independent_enrichment,\n",
    ")\n",
    "\n",
    "attn_pattern = verify_head_patterns(\n",
    "    prompt=sample.prompt(option_style=\"single_line\"),\n",
    "    options=sample.options,\n",
    "    mt=mt,\n",
    "    heads=optimized_heads,\n",
    "    # heads = HEADS,\n",
    "    # heads = [(35, 19)],\n",
    "    start_from=1\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "129fe333",
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy\n",
    "import random\n",
    "from src.selection.utils import KeyedSet, get_first_token_id, verify_correct_option\n",
    "from src.selection.data import SelectionSample\n",
    "from src.functional import predict_next_token\n",
    "from src.tokens import prepare_input\n",
    "\n",
    "######################################################################\n",
    "N_DISTRACTORS = 5\n",
    "######################################################################\n",
    "\n",
    "from src.selection.data import get_counterfactual_samples_within_task"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "425f6285",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.data import get_counterfactual_samples_interface\n",
    "\n",
    "counterfactual_sampler = get_counterfactual_samples_interface[select_task.task_name]\n",
    "\n",
    "# patch_prompt_template_idx = 3\n",
    "# clean_prompt_template_idx = 3\n",
    "\n",
    "\n",
    "patch_sample, clean_sample = counterfactual_sampler(\n",
    "    # patch_category=\"politician\",\n",
    "    # clean_category=\"actor\",\n",
    "    mt=mt,\n",
    "    task=select_task,\n",
    "    prompt_template_idx=prompt_template_idx,\n",
    "    filter_by_lm_prediction=True,\n",
    "    option_style=OPTION_STYLE,\n",
    "\n",
    "    patch_category=\"fruit\",\n",
    "    clean_category=\"vehicle\",\n",
    "    # distinct_options=True,\n",
    "    # patch_n_distractors=5,\n",
    "    # clean_n_distractors=5,\n",
    "    # patch_prompt_template_idx=patch_prompt_template_idx,\n",
    "    # clean_prompt_template_idx=clean_prompt_template_idx,\n",
    "    # patch_option_style=\"single_line\",\n",
    "    # clean_option_style=\"numbered\",\n",
    "    # n_options=5,\n",
    "    # patch_yes_mode=False\n",
    ")\n",
    "print('=' * 20)\n",
    "print(patch_sample.prompt(), \">>\", f'\"{mt.tokenizer.decode([patch_sample.ans_token_id])}\"')\n",
    "print(clean_sample.prompt(), \">>\", f'\"{mt.tokenizer.decode([clean_sample.ans_token_id])}\"')\n",
    "\n",
    "print(\n",
    "    clean_sample.metadata[\"track_type_obj\"], \n",
    "    clean_sample.metadata[\"track_type_obj_idx\"], \n",
    "    mt.tokenizer.decode(clean_sample.metadata[\"track_type_obj_token_id\"])\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d8b5380b",
   "metadata": {},
   "outputs": [],
   "source": [
    "clean_sample.prompt_template"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "86d35a26",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.data import MCQify_sample\n",
    "from src.selection.utils import get_first_token_id\n",
    "\n",
    "patch_sample.options = [\"Cherry\", \"Knife\", \"Pants\", \"Car\"]\n",
    "patch_sample.prompt_template = \"<_options_>\\nFind the <_category_>\\nAnswer:\"\n",
    "print(patch_sample.prompt())\n",
    "\n",
    "\n",
    "clean_sample.options = [\"Binder\", \"Peach\", \"Watch\", \"Scooter\", \"Phone\"]\n",
    "clean_sample.prompt_template = \"<_options_>\\nFind the <_category_>\\nAnswer:\"\n",
    "clean_sample.object = \"Scooter\"\n",
    "clean_sample.obj_idx = 3\n",
    "clean_sample.metadata[\"track_type_obj_token_id\"] = get_first_token_id(\n",
    "    name=\"Peach\", tokenizer=mt.tokenizer\n",
    ")\n",
    "clean_sample = MCQify_sample(sample=clean_sample, tokenizer=mt)\n",
    "print(clean_sample.prompt())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "510772fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.functional import generate_with_patch\n",
    "from src.selection.functional import verify_head_patterns\n",
    "\n",
    "# patch_sample.options[patch_sample.obj_idx] = \"Screw\"\n",
    "# patch_sample.options[patch_sample.obj_idx] = patch_sample.obj\n",
    "\n",
    "gold_sample = copy.deepcopy(patch_sample)\n",
    "gold_sample.options = clean_sample.options\n",
    "gold_sample.ans_token_id = clean_sample.metadata[\"track_type_obj_token_id\"]\n",
    "\n",
    "for sample in [patch_sample, clean_sample, gold_sample]:\n",
    "# for sample in [order_sample_1, order_sample_2]:\n",
    "    print(sample.prompt(), \">>\", f'\"{mt.tokenizer.decode([sample.ans_token_id])}\"')\n",
    "    attn_pattern = verify_head_patterns(\n",
    "        prompt=sample.prompt(),\n",
    "        options=sample.options,\n",
    "        mt=mt,\n",
    "        # heads=qwen_72_heads,\n",
    "        # heads=optimized_heads,\n",
    "        heads = [(35, 19)],\n",
    "        # heads=[(layer_idx, head_idx)],\n",
    "        # generate_full_answer=True,\n",
    "        query_index=-1\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0b4d487f",
   "metadata": {},
   "outputs": [],
   "source": [
    "clean_sample.metadata, patch_sample.metadata"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a35a6b4",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.functional import cache_q_projections\n",
    "from src.functional import interpret_logits\n",
    "from src.tokens import prepare_input, find_token_range\n",
    "\n",
    "prompts = [\n",
    "    patch_sample.prompt(),\n",
    "    clean_sample.prompt(),\n",
    "]\n",
    "\n",
    "tokenized = prepare_input(\n",
    "    prompts=prompts,\n",
    "    tokenizer=mt.tokenizer,\n",
    "    return_offsets_mapping=True,\n",
    ")\n",
    "\n",
    "offset_mapping = tokenized.pop(\"offset_mapping\")\n",
    "\n",
    "question_ranges = [\n",
    "    find_token_range(\n",
    "        string=prompt,\n",
    "        substring=\"?\",\n",
    "        occurrence=-1,\n",
    "        offset_mapping=offset,\n",
    "    )\n",
    "    for prompt, offset in zip(prompts, offset_mapping)\n",
    "]\n",
    "ques_pos = [rng[1]-1 for rng in question_ranges]\n",
    "for tok, q_pos in zip(tokenized.input_ids, ques_pos):\n",
    "    assert mt.tokenizer.decode(tok[q_pos]).strip() == \"?\"\n",
    "\n",
    "q_projections, out = cache_q_projections(\n",
    "    mt=mt,\n",
    "    input=tokenized,\n",
    "    heads=HEADS,\n",
    "    token_indices=[ [q_pos, -2, -1] for q_pos in ques_pos],\n",
    "    return_output=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3f6188d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenized"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "209511b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "patch_logits = out.logits[0, -1, :]\n",
    "interpret_logits(\n",
    "    logits=patch_logits,\n",
    "    tokenizer=mt.tokenizer,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "42df8aa2",
   "metadata": {},
   "outputs": [],
   "source": [
    "patch_logits = out.logits[1, -1, :]\n",
    "interpret_logits(\n",
    "    logits=patch_logits,\n",
    "    tokenizer=mt.tokenizer,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "91becd4a",
   "metadata": {},
   "outputs": [],
   "source": [
    "q_projections[1].keys()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f9962f8c",
   "metadata": {},
   "source": [
    "## Testing patching the query projection of a single head"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aed90f3a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# from src.selection.functional import cache_q_projections\n",
    "from src.functional import patch_with_baukit, PatchSpec, repeat_kv, get_module_nnsight\n",
    "from src.utils.typing import TokenizerOutput, Tokenizer\n",
    "from src.selection.functional import cache_q_projections, verify_head_patterns\n",
    "\n",
    "def find_quesmark_pos(\n",
    "    prompt: str, \n",
    "    tokenizer: Tokenizer,\n",
    "    tokenized: TokenizerOutput,\n",
    "    offset_mapping: list[tuple[int, int]] | None = None,\n",
    "    ques_mark: str = \"?\",\n",
    "):\n",
    "    if offset_mapping is None:\n",
    "        if tokenized is None or \"offset_mapping\" not in tokenized:\n",
    "            tokenized = prepare_input(\n",
    "                prompts=[prompt], \n",
    "                tokenizer=tokenizer, \n",
    "                return_offsets_mapping=True,\n",
    "            )\n",
    "        offset_mapping = tokenized.pop(\"offset_mapping\")[0]\n",
    "    \n",
    "    ques_range = find_token_range(\n",
    "        string=prompt,\n",
    "        substring=ques_mark,\n",
    "        occurrence=-1,\n",
    "        offset_mapping=offset_mapping\n",
    "    )\n",
    "    ques_pos = ques_range[1]-1\n",
    "    assert tokenizer.decode(tokenized.input_ids[0][ques_pos]).strip() == ques_mark\n",
    "    return ques_pos\n",
    "\n",
    "mt.set_attn_implementation(\"eager\")\n",
    "mt.reset_forward()\n",
    "\n",
    "test_heads = [(35, 19)]\n",
    "# test_heads = copy.deepcopy(optimized_heads)\n",
    "\n",
    "patch_tokenized = prepare_input(\n",
    "    prompts=patch_sample.prompt(), \n",
    "    tokenizer=mt,\n",
    "    return_offsets_mapping=True\n",
    ")\n",
    "patch_offsets = patch_tokenized.pop(\"offset_mapping\")[0]\n",
    "# patch_ques_pos = find_quesmark_pos(\n",
    "#     prompt=patch_sample.prompt(),\n",
    "#     tokenizer=mt.tokenizer,\n",
    "#     tokenized=patch_tokenized,\n",
    "#     offset_mapping=patch_offsets\n",
    "# )\n",
    "\n",
    "clean_tokenized = prepare_input(\n",
    "    prompts=clean_sample.prompt(), \n",
    "    tokenizer=mt,\n",
    "    return_offsets_mapping=True\n",
    ")\n",
    "clean_offsets = clean_tokenized.pop(\"offset_mapping\")[0]\n",
    "# clean_ques_pos = find_quesmark_pos(\n",
    "#     prompt=clean_sample.prompt(),\n",
    "#     tokenizer=mt.tokenizer,\n",
    "#     tokenized=clean_tokenized,\n",
    "#     offset_mapping=clean_offsets\n",
    "# )\n",
    "\n",
    "# indices = [patch_ques_pos, -2, -1]\n",
    "indices = [-3, -2, -1]\n",
    "\n",
    "q_states = cache_q_projections(\n",
    "    mt=mt,\n",
    "    input=patch_tokenized,\n",
    "    heads=test_heads,\n",
    "    token_indices=[indices],\n",
    ")[0]\n",
    "\n",
    "# map_indices = {patch_ques_pos: clean_ques_pos, -2: -2, -1: -1}\n",
    "map_indices = {-3: -3, -2: -2, -1: -1}\n",
    "q_patches = []\n",
    "for (l_idx, h_idx, patch_token_idx), q_proj in q_states.items():\n",
    "    q_patches.append(PatchSpec(\n",
    "        location=(\n",
    "            mt.attn_module_name_format.format(l_idx)+\".q_proj\",\n",
    "            h_idx,\n",
    "            map_indices[patch_token_idx]\n",
    "        ),\n",
    "        patch=q_proj.squeeze()\n",
    "    ))\n",
    "\n",
    "# The attention patterns for the patch sample should match exactly\n",
    "test_inplace_swap = verify_head_patterns(\n",
    "    prompt=patch_sample.prompt(),\n",
    "    mt=mt,\n",
    "    heads=test_heads,\n",
    "    # heads=qwen_72_heads,\n",
    "    tokenized_prompt=patch_tokenized,\n",
    "    query_patches=q_patches\n",
    ")\n",
    "\n",
    "test_predicate_swap = verify_head_patterns(\n",
    "    prompt=clean_sample.prompt(),\n",
    "    mt=mt,\n",
    "    heads=test_heads,\n",
    "    # heads=qwen_72_heads,\n",
    "    tokenized_prompt=clean_tokenized,\n",
    "    query_patches=q_patches\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0464d9eb",
   "metadata": {},
   "outputs": [],
   "source": [
    "len(test_heads)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f9286b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "(\n",
    "    (patch_ques_pos, mt.tokenizer.decode(patch_tokenized.input_ids[0][patch_ques_pos])),\n",
    "    (clean_ques_pos, mt.tokenizer.decode(clean_tokenized.input_ids[0][clean_ques_pos]))\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d5263a2d",
   "metadata": {},
   "outputs": [],
   "source": [
    "clean_tokenized.input_ids.shape, patch_tokenized.input_ids.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bf72a519",
   "metadata": {},
   "outputs": [],
   "source": [
    "import baukit\n",
    "from src.functional import get_module_nnsight, PatchSpec\n",
    "from src.hooking.llama_attention import LlamaAttentionPatcher\n",
    "import types\n",
    "from typing import Literal\n",
    "from src.tokens import prepare_input\n",
    "\n",
    "\n",
    "def set_attn_implementation(mt, attn_implementation: Literal[\"sdpa\", \"eager\"]):\n",
    "    mt.config._attn_implementation = attn_implementation\n",
    "    for layer_idx in range(mt.config.num_hidden_layers):\n",
    "        attn_block_name = mt.attn_module_name_format.format(layer_idx)\n",
    "        attn_block = baukit.get_module(mt._model, attn_block_name)\n",
    "        attn_block.config._attn_implementation = attn_implementation\n",
    "\n",
    "\n",
    "###################################################################################\n",
    "batch_size = 1  # tokenized.input_ids.shape[0]\n",
    "n_heads = mt.config.num_attention_heads\n",
    "head_dim = mt.n_embd // n_heads\n",
    "query_idx = -1 # almost always the last token\n",
    "###################################################################################\n",
    "\n",
    "mt.reset_forward()\n",
    "set_attn_implementation(mt, \"sdpa\")\n",
    "\n",
    "layer_idx, head_idx = HEADS[0]\n",
    "\n",
    "attn_block_name = mt.attn_module_name_format.format(layer_idx)\n",
    "attn_block = baukit.get_module(mt._model, attn_block_name)\n",
    "attn_block.forward = types.MethodType(\n",
    "    LlamaAttentionPatcher(block_name=attn_block_name),\n",
    "    attn_block,\n",
    ")\n",
    "\n",
    "patch_tokenized = prepare_input(prompts=patch_sample.prompt(), tokenizer=mt)\n",
    "patch_seq_len = patch_tokenized.input_ids.shape[1]\n",
    "input_ln = mt.layer_name_format.format(layer_idx) + \".input_layernorm\"\n",
    "\n",
    "with mt.trace(patch_tokenized) as trace:\n",
    "    ln_module = get_module_nnsight(mt, input_ln)\n",
    "    patch_ln = ln_module.output.save()\n",
    "\n",
    "    q_proj_name = mt.attn_module_name_format.format(layer_idx) + \".q_proj\"\n",
    "    q_proj_module = get_module_nnsight(mt, q_proj_name)\n",
    "    patch_q_proj = q_proj_module.output.view(batch_size, patch_seq_len, n_heads, head_dim).transpose(1, 2).save()\n",
    "    # patch_q_proj = PatchSpec(\n",
    "    #     location=(q_proj_name + f\".{head_idx}\", -1),\n",
    "    #     patch=patch_q_proj[:, head_idx, query_idx, :].squeeze().save()\n",
    "    # )\n",
    "\n",
    "clean_tokenized = prepare_input(prompts=clean_sample.prompt(), tokenizer=mt)\n",
    "clean_seq_len = clean_tokenized.input_ids.shape[1]\n",
    "with mt.trace(clean_tokenized) as trace:\n",
    "    ln_module = get_module_nnsight(mt, input_ln)\n",
    "    clean_ln = ln_module.output.save()\n",
    "\n",
    "    q_proj_name = mt.attn_module_name_format.format(layer_idx) + \".q_proj\"\n",
    "    q_proj_module = get_module_nnsight(mt, q_proj_name)\n",
    "    clean_q_proj = q_proj_module.output.view(batch_size, clean_seq_len, n_heads, head_dim).transpose(1, 2).save()\n",
    "    # clean_q_proj = PatchSpec(\n",
    "    #     location=(q_proj_name + f\".{head_idx}\", -1),\n",
    "    #     patch=clean_q_proj[:, head_idx, query_idx, :].squeeze().save()\n",
    "    # )\n",
    "\n",
    "mt.reset_forward()\n",
    "set_attn_implementation(mt, \"eager\")\n",
    "\n",
    "patch_q_proj.shape, clean_q_proj.shape, patch_ln.shape, clean_ln.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f2d10d7f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.attention import visualize_attn_matrix\n",
    "from src.functional import get_hs, interpret_logits\n",
    "\n",
    "mt.reset_forward()\n",
    "set_attn_implementation(mt, \"sdpa\")\n",
    "\n",
    "layer_idx, head_idx = 35, 19\n",
    "# layer_idx, head_idx = 62, 1\n",
    "\n",
    "attn_matrices = {layer_idx: {}}\n",
    "\n",
    "attn_block_name = mt.attn_module_name_format.format(layer_idx)\n",
    "attn_block = baukit.get_module(mt._model, attn_block_name)\n",
    "attn_block.forward = types.MethodType(\n",
    "    LlamaAttentionPatcher(\n",
    "        block_name=attn_block_name,\n",
    "        save_attn_for=[head_idx],\n",
    "        store_attn_matrices=attn_matrices[layer_idx],\n",
    "    ),\n",
    "    attn_block,\n",
    ")\n",
    "\n",
    "logit_location = (mt.lm_head_name, -1)\n",
    "logits = get_hs(\n",
    "    mt = mt,\n",
    "    input = clean_tokenized,\n",
    "    locations = [logit_location],\n",
    "    return_dict=False\n",
    ").squeeze()  # (seq_len, vocab_size)\n",
    "\n",
    "mt.reset_forward()\n",
    "set_attn_implementation(mt, \"eager\")\n",
    "\n",
    "head_matrix = attn_matrices[layer_idx][head_idx].squeeze().to(torch.float32).cpu().numpy()\n",
    "\n",
    "visualize_attn_matrix(\n",
    "    attn_matrix=head_matrix,\n",
    "    tokens=[mt.tokenizer.decode(t) for t in clean_tokenized.input_ids[0]],\n",
    "    q_index=-1,\n",
    ")\n",
    "\n",
    "interpret_logits(\n",
    "    tokenizer=mt,\n",
    "    logits=logits,\n",
    "    interested_tokens=[clean_sample.metadata[\"track_type_obj_token_id\"]],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b25e138e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# type(mt._model), query_idx\n",
    "head_idx, query_idx, patch_q_proj[:, head_idx, query_idx, :].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e560d91f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.attention import visualize_attn_matrix\n",
    "\n",
    "mt.reset_forward()\n",
    "set_attn_implementation(mt, \"sdpa\")\n",
    "\n",
    "attn_matrices = {layer_idx: {}}\n",
    "\n",
    "attn_block_name = mt.attn_module_name_format.format(layer_idx)\n",
    "attn_block = baukit.get_module(mt._model, attn_block_name)\n",
    "attn_block.forward = types.MethodType(\n",
    "    LlamaAttentionPatcher(\n",
    "        block_name=attn_block_name,\n",
    "        save_attn_for=[head_idx],\n",
    "        store_attn_matrices=attn_matrices[layer_idx],\n",
    "        query_patches=[(head_idx, query_idx, patch_q_proj[:, head_idx, query_idx, :].squeeze())],\n",
    "    ),\n",
    "    attn_block,\n",
    ")\n",
    "\n",
    "logit_location = (mt.lm_head_name, -1)\n",
    "patch_logits = get_hs(\n",
    "    mt = mt,\n",
    "    # input = clean_tokenized,\n",
    "    input=patch_tokenized,\n",
    "    locations = [logit_location],\n",
    "    return_dict=False\n",
    ").squeeze()  # (seq_len, vocab_size)\n",
    "\n",
    "mt.reset_forward()\n",
    "set_attn_implementation(mt, \"eager\")\n",
    "\n",
    "head_matrix = attn_matrices[layer_idx][head_idx].squeeze().to(torch.float32).cpu().numpy()\n",
    "\n",
    "visualize_attn_matrix(\n",
    "    attn_matrix=head_matrix,\n",
    "    # tokens=[mt.tokenizer.decode(t) for t in clean_tokenized.input_ids[0]],\n",
    "    tokens=[mt.tokenizer.decode(t) for t in patch_tokenized.input_ids[0]],\n",
    "    q_index=-1,\n",
    ")\n",
    "\n",
    "interpret_logits(\n",
    "    tokenizer=mt,\n",
    "    logits=patch_logits,\n",
    "    interested_tokens=[clean_sample.metadata[\"track_type_obj_token_id\"]],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3e922a85",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "231d29d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.allclose(patch_logits, logits, atol = 1e-3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b92a102f",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.allclose(\n",
    "    patch_q_proj[:, head_idx, query_idx, :], \n",
    "    clean_q_proj[:, head_idx, query_idx, :],\n",
    "    atol=1e-3\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e9042b71",
   "metadata": {},
   "outputs": [],
   "source": [
    "# torch.allclose(\n",
    "#     patch_ln[:, query_idx, :], \n",
    "#     clean_ln[:, query_idx, :],\n",
    "#     atol=1e-3\n",
    "# )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49583481",
   "metadata": {},
   "outputs": [],
   "source": [
    "# patch_ln[:, query_idx, :], clean_ln[:, query_idx, :]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "04ead16d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# # manual calculation\n",
    "# attn_module = baukit.get_module(mt._model, mt.attn_module_name_format.format(layer_idx))\n",
    "# patch_q_proj_manual = attn_module.q_proj(patch_ln)\n",
    "# clean_q_proj_manual = attn_module.q_proj(clean_ln)\n",
    "\n",
    "# print(patch_q_proj_manual.shape, clean_q_proj_manual.shape)\n",
    "# print(torch.allclose(\n",
    "#     patch_q_proj_manual[:, query_idx, :], \n",
    "#     clean_q_proj_manual[:, query_idx, :],\n",
    "#     atol=1e-3\n",
    "# ))\n",
    "\n",
    "# patch_q_proj_manual = patch_q_proj_manual.reshape(batch_size, patch_seq_len, n_heads, head_dim).transpose(1, 2)\n",
    "# clean_q_proj_manual = clean_q_proj_manual.reshape(batch_size, clean_seq_len, n_heads, head_dim).transpose(1, 2)\n",
    "# print(patch_q_proj_manual.shape, clean_q_proj_manual.shape)\n",
    "\n",
    "# for idx in range(n_heads):\n",
    "#     print(head_idx, torch.allclose(\n",
    "#         patch_q_proj_manual[:, idx, query_idx, :], \n",
    "#         clean_q_proj_manual[:, idx, query_idx, :],\n",
    "#         atol=1e-3\n",
    "#     ))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "098535e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# torch.allclose(\n",
    "#     patch_q_proj_manual[:, head_idx, query_idx, :], \n",
    "#     patch_q_proj[:, head_idx, query_idx, :],\n",
    "#     atol=1e-3\n",
    "# )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6de9135a",
   "metadata": {},
   "outputs": [],
   "source": [
    "replace_q_proj = clean_q_proj.clone()\n",
    "replace_q_proj[:, head_idx, query_idx, :] = patch_q_proj[:, head_idx, query_idx, :]\n",
    "\n",
    "print(torch.allclose(\n",
    "        replace_q_proj[:, head_idx, query_idx, :],\n",
    "        clean_q_proj[:, head_idx, query_idx, :],\n",
    "        atol=1e-3\n",
    "    )\n",
    ")\n",
    "print(replace_q_proj.shape)\n",
    "\n",
    "replace_q_proj = replace_q_proj.transpose(1, 2).reshape(batch_size, clean_seq_len, -1)\n",
    "\n",
    "rep_patch = PatchSpec(\n",
    "    location=(q_proj_name, -1),\n",
    "    patch=replace_q_proj[:, -1, :].squeeze(),\n",
    ")\n",
    "\n",
    "ln_patch = PatchSpec(\n",
    "    location=(input_ln, -1),\n",
    "    patch=patch_ln[:, query_idx, :].squeeze(),\n",
    ")\n",
    "\n",
    "head_q_patch = PatchSpec(\n",
    "    location=(q_proj_name, head_idx, -1),\n",
    "    patch=patch_q_proj[:, head_idx, query_idx, :].squeeze(),\n",
    ")\n",
    "\n",
    "head_q_patch.location, rep_patch.location"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3d26910",
   "metadata": {},
   "outputs": [],
   "source": [
    "# clean_q_proj_rs = clean_q_proj.view(batch_size, clean_seq_len, -1)\n",
    "# patch_q_proj_rs = patch_q_proj.view(batch_size, clean_seq_len, -1)\n",
    "# clean_q_proj_rs[:, -1, :].shape, patch_q_proj_rs[:, -1, :].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "95232df1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# torch.allclose(\n",
    "#     replace_q_proj[:, -1, :].squeeze(), \n",
    "#     # patch_q_proj_rs[:, -1, :].squeeze(),\n",
    "#     clean_q_proj_rs[:, -1, :].squeeze(), \n",
    "#     atol=1e-3\n",
    "# )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9835953c",
   "metadata": {},
   "outputs": [],
   "source": [
    "patch_q_proj.shape, clean_q_proj.shape\n",
    "# torch.allclose(patch_q_proj.patch, clean_q_proj.patch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b4454946",
   "metadata": {},
   "outputs": [],
   "source": [
    "clean_sample.metadata"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "759c4ee0",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.attention import get_attention_matrices, visualize_attn_matrix\n",
    "from src.functional import interpret_logits\n",
    "\n",
    "clean_tokenized = prepare_input(prompts=clean_sample.prompt(), tokenizer=mt)\n",
    "\n",
    "attn_info = get_attention_matrices(\n",
    "    input=clean_tokenized,\n",
    "    mt=mt,\n",
    ")\n",
    "\n",
    "attn_matrix = attn_info.attention_matrices[layer_idx, head_idx].squeeze()\n",
    "visualize_attn_matrix(\n",
    "    attn_matrix=attn_matrix,\n",
    "    tokens=[mt.tokenizer.decode(t) for t in clean_tokenized.input_ids[0]],\n",
    "    q_index=-1,\n",
    ")\n",
    "\n",
    "interpret_logits(\n",
    "    tokenizer=mt,\n",
    "    logits=attn_info.logits,\n",
    "    interested_tokens=[clean_sample.ans_token_id, clean_sample.metadata[\"track_type_obj_token_id\"]]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "136268c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.attention import get_attention_matrices, visualize_attn_matrix\n",
    "from src.functional import patch_with_nnsight, patch_with_baukit\n",
    "\n",
    "clean_tokenized = prepare_input(prompts=clean_sample.prompt(), tokenizer=mt)\n",
    "\n",
    "patched_attn_info = get_attention_matrices(\n",
    "    input=clean_tokenized,\n",
    "    mt=mt,\n",
    "    # patches=[ln_patch],\n",
    "    patches = [head_q_patch],\n",
    "    # patches = [rep_patch],\n",
    "    # patch_interface=patch_with_nnsight\n",
    "    patch_interface=patch_with_baukit\n",
    ")\n",
    "\n",
    "patched_attn_matrix = patched_attn_info.attention_matrices[layer_idx, head_idx].squeeze()\n",
    "visualize_attn_matrix(\n",
    "    attn_matrix=patched_attn_matrix,\n",
    "    tokens=[mt.tokenizer.decode(t) for t in clean_tokenized.input_ids[0]],\n",
    "    q_index=-1,\n",
    ")\n",
    "\n",
    "interpret_logits(\n",
    "    tokenizer=mt,\n",
    "    logits=patched_attn_info.logits,\n",
    "    interested_tokens=[clean_sample.ans_token_id, clean_sample.metadata[\"track_type_obj_token_id\"]]\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2bc89a66",
   "metadata": {},
   "source": [
    "## Patching a bunch of heads"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "495ee03a",
   "metadata": {},
   "source": [
    "### Loading the Heads"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cdeda2be",
   "metadata": {},
   "source": [
    "#### Attention Behavior"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2bdb4480",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm.auto import tqdm\n",
    "from scripts.locate_via_attention_behavior import SelectionSampleAttn\n",
    "\n",
    "attn_path = os.path.join(\n",
    "    env_utils.DEFAULT_RESULTS_DIR,\n",
    "    \"selection/attention_patterns/select_one\",\n",
    "    # mt.name.split(\"/\")[-1],\n",
    "    model_key.split(\"/\")[-1],\n",
    "    \"objects\"\n",
    ")\n",
    "files = sorted(os.listdir(attn_path))\n",
    "print(files)\n",
    "\n",
    "#######################################################################\n",
    "# LIMIT = 128\n",
    "LIMIT = len(files)\n",
    "#######################################################################\n",
    "\n",
    "selection_attns = []\n",
    "\n",
    "for npz_file in tqdm(files[:LIMIT]):\n",
    "    if not npz_file.endswith(\".npz\"):\n",
    "        continue\n",
    "\n",
    "    npz_path = os.path.join(attn_path, npz_file)\n",
    "    selection_attns.append(SelectionSampleAttn.from_npz(npz_path))\n",
    "    if len(selection_attns) % 128 == 0:\n",
    "        print(f\"Loaded {len(selection_attns)}/{LIMIT} files\")\n",
    "\n",
    "len(selection_attns)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e494e1b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.attention import visualize_attn_matrix\n",
    "\n",
    "sample_idx = 3\n",
    "\n",
    "layer_idx, head_idx = 35, 19 # llama-70B\n",
    "# layer_idx, head_idx = 54, 44 # qwen-72B\n",
    "# layer_idx, head_idx = 51, 11 # qwen-32B\n",
    "# layer_idx, head_idx = 29, 3 # gemma-27b\n",
    "\n",
    "\n",
    "selection_attn = selection_attns[sample_idx]\n",
    "print(selection_attn.resolution_score(layer_idx, head_idx))\n",
    "visualize_attn_matrix(\n",
    "    attn_matrix=selection_attn.attention_pattern.attention_matrices[layer_idx, head_idx],\n",
    "    tokens=selection_attn.attention_pattern.tokenized_prompt,\n",
    "    q_index=-1,\n",
    "    start_from=1,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c2c5eb83",
   "metadata": {},
   "outputs": [],
   "source": [
    "selection_attns[0].attention_pattern.attention_matrices.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "80f22968",
   "metadata": {},
   "outputs": [],
   "source": [
    "selection_attn.score_per_option(\n",
    "    layer_idx=35, head_idx=19, query_idx=-1, \n",
    "    value_weighted=True,\n",
    "    include_delim=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "28d6cd52",
   "metadata": {},
   "outputs": [],
   "source": [
    "selection_attn.resolution_score(\n",
    "    layer_idx=35, head_idx=19, query_idx=-1,\n",
    "    value_weighted=True,\n",
    "    include_delim=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0631f96d",
   "metadata": {},
   "outputs": [],
   "source": [
    "selection_attn.sample.prompt()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c5d48a6e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm.auto import tqdm\n",
    "\n",
    "#############################################################################\n",
    "n_layer = selection_attns[0].attention_pattern.attention_matrices.shape[0]\n",
    "n_head = selection_attns[0].attention_pattern.attention_matrices.shape[1]\n",
    "# token_idx = \"all\"\n",
    "token_idx = \"last\"\n",
    "##############################################################################\n",
    "\n",
    "resolution_scores = torch.zeros((n_head, n_layer), dtype=torch.float32)\n",
    "for selection_attn in tqdm(selection_attns):\n",
    "    for layer_idx in range(n_layer):\n",
    "        for head_idx in range(n_head):\n",
    "            resolution_scores[head_idx, layer_idx] += selection_attn.resolution_score(\n",
    "                layer_idx, head_idx, token_idx=token_idx,\n",
    "                # value_weighted=True,\n",
    "                include_delim=True\n",
    "            )[0]\n",
    "            # resolution_scores[head_idx, layer_idx] += selection_attn.first_token_score(\n",
    "            #     layer_idx, head_idx\n",
    "            # )[0]\n",
    "\n",
    "resolution_scores /= len(selection_attns)\n",
    "resolution_scores.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d417ce11",
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib import pyplot as plt\n",
    "\n",
    "plt.figure(figsize=(20, 10))\n",
    "scale = torch.max(torch.abs(resolution_scores))\n",
    "plt.imshow(\n",
    "    resolution_scores.cpu().numpy(),\n",
    "    cmap=\"RdBu\",\n",
    "    aspect=\"auto\",\n",
    "    vmin=-scale,\n",
    "    vmax=scale,\n",
    ")\n",
    "plt.colorbar()\n",
    "# plt.title(f\"score(target) - max(score(distractors)) | {token_idx.upper()} tokens of options\")\n",
    "plt.title(\"score(target[0]) - sum(score(target[1:]))\")\n",
    "plt.xlabel(\"Layer\")\n",
    "plt.ylabel(\"Head\")\n",
    "\n",
    "\n",
    "def get_ticks(ticks, skip=5):\n",
    "    ret = []\n",
    "    for i in ticks:\n",
    "        if i % skip == 0:\n",
    "            ret.append(str(i))\n",
    "        else:\n",
    "            ret.append(\"\")\n",
    "    return ret\n",
    "\n",
    "\n",
    "plt.xticks(\n",
    "    ticks=range(n_layer),\n",
    "    labels=get_ticks(range(n_layer)),\n",
    "    rotation=45,\n",
    ")\n",
    "plt.yticks(\n",
    "    ticks=range(n_head),\n",
    "    labels=get_ticks(range(n_head), skip=4),\n",
    ")\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "\n",
    "scores_per_head = []\n",
    "for head_idx in range(n_head):\n",
    "    for layer_idx in range(n_layer):\n",
    "        scores_per_head.append(\n",
    "            (head_idx, layer_idx, resolution_scores[head_idx, layer_idx].item())\n",
    "        )\n",
    "\n",
    "scores_per_head = sorted(scores_per_head, key=lambda x: x[2], reverse=True)\n",
    "for head_idx, layer_idx, score in scores_per_head[:15]:\n",
    "    print(f\"Layer {layer_idx}, Head {head_idx}: {score:.4f}\")\n",
    "\n",
    "save_dir = os.path.join(\"figures\", model_key.split(\"/\")[-1], \"raw\")\n",
    "os.makedirs(save_dir, exist_ok=True)\n",
    "with open(os.path.join(save_dir, \"attention_pattern.json\"), \"w\") as f:\n",
    "    json.dump(scores_per_head, f, indent=4)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "172712e8",
   "metadata": {},
   "outputs": [],
   "source": [
    "heads_attn_behavior = [(layer_idx, head_idx) for layer_idx, head_idx, score in scores_per_head[:100]]\n",
    "print(heads_attn_behavior)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4ac057c5",
   "metadata": {},
   "source": [
    "#### Based on Patching Score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d18ed5f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm.auto import tqdm\n",
    "from src.functional import interpret_logits, PredictedToken\n",
    "from dataclasses import dataclass\n",
    "from dataclasses_json import DataClassJsonMixin\n",
    "from typing import Literal\n",
    "\n",
    "@dataclass\n",
    "class SelectionQprojPatchResult(DataClassJsonMixin):\n",
    "    patch_sample: SelectionSample\n",
    "    clean_sample: SelectionSample\n",
    "    interested_tokens: list[int]\n",
    "    base_results: dict[int, tuple[int, PredictedToken]]\n",
    "    headwise_patching_effects: dict[\n",
    "        tuple[int, int], dict[int, tuple[int, PredictedToken]]\n",
    "    ]\n",
    "    gold_results: dict[int, tuple[int, PredictedToken]] | None = None\n",
    "\n",
    "    def __post_init__(self):\n",
    "        if \"track_type_obj_token_id\" in self.clean_sample.metadata:\n",
    "            self.track_obj_token_id = self.clean_sample.metadata[\n",
    "                \"track_type_obj_token_id\"\n",
    "            ]\n",
    "        elif \"track_obj_token_id\" in self.clean_sample.metadata:\n",
    "            self.track_obj_token_id = self.clean_sample.metadata[\"track_obj_token_id\"]\n",
    "        else:\n",
    "            raise AssertionError(\"Set `track_obj_token_id` in metadata of clean sample\")\n",
    "\n",
    "    def head_effect(\n",
    "        self,\n",
    "        layer_idx,\n",
    "        head_idx,\n",
    "        metric: Literal[\"prob\", \"logit\"] = \"logit\",\n",
    "        normalize=True,\n",
    "    ):\n",
    "        if isinstance(self.base_results, dict):\n",
    "            low_score = getattr(self.base_results[self.track_obj_token_id][1], metric)\n",
    "        else:\n",
    "            low_score = getattr(self.base_results[1], metric)\n",
    "\n",
    "        patch_score = getattr(\n",
    "            self.headwise_patching_effects[(layer_idx, head_idx)][\n",
    "                self.track_obj_token_id\n",
    "            ][1],\n",
    "            metric,\n",
    "        )\n",
    "\n",
    "        # logger.debug(f\"{low_score=}, {high_score=}, {patch_score=}\")\n",
    "        if normalize and self.gold_results is not None:\n",
    "            high_score = getattr(\n",
    "                self.gold_results[self.patch_sample.ans_token_id][1], metric\n",
    "            )\n",
    "            indirect_effect = (patch_score - low_score) / (high_score - low_score)\n",
    "        else:\n",
    "            indirect_effect = patch_score - low_score\n",
    "        return indirect_effect\n",
    "\n",
    "    def delist_patching_effects(self):\n",
    "        self.headwise_patching_effects = {\n",
    "            f\"{layer_idx}_<>_{head_idx}\": effect\n",
    "            for (layer_idx, head_idx), effect in self.headwise_patching_effects.items()\n",
    "        }\n",
    "\n",
    "    @staticmethod\n",
    "    def load_from_json(file_path: str) -> \"SelectionQprojPatchResult\":\n",
    "        with open(file_path, \"r\") as f:\n",
    "            data = json.load(f)\n",
    "        head_wise_patching_effects = {}\n",
    "        for key, value in data[\"headwise_patching_effects\"].items():\n",
    "            layer_idx, head_idx = map(int, key.split(\"_<>_\"))\n",
    "            head_wise_patching_effects[(layer_idx, head_idx)] = value\n",
    "        data[\"headwise_patching_effects\"] = head_wise_patching_effects\n",
    "        if \"obj_token_id\" in data[\"patch_sample\"]:\n",
    "            data[\"patch_sample\"][\"ans_token_id\"] = data[\"patch_sample\"][\"obj_token_id\"]\n",
    "        if \"obj_token_id\" in data[\"clean_sample\"]:\n",
    "            data[\"clean_sample\"][\"ans_token_id\"] = data[\"clean_sample\"][\"obj_token_id\"]\n",
    "        return SelectionQprojPatchResult.from_dict(data)\n",
    "\n",
    "\n",
    "q_proj_root = os.path.join(\n",
    "    env_utils.DEFAULT_RESULTS_DIR,\n",
    "    \"selection/q_states_patching\",\n",
    "    # mt.name.split(\"/\")[-1],\n",
    "    model_key.split(\"/\")[-1],\n",
    ")\n",
    "\n",
    "#############################################\n",
    "# LIMIT = 10\n",
    "LIMIT = None\n",
    "# n_layer = mt.n_layer\n",
    "# n_head = mt.config.num_attention_heads\n",
    "categories = [\n",
    "    # \"profession\",\n",
    "    # \"nationality\",\n",
    "    \"objects\",\n",
    "]\n",
    "#############################################\n",
    "\n",
    "q_proj_results = {cat: [] for cat in categories}\n",
    "\n",
    "for category in categories:\n",
    "    print(f\"category: {category}\")\n",
    "    q_proj_path = os.path.join(q_proj_root, category)\n",
    "    files = sorted(os.listdir(q_proj_path))\n",
    "    LIMIT = LIMIT or len(files)\n",
    "    q_proj_results[category] = []\n",
    "    for file in tqdm(files[:LIMIT]):\n",
    "        if not file.endswith(\".json\"):\n",
    "            continue\n",
    "\n",
    "        file_path = os.path.join(q_proj_path, file)\n",
    "        q_proj_results[category].append(SelectionQprojPatchResult.load_from_json(file_path))\n",
    "        # if len(q_proj_results) % 10 == 0:\n",
    "        #     print(f\"Loaded {len(q_proj_results)}/{LIMIT} files\")\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "41757b2a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# import copy\n",
    "# combine_all_results = []\n",
    "# for category in categories:\n",
    "#     combine_all_results.extend(q_proj_results[category])\n",
    "# results_copy = copy.deepcopy(q_proj_results)\n",
    "# results_copy[\"all\"] = combine_all_results\n",
    "\n",
    "results_copy = q_proj_results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f4c8d728",
   "metadata": {},
   "outputs": [],
   "source": [
    "heads = results_copy[\"objects\"][0].headwise_patching_effects.keys()\n",
    "layers = set([layer for layer, head in heads])\n",
    "heads = set([head for layer, head in heads])    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b809f4e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import matplotlib.patches as patches\n",
    "\n",
    "plt.rcParams[\"figure.dpi\"] = 200\n",
    "plt.rcParams[\"font.family\"] = \"Times New Roman\"\n",
    "\n",
    "SMALL_SIZE = 30\n",
    "MEDIUM_SIZE = 35\n",
    "BIGGER_SIZE = 40\n",
    "\n",
    "plt.rc(\"font\", size=SMALL_SIZE)  # controls default text sizes\n",
    "plt.rc(\"axes\", titlesize=BIGGER_SIZE)  # fontsize of the axes title\n",
    "plt.rc(\"axes\", labelsize=MEDIUM_SIZE)  # fontsize of the x and y labels\n",
    "plt.rc(\"xtick\", labelsize=SMALL_SIZE)  # fontsize of the tick labels\n",
    "plt.rc(\"ytick\", labelsize=SMALL_SIZE)  # fontsize of the tick labels\n",
    "plt.rc(\"legend\", fontsize=SMALL_SIZE)  # legend fontsize\n",
    "plt.rc(\"figure\", titlesize=BIGGER_SIZE)  # fontsize of the figure title\n",
    "\n",
    "fig_save_path = os.path.join(\"figures\", model_key.split(\"/\")[-1], \"aie\")\n",
    "os.makedirs(fig_save_path, exist_ok=True)\n",
    "\n",
    "n_layer = len(layers)\n",
    "n_heads = len(heads)\n",
    "\n",
    "category_wise_heads = {}\n",
    "for category, categorywise_result in results_copy.items():\n",
    "    indirect_effects = torch.zeros((n_layer, n_heads), dtype=torch.float32)\n",
    "    for layer_idx in range(n_layer):\n",
    "        for head_idx in range(n_heads):\n",
    "            indirect_effects[layer_idx, head_idx] = torch.mean(\n",
    "                torch.tensor(\n",
    "                    [\n",
    "                        sample_result.head_effect(layer_idx, head_idx, normalize=False)\n",
    "                        for sample_result in categorywise_result\n",
    "                    ]\n",
    "                )\n",
    "            )\n",
    "\n",
    "    # with torch\n",
    "\n",
    "    plt.figure(figsize=(20, 10))\n",
    "    scale = torch.max(torch.abs(indirect_effects))\n",
    "    plt.imshow(\n",
    "        indirect_effects.T.cpu().numpy(),\n",
    "        cmap=\"RdBu\",\n",
    "        aspect=\"auto\",\n",
    "        # vmin=-scale,\n",
    "        # vmax=scale,\n",
    "        # vmin=-0.15,\n",
    "        # vmax=0.15,\n",
    "        vmin=2,\n",
    "        vmax=-2\n",
    "    )\n",
    "    plt.colorbar()\n",
    "    # plt.title(f\"score(target) - max(score(distractors)) | {token_idx.upper()} tokens of options\")\n",
    "    # plt.title(\"IE of q_proj patching | \" + category)\n",
    "    plt.xlabel(\"Layer\")\n",
    "    plt.ylabel(\"Head Index\")\n",
    "\n",
    "    def get_ticks(ticks, skip=5):\n",
    "        ret = []\n",
    "        for i in ticks:\n",
    "            if i % skip == 0:\n",
    "                ret.append(str(i))\n",
    "            else:\n",
    "                ret.append(\"\")\n",
    "        return ret\n",
    "\n",
    "    plt.xticks(\n",
    "        ticks=range(n_layer),\n",
    "        labels=get_ticks(range(n_layer)),\n",
    "        # rotation=45,\n",
    "    )\n",
    "    plt.yticks(\n",
    "        ticks=range(n_heads),\n",
    "        labels=get_ticks(range(n_heads), skip=8),\n",
    "    )\n",
    "\n",
    "    # # Get the current axes\n",
    "    ax = plt.gca()\n",
    "\n",
    "    # Draw borders around marked cells\n",
    "    for (x, y) in optimized_heads:\n",
    "        # Create a Rectangle patch\n",
    "        # Note: (x-0.5, y-0.5) positions the rectangle correctly around the cell\n",
    "        # Width and height of 1 covers exactly one cell\n",
    "        rect = patches.Rectangle(\n",
    "            (x - 0.5, y - 0.5),  # bottom-left corner\n",
    "            1,                     # width\n",
    "            1,                     # height\n",
    "            linewidth=1.5,          # border thickness\n",
    "            edgecolor='black',    # border color (you can change this)\n",
    "            facecolor='none'      # no fill, just border\n",
    "        )\n",
    "        ax.add_patch(rect)\n",
    "\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(os.path.join(fig_save_path, f\"{category}.pdf\"), bbox_inches='tight', pad_inches=0)\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "    scores_per_head = []\n",
    "    for layer_idx in layers:\n",
    "        for head_idx in heads:\n",
    "            scores_per_head.append(\n",
    "                (layer_idx, head_idx, indirect_effects[layer_idx, head_idx].item())\n",
    "            )\n",
    "\n",
    "    scores_per_head = sorted(scores_per_head, key=lambda x: x[2], reverse=True)\n",
    "    category_wise_heads[category] = scores_per_head\n",
    "    for layer_idx, head_idx, score in scores_per_head[:15]:\n",
    "        print(f\"Layer {layer_idx}, Head {head_idx}: {score:.4f}\")\n",
    "\n",
    "with open(\"category_wise_heads.json\", \"w\") as f:\n",
    "    json.dump(category_wise_heads, f, indent=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a61518b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "save_path = os.path.join(\"figures\", model_key.split(\"/\")[-1], \"raw\")\n",
    "os.makedirs(save_path, exist_ok=True)\n",
    "with open(os.path.join(save_path, \"aie_per_head.json\"), \"w\") as f:\n",
    "    json.dump(scores_per_head, f, indent=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8dd99c18",
   "metadata": {},
   "outputs": [],
   "source": [
    "for layer_idx, head_idx, score in scores_per_head:\n",
    "    print(f\"Layer {layer_idx}, Head {head_idx}: {score:.4f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0ff71ba6",
   "metadata": {},
   "source": [
    "#### Performing the Patch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "24c99868",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.tokens import prepare_input\n",
    "from src.functional import get_module_nnsight\n",
    "\n",
    "# HEADS = [\n",
    "#     (33, 45),\n",
    "#     (33, 18),\n",
    "#     (34, 1),\n",
    "#     (34, 6),\n",
    "#     (34, 7),\n",
    "#     (35, 19),\n",
    "#     (39, 40),\n",
    "#     (42, 30),\n",
    "#     (47, 18),\n",
    "#     (52, 58),\n",
    "# ]\n",
    "\n",
    "# HEADS = [\n",
    "#     (layer_idx, head_idx)\n",
    "#     for layer_idx, head_idx, score in category_wise_heads[\"all\"][:50]\n",
    "# ]\n",
    "\n",
    "# HEADS = heads_selected\n",
    "\n",
    "clean_tokenized = prepare_input(prompts=clean_sample.prompt(), tokenizer=mt)\n",
    "patch_tokenized = prepare_input(prompts=patch_sample.prompt(), tokenizer=mt)\n",
    "\n",
    "# category_wise_heads[\"all\"][len(HEADS) - 1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "98815bf3",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.utils.typing import TokenizerOutput\n",
    "\n",
    "\n",
    "@torch.inference_mode()\n",
    "def cache_q_projections(\n",
    "    mt: ModelandTokenizer,\n",
    "    input: TokenizerOutput,\n",
    "    query_locations: list[tuple[int, int, int]],  # (layer_idx, head_idx, query_idx)\n",
    "    return_output: bool = False,\n",
    "):\n",
    "    layer_to_hq = {}\n",
    "    for layer_idx, head_idx, query_idx in query_locations:\n",
    "        if layer_idx not in layer_to_hq:\n",
    "            layer_to_hq[layer_idx] = []\n",
    "        layer_to_hq[layer_idx].append((head_idx, query_idx))\n",
    "\n",
    "    q_projections = {}\n",
    "    batch_size = input.input_ids.shape[0]\n",
    "    seq_len = input.input_ids.shape[1]\n",
    "    n_heads = mt.config.num_attention_heads\n",
    "    head_dim = mt.n_embd // n_heads\n",
    "    with mt.trace(input) as tracer:\n",
    "        for layer_idx, query_locs in layer_to_hq.items():\n",
    "            q_proj_name = mt.attn_module_name_format.format(layer_idx) + \".q_proj\"\n",
    "            q_proj_module = get_module_nnsight(mt, q_proj_name)\n",
    "            q_proj_out = q_proj_module.output.view(\n",
    "                batch_size, seq_len, n_heads, head_dim\n",
    "            ).transpose(1, 2)\n",
    "            for head_idx, query_idx in query_locs:\n",
    "                q_projections[(layer_idx, head_idx, query_idx)] = (\n",
    "                    q_proj_out[:, head_idx, query_idx, :].squeeze().save()\n",
    "                )\n",
    "\n",
    "        if return_output:\n",
    "            output = mt.output.save()\n",
    "\n",
    "    if return_output:\n",
    "        return q_projections, output\n",
    "    return q_projections"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b7bf27ac",
   "metadata": {},
   "outputs": [],
   "source": [
    "query_indices = list(range(-3, 0))\n",
    "query_locations = [\n",
    "    (layer_idx, head_idx, query_idx)\n",
    "    for layer_idx, head_idx in HEADS\n",
    "    for query_idx in query_indices\n",
    "]\n",
    "\n",
    "cached_q_states = cache_q_projections(\n",
    "    mt=mt,\n",
    "    input=patch_tokenized,\n",
    "    query_locations=query_locations,\n",
    ")\n",
    "\n",
    "# cached_q_states[(HEADS[0])].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa5508d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.attention import get_attention_matrices, visualize_attn_matrix\n",
    "from src.functional import interpret_logits\n",
    "\n",
    "clean_tokenized = prepare_input(prompts=clean_sample.prompt(), tokenizer=mt)\n",
    "\n",
    "# attn_info = get_attention_matrices(\n",
    "#     input=clean_tokenized,\n",
    "#     mt=mt,\n",
    "# )\n",
    "\n",
    "# layer_idx, head_idx = 35, 19\n",
    "# attn_matrix = attn_info.attention_matrices[layer_idx, head_idx].squeeze()\n",
    "# visualize_attn_matrix(\n",
    "#     attn_matrix=attn_matrix,\n",
    "#     tokens=[mt.tokenizer.decode(t) for t in clean_tokenized.input_ids[0]],\n",
    "#     q_index=-1,\n",
    "# )\n",
    "\n",
    "# interpret_logits(\n",
    "#     tokenizer=mt,\n",
    "#     logits=attn_info.logits,\n",
    "#     interested_tokens=[clean_sample.obj_token_id, clean_sample.metadata[\"track_type_obj_token_id\"]]\n",
    "# )\n",
    "\n",
    "attn_pattern = verify_head_patterns(\n",
    "    prompt = clean_tokenized,\n",
    "    options = clean_sample.options,\n",
    "    pivot = clean_sample.subj,\n",
    "    mt = mt,\n",
    "    heads = HEADS,\n",
    "    # heads = patching_heads,\n",
    "    generate_full_answer=True,\n",
    ")\n",
    "\n",
    "attn_pattern[\"predictions\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ace4ff45",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.attention import get_attention_matrices, visualize_attn_matrix\n",
    "from src.functional import patch_with_nnsight, patch_with_baukit, PatchSpec\n",
    "\n",
    "q_proj_patches = []\n",
    "for (layer_idx, head_idx, query_idx), q_proj in cached_q_states.items():\n",
    "    q_proj_patches.append(\n",
    "        PatchSpec(\n",
    "            location=(mt.attn_module_name_format.format(layer_idx) + \".q_proj\", head_idx, query_idx),\n",
    "            patch=q_proj\n",
    "        )\n",
    "    )\n",
    "\n",
    "# patched_attn_info = get_attention_matrices(\n",
    "#     input=clean_tokenized,\n",
    "#     mt=mt,\n",
    "#     patches=q_proj_patches,\n",
    "#     patch_interface=patch_with_baukit\n",
    "# )\n",
    "\n",
    "# layer_idx, head_idx = 35, 19\n",
    "# patched_attn_matrix = patched_attn_info.attention_matrices[layer_idx, head_idx].squeeze()\n",
    "# visualize_attn_matrix(\n",
    "#     attn_matrix=patched_attn_matrix,\n",
    "#     tokens=[mt.tokenizer.decode(t) for t in clean_tokenized.input_ids[0]],\n",
    "#     q_index=-1,\n",
    "# )\n",
    "\n",
    "# interpret_logits(\n",
    "#     tokenizer=mt,\n",
    "#     logits=patched_attn_info.logits,\n",
    "#     interested_tokens=[clean_sample.obj_token_id, clean_sample.metadata[\"track_type_obj_token_id\"]]\n",
    "# )\n",
    "\n",
    "patched_attn_pattern = verify_head_patterns(\n",
    "    prompt = clean_tokenized,\n",
    "    options = clean_sample.options,\n",
    "    pivot = clean_sample.subj,\n",
    "    mt = mt,\n",
    "    heads = HEADS,\n",
    "    # heads = patching_heads,\n",
    "    query_patches=q_proj_patches,\n",
    "    # generate_full_answer=True,\n",
    ")\n",
    "\n",
    "patched_attn_pattern[\"predictions\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1972e073",
   "metadata": {},
   "source": [
    "### Search over layers and heads"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3c45342d",
   "metadata": {},
   "outputs": [],
   "source": [
    "clean_sample.metadata"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fbbcd7aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "from scripts.patching_within_task import SelectionQprojPatchResult"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f1d94c15",
   "metadata": {},
   "outputs": [],
   "source": [
    "from itertools import product\n",
    "from src.functional import patch_with_baukit, interpret_logits\n",
    "from src.selection.functional import cache_q_projections\n",
    "\n",
    "all_heads = list(product(range(20, 30), range(mt.config.num_attention_heads)))\n",
    "query_indices = {-3: -3, -2: -2, -1: -1}\n",
    "\n",
    "clean_tokenized = prepare_input(prompts=clean_sample.prompt(), tokenizer=mt)\n",
    "patch_tokenized = prepare_input(prompts=patch_sample.prompt(), tokenizer=mt)\n",
    "\n",
    "interested_tokens = [\n",
    "    patch_sample.ans_token_id,\n",
    "    clean_sample.ans_token_id,\n",
    "    clean_sample.metadata[\"track_type_obj_token_id\"]\n",
    "]\n",
    "\n",
    "\n",
    "query_locations = [\n",
    "    (layer_idx, head_idx, patch_query_idx)\n",
    "    for layer_idx, head_idx in all_heads\n",
    "    for patch_query_idx in query_indices.keys()\n",
    "]\n",
    "\n",
    "all_q_projections, patch_out = cache_q_projections(\n",
    "    mt=mt,\n",
    "    input=patch_tokenized,\n",
    "    query_locations=query_locations,\n",
    "    return_output=True,\n",
    ")\n",
    "logger.debug(len(all_q_projections))\n",
    "\n",
    "patch_logits = patch_out.logits[:, -1, :].squeeze()\n",
    "patch_precitions, patch_track = interpret_logits(\n",
    "    tokenizer=mt,\n",
    "    logits=patch_logits,\n",
    "    interested_tokens=interested_tokens,\n",
    ")\n",
    "\n",
    "patch_precitions, patch_track"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e97dcd3f",
   "metadata": {},
   "outputs": [],
   "source": [
    "clean_out = patch_with_baukit(\n",
    "    mt=mt,\n",
    "    inputs=clean_tokenized,\n",
    "    patches=[],\n",
    ")\n",
    "\n",
    "base_logits = clean_out.logits[:, -1, :].squeeze()\n",
    "base_predictions, base_track = interpret_logits(\n",
    "    tokenizer=mt,\n",
    "    logits=base_logits,\n",
    "    interested_tokens=interested_tokens,\n",
    ")\n",
    "base_predictions, base_track"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d9db38e8",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.functional import PatchSpec\n",
    "q_proj_patches = []\n",
    "for (layer_idx, head_idx, patch_query_idx), q_proj in all_q_projections.items():\n",
    "    q_proj_patches.append(\n",
    "        PatchSpec(\n",
    "            location=(\n",
    "                mt.attn_module_name_format.format(layer_idx) + \".q_proj\",\n",
    "                head_idx,\n",
    "                query_indices[patch_query_idx],\n",
    "            ),\n",
    "            patch=q_proj,\n",
    "        )\n",
    "    )\n",
    "\n",
    "int_out = patch_with_baukit(\n",
    "    mt = mt,\n",
    "    inputs = clean_tokenized,\n",
    "    patches = q_proj_patches,\n",
    ")\n",
    "\n",
    "logits = int_out.logits[:, -1, :].squeeze()\n",
    "\n",
    "interpret_logits(\n",
    "    tokenizer=mt,\n",
    "    logits=logits,\n",
    "    interested_tokens=interested_tokens\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8dede3bd",
   "metadata": {},
   "outputs": [],
   "source": [
    "len(all_q_projections)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c7b361b",
   "metadata": {},
   "outputs": [],
   "source": [
    "list(all_q_projections.keys())[:5]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "90779a00",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm.auto import tqdm\n",
    "head_wise_patching_effects = {}\n",
    "\n",
    "for (layer_idx, head_idx) in tqdm(all_heads):\n",
    "    q_proj_patch = []\n",
    "    for patch_tok_idx, clean_tok_idx in query_indices.items():\n",
    "        q_proj_patch.append(\n",
    "            PatchSpec(\n",
    "                location=(\n",
    "                    mt.attn_module_name_format.format(layer_idx) + \".q_proj\", \n",
    "                    head_idx, \n",
    "                    query_indices[patch_tok_idx]\n",
    "                ),\n",
    "                patch=all_q_projections[(layer_idx, head_idx, patch_tok_idx)]\n",
    "            )\n",
    "        )\n",
    "    out = patch_with_baukit(\n",
    "        mt = mt,\n",
    "        inputs = clean_tokenized,\n",
    "        patches = q_proj_patch,\n",
    "    )\n",
    "    logits = out.logits[:, -1, :].squeeze()\n",
    "    predictions, track = interpret_logits(\n",
    "        tokenizer=mt,\n",
    "        logits=logits,\n",
    "        interested_tokens=interested_tokens\n",
    "    )\n",
    "    head_wise_patching_effects[(layer_idx, head_idx)] = track"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd54db35",
   "metadata": {},
   "outputs": [],
   "source": [
    "patching_results = SelectionQprojPatchResult(\n",
    "    patch_sample=patch_sample,\n",
    "    clean_sample=clean_sample,\n",
    "    interested_tokens=interested_tokens,\n",
    "    base_results=base_track,\n",
    "    gold_results=patch_track,\n",
    "    headwise_patching_effects=head_wise_patching_effects\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a063e814",
   "metadata": {},
   "outputs": [],
   "source": [
    "patching_results.head_effect(layer_idx=25, head_idx=19)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b454208a",
   "metadata": {},
   "outputs": [],
   "source": [
    "headwise_scores = [\n",
    "    (\n",
    "        layer_idx,\n",
    "        head_idx,\n",
    "        patching_results.head_effect(layer_idx, head_idx)\n",
    "    )\n",
    "    for layer_idx, head_idx in head_wise_patching_effects.keys()\n",
    "]\n",
    "\n",
    "headwise_scores = sorted(headwise_scores, key=lambda x: x[2], reverse=True)\n",
    "patching_heads = []\n",
    "for layer_idx, head_idx, score in headwise_scores[:15]:\n",
    "    print(f\"Layer {layer_idx}, Head {head_idx}: {score:.4f}\")\n",
    "    patching_heads.append((layer_idx, head_idx))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d8e0d4a",
   "metadata": {},
   "outputs": [],
   "source": [
    "patching_results.patch_sample.metadata.pop(\"tokenized\")\n",
    "patching_results.clean_sample.metadata.pop(\"tokenized\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "35265af0",
   "metadata": {},
   "outputs": [],
   "source": [
    "patching_results.delist_patching_effects()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fd99834d",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"patching_results.json\", \"w\") as f:\n",
    "    json.dump(patching_results.to_dict(), f, indent=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8f942ccf",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"patching_results.json\", \"r\") as f:\n",
    "    loaded_results = json.load(f)\n",
    "\n",
    "loaded_results[\"headwise_patching_effects\"] = {\n",
    "    (int(layer_idx.split(\"_<>_\")[0]), int(layer_idx.split(\"_<>_\")[1])): effect\n",
    "    for layer_idx, effect in loaded_results[\"headwise_patching_effects\"].items()\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1fff3b5d",
   "metadata": {},
   "outputs": [],
   "source": [
    "loaded_results[\"headwise_patching_effects\"].keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "27e0951b",
   "metadata": {},
   "outputs": [],
   "source": [
    "patching_results_loaded = SelectionQprojPatchResult.from_dict(loaded_results)\n",
    "patching_results_loaded.head_effect(layer_idx=25, head_idx=19)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f3e661ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "mt.config.num_attention_heads"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7891c437",
   "metadata": {},
   "source": [
    "### Optimization to select heads to patch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "51a3e085",
   "metadata": {},
   "outputs": [],
   "source": [
    "#################################################################################\n",
    "train_limit = 512\n",
    "# prompt_template_idx = 1\n",
    "prompt_template_idx = 3\n",
    "N_DISTRACTORS = 5\n",
    "OPTION_STYLE = \"single_line\"\n",
    "#################################################################################"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "af4ce17d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.functional import free_gpu_cache\n",
    "from src.selection.data import get_counterfactual_samples_interface\n",
    "from src.selection.data import CounterFactualSamplePair\n",
    "\n",
    "\n",
    "train_samples_save_path = os.path.join(\n",
    "    env_utils.DEFAULT_RESULTS_DIR,\n",
    "    \"selection\",\n",
    "    \"samples\",\n",
    "    \"train\",\n",
    "    mt.name.split(\"/\")[-1],\n",
    "    select_task.task_name,\n",
    "    \"objects\",\n",
    ")\n",
    "os.makedirs(train_samples_save_path, exist_ok=True)\n",
    "\n",
    "train_set = []\n",
    "while len(train_set) < train_limit:\n",
    "    print(f\"sample {len(train_set)+1} / {train_limit}\")\n",
    "    patch, clean = counterfactual_sampler(\n",
    "        mt=mt,\n",
    "        task=select_task,\n",
    "        filter_by_lm_prediction=True,\n",
    "        prompt_template_idx=prompt_template_idx,\n",
    "        option_style=OPTION_STYLE,\n",
    "        # distinct_options=True,\n",
    "        # # n_distractors=N_DISTRACTORS,\n",
    "        # patch_n_distractors=N_DISTRACTORS,\n",
    "        # clean_n_distractors=N_DISTRACTORS\n",
    "        n_options = random.choice([4, 5, 6, 7])\n",
    "    )\n",
    "    train_set.append((clean, patch))\n",
    "\n",
    "    cf_pair = CounterFactualSamplePair(\n",
    "        patch_sample=patch,\n",
    "        clean_sample=clean,\n",
    "    )\n",
    "    cf_pair.detensorize()\n",
    "    with open(os.path.join(train_samples_save_path, f\"{len(train_set):05d}.json\"), \"w\") as f:\n",
    "        json.dump(cf_pair.to_dict(), f, indent=2)\n",
    "\n",
    "len(train_set)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b32f8eab",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.data import CounterFactualSamplePair\n",
    "\n",
    "train_set = []\n",
    "train_samples_load_path = os.path.join(\n",
    "    env_utils.DEFAULT_RESULTS_DIR,\n",
    "    \"selection\",\n",
    "    \"samples\",\n",
    "    \"train\",\n",
    "    mt.name.split(\"/\")[-1],\n",
    "    select_task.task_name,\n",
    "    \"objects\",\n",
    ")\n",
    "\n",
    "sample_files = [\n",
    "    os.path.join(train_samples_load_path, f)\n",
    "    for f in os.listdir(train_samples_load_path)\n",
    "    if f.endswith(\".json\")\n",
    "]\n",
    "logger.info(f\"Found {len(sample_files)} sample files\")\n",
    "random.shuffle(sample_files)\n",
    "sample_files = sample_files[:train_limit]\n",
    "for sample_file in sample_files:\n",
    "    with open(sample_file, \"r\") as f:\n",
    "        cf_pair_data = json.load(f)\n",
    "    cf_pair = CounterFactualSamplePair.from_dict(cf_pair_data)\n",
    "    train_set.append((cf_pair.clean_sample, cf_pair.patch_sample))\n",
    "\n",
    "len(train_set)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f1f36fd3",
   "metadata": {},
   "outputs": [],
   "source": [
    "clean, patch = train_set[5]\n",
    "print(patch.prompt(), \">>\", mt.tokenizer.decode(patch.ans_token_id))\n",
    "print(clean.prompt(), \">>\", mt.tokenizer.decode(clean.ans_token_id))\n",
    "\n",
    "clean.metadata[\"track_type_obj_token_id\"], mt.tokenizer.decode(clean.metadata[\"track_type_obj_token_id\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e5eb0b8a",
   "metadata": {},
   "outputs": [],
   "source": [
    "mt._model.zero_grad()\n",
    "free_gpu_cache()\n",
    "len(train_set)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d636fae",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.optimization import get_optimal_head_mask_optimized, get_optimal_head_mask_prev\n",
    "import numpy as np\n",
    "\n",
    "free_gpu_cache()\n",
    "\n",
    "optimization_interface = {\n",
    "    \"legacy\": get_optimal_head_mask_prev,\n",
    "    \"updated\": get_optimal_head_mask_optimized,\n",
    "}\n",
    "\n",
    "#############################\n",
    "intface = \"legacy\"\n",
    "#############################\n",
    "\n",
    "optimized_path = os.path.join(\n",
    "    env_utils.DEFAULT_RESULTS_DIR,\n",
    "    \"selection/test_localization_code\",\n",
    "    mt.name.split(\"/\")[-1],\n",
    "    f\"{TASK_CLS.task_name}\",\n",
    ")\n",
    "\n",
    "optimization_func = optimization_interface[intface]\n",
    "\n",
    "indices_kwargs = {\"query_indices\": [-2, -1]}\n",
    "if intface == \"legacy\":\n",
    "    optimized_path = os.path.join(optimized_path, \"legacy\")\n",
    "    # indices_kwargs[\"query_indices\"] = [-3, -2, -1]\n",
    "    indices_kwargs[\"query_indices\"] = [-1]\n",
    "elif intface == \"updated\":\n",
    "    indices_kwargs[\"add_ques_pos_to_query_indices\"] = True\n",
    "\n",
    "optimal_mask, losses = optimization_func(\n",
    "    mt=mt,\n",
    "    train_set=train_set,\n",
    "    learning_rate=1e-2,\n",
    "    n_epochs=10,\n",
    "    lamb=2e-2,\n",
    "    batch_size=16,\n",
    "    save_step=2,\n",
    "    save_path=optimized_path,\n",
    "    # black_list_heads=optimized_heads\n",
    "    **indices_kwargs\n",
    ")\n",
    "\n",
    "os.makedirs(os.path.dirname(optimized_path), exist_ok=True)\n",
    "\n",
    "np.savez_compressed(\n",
    "    optimized_path,\n",
    "    **dict(\n",
    "        optimal_mask=optimal_mask.to(torch.float32).numpy(),\n",
    "        losses=np.array(losses, dtype=np.float32),\n",
    "    ),\n",
    "    allow_pickle=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "04fa7b04",
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib import pyplot as plt\n",
    "\n",
    "optimized_path = os.path.join(\n",
    "    env_utils.DEFAULT_RESULTS_DIR,\n",
    "    \"selection/test_localization_code\",\n",
    "    mt.name.split(\"/\")[-1],\n",
    "    f\"{TASK_CLS.task_name}\",\n",
    "    \"legacy\",\n",
    "    \"epoch_10.npz\",\n",
    ")\n",
    "\n",
    "optimization_results = np.load(optimized_path, allow_pickle=True)\n",
    "plt.plot(optimization_results[\"losses\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "66e217f5",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(20, 10))\n",
    "\n",
    "optimal_head_mask = torch.tensor(optimization_results[\"optimal_mask\"]).to(torch.float32)\n",
    "optimal_head_mask[50:, :] = 0.0\n",
    "# optimal_head_mask[75:, :] = 0.0\n",
    "# optimal_head_mask[37:, :] = 0.0\n",
    "\n",
    "plt.imshow(\n",
    "    optimal_head_mask.T.numpy(),\n",
    "    cmap=\"Blues\",\n",
    "    aspect=\"auto\",\n",
    "    vmin=0,\n",
    "    vmax=1,\n",
    ")\n",
    "\n",
    "optimized_heads = torch.nonzero(optimal_head_mask > 0.5, as_tuple=False).to(dtype=torch.int).tolist()\n",
    "optimized_heads = [\n",
    "    (layer_idx, head_idx) for layer_idx, head_idx in optimized_heads\n",
    "]\n",
    "print(len(optimized_heads))\n",
    "\n",
    "HEADS = optimized_heads\n",
    "\n",
    "(35, 19) in HEADS, (35, 19) in optimized_heads"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f414b376",
   "metadata": {},
   "source": [
    "## Validation of the patching effect"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b268f172",
   "metadata": {},
   "outputs": [],
   "source": [
    "OPTION_STYLE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "82c310a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.data import CounterFactualSamplePair, get_counterfactual_samples_interface\n",
    "from src.functional import free_gpu_cache\n",
    "import random\n",
    "\n",
    "validation_samples_save_path = os.path.join(\n",
    "    env_utils.DEFAULT_RESULTS_DIR,\n",
    "    \"selection\",\n",
    "    \"samples\",\n",
    "    \"validation_upd\",\n",
    "    mt.name.split(\"/\")[-1],\n",
    "    select_task.task_name,\n",
    "    \"objects\",\n",
    ")\n",
    "os.makedirs(validation_samples_save_path, exist_ok=True)\n",
    "\n",
    "free_gpu_cache()\n",
    "validation_set = []\n",
    "validation_limit = 1024\n",
    "start_number = 1\n",
    "\n",
    "counterfactual_sampler = get_counterfactual_samples_interface[select_task.task_name]\n",
    "\n",
    "while len(validation_set) < validation_limit:\n",
    "    print(f\"sample {len(validation_set)+1} / {validation_limit}\")\n",
    "    patch, clean = counterfactual_sampler(\n",
    "        mt=mt,\n",
    "        task=select_task,\n",
    "        filter_by_lm_prediction=True,\n",
    "        prompt_template_idx=prompt_template_idx,\n",
    "        option_style=OPTION_STYLE,\n",
    "        distinct_options=True,\n",
    "        # # n_distractors=N_DISTRACTORS,\n",
    "        patch_n_distractors=random.choice(range(2, 6)),\n",
    "        clean_n_distractors=random.choice(range(2, 6)),\n",
    "        # n_options = random.choice([4, 5, 6, 7])\n",
    "    )\n",
    "    validation_set.append((clean, patch))\n",
    "    cf_pair = CounterFactualSamplePair(\n",
    "        patch_sample=patch,\n",
    "        clean_sample=clean,\n",
    "    )\n",
    "    cf_pair.detensorize()\n",
    "    with open(\n",
    "        os.path.join(validation_samples_save_path, f\"{len(validation_set) + start_number - 1:05d}.json\"),\n",
    "        \"w\",\n",
    "    ) as f:\n",
    "        json.dump(cf_pair.to_dict(), f, indent=2)\n",
    "\n",
    "len(validation_set)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f0bbe8c",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.data import CounterFactualSamplePair\n",
    "import random\n",
    "\n",
    "validation_set = []\n",
    "validation_limit = 512\n",
    "\n",
    "validation_samples_load_path = os.path.join(\n",
    "    env_utils.DEFAULT_RESULTS_DIR,\n",
    "    \"selection\",\n",
    "    \"samples\",\n",
    "    \"validation_upd\",\n",
    "    mt.name.split(\"/\")[-1],\n",
    "    select_task.task_name,\n",
    "    \"objects\",\n",
    ")\n",
    "\n",
    "sample_files = [\n",
    "    os.path.join(validation_samples_load_path, f)\n",
    "    for f in os.listdir(validation_samples_load_path)\n",
    "    if f.endswith(\".json\")\n",
    "]\n",
    "logger.info(f\"Found {len(sample_files)} sample files\")\n",
    "\n",
    "random.shuffle(sample_files)\n",
    "sample_files = sample_files[:validation_limit]\n",
    "for sample_file in sample_files:\n",
    "    with open(sample_file, \"r\") as f:\n",
    "        cf_pair_data = json.load(f)\n",
    "    cf_pair = CounterFactualSamplePair.from_dict(cf_pair_data)\n",
    "    validation_set.append((cf_pair.clean_sample, cf_pair.patch_sample))\n",
    "\n",
    "len(validation_set)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bcc99542",
   "metadata": {},
   "outputs": [],
   "source": [
    "clean, patch = validation_set[3]\n",
    "print(patch.prompt(), \">>\", mt.tokenizer.decode(patch.ans_token_id))\n",
    "print(clean.prompt(), \">>\", mt.tokenizer.decode(clean.ans_token_id))\n",
    "clean.metadata[\"track_type_obj_token_id\"], mt.tokenizer.decode(clean.metadata[\"track_type_obj_token_id\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4077a595",
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Any, Optional\n",
    "from src.utils.typing import TokenizerOutput\n",
    "from src.functional import get_module_nnsight, PatchSpec, patch_with_baukit\n",
    "import baukit\n",
    "import types\n",
    "from src.hooking.llama_attention import LlamaAttentionPatcher\n",
    "from src.attention import visualize_attn_matrix\n",
    "from src.selection.data import SelectionSample, CountingSample\n",
    "from src.selection.data import get_options_for_answer\n",
    "\n",
    "\n",
    "@torch.inference_mode()\n",
    "def cache_q_projections(\n",
    "    mt: ModelandTokenizer,\n",
    "    input: TokenizerOutput,\n",
    "    heads: list[tuple[int, int]],  # (layer_idx, head_idx)\n",
    "    token_indices: list[list[int]],\n",
    "    return_output: bool = False,\n",
    "    projection_signature: str = \".q_proj\",\n",
    "):\n",
    "    batch_size = input.input_ids.shape[0]\n",
    "    assert len(token_indices) == batch_size, f\"{len(token_indices)=} != {batch_size=}\"\n",
    "    layer_to_head = {}\n",
    "    for layer_idx, head_idx in heads:\n",
    "        if layer_idx not in layer_to_head:\n",
    "            layer_to_head[layer_idx] = []\n",
    "        layer_to_head[layer_idx].append(head_idx)\n",
    "\n",
    "    seq_len = input.input_ids.shape[1]\n",
    "    n_heads = mt.config.num_attention_heads\n",
    "    # head_dim = mt.n_embd // n_heads\n",
    "    head_dim = get_module_nnsight(\n",
    "        mt._model, mt.attn_module_name_format.format(0)\n",
    "    ).head_dim\n",
    "    group_size = n_heads // mt.config.num_key_value_heads\n",
    "    q_module_projections_per_layer = {}\n",
    "    with mt.trace(input) as tracer:  # noqa\n",
    "        for layer_idx, head_indices in layer_to_head.items():\n",
    "            q_proj_name = (\n",
    "                mt.attn_module_name_format.format(layer_idx) + projection_signature\n",
    "            )\n",
    "            q_proj_module = get_module_nnsight(mt, q_proj_name)\n",
    "            q_module_projections_per_layer[q_proj_name] = q_proj_module.output.save()\n",
    "\n",
    "        if return_output:\n",
    "            output = mt.output.save()\n",
    "\n",
    "    q_projections = [{} for _ in range(batch_size)]\n",
    "    for layer_idx, head_indices in layer_to_head.items():\n",
    "        q_proj_name = (\n",
    "            mt.attn_module_name_format.format(layer_idx) + projection_signature\n",
    "        )\n",
    "        # print(q_proj_name)\n",
    "        q_proj_out = (\n",
    "            q_module_projections_per_layer[q_proj_name]\n",
    "            .view(batch_size, seq_len, -1, head_dim)\n",
    "            .transpose(1, 2)\n",
    "        )\n",
    "        if projection_signature in [\".k_proj\", \".v_proj\"] and group_size != 1:\n",
    "            q_proj_out = repeat_kv(q_proj_out, n_rep=group_size)\n",
    "        # print(q_proj_out.shape, q_proj_out.norm())\n",
    "        for prompt_idx in range(batch_size):\n",
    "            for head_idx in head_indices:\n",
    "                for token_idx in token_indices[prompt_idx]:\n",
    "                    q_projections[prompt_idx][(layer_idx, head_idx, token_idx)] = (\n",
    "                        q_proj_out[prompt_idx, head_idx, token_idx]\n",
    "                    )\n",
    "\n",
    "    if return_output:\n",
    "        return q_projections, output\n",
    "    return q_projections\n",
    "\n",
    "\n",
    "def locate_with_delim(prompt, option):\n",
    "    st = prompt.index(option)\n",
    "    return prompt[st : st + len(option) + 1]\n",
    "\n",
    "\n",
    "@torch.no_grad()\n",
    "def validate_q_proj_ie_on_sample_pair(\n",
    "    mt: ModelandTokenizer,\n",
    "    clean_sample: SelectionSample | CountingSample,\n",
    "    patch_sample: SelectionSample | CountingSample,\n",
    "    heads: list[tuple[int, int]],\n",
    "    query_indices: dict[int, int] = {-1: -1},  # patch_idx -> clean_idx\n",
    "    verify_head_behavior_on: Optional[int] = None,\n",
    "    ablate_possible_ans_info_from_options: bool = False,\n",
    "    amplification_scale: float = 1.0,\n",
    "    must_track_tokens: list[int] = [],\n",
    "    patch_args: dict[str, Any] = {},\n",
    "):\n",
    "    clean_tokenized = prepare_input(prompts=clean_sample.prompt(), tokenizer=mt)\n",
    "    patch_tokenized = prepare_input(prompts=patch_sample.prompt(), tokenizer=mt)\n",
    "    if patch_args.get(\"batch_size\", 1) > 1:\n",
    "        patch_samples = []\n",
    "        task = patch_args[\"task\"]\n",
    "        logger.debug(f\"Sampling {patch_args.get('batch_size', 1)} patch samples...\")\n",
    "        while len(patch_samples) < patch_args.get(\"batch_size\", 1):\n",
    "            obj_idx = len(patch_samples) % len(patch_sample.options)\n",
    "            if patch_args[\"distinct_options\"] is True:\n",
    "                sample = task.get_random_sample(\n",
    "                    mt=mt,\n",
    "                    category=patch_sample.category,\n",
    "                    prompt_template_idx=patch_args[\"prompt_template_idx\"],\n",
    "                    option_style=patch_args[\"option_style\"],\n",
    "                    filter_by_lm_prediction=True,\n",
    "                    # exclude_objs=[clean_sample.obj, patch_sample.obj],\n",
    "                    n_distractors=patch_args[\"n_distractors\"],\n",
    "                    obj_idx=obj_idx,\n",
    "                )\n",
    "            else:\n",
    "                sample = copy.deepcopy(patch_sample)\n",
    "                sample.options[obj_idx], sample.options[sample.obj_idx] = (\n",
    "                    sample.options[sample.obj_idx],\n",
    "                    sample.options[obj_idx],\n",
    "                )\n",
    "                sample.obj_idx = obj_idx\n",
    "                # random.shuffle(sample.options)\n",
    "            patch_samples.append(sample)\n",
    "        patch_tokenized_batch = prepare_input(\n",
    "            prompts=[sample.prompt() for sample in patch_samples], tokenizer=mt\n",
    "        )\n",
    "        logger.debug(f\"{patch_tokenized_batch.input_ids.shape}\")\n",
    "\n",
    "    if verify_head_behavior_on is not None:\n",
    "        logger.info(\"Verifying head behavior...\")\n",
    "\n",
    "        logger.info(f\"Clean Sample >> Ans: {mt.tokenizer.decode(clean_sample.ans_token_id)}\")\n",
    "        clean_attn_pattern = verify_head_patterns(  # noqa\n",
    "            prompt=clean_sample.prompt(),\n",
    "            tokenized_prompt=clean_tokenized,\n",
    "            options=(\n",
    "                [\n",
    "                    locate_with_delim(clean_sample.prompt(), opt)\n",
    "                    for opt in clean_sample.options\n",
    "                ]\n",
    "                if ablate_possible_ans_info_from_options\n",
    "                else clean_sample.options\n",
    "            ),\n",
    "            mt=mt,\n",
    "            heads=heads,\n",
    "            generate_full_answer=False,\n",
    "            query_index=verify_head_behavior_on,\n",
    "            ablate_possible_ans_info_from_options=ablate_possible_ans_info_from_options,\n",
    "        )\n",
    "\n",
    "        logger.info(f\"Patch Sample >> Ans: {mt.tokenizer.decode(patch_sample.ans_token_id)}\")\n",
    "        patch_attn_pattern = verify_head_patterns(  # noqa\n",
    "            prompt=patch_sample.prompt(),\n",
    "            tokenized_prompt=patch_tokenized,\n",
    "            options=(\n",
    "                [\n",
    "                    locate_with_delim(patch_sample.prompt(), opt)\n",
    "                    for opt in patch_sample.options\n",
    "                ]\n",
    "                if ablate_possible_ans_info_from_options\n",
    "                else patch_sample.options\n",
    "            ),\n",
    "            mt=mt,\n",
    "            heads=heads,\n",
    "            generate_full_answer=False,\n",
    "            query_index=verify_head_behavior_on,\n",
    "            ablate_possible_ans_info_from_options=ablate_possible_ans_info_from_options,\n",
    "        )\n",
    "\n",
    "    logger.info(f\"Caching the query states for the {len(heads)} heads\")\n",
    "\n",
    "    cached_q_states, patch_output = cache_q_projections(\n",
    "        mt=mt,\n",
    "        input=patch_tokenized,\n",
    "        heads=heads,\n",
    "        token_indices=[list(query_indices.keys())],\n",
    "        return_output=True,\n",
    "    )\n",
    "    if patch_args.get(\"batch_size\", 1) > 1:\n",
    "        cached_q_states = cache_q_projections(\n",
    "            mt=mt,\n",
    "            input=patch_tokenized_batch,\n",
    "            heads=heads,\n",
    "            token_indices=[list(query_indices.keys())]\n",
    "            * patch_args.get(\"batch_size\", 1),\n",
    "            return_output=False,\n",
    "        )\n",
    "        mean_q_states = {}\n",
    "        for prompt_idx in range(patch_args.get(\"batch_size\", 1)):\n",
    "            for key, value in cached_q_states[prompt_idx].items():\n",
    "                if key not in mean_q_states:\n",
    "                    mean_q_states[key] = []\n",
    "                mean_q_states[key].append(value)\n",
    "        for key, value in mean_q_states.items():\n",
    "            mean_q_states[key] = torch.mean(torch.stack(value), dim=0)\n",
    "        cached_q_states = [mean_q_states]\n",
    "\n",
    "    q_proj_patches = []\n",
    "    for (layer_idx, head_idx, patch_query_idx), q_proj in cached_q_states[0].items():\n",
    "        q_proj_patches.append(\n",
    "            PatchSpec(\n",
    "                location=(\n",
    "                    mt.attn_module_name_format.format(layer_idx) + \".q_proj\",\n",
    "                    head_idx,\n",
    "                    query_indices[patch_query_idx],\n",
    "                ),\n",
    "                patch=q_proj,\n",
    "            )\n",
    "        )\n",
    "\n",
    "    patch_logits = patch_output.logits[:, -1, :].squeeze()\n",
    "    patch_predictions = interpret_logits(\n",
    "        tokenizer=mt,\n",
    "        logits=patch_logits,\n",
    "    )\n",
    "    logger.info(f\"patch_prediction={[str(pred) for pred in patch_predictions]}\")\n",
    "\n",
    "    # interested_tokens = [\n",
    "    #     patch_sample.ans_token_id,\n",
    "    #     clean_sample.ans_token_id,\n",
    "    #     clean_sample.metadata[\"track_type_obj_token_id\"],\n",
    "    # ]\n",
    "    interested_tokens = get_options_for_answer(sample=clean_sample)\n",
    "    interested_tokens = [\n",
    "        get_first_token_id(name=opt, tokenizer=mt.tokenizer, prefix=\" \")\n",
    "        for opt in interested_tokens\n",
    "    ]\n",
    "    # interested_tokens += [patch_sample.ans_token_id]\n",
    "    # interested_tokens = list(set(interested_tokens))  # remove duplicates #! don't need to, made sure during sampling\n",
    "\n",
    "    logger.info(\"clean run\")\n",
    "    clean_output = patch_with_baukit(\n",
    "        mt=mt,\n",
    "        inputs=clean_tokenized,\n",
    "        patches=[],\n",
    "    )\n",
    "    clean_logits = clean_output.logits[:, -1, :].squeeze()\n",
    "    clean_predictions, clean_track = interpret_logits(\n",
    "        tokenizer=mt,\n",
    "        logits=clean_logits,\n",
    "        interested_tokens=interested_tokens + must_track_tokens,\n",
    "    )\n",
    "    logger.info(f\"clean_prediction={[str(pred) for pred in clean_predictions]}\")\n",
    "    logger.info(f\"clean_track={clean_track}\")\n",
    "\n",
    "    logger.info(\"patching the q_proj states\")\n",
    "\n",
    "    if verify_head_behavior_on is not None and amplification_scale == 1.0:\n",
    "        int_attn_pattern = verify_head_patterns(\n",
    "            prompt=clean_sample.prompt(),\n",
    "            tokenized_prompt=clean_tokenized,\n",
    "            options=(\n",
    "                [\n",
    "                    locate_with_delim(clean_sample.prompt(), opt)\n",
    "                    for opt in clean_sample.options\n",
    "                ]\n",
    "                if ablate_possible_ans_info_from_options\n",
    "                else clean_sample.options   \n",
    "            ),\n",
    "            mt=mt,\n",
    "            heads=heads,\n",
    "            query_patches=q_proj_patches,\n",
    "            generate_full_answer=False,\n",
    "            query_index=verify_head_behavior_on,\n",
    "            ablate_possible_ans_info_from_options=ablate_possible_ans_info_from_options,\n",
    "        )\n",
    "        int_logits = int_attn_pattern[\"logits\"].squeeze()\n",
    "\n",
    "    else:\n",
    "        default_attn_implementation = mt.config._attn_implementation\n",
    "        if amplification_scale != 1.0:\n",
    "            mt.reset_forward()\n",
    "            mt.set_attn_implementation(\"sdpa\")\n",
    "\n",
    "            layers_to_heads = {}\n",
    "            for layer_idx, head_idx in heads:\n",
    "                if layer_idx not in layers_to_heads:\n",
    "                    layers_to_heads[layer_idx] = []\n",
    "                layers_to_heads[layer_idx].append(head_idx)\n",
    "\n",
    "            layers_to_q_patches = {}\n",
    "            for (\n",
    "                layer_idx,\n",
    "                head_idx,\n",
    "                patch_query_idx,\n",
    "            ), patch in cached_q_states[0].items():\n",
    "                if layer_idx not in layers_to_q_patches:\n",
    "                    layers_to_q_patches[layer_idx] = []\n",
    "                layers_to_q_patches[layer_idx].append(\n",
    "                    (head_idx, query_indices[patch_query_idx], patch)\n",
    "                )\n",
    "\n",
    "            attention_patterns = {}\n",
    "            head_contributions = {}\n",
    "            for layer_idx, head_indices in layers_to_heads.items():\n",
    "                attn_block_name = mt.attn_module_name_format.format(layer_idx)\n",
    "                attn_block = baukit.get_module(mt._model, attn_block_name)\n",
    "\n",
    "                attention_patterns[layer_idx] = {}\n",
    "                head_contributions[layer_idx] = {}\n",
    "\n",
    "                attn_block.forward = types.MethodType(\n",
    "                    LlamaAttentionPatcher(\n",
    "                        block_name=attn_block_name,\n",
    "                        save_attn_for=head_indices,\n",
    "                        store_attn_matrices=attention_patterns[layer_idx],\n",
    "                        store_head_contributions=head_contributions[layer_idx],\n",
    "                        query_patches=layers_to_q_patches[layer_idx],\n",
    "                        amplify_contributions=[\n",
    "                            (head_idx, q_idx, amplification_scale)\n",
    "                            for head_idx in head_indices\n",
    "                            for q_idx in query_indices.values()\n",
    "                        ],\n",
    "                        # value_weighted=True,\n",
    "                    ),\n",
    "                    attn_block,\n",
    "                )\n",
    "            patches = []  # already handled by hooking the default forward pass\n",
    "\n",
    "        else:\n",
    "            patches = q_proj_patches\n",
    "\n",
    "        if ablate_possible_ans_info_from_options:\n",
    "            patches.extend(\n",
    "                get_patches_to_verify_independent_enrichment(\n",
    "                    prompt=clean_sample.prompt(),\n",
    "                    options=clean_sample.options,\n",
    "                    pivot=clean_sample.subj,\n",
    "                    mt=mt,\n",
    "                    tokenized_prompt=clean_tokenized,\n",
    "                )\n",
    "            )\n",
    "\n",
    "        int_out = patch_with_baukit(\n",
    "            mt=mt,\n",
    "            inputs=clean_tokenized,\n",
    "            patches=patches,\n",
    "        )\n",
    "        int_logits = int_out.logits[:, -1, :].squeeze()\n",
    "\n",
    "        if amplification_scale != 1.0:\n",
    "            mt.reset_forward()\n",
    "            mt.set_attn_implementation(default_attn_implementation)\n",
    "\n",
    "            if verify_head_behavior_on is not None:\n",
    "                attn_matrix = []\n",
    "                for layer_idx in attention_patterns:\n",
    "                    for head_idx in attention_patterns[layer_idx]:\n",
    "                        attn_matrix.append(\n",
    "                            attention_patterns[layer_idx][head_idx].cpu()\n",
    "                        )\n",
    "\n",
    "                attn_matrix = torch.stack(attn_matrix).squeeze()\n",
    "                if attn_matrix.dim() == 3:\n",
    "                    attn_matrix = attn_matrix.mean(dim=0)\n",
    "\n",
    "                visualize_attn_matrix(\n",
    "                    attn_matrix=attn_matrix,\n",
    "                    tokens=[\n",
    "                        mt.tokenizer.decode(t) for t in clean_tokenized[\"input_ids\"][0]\n",
    "                    ],\n",
    "                )\n",
    "\n",
    "    int_predictions, int_track = interpret_logits(\n",
    "        tokenizer=mt,\n",
    "        logits=int_logits,\n",
    "        interested_tokens=interested_tokens + must_track_tokens,\n",
    "    )\n",
    "    logger.info(f\"int_prediction={[str(pred) for pred in int_predictions]}\")\n",
    "    logger.info(f\"int_track={int_track}\")\n",
    "\n",
    "    return {\n",
    "        \"clean_sample\": clean_sample,\n",
    "        \"patch_sample\": patch_sample,\n",
    "        \"clean_predictions\": clean_predictions,\n",
    "        \"patch_predictions\": patch_predictions,\n",
    "        \"int_predictions\": int_predictions,\n",
    "        \"clean_track\": clean_track,\n",
    "        \"int_track\": int_track,\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "efb29ec2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# no_attn = set(optimized_heads) - set(heads_max_ie)\n",
    "# no_attn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3d041d50",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.optimization import validate_q_proj_ie_on_sample_pair\n",
    "import copy\n",
    "\n",
    "# clean, patch = copy.deepcopy(validation_set[15])\n",
    "# clean.default_option_style=\"numbered\"\n",
    "# patch.default_option_style=\"numbered\"\n",
    "# clean, patch = train_set[18]\n",
    "\n",
    "clean = copy.deepcopy(clean_sample)\n",
    "patch = copy.deepcopy(patch_sample)\n",
    "\n",
    "# clean, patch = copy.deepcopy(clean_sample), copy.deepcopy(patch_sample)\n",
    "\n",
    "# failed_case = failed_cases[27]\n",
    "# clean = failed_case[\"clean_sample\"]\n",
    "# patch = failed_case[\"patch_sample\"]\n",
    "\n",
    "print(clean.prompt(), \">>\", mt.tokenizer.decode(clean.ans_token_id))\n",
    "print(patch.prompt(), \">>\", mt.tokenizer.decode(patch.ans_token_id))\n",
    "\n",
    "mt.set_attn_implementation(\"eager\")\n",
    "mt.reset_forward()\n",
    "\n",
    "validation_result = validate_q_proj_ie_on_sample_pair(\n",
    "    mt=mt,\n",
    "    clean_sample=clean,\n",
    "    patch_sample=patch,\n",
    "    # heads=optimized_heads,\n",
    "    # heads=sorted(heads_max_ie),\n",
    "    # heads=sorted(list(no_attn)),\n",
    "    # heads=sorted(qwen_32_heads, key=lambda x: (x[0], x[1])),\n",
    "    # heads=sorted(heads_attn_behavior, key=lambda x: (x[0], x[1])),\n",
    "    heads=[(35, 19)],\n",
    "    # query_indices={\n",
    "    #     patch.metadata[\"ques_pos\"]: clean.metadata[\"ques_pos\"],\n",
    "    #     -2: -2,\n",
    "    #     -1: -1,\n",
    "    # },\n",
    "    query_indices={tok_idx: tok_idx for tok_idx in range(-3, 0)},\n",
    "    # verify_head_behavior_on=clean.metadata[\"ques_pos\"]\n",
    "    verify_head_behavior_on=-1,\n",
    "    # ablate_possible_ans_info_from_options=True,\n",
    "    # amplification_scale=2.0\n",
    "    # patch_args={\n",
    "    #     \"batch_size\": N_DISTRACTORS + 1,\n",
    "    #     \"task\": select_task,\n",
    "    #     \"prompt_template_idx\": prompt_template_idx,\n",
    "    #     \"option_style\": patch.default_option_style,\n",
    "    #     \"distinct_options\": False,\n",
    "    #     \"n_distractors\": N_DISTRACTORS,\n",
    "    # },\n",
    "    generate_full_ans_for_verify=False,\n",
    "    must_track_tokens=[clean.ans_token_id, clean.metadata[\"track_type_obj_token_id\"]],\n",
    ")\n",
    "\n",
    "clean_obj = clean.ans_token_id\n",
    "target_obj = clean.metadata[\"track_type_obj_token_id\"]\n",
    "\n",
    "logger.debug(f\"clean obj: {mt.tokenizer.decode(clean_obj)}\")\n",
    "logger.debug(f\"target obj: {mt.tokenizer.decode(target_obj)}\")\n",
    "\n",
    "before_intervention = {\n",
    "    \"clean_rank\": validation_result[\"clean_track\"][clean_obj][0],\n",
    "    \"clean_logit\": validation_result[\"clean_track\"][clean_obj][1].logit,\n",
    "    \"target_rank\": validation_result[\"clean_track\"][target_obj][0],\n",
    "    \"target_logit\": validation_result[\"clean_track\"][target_obj][1].logit,\n",
    "}\n",
    "\n",
    "after_intervention = {\n",
    "    \"clean_rank\": validation_result[\"int_track\"][clean_obj][0],\n",
    "    \"clean_logit\": validation_result[\"int_track\"][clean_obj][1].logit,\n",
    "    \"target_rank\": validation_result[\"int_track\"][target_obj][0],\n",
    "    \"target_logit\": validation_result[\"int_track\"][target_obj][1].logit,\n",
    "}\n",
    "\n",
    "clean_rank_delta = after_intervention[\"clean_rank\"] - before_intervention[\"clean_rank\"]\n",
    "target_rank_delta = (\n",
    "    after_intervention[\"target_rank\"] - before_intervention[\"target_rank\"]\n",
    ")\n",
    "logger.info(\n",
    "    f\"Clean Prediction Rank Change: {before_intervention['clean_rank']} -> {after_intervention['clean_rank']} | Delta: {clean_rank_delta} \"\n",
    ")\n",
    "logger.info(\n",
    "    f\"Target Prediction Rank Change: {before_intervention['target_rank']} -> {after_intervention['target_rank']} | Delta: {target_rank_delta} \"\n",
    ")\n",
    "\n",
    "clean_logit_delta = (\n",
    "    after_intervention[\"clean_logit\"] - before_intervention[\"clean_logit\"]\n",
    ")\n",
    "target_logit_delta = (\n",
    "    after_intervention[\"target_logit\"] - before_intervention[\"target_logit\"]\n",
    ")\n",
    "logger.info(\n",
    "    f\"Clean Prediction Logit Change: {before_intervention['clean_logit']:.4f} -> {after_intervention['clean_logit']:.4f} | Delta: {clean_logit_delta:.4f} \"\n",
    ")\n",
    "logger.info(\n",
    "    f\"Target Prediction Logit Change: {before_intervention['target_logit']:.4f} -> {after_intervention['target_logit']:.4f} | Delta: {target_logit_delta:.4f} \"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "62de4053",
   "metadata": {},
   "outputs": [],
   "source": [
    "mt.tokenizer.decode(clean.metadata[\"track_type_obj_token_id\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "80bf8527",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(clean_sample.options)\n",
    "print(clean_sample.prompt())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "053091bd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# from src.selection.functional import cache_q_projections, verify_head_patterns\n",
    "# from typing import Literal\n",
    "# from src.functional import PatchSpec, interpret_logits\n",
    "# from src.hooking.llama_attention import LlamaAttentionPatcher\n",
    "# import baukit\n",
    "# import types\n",
    "\n",
    "\n",
    "# def set_attn_implementation(mt, attn_implementation: Literal[\"sdpa\", \"eager\"]):\n",
    "#     mt.config._attn_implementation = attn_implementation\n",
    "#     for layer_idx in range(mt.config.num_hidden_layers):\n",
    "#         attn_block_name = mt.attn_module_name_format.format(layer_idx)\n",
    "#         attn_block = baukit.get_module(mt._model, attn_block_name)\n",
    "#         attn_block.config._attn_implementation = attn_implementation\n",
    "\n",
    "\n",
    "# ##########################################################\n",
    "# query_indices = [-3, -2, -1]\n",
    "# heads = optimized_heads\n",
    "# ##########################################################\n",
    "\n",
    "# mt.reset_forward()\n",
    "# set_attn_implementation(mt, \"eager\")\n",
    "\n",
    "# clean_sample = failed_case[\"clean_sample\"]\n",
    "# patch_sample = failed_case[\"patch_sample\"]\n",
    "\n",
    "# clean_tokenized = prepare_input(prompts=clean_sample.prompt(), tokenizer=mt)\n",
    "# patch_tokenized = prepare_input(prompts=patch_sample.prompt(), tokenizer=mt)\n",
    "\n",
    "# verify_head_patterns(\n",
    "#     prompt=patch_sample.prompt(),\n",
    "#     tokenized_prompt=patch_tokenized,\n",
    "#     pivot=patch_sample.subj,\n",
    "#     options=patch_sample.options,\n",
    "#     mt=mt,\n",
    "#     heads=heads,\n",
    "#     query_index=-1,\n",
    "# )\n",
    "\n",
    "# verify_head_patterns(\n",
    "#     prompt=clean_sample.prompt(),\n",
    "#     tokenized_prompt=clean_tokenized,\n",
    "#     pivot=clean_sample.subj,\n",
    "#     options=clean_sample.options,\n",
    "#     mt=mt,\n",
    "#     heads=heads,\n",
    "#     query_index=-1,\n",
    "# )\n",
    "\n",
    "# query_locations = [\n",
    "#     (layer_idx, head_idx, query_idx)\n",
    "#     for layer_idx, head_idx in heads\n",
    "#     for query_idx in query_indices\n",
    "# ]\n",
    "\n",
    "# cached_q_states, patch_output = cache_q_projections(\n",
    "#     mt=mt,\n",
    "#     input=patch_tokenized,\n",
    "#     query_locations=query_locations,\n",
    "#     return_output=True,\n",
    "# )\n",
    "# q_proj_patches = []\n",
    "# for (layer_idx, head_idx, query_idx), q_proj in cached_q_states.items():\n",
    "#     q_proj_patches.append(\n",
    "#         PatchSpec(\n",
    "#             location=(\n",
    "#                 mt.attn_module_name_format.format(layer_idx) + \".q_proj\",\n",
    "#                 head_idx,\n",
    "#                 query_idx,\n",
    "#             ),\n",
    "#             patch=q_proj,\n",
    "#         )\n",
    "#     )\n",
    "\n",
    "# patch_logits = patch_output.logits[:, -1, :].squeeze()\n",
    "# patch_predictions = interpret_logits(\n",
    "#     tokenizer=mt,\n",
    "#     logits=patch_logits,\n",
    "# )\n",
    "# logger.info(f\"patch_prediction={[str(pred) for pred in patch_predictions]}\")\n",
    "\n",
    "\n",
    "# mt.reset_forward()\n",
    "# set_attn_implementation(mt, \"sdpa\")\n",
    "\n",
    "# layers_to_heads = {}\n",
    "# for layer_idx, head_idx in heads:\n",
    "#     if layer_idx not in layers_to_heads:\n",
    "#         layers_to_heads[layer_idx] = []\n",
    "#     layers_to_heads[layer_idx].append(head_idx)\n",
    "\n",
    "# layers_to_q_patches = {}\n",
    "# for (layer_idx, head_idx, query_idx), patch in cached_q_states.items():\n",
    "#     if layer_idx not in layers_to_q_patches:\n",
    "#         layers_to_q_patches[layer_idx] = []\n",
    "#     layers_to_q_patches[layer_idx].append((head_idx, query_idx, patch))\n",
    "\n",
    "# attention_patterns = {}\n",
    "# head_contributions = {}\n",
    "# for layer_idx, head_indices in layers_to_heads.items():\n",
    "#     attn_block_name = mt.attn_module_name_format.format(layer_idx)\n",
    "#     attn_block = baukit.get_module(mt._model, attn_block_name)\n",
    "\n",
    "#     attention_patterns[layer_idx] = {}\n",
    "#     head_contributions[layer_idx] = {}\n",
    "\n",
    "#     attn_block.forward = types.MethodType(\n",
    "#         LlamaAttentionPatcher(\n",
    "#             block_name=attn_block_name,\n",
    "#             save_attn_for=head_indices,\n",
    "#             store_attn_matrices=attention_patterns[layer_idx],\n",
    "#             store_head_contributions=head_contributions[layer_idx],\n",
    "#             query_patches=layers_to_q_patches[layer_idx],\n",
    "#             amplify_contributions=[\n",
    "#                 (head_idx, q_idx, 2.0)\n",
    "#                 for head_idx in head_indices\n",
    "#                 for q_idx in query_indices\n",
    "#             ],\n",
    "#             # value_weighted=True,\n",
    "#         ),\n",
    "#         attn_block,\n",
    "#     )\n",
    "\n",
    "\n",
    "# output = mt._model(**clean_tokenized)\n",
    "# int_logits = output.logits[:, -1, :].squeeze()\n",
    "# int_pred = interpret_logits(\n",
    "#     tokenizer=mt,\n",
    "#     logits=int_logits,\n",
    "# )\n",
    "\n",
    "# logger.info(f\"int_prediction={[str(pred) for pred in int_pred]}\")\n",
    "\n",
    "# mt.reset_forward()\n",
    "# set_attn_implementation(mt, \"eager\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8c8e8280",
   "metadata": {},
   "outputs": [],
   "source": [
    "# from src.selection.functional import visualize_attn_matrix\n",
    "\n",
    "# attn_matrix = []\n",
    "# for layer_idx in attention_patterns:\n",
    "#     for head_idx in attention_patterns[layer_idx]:\n",
    "#         attn_matrix.append(attention_patterns[layer_idx][head_idx].cpu())\n",
    "# attn_matrix = torch.stack(attn_matrix).squeeze().mean(dim=0)\n",
    "\n",
    "# visualize_attn_matrix(\n",
    "#     attn_matrix = attn_matrix,\n",
    "#     tokens = [mt.tokenizer.decode(t) for t in clean_tokenized[\"input_ids\"][0]]\n",
    "# )\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "45c65ba2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# heads_attn_behavior\n",
    "# for layer_idx, head_idx, score in scores_per_head[:79]:\n",
    "#     print(f\"Layer {layer_idx}, Head {head_idx}: {score:.4f}\")\n",
    "\n",
    "# heads_max_ie = [\n",
    "#     (layer_idx, head_idx)\n",
    "#     for layer_idx, head_idx, score in scores_per_head[:79]\n",
    "# ]\n",
    "\n",
    "raw_dir = os.path.join(\"figures\", model_key.split(\"/\")[-1], \"raw\")\n",
    "with open(\n",
    "    os.path.join(\n",
    "        raw_dir,\n",
    "        \"attention_pattern.json\"\n",
    "        # \"aie_per_head.json\"\n",
    "    ), \"r\") as f:\n",
    "    scores_per_head = json.load(f)\n",
    "\n",
    "heads_max_ie = []\n",
    "# for layer_idx, head_idx, score in scores_per_head:\n",
    "for head_idx, layer_idx, score in scores_per_head:\n",
    "    if layer_idx < 55 and layer_idx > 25:\n",
    "        heads_max_ie.append((layer_idx, head_idx))\n",
    "    if len(heads_max_ie) >= 80:\n",
    "        break\n",
    "\n",
    "print(heads_max_ie[:5])\n",
    "heads_max_ie = sorted(heads_max_ie, key=lambda x: (x[0], x[1]))\n",
    "len(heads_max_ie)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2017d6f",
   "metadata": {},
   "outputs": [],
   "source": [
    "len(optimized_heads)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "65931db6",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm.auto import tqdm\n",
    "from src.selection.optimization import validate_q_proj_ie_on_sample_pair\n",
    "\n",
    "validation_results = []\n",
    "for clean_sample, patch_sample in tqdm(validation_set):\n",
    "    result = validate_q_proj_ie_on_sample_pair(\n",
    "        mt=mt,\n",
    "        clean_sample=clean_sample,\n",
    "        patch_sample=patch_sample,\n",
    "        # heads=optimized_heads,\n",
    "        # heads = sorted(heads_attn_behavior),\n",
    "        heads = sorted(heads_max_ie),\n",
    "        # heads=backup_heads,\n",
    "        # heads=optimized_heads + backup_heads,\n",
    "        # heads = overlapping_heads,\n",
    "        query_indices={-2: -2, -1: -1},\n",
    "        add_ques_pos_to_query_indices=True,\n",
    "        # query_indices={tok_idx: tok_idx for tok_idx in range(-10, 0)},\n",
    "        verify_head_behavior_on=None,\n",
    "        # amplification_scale=1.5\n",
    "        patch_args={\n",
    "            \"batch_size\": len(patch_sample.options),\n",
    "            \"distinct_options\": False,\n",
    "        },\n",
    "    )\n",
    "    validation_results.append(result)\n",
    "    print(\"=\" * 80)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "024eed06",
   "metadata": {},
   "outputs": [],
   "source": [
    "before_intervention = []\n",
    "after_intervention = []\n",
    "\n",
    "for intervention_result in validation_results:\n",
    "    clean_sample = intervention_result[\"clean_sample\"]\n",
    "    patch_sample = intervention_result[\"patch_sample\"]\n",
    "\n",
    "    clean_obj = clean_sample.ans_token_id\n",
    "    target_obj = clean_sample.metadata[\"track_type_obj_token_id\"]\n",
    "\n",
    "    before_intervention.append({\n",
    "        \"clean_rank\": intervention_result[\"clean_track\"][clean_obj][0],\n",
    "        \"clean_logit\": intervention_result[\"clean_track\"][clean_obj][1].logit,\n",
    "        \"target_rank\": intervention_result[\"clean_track\"][target_obj][0],\n",
    "        \"target_logit\": intervention_result[\"clean_track\"][target_obj][1].logit,\n",
    "    })\n",
    "\n",
    "    after_intervention.append({\n",
    "        \"clean_rank\": intervention_result[\"int_track\"][clean_obj][0],\n",
    "        \"clean_logit\": intervention_result[\"int_track\"][clean_obj][1].logit,\n",
    "        \"target_rank\": intervention_result[\"int_track\"][target_obj][0],\n",
    "        \"target_logit\": intervention_result[\"int_track\"][target_obj][1].logit,\n",
    "    })"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "573b7692",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "clean_rank_delta = [\n",
    "    after[\"clean_rank\"] - before[\"clean_rank\"]\n",
    "    for before, after in zip(before_intervention, after_intervention)\n",
    "]\n",
    "target_rank_delta = [\n",
    "    after[\"target_rank\"] - before[\"target_rank\"]\n",
    "    for before, after in zip(before_intervention, after_intervention)\n",
    "]\n",
    "\n",
    "clean_rank_delta, target_rank_delta = np.array(clean_rank_delta), np.array(\n",
    "    target_rank_delta\n",
    ")\n",
    "print(f\"clean_rank_delta: {clean_rank_delta.mean():.4f} ± {clean_rank_delta.std():.4f}\")\n",
    "print(\n",
    "    f\"target_rank_delta: {target_rank_delta.mean():.4f} ± {target_rank_delta.std():.4f}\"\n",
    ")\n",
    "\n",
    "clean_rank_after_intervention = [after[\"clean_rank\"] for after in after_intervention]\n",
    "clean_rank_after_intervention = np.array(clean_rank_after_intervention)\n",
    "print(\n",
    "    f\"clean_rank_after_intervention: {clean_rank_after_intervention.mean():.4f} ± {clean_rank_after_intervention.std():.4f}\"\n",
    ")\n",
    "\n",
    "target_rank_after_intervention = [after[\"target_rank\"] for after in after_intervention]\n",
    "target_rank_after_intervention = np.array(target_rank_after_intervention)\n",
    "print(\n",
    "    f\"target_rank_after_intervention: {target_rank_after_intervention.mean():.4f} ± {target_rank_after_intervention.std():.4f}\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10522073",
   "metadata": {},
   "outputs": [],
   "source": [
    "clean_logit_delta = [\n",
    "    after[\"clean_logit\"] - before[\"clean_logit\"]\n",
    "    for before, after in zip(before_intervention, after_intervention)\n",
    "]\n",
    "target_logit_delta = [\n",
    "    after[\"target_logit\"] - before[\"target_logit\"]\n",
    "    for before, after in zip(before_intervention, after_intervention)\n",
    "]\n",
    "clean_logit_delta, target_logit_delta = np.array(clean_logit_delta), np.array(target_logit_delta)\n",
    "print(f\"clean_logit_delta: {clean_logit_delta.mean():.4f} ± {clean_logit_delta.std():.4f}\")\n",
    "print(f\"target_logit_delta: {target_logit_delta.mean():.4f} ± {target_logit_delta.std():.4f}\")\n",
    "\n",
    "clean_logit_after_intervention = [\n",
    "    after[\"clean_logit\"]\n",
    "    for after in after_intervention\n",
    "]\n",
    "clean_logit_after_intervention = np.array(clean_logit_after_intervention)\n",
    "print(f\"clean_logit_after_intervention: {clean_logit_after_intervention.mean():.4f} ± {clean_logit_after_intervention.std():.4f}\")\n",
    "\n",
    "target_logit_after_intervention = [\n",
    "    after[\"target_logit\"]\n",
    "    for after in after_intervention\n",
    "]\n",
    "target_logit_after_intervention = np.array(target_logit_after_intervention)\n",
    "print(f\"target_logit_after_intervention: {target_logit_after_intervention.mean():.4f} ± {target_logit_after_intervention.std():.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d484f25",
   "metadata": {},
   "outputs": [],
   "source": [
    "top_1 = sum([1 for after in after_intervention if after[\"target_rank\"] == 1])\n",
    "top_1 / len(after_intervention)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6fbdd059",
   "metadata": {},
   "outputs": [],
   "source": [
    "counter_patch_type_top_option = 0\n",
    "failed_cases = []\n",
    "\n",
    "for intervention_result in validation_results:\n",
    "    clean_sample = intervention_result[\"clean_sample\"]\n",
    "    patch_sample = intervention_result[\"patch_sample\"]\n",
    "    int_track = intervention_result[\"int_track\"]\n",
    "    clean_track = intervention_result[\"clean_track\"]\n",
    "    if (\n",
    "        int_track[list(int_track.keys())[0]][1].token_id\n",
    "        == clean_sample.metadata[\"track_type_obj_token_id\"]\n",
    "    ): \n",
    "        counter_patch_type_top_option += 1\n",
    "    else:\n",
    "        failed_cases.append(\n",
    "            {\n",
    "                \"clean_sample\": clean_sample,\n",
    "                \"patch_sample\": patch_sample,\n",
    "                \"int_track\": int_track,\n",
    "                \"clean_track\": clean_track,\n",
    "            }\n",
    "        )\n",
    "\n",
    "top_1_accuracy = counter_patch_type_top_option / len(validation_results)\n",
    "print(\n",
    "    f\"Counterfactual patching accuracy: {top_1_accuracy:.4f} ({counter_patch_type_top_option}/{len(validation_results)})\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2a66e502",
   "metadata": {},
   "outputs": [],
   "source": [
    "for failed_case in failed_cases[:20]:\n",
    "    clean_sample = failed_case[\"clean_sample\"]\n",
    "    patch_sample = failed_case[\"patch_sample\"]\n",
    "    int_track = failed_case[\"int_track\"]\n",
    "    clean_track = failed_case[\"clean_track\"]\n",
    "\n",
    "    print(\"Clean Sample:\")\n",
    "    print(clean_sample.prompt(), \">>\", mt.tokenizer.decode(clean_sample.ans_token_id))\n",
    "\n",
    "    print(\"-\" * 100)\n",
    "    print(\n",
    "        \"Track: \",\n",
    "        f\"\\\"{mt.tokenizer.decode(clean_sample.metadata['track_type_obj_token_id'])}\\\"\",\n",
    "    )\n",
    "    print(\n",
    "        \"Clean:\",\n",
    "        f\"(Token: {mt.tokenizer.decode(clean_sample.ans_token_id)})\",\n",
    "    )\n",
    "    print(\"-\" * 100)\n",
    "\n",
    "    clean_track = [pred for tok_id, (rank, pred) in clean_track.items()]\n",
    "    print(f\"Clean Track: {json.dumps([str(pred) for pred in clean_track], indent=4)}\")\n",
    "\n",
    "    int_track = [pred for tok_id, (rank, pred) in int_track.items()]\n",
    "    print(\n",
    "        f\"Intervened Track: {json.dumps([str(pred) for pred in int_track], indent=4)}\"\n",
    "    )\n",
    "    print(\"=\" * 100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fea0c607",
   "metadata": {},
   "outputs": [],
   "source": [
    "scores = {\n",
    "    \"FilterScore\": 0.6680 (342/512),\n",
    "    \"IE\": 0.5879 (301/512),\n",
    "}\n",
    "\n",
    "\n",
    "plt.figure(figsize=(4, 4))\n",
    "bars = plt.bar(scores.keys(), scores.values())\n",
    "plt.ylim(0, 100)\n",
    "for bar in bars:\n",
    "    height = bar.get_height()\n",
    "    plt.text(bar.get_x() + bar.get_width() / 2, height + 1, f\"{height:.2f}%\", ha='center', va='bottom')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d7a2e01",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "e8af723e",
   "metadata": {},
   "source": [
    "## Heads found with different tasks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab92210f",
   "metadata": {},
   "outputs": [],
   "source": [
    "task_names = [\n",
    "    \"select_one\", \n",
    "    \"select_order\"\n",
    "]\n",
    "heads = {task_name: [] for task_name in task_names}\n",
    "colors = {\n",
    "    \"select_one\": \"Blues\",\n",
    "    \"select_order\": \"Reds\",\n",
    "}\n",
    "\n",
    "# Create figure and axis\n",
    "fig, ax = plt.subplots(figsize=(12, 6))\n",
    "\n",
    "# Create combined mask for overlapping heads\n",
    "combined_mask = np.zeros((mt.config.num_attention_heads, mt.n_layer)) \n",
    "\n",
    "for i, task_name in enumerate(task_names):\n",
    "    print(\"Processing task:\", task_name)\n",
    "\n",
    "    optimized_path = os.path.join(\n",
    "        env_utils.DEFAULT_RESULTS_DIR,\n",
    "        \"selection/optimized_heads\",\n",
    "        mt.name.split(\"/\")[-1],\n",
    "        f\"{task_name}.npz\",\n",
    "    )\n",
    "    optimization_results = np.load(optimized_path, allow_pickle=True)\n",
    "    optimal_head_mask = torch.tensor(optimization_results[\"optimal_mask\"]).to(\n",
    "        torch.float32\n",
    "    )\n",
    "    print(f\"Optimal head mask shape for {task_name}: {optimal_head_mask.shape}\")\n",
    "    \n",
    "    # Get head positions\n",
    "    task_heads = torch.nonzero(optimal_head_mask > 0.5, as_tuple=False).tolist()\n",
    "    task_heads = [\n",
    "        (layer_idx, head_idx) for layer_idx, head_idx in task_heads if layer_idx < 50\n",
    "    ]\n",
    "    heads[task_name] = task_heads\n",
    "    \n",
    "    # Prepare mask for visualization\n",
    "    optimal_head_mask = optimal_head_mask.round()\n",
    "    optimal_head_mask[50:] = 0\n",
    "    \n",
    "    # Create a masked array to handle transparency properly\n",
    "    mask_array = optimal_head_mask.T.numpy()\n",
    "    masked_data = np.ma.masked_where(mask_array == 0, mask_array)\n",
    "    \n",
    "    # Plot with proper alpha blending\n",
    "    im = ax.imshow(\n",
    "        masked_data,\n",
    "        cmap=colors[task_name],\n",
    "        aspect=\"auto\",\n",
    "        vmin=0,\n",
    "        vmax=1.5,\n",
    "        alpha=0.8 if i == 0 else 0.5,  # Different alphas for better visibility\n",
    "        interpolation='nearest'\n",
    "    )\n",
    "    \n",
    "    # Track overlaps (optional)\n",
    "    combined_mask += mask_array * (i + 1)\n",
    "\n",
    "# Add labels and formatting\n",
    "ax.set_xlabel(\"Layer\")\n",
    "ax.set_ylabel(\"Head\")\n",
    "ax.set_title(\"Filter Heads Comparison: Select One (Blue) vs Select Order (Red)\")\n",
    "\n",
    "# Add grid for clarity\n",
    "ax.set_xticks(np.arange(0, 50, 5))\n",
    "ax.set_yticks(np.arange(0, optimal_head_mask.shape[1], 2))\n",
    "ax.grid(True, alpha=0.3, linestyle='--')\n",
    "\n",
    "# Create custom legend\n",
    "from matplotlib.patches import Patch\n",
    "legend_elements = [\n",
    "    Patch(facecolor='red', alpha=0.5, label='Select One'),\n",
    "    Patch(facecolor='blue', alpha=0.5, label='Select Order'),\n",
    "    Patch(facecolor='purple', alpha=1, label='Overlap')\n",
    "]\n",
    "ax.legend(handles=legend_elements, loc='upper right')\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "# Optional: Print overlap statistics\n",
    "overlapping_heads = []\n",
    "for task1_head in heads[\"select_one\"]:\n",
    "    if task1_head in heads[\"select_order\"]:\n",
    "        overlapping_heads.append(task1_head)\n",
    "\n",
    "print(f\"Total heads for select_one: {len(heads['select_one'])}\")\n",
    "print(f\"Total heads for select_order: {len(heads['select_order'])}\")\n",
    "print(f\"Overlapping heads: {len(overlapping_heads)}\")\n",
    "if overlapping_heads:\n",
    "    print(f\"Overlapping positions: {overlapping_heads}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a028a6d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# find intersection of heads\n",
    "overlapping_heads = set(heads[task_names[0]]) & set(heads[task_names[1]])\n",
    "print(f\"Intersection Heads: {len(overlapping_heads)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4b79f336",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "connection",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
