{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f8ec34ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bfae7235",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "\n",
    "import sys\n",
    "\n",
    "sys.path.append(\"../\")\n",
    "\n",
    "##################################################################\n",
    "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
    "# os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0,1,2,3,4,5,6,7\"\n",
    "##################################################################\n",
    "\n",
    "import logging\n",
    "from src.utils import logging_utils\n",
    "from src.utils import env_utils\n",
    "\n",
    "logger = logging.getLogger(__name__)\n",
    "\n",
    "logging.basicConfig(\n",
    "    level=logging.DEBUG,\n",
    "    format=logging_utils.DEFAULT_FORMAT,\n",
    "    datefmt=logging_utils.DEFAULT_DATEFMT,\n",
    "    stream=sys.stdout,\n",
    ")\n",
    "\n",
    "import torch\n",
    "import transformers\n",
    "\n",
    "logger.info(f\"{torch.__version__=}, {torch.version.cuda=}\")\n",
    "logger.info(\n",
    "    f\"{torch.cuda.is_available()=}, {torch.cuda.device_count()=}, {torch.cuda.get_device_name()=}\"\n",
    ")\n",
    "logger.info(f\"{transformers.__version__=}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc7720ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.utils.training_utils import get_device_map\n",
    "\n",
    "# model_key = \"meta-llama/Llama-3.2-3B\"\n",
    "# model_key = \"meta-llama/Llama-3.1-8B\"\n",
    "model_key = \"meta-llama/Llama-3.3-70B-Instruct\"\n",
    "# model_key = \"meta-llama/Llama-3.1-405B-Instruct\"\n",
    "\n",
    "# model_key = \"google/gemma-2-9b-it\"\n",
    "# model_key = \"google/gemma-3-12b-it\"\n",
    "# model_key = \"google/gemma-2-27b-it\"\n",
    "\n",
    "# model_key = \"deepseek-ai/DeepSeek-R1-Distill-Llama-8B\"\n",
    "\n",
    "# model_key = \"allenai/OLMo-2-1124-7B-Instruct\"\n",
    "# model_key = \"allenai/OLMo-7B-0424-hf\"\n",
    "\n",
    "# model_key = \"Qwen/Qwen2-7B\"\n",
    "# model_key = \"Qwen/Qwen2.5-14B-Instruct\"\n",
    "# model_key = \"Qwen/Qwen2.5-32B-Instruct\"\n",
    "# model_key = \"Qwen/Qwen2.5-72B-Instruct\"\n",
    "\n",
    "# model_key = \"Qwen/Qwen3-1.7B\"\n",
    "# model_key = \"Qwen/Qwen3-4B\"\n",
    "# model_key = \"Qwen/Qwen3-8B\"\n",
    "# model_key = \"Qwen/Qwen3-14B\"\n",
    "# model_key = \"Qwen/Qwen3-32B\"\n",
    "\n",
    "# device_map = get_device_map(model_key, 30, n_gpus=8)\n",
    "# device_map"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "683855df",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.models import ModelandTokenizer\n",
    "\n",
    "# from transformers import BitsAndBytesConfig\n",
    "\n",
    "mt = ModelandTokenizer(\n",
    "    model_key=model_key,\n",
    "    torch_dtype=torch.bfloat16,\n",
    "    # device_map=device_map,\n",
    "    device_map=\"auto\",\n",
    "    # quantization_config = BitsAndBytesConfig(\n",
    "    #     # load_in_4bit=True\n",
    "    #     load_in_8bit=True\n",
    "    # )\n",
    "    attn_implementation=\"eager\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "080021e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# file_path = os.path.join(\n",
    "#     env_utils.DEFAULT_DATA_DIR,\n",
    "#     \"selection\",\n",
    "#     # \"profession.json\"\n",
    "#     # \"nationality.json\"\n",
    "#     \"objects.json\",\n",
    "# )\n",
    "\n",
    "# with open(file_path, \"r\") as f:\n",
    "#     temp = json.load(f)\n",
    "\n",
    "# for cat in temp[\"categories\"]:\n",
    "#     temp[\"categories\"][cat] = [obj.capitalize() for obj in temp[\"categories\"][cat]]\n",
    "\n",
    "# with open(file_path, \"w\") as f:\n",
    "#     json.dump(temp, f, indent=4)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a62d97cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.data import SelectOneTask, SelectOrderTask\n",
    "\n",
    "#################################################################################\n",
    "# TASK_CLS = SelectOrderTask\n",
    "# prompt_template_idx = 1\n",
    "TASK_CLS = SelectOneTask\n",
    "prompt_template_idx = 3\n",
    "N_DISTRACTORS = 5\n",
    "OPTION_STYLE = \"single_line\"\n",
    "#################################################################################\n",
    "\n",
    "select_task = TASK_CLS.load(\n",
    "    path=os.path.join(\n",
    "        env_utils.DEFAULT_DATA_DIR, \n",
    "        \"selection\", \n",
    "        # \"profession.json\"\n",
    "        # \"nationality.json\"\n",
    "        \"objects.json\"\n",
    "    )\n",
    ")\n",
    "\n",
    "select_task.filter_single_token(tokenizer=mt.tokenizer, prefix=\" \")\n",
    "\n",
    "print(select_task)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "45087a7f",
   "metadata": {},
   "outputs": [],
   "source": [
    "sample = select_task.get_random_sample(\n",
    "    mt = mt,\n",
    "    option_style=OPTION_STYLE,\n",
    "    prompt_template_idx=prompt_template_idx,\n",
    "    obj_idx=2,\n",
    "    # category=\"actor\",\n",
    "    # category=\"Brazil\"\n",
    "    category=\"fruit\",\n",
    "    filter_by_lm_prediction=False,\n",
    ")\n",
    "\n",
    "print(sample)\n",
    "print(sample.prompt())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f66399b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.utils import verify_correct_option\n",
    "# sample.prompt_template = select_prof.prompt_templates[3]\n",
    "print(f'\"{sample.prompt()}\"', \">>\", sample.obj)\n",
    "\n",
    "verify_correct_option(\n",
    "    mt=mt,\n",
    "    target=sample.obj,\n",
    "    options=sample.options,\n",
    "    input=sample.prompt()\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea0499b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.functional import generate_with_patch\n",
    "\n",
    "gen = generate_with_patch(\n",
    "    mt = mt,\n",
    "    inputs = sample.prompt(),\n",
    "    max_new_tokens=20,\n",
    "    do_sample=False,\n",
    "    remove_prefix=True\n",
    ")[0]\n",
    "print(f'\"{gen}\"', \">>\", sample.obj)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2e8a8aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "mt.n_layer, mt.config.num_attention_heads"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a35516e",
   "metadata": {},
   "outputs": [],
   "source": [
    "HEADS = [\n",
    "    (33, 45),\n",
    "    (33, 18),\n",
    "    (34, 1),\n",
    "    (34, 6),\n",
    "    (34, 7),\n",
    "    (35, 19),\n",
    "    (39, 40),\n",
    "    (42, 30),\n",
    "    (47, 18),\n",
    "    (52, 58),\n",
    "]\n",
    "# HEADS = [(35, 19)]\n",
    "\n",
    "\n",
    "# with open(\"optimized_heads.json\", \"r\") as f:\n",
    "#     HEADS = json.load(f)\n",
    "\n",
    "# with open(\"category_wise_heads.json\", \"r\") as f:\n",
    "#     category_wise_heads = json.load(f)\n",
    "# HEADS = [\n",
    "#     (layer_idx, head_idx)\n",
    "#     for layer_idx, head_idx, score in category_wise_heads[\"all\"][:100]\n",
    "# ]\n",
    "# HEADS = [(layer_idx, head_idx) for layer_idx, head_idx in HEADS if layer_idx < 61]\n",
    "\n",
    "\n",
    "print(len(HEADS))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe381b2f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib import pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "# optimized_path = os.path.join(\n",
    "#     env_utils.DEFAULT_RESULTS_DIR,\n",
    "#     \"selection/optimized_heads\",\n",
    "#     mt.name.split(\"/\")[-1],\n",
    "#     f\"{select_task.task_name}.npz\"\n",
    "# )\n",
    "\n",
    "optimized_path = os.path.join(\n",
    "    env_utils.DEFAULT_RESULTS_DIR,\n",
    "    \"selection/optimized_heads\",\n",
    "    model_key.split(\"/\")[-1],\n",
    "    \"distinct_options\",\n",
    "    # f\"{select_task.task_name}\",\n",
    "    \"select_one\",\n",
    "    # \"legacy\",\n",
    "    \"epoch_10.npz\"\n",
    ")\n",
    "\n",
    "optimization_results = np.load(optimized_path, allow_pickle=True)\n",
    "plt.plot(optimization_results[\"losses\"])\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f87810f",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(20, 10))\n",
    "\n",
    "optimal_head_mask = torch.tensor(optimization_results[\"optimal_mask\"]).to(torch.float32)\n",
    "optimal_head_mask[52:, :] = 0.0\n",
    "\n",
    "plt.imshow(\n",
    "    optimal_head_mask.T.numpy(),\n",
    "    cmap=\"Blues\",\n",
    "    aspect=\"auto\",\n",
    "    vmin=0,\n",
    "    vmax=1,\n",
    ")\n",
    "\n",
    "optimized_heads = torch.nonzero(optimal_head_mask > 0.5, as_tuple=False).tolist()\n",
    "optimized_heads = [\n",
    "    (layer_idx, head_idx) for layer_idx, head_idx in optimized_heads\n",
    "]\n",
    "print(len(optimized_heads))\n",
    "\n",
    "HEADS = optimized_heads\n",
    "\n",
    "(35, 19) in HEADS, (35, 19) in optimized_heads\n",
    "# [(29, 3) in HEADS]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "85a9a2a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.attention import get_attention_matrices\n",
    "from src.selection.functional import (\n",
    "    verify_head_patterns,\n",
    "    get_patches_to_verify_independent_enrichment,\n",
    ")\n",
    "\n",
    "attn_pattern = verify_head_patterns(\n",
    "    prompt=sample.prompt(option_style=\"single_line\"),\n",
    "    options=sample.options,\n",
    "    pivot=sample.subj,\n",
    "    mt=mt,\n",
    "    heads=HEADS,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "425f6285",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.data import get_counterfactual_samples_within_task\n",
    "\n",
    "patch_sample, clean_sample = get_counterfactual_samples_within_task(\n",
    "    mt=mt,\n",
    "    task=select_task,\n",
    "    patch_category=\"fruit\",\n",
    "    clean_category=\"vehicle\",\n",
    "    filter_by_lm_prediction=True,\n",
    "    prompt_template_idx=prompt_template_idx,\n",
    "    option_style=OPTION_STYLE,\n",
    "    distinct_options=True,\n",
    "    n_distractors=5,\n",
    ")\n",
    "\n",
    "# patch_sample.default_option_style = \"single_line\"\n",
    "# clean_sample.default_option_style = \"numbered\"\n",
    "\n",
    "clean_sample.metadata"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "510772fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.functional import generate_with_patch\n",
    "from src.selection.utils import get_first_token_id\n",
    "from src.selection.data import MCQify_sample\n",
    "\n",
    "mt.reset_forward()\n",
    "mt.set_attn_implementation(\"eager\")\n",
    "\n",
    "# patch_sample.options[patch_sample.obj_idx] = \"Screw\"\n",
    "# patch_sample.options[patch_sample.obj_idx] = patch_sample.obj\n",
    "patch_sample.options = [\"Cherry\", \"Knife\", \"Pen\", \"Ambulance\"]\n",
    "# patch_sample.options = [\"Ferry\", \"Knife\", \"Pen\", \"Ambulance\"]\n",
    "# patch_sample.options = [\"#\"]\n",
    "\n",
    "clean_sample.options = [\"Binder\", \"Peach\", \"Watch\", \"Scooter\", \"Phone\"]\n",
    "clean_sample.metadata = {\n",
    "    \"track_category\": \"fruit\",\n",
    "    \"track_type_obj\": \"Peach\",\n",
    "    \"track_type_obj_idx\": 1,\n",
    "    \"track_type_obj_token_id\": get_first_token_id(tokenizer=mt.tokenizer, name=\"Peach\", prefix=\" \")\n",
    "}\n",
    "clean_sample.obj_idx = 3\n",
    "clean_sample.obj = \"Scooter\"\n",
    "clean_sample.ans_token_id = get_first_token_id(tokenizer=mt.tokenizer, name=\"Scooter\", prefix=\" \")\n",
    "\n",
    "clean_sample.prompt_template = \"<_options_>\\nFind the <_category_> in the list.\\nAnswer:\"\n",
    "patch_sample.prompt_template = \"<_options_>\\nFind the <_category_> in the list.\\nAnswer:\"\n",
    "\n",
    "clean_sample = MCQify_sample(sample=clean_sample, tokenizer=mt.tokenizer)\n",
    "# clean_sample.default_option_style = \"bulleted\"\n",
    "\n",
    "for sample in [patch_sample, clean_sample]:\n",
    "    print(sample.prompt(), \">>\", sample.obj)\n",
    "    attn_pattern = verify_head_patterns(\n",
    "        prompt=sample.prompt(),\n",
    "        options=sample.options,\n",
    "        pivot=sample.subj,\n",
    "        mt=mt,\n",
    "        # heads=HEADS,\n",
    "        heads=[(35, 19)],\n",
    "        # generate_full_answer=True,\n",
    "        query_index=-1\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4b79f336",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.functional import cache_q_projections\n",
    "from src.functional import PatchSpec\n",
    "from src.tokens import prepare_input\n",
    "\n",
    "mt.reset_forward()\n",
    "mt.set_attn_implementation(\"eager\")\n",
    "#####################################################\n",
    "query_indices = {-3: -3, -2: -2, -1: -1}\n",
    "#####################################################\n",
    "\n",
    "clean_tokenized = prepare_input(\n",
    "    prompts=clean_sample.prompt(), \n",
    "    tokenizer=mt.tokenizer,\n",
    "    return_offsets_mapping=True\n",
    ")\n",
    "clean_offsets = clean_tokenized.pop(\"offset_mapping\")[0]\n",
    "\n",
    "patch_tokenized = prepare_input(\n",
    "    prompts=patch_sample.prompt(), \n",
    "    tokenizer=mt.tokenizer,\n",
    "    return_offsets_mapping=True\n",
    ")\n",
    "patch_offsets = patch_tokenized.pop(\"offset_mapping\")[0]\n",
    "\n",
    "cached_q_states, patch_output = cache_q_projections(\n",
    "    mt=mt,\n",
    "    input=patch_tokenized,\n",
    "    # heads=[(35, 19)],\n",
    "    heads = optimized_heads,\n",
    "    token_indices=[[-3, -2, -1]],\n",
    "    return_output=True,\n",
    ")\n",
    "\n",
    "q_proj_patches = []\n",
    "for (layer_idx, head_idx, patch_query_idx), q_proj in cached_q_states[0].items():\n",
    "    q_proj_patches.append(\n",
    "        PatchSpec(\n",
    "            location=(\n",
    "                mt.attn_module_name_format.format(layer_idx) + \".q_proj\",\n",
    "                head_idx,\n",
    "                query_indices[patch_query_idx],\n",
    "            ),\n",
    "            patch=q_proj,\n",
    "        )\n",
    "    )\n",
    "\n",
    "attn_pattern = verify_head_patterns(\n",
    "    prompt=clean_sample.prompt(),\n",
    "    options=clean_sample.options,\n",
    "    pivot=clean_sample.subj,\n",
    "    mt=mt,\n",
    "    # heads=HEADS,\n",
    "    # heads=optimized_heads,\n",
    "    heads=[(35, 19)],\n",
    "    query_index=-1,\n",
    "    query_patches=q_proj_patches\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "48ba7263",
   "metadata": {},
   "outputs": [],
   "source": [
    "# clean_sample.obj, clean_sample.metadata['track_type_obj']\n",
    "from src.tokens import find_token_range\n",
    "import random\n",
    "\n",
    "# random_target = random.choice(\n",
    "#     list(\n",
    "#         set(clean_sample.options)\n",
    "#         - {clean_sample.obj, clean_sample.metadata[\"track_type_obj\"]}\n",
    "#     )\n",
    "# )\n",
    "random_target = \"Phone\"\n",
    "logger.debug(f\"Random target: {random_target}\")\n",
    "\n",
    "patch_type_obj_token_idx = (\n",
    "    find_token_range(\n",
    "        string=clean_sample.prompt(),\n",
    "        substring=clean_sample.metadata[\"track_type_obj\"],\n",
    "        # substring=clean_sample.obj,\n",
    "        offset_mapping=clean_offsets,\n",
    "    )[1]\n",
    "    - 1\n",
    ")\n",
    "logger.debug(\n",
    "    f'{patch_type_obj_token_idx=} | \"{mt.tokenizer.decode(clean_tokenized.input_ids[0][patch_type_obj_token_idx])}\"'\n",
    ")\n",
    "\n",
    "target_obj_token_idx = (\n",
    "    find_token_range(\n",
    "        string=clean_sample.prompt(),\n",
    "        substring=random_target,\n",
    "        offset_mapping=clean_offsets,\n",
    "    )[1]\n",
    "    - 1\n",
    ")\n",
    "logger.debug(\n",
    "    f'{target_obj_token_idx=} | \"{mt.tokenizer.decode(clean_tokenized.input_ids[0][target_obj_token_idx])}\"'\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6a602a4d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# import nnsight\n",
    "# with mt.trace(\"hello\") as tracer:\n",
    "#     hs = mt.model.layers[-1].output[0]\n",
    "#     tracer.log(hs[0,0,0])\n",
    "#     out =  mt.lm_head.output.save()\n",
    "\n",
    "# print(out)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46d37dec",
   "metadata": {},
   "outputs": [],
   "source": [
    "len(optimized_heads)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "33110969",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.functional import cache_q_projections\n",
    "\n",
    "key_indices = {\n",
    "    patch_type_obj_token_idx: target_obj_token_idx,\n",
    "    target_obj_token_idx: patch_type_obj_token_idx,\n",
    "    patch_type_obj_token_idx + 1: target_obj_token_idx + 1,\n",
    "    target_obj_token_idx + 1: patch_type_obj_token_idx + 1,\n",
    "}\n",
    "\n",
    "key_locations = [\n",
    "    (layer_idx, head_idx, patch_key_idx)\n",
    "    for layer_idx, head_idx in HEADS\n",
    "    # for layer_idx, head_idx in [(35, 19)]\n",
    "    for patch_key_idx in key_indices.keys()\n",
    "]\n",
    "\n",
    "cached_k_states = cache_q_projections(\n",
    "    mt=mt,\n",
    "    input=clean_tokenized, #! should always be clean_tokenized\n",
    "    token_indices=[list(key_indices.keys())],\n",
    "    heads=optimized_heads,\n",
    "    return_output=False,\n",
    "    projection_signature=\".k_proj\",\n",
    ")[0]\n",
    "\n",
    "k_proj_patches = []\n",
    "for (layer_idx, head_idx, patch_key_idx), k_proj in cached_k_states.items():\n",
    "    k_proj_patches.append(\n",
    "        PatchSpec(\n",
    "            location=(\n",
    "                mt.attn_module_name_format.format(layer_idx) + \".k_proj\",\n",
    "                head_idx,\n",
    "                key_indices[patch_key_idx],\n",
    "            ),\n",
    "            patch=k_proj,\n",
    "        )\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb18c405",
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "from src.functional import get_module_nnsight, repeat_kv\n",
    "from src.attention import visualize_attn_matrix\n",
    "from src.hooking.llama_attention import apply_rotary_pos_emb\n",
    "\n",
    "layer_idx, head_idx = (35, 19)\n",
    "with mt.trace(clean_tokenized) as tracer:\n",
    "    q_proj_name = mt.attn_module_name_format.format(layer_idx) + \".q_proj\"\n",
    "    q_proj_module = get_module_nnsight(mt, q_proj_name)\n",
    "    clean_q_proj = q_proj_module.output.save()\n",
    "    # tracer.log(patch_q_proj.shape)\n",
    "\n",
    "    k_proj_name = mt.attn_module_name_format.format(layer_idx) + \".k_proj\"\n",
    "    k_proj_module = get_module_nnsight(mt, k_proj_name)\n",
    "    clean_k_proj = k_proj_module.output.save()\n",
    "    # tracer.log(patch_k_proj.shape)\n",
    "\n",
    "query = clean_q_proj\n",
    "key = clean_k_proj\n",
    "print(f\"{query.shape=}, {key.shape=}\")\n",
    "\n",
    "batch_size = clean_tokenized.input_ids.shape[0]\n",
    "seq_len = clean_tokenized.input_ids.shape[1]\n",
    "n_heads = mt.config.num_attention_heads\n",
    "n_kv_heads = mt.config.num_key_value_heads\n",
    "head_dim = mt.n_embd // n_heads\n",
    "key = key.view(batch_size, seq_len, -1, head_dim).transpose(1, 2)\n",
    "query = query.view(batch_size, seq_len, -1, head_dim).transpose(1, 2)\n",
    "\n",
    "key = repeat_kv(key, n_rep=n_heads // n_kv_heads)\n",
    "print(f\"{query.shape=}, {key.shape=}\")\n",
    "\n",
    "\n",
    "############## intervention ##############\n",
    "source_token_id = patch_type_obj_token_idx\n",
    "key_source = key[:, head_idx, source_token_id].clone()\n",
    "\n",
    "target_token_id = target_obj_token_idx\n",
    "key_target = key[:, head_idx, target_token_id].clone()\n",
    "\n",
    "# key[:, head_idx, source_token_id, :] = key_target\n",
    "# key[:, head_idx, target_token_id, :] = key_source\n",
    "key[:, head_idx, source_token_id, :] = cached_k_states[(layer_idx, head_idx, target_token_id)]\n",
    "key[:, head_idx, target_token_id, :] = cached_k_states[(layer_idx, head_idx, source_token_id)]\n",
    "############## intervention ##############\n",
    "\n",
    "\n",
    "scale_factor = 1 / math.sqrt(query.size(-1))\n",
    "L, S = query.size(-2), key.size(-2)\n",
    "attn_bias = torch.zeros(L, S, dtype=query.dtype)\n",
    "temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)\n",
    "attn_bias.masked_fill_(temp_mask.logical_not(), float(\"-inf\"))\n",
    "attn_bias.to(query.dtype)\n",
    "\n",
    "attn_weight = query @ key.transpose(-2, -1) * scale_factor\n",
    "attn_weight += attn_bias.to(attn_weight.dtype).to(attn_weight.device)\n",
    "attn_weight = torch.softmax(attn_weight, dim=-1)\n",
    "\n",
    "head_matrix = attn_weight[0, head_idx].squeeze()\n",
    "\n",
    "visualize_attn_matrix(\n",
    "    attn_matrix=head_matrix,\n",
    "    tokens=[mt.tokenizer.decode(t) for t in clean_tokenized.input_ids[0]]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f4fe3e1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# torch.allclose(\n",
    "#     cached_k_states[(layer_idx, head_idx, source_token_id)],\n",
    "#     key_source,\n",
    "#     atol=1e-3\n",
    "# )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "77fd0ae5",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.optimization import validate_q_proj_ie_on_sample_pair\n",
    "validation_result = validate_q_proj_ie_on_sample_pair(\n",
    "    mt=mt,\n",
    "    clean_sample=clean_sample,\n",
    "    patch_sample=patch_sample,\n",
    "    heads=optimized_heads,\n",
    "    query_indices={tok_idx: tok_idx for tok_idx in range(-10, 0)},\n",
    "    verify_head_behavior_on=-1,\n",
    "    # ablate_possible_ans_info_from_options=True,\n",
    "    # amplification_scale=2.0\n",
    "    # patch_args={\n",
    "    #     \"batch_size\": N_DISTRACTORS + 1,\n",
    "    #     \"task\": select_task,\n",
    "    #     \"prompt_template_idx\": prompt_template_idx,\n",
    "    #     \"option_style\": patch.default_option_style,\n",
    "    #     \"distinct_options\": False,\n",
    "    #     \"n_distractors\": N_DISTRACTORS,\n",
    "    # },\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ef4b3aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.utils.typing import TokenizerOutput\n",
    "from src.functional import interpret_logits, patch_with_baukit\n",
    "from src.attention import visualize_attn_matrix, visualize_average_attn_matrix\n",
    "import baukit\n",
    "from src.hooking.llama_attention import LlamaAttentionPatcher\n",
    "import types\n",
    "\n",
    "\n",
    "def verify_head_patterns_custom_attention(\n",
    "    prompt: str | TokenizerOutput,\n",
    "    options: list[str],\n",
    "    pivot: str,\n",
    "    mt: ModelandTokenizer,\n",
    "    heads: list[tuple[int, int]],\n",
    "    tokenized_prompt: TokenizerOutput | None = None,\n",
    "    visualize_individual_heads: bool = False,\n",
    "    value_weighted: bool = False,\n",
    "    ablate_possible_ans_info_from_options: bool = False,\n",
    "    bare_prompt_template=\" The fact that {}\",\n",
    "    query_index: int = -1,\n",
    "    query_patches: list[PatchSpec] = [],\n",
    "    key_patches: list[PatchSpec] = [],\n",
    "    start_from: int = 1,\n",
    "):\n",
    "    tokenized_prompt = (\n",
    "        prepare_input(\n",
    "            tokenizer=mt,\n",
    "            prompts=prompt,\n",
    "            return_offsets_mapping=True,\n",
    "        )\n",
    "        if tokenized_prompt is None\n",
    "        else tokenized_prompt\n",
    "    )\n",
    "    patches = (\n",
    "        get_patches_to_verify_independent_enrichment(\n",
    "            prompt=prompt,\n",
    "            options=options,\n",
    "            pivot=pivot,\n",
    "            mt=mt,\n",
    "            tokenized_prompt=tokenized_prompt,\n",
    "            bare_prompt_template=bare_prompt_template,\n",
    "        )\n",
    "        if ablate_possible_ans_info_from_options\n",
    "        else []\n",
    "    )\n",
    "    patches = patches + query_patches\n",
    "    print(len(patches), \"patches to ablate possible answer information from options\")\n",
    "\n",
    "    ################## inference with intervention #######################\n",
    "    ret_dict = {}\n",
    "    default_attn_implementation = mt.config._attn_implementation\n",
    "    mt.reset_forward()\n",
    "    mt.set_attn_implementation(\"sdpa\")\n",
    "    layers_to_heads = {}\n",
    "    for layer_idx, head_idx in heads:\n",
    "        if layer_idx not in layers_to_heads:\n",
    "            layers_to_heads[layer_idx] = []\n",
    "        layers_to_heads[layer_idx].append(head_idx)\n",
    "\n",
    "    layers_to_q_patches = {}\n",
    "    for patch in query_patches:\n",
    "        layer, head_idx, token_idx = patch.location\n",
    "        layer_idx = int(layer.split(\".\")[2])\n",
    "        patch_rep = patch.patch\n",
    "        if layer_idx not in layers_to_q_patches:\n",
    "            layers_to_q_patches[layer_idx] = []\n",
    "        layers_to_q_patches[layer_idx].append((head_idx, token_idx, patch_rep))\n",
    "\n",
    "    layers_to_k_patches = {}\n",
    "    for patch in key_patches:\n",
    "        layer, head_idx, token_idx = patch.location\n",
    "        layer_idx = int(layer.split(\".\")[2])\n",
    "        patch_rep = patch.patch\n",
    "        if layer_idx not in layers_to_k_patches:\n",
    "            layers_to_k_patches[layer_idx] = []\n",
    "        layers_to_k_patches[layer_idx].append((head_idx, token_idx, patch_rep))\n",
    "\n",
    "    attention_patterns = {}\n",
    "    head_contributions = {}\n",
    "    for layer_idx, head_indices in layers_to_heads.items():\n",
    "        attn_block_name = mt.attn_module_name_format.format(layer_idx)\n",
    "        attn_block = baukit.get_module(mt._model, attn_block_name)\n",
    "\n",
    "        attention_patterns[layer_idx] = {}\n",
    "        # head_contributions[layer_idx] = {}\n",
    "\n",
    "        attn_block.forward = types.MethodType(\n",
    "            LlamaAttentionPatcher(\n",
    "                block_name=attn_block_name,\n",
    "                save_attn_for=head_indices,\n",
    "                store_attn_matrices=attention_patterns[layer_idx],\n",
    "                # store_head_contributions=head_contributions[layer_idx],\n",
    "                query_patches=layers_to_q_patches[layer_idx] if layer_idx in layers_to_q_patches else [],\n",
    "                key_patches=layers_to_k_patches[layer_idx] if layer_idx in layers_to_k_patches else [],\n",
    "                value_weighted=value_weighted,\n",
    "            ),\n",
    "            attn_block,\n",
    "        )\n",
    "    patches = []  # q_proj_patches are handled by hooking the default forward pass\n",
    "\n",
    "    if ablate_possible_ans_info_from_options:\n",
    "        patches.extend(\n",
    "            get_patches_to_verify_independent_enrichment(\n",
    "                prompt=clean_sample.prompt(),\n",
    "                options=clean_sample.options,\n",
    "                pivot=clean_sample.subj,\n",
    "                mt=mt,\n",
    "                tokenized_prompt=clean_tokenized,\n",
    "            )\n",
    "        )\n",
    "    output = patch_with_baukit(\n",
    "        mt=mt,\n",
    "        inputs=tokenized_prompt,\n",
    "        patches=patches,\n",
    "    )\n",
    "    logits = output.logits[:, -1, :].squeeze()\n",
    "\n",
    "    mt.reset_forward()\n",
    "    mt.set_attn_implementation(default_attn_implementation)\n",
    "    ################## inference with intervention #######################\n",
    "\n",
    "    predictions = interpret_logits(\n",
    "        tokenizer=mt.tokenizer,\n",
    "        logits=logits,\n",
    "    )\n",
    "    logger.debug(f\"Predictions: {[str(p) for p in predictions]}\")\n",
    "    ret_dict[\"predictions\"] = predictions\n",
    "    ret_dict[\"logits\"] = logits\n",
    "    ret_dict[\"attn_matrices\"] = attention_patterns\n",
    "\n",
    "    if heads is not None and len(heads) > 0:\n",
    "        combined = []\n",
    "        for layer_idx, head_idx in heads:\n",
    "            head_matrix = torch.Tensor(\n",
    "                attention_patterns[layer_idx][head_idx].cpu()\n",
    "            )\n",
    "            combined.append(head_matrix)\n",
    "            if visualize_individual_heads:\n",
    "                logger.info(f\"Layer: {layer_idx}, Head: {head_idx}\")\n",
    "                visualize_attn_matrix(\n",
    "                    attn_matrix=head_matrix,\n",
    "                    tokens=[mt.tokenizer.decode(t) for t in tokenized_prompt.input_ids[0]],\n",
    "                    q_index=query_index,\n",
    "                    start_from=start_from,\n",
    "                )\n",
    "\n",
    "        logger.info(\"Combined attention matrix for all heads\")\n",
    "        combined_matrix = torch.stack(combined).squeeze()\n",
    "        if combined_matrix.dim() == 3:\n",
    "            combined_matrix = combined_matrix.mean(dim=0)\n",
    "        visualize_attn_matrix(\n",
    "            attn_matrix=combined_matrix,\n",
    "            tokens=[mt.tokenizer.decode(t) for t in tokenized_prompt.input_ids[0]],\n",
    "            q_index=query_index,\n",
    "            start_from=start_from,\n",
    "        )\n",
    "    return ret_dict\n",
    "\n",
    "clean_attn_info = verify_head_patterns_custom_attention(\n",
    "    prompt=clean_sample.prompt(),\n",
    "    options=clean_sample.options,\n",
    "    pivot=clean_sample.subj,\n",
    "    mt=mt,\n",
    "    # heads=optimized_heads,\n",
    "    heads=[(35, 19)],\n",
    "    query_index=-1,\n",
    ")\n",
    "\n",
    "q_attn_info = verify_head_patterns_custom_attention(\n",
    "    prompt=clean_sample.prompt(),\n",
    "    options=clean_sample.options,\n",
    "    pivot=clean_sample.subj,\n",
    "    mt=mt,\n",
    "    # heads=optimized_heads,\n",
    "    heads=[(35, 19)],\n",
    "    query_index=-1,\n",
    "    query_patches=q_proj_patches,\n",
    ")\n",
    "\n",
    "qk_attn_info = verify_head_patterns_custom_attention(\n",
    "    prompt=clean_sample.prompt(),\n",
    "    options=clean_sample.options,\n",
    "    pivot=clean_sample.subj,\n",
    "    mt=mt,\n",
    "    # heads=optimized_heads,\n",
    "    heads=[(35, 19)],\n",
    "    query_index=-1,\n",
    "    query_patches=q_proj_patches,\n",
    "    key_patches=k_proj_patches,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee0e2c6f",
   "metadata": {},
   "outputs": [],
   "source": [
    "clean_tok = get_first_token_id(clean_sample.obj, mt.tokenizer, prefix=\" \")\n",
    "q_target = clean_sample.metadata[\"track_type_obj_token_id\"]\n",
    "kq_target = get_first_token_id(random_target, mt.tokenizer, prefix=\" \")\n",
    "\n",
    "interested_tokens = {\n",
    "    \"clean_obj\": clean_tok,\n",
    "    \"q_target\": q_target,\n",
    "    \"kq_target\": kq_target\n",
    "}\n",
    "clean_pred, clean_track = interpret_logits(\n",
    "    logits=clean_attn_info[\"logits\"],\n",
    "    tokenizer=mt.tokenizer,\n",
    "    interested_tokens=list(interested_tokens.values())\n",
    ")\n",
    "logger.info(f\"clean_pred={[str(pred) for pred in clean_pred]}\")\n",
    "logger.info(f\"{clean_track=}\")\n",
    "\n",
    "int_pred, int_track = interpret_logits(\n",
    "    logits=qk_attn_info[\"logits\"],\n",
    "    tokenizer=mt.tokenizer,\n",
    "    interested_tokens=list(interested_tokens.values())\n",
    ")\n",
    "logger.info(f\"int_pred={[str(pred) for pred in int_pred]}\")\n",
    "logger.info(f\"{int_track=}\")\n",
    "\n",
    "print(\"\\n\")\n",
    "\n",
    "for token_type, token_id in interested_tokens.items():\n",
    "    logger.info(f\"{token_type}={token_id}, [\\\"{mt.tokenizer.decode(token_id)}\\\"]\")\n",
    "    on_clean = clean_track[token_id]\n",
    "    on_int = int_track[token_id]\n",
    "    logger.info(f\"Rank: {on_clean[0]} -> {on_int[0]} | Δ={on_int[0] - on_clean[0]}\")\n",
    "    logger.info(f\"Logit: {on_clean[1].logit:.4f} -> {on_int[1].logit:.4f} | Δ={on_int[1].logit - on_clean[1].logit:.4f}\")\n",
    "    print(\"=\" * 100)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "215157ba",
   "metadata": {},
   "source": [
    "## Scale up"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ef7f6b93",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.data import CounterFactualSamplePair, get_counterfactual_samples_interface\n",
    "from src.functional import free_gpu_cache\n",
    "import random\n",
    "\n",
    "select_task.filter_single_token(mt.tokenizer, prefix=\" \")\n",
    "\n",
    "validation_samples_save_path = os.path.join(\n",
    "    env_utils.DEFAULT_RESULTS_DIR,\n",
    "    \"selection\",\n",
    "    \"samples\",\n",
    "    \"validation_single_token_options\",\n",
    "    mt.name.split(\"/\")[-1],\n",
    "    select_task.task_name,\n",
    "    \"objects\",\n",
    ")\n",
    "os.makedirs(validation_samples_save_path, exist_ok=True)\n",
    "\n",
    "free_gpu_cache()\n",
    "validation_set = []\n",
    "validation_limit = 768\n",
    "start_number = 1\n",
    "\n",
    "counterfactual_sampler = get_counterfactual_samples_interface[select_task.task_name]\n",
    "\n",
    "while len(validation_set) < validation_limit:\n",
    "    print(f\"sample {len(validation_set)+1} / {validation_limit}\")\n",
    "    patch, clean = counterfactual_sampler(\n",
    "        mt=mt,\n",
    "        task=select_task,\n",
    "        filter_by_lm_prediction=True,\n",
    "        prompt_template_idx=prompt_template_idx,\n",
    "        option_style=OPTION_STYLE,\n",
    "        distinct_options=True,\n",
    "        # # n_distractors=N_DISTRACTORS,\n",
    "        patch_n_distractors=random.choice(range(2, 6)),\n",
    "        clean_n_distractors=random.choice(range(2, 6)),\n",
    "        # n_options = random.choice([4, 5, 6, 7])\n",
    "    )\n",
    "    validation_set.append((clean, patch))\n",
    "    cf_pair = CounterFactualSamplePair(\n",
    "        patch_sample=patch,\n",
    "        clean_sample=clean,\n",
    "    )\n",
    "    cf_pair.detensorize()\n",
    "    with open(\n",
    "        os.path.join(validation_samples_save_path, f\"{len(validation_set) + start_number - 1:05d}.json\"),\n",
    "        \"w\",\n",
    "    ) as f:\n",
    "        json.dump(cf_pair.to_dict(), f, indent=2)\n",
    "\n",
    "len(validation_set)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1dce8993",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.data import CounterFactualSamplePair\n",
    "import random\n",
    "\n",
    "validation_set = []\n",
    "validation_limit = 768\n",
    "\n",
    "validation_samples_load_path = os.path.join(\n",
    "    env_utils.DEFAULT_RESULTS_DIR,\n",
    "    \"selection\",\n",
    "    \"samples\",\n",
    "    \"validation_single_token_options\",\n",
    "    mt.name.split(\"/\")[-1],\n",
    "    select_task.task_name,\n",
    "    \"objects\",\n",
    ")\n",
    "\n",
    "sample_files = [\n",
    "    os.path.join(validation_samples_load_path, f)\n",
    "    for f in os.listdir(validation_samples_load_path)\n",
    "    if f.endswith(\".json\")\n",
    "]\n",
    "logger.info(f\"Found {len(sample_files)} sample files\")\n",
    "\n",
    "random.shuffle(sample_files)\n",
    "sample_files = sample_files[:validation_limit]\n",
    "for sample_file in sample_files:\n",
    "    with open(sample_file, \"r\") as f:\n",
    "        cf_pair_data = json.load(f)\n",
    "    cf_pair = CounterFactualSamplePair.from_dict(cf_pair_data)\n",
    "    validation_set.append((cf_pair.clean_sample, cf_pair.patch_sample))\n",
    "\n",
    "len(validation_set)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e6e3f3ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.optimization import validate_q_proj_ie_on_sample_pair\n",
    "import copy\n",
    "\n",
    "# clean, patch = copy.deepcopy(clean_sample), copy.deepcopy(patch_sample)\n",
    "clean, patch = copy.deepcopy(validation_set[22])\n",
    "\n",
    "print(clean.prompt(), \">>\", clean.obj)\n",
    "print(patch.prompt(), \">>\", patch.obj)\n",
    "\n",
    "validation_result = validate_q_proj_ie_on_sample_pair(\n",
    "    mt=mt,\n",
    "    clean_sample=clean,\n",
    "    patch_sample=patch,\n",
    "    # heads=HEADS,\n",
    "    heads=optimized_heads,\n",
    "    # heads=[(35, 19)],\n",
    "    query_indices={-3: -3, -2: -2, -1: -1},\n",
    "    # query_indices = {-idx: -idx for idx in range(1, 6)},\n",
    "    verify_head_behavior_on=-1,\n",
    "    # ablate_possible_ans_info_from_options=True,\n",
    "    # amplification_scale=1.1\n",
    ")\n",
    "\n",
    "clean_obj = clean.ans_token_id\n",
    "target_obj = clean.metadata[\"track_type_obj_token_id\"]\n",
    "\n",
    "logger.debug(f\"clean obj: {mt.tokenizer.decode(clean_obj)}\")\n",
    "logger.debug(f\"target obj: {mt.tokenizer.decode(target_obj)}\")\n",
    "\n",
    "before_intervention = {\n",
    "    \"clean_rank\": validation_result[\"clean_track\"][clean_obj][0],\n",
    "    \"clean_logit\": validation_result[\"clean_track\"][clean_obj][1].logit,\n",
    "    \"target_rank\": validation_result[\"clean_track\"][target_obj][0],\n",
    "    \"target_logit\": validation_result[\"clean_track\"][target_obj][1].logit,\n",
    "}\n",
    "\n",
    "after_intervention = {\n",
    "    \"clean_rank\": validation_result[\"int_track\"][clean_obj][0],\n",
    "    \"clean_logit\": validation_result[\"int_track\"][clean_obj][1].logit,\n",
    "    \"target_rank\": validation_result[\"int_track\"][target_obj][0],\n",
    "    \"target_logit\": validation_result[\"int_track\"][target_obj][1].logit,\n",
    "}\n",
    "\n",
    "clean_rank_delta = after_intervention[\"clean_rank\"] - before_intervention[\"clean_rank\"]\n",
    "target_rank_delta = (\n",
    "    after_intervention[\"target_rank\"] - before_intervention[\"target_rank\"]\n",
    ")\n",
    "logger.info(\n",
    "    f\"Clean Prediction Rank Change: {before_intervention['clean_rank']} -> {after_intervention['clean_rank']} | Delta: {clean_rank_delta} \"\n",
    ")\n",
    "logger.info(\n",
    "    f\"Target Prediction Rank Change: {before_intervention['target_rank']} -> {after_intervention['target_rank']} | Delta: {target_rank_delta} \"\n",
    ")\n",
    "\n",
    "clean_logit_delta = (\n",
    "    after_intervention[\"clean_logit\"] - before_intervention[\"clean_logit\"]\n",
    ")\n",
    "target_logit_delta = (\n",
    "    after_intervention[\"target_logit\"] - before_intervention[\"target_logit\"]\n",
    ")\n",
    "logger.info(\n",
    "    f\"Clean Prediction Logit Change: {before_intervention['clean_logit']:.4f} -> {after_intervention['clean_logit']:.4f} | Delta: {clean_logit_delta:.4f} \"\n",
    ")\n",
    "logger.info(\n",
    "    f\"Target Prediction Logit Change: {before_intervention['target_logit']:.4f} -> {after_intervention['target_logit']:.4f} | Delta: {target_logit_delta:.4f} \"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4dec6547",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm.auto import tqdm\n",
    "\n",
    "query_validation_results = []\n",
    "for clean_sample, patch_sample in tqdm(validation_set):\n",
    "    # clean_sample = copy.deepcopy(clean_sample)\n",
    "    # clean_sample.options = patch_sample.options\n",
    "\n",
    "    # patch_sample = copy.deepcopy(patch_sample)\n",
    "    # patch_sample.options = [\"#\"]\n",
    "    # patch_sample.prompt_template = \"Which among these objects mentioned above is a <_category_>?\\nAnswer:\"\n",
    "\n",
    "\n",
    "    # no information from the patch\n",
    "    result = validate_q_proj_ie_on_sample_pair(\n",
    "        mt=mt,\n",
    "        clean_sample=clean_sample,\n",
    "        patch_sample=patch_sample,\n",
    "        heads=optimized_heads,\n",
    "        query_indices={-3: -3, -2: -2, -1: -1},\n",
    "        verify_head_behavior_on=None,\n",
    "        # amplification_scale=1.5\n",
    "    )\n",
    "    query_validation_results.append(result)\n",
    "    print(\"=\" * 80)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6387a0b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "before_intervention = []\n",
    "after_intervention = []\n",
    "\n",
    "for intervention_result in query_validation_results:\n",
    "    clean_sample = intervention_result[\"clean_sample\"]\n",
    "    patch_sample = intervention_result[\"patch_sample\"]\n",
    "\n",
    "    clean_obj = clean_sample.ans_token_id\n",
    "    target_obj = clean_sample.metadata[\"track_type_obj_token_id\"]\n",
    "\n",
    "    before_intervention.append({\n",
    "        \"clean_rank\": intervention_result[\"clean_track\"][clean_obj][0],\n",
    "        \"clean_logit\": intervention_result[\"clean_track\"][clean_obj][1].logit,\n",
    "        \"target_rank\": intervention_result[\"clean_track\"][target_obj][0],\n",
    "        \"target_logit\": intervention_result[\"clean_track\"][target_obj][1].logit,\n",
    "    })\n",
    "\n",
    "    after_intervention.append({\n",
    "        \"clean_rank\": intervention_result[\"int_track\"][clean_obj][0],\n",
    "        \"clean_logit\": intervention_result[\"int_track\"][clean_obj][1].logit,\n",
    "        \"target_rank\": intervention_result[\"int_track\"][target_obj][0],\n",
    "        \"target_logit\": intervention_result[\"int_track\"][target_obj][1].logit,\n",
    "    })\n",
    "\n",
    "clean_rank_delta = [\n",
    "    after[\"clean_rank\"] - before[\"clean_rank\"]\n",
    "    for before, after in zip(before_intervention, after_intervention)\n",
    "]\n",
    "target_rank_delta = [\n",
    "    after[\"target_rank\"] - before[\"target_rank\"]\n",
    "    for before, after in zip(before_intervention, after_intervention)\n",
    "]\n",
    "\n",
    "clean_rank_delta, target_rank_delta = np.array(clean_rank_delta), np.array(\n",
    "    target_rank_delta\n",
    ")\n",
    "print(f\"clean_rank_delta: {clean_rank_delta.mean():.4f} ± {clean_rank_delta.std():.4f}\")\n",
    "print(\n",
    "    f\"target_rank_delta: {target_rank_delta.mean():.4f} ± {target_rank_delta.std():.4f}\"\n",
    ")\n",
    "\n",
    "clean_rank_after_intervention = [after[\"clean_rank\"] for after in after_intervention]\n",
    "clean_rank_after_intervention = np.array(clean_rank_after_intervention)\n",
    "print(\n",
    "    f\"clean_rank_after_intervention: {clean_rank_after_intervention.mean():.4f} ± {clean_rank_after_intervention.std():.4f}\"\n",
    ")\n",
    "\n",
    "target_rank_after_intervention = [after[\"target_rank\"] for after in after_intervention]\n",
    "target_rank_after_intervention = np.array(target_rank_after_intervention)\n",
    "print(\n",
    "    f\"target_rank_after_intervention: {target_rank_after_intervention.mean():.4f} ± {target_rank_after_intervention.std():.4f}\"\n",
    ")\n",
    "\n",
    "print(\"=\" * 80)\n",
    "\n",
    "clean_logit_delta = [\n",
    "    after[\"clean_logit\"] - before[\"clean_logit\"]\n",
    "    for before, after in zip(before_intervention, after_intervention)\n",
    "]\n",
    "target_logit_delta = [\n",
    "    after[\"target_logit\"] - before[\"target_logit\"]\n",
    "    for before, after in zip(before_intervention, after_intervention)\n",
    "]\n",
    "\n",
    "clean_logit_delta, target_logit_delta = np.array(clean_logit_delta), np.array(target_logit_delta)\n",
    "print(f\"clean_logit_delta: {clean_logit_delta.mean():.4f} ± {clean_logit_delta.std():.4f}\")\n",
    "print(f\"target_logit_delta: {target_logit_delta.mean():.4f} ± {target_logit_delta.std():.4f}\")\n",
    "\n",
    "clean_logit_after_intervention = [\n",
    "    after[\"clean_logit\"]\n",
    "    for after in after_intervention\n",
    "]\n",
    "clean_logit_after_intervention = np.array(clean_logit_after_intervention)\n",
    "print(f\"clean_logit_after_intervention: {clean_logit_after_intervention.mean():.4f} ± {clean_logit_after_intervention.std():.4f}\")\n",
    "\n",
    "target_logit_after_intervention = [\n",
    "    after[\"target_logit\"]\n",
    "    for after in after_intervention\n",
    "]\n",
    "target_logit_after_intervention = np.array(target_logit_after_intervention)\n",
    "print(f\"target_logit_after_intervention: {target_logit_after_intervention.mean():.4f} ± {target_logit_after_intervention.std():.4f}\")\n",
    "\n",
    "print(\"=\" * 80)\n",
    "\n",
    "counter_patch_type_top_option = 0\n",
    "failed_cases = []\n",
    "successful_cases = []\n",
    "for intervention_result in query_validation_results:\n",
    "    clean_sample = intervention_result[\"clean_sample\"]\n",
    "    patch_sample = intervention_result[\"patch_sample\"]\n",
    "    int_track = intervention_result[\"int_track\"]\n",
    "    clean_track = intervention_result[\"clean_track\"]\n",
    "    if (\n",
    "        int_track[list(int_track.keys())[0]][1].token_id\n",
    "        == clean_sample.metadata[\"track_type_obj_token_id\"]\n",
    "    ): \n",
    "        counter_patch_type_top_option += 1\n",
    "        successful_cases.append(\n",
    "            {\n",
    "                \"clean_sample\": clean_sample,\n",
    "                \"patch_sample\": patch_sample,\n",
    "                \"int_track\": int_track,\n",
    "                \"clean_track\": clean_track,\n",
    "            }\n",
    "        )\n",
    "    else:\n",
    "        failed_cases.append(\n",
    "            {\n",
    "                \"clean_sample\": clean_sample,\n",
    "                \"patch_sample\": patch_sample,\n",
    "                \"int_track\": int_track,\n",
    "                \"clean_track\": clean_track,\n",
    "            }\n",
    "        )\n",
    "\n",
    "top_1_accuracy = counter_patch_type_top_option / len(query_validation_results)\n",
    "print(\n",
    "    f\"Counterfactual patching accuracy (w/o avg trick): {top_1_accuracy:.4f} ({counter_patch_type_top_option}/{len(query_validation_results)})\"\n",
    ")\n",
    "print(f\"Failed: {len(failed_cases)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49218177",
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Optional, Any\n",
    "from src.selection.data import SelectionSample\n",
    "\n",
    "\n",
    "def patched_run_with_custom_attention(\n",
    "    mt: ModelandTokenizer,\n",
    "    input: TokenizerOutput,\n",
    "    patches: list[PatchSpec] = [],\n",
    "    query_patches: list[PatchSpec] = [],\n",
    "    key_patches: list[PatchSpec] = [],\n",
    "    return_attention_patterns: bool = False,\n",
    "    value_weighted: bool = False,\n",
    "):\n",
    "    default_attn_implementation = mt.config._attn_implementation\n",
    "    mt.reset_forward()\n",
    "    mt.set_attn_implementation(\"sdpa\")\n",
    "    layers_to_heads = {}\n",
    "    layers_to_q_patches = {}\n",
    "    for patch in query_patches:\n",
    "        layer, head_idx, token_idx = patch.location\n",
    "        layer_idx = int(layer.split(\".\")[2])\n",
    "        if layer_idx not in layers_to_heads:\n",
    "            layers_to_heads[layer_idx] = []\n",
    "        layers_to_heads[layer_idx].append(head_idx)\n",
    "        patch_rep = patch.patch\n",
    "        if layer_idx not in layers_to_q_patches:\n",
    "            layers_to_q_patches[layer_idx] = []\n",
    "        layers_to_q_patches[layer_idx].append((head_idx, token_idx, patch_rep))\n",
    "\n",
    "    layers_to_k_patches = {}\n",
    "    for patch in key_patches:\n",
    "        layer, head_idx, token_idx = patch.location\n",
    "        layer_idx = int(layer.split(\".\")[2])\n",
    "        if layer_idx not in layers_to_heads:\n",
    "            layers_to_heads[layer_idx] = []\n",
    "        layers_to_heads[layer_idx].append(head_idx)\n",
    "        patch_rep = patch.patch\n",
    "        if layer_idx not in layers_to_k_patches:\n",
    "            layers_to_k_patches[layer_idx] = []\n",
    "        layers_to_k_patches[layer_idx].append((head_idx, token_idx, patch_rep))\n",
    "\n",
    "    for layer_idx in layers_to_heads:\n",
    "        layers_to_heads[layer_idx] = list(set(layers_to_heads[layer_idx]))\n",
    "\n",
    "    attention_patterns = {}\n",
    "    for layer_idx, head_indices in layers_to_heads.items():\n",
    "        attn_block_name = mt.attn_module_name_format.format(layer_idx)\n",
    "        attn_block = baukit.get_module(mt._model, attn_block_name)\n",
    "\n",
    "        attention_patterns[layer_idx] = {}\n",
    "        # head_contributions[layer_idx] = {}\n",
    "\n",
    "        attn_block.forward = types.MethodType(\n",
    "            LlamaAttentionPatcher(\n",
    "                block_name=attn_block_name,\n",
    "                save_attn_for=head_indices,\n",
    "                store_attn_matrices=attention_patterns[layer_idx],\n",
    "                # store_head_contributions=head_contributions[layer_idx],\n",
    "                query_patches=(\n",
    "                    layers_to_q_patches[layer_idx]\n",
    "                    if layer_idx in layers_to_q_patches\n",
    "                    else []\n",
    "                ),\n",
    "                key_patches=(\n",
    "                    layers_to_k_patches[layer_idx]\n",
    "                    if layer_idx in layers_to_k_patches\n",
    "                    else []\n",
    "                ),\n",
    "                value_weighted=value_weighted,\n",
    "            ),\n",
    "            attn_block,\n",
    "        )\n",
    "\n",
    "    output = patch_with_baukit(\n",
    "        mt=mt,\n",
    "        inputs=input,\n",
    "        patches=patches,\n",
    "    )\n",
    "\n",
    "    mt.reset_forward()\n",
    "    mt.set_attn_implementation(default_attn_implementation)\n",
    "\n",
    "    if return_attention_patterns:\n",
    "        return output, attention_patterns\n",
    "    else:\n",
    "        return output\n",
    "\n",
    "\n",
    "@torch.inference_mode()\n",
    "def validate_k_proj_ie_on_sample_pair(\n",
    "    mt: ModelandTokenizer,\n",
    "    clean_sample: SelectionSample,\n",
    "    patch_sample: SelectionSample,\n",
    "    heads: list[tuple[int, int]],\n",
    "    query_indices: dict[int, int] = {-1: -1},\n",
    "    verify_head_behavior_on: Optional[int] = None,\n",
    "    ablate_possible_ans_info_from_options: bool = False,\n",
    "    must_track_tokens: list[int] = [],\n",
    "    patch_args: dict[str, Any] = {},\n",
    "):\n",
    "    ret_dict = {}\n",
    "    clean_tokenized = prepare_input(prompts=clean_sample.prompt(), tokenizer=mt, return_offsets_mapping=True)\n",
    "    patch_tokenized = prepare_input(prompts=patch_sample.prompt(), tokenizer=mt, return_offsets_mapping=True)\n",
    "    clean_offsets = clean_tokenized.pop(\"offset_mapping\")[0]\n",
    "    patch_offsets = patch_tokenized.pop(\"offset_mapping\")[0]\n",
    "\n",
    "    ret_dict[\"clean_sample\"] = clean_sample\n",
    "    ret_dict[\"patch_sample\"] = patch_sample\n",
    "\n",
    "    if verify_head_behavior_on is not None:\n",
    "        logger.info(\"Verifying head behavior of the samples...\")\n",
    "\n",
    "        logger.info(f\"Clean Sample >> Ans: {clean_sample.obj}\")\n",
    "        clean_attn_pattern = verify_head_patterns(  # noqa\n",
    "            prompt=clean_sample.prompt(),\n",
    "            tokenized_prompt=clean_tokenized,\n",
    "            # options=clean_sample.options,\n",
    "            options=[f\"{opt},\" for opt in clean_sample.options[:-1]]\n",
    "            + [f\"{clean_sample.options[-1]}.\"],\n",
    "            pivot=clean_sample.subj,\n",
    "            mt=mt,\n",
    "            heads=heads,\n",
    "            # generate_full_answer=True,\n",
    "            query_index=verify_head_behavior_on,\n",
    "            ablate_possible_ans_info_from_options=ablate_possible_ans_info_from_options,\n",
    "        )\n",
    "\n",
    "        logger.info(f\"Patch Sample >> Ans: {patch_sample.obj}\")\n",
    "        patch_attn_pattern = verify_head_patterns(  # noqa\n",
    "            prompt=patch_sample.prompt(),\n",
    "            tokenized_prompt=patch_tokenized,\n",
    "            # options=patch_sample.options,\n",
    "            options=[f\"{opt},\" for opt in patch_sample.options[:-1]]\n",
    "            + [f\"{patch_sample.options[-1]}.\"],\n",
    "            pivot=patch_sample.subj,\n",
    "            mt=mt,\n",
    "            heads=heads,\n",
    "            # generate_full_answer=True,\n",
    "            query_index=verify_head_behavior_on,\n",
    "            ablate_possible_ans_info_from_options=ablate_possible_ans_info_from_options,\n",
    "        )\n",
    "\n",
    "\n",
    "    cached_q_states, patch_output = cache_q_projections(\n",
    "        mt=mt,\n",
    "        input=patch_tokenized,\n",
    "        token_indices=[list(query_indices.keys())],\n",
    "        heads=heads,\n",
    "        return_output=True,\n",
    "    )\n",
    "    cached_q_states = cached_q_states[0]\n",
    "    q_proj_patches = []\n",
    "    for (layer_idx, head_idx, patch_query_idx), q_proj in cached_q_states.items():\n",
    "        q_proj_patches.append(\n",
    "            PatchSpec(\n",
    "                location=(\n",
    "                    mt.attn_module_name_format.format(layer_idx) + \".q_proj\",\n",
    "                    head_idx,\n",
    "                    query_indices[patch_query_idx],\n",
    "                ),\n",
    "                patch=q_proj,\n",
    "            )\n",
    "        )\n",
    "\n",
    "    patch_logits = patch_output.logits[:, -1, :].squeeze()\n",
    "    patch_predictions = interpret_logits(\n",
    "        tokenizer=mt,\n",
    "        logits=patch_logits,\n",
    "    )\n",
    "    logger.info(f\"patch_prediction={[str(pred) for pred in patch_predictions]}\")\n",
    "\n",
    "    random_target = random.choice(\n",
    "        list(\n",
    "            set(clean_sample.options)\n",
    "            - {clean_sample.obj, clean_sample.metadata[\"track_type_obj\"]}\n",
    "        )\n",
    "    )\n",
    "    clean_tok = get_first_token_id(clean_sample.obj, mt.tokenizer, prefix=\" \")\n",
    "    q_target = clean_sample.metadata[\"track_type_obj_token_id\"]\n",
    "    kq_target = get_first_token_id(random_target, mt.tokenizer, prefix=\" \")\n",
    "\n",
    "    track_tokens = {\n",
    "        \"clean_obj\": clean_tok,\n",
    "        \"q_target\": q_target,\n",
    "        \"kq_target\": kq_target,\n",
    "    }\n",
    "    ret_dict[\"track_tokens\"] = track_tokens\n",
    "    interested_tokens = [\n",
    "        get_first_token_id(option, mt.tokenizer, prefix=\" \")\n",
    "        for option in clean_sample.options\n",
    "    ] + must_track_tokens\n",
    "\n",
    "    if verify_head_behavior_on is not None:\n",
    "        logger.info(\"Applying q_proj patches (patching the predicate function)\")\n",
    "        q_out, q_attn_patterns = patched_run_with_custom_attention(\n",
    "            mt=mt,\n",
    "            input=clean_tokenized,\n",
    "            query_patches=q_proj_patches,\n",
    "            return_attention_patterns=True,\n",
    "        )\n",
    "        q_logits = q_out.logits[:, -1, :].squeeze()\n",
    "        q_predictions, q_track = interpret_logits(\n",
    "            logits=q_logits, tokenizer=mt.tokenizer, interested_tokens=interested_tokens\n",
    "        )\n",
    "        logger.info(f\"q_predictions={[str(pred) for pred in q_predictions]}\")\n",
    "        logger.info(f\"{q_track=}\")\n",
    "        ret_dict[\"q_predictions\"] = q_predictions\n",
    "        ret_dict[\"q_track\"] = q_track\n",
    "\n",
    "        combined = []\n",
    "        for layer_idx, head_idx in heads:\n",
    "            head_matrix = torch.Tensor(q_attn_patterns[layer_idx][head_idx].cpu())\n",
    "            combined.append(head_matrix)\n",
    "        combined = torch.stack(combined).squeeze()\n",
    "        if combined.dim() == 3:\n",
    "            combined = combined.mean(dim=0)\n",
    "        visualize_attn_matrix(\n",
    "            attn_matrix=combined,\n",
    "            tokens=[mt.tokenizer.decode(t) for t in clean_tokenized.input_ids[0]],\n",
    "            q_index=verify_head_behavior_on,\n",
    "            start_from=1,\n",
    "        )\n",
    "\n",
    "    logger.info(f\"{random_target=}\")\n",
    "    patch_type_obj_token_idx = (\n",
    "        find_token_range(\n",
    "            string=clean_sample.prompt(),\n",
    "            substring=clean_sample.metadata[\"track_type_obj\"],\n",
    "            # substring=clean_sample.obj,\n",
    "            offset_mapping=clean_offsets,\n",
    "        )[1]\n",
    "        - 1\n",
    "    )\n",
    "    target_obj_token_idx = (\n",
    "        find_token_range(\n",
    "            string=clean_sample.prompt(),\n",
    "            substring=random_target,\n",
    "            offset_mapping=clean_offsets,\n",
    "        )[1]\n",
    "        - 1\n",
    "    )\n",
    "    logger.debug(\n",
    "        f'{patch_type_obj_token_idx=} | \"{mt.tokenizer.decode(clean_tokenized.input_ids[0][patch_type_obj_token_idx])}\"'\n",
    "    )\n",
    "    logger.debug(\n",
    "        f'{target_obj_token_idx=} | \"{mt.tokenizer.decode(clean_tokenized.input_ids[0][target_obj_token_idx])}\"'\n",
    "    )\n",
    "\n",
    "    key_indices = {\n",
    "        patch_type_obj_token_idx: target_obj_token_idx,\n",
    "        target_obj_token_idx: patch_type_obj_token_idx,\n",
    "        patch_type_obj_token_idx + 1: target_obj_token_idx + 1,\n",
    "        target_obj_token_idx + 1: patch_type_obj_token_idx + 1,\n",
    "    }\n",
    "    cached_k_states, clean_output = cache_q_projections(\n",
    "        mt=mt,\n",
    "        input=clean_tokenized,  #! should always be clean_tokenized\n",
    "        token_indices=[list(key_indices.keys())],\n",
    "        heads=heads,\n",
    "        return_output=True,\n",
    "        projection_signature=\".k_proj\",\n",
    "    )\n",
    "    cached_k_states = cached_k_states[0]\n",
    "    clean_logits = clean_output.logits[:, -1, :].squeeze()\n",
    "    clean_predictions, clean_track = interpret_logits(\n",
    "        tokenizer=mt, logits=clean_logits, interested_tokens=interested_tokens\n",
    "    )\n",
    "    logger.info(f\"clean_prediction={[str(pred) for pred in clean_predictions]}\")\n",
    "    logger.info(f\"{clean_track=}\")\n",
    "    ret_dict[\"clean_predictions\"] = clean_predictions\n",
    "    ret_dict[\"clean_track\"] = clean_track\n",
    "\n",
    "    k_proj_patches = []\n",
    "    for (layer_idx, head_idx, patch_key_idx), k_proj in cached_k_states.items():\n",
    "        k_proj_patches.append(\n",
    "            PatchSpec(\n",
    "                location=(\n",
    "                    mt.attn_module_name_format.format(layer_idx) + \".k_proj\",\n",
    "                    head_idx,\n",
    "                    key_indices[patch_key_idx],\n",
    "                ),\n",
    "                patch=k_proj,\n",
    "            )\n",
    "        )\n",
    "\n",
    "    logger.info(\"Applying q_proj patches (patching the predicate function)\")\n",
    "    kq_out, kq_attn_patterns = patched_run_with_custom_attention(\n",
    "        mt=mt,\n",
    "        input=clean_tokenized,\n",
    "        query_patches=q_proj_patches,\n",
    "        key_patches=k_proj_patches,\n",
    "        return_attention_patterns=True,\n",
    "    )\n",
    "    kq_logits = kq_out.logits[:, -1, :].squeeze()\n",
    "    kq_predictions, kq_track = interpret_logits(\n",
    "        logits=kq_logits, tokenizer=mt.tokenizer, interested_tokens=interested_tokens\n",
    "    )\n",
    "    logger.info(f\"kq_predictions={[str(pred) for pred in kq_predictions]}\")\n",
    "    logger.info(f\"{kq_track=}\")\n",
    "    ret_dict[\"kq_predictions\"] = kq_predictions\n",
    "    ret_dict[\"kq_track\"] = kq_track\n",
    "\n",
    "    if verify_head_behavior_on is not None:\n",
    "        combined = []\n",
    "        for layer_idx, head_idx in heads:\n",
    "            head_matrix = torch.Tensor(kq_attn_patterns[layer_idx][head_idx].cpu())\n",
    "            combined.append(head_matrix)\n",
    "        combined = torch.stack(combined).squeeze()\n",
    "        if combined.dim() == 3:\n",
    "            combined = combined.mean(dim=0)\n",
    "        visualize_attn_matrix(\n",
    "            attn_matrix=combined,\n",
    "            tokens=[mt.tokenizer.decode(t) for t in clean_tokenized.input_ids[0]],\n",
    "            q_index=verify_head_behavior_on,\n",
    "            start_from=1,\n",
    "        )\n",
    "        ret_dict[\"combined_kq\"] = combined\n",
    "    return ret_dict\n",
    "\n",
    "clean_sample, patch_sample = copy.deepcopy(validation_set[78])\n",
    "\n",
    "# failed_case = failed_qk_cases[25]\n",
    "# clean_sample = failed_case[\"clean_sample\"]\n",
    "# patch_sample = failed_case[\"patch_sample\"]\n",
    "\n",
    "patching_result = validate_k_proj_ie_on_sample_pair(\n",
    "    mt=mt,\n",
    "    clean_sample=clean_sample,\n",
    "    patch_sample=patch_sample,\n",
    "    # heads=heads_selected,\n",
    "    heads=optimized_heads,\n",
    "    # heads=[(35, 19)],\n",
    "    query_indices=query_indices,\n",
    "    verify_head_behavior_on=-1,\n",
    ")\n",
    "\n",
    "for token_type, token_id in patching_result[\"track_tokens\"].items():\n",
    "    logger.info(f\"{token_type}={token_id}, [\\\"{mt.tokenizer.decode(token_id)}\\\"]\")\n",
    "    on_clean = patching_result[\"clean_track\"][token_id]\n",
    "    on_int = patching_result[\"kq_track\"][token_id]\n",
    "    logger.info(f\"Rank: {on_clean[0]} -> {on_int[0]} | Δ={on_int[0] - on_clean[0]}\")\n",
    "    logger.info(f\"Logit: {on_clean[1].logit:.4f} -> {on_int[1].logit:.4f} | Δ={on_int[1].logit - on_clean[1].logit:.4f}\")\n",
    "    print(\"=\" * 100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f2ede66d",
   "metadata": {},
   "outputs": [],
   "source": [
    "len(optimized_heads)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5064d581",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm.auto import tqdm\n",
    "\n",
    "kq_validation_results = []\n",
    "for case in tqdm(successful_cases):\n",
    "    clean_sample = case[\"clean_sample\"]\n",
    "    patch_sample = case[\"patch_sample\"]\n",
    "\n",
    "    # clean_sample = copy.deepcopy(clean_sample)\n",
    "    # clean_sample.options = patch_sample.options\n",
    "\n",
    "    # patch_sample = copy.deepcopy(patch_sample)\n",
    "    # patch_sample.options = [\"#\"]\n",
    "    # patch_sample.prompt_template = \"Which among these objects mentioned above is a <_category_>?\\nAnswer:\"\n",
    "\n",
    "    kq_sample_result = validate_k_proj_ie_on_sample_pair(\n",
    "        mt=mt,\n",
    "        clean_sample=clean_sample,\n",
    "        patch_sample=patch_sample,\n",
    "        # heads=heads_selected,\n",
    "        heads = optimized_heads,\n",
    "        query_indices={-3: -3, -2: -2, -1: -1},\n",
    "        verify_head_behavior_on=None,\n",
    "    )\n",
    "    kq_validation_results.append(kq_sample_result)\n",
    "    print(\"=\" * 80)\n",
    "\n",
    "    for token_type, token_id in kq_sample_result[\"track_tokens\"].items():\n",
    "        logger.info(f\"{token_type}={token_id}, [\\\"{mt.tokenizer.decode(token_id)}\\\"]\")\n",
    "        on_clean = kq_sample_result[\"clean_track\"][token_id]\n",
    "        on_int = kq_sample_result[\"kq_track\"][token_id]\n",
    "        logger.info(f\"Rank: {on_clean[0]} -> {on_int[0]} | Δ={on_int[0] - on_clean[0]}\")\n",
    "        logger.info(f\"Logit: {on_clean[1].logit:.4f} -> {on_int[1].logit:.4f} | Δ={on_int[1].logit - on_clean[1].logit:.4f}\")\n",
    "        print(\"=\" * 100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "31211dd8",
   "metadata": {},
   "outputs": [],
   "source": [
    "clean_scores = {}\n",
    "intervention_scores = {}\n",
    "token_types = [\"clean_obj\", \"q_target\", \"kq_target\"]\n",
    "attributes = [\"rank\", \"logit\"]\n",
    "keys = [f\"{tok_type}_{attr}\" for tok_type in token_types for attr in attributes]\n",
    "for key in keys:\n",
    "    clean_scores[key] = []\n",
    "    intervention_scores[key] = []\n",
    "\n",
    "for kq_result in kq_validation_results:\n",
    "    for token_type, token_id in kq_result[\"track_tokens\"].items():\n",
    "        on_clean = kq_result[\"clean_track\"][token_id]\n",
    "        on_int = kq_result[\"kq_track\"][token_id]\n",
    "        clean_scores[f\"{token_type}_rank\"].append(on_clean[0])\n",
    "        clean_scores[f\"{token_type}_logit\"].append(on_clean[1].logit)\n",
    "        intervention_scores[f\"{token_type}_rank\"].append(on_int[0])\n",
    "        intervention_scores[f\"{token_type}_logit\"].append(on_int[1].logit)\n",
    "\n",
    "\n",
    "for token_type in token_types:\n",
    "    for attribute in attributes:\n",
    "        key = f\"{token_type}_{attribute}\"\n",
    "        clean_array = np.array(clean_scores[key])\n",
    "        int_array = np.array(intervention_scores[key])\n",
    "        delta_array = int_array - clean_array\n",
    "        clean_report = f\"{clean_array.mean():.4f} ± {clean_array.std():.4f}\"\n",
    "        int_report = f\"{int_array.mean():.4f} ± {int_array.std():.4f}\"\n",
    "        delta_report = f\"{delta_array.mean():.4f} ± {delta_array.std():.4f}\"\n",
    "        print(f\"{key}: {clean_report} -> {int_report} |  Δ={delta_report}\")\n",
    "    print(\"=\" * 80)\n",
    "\n",
    "\n",
    "counter_patch_type_top_option = 0\n",
    "failed_qk_cases = []\n",
    "for intervention_result in kq_validation_results:\n",
    "    clean_sample = intervention_result[\"clean_sample\"]\n",
    "    patch_sample = intervention_result[\"patch_sample\"]\n",
    "    int_track = intervention_result[\"kq_track\"]\n",
    "    clean_track = intervention_result[\"clean_track\"]\n",
    "    if (\n",
    "        int_track[list(int_track.keys())[0]][1].token_id\n",
    "        == intervention_result[\"track_tokens\"][\"kq_target\"]\n",
    "    ): \n",
    "        counter_patch_type_top_option += 1\n",
    "    else:\n",
    "        failed_qk_cases.append(\n",
    "            {\n",
    "                \"clean_sample\": clean_sample,\n",
    "                \"patch_sample\": patch_sample,\n",
    "                \"int_track\": int_track,\n",
    "                \"clean_track\": clean_track,\n",
    "            }\n",
    "        )\n",
    "\n",
    "top_1_accuracy = counter_patch_type_top_option / len(kq_validation_results)\n",
    "print(\n",
    "    f\"Counterfactual patching accuracy: {top_1_accuracy:.4f} ({counter_patch_type_top_option}/{len(kq_validation_results)})\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e4d97ac4",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "connection",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
