{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f8ec34ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bfae7235",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "\n",
    "import sys\n",
    "\n",
    "sys.path.append(\"../\")\n",
    "\n",
    "##################################################################\n",
    "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
    "# os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0,1,2,3,4,5,6,7\"\n",
    "##################################################################\n",
    "\n",
    "import logging\n",
    "from src.utils import logging_utils\n",
    "from src.utils import env_utils\n",
    "\n",
    "logger = logging.getLogger(__name__)\n",
    "\n",
    "logging.basicConfig(\n",
    "    level=logging.DEBUG,\n",
    "    format=logging_utils.DEFAULT_FORMAT,\n",
    "    datefmt=logging_utils.DEFAULT_DATEFMT,\n",
    "    stream=sys.stdout,\n",
    ")\n",
    "\n",
    "import torch\n",
    "import transformers\n",
    "\n",
    "logger.info(f\"{torch.__version__=}, {torch.version.cuda=}\")\n",
    "logger.info(\n",
    "    f\"{torch.cuda.is_available()=}, {torch.cuda.device_count()=}, {torch.cuda.get_device_name()=}\"\n",
    ")\n",
    "logger.info(f\"{transformers.__version__=}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc7720ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.utils.training_utils import get_device_map\n",
    "\n",
    "# model_key = \"meta-llama/Llama-3.2-3B\"\n",
    "# model_key = \"meta-llama/Llama-3.1-8B\"\n",
    "model_key = \"meta-llama/Llama-3.3-70B-Instruct\"\n",
    "# model_key = \"meta-llama/Llama-3.1-405B-Instruct\"\n",
    "\n",
    "# model_key = \"google/gemma-2-9b-it\"\n",
    "# model_key = \"google/gemma-3-12b-it\"\n",
    "# model_key = \"google/gemma-2-27b-it\"\n",
    "\n",
    "# model_key = \"deepseek-ai/DeepSeek-R1-Distill-Llama-8B\"\n",
    "\n",
    "# model_key = \"allenai/OLMo-2-1124-7B-Instruct\"\n",
    "# model_key = \"allenai/OLMo-7B-0424-hf\"\n",
    "\n",
    "# model_key = \"Qwen/Qwen2-7B\"\n",
    "# model_key = \"Qwen/Qwen2.5-14B-Instruct\"\n",
    "# model_key = \"Qwen/Qwen2.5-32B-Instruct\"\n",
    "# model_key = \"Qwen/Qwen2.5-72B-Instruct\"\n",
    "\n",
    "# model_key = \"Qwen/Qwen3-1.7B\"\n",
    "# model_key = \"Qwen/Qwen3-4B\"\n",
    "# model_key = \"Qwen/Qwen3-8B\"\n",
    "# model_key = \"Qwen/Qwen3-14B\"\n",
    "# model_key = \"Qwen/Qwen3-32B\"\n",
    "\n",
    "# device_map = get_device_map(model_key, 30, n_gpus=8)\n",
    "# device_map"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "683855df",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.models import ModelandTokenizer\n",
    "\n",
    "# from transformers import BitsAndBytesConfig\n",
    "\n",
    "mt = ModelandTokenizer(\n",
    "    model_key=model_key,\n",
    "    torch_dtype=torch.bfloat16,\n",
    "    # device_map=device_map,\n",
    "    device_map=\"auto\",\n",
    "    # quantization_config = BitsAndBytesConfig(\n",
    "    #     # load_in_4bit=True\n",
    "    #     load_in_8bit=True\n",
    "    # )\n",
    "    attn_implementation=\"eager\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "080021e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Optional\n",
    "from src.selection.data import SelectionSample\n",
    "from src.tokens import prepare_input\n",
    "\n",
    "\n",
    "from src.functional import (\n",
    "    get_module_nnsight,\n",
    "    PatchSpec,\n",
    "    interpret_logits,\n",
    "    patch_with_baukit,\n",
    ")\n",
    "from src.selection.functional import get_first_token_id\n",
    "from src.utils.typing import TokenizerOutput\n",
    "from src.attention import get_attention_matrices\n",
    "from src.selection.functional import (\n",
    "    verify_head_patterns,\n",
    "    cache_q_projections,\n",
    "    validate_q_proj_ie_on_sample_pair,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6260a4d6",
   "metadata": {},
   "source": [
    "## Loading the heads"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d45e4474",
   "metadata": {},
   "source": [
    "### Attention Behavior"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a08c77ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "from scripts.locate_resolution_heads import SelectionSampleAttn\n",
    "from tqdm import tqdm\n",
    "attn_path = os.path.join(\n",
    "    env_utils.DEFAULT_RESULTS_DIR,\n",
    "    \"selection/attention_patterns\",\n",
    "    \"select_odd_one_out\",\n",
    "    mt.name.split(\"/\")[-1],\n",
    "    \"objects\"\n",
    ")\n",
    "files = sorted(os.listdir(attn_path))\n",
    "\n",
    "#######################################################################\n",
    "# LIMIT = 100\n",
    "LIMIT = len(files)\n",
    "#######################################################################\n",
    "\n",
    "selection_attns = []\n",
    "\n",
    "for npz_file in tqdm(files[:LIMIT]):\n",
    "    if not npz_file.endswith(\".npz\"):\n",
    "        continue\n",
    "\n",
    "    npz_path = os.path.join(attn_path, npz_file)\n",
    "    selection_attns.append(SelectionSampleAttn.from_npz(npz_path))\n",
    "    # if len(selection_attns) % 10 == 0:\n",
    "    #     print(f\"Loaded {len(selection_attns)}/{LIMIT} files\")\n",
    "\n",
    "len(selection_attns)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7ca53542",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.attention import visualize_attn_matrix\n",
    "\n",
    "sample_idx = 18\n",
    "# layer_idx = 39\n",
    "# head_idx = 40\n",
    "layer_idx = 35\n",
    "head_idx = 19\n",
    "\n",
    "selection_attn = selection_attns[sample_idx]\n",
    "print(selection_attn.resolution_score(layer_idx, head_idx))\n",
    "visualize_attn_matrix(\n",
    "    attn_matrix=selection_attn.attention_pattern.attention_matrices[layer_idx, head_idx],\n",
    "    tokens=selection_attn.attention_pattern.tokenized_prompt,\n",
    "    q_index=-1,\n",
    "    start_from=1,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b78a339b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm.auto import tqdm\n",
    "\n",
    "#############################################################################\n",
    "n_layer = mt.n_layer\n",
    "n_head = mt.config.num_attention_heads\n",
    "# token_idx = \"all\"\n",
    "token_idx = \"last\"\n",
    "##############################################################################\n",
    "\n",
    "resolution_scores = torch.zeros((n_head, n_layer), dtype=torch.float32)\n",
    "for selection_attn in tqdm(selection_attns):\n",
    "    for layer_idx in range(n_layer):\n",
    "        for head_idx in range(n_head):\n",
    "            resolution_scores[head_idx, layer_idx] += selection_attn.resolution_score(\n",
    "                layer_idx, head_idx, token_idx=token_idx\n",
    "            )[0]\n",
    "            # resolution_scores[head_idx, layer_idx] += selection_attn.first_token_score(\n",
    "            #     layer_idx, head_idx\n",
    "            # )[0]\n",
    "\n",
    "resolution_scores /= len(selection_attns)\n",
    "resolution_scores.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eaa855e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib import pyplot as plt\n",
    "\n",
    "plt.figure(figsize=(20, 10))\n",
    "scale = torch.max(torch.abs(resolution_scores))\n",
    "plt.imshow(\n",
    "    resolution_scores.cpu().numpy(),\n",
    "    cmap=\"RdBu\",\n",
    "    aspect=\"auto\",\n",
    "    vmin=-scale,\n",
    "    vmax=scale,\n",
    ")\n",
    "plt.colorbar()\n",
    "# plt.title(f\"score(target) - max(score(distractors)) | {token_idx.upper()} tokens of options\")\n",
    "plt.title(\"score(target[0]) - sum(score(target[1:]))\")\n",
    "plt.xlabel(\"Layer\")\n",
    "plt.ylabel(\"Head\")\n",
    "\n",
    "def get_ticks(ticks, skip=5):\n",
    "    ret = []\n",
    "    for i in ticks:\n",
    "        if i % skip == 0:\n",
    "            ret.append(str(i))\n",
    "        else:\n",
    "            ret.append(\"\")\n",
    "    return ret\n",
    "\n",
    "plt.xticks(\n",
    "    ticks=range(n_layer),\n",
    "    labels=get_ticks(range(n_layer)),\n",
    "    rotation=45,\n",
    ")\n",
    "plt.yticks(\n",
    "    ticks=range(n_head),\n",
    "    labels=get_ticks(range(n_head), skip=4),\n",
    ")\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "\n",
    "scores_per_head = []\n",
    "for head_idx in range(n_head):\n",
    "    for layer_idx in range(n_layer):\n",
    "        scores_per_head.append(\n",
    "            (head_idx, layer_idx, resolution_scores[head_idx, layer_idx].item())\n",
    "        )\n",
    "\n",
    "scores_per_head = sorted(scores_per_head, key=lambda x: x[2], reverse=True)\n",
    "for head_idx, layer_idx, score in scores_per_head[:15]:\n",
    "    print(f\"Layer {layer_idx}, Head {head_idx}: {score:.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49607b90",
   "metadata": {},
   "outputs": [],
   "source": [
    "HEADS = [\n",
    "    (33, 45),\n",
    "    (33, 18),\n",
    "    (34, 1),\n",
    "    (34, 6),\n",
    "    (34, 7),\n",
    "    (35, 19),\n",
    "    (39, 40),\n",
    "    (42, 30),\n",
    "    (47, 18),\n",
    "    (52, 58),\n",
    "]\n",
    "\n",
    "# with open(\"optimized_heads.json\", \"r\") as f:\n",
    "#     HEADS = json.load(f)\n",
    "\n",
    "# with open(\"category_wise_heads.json\", \"r\") as f:\n",
    "#     category_wise_heads = json.load(f)\n",
    "# HEADS = [\n",
    "#     (layer_idx, head_idx)\n",
    "#     for layer_idx, head_idx, score in category_wise_heads[\"all\"][:50]\n",
    "# ]\n",
    "\n",
    "# print(len(HEADS))\n",
    "\n",
    "# HEADS = [(layer_idx, head_idx) for head_idx, layer_idx, score in scores_per_head[:10]]\n",
    "HEADS = heads_selected\n",
    "print(HEADS)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cbe9230a",
   "metadata": {},
   "source": [
    "### Load the optimzied heads"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1f03728",
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib import pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "optimized_path = os.path.join(\n",
    "    env_utils.DEFAULT_RESULTS_DIR,\n",
    "    \"selection/optimized_heads\",\n",
    "    mt.name.split(\"/\")[-1],\n",
    "    \"odd_one_out__not_patch_category.npz\"\n",
    ")\n",
    "\n",
    "optimization_results = np.load(optimized_path, allow_pickle=True)\n",
    "plt.plot(optimization_results[\"losses\"])\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ced6c846",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(20, 10))\n",
    "\n",
    "optimal_head_mask = torch.tensor(optimization_results[\"optimal_mask\"]).to(torch.float32)\n",
    "\n",
    "plt.imshow(\n",
    "    optimal_head_mask.T.numpy(),\n",
    "    cmap=\"Blues\",\n",
    "    aspect=\"auto\",\n",
    "    vmin=0,\n",
    "    vmax=1,\n",
    ")\n",
    "\n",
    "heads_selected = torch.nonzero(optimal_head_mask > 0.5, as_tuple=False).tolist()\n",
    "heads_selected = [\n",
    "    (layer_idx, head_idx) for layer_idx, head_idx in heads_selected if layer_idx < 50\n",
    "]\n",
    "len(heads_selected)\n",
    "\n",
    "HEADS = heads_selected"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a6d426a7",
   "metadata": {},
   "source": [
    "## Utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ed7897cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "HEADS = [\n",
    "    (33, 45),\n",
    "    (33, 18),\n",
    "    (34, 1),\n",
    "    (34, 6),\n",
    "    (34, 7),\n",
    "    (35, 19),\n",
    "    (39, 40),\n",
    "    (42, 30),\n",
    "    (47, 18),\n",
    "    (52, 58),\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "65d22b6e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy\n",
    "import random\n",
    "from src.selection.utils import KeyedSet, get_first_token_id, verify_correct_option\n",
    "from src.selection.data import SelectionSample\n",
    "from src.tokens import prepare_input\n",
    "from src.selection.data import SelectOddOneOutTask\n",
    "\n",
    "@torch.inference_mode()\n",
    "def get_counterfactual_samples_odd_one_out(\n",
    "    task: SelectOddOneOutTask,\n",
    "    obj_category: str | None = None,\n",
    "    distractor_category: str | None = None,\n",
    "    prompt_template_idx=3,\n",
    "    option_style=\"single_line\",\n",
    "    filter_by_lm_prediction: bool = True,\n",
    "    n_distractors: int = 5,\n",
    "    counterfact_obj_idx: int | None = None,\n",
    "):\n",
    "    obj_category = obj_category or random.choice(task.categories)\n",
    "    distractor_category = distractor_category or random.choice(\n",
    "        list(set(task.categories) - {obj_category})\n",
    "    )\n",
    "    assert obj_category != distractor_category, f\"{obj_category=} {distractor_category=}\"\n",
    "\n",
    "    logger.info(\n",
    "        f\"obj_category={obj_category}, distractor_category={distractor_category}, prompt_template_idx={prompt_template_idx}, option_style={option_style}, n_distractors={n_distractors}\"\n",
    "    )\n",
    "\n",
    "    patch_sample = task.get_random_sample(\n",
    "        mt=mt,\n",
    "        option_style=option_style,\n",
    "        prompt_template_idx=prompt_template_idx,\n",
    "        obj_category=obj_category,\n",
    "        distractor_category=distractor_category,\n",
    "        filter_by_lm_prediction=False,\n",
    "        n_distractors=n_distractors,\n",
    "        obj_idx = random.choice(list(set(list(range(n_distractors + 1))) - {0, 1}))\n",
    "    )\n",
    "    logger.info(f\"patch_sample={str(patch_sample)}\")\n",
    "\n",
    "    #! criterion = not distractor_category\n",
    "    # Options (2): \n",
    "    # 1. distractor_category (selected)\n",
    "    # 2. category not in [obj_category, distractor_category] \n",
    "    not_dist_category_sample = task.get_random_sample(\n",
    "        mt=mt,\n",
    "        option_style=option_style,\n",
    "        prompt_template_idx=prompt_template_idx,\n",
    "        obj_category=distractor_category,\n",
    "        distractor_category=random.choice(\n",
    "            list(set(task.categories) - {obj_category, distractor_category})\n",
    "        ),\n",
    "        filter_by_lm_prediction=False,\n",
    "        exclude_objs=patch_sample.options,\n",
    "        n_distractors=1,\n",
    "        obj_idx=counterfact_obj_idx\n",
    "    )\n",
    "    logger.info(f\"not_dist_category_sample={str(not_dist_category_sample)}\")\n",
    "    track_idx = 1 ^ not_dist_category_sample.obj_idx\n",
    "    not_dist_category_sample.metadata = {\n",
    "        \"track_type_obj\": not_dist_category_sample.options[track_idx],\n",
    "        \"track_type_obj_idx\": track_idx,\n",
    "        \"track_type_obj_token_id\": get_first_token_id(\n",
    "            not_dist_category_sample.options[track_idx], mt.tokenizer, prefix=\" \"\n",
    "        ),\n",
    "    }\n",
    "\n",
    "    #! criterion = is obj_category\n",
    "    # Options (2):\n",
    "    # 1. obj_category\n",
    "    # 2. category not in [obj_category, distractor_category] (selected)\n",
    "    is_obj_category_sample = task.get_random_sample(\n",
    "        mt=mt,\n",
    "        option_style=option_style,\n",
    "        prompt_template_idx=prompt_template_idx,\n",
    "        distractor_category=obj_category,\n",
    "        obj_category=random.choice(\n",
    "            list(set(task.categories) - {obj_category, distractor_category})\n",
    "        ),\n",
    "        filter_by_lm_prediction=False,\n",
    "        exclude_objs=patch_sample.options,\n",
    "        n_distractors=1,\n",
    "        obj_idx=counterfact_obj_idx\n",
    "    )\n",
    "    logger.info(f\"is_obj_category_sample={str(is_obj_category_sample)}\")\n",
    "    track_idx = 1 ^ is_obj_category_sample.obj_idx\n",
    "    is_obj_category_sample.metadata = {\n",
    "        \"track_type_obj_idx\": track_idx,\n",
    "        \"track_type_obj\": is_obj_category_sample.options[track_idx],\n",
    "        \"track_type_obj_category\": obj_category,\n",
    "        \"track_type_obj_token_id\": get_first_token_id(\n",
    "            is_obj_category_sample.options[track_idx], mt.tokenizer, prefix=\" \"\n",
    "        ),\n",
    "    }\n",
    "\n",
    "    if filter_by_lm_prediction:\n",
    "        test_samples = [patch_sample, not_dist_category_sample, is_obj_category_sample]\n",
    "\n",
    "        for sample in test_samples:\n",
    "            tokenized = prepare_input(tokenizer=mt, prompts=sample.prompt())\n",
    "            is_correct, predictions, track_options = verify_correct_option(\n",
    "                mt=mt, target=sample.obj, options=sample.options, input=tokenized\n",
    "            )\n",
    "            sample.metadata[\"tokenized\"] = tokenized.data\n",
    "            logger.info(sample.prompt())\n",
    "            logger.info(\n",
    "                f\"{sample.subj} | {sample.category} -> {sample.obj} | pred={[str(p) for p in predictions]}\"\n",
    "            )\n",
    "            if not is_correct:\n",
    "                logger.error(\n",
    "                    f'Prediction mismatch: {track_options[list(track_options.keys())[0]]}[\"{mt.tokenizer.decode(predictions[0].token_id)}\"] != {sample.ans_token_id}[\"{mt.tokenizer.decode(sample.ans_token_id)}\"]'\n",
    "                )\n",
    "                return get_counterfactual_samples_odd_one_out(\n",
    "                    task=task,\n",
    "                    obj_category=obj_category,\n",
    "                    distractor_category=distractor_category,\n",
    "                    prompt_template_idx=prompt_template_idx,\n",
    "                    option_style=option_style,\n",
    "                    filter_by_lm_prediction=filter_by_lm_prediction,\n",
    "                    n_distractors=n_distractors,\n",
    "                )\n",
    "            sample.prediction = predictions\n",
    "\n",
    "    return {\n",
    "        \"patch_sample\": patch_sample,\n",
    "        \"not_dist_category_sample\": not_dist_category_sample,\n",
    "        \"is_obj_category_sample\": is_obj_category_sample\n",
    "    }"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c8750bd6",
   "metadata": {},
   "source": [
    "## Experiments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a62d97cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.data import SelectOddOneOutTask\n",
    "\n",
    "select_odd_one = SelectOddOneOutTask.load(\n",
    "    path=os.path.join(\n",
    "        env_utils.DEFAULT_DATA_DIR, \n",
    "        \"selection\", \n",
    "        # \"profession.json\"\n",
    "        # \"nationality.json\"\n",
    "        \"objects.json\"\n",
    "    )\n",
    ")\n",
    "\n",
    "print(select_odd_one)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "45087a7f",
   "metadata": {},
   "outputs": [],
   "source": [
    "#############################################\n",
    "distractor_category=\"fruit\"\n",
    "obj_category=\"electronics\"\n",
    "option_style=\"single_line\"\n",
    "# option_style=\"numbered\"\n",
    "prompt_template_idx=3\n",
    "N_DISTRACTORS = 5\n",
    "counterfact_obj_idx = 1  # 0 or 1, if None, will be randomly selected\n",
    "#############################################"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a59a8c17",
   "metadata": {},
   "outputs": [],
   "source": [
    "exp_samples = get_counterfactual_samples_odd_one_out(\n",
    "    task=select_odd_one,\n",
    "    # obj_category=obj_category,\n",
    "    # distractor_category=distractor_category,\n",
    "    prompt_template_idx=prompt_template_idx,\n",
    "    option_style=option_style,\n",
    "    filter_by_lm_prediction=True,\n",
    "    n_distractors=5\n",
    ")\n",
    "\n",
    "patch_sample = exp_samples[\"patch_sample\"]\n",
    "not_dist_category_sample = exp_samples[\"not_dist_category_sample\"]\n",
    "is_obj_category_sample = exp_samples[\"is_obj_category_sample\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d554a4e",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(patch_sample)\n",
    "print(f'\"{patch_sample.prompt()}\"', \">>\", patch_sample.obj)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "85a9a2a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "patch_attn_info = verify_head_patterns(\n",
    "    prompt = patch_sample.prompt(),\n",
    "    options = patch_sample.options,\n",
    "    pivot = patch_sample.subj,\n",
    "    mt = mt,\n",
    "    heads = HEADS\n",
    ")\n",
    "\n",
    "not_dist_category_attn_info = verify_head_patterns(\n",
    "    prompt = not_dist_category_sample.prompt(),\n",
    "    options = not_dist_category_sample.options,\n",
    "    pivot = not_dist_category_sample.subj,\n",
    "    mt = mt,\n",
    "    heads = HEADS\n",
    ")\n",
    "\n",
    "is_obj_category_attn_info = verify_head_patterns(\n",
    "    prompt = is_obj_category_sample.prompt(),\n",
    "    options = is_obj_category_sample.options,\n",
    "    pivot = is_obj_category_sample.subj,\n",
    "    mt = mt,\n",
    "    heads = HEADS\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fd824f03",
   "metadata": {},
   "outputs": [],
   "source": [
    "patch_tokenized = prepare_input(prompts=patch_sample.prompt(), tokenizer=mt)\n",
    "[mt.tokenizer.decode(patch_tokenized.input_ids[0][t]) for t in range(-3, 0)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bef615d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "not_dist_category_sample.metadata"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ad040e00",
   "metadata": {},
   "outputs": [],
   "source": [
    "not_dist_category_patch = validate_q_proj_ie_on_sample_pair(\n",
    "    mt=mt,\n",
    "    clean_sample=not_dist_category_sample,\n",
    "    patch_sample=patch_sample,\n",
    "    heads=HEADS,\n",
    "    query_indices=[-3, -2, -1],\n",
    "    verify_head_behavior_on=-1,\n",
    ")\n",
    "\n",
    "print(\"-\" * 100)\n",
    "\n",
    "clean_obj_tok = get_first_token_id(not_dist_category_sample.obj, tokenizer=mt.tokenizer, prefix=\" \")\n",
    "target_obj_tok = not_dist_category_sample.metadata[\"track_type_obj_token_id\"]\n",
    "\n",
    "before_clean = not_dist_category_patch[\"clean_track\"][clean_obj_tok][1].logit\n",
    "after_clean = not_dist_category_patch[\"int_track\"][clean_obj_tok][1].logit\n",
    "delta_clean = after_clean - before_clean\n",
    "print(f\"\\\"{mt.tokenizer.decode(clean_obj_tok)}\\\" |>> {before_clean=}, {after_clean=} | {delta_clean=}\")\n",
    "\n",
    "before_target = not_dist_category_patch[\"clean_track\"][target_obj_tok][1].logit\n",
    "after_target = not_dist_category_patch[\"int_track\"][target_obj_tok][1].logit\n",
    "delta_target = after_target - before_target\n",
    "print(f\"\\\"{mt.tokenizer.decode(target_obj_tok)}\\\" |>> {before_target=}, {after_target=} | {delta_target=}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7babcb14",
   "metadata": {},
   "outputs": [],
   "source": [
    "is_patch_cat_selection_patch = validate_q_proj_ie_on_sample_pair(\n",
    "    mt=mt,\n",
    "    clean_sample=is_obj_category_sample,\n",
    "    patch_sample=patch_sample,\n",
    "    heads=HEADS,\n",
    "    query_indices=[-3, -2, -1],\n",
    "    verify_head_behavior_on=-1,\n",
    ")\n",
    "\n",
    "print(\"-\" * 100)\n",
    "clean_obj_tok = get_first_token_id(is_obj_category_sample.obj, tokenizer=mt.tokenizer, prefix=\" \")\n",
    "target_obj_tok = is_obj_category_sample.metadata[\"track_type_obj_token_id\"]\n",
    "before_clean = is_patch_cat_selection_patch[\"clean_track\"][clean_obj_tok][1].logit\n",
    "after_clean = is_patch_cat_selection_patch[\"int_track\"][clean_obj_tok][1].logit\n",
    "delta_clean = after_clean - before_clean\n",
    "print(f\"\\\"{mt.tokenizer.decode(clean_obj_tok)}\\\" |>> {before_clean=}, {after_clean=} | {delta_clean=}\")\n",
    "\n",
    "before_target = is_patch_cat_selection_patch[\"clean_track\"][target_obj_tok][1].logit\n",
    "after_target = is_patch_cat_selection_patch[\"int_track\"][target_obj_tok][1].logit\n",
    "delta_target = after_target - before_target\n",
    "print(f\"\\\"{mt.tokenizer.decode(target_obj_tok)}\\\" |>> {before_target=}, {after_target=} | {delta_target=}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "683230c2",
   "metadata": {},
   "outputs": [],
   "source": [
    "count_sample = SelectionSample(\n",
    "    obj = \"Banana\",\n",
    "    answer = \"3\",\n",
    "    obj_idx = 3,\n",
    "    # prompt_template=\"\"\"<_options_>\\nCount the number of fruits in the list\\nAnswer:\"\"\",\n",
    "    prompt_template=\"\"\"Count the number of fruits in the following list:\\n<_options_>\\nAnswer:\"\"\",\n",
    "    # prompt_template=\"\"\"<_options_>\\nAnswer:\"\"\",\n",
    "    options=[\"Bus\", \"Apple\", \"Microphone\", \"Banana\", \"Orange\", \"Cow\", \"Table\"],\n",
    "    # default_option_style=\"numbered\",\n",
    "    # default_option_style=\"bullet\",\n",
    ")\n",
    "\n",
    "print(count_sample.prompt())\n",
    "\n",
    "validate_q_proj_ie_on_sample_pair(\n",
    "    mt=mt,\n",
    "    clean_sample=count_sample,\n",
    "    patch_sample=patch_sample,\n",
    "    heads=HEADS,\n",
    "    query_indices=[-2, -1],\n",
    "    verify_head_behavior_on=-1,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1972e073",
   "metadata": {},
   "source": [
    "### Search for heads with the most patching effect"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f1d94c15",
   "metadata": {},
   "outputs": [],
   "source": [
    "from itertools import product\n",
    "from src.functional import patch_with_baukit, interpret_logits\n",
    "\n",
    "all_heads = list(product(range(30, 50), range(mt.config.num_attention_heads)))\n",
    "query_indices = [-3, -2, -1]  # last 3 tokens\n",
    "query_locations = [\n",
    "    (layer_idx, head_idx, query_idx)\n",
    "    for layer_idx, head_idx in all_heads\n",
    "    for query_idx in query_indices\n",
    "]\n",
    "\n",
    "patch_tokenized = prepare_input(prompts=patch_sample.prompt(), tokenizer=mt)\n",
    "cached_q_states, patch_output = cache_q_projections(\n",
    "    mt=mt,\n",
    "    input=patch_tokenized,\n",
    "    query_locations=query_locations,\n",
    "    return_output=True,\n",
    ")\n",
    "\n",
    "patch_logits = patch_output.logits[:, -1, :].squeeze()\n",
    "interpret_logits(\n",
    "    tokenizer=mt,\n",
    "    logits=patch_logits,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e97dcd3f",
   "metadata": {},
   "outputs": [],
   "source": [
    "clean_sample_type = {\n",
    "    \"not_dist_category\": not_dist_category_sample,\n",
    "    \"is_obj_category\": is_obj_category_sample,\n",
    "}\n",
    "\n",
    "base_scores = {}\n",
    "\n",
    "for sample_type, clean_sample in clean_sample_type.items():\n",
    "    clean_tokenized = prepare_input(prompts=clean_sample.prompt(), tokenizer=mt)\n",
    "    clean_sample.metadata[\"tokenized\"] = clean_tokenized.data\n",
    "\n",
    "    clean_out = patch_with_baukit(\n",
    "        mt=mt,\n",
    "        inputs=clean_tokenized,\n",
    "        patches=[],\n",
    "    )\n",
    "    base_logits = clean_out.logits[:, -1, :].squeeze()\n",
    "    base_predictions, base_track = interpret_logits(\n",
    "        tokenizer=mt,\n",
    "        logits=base_logits,\n",
    "        interested_tokens=[\n",
    "            get_first_token_id(clean_sample.obj, tokenizer=mt.tokenizer, prefix=\" \"),\n",
    "            clean_sample.metadata[\"track_obj_token_id\"],\n",
    "        ],\n",
    "    )\n",
    "    print(base_track[clean_sample.metadata[\"track_obj_token_id\"]])\n",
    "    base_scores[sample_type] = base_track[clean_sample.metadata[\"track_obj_token_id\"]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "03e2b73d",
   "metadata": {},
   "outputs": [],
   "source": [
    "base_scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d9db38e8",
   "metadata": {},
   "outputs": [],
   "source": [
    "q_proj_patches = []\n",
    "for (layer_idx, head_idx, query_idx), q_proj in cached_q_states.items():\n",
    "    q_proj_patches.append(\n",
    "        PatchSpec(\n",
    "            location=(\n",
    "                mt.attn_module_name_format.format(layer_idx) + \".q_proj\",\n",
    "                head_idx,\n",
    "                query_idx,\n",
    "            ),\n",
    "            patch=q_proj,\n",
    "        )\n",
    "    )\n",
    "\n",
    "int_scores = {}\n",
    "for sample_type, clean_sample in clean_sample_type.items():\n",
    "    clean_tokenized = TokenizerOutput(data=clean_sample.metadata[\"tokenized\"])\n",
    "\n",
    "    int_out = patch_with_baukit(\n",
    "        mt=mt,\n",
    "        inputs=clean_tokenized,\n",
    "        patches=q_proj_patches,\n",
    "    )\n",
    "    int_logits = int_out.logits[:, -1, :].squeeze()\n",
    "    int_predictions, int_track = interpret_logits(\n",
    "        tokenizer=mt,\n",
    "        logits=int_logits,\n",
    "        interested_tokens=[\n",
    "            get_first_token_id(clean_sample.obj, tokenizer=mt.tokenizer, prefix=\" \"),\n",
    "            clean_sample.metadata[\"track_obj_token_id\"],\n",
    "        ],\n",
    "    )\n",
    "    print(int_track[clean_sample.metadata[\"track_obj_token_id\"]])\n",
    "    int_scores[sample_type] = int_track[clean_sample.metadata[\"track_obj_token_id\"]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "90779a00",
   "metadata": {},
   "outputs": [],
   "source": [
    "individual_patching_effects = {\n",
    "    \"not_dist_category\": {},\n",
    "    \"is_obj_category\": {},\n",
    "}\n",
    "print(query_indices)\n",
    "for layer_idx, head_idx in tqdm(all_heads, desc=\"Iterating over heads\"):\n",
    "    q_proj_patches = [\n",
    "        PatchSpec(\n",
    "            location=(\n",
    "                mt.attn_module_name_format.format(layer_idx) + \".q_proj\",\n",
    "                head_idx,\n",
    "                query_idx,\n",
    "            ),\n",
    "            patch= cached_q_states[(layer_idx, head_idx, query_idx)],\n",
    "        ) for query_idx in query_indices\n",
    "    ]\n",
    "    for sample_type, clean_sample in clean_sample_type.items():\n",
    "        clean_tokenized = TokenizerOutput(data=clean_sample.metadata[\"tokenized\"])\n",
    "\n",
    "        # Patch the q_proj for the current head\n",
    "        out = patch_with_baukit(\n",
    "            mt=mt,\n",
    "            inputs=clean_tokenized,\n",
    "            patches=q_proj_patches,\n",
    "        )\n",
    "        logits = out.logits[:, -1, :].squeeze()\n",
    "        target_tok = clean_sample.metadata[\"track_obj_token_id\"]\n",
    "        predictions, track = interpret_logits(\n",
    "            tokenizer=mt,\n",
    "            logits=logits,\n",
    "            interested_tokens=[target_tok],\n",
    "        )\n",
    "        individual_patching_effects[sample_type][\n",
    "            (layer_idx, head_idx)\n",
    "        ] = track"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5b15fe39",
   "metadata": {},
   "outputs": [],
   "source": [
    "not_dist_category_sample.metadata"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd54db35",
   "metadata": {},
   "outputs": [],
   "source": [
    "from scripts.patching_within_task import SelectionQprojPatchResult\n",
    "\n",
    "not_dist_category_headwise_ie = SelectionQprojPatchResult(\n",
    "    patch_sample=patch_sample,\n",
    "    clean_sample=not_dist_category_sample,\n",
    "    interested_tokens=[not_dist_category_sample.metadata[\"track_obj_token_id\"]],\n",
    "    base_results=base_scores[\"not_dist_category\"],\n",
    "    headwise_patching_effects=individual_patching_effects[\"not_dist_category\"],\n",
    ")\n",
    "\n",
    "is_obj_category_headwise_ie = SelectionQprojPatchResult(\n",
    "    patch_sample=patch_sample,\n",
    "    clean_sample=is_obj_category_sample,\n",
    "    interested_tokens=[is_obj_category_sample.metadata[\"track_obj_token_id\"]],\n",
    "    base_results=base_scores[\"is_obj_category\"],\n",
    "    headwise_patching_effects=individual_patching_effects[\"is_obj_category\"],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b454208a",
   "metadata": {},
   "outputs": [],
   "source": [
    "patching_results = not_dist_category_headwise_ie\n",
    "patching_results = is_obj_category_headwise_ie\n",
    "\n",
    "headwise_scores = [\n",
    "    (\n",
    "        layer_idx,\n",
    "        head_idx,\n",
    "        patching_results.head_effect(layer_idx, head_idx)\n",
    "    )\n",
    "    for layer_idx, head_idx in patching_results.headwise_patching_effects.keys()\n",
    "]\n",
    "\n",
    "headwise_scores = sorted(headwise_scores, key=lambda x: x[2], reverse=True)\n",
    "patching_heads = []\n",
    "for layer_idx, head_idx, score in headwise_scores[:15]:\n",
    "    print(f\"Layer {layer_idx}, Head {head_idx}: {score:.4f}\")\n",
    "    patching_heads.append((layer_idx, head_idx))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7891c437",
   "metadata": {},
   "source": [
    "## Optimization to select heads to patch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "78345e8c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# print(clean_sample.prompt(), \">>\", clean_sample.obj)\n",
    "# print(patch_sample.prompt(), \">>\", patch_sample.obj)\n",
    "\n",
    "# train_set = [(clean_sample, patch_sample)]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "af4ce17d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.functional import free_gpu_cache\n",
    "free_gpu_cache()\n",
    "\n",
    "#################################################################################\n",
    "train_limit = 512\n",
    "prompt_template_idx = 1\n",
    "N_DISTRACTORS = 5\n",
    "OPTION_STYLE = \"single_line\"\n",
    "#################################################################################\n",
    "\n",
    "train_set = []\n",
    "while len(train_set) < train_limit:\n",
    "    samples = get_counterfactual_samples_odd_one_out(\n",
    "        task=select_odd_one,\n",
    "        filter_by_lm_prediction=True,\n",
    "        prompt_template_idx=3,\n",
    "        option_style=\"single_line\",\n",
    "        n_distractors=5\n",
    "    )\n",
    "    train_set.append(samples)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e5eb0b8a",
   "metadata": {},
   "outputs": [],
   "source": [
    "mt._model.zero_grad()\n",
    "free_gpu_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9249bc97",
   "metadata": {},
   "outputs": [],
   "source": [
    "len(train_set)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d636fae",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.optimization import get_optimal_head_mask\n",
    "import numpy as np\n",
    "\n",
    "free_gpu_cache()\n",
    "optimal_mask, losses = get_optimal_head_mask(\n",
    "    mt=mt,\n",
    "    train_set=[\n",
    "        (\n",
    "            # samples[\"not_dist_category_sample\"],\n",
    "            samples[\"is_obj_category_sample\"], \n",
    "            samples[\"patch_sample\"]\n",
    "        )\n",
    "        for samples in train_set\n",
    "    ],\n",
    "    learning_rate=1e-2,\n",
    "    n_epochs=20,\n",
    "    lamb=2e-2,\n",
    "    batch_size=16,\n",
    "    query_indices=[-3, -2, -1],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cbf2efef",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "optimized_path = os.path.join(\n",
    "    env_utils.DEFAULT_RESULTS_DIR,\n",
    "    \"selection/optimized_heads\",\n",
    "    mt.name.split(\"/\")[-1],\n",
    "    # \"odd_one_out__not_patch_category.npz\"\n",
    "    \"odd_one_out__is_obj_category.npz\"\n",
    ")\n",
    "os.makedirs(os.path.dirname(optimized_path), exist_ok=True)\n",
    "\n",
    "np.savez_compressed(\n",
    "    optimized_path,\n",
    "    **dict(\n",
    "        optimal_mask=optimal_mask.to(torch.float32).numpy(),\n",
    "        losses=np.array(losses, dtype=np.float32),\n",
    "    ),\n",
    "    allow_pickle=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2cf8d99a",
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib import pyplot as plt\n",
    "\n",
    "optimized_path = os.path.join(\n",
    "    env_utils.DEFAULT_RESULTS_DIR,\n",
    "    \"selection/optimized_heads\",\n",
    "    mt.name.split(\"/\")[-1],\n",
    "    # \"odd_one_out__not_patch_category.npz\"\n",
    "    \"odd_one_out__is_obj_category.npz\"\n",
    ")\n",
    "\n",
    "optimization_results = np.load(optimized_path, allow_pickle=True)\n",
    "plt.plot(optimization_results[\"losses\"])\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "66e217f5",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(20, 10))\n",
    "optimal_head_mask = torch.tensor(optimization_results[\"optimal_mask\"]).to(torch.float32)\n",
    "optimal_head_mask[50:, :] = 0  # Ignore the last layers, as they are not used in the task\n",
    "\n",
    "plt.imshow(\n",
    "    optimal_head_mask.T.numpy(),\n",
    "    cmap=\"Blues\",\n",
    "    aspect=\"auto\",\n",
    "    vmin=0,\n",
    "    vmax=1,\n",
    ")\n",
    "\n",
    "heads_selected = torch.nonzero(optimal_head_mask > 0.5, as_tuple=False).tolist()\n",
    "# heads_selected = [\n",
    "#     (layer_idx, head_idx) for layer_idx, head_idx in heads_selected if layer_idx < 55\n",
    "# ]\n",
    "len(heads_selected)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f414b376",
   "metadata": {},
   "source": [
    "## Validation of the patching effect"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "00698133",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.functional import free_gpu_cache\n",
    "free_gpu_cache()\n",
    "validation_set = []\n",
    "validation_limit = 256\n",
    "\n",
    "while len(validation_set) < validation_limit:\n",
    "    samples = get_counterfactual_samples_odd_one_out(\n",
    "        task=select_odd_one,\n",
    "        filter_by_lm_prediction=True,\n",
    "        prompt_template_idx=3,\n",
    "        option_style=\"single_line\",\n",
    "        n_distractors=5\n",
    "    )\n",
    "    validation_set.append(samples)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3d041d50",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.optimization import validate_q_proj_ie_on_sample_pair\n",
    "\n",
    "samples = validation_set[17]\n",
    "patch = samples[\"patch_sample\"]\n",
    "# clean = samples[\"not_dist_category_sample\"]\n",
    "clean = samples[\"is_obj_category_sample\"]\n",
    "\n",
    "print(clean.prompt(), \">>\", clean.obj)\n",
    "print(patch.prompt(), \">>\", patch.obj)\n",
    "\n",
    "validation_result = validate_q_proj_ie_on_sample_pair(\n",
    "    mt=mt,\n",
    "    clean_sample=clean,\n",
    "    patch_sample=patch,\n",
    "    # heads=HEADS,\n",
    "    heads=heads_selected,\n",
    "    query_indices=[-3, -2, -1],\n",
    "    verify_head_behavior_on=-1,\n",
    ")\n",
    "\n",
    "clean_obj = clean.ans_token_id\n",
    "target_obj = clean.metadata[\"track_type_obj_token_id\"]\n",
    "\n",
    "logger.debug(f\"clean obj: {mt.tokenizer.decode(clean_obj)}\")\n",
    "logger.debug(f\"target obj: {mt.tokenizer.decode(target_obj)}\")\n",
    "\n",
    "before_intervention = {\n",
    "    \"clean_rank\": validation_result[\"clean_track\"][clean_obj][0],\n",
    "    \"clean_logit\": validation_result[\"clean_track\"][clean_obj][1].logit,\n",
    "    \"target_rank\": validation_result[\"clean_track\"][target_obj][0],\n",
    "    \"target_logit\": validation_result[\"clean_track\"][target_obj][1].logit,\n",
    "}\n",
    "\n",
    "after_intervention = {\n",
    "    \"clean_rank\": validation_result[\"int_track\"][clean_obj][0],\n",
    "    \"clean_logit\": validation_result[\"int_track\"][clean_obj][1].logit,\n",
    "    \"target_rank\": validation_result[\"int_track\"][target_obj][0],\n",
    "    \"target_logit\": validation_result[\"int_track\"][target_obj][1].logit,\n",
    "}\n",
    "\n",
    "clean_rank_delta = before_intervention[\"clean_rank\"] - after_intervention[\"clean_rank\"]\n",
    "target_rank_delta = (\n",
    "    before_intervention[\"target_rank\"] - after_intervention[\"target_rank\"]\n",
    ")\n",
    "logger.info(\n",
    "    f\"Clean Prediction Rank Change: {before_intervention['clean_rank']} -> {after_intervention['clean_rank']} | Delta: {clean_rank_delta} \"\n",
    ")\n",
    "logger.info(\n",
    "    f\"Target Prediction Rank Change: {before_intervention['target_rank']} -> {after_intervention['target_rank']} | Delta: {target_rank_delta} \"\n",
    ")\n",
    "\n",
    "clean_logit_delta = (\n",
    "    after_intervention[\"clean_logit\"] - before_intervention[\"clean_logit\"]\n",
    ")\n",
    "target_logit_delta = (\n",
    "    after_intervention[\"target_logit\"] - before_intervention[\"target_logit\"]\n",
    ")\n",
    "logger.info(\n",
    "    f\"Clean Prediction Logit Change: {before_intervention['clean_logit']:.4f} -> {after_intervention['clean_logit']:.4f} | Delta: {clean_logit_delta:.4f} \"\n",
    ")\n",
    "logger.info(\n",
    "    f\"Target Prediction Logit Change: {before_intervention['target_logit']:.4f} -> {after_intervention['target_logit']:.4f} | Delta: {target_logit_delta:.4f} \"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "65931db6",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm.auto import tqdm\n",
    "\n",
    "not_patch_category_validation_set = [\n",
    "    (samples[\"not_dist_category_sample\"], samples[\"patch_sample\"])\n",
    "    for samples in validation_set\n",
    "]\n",
    "\n",
    "is_obj_category_validation_set = [\n",
    "    (samples[\"is_obj_category_sample\"], samples[\"patch_sample\"])\n",
    "    for samples in validation_set\n",
    "]\n",
    "\n",
    "validation_results = []\n",
    "for clean_sample, patch_sample in tqdm(not_patch_category_validation_set):\n",
    "# for clean_sample, patch_sample in tqdm(is_obj_category_validation_set):\n",
    "    result = validate_q_proj_ie_on_sample_pair(\n",
    "        mt=mt,\n",
    "        clean_sample=clean_sample,\n",
    "        patch_sample=patch_sample,\n",
    "        # heads=HEADS,\n",
    "        # heads=heads_selected,\n",
    "        heads=overlapping_heads,\n",
    "        query_indices=[-3, -2, -1],\n",
    "        verify_head_behavior_on=None,\n",
    "    )\n",
    "    validation_results.append(result)\n",
    "    print(\"=\" * 80)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "20920f5c",
   "metadata": {},
   "outputs": [],
   "source": [
    "len(overlapping_heads)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "024eed06",
   "metadata": {},
   "outputs": [],
   "source": [
    "before_intervention = []\n",
    "after_intervention = []\n",
    "\n",
    "for intervention_result in validation_results:\n",
    "    clean_sample = intervention_result[\"clean_sample\"]\n",
    "    patch_sample = intervention_result[\"patch_sample\"]\n",
    "\n",
    "    clean_obj = clean_sample.ans_token_id\n",
    "    target_obj = clean_sample.metadata[\"track_type_obj_token_id\"]\n",
    "\n",
    "    before_intervention.append({\n",
    "        \"clean_rank\": intervention_result[\"clean_track\"][clean_obj][0],\n",
    "        \"clean_logit\": intervention_result[\"clean_track\"][clean_obj][1].logit,\n",
    "        \"target_rank\": intervention_result[\"clean_track\"][target_obj][0],\n",
    "        \"target_logit\": intervention_result[\"clean_track\"][target_obj][1].logit,\n",
    "    })\n",
    "\n",
    "    after_intervention.append({\n",
    "        \"clean_rank\": intervention_result[\"int_track\"][clean_obj][0],\n",
    "        \"clean_logit\": intervention_result[\"int_track\"][clean_obj][1].logit,\n",
    "        \"target_rank\": intervention_result[\"int_track\"][target_obj][0],\n",
    "        \"target_logit\": intervention_result[\"int_track\"][target_obj][1].logit,\n",
    "    })"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "573b7692",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "clean_rank_delta = [\n",
    "    after[\"clean_rank\"] - before[\"clean_rank\"]\n",
    "    for before, after in zip(before_intervention, after_intervention)\n",
    "]\n",
    "target_rank_delta = [\n",
    "    after[\"target_rank\"] - before[\"target_rank\"]\n",
    "    for before, after in zip(before_intervention, after_intervention)\n",
    "]\n",
    "\n",
    "clean_rank_delta, target_rank_delta = np.array(clean_rank_delta), np.array(\n",
    "    target_rank_delta\n",
    ")\n",
    "print(f\"clean_rank_delta: {clean_rank_delta.mean():.4f} ± {clean_rank_delta.std():.4f}\")\n",
    "print(\n",
    "    f\"target_rank_delta: {target_rank_delta.mean():.4f} ± {target_rank_delta.std():.4f}\"\n",
    ")\n",
    "\n",
    "clean_rank_after_intervention = [after[\"clean_rank\"] for after in after_intervention]\n",
    "clean_rank_after_intervention = np.array(clean_rank_after_intervention)\n",
    "print(\n",
    "    f\"clean_rank_after_intervention: {clean_rank_after_intervention.mean():.4f} ± {clean_rank_after_intervention.std():.4f}\"\n",
    ")\n",
    "\n",
    "target_rank_after_intervention = [after[\"target_rank\"] for after in after_intervention]\n",
    "target_rank_after_intervention = np.array(target_rank_after_intervention)\n",
    "print(\n",
    "    f\"target_rank_after_intervention: {target_rank_after_intervention.mean():.4f} ± {target_rank_after_intervention.std():.4f}\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10522073",
   "metadata": {},
   "outputs": [],
   "source": [
    "clean_logit_delta = [\n",
    "    after[\"clean_logit\"] - before[\"clean_logit\"]\n",
    "    for before, after in zip(before_intervention, after_intervention)\n",
    "]\n",
    "target_logit_delta = [\n",
    "    after[\"target_logit\"] - before[\"target_logit\"]\n",
    "    for before, after in zip(before_intervention, after_intervention)\n",
    "]\n",
    "clean_logit_delta, target_logit_delta = np.array(clean_logit_delta), np.array(target_logit_delta)\n",
    "print(f\"clean_logit_delta: {clean_logit_delta.mean():.4f} ± {clean_logit_delta.std():.4f}\")\n",
    "print(f\"target_logit_delta: {target_logit_delta.mean():.4f} ± {target_logit_delta.std():.4f}\")\n",
    "\n",
    "clean_logit_after_intervention = [\n",
    "    after[\"clean_logit\"]\n",
    "    for after in after_intervention\n",
    "]\n",
    "clean_logit_after_intervention = np.array(clean_logit_after_intervention)\n",
    "print(f\"clean_logit_after_intervention: {clean_logit_after_intervention.mean():.4f} ± {clean_logit_after_intervention.std():.4f}\")\n",
    "\n",
    "target_logit_after_intervention = [\n",
    "    after[\"target_logit\"]\n",
    "    for after in after_intervention\n",
    "]\n",
    "target_logit_after_intervention = np.array(target_logit_after_intervention)\n",
    "print(f\"target_logit_after_intervention: {target_logit_after_intervention.mean():.4f} ± {target_logit_after_intervention.std():.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d484f25",
   "metadata": {},
   "outputs": [],
   "source": [
    "top_1 = sum([1 for after in after_intervention if after[\"target_rank\"] == 1])\n",
    "top_1 / len(after_intervention)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bbaedc00",
   "metadata": {},
   "outputs": [],
   "source": [
    "len(heads_selected)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "613dc0d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "counter_patch_type_top_option = 0\n",
    "failed_cases = []\n",
    "\n",
    "for intervention_result in validation_results:\n",
    "    clean_sample = intervention_result[\"clean_sample\"]\n",
    "    patch_sample = intervention_result[\"patch_sample\"]\n",
    "    int_track = intervention_result[\"int_track\"]\n",
    "    clean_track = intervention_result[\"clean_track\"]\n",
    "    if (\n",
    "        int_track[list(int_track.keys())[0]][1].token_id\n",
    "        == clean_sample.metadata[\"track_type_obj_token_id\"]\n",
    "    ): \n",
    "        counter_patch_type_top_option += 1\n",
    "    else:\n",
    "        failed_cases.append(\n",
    "            {\n",
    "                \"clean_sample\": clean_sample,\n",
    "                \"patch_sample\": patch_sample,\n",
    "                \"int_track\": int_track,\n",
    "                \"clean_track\": clean_track,\n",
    "            }\n",
    "        )\n",
    "\n",
    "top_1_accuracy = counter_patch_type_top_option / len(validation_results)\n",
    "print(\n",
    "    f\"Counterfactual patching accuracy: {top_1_accuracy:.4f} ({counter_patch_type_top_option}/{len(validation_results)})\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2a66e502",
   "metadata": {},
   "outputs": [],
   "source": [
    "for failed_case in failed_cases[:20]:\n",
    "    clean_sample = failed_case[\"clean_sample\"]\n",
    "    patch_sample = failed_case[\"patch_sample\"]\n",
    "    int_track = failed_case[\"int_track\"]\n",
    "    clean_track = failed_case[\"clean_track\"]\n",
    "\n",
    "    print(\"Clean Sample:\")\n",
    "    print(clean_sample.prompt(), \">>\", clean_sample.obj)\n",
    "\n",
    "    print(\"-\" * 100)\n",
    "    print(clean_sample.metadata[\"track_category\"], \">>\", clean_sample.metadata[\"track_type_obj\"])\n",
    "    print(\"-\" * 100)\n",
    "\n",
    "    clean_track = [pred for tok_id, (rank, pred) in clean_track.items()]\n",
    "    print(f\"Clean Track: {json.dumps([str(pred) for pred in clean_track], indent=4)}\")\n",
    "\n",
    "    int_track = [pred for tok_id, (rank, pred) in int_track.items()]\n",
    "    print(f\"Intervened Track: {json.dumps([str(pred) for pred in int_track], indent=4)}\")\n",
    "    print(\"=\" * 100)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "27e073d6",
   "metadata": {},
   "source": [
    "## Intersection of Heads"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4264e808",
   "metadata": {},
   "outputs": [],
   "source": [
    "task_names = [\n",
    "    \"odd_one_out__not_patch_category\", \n",
    "    \"odd_one_out__is_obj_category\"\n",
    "]\n",
    "heads = {task_name: [] for task_name in task_names}\n",
    "colors = {\n",
    "    \"odd_one_out__not_patch_category\": \"Blues\",\n",
    "    \"odd_one_out__is_obj_category\": \"Reds\",\n",
    "}\n",
    "\n",
    "# Create figure and axis\n",
    "fig, ax = plt.subplots(figsize=(12, 6))\n",
    "\n",
    "# Create combined mask for overlapping heads\n",
    "combined_mask = np.zeros((mt.config.num_attention_heads, mt.n_layer)) \n",
    "\n",
    "for i, task_name in enumerate(task_names):\n",
    "    print(\"Processing task:\", task_name)\n",
    "\n",
    "    optimized_path = os.path.join(\n",
    "        env_utils.DEFAULT_RESULTS_DIR,\n",
    "        \"selection/optimized_heads\",\n",
    "        mt.name.split(\"/\")[-1],\n",
    "        f\"{task_name}.npz\",\n",
    "    )\n",
    "    optimization_results = np.load(optimized_path, allow_pickle=True)\n",
    "    optimal_head_mask = torch.tensor(optimization_results[\"optimal_mask\"]).to(\n",
    "        torch.float32\n",
    "    )\n",
    "    print(f\"Optimal head mask shape for {task_name}: {optimal_head_mask.shape}\")\n",
    "    \n",
    "    # Get head positions\n",
    "    task_heads = torch.nonzero(optimal_head_mask > 0.5, as_tuple=False).tolist()\n",
    "    task_heads = [\n",
    "        (layer_idx, head_idx) for layer_idx, head_idx in task_heads if layer_idx < 50\n",
    "    ]\n",
    "    heads[task_name] = task_heads\n",
    "    \n",
    "    # Prepare mask for visualization\n",
    "    optimal_head_mask = optimal_head_mask.round()\n",
    "    optimal_head_mask[50:] = 0\n",
    "    \n",
    "    # Create a masked array to handle transparency properly\n",
    "    mask_array = optimal_head_mask.T.numpy()\n",
    "    masked_data = np.ma.masked_where(mask_array == 0, mask_array)\n",
    "    \n",
    "    # Plot with proper alpha blending\n",
    "    im = ax.imshow(\n",
    "        masked_data,\n",
    "        cmap=colors[task_name],\n",
    "        aspect=\"auto\",\n",
    "        vmin=0,\n",
    "        vmax=1.5,\n",
    "        alpha=0.8 if i == 0 else 0.5,  # Different alphas for better visibility\n",
    "        interpolation='nearest'\n",
    "    )\n",
    "    \n",
    "    # Track overlaps (optional)\n",
    "    combined_mask += mask_array * (i + 1)\n",
    "\n",
    "# Add labels and formatting\n",
    "ax.set_xlabel(\"Layer\")\n",
    "ax.set_ylabel(\"Head\")\n",
    "ax.set_title(\"Filter Heads Comparison: Odd One Out: Not Patch Category (Blue) vs Odd One Out: Is Obj Category (Red)\")\n",
    "\n",
    "# Add grid for clarity\n",
    "ax.set_xticks(np.arange(0, 50, 5))\n",
    "ax.set_yticks(np.arange(0, optimal_head_mask.shape[1], 2))\n",
    "ax.grid(True, alpha=0.3, linestyle='--')\n",
    "\n",
    "# Create custom legend\n",
    "from matplotlib.patches import Patch\n",
    "legend_elements = [\n",
    "    Patch(facecolor='red', alpha=0.5, label='Odd One Out: Not Patch Category'),\n",
    "    Patch(facecolor='blue', alpha=0.5, label='Odd One Out: Is Obj Category'),\n",
    "    Patch(facecolor='purple', alpha=1, label='Overlap')\n",
    "]\n",
    "ax.legend(handles=legend_elements, loc='upper right')\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "# Optional: Print overlap statistics\n",
    "overlapping_heads = []\n",
    "for task1_head in heads[\"odd_one_out__not_patch_category\"]:\n",
    "    if task1_head in heads[\"odd_one_out__is_obj_category\"]:\n",
    "        overlapping_heads.append(task1_head)\n",
    "\n",
    "print(f\"Total heads for odd_one_out__not_patch_category: {len(heads['odd_one_out__not_patch_category'])}\")\n",
    "print(f\"Total heads for odd_one_out__is_obj_category: {len(heads['odd_one_out__is_obj_category'])}\")\n",
    "print(f\"Overlapping heads: {len(overlapping_heads)}\")\n",
    "if overlapping_heads:\n",
    "    print(f\"Overlapping positions: {overlapping_heads}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b8de559",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "connection",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
