{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f8ec34ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bfae7235",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "\n",
    "import sys\n",
    "\n",
    "sys.path.append(\"../\")\n",
    "\n",
    "##################################################################\n",
    "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
    "# os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0,1,2,3,4,5,6,7\"\n",
    "##################################################################\n",
    "\n",
    "import logging\n",
    "from src.utils import logging_utils\n",
    "from src.utils import env_utils\n",
    "\n",
    "logger = logging.getLogger(__name__)\n",
    "\n",
    "logging.basicConfig(\n",
    "    level=logging.DEBUG,\n",
    "    format=logging_utils.DEFAULT_FORMAT,\n",
    "    datefmt=logging_utils.DEFAULT_DATEFMT,\n",
    "    stream=sys.stdout,\n",
    ")\n",
    "\n",
    "import torch\n",
    "import transformers\n",
    "\n",
    "logger.info(f\"{torch.__version__=}, {torch.version.cuda=}\")\n",
    "logger.info(\n",
    "    f\"{torch.cuda.is_available()=}, {torch.cuda.device_count()=}, {torch.cuda.get_device_name()=}\"\n",
    ")\n",
    "logger.info(f\"{transformers.__version__=}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc7720ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.utils.training_utils import get_device_map\n",
    "\n",
    "# model_key = \"meta-llama/Llama-3.2-3B\"\n",
    "# model_key = \"meta-llama/Llama-3.1-8B\"\n",
    "model_key = \"meta-llama/Llama-3.3-70B-Instruct\"\n",
    "# model_key = \"meta-llama/Llama-3.1-405B-Instruct\"\n",
    "\n",
    "# model_key = \"google/gemma-2-9b-it\"\n",
    "# model_key = \"google/gemma-3-12b-it\"\n",
    "# model_key = \"google/gemma-2-27b-it\"\n",
    "\n",
    "# model_key = \"deepseek-ai/DeepSeek-R1-Distill-Llama-8B\"\n",
    "\n",
    "# model_key = \"allenai/OLMo-2-1124-7B-Instruct\"\n",
    "# model_key = \"allenai/OLMo-7B-0424-hf\"\n",
    "\n",
    "# model_key = \"Qwen/Qwen2-7B\"\n",
    "# model_key = \"Qwen/Qwen2.5-14B-Instruct\"\n",
    "# model_key = \"Qwen/Qwen2.5-32B-Instruct\"\n",
    "# model_key = \"Qwen/Qwen2.5-72B-Instruct\"\n",
    "\n",
    "# model_key = \"Qwen/Qwen3-1.7B\"\n",
    "# model_key = \"Qwen/Qwen3-4B\"\n",
    "# model_key = \"Qwen/Qwen3-8B\"\n",
    "# model_key = \"Qwen/Qwen3-14B\"\n",
    "# model_key = \"Qwen/Qwen3-32B\"\n",
    "\n",
    "# device_map = get_device_map(model_key, 30, n_gpus=8)\n",
    "# device_map"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "683855df",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.models import ModelandTokenizer\n",
    "\n",
    "# from transformers import BitsAndBytesConfig\n",
    "\n",
    "mt = ModelandTokenizer(\n",
    "    model_key=model_key,\n",
    "    torch_dtype=torch.bfloat16,\n",
    "    # device_map=device_map,\n",
    "    device_map=\"auto\",\n",
    "    # quantization_config = BitsAndBytesConfig(\n",
    "    #     # load_in_4bit=True\n",
    "    #     load_in_8bit=True\n",
    "    # )\n",
    "    attn_implementation=\"eager\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "080021e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# file_path = os.path.join(\n",
    "#     env_utils.DEFAULT_DATA_DIR,\n",
    "#     \"selection\",\n",
    "#     # \"profession.json\"\n",
    "#     # \"nationality.json\"\n",
    "#     \"objects.json\",\n",
    "# )\n",
    "\n",
    "# with open(file_path, \"r\") as f:\n",
    "#     temp = json.load(f)\n",
    "\n",
    "# for cat in temp[\"categories\"]:\n",
    "#     temp[\"categories\"][cat] = [obj.capitalize() for obj in temp[\"categories\"][cat]]\n",
    "\n",
    "# with open(file_path, \"w\") as f:\n",
    "#     json.dump(temp, f, indent=4)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a62d97cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.data import SelectOneTask, SelectOrderTask\n",
    "\n",
    "#################################################################################\n",
    "# TASK_CLS = SelectOrderTask\n",
    "# prompt_template_idx = 1\n",
    "TASK_CLS = SelectOneTask\n",
    "prompt_template_idx = 3\n",
    "N_DISTRACTORS = 5\n",
    "OPTION_STYLE = \"single_line\"\n",
    "#################################################################################\n",
    "\n",
    "select_task = TASK_CLS.load(\n",
    "    path=os.path.join(\n",
    "        env_utils.DEFAULT_DATA_DIR, \n",
    "        \"selection\", \n",
    "        # \"profession.json\"\n",
    "        # \"nationality.json\"\n",
    "        \"objects.json\"\n",
    "    )\n",
    ")\n",
    "select_task.filter_single_token(mt.tokenizer, prefix=\" \")\n",
    "print(select_task)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "45087a7f",
   "metadata": {},
   "outputs": [],
   "source": [
    "sample = select_task.get_random_sample(\n",
    "    mt = mt,\n",
    "    option_style=OPTION_STYLE,\n",
    "    prompt_template_idx=prompt_template_idx,\n",
    "    obj_idx=2,\n",
    "    # category=\"actor\",\n",
    "    # category=\"Brazil\"\n",
    "    category=\"fruit\",\n",
    "    filter_by_lm_prediction=False,\n",
    ")\n",
    "\n",
    "print(sample)\n",
    "print(sample.prompt())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f66399b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.utils import verify_correct_option\n",
    "# sample.prompt_template = select_prof.prompt_templates[3]\n",
    "print(f'\"{sample.prompt()}\"', \">>\", sample.obj)\n",
    "\n",
    "verify_correct_option(\n",
    "    mt=mt,\n",
    "    target=sample.obj,\n",
    "    options=sample.options,\n",
    "    input=sample.prompt()\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea0499b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.functional import generate_with_patch\n",
    "\n",
    "gen = generate_with_patch(\n",
    "    mt = mt,\n",
    "    inputs = sample.prompt(),\n",
    "    max_new_tokens=20,\n",
    "    do_sample=False,\n",
    "    remove_prefix=True\n",
    ")[0]\n",
    "print(f'\"{gen}\"', \">>\", sample.obj)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2e8a8aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "mt.n_layer, mt.config.num_attention_heads"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a35516e",
   "metadata": {},
   "outputs": [],
   "source": [
    "llama_70_heads = [\n",
    "    (33, 45),\n",
    "    (33, 18),\n",
    "    (34, 1),\n",
    "    (34, 6),\n",
    "    (34, 7),\n",
    "    (35, 19),\n",
    "    (39, 40),\n",
    "    (42, 30),\n",
    "    (47, 18),\n",
    "    (52, 58),\n",
    "]\n",
    "\n",
    "qwen_72_heads = [\n",
    "    (62, 1),\n",
    "    (60, 9),\n",
    "    (64, 8),\n",
    "    (62, 0),\n",
    "    (62, 45),\n",
    "    (59, 59),\n",
    "    (71, 28),\n",
    "    (64, 12),\n",
    "    (61, 7),\n",
    "    (64, 13),\n",
    "    (67, 53),\n",
    "    (67, 51),\n",
    "    (54, 44),\n",
    "    (57, 5),\n",
    "    (59, 60),\n",
    "    (71, 25),\n",
    "    (62, 7),\n",
    "    (64, 9),\n",
    "    (62, 23),\n",
    "    (65, 40),\n",
    "]\n",
    "\n",
    "qwen_32_heads = [\n",
    "    (51, 11),\n",
    "    (48, 4),\n",
    "    (52, 21),\n",
    "    (54, 35),\n",
    "    (48, 8),\n",
    "    (50, 6),\n",
    "    (48, 9),\n",
    "    (48, 32),\n",
    "    (52, 10),\n",
    "    (45, 11),\n",
    "    (45, 13),\n",
    "    (48, 34),\n",
    "    (53, 16),\n",
    "    (50, 12),\n",
    "    (49, 2),\n",
    "    (54, 38),\n",
    "    (55, 4),\n",
    "    (50, 27),\n",
    "    (54, 33),\n",
    "    (50, 14),\n",
    "]\n",
    "\n",
    "\n",
    "# HEADS = [(35, 19)]\n",
    "\n",
    "\n",
    "# with open(\"optimized_heads.json\", \"r\") as f:\n",
    "#     HEADS = json.load(f)\n",
    "\n",
    "# with open(\"category_wise_heads.json\", \"r\") as f:\n",
    "#     category_wise_heads = json.load(f)\n",
    "# HEADS = [\n",
    "#     (layer_idx, head_idx)\n",
    "#     for layer_idx, head_idx, score in category_wise_heads[\"all\"][:100]\n",
    "# ]\n",
    "# HEADS = [(layer_idx, head_idx) for layer_idx, head_idx in HEADS if layer_idx < 61]\n",
    "\n",
    "# HEADS = qwen_32_heads\n",
    "HEADS = llama_70_heads\n",
    "print(len(HEADS))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe381b2f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib import pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "# optimized_path = os.path.join(\n",
    "#     env_utils.DEFAULT_RESULTS_DIR,\n",
    "#     \"selection/optimized_backup_heads\",\n",
    "#     mt.name.split(\"/\")[-1],\n",
    "#     f\"{select_task.task_name}.npz\"\n",
    "# )\n",
    "\n",
    "optimized_path = os.path.join(\n",
    "    env_utils.DEFAULT_RESULTS_DIR,\n",
    "    \"selection/optimized_heads\",\n",
    "    # mt.name.split(\"/\")[-1],\n",
    "    model_key.split(\"/\")[-1],\n",
    "    \"distinct_options\",\n",
    "    f\"{select_task.task_name}\",\n",
    "    \"epoch_10.npz\"\n",
    ")\n",
    "\n",
    "optimization_results = np.load(optimized_path, allow_pickle=True)\n",
    "plt.plot(optimization_results[\"losses\"])\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f87810f",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(20, 10))\n",
    "\n",
    "optimal_head_mask = torch.tensor(optimization_results[\"optimal_mask\"]).to(torch.float32)\n",
    "optimal_head_mask[52:, :] = 0.0\n",
    "\n",
    "plt.imshow(\n",
    "    optimal_head_mask.T.numpy(),\n",
    "    cmap=\"Blues\",\n",
    "    aspect=\"auto\",\n",
    "    vmin=0,\n",
    "    vmax=1,\n",
    ")\n",
    "\n",
    "heads_selected = torch.nonzero(optimal_head_mask > 0.5, as_tuple=False).tolist()\n",
    "heads_selected = [\n",
    "    (layer_idx, head_idx) for layer_idx, head_idx in heads_selected if layer_idx < 50\n",
    "]\n",
    "print(len(heads_selected))\n",
    "\n",
    "HEADS = heads_selected\n",
    "\n",
    "(35, 19) in HEADS, (35, 19) in heads_selected"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "85a9a2a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.attention import get_attention_matrices\n",
    "from src.selection.functional import (\n",
    "    verify_head_patterns,\n",
    "    get_patches_to_verify_independent_enrichment,\n",
    ")\n",
    "\n",
    "attn_pattern = verify_head_patterns(\n",
    "    prompt=sample.prompt(option_style=\"single_line\"),\n",
    "    options=sample.options,\n",
    "    pivot=sample.subj,\n",
    "    mt=mt,\n",
    "    heads=HEADS,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "425f6285",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.data import get_counterfactual_samples_within_task\n",
    "\n",
    "patch_sample, clean_sample = get_counterfactual_samples_within_task(\n",
    "    mt=mt,\n",
    "    task=select_task,\n",
    "    patch_category=\"fruit\",\n",
    "    clean_category=\"vehicle\",\n",
    "    filter_by_lm_prediction=True,\n",
    "    prompt_template_idx=prompt_template_idx,\n",
    "    option_style=OPTION_STYLE,\n",
    "    distinct_options=True,\n",
    "    n_distractors = N_DISTRACTORS,\n",
    ")\n",
    "\n",
    "clean_sample.metadata"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bab527e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "patch_sample.metadata, clean_sample.metadata"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "86d35a26",
   "metadata": {},
   "outputs": [],
   "source": [
    "# order_sample_1 = SelectionSample(\n",
    "#     subj=\"random\",\n",
    "#     category=\"test_order\",\n",
    "#     options=[\"Bike\", \"Apple\", \"Bed\", \"Dog\", \"Monitor\", \"Theater\"],\n",
    "#     obj=\"Apple\",\n",
    "#     obj_idx=1,\n",
    "#     prompt_template=\"<_options_>\\nWhat is the third item in the list?\\nAnswer:\",\n",
    "#     answer=\"Apple\",\n",
    "# )\n",
    "\n",
    "# order_sample_2 = SelectionSample(\n",
    "#     subj=\"random\",\n",
    "#     category=\"test_order\",\n",
    "#     options=[\"Cat\", \"Chair\", \"Bus\", \"Phone\", \"Library\", \"Orange\"],\n",
    "#     obj=\"Phone\",\n",
    "#     obj_idx=3,\n",
    "#     prompt_template=\"<_options_>\\nWhat is the fifth item in the list?\\nAnswer:\",\n",
    "#     answer=\"Phone\",\n",
    "# )\n",
    "len(HEADS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "510772fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.functional import generate_with_patch\n",
    "\n",
    "# patch_sample.options[patch_sample.obj_idx] = \"Screw\"\n",
    "# patch_sample.options[patch_sample.obj_idx] = patch_sample.obj\n",
    "\n",
    "for sample in [patch_sample, clean_sample]:\n",
    "# for sample in [order_sample_1, order_sample_2]:\n",
    "    print(sample.prompt(), \">>\", sample.obj)\n",
    "    attn_pattern = verify_head_patterns(\n",
    "        prompt=sample.prompt(),\n",
    "        options=sample.options,\n",
    "        pivot=sample.subj,\n",
    "        mt=mt,\n",
    "        heads=HEADS,\n",
    "        generate_full_answer=True,\n",
    "        query_index=-1\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0b4d487f",
   "metadata": {},
   "outputs": [],
   "source": [
    "clean_sample.metadata"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f414b376",
   "metadata": {},
   "source": [
    "## Validation of the patching effect"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "00698133",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.data import CounterFactualSamplePair\n",
    "from src.functional import free_gpu_cache\n",
    "from src.selection.data import get_counterfactual_samples_interface\n",
    "import random\n",
    "\n",
    "validation_samples_save_path = os.path.join(\n",
    "    env_utils.DEFAULT_RESULTS_DIR,\n",
    "    \"selection\",\n",
    "    \"samples\",\n",
    "    \"validation\",\n",
    "    mt.name.split(\"/\")[-1],\n",
    "    select_task.task_name,\n",
    "    \"objects\",\n",
    "    # \"profession\",\n",
    "    # \"nationality\",\n",
    "    # \"landmarks\"\n",
    ")\n",
    "\n",
    "os.makedirs(validation_samples_save_path, exist_ok=True)\n",
    "\n",
    "\n",
    "free_gpu_cache()\n",
    "validation_set = []\n",
    "validation_limit = 256\n",
    "\n",
    "counterfactual_sampler = get_counterfactual_samples_interface[select_task.task_name]\n",
    "\n",
    "while len(validation_set) < validation_limit:\n",
    "    print(f\"sample {len(validation_set)+1} / {validation_limit}\")\n",
    "    patch, clean = counterfactual_sampler(\n",
    "        mt=mt,\n",
    "        task=select_task,\n",
    "        filter_by_lm_prediction=True,\n",
    "        prompt_template_idx=3,\n",
    "        option_style=OPTION_STYLE,\n",
    "        n_distractors=random.choice(range(2, 6)),\n",
    "    )\n",
    "    validation_set.append((clean, patch))\n",
    "    cf_pair = CounterFactualSamplePair(\n",
    "        patch_sample=patch,\n",
    "        clean_sample=clean,\n",
    "    )\n",
    "    cf_pair.detensorize()\n",
    "    with open(\n",
    "        os.path.join(validation_samples_save_path, f\"{len(validation_set):05d}.json\"),\n",
    "        \"w\",\n",
    "    ) as f:\n",
    "        json.dump(cf_pair.to_dict(), f, indent=2)\n",
    "\n",
    "len(validation_set)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0bd4abb4",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.data import CounterFactualSamplePair\n",
    "import random\n",
    "\n",
    "validation_set = []\n",
    "validation_limit = 1024\n",
    "\n",
    "validation_samples_load_path = os.path.join(\n",
    "    env_utils.DEFAULT_RESULTS_DIR,\n",
    "    \"selection\",\n",
    "    \"samples\",\n",
    "    \"validation\",\n",
    "    mt.name.split(\"/\")[-1],\n",
    "    select_task.task_name,\n",
    "    \"objects\",\n",
    "    # \"profession\",\n",
    "    # \"nationality\"\n",
    "    # \"landmarks\"\n",
    ")\n",
    "\n",
    "sample_files = [\n",
    "    os.path.join(validation_samples_load_path, f)\n",
    "    for f in os.listdir(validation_samples_load_path)\n",
    "    if f.endswith(\".json\")\n",
    "]\n",
    "logger.info(f\"Found {len(sample_files)} sample files\")\n",
    "\n",
    "random.shuffle(sample_files)\n",
    "sample_files = sample_files[:validation_limit]\n",
    "for sample_file in sample_files:\n",
    "    with open(sample_file, \"r\") as f:\n",
    "        cf_pair_data = json.load(f)\n",
    "    cf_pair = CounterFactualSamplePair.from_dict(cf_pair_data)\n",
    "    validation_set.append((cf_pair.clean_sample, cf_pair.patch_sample))\n",
    "\n",
    "len(validation_set)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cbf5ede5",
   "metadata": {},
   "outputs": [],
   "source": [
    "clean, patch = validation_set[3]\n",
    "print(patch.prompt(), \">>\", mt.tokenizer.decode(patch.ans_token_id))\n",
    "print(clean.prompt(), \">>\", mt.tokenizer.decode(clean.ans_token_id))\n",
    "clean.metadata[\"track_type_obj_token_id\"], mt.tokenizer.decode(clean.metadata[\"track_type_obj_token_id\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3d041d50",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.optimization import validate_q_proj_ie_on_sample_pair\n",
    "import copy\n",
    "\n",
    "clean, patch = copy.deepcopy(clean_sample), copy.deepcopy(patch_sample)\n",
    "# clean, patch = copy.deepcopy(validation_set[18])\n",
    "# clean.default_option_style=\"numbered\"\n",
    "# patch.default_option_style=\"numbered\"\n",
    "\n",
    "# failed_case = failed_cases[17]\n",
    "# clean = failed_case[\"clean_sample\"]\n",
    "# patch = failed_case[\"patch_sample\"]\n",
    "\n",
    "print(clean.prompt(), \">>\", clean.obj)\n",
    "print(patch.prompt(), \">>\", patch.obj)\n",
    "\n",
    "validation_result = validate_q_proj_ie_on_sample_pair(\n",
    "    mt=mt,\n",
    "    clean_sample=clean,\n",
    "    patch_sample=patch,\n",
    "    # heads=HEADS,\n",
    "    heads = heads_selected,\n",
    "    # heads = overlapping_heads,\n",
    "    # heads=[(35, 19)],\n",
    "    query_indices={-3: -3, -2: -2, -1: -1},\n",
    "    verify_head_behavior_on=-1,\n",
    "    # ablate_possible_ans_info_from_options=True,\n",
    "    # amplification_scale=2.0\n",
    ")\n",
    "\n",
    "clean_obj = clean.ans_token_id\n",
    "target_obj = clean.metadata[\"track_type_obj_token_id\"]\n",
    "\n",
    "logger.debug(f\"clean obj: {mt.tokenizer.decode(clean_obj)}\")\n",
    "logger.debug(f\"target obj: {mt.tokenizer.decode(target_obj)}\")\n",
    "\n",
    "before_intervention = {\n",
    "    \"clean_rank\": validation_result[\"clean_track\"][clean_obj][0],\n",
    "    \"clean_logit\": validation_result[\"clean_track\"][clean_obj][1].logit,\n",
    "    \"target_rank\": validation_result[\"clean_track\"][target_obj][0],\n",
    "    \"target_logit\": validation_result[\"clean_track\"][target_obj][1].logit,\n",
    "}\n",
    "\n",
    "after_intervention = {\n",
    "    \"clean_rank\": validation_result[\"int_track\"][clean_obj][0],\n",
    "    \"clean_logit\": validation_result[\"int_track\"][clean_obj][1].logit,\n",
    "    \"target_rank\": validation_result[\"int_track\"][target_obj][0],\n",
    "    \"target_logit\": validation_result[\"int_track\"][target_obj][1].logit,\n",
    "}\n",
    "\n",
    "clean_rank_delta = after_intervention[\"clean_rank\"] - before_intervention[\"clean_rank\"]\n",
    "target_rank_delta = after_intervention[\"target_rank\"] - before_intervention[\"target_rank\"]\n",
    "logger.info(f\"Clean Prediction Rank Change: {before_intervention['clean_rank']} -> {after_intervention['clean_rank']} | Delta: {clean_rank_delta} \")\n",
    "logger.info(f\"Target Prediction Rank Change: {before_intervention['target_rank']} -> {after_intervention['target_rank']} | Delta: {target_rank_delta} \")\n",
    "\n",
    "clean_logit_delta = after_intervention[\"clean_logit\"] - before_intervention[\"clean_logit\"]\n",
    "target_logit_delta =  after_intervention[\"target_logit\"] - before_intervention[\"target_logit\"]\n",
    "logger.info(f\"Clean Prediction Logit Change: {before_intervention['clean_logit']:.4f} -> {after_intervention['clean_logit']:.4f} | Delta: {clean_logit_delta:.4f} \")\n",
    "logger.info(f\"Target Prediction Logit Change: {before_intervention['target_logit']:.4f} -> {after_intervention['target_logit']:.4f} | Delta: {target_logit_delta:.4f} \")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4b79f336",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.utils.typing import TokenizerOutput\n",
    "from src.hooking.llama_attention import AttentionEdge, LlamaAttentionPatcher\n",
    "import baukit\n",
    "import types\n",
    "from src.tokens import prepare_input\n",
    "from src.functional import interpret_logits\n",
    "from src.selection.utils import get_first_token_id\n",
    "import random\n",
    "\n",
    "\n",
    "def ablate_attn_heads(\n",
    "    mt: ModelandTokenizer,\n",
    "    input: TokenizerOutput,\n",
    "    heads: list[tuple[int, int]],\n",
    "    block_edges: list[AttentionEdge],\n",
    "):\n",
    "    default_attn_implementation = mt.config._attn_implementation\n",
    "    mt.reset_forward()\n",
    "    mt.set_attn_implementation(\"sdpa\")\n",
    "\n",
    "    layers_to_heads = {}\n",
    "    for layer_idx, head_idx in heads:\n",
    "        if layer_idx not in layers_to_heads:\n",
    "            layers_to_heads[layer_idx] = []\n",
    "        layers_to_heads[layer_idx].append(head_idx)\n",
    "\n",
    "    for layer_idx in sorted(list(layers_to_heads.keys())):\n",
    "        attn_block_name = mt.attn_module_name_format.format(layer_idx)\n",
    "        attn_block = baukit.get_module(mt._model, attn_block_name)\n",
    "        attn_block.forward = types.MethodType(\n",
    "            LlamaAttentionPatcher(\n",
    "                block_name=attn_block_name,\n",
    "                cut_attn_edges={\n",
    "                    head_idx: block_edges for head_idx in layers_to_heads[layer_idx]\n",
    "                },\n",
    "            ),\n",
    "            attn_block,\n",
    "        )\n",
    "\n",
    "    with mt.trace(input) as tr:\n",
    "        logits = mt.output.logits[:, -1, :].save()\n",
    "\n",
    "    mt.reset_forward()\n",
    "    mt.set_attn_implementation(default_attn_implementation)\n",
    "\n",
    "    return logits\n",
    "\n",
    "\n",
    "def get_edges_to_block(\n",
    "    input: TokenizerOutput,\n",
    "    whitelist_key_indices: list[int] = [0],\n",
    "    query_indices: list[int] = [-1],\n",
    "):\n",
    "    block_edges: list[AttentionEdge] = []\n",
    "    for idx in range(len(whitelist_key_indices)):\n",
    "        if whitelist_key_indices[idx] < 0:\n",
    "            whitelist_key_indices[idx] = (\n",
    "                input.input_ids.shape[1] + whitelist_key_indices[idx]\n",
    "            )\n",
    "    for idx in range(len(query_indices)):\n",
    "        if query_indices[idx] < 0:\n",
    "            query_indices[idx] = input.input_ids.shape[1] + query_indices[idx]\n",
    "\n",
    "    for key_idx in range(1, input.input_ids.shape[1]):\n",
    "        if key_idx in whitelist_key_indices:\n",
    "            continue\n",
    "        for query_idx in query_indices:\n",
    "            if query_idx < key_idx:\n",
    "                continue  # autoregressive LM\n",
    "            block_edges.append(AttentionEdge(q_idx=query_idx, k_idx=key_idx))\n",
    "    return block_edges\n",
    "\n",
    "\n",
    "mt.reset_forward()\n",
    "mt.set_attn_implementation(\"eager\")\n",
    "\n",
    "answer_token = get_first_token_id(\n",
    "    name=clean_sample.obj, tokenizer=mt.tokenizer, prefix=\" \"\n",
    ")\n",
    "\n",
    "# clean_run\n",
    "clean_tokenized = prepare_input(tokenizer=mt.tokenizer, prompts=clean_sample.prompt())\n",
    "print(clean_tokenized.input_ids.shape)\n",
    "clean_logits = ablate_attn_heads(mt=mt, input=clean_tokenized, heads=[], block_edges=[])\n",
    "clean_pred, clean_track = interpret_logits(\n",
    "    tokenizer=mt.tokenizer,\n",
    "    logits=clean_logits,\n",
    "    interested_tokens=[answer_token],\n",
    ")\n",
    "logger.info(f\"clean_pred={[str(pred) for pred in clean_pred]}\")\n",
    "logger.info(f\"{clean_track=}\")\n",
    "\n",
    "# ablated_run\n",
    "all_heads = [(layer, head) for layer in range(mt.config.num_hidden_layers) for head in range(mt.config.num_attention_heads)]\n",
    "all_heads = list(set(all_heads) - set(HEADS))\n",
    "random_heads = random.sample(all_heads, len(heads_selected))\n",
    "ablated_logits = ablate_attn_heads(\n",
    "    mt=mt,\n",
    "    input=clean_tokenized,\n",
    "    # heads=random_heads,\n",
    "    heads=heads_selected,\n",
    "    block_edges=get_edges_to_block(\n",
    "        input=clean_tokenized, whitelist_key_indices=[0], query_indices=[-1]\n",
    "    ),\n",
    ")\n",
    "ablated_pred, ablated_track = interpret_logits(\n",
    "    tokenizer=mt.tokenizer,\n",
    "    logits=ablated_logits,\n",
    "    interested_tokens=[answer_token],\n",
    ")\n",
    "\n",
    "logger.info(f\"ablated_pred={[str(pred) for pred in ablated_pred]}\")\n",
    "logger.info(f\"{ablated_track=}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2cc52189",
   "metadata": {},
   "outputs": [],
   "source": [
    "clean_tokenized = prepare_input(tokenizer=mt.tokenizer, prompts=clean_sample.prompt())\n",
    "\n",
    "\n",
    "block_edges = get_edges_to_block(\n",
    "    input=clean_tokenized, whitelist_key_indices=[0], query_indices=[-1]\n",
    ")\n",
    "\n",
    "mt.reset_forward()\n",
    "mt.set_attn_implementation(\"sdpa\")\n",
    "\n",
    "layers_to_heads = {}\n",
    "for layer_idx, head_idx in [(35, 19)]:\n",
    "    if layer_idx not in layers_to_heads:\n",
    "        layers_to_heads[layer_idx] = []\n",
    "    layers_to_heads[layer_idx].append(head_idx)\n",
    "\n",
    "attn_matrices = {}\n",
    "for layer_idx in sorted(list(layers_to_heads.keys())):\n",
    "    attn_block_name = mt.attn_module_name_format.format(layer_idx)\n",
    "    attn_block = baukit.get_module(mt._model, attn_block_name)\n",
    "    attn_matrices[layer_idx] = {}\n",
    "    attn_block.forward = types.MethodType(\n",
    "        LlamaAttentionPatcher(\n",
    "            block_name=attn_block_name,\n",
    "            cut_attn_edges={\n",
    "                head_idx: block_edges for head_idx in layers_to_heads[layer_idx]\n",
    "            },\n",
    "            save_attn_for=layers_to_heads[layer_idx],\n",
    "            store_attn_matrices=attn_matrices[layer_idx]\n",
    "        ),\n",
    "        attn_block,\n",
    "    )\n",
    "\n",
    "with mt.trace(clean_tokenized) as tr:\n",
    "    logits = mt.output.logits[:, -1, :].save()\n",
    "\n",
    "mt.reset_forward()\n",
    "mt.set_attn_implementation(\"eager\")\n",
    "\n",
    "interpret_logits(tokenizer=mt.tokenizer, logits=logits)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7b2277fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "attn_matrices[35][19].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "69e3446c",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.attention import visualize_attn_matrix\n",
    "\n",
    "visualize_attn_matrix(\n",
    "    attn_matrix=attn_matrices[35][19].squeeze(),\n",
    "    tokens = [\n",
    "        mt.tokenizer.decode(token_id) for token_id in clean_tokenized.input_ids[0]\n",
    "    ],\n",
    "    q_index = -1,\n",
    "    start_from=0\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c530620c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d8454178",
   "metadata": {},
   "outputs": [],
   "source": [
    "results = []\n",
    "\n",
    "for clean_sample, _ in validation_set:\n",
    "    print(clean_sample.prompt(), \">>\", clean_sample.obj)\n",
    "    # clean_run\n",
    "    clean_tokenized = prepare_input(\n",
    "        tokenizer=mt.tokenizer, prompts=clean_sample.prompt()\n",
    "    )\n",
    "    answer_token = get_first_token_id(\n",
    "        tokenizer=mt.tokenizer, name=clean_sample.obj, prefix=\" \"\n",
    "    )\n",
    "    clean_logits = ablate_attn_heads(\n",
    "        mt=mt, input=clean_tokenized, heads=[], block_edges=[]\n",
    "    )\n",
    "    clean_pred, clean_track = interpret_logits(\n",
    "        tokenizer=mt.tokenizer,\n",
    "        logits=clean_logits,\n",
    "        interested_tokens=[answer_token],\n",
    "    )\n",
    "    logger.info(f\"clean_pred={[str(pred) for pred in clean_pred]}\")\n",
    "    logger.info(f\"{clean_track=}\")\n",
    "    print(\"-\"*50)\n",
    "\n",
    "    # ablated_run\n",
    "    block_edges = get_edges_to_block(\n",
    "        input=clean_tokenized, whitelist_key_indices=[0], query_indices=[-1]\n",
    "    )\n",
    "    ablated_logits = ablate_attn_heads(\n",
    "        mt=mt, input=clean_tokenized, heads=heads_selected, block_edges=block_edges\n",
    "    )\n",
    "    ablated_pred, ablated_track = interpret_logits(\n",
    "        tokenizer=mt.tokenizer,\n",
    "        logits=ablated_logits,\n",
    "        interested_tokens=[answer_token],\n",
    "    )\n",
    "    logger.info(f\"ablated_pred={[str(pred) for pred in ablated_pred]}\")\n",
    "    logger.info(f\"{ablated_track=}\")\n",
    "    print(\"-\"*50)\n",
    "\n",
    "    # random ablated\n",
    "    random_heads = random.sample(\n",
    "        [\n",
    "            (layer, head)\n",
    "            for layer in range(mt.config.num_hidden_layers)\n",
    "            for head in range(mt.config.num_attention_heads)\n",
    "        ],\n",
    "        len(heads_selected),\n",
    "    )\n",
    "    random_ablated_logits = ablate_attn_heads(\n",
    "        mt=mt, input=clean_tokenized, heads=random_heads, block_edges=block_edges\n",
    "    )\n",
    "    random_ablated_pred, random_ablated_track = interpret_logits(\n",
    "        tokenizer=mt.tokenizer,\n",
    "        logits=random_ablated_logits,\n",
    "        interested_tokens=[answer_token],\n",
    "    )\n",
    "    logger.info(f\"random_ablated_pred={[str(pred) for pred in random_ablated_pred]}\")\n",
    "    logger.info(f\"{random_ablated_track=}\")\n",
    "\n",
    "    results.append({\n",
    "        \"sample\": clean_sample,\n",
    "        \"answer_token\": answer_token,\n",
    "        \"clean_pred\": clean_pred,\n",
    "        \"clean_track\": clean_track,\n",
    "        \"ablated_pred\": ablated_pred,\n",
    "        \"ablated_track\": ablated_track,\n",
    "        \"random_ablated_pred\": random_ablated_pred,\n",
    "        \"random_ablated_track\": random_ablated_track,\n",
    "    })\n",
    "    print(\"=\"*100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "357d013d",
   "metadata": {},
   "outputs": [],
   "source": [
    "clean_logits = []\n",
    "clean_rank = []\n",
    "ablated_logits = []\n",
    "ablated_rank = []\n",
    "random_ablated_logits = []\n",
    "random_ablated_rank = []\n",
    "\n",
    "for sample_res in results:\n",
    "    clean_logits.append(sample_res[\"clean_track\"][sample_res[\"answer_token\"]][1].logit)\n",
    "    clean_rank.append(sample_res[\"clean_track\"][sample_res[\"answer_token\"]][0])\n",
    "    \n",
    "    ablated_logits.append(sample_res[\"ablated_track\"][sample_res[\"answer_token\"]][1].logit)\n",
    "    ablated_rank.append(sample_res[\"ablated_track\"][sample_res[\"answer_token\"]][0])\n",
    "\n",
    "    random_ablated_logits.append(sample_res[\"random_ablated_track\"][sample_res[\"answer_token\"]][1].logit)\n",
    "    random_ablated_rank.append(sample_res[\"random_ablated_track\"][sample_res[\"answer_token\"]][0])\n",
    "\n",
    "clean_logits = np.array(clean_logits)\n",
    "clean_rank = np.array(clean_rank)\n",
    "ablated_logits = np.array(ablated_logits)\n",
    "ablated_rank = np.array(ablated_rank)\n",
    "random_ablated_logits = np.array(random_ablated_logits)\n",
    "random_ablated_rank = np.array(random_ablated_rank)\n",
    "\n",
    "print(f\"Clean Run >> logits={clean_logits.mean():.4f}±{clean_logits.std():.4f}, rank={clean_rank.mean():.2f}±{clean_rank.std():.2f}\")\n",
    "print(f\"Ablated Run >> logits={ablated_logits.mean():.4f}±{ablated_logits.std():.4f}, rank={ablated_rank.mean():.2f}±{ablated_rank.std():.2f}\")\n",
    "print(f\"Random Ablated Run >> logits={random_ablated_logits.mean():.4f}±{random_ablated_logits.std():.4f}, rank={random_ablated_rank.mean():.2f}±{random_ablated_rank.std():.2f}\")\n",
    "\n",
    "clean_accuracy = (clean_rank == 1).mean()\n",
    "ablated_accuracy = (ablated_rank == 1).mean()\n",
    "random_ablated_accuracy = (random_ablated_rank == 1).mean()\n",
    "\n",
    "print(f\"Clean Run >> accuracy={clean_accuracy:.4f}\")\n",
    "print(f\"Ablated Run >> accuracy={ablated_accuracy:.4f}\")\n",
    "print(f\"Random Ablated Run >> accuracy={random_ablated_accuracy:.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23decf0c",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "connection",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
