{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f8ec34ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bfae7235",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "\n",
    "import sys\n",
    "\n",
    "sys.path.append(\"../\")\n",
    "\n",
    "##################################################################\n",
    "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
    "# os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0,1,2,3,4,5,6,7\"\n",
    "##################################################################\n",
    "\n",
    "import logging\n",
    "from src.utils import logging_utils\n",
    "from src.utils import env_utils\n",
    "\n",
    "logger = logging.getLogger(__name__)\n",
    "\n",
    "logging.basicConfig(\n",
    "    level=logging.DEBUG,\n",
    "    format=logging_utils.DEFAULT_FORMAT,\n",
    "    datefmt=logging_utils.DEFAULT_DATEFMT,\n",
    "    stream=sys.stdout,\n",
    ")\n",
    "\n",
    "import torch\n",
    "import transformers\n",
    "\n",
    "logger.info(f\"{torch.__version__=}, {torch.version.cuda=}\")\n",
    "logger.info(\n",
    "    f\"{torch.cuda.is_available()=}, {torch.cuda.device_count()=}, {torch.cuda.get_device_name()=}\"\n",
    ")\n",
    "logger.info(f\"{transformers.__version__=}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc7720ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.utils.training_utils import get_device_map\n",
    "\n",
    "# model_key = \"meta-llama/Llama-3.2-3B\"\n",
    "# model_key = \"meta-llama/Llama-3.1-8B\"\n",
    "model_key = \"meta-llama/Llama-3.3-70B-Instruct\"\n",
    "# model_key = \"meta-llama/Llama-3.1-405B-Instruct\"\n",
    "\n",
    "# model_key = \"google/gemma-2-9b-it\"\n",
    "# model_key = \"google/gemma-3-12b-it\"\n",
    "# model_key = \"google/gemma-2-27b-it\"\n",
    "\n",
    "# model_key = \"deepseek-ai/DeepSeek-R1-Distill-Llama-8B\"\n",
    "\n",
    "# model_key = \"allenai/OLMo-2-1124-7B-Instruct\"\n",
    "# model_key = \"allenai/OLMo-7B-0424-hf\"\n",
    "\n",
    "# model_key = \"Qwen/Qwen2-7B\"\n",
    "# model_key = \"Qwen/Qwen2.5-14B-Instruct\"\n",
    "# model_key = \"Qwen/Qwen2.5-32B-Instruct\"\n",
    "# model_key = \"Qwen/Qwen2.5-72B-Instruct\"\n",
    "\n",
    "# model_key = \"Qwen/Qwen3-1.7B\"\n",
    "# model_key = \"Qwen/Qwen3-4B\"\n",
    "# model_key = \"Qwen/Qwen3-8B\"\n",
    "# model_key = \"Qwen/Qwen3-14B\"\n",
    "# model_key = \"Qwen/Qwen3-32B\"\n",
    "\n",
    "# device_map = get_device_map(model_key, 30, n_gpus=8)\n",
    "# device_map"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "683855df",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.models import ModelandTokenizer\n",
    "\n",
    "# from transformers import BitsAndBytesConfig\n",
    "\n",
    "mt = ModelandTokenizer(\n",
    "    model_key=model_key,\n",
    "    torch_dtype=torch.bfloat16,\n",
    "    # device_map=device_map,\n",
    "    device_map=\"auto\",\n",
    "    # quantization_config = BitsAndBytesConfig(\n",
    "    #     # load_in_4bit=True\n",
    "    #     load_in_8bit=True\n",
    "    # )\n",
    "    attn_implementation=\"eager\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "080021e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# file_path = os.path.join(\n",
    "#     env_utils.DEFAULT_DATA_DIR,\n",
    "#     \"selection\",\n",
    "#     # \"profession.json\"\n",
    "#     # \"nationality.json\"\n",
    "#     \"objects.json\",\n",
    "# )\n",
    "\n",
    "# with open(file_path, \"r\") as f:\n",
    "#     temp = json.load(f)\n",
    "\n",
    "# for cat in temp[\"categories\"]:\n",
    "#     temp[\"categories\"][cat] = [obj.capitalize() for obj in temp[\"categories\"][cat]]\n",
    "\n",
    "# with open(file_path, \"w\") as f:\n",
    "#     json.dump(temp, f, indent=4)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a62d97cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.data import SelectOneTask, SelectOrderTask\n",
    "\n",
    "#################################################################################\n",
    "# TASK_CLS = SelectOrderTask\n",
    "# prompt_template_idx = 1\n",
    "TASK_CLS = SelectOneTask\n",
    "prompt_template_idx = 3\n",
    "N_DISTRACTORS = 5\n",
    "OPTION_STYLE = \"single_line\"\n",
    "#################################################################################\n",
    "\n",
    "select_task = TASK_CLS.load(\n",
    "    path=os.path.join(\n",
    "        env_utils.DEFAULT_DATA_DIR, \n",
    "        \"selection\", \n",
    "        # \"profession.json\"\n",
    "        # \"nationality.json\"\n",
    "        \"objects.json\"\n",
    "    )\n",
    ")\n",
    "\n",
    "# select_task.filter_single_token(mt.tokenizer, prefix=\" \")\n",
    "print(select_task)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "45087a7f",
   "metadata": {},
   "outputs": [],
   "source": [
    "sample = select_task.get_random_sample(\n",
    "    mt = mt,\n",
    "    option_style=OPTION_STYLE,\n",
    "    prompt_template_idx=prompt_template_idx,\n",
    "    obj_idx=2,\n",
    "    # category=\"actor\",\n",
    "    # category=\"Brazil\"\n",
    "    # category=\"fruit\",\n",
    "    filter_by_lm_prediction=False,\n",
    ")\n",
    "\n",
    "print(sample)\n",
    "print(sample.prompt())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f66399b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.utils import verify_correct_option\n",
    "# sample.prompt_template = select_prof.prompt_templates[3]\n",
    "print(f'\"{sample.prompt()}\"', \">>\", sample.obj)\n",
    "\n",
    "verify_correct_option(\n",
    "    mt=mt,\n",
    "    target=sample.obj,\n",
    "    options=sample.options,\n",
    "    input=sample.prompt()\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea0499b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.functional import generate_with_patch\n",
    "\n",
    "gen = generate_with_patch(\n",
    "    mt = mt,\n",
    "    inputs = sample.prompt(),\n",
    "    max_new_tokens=20,\n",
    "    do_sample=False,\n",
    "    remove_prefix=True\n",
    ")[0]\n",
    "print(f'\"{gen}\"', \">>\", sample.obj)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2e8a8aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "mt.n_layer, mt.config.num_attention_heads"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a35516e",
   "metadata": {},
   "outputs": [],
   "source": [
    "HEADS = [\n",
    "    (33, 45),\n",
    "    (33, 18),\n",
    "    (34, 1),\n",
    "    (34, 6),\n",
    "    (34, 7),\n",
    "    (35, 19),\n",
    "    (39, 40),\n",
    "    (42, 30),\n",
    "    (47, 18),\n",
    "    (52, 58),\n",
    "]\n",
    "# HEADS = [(35, 19)]\n",
    "\n",
    "\n",
    "# with open(\"optimized_heads.json\", \"r\") as f:\n",
    "#     HEADS = json.load(f)\n",
    "\n",
    "# with open(\"category_wise_heads.json\", \"r\") as f:\n",
    "#     category_wise_heads = json.load(f)\n",
    "# HEADS = [\n",
    "#     (layer_idx, head_idx)\n",
    "#     for layer_idx, head_idx, score in category_wise_heads[\"all\"][:100]\n",
    "# ]\n",
    "# HEADS = [(layer_idx, head_idx) for layer_idx, head_idx in HEADS if layer_idx < 61]\n",
    "\n",
    "\n",
    "print(len(HEADS))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe381b2f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib import pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "# optimized_path = os.path.join(\n",
    "#     env_utils.DEFAULT_RESULTS_DIR,\n",
    "#     \"selection/optimized_heads\",\n",
    "#     mt.name.split(\"/\")[-1],\n",
    "#     f\"{select_task.task_name}.npz\"\n",
    "# )\n",
    "\n",
    "optimized_path = os.path.join(\n",
    "    env_utils.DEFAULT_RESULTS_DIR,\n",
    "    \"selection/optimized_heads\",\n",
    "    # mt.name.split(\"/\")[-1],\n",
    "    model_key.split(\"/\")[-1],\n",
    "    f\"{select_task.task_name}\",\n",
    "    \"epoch_10.npz\"\n",
    ")\n",
    "\n",
    "optimization_results = np.load(optimized_path, allow_pickle=True)\n",
    "plt.plot(optimization_results[\"losses\"])\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f87810f",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(20, 10))\n",
    "\n",
    "optimal_head_mask = torch.tensor(optimization_results[\"optimal_mask\"]).to(torch.float32)\n",
    "optimal_head_mask[50:, :] = 0.0\n",
    "\n",
    "plt.imshow(\n",
    "    optimal_head_mask.T.numpy(),\n",
    "    cmap=\"Blues\",\n",
    "    aspect=\"auto\",\n",
    "    vmin=0,\n",
    "    vmax=1,\n",
    ")\n",
    "\n",
    "heads_selected = torch.nonzero(optimal_head_mask > 0.5, as_tuple=False).tolist()\n",
    "heads_selected = [\n",
    "    (layer_idx, head_idx) for layer_idx, head_idx in heads_selected if layer_idx < 50\n",
    "]\n",
    "print(len(heads_selected))\n",
    "\n",
    "HEADS = heads_selected\n",
    "\n",
    "(35, 19) in HEADS, (35, 19) in heads_selected"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "85a9a2a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.attention import get_attention_matrices\n",
    "from src.selection.functional import (\n",
    "    verify_head_patterns,\n",
    "    get_patches_to_verify_independent_enrichment,\n",
    ")\n",
    "\n",
    "attn_pattern = verify_head_patterns(\n",
    "    prompt=sample.prompt(option_style=\"single_line\"),\n",
    "    options=sample.options,\n",
    "    pivot=sample.subj,\n",
    "    mt=mt,\n",
    "    heads=HEADS,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "129fe333",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.data import get_counterfactual_samples_within_task"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "425f6285",
   "metadata": {},
   "outputs": [],
   "source": [
    "patch_sample, clean_sample = get_counterfactual_samples_within_task(\n",
    "    mt=mt,\n",
    "    task=select_task,\n",
    "    patch_category=\"fruit\",\n",
    "    clean_category=\"vehicle\",\n",
    "    filter_by_lm_prediction=True,\n",
    "    prompt_template_idx=prompt_template_idx,\n",
    "    option_style=OPTION_STYLE,\n",
    "    # distinct_options=True,\n",
    "    distinct_options=False,\n",
    "    n_distractors=N_DISTRACTORS,\n",
    ")\n",
    "\n",
    "# patch_sample.default_option_style = \"single_line\"\n",
    "# clean_sample.default_option_style = \"numbered\"\n",
    "\n",
    "clean_sample.metadata"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "510772fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.functional import generate_with_patch\n",
    "from src.selection.utils import get_first_token_id\n",
    "\n",
    "# patch_sample.options[patch_sample.obj_idx] = \"Screw\"\n",
    "# patch_sample.options[patch_sample.obj_idx] = patch_sample.obj\n",
    "patch_sample.options = [\"Cherry\", \"Knife\", \"Pen\", \"Ambulance\"]\n",
    "clean_sample.options = [\"Binder\", \"Peach\", \"Watch\", \"Scooter\", \"Phone\"]\n",
    "clean_sample.metadata = {\n",
    "    \"track_category\": \"fruit\",\n",
    "    \"track_type_obj\": \"Peach\",\n",
    "    \"track_type_obj_idx\": 1,\n",
    "    \"track_type_obj_token_id\": get_first_token_id(tokenizer=mt.tokenizer, name=\"Peach\", prefix=\" \")\n",
    "}\n",
    "clean_sample.obj_idx = 3\n",
    "clean_sample.obj = \"Scooter\"\n",
    "clean_sample.ans_token_id = get_first_token_id(tokenizer=mt.tokenizer, name=\"Scooter\", prefix=\" \")\n",
    "\n",
    "for sample in [patch_sample, clean_sample]:\n",
    "    # for sample in [order_sample_1, order_sample_2]:\n",
    "    print(sample.prompt(), \">>\", sample.obj)\n",
    "    attn_pattern = verify_head_patterns(\n",
    "        prompt=sample.prompt(),\n",
    "        options=sample.options,\n",
    "        pivot=sample.subj,\n",
    "        mt=mt,\n",
    "        # heads=HEADS,\n",
    "        heads=[(35, 19)],\n",
    "        # generate_full_answer=True,\n",
    "        query_index=-1,\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0b4d487f",
   "metadata": {},
   "outputs": [],
   "source": [
    "clean_sample.metadata"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d728087",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.tokens import prepare_input\n",
    "\n",
    "clean_tokenized = prepare_input(prompts = clean_sample.prompt(), tokenizer=mt.tokenizer)\n",
    "patch_tokenized = prepare_input(prompts = patch_sample.prompt(), tokenizer=mt.tokenizer)\n",
    "clean_tokenized.input_ids.shape, patch_tokenized.input_ids.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7891c437",
   "metadata": {},
   "source": [
    "## Optimization to select heads to patch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "af4ce17d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.functional import free_gpu_cache\n",
    "free_gpu_cache()\n",
    "\n",
    "#################################################################################\n",
    "train_limit = 64\n",
    "# prompt_template_idx = 1\n",
    "prompt_template_idx = 3\n",
    "N_DISTRACTORS = 5\n",
    "OPTION_STYLE = \"single_line\"\n",
    "#################################################################################\n",
    "\n",
    "train_set = []\n",
    "while len(train_set) < train_limit:\n",
    "    patch, clean = get_counterfactual_samples_within_task(\n",
    "        task=select_task,\n",
    "        filter_by_lm_prediction=True,\n",
    "        prompt_template_idx=prompt_template_idx,\n",
    "        option_style=OPTION_STYLE,\n",
    "        distinct_options=True,\n",
    "        n_distractors=N_DISTRACTORS,\n",
    "    )\n",
    "    train_set.append((clean, patch))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e5eb0b8a",
   "metadata": {},
   "outputs": [],
   "source": [
    "mt._model.zero_grad()\n",
    "free_gpu_cache()\n",
    "print(len(train_set))\n",
    "# print(len(heads_selected))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d636fae",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.optimization import get_optimal_head_mask\n",
    "import numpy as np\n",
    "\n",
    "free_gpu_cache()\n",
    "optimal_mask, losses = get_optimal_head_mask(\n",
    "    mt=mt,\n",
    "    train_set=train_set,\n",
    "    learning_rate=1e-2,\n",
    "    n_epochs=20,\n",
    "    lamb=2e-2,\n",
    "    batch_size=16,\n",
    "    query_indices=[-3, -2, -1],\n",
    "    # black_list_heads=heads_selected\n",
    ")\n",
    "\n",
    "optimized_path = os.path.join(\n",
    "    env_utils.DEFAULT_RESULTS_DIR,\n",
    "    \"selection/optimized_heads\",\n",
    "    # \"selection/optimized_backup_heads\",\n",
    "    # \"selection/optimized_single_tok\",\n",
    "    # \"selection/optimized_same_collection_single_tok\",\n",
    "    mt.name.split(\"/\")[-1],\n",
    "    f\"{TASK_CLS.task_name}.npz\"\n",
    ")\n",
    "\n",
    "os.makedirs(os.path.dirname(optimized_path), exist_ok=True)\n",
    "\n",
    "np.savez_compressed(\n",
    "    optimized_path,\n",
    "    **dict(\n",
    "        optimal_mask=optimal_mask.to(torch.float32).numpy(),\n",
    "        losses=np.array(losses, dtype=np.float32),\n",
    "    ),\n",
    "    allow_pickle=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "04fa7b04",
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib import pyplot as plt\n",
    "\n",
    "optimized_path = os.path.join(\n",
    "    env_utils.DEFAULT_RESULTS_DIR,\n",
    "    \"selection/optimized_heads\",\n",
    "    # \"selection/optimized_backup_heads\",\n",
    "    # \"selection/optimized_single_tok\",\n",
    "    # \"selection/optimized_same_collection_single_tok\",\n",
    "    mt.name.split(\"/\")[-1],\n",
    "    f\"{TASK_CLS.task_name}.npz\"\n",
    ")\n",
    "\n",
    "optimization_results = np.load(optimized_path, allow_pickle=True)\n",
    "plt.plot(optimization_results[\"losses\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "66e217f5",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(20, 10))\n",
    "\n",
    "optimal_head_mask = torch.tensor(optimization_results[\"optimal_mask\"]).to(torch.float32)\n",
    "optimal_head_mask[50:, :] = 0.0\n",
    "\n",
    "plt.imshow(\n",
    "    optimal_head_mask.T.numpy(),\n",
    "    cmap=\"Blues\",\n",
    "    aspect=\"auto\",\n",
    "    vmin=0,\n",
    "    vmax=1,\n",
    ")\n",
    "\n",
    "backup_heads = torch.nonzero(optimal_head_mask > 0.5, as_tuple=False).to(dtype=torch.int).tolist()\n",
    "backup_heads = [\n",
    "    (layer_idx, head_idx) for layer_idx, head_idx in backup_heads\n",
    "]\n",
    "print(len(backup_heads))\n",
    "\n",
    "HEADS = backup_heads\n",
    "\n",
    "(35, 19) in HEADS, (35, 19) in backup_heads"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de3b0bcc",
   "metadata": {},
   "outputs": [],
   "source": [
    "# len(heads_selected), len(backup_heads)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f414b376",
   "metadata": {},
   "source": [
    "## Validation of the patching effect"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "00698133",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.functional import free_gpu_cache\n",
    "free_gpu_cache()\n",
    "validation_set = []\n",
    "validation_limit = 1024\n",
    "\n",
    "while len(validation_set) < validation_limit:\n",
    "    patch, clean = get_counterfactual_samples_within_task(\n",
    "        task=select_task,\n",
    "        filter_by_lm_prediction=True,\n",
    "        prompt_template_idx=prompt_template_idx,\n",
    "        option_style=OPTION_STYLE,\n",
    "        n_distractors=N_DISTRACTORS,\n",
    "        # distinct_options=True,\n",
    "        distinct_options=False,\n",
    "        shuffle_clean_options=True\n",
    "    )\n",
    "    validation_set.append((clean, patch))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1714e0db",
   "metadata": {},
   "outputs": [],
   "source": [
    "len(validation_set)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "83e5dabb",
   "metadata": {},
   "outputs": [],
   "source": [
    "mt.reset_forward()\n",
    "# set_attn_implementation(mt, \"eager\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3d041d50",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.optimization import validate_q_proj_ie_on_sample_pair\n",
    "import copy\n",
    "\n",
    "# clean, patch = copy.deepcopy(validation_set[28])\n",
    "# clean.default_option_style=\"numbered\"\n",
    "# patch.default_option_style=\"numbered\"\n",
    "# clean, patch = train_set[18]\n",
    "\n",
    "clean, patch = copy.deepcopy(clean_sample), copy.deepcopy(patch_sample)\n",
    "clean.prompt_template = \"<_options_>\\nFind the <_category_> in the list.\\nAnswer:\"\n",
    "patch.prompt_template = \"<_options_>\\nFind the <_category_> in the list.\\nAnswer:\"\n",
    "\n",
    "# clean, patch = order_sample_2, order_sample_1\n",
    "# patch, clean = order_sample_1, order_sample_2\n",
    "\n",
    "# failed_case = copy.deepcopy(failed_cases[27])\n",
    "# failed_case = copy.deepcopy(failed_pos_track[\"patch_obj_idx\"][15])\n",
    "# clean = failed_case[\"clean_sample\"]\n",
    "# patch = failed_case[\"patch_sample\"]\n",
    "\n",
    "# patch.options=[\"#\"]\n",
    "\n",
    "# patch.prompt_template = \"Which among these objects mentioned above is a <_category_>?\\nAnswer:\"\n",
    "\n",
    "print(clean.prompt(), \">>\", clean.obj)\n",
    "print(patch.prompt(), \">>\", patch.obj)\n",
    "\n",
    "validation_result = validate_q_proj_ie_on_sample_pair(\n",
    "    mt=mt,\n",
    "    clean_sample=clean,\n",
    "    patch_sample=patch,\n",
    "    # heads=heads_selected,\n",
    "    # heads=backup_heads,\n",
    "    # heads=heads_selected + backup_heads,\n",
    "    # heads = overlapping_heads,\n",
    "    heads=[(35, 19)],\n",
    "    query_indices={-3: -3, -2: -2, -1: -1},\n",
    "    # query_indices = {-idx: -idx for idx in range(1, 6)},\n",
    "    verify_head_behavior_on=-1,\n",
    "    # ablate_possible_ans_info_from_options=True,\n",
    "    # amplification_scale=1.1\n",
    ")\n",
    "\n",
    "clean_obj = clean.ans_token_id\n",
    "target_obj = clean.metadata[\"track_type_obj_token_id\"]\n",
    "\n",
    "logger.debug(f\"clean obj: {mt.tokenizer.decode(clean_obj)}\")\n",
    "logger.debug(f\"target obj: {mt.tokenizer.decode(target_obj)}\")\n",
    "\n",
    "before_intervention = {\n",
    "    \"clean_rank\": validation_result[\"clean_track\"][clean_obj][0],\n",
    "    \"clean_logit\": validation_result[\"clean_track\"][clean_obj][1].logit,\n",
    "    \"target_rank\": validation_result[\"clean_track\"][target_obj][0],\n",
    "    \"target_logit\": validation_result[\"clean_track\"][target_obj][1].logit,\n",
    "}\n",
    "\n",
    "after_intervention = {\n",
    "    \"clean_rank\": validation_result[\"int_track\"][clean_obj][0],\n",
    "    \"clean_logit\": validation_result[\"int_track\"][clean_obj][1].logit,\n",
    "    \"target_rank\": validation_result[\"int_track\"][target_obj][0],\n",
    "    \"target_logit\": validation_result[\"int_track\"][target_obj][1].logit,\n",
    "}\n",
    "\n",
    "clean_rank_delta = after_intervention[\"clean_rank\"] - before_intervention[\"clean_rank\"]\n",
    "target_rank_delta = (\n",
    "    after_intervention[\"target_rank\"] - before_intervention[\"target_rank\"]\n",
    ")\n",
    "logger.info(\n",
    "    f\"Clean Prediction Rank Change: {before_intervention['clean_rank']} -> {after_intervention['clean_rank']} | Delta: {clean_rank_delta} \"\n",
    ")\n",
    "logger.info(\n",
    "    f\"Target Prediction Rank Change: {before_intervention['target_rank']} -> {after_intervention['target_rank']} | Delta: {target_rank_delta} \"\n",
    ")\n",
    "\n",
    "clean_logit_delta = (\n",
    "    after_intervention[\"clean_logit\"] - before_intervention[\"clean_logit\"]\n",
    ")\n",
    "target_logit_delta = (\n",
    "    after_intervention[\"target_logit\"] - before_intervention[\"target_logit\"]\n",
    ")\n",
    "logger.info(\n",
    "    f\"Clean Prediction Logit Change: {before_intervention['clean_logit']:.4f} -> {after_intervention['clean_logit']:.4f} | Delta: {clean_logit_delta:.4f} \"\n",
    ")\n",
    "logger.info(\n",
    "    f\"Target Prediction Logit Change: {before_intervention['target_logit']:.4f} -> {after_intervention['target_logit']:.4f} | Delta: {target_logit_delta:.4f} \"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "65931db6",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm.auto import tqdm\n",
    "\n",
    "validation_results = []\n",
    "for clean_sample, patch_sample in tqdm(validation_set):\n",
    "    \n",
    "    # clean_sample = copy.deepcopy(clean_sample)\n",
    "    # clean_sample.options = patch_sample.options\n",
    "\n",
    "    # patch_sample = copy.deepcopy(patch_sample)\n",
    "    # patch_sample.options = [\"#\"]\n",
    "    # patch_sample.prompt_template = \"Which among these objects mentioned above is a <_category_>?\\nAnswer:\"\n",
    "\n",
    "\n",
    "    # no information from the patch\n",
    "    result = validate_q_proj_ie_on_sample_pair(\n",
    "        mt=mt,\n",
    "        clean_sample=clean_sample,\n",
    "        patch_sample=patch_sample,\n",
    "        heads=heads_selected,\n",
    "        # heads=backup_heads,\n",
    "        # heads=heads_selected + backup_heads,\n",
    "        # heads = overlapping_heads,\n",
    "        query_indices={-3: -3, -2: -2, -1: -1},\n",
    "        verify_head_behavior_on=None,\n",
    "        # amplification_scale=1.5\n",
    "    )\n",
    "    validation_results.append(result)\n",
    "    print(\"=\" * 80)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "024eed06",
   "metadata": {},
   "outputs": [],
   "source": [
    "before_intervention = []\n",
    "after_intervention = []\n",
    "\n",
    "for intervention_result in validation_results:\n",
    "    clean_sample = intervention_result[\"clean_sample\"]\n",
    "    patch_sample = intervention_result[\"patch_sample\"]\n",
    "\n",
    "    clean_obj = clean_sample.ans_token_id\n",
    "    target_obj = clean_sample.metadata[\"track_type_obj_token_id\"]\n",
    "\n",
    "    before_intervention.append({\n",
    "        \"clean_rank\": intervention_result[\"clean_track\"][clean_obj][0],\n",
    "        \"clean_logit\": intervention_result[\"clean_track\"][clean_obj][1].logit,\n",
    "        \"target_rank\": intervention_result[\"clean_track\"][target_obj][0],\n",
    "        \"target_logit\": intervention_result[\"clean_track\"][target_obj][1].logit,\n",
    "    })\n",
    "\n",
    "    after_intervention.append({\n",
    "        \"clean_rank\": intervention_result[\"int_track\"][clean_obj][0],\n",
    "        \"clean_logit\": intervention_result[\"int_track\"][clean_obj][1].logit,\n",
    "        \"target_rank\": intervention_result[\"int_track\"][target_obj][0],\n",
    "        \"target_logit\": intervention_result[\"int_track\"][target_obj][1].logit,\n",
    "    })"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "573b7692",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "clean_rank_delta = [\n",
    "    after[\"clean_rank\"] - before[\"clean_rank\"]\n",
    "    for before, after in zip(before_intervention, after_intervention)\n",
    "]\n",
    "target_rank_delta = [\n",
    "    after[\"target_rank\"] - before[\"target_rank\"]\n",
    "    for before, after in zip(before_intervention, after_intervention)\n",
    "]\n",
    "\n",
    "clean_rank_delta, target_rank_delta = np.array(clean_rank_delta), np.array(\n",
    "    target_rank_delta\n",
    ")\n",
    "print(f\"clean_rank_delta: {clean_rank_delta.mean():.4f} ± {clean_rank_delta.std():.4f}\")\n",
    "print(\n",
    "    f\"target_rank_delta: {target_rank_delta.mean():.4f} ± {target_rank_delta.std():.4f}\"\n",
    ")\n",
    "\n",
    "clean_rank_after_intervention = [after[\"clean_rank\"] for after in after_intervention]\n",
    "clean_rank_after_intervention = np.array(clean_rank_after_intervention)\n",
    "print(\n",
    "    f\"clean_rank_after_intervention: {clean_rank_after_intervention.mean():.4f} ± {clean_rank_after_intervention.std():.4f}\"\n",
    ")\n",
    "\n",
    "target_rank_after_intervention = [after[\"target_rank\"] for after in after_intervention]\n",
    "target_rank_after_intervention = np.array(target_rank_after_intervention)\n",
    "print(\n",
    "    f\"target_rank_after_intervention: {target_rank_after_intervention.mean():.4f} ± {target_rank_after_intervention.std():.4f}\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10522073",
   "metadata": {},
   "outputs": [],
   "source": [
    "clean_logit_delta = [\n",
    "    after[\"clean_logit\"] - before[\"clean_logit\"]\n",
    "    for before, after in zip(before_intervention, after_intervention)\n",
    "]\n",
    "target_logit_delta = [\n",
    "    after[\"target_logit\"] - before[\"target_logit\"]\n",
    "    for before, after in zip(before_intervention, after_intervention)\n",
    "]\n",
    "clean_logit_delta, target_logit_delta = np.array(clean_logit_delta), np.array(target_logit_delta)\n",
    "print(f\"clean_logit_delta: {clean_logit_delta.mean():.4f} ± {clean_logit_delta.std():.4f}\")\n",
    "print(f\"target_logit_delta: {target_logit_delta.mean():.4f} ± {target_logit_delta.std():.4f}\")\n",
    "\n",
    "clean_logit_after_intervention = [\n",
    "    after[\"clean_logit\"]\n",
    "    for after in after_intervention\n",
    "]\n",
    "clean_logit_after_intervention = np.array(clean_logit_after_intervention)\n",
    "print(f\"clean_logit_after_intervention: {clean_logit_after_intervention.mean():.4f} ± {clean_logit_after_intervention.std():.4f}\")\n",
    "\n",
    "target_logit_after_intervention = [\n",
    "    after[\"target_logit\"]\n",
    "    for after in after_intervention\n",
    "]\n",
    "target_logit_after_intervention = np.array(target_logit_after_intervention)\n",
    "print(f\"target_logit_after_intervention: {target_logit_after_intervention.mean():.4f} ± {target_logit_after_intervention.std():.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d484f25",
   "metadata": {},
   "outputs": [],
   "source": [
    "top_1 = sum([1 for after in after_intervention if after[\"target_rank\"] == 1])\n",
    "top_1 / len(after_intervention)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6fbdd059",
   "metadata": {},
   "outputs": [],
   "source": [
    "counter_patch_type_top_option = 0\n",
    "failed_cases = []\n",
    "\n",
    "for intervention_result in validation_results:\n",
    "    clean_sample = intervention_result[\"clean_sample\"]\n",
    "    patch_sample = intervention_result[\"patch_sample\"]\n",
    "    int_track = intervention_result[\"int_track\"]\n",
    "    clean_track = intervention_result[\"clean_track\"]\n",
    "    if (\n",
    "        int_track[list(int_track.keys())[0]][1].token_id\n",
    "        == clean_sample.metadata[\"track_type_obj_token_id\"]\n",
    "    ): \n",
    "        counter_patch_type_top_option += 1\n",
    "    else:\n",
    "        failed_cases.append(\n",
    "            {\n",
    "                \"clean_sample\": clean_sample,\n",
    "                \"patch_sample\": patch_sample,\n",
    "                \"int_track\": int_track,\n",
    "                \"clean_track\": clean_track,\n",
    "            }\n",
    "        )\n",
    "\n",
    "top_1_accuracy = counter_patch_type_top_option / len(validation_results)\n",
    "print(\n",
    "    f\"Counterfactual patching accuracy: {top_1_accuracy:.4f} ({counter_patch_type_top_option}/{len(validation_results)})\"\n",
    ")\n",
    "print(f\"{len(failed_cases)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2a66e502",
   "metadata": {},
   "outputs": [],
   "source": [
    "for failed_case in failed_cases:\n",
    "    clean_sample = failed_case[\"clean_sample\"]\n",
    "    patch_sample = failed_case[\"patch_sample\"]\n",
    "    int_track = failed_case[\"int_track\"]\n",
    "    clean_track = failed_case[\"clean_track\"]\n",
    "\n",
    "    print(\"Clean Sample:\")\n",
    "    print(clean_sample.prompt(), \">>\", clean_sample.obj)\n",
    "\n",
    "    print(\"-\" * 100)\n",
    "    print(\n",
    "        \"Track: \",\n",
    "        clean_sample.metadata[\"track_type_obj\"], \" | Token\"\n",
    "        f\"\\\"{mt.tokenizer.decode(clean_sample.metadata['track_type_obj_token_id'])}\\\"\",\n",
    "    )\n",
    "    print(\"Clean:\", clean_sample.obj, f\"(Token: {mt.tokenizer.decode(clean_sample.ans_token_id)})\")\n",
    "    print(\"-\" * 100)\n",
    "\n",
    "    clean_track = [pred for tok_id, (rank, pred) in clean_track.items()]\n",
    "    print(f\"Clean Track: {json.dumps([str(pred) for pred in clean_track], indent=4)}\")\n",
    "\n",
    "    int_track = [pred for tok_id, (rank, pred) in int_track.items()]\n",
    "    print(\n",
    "        f\"Intervened Track: {json.dumps([str(pred) for pred in int_track], indent=4)}\"\n",
    "    )\n",
    "    print(\"=\" * 100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fea0c607",
   "metadata": {},
   "outputs": [],
   "source": [
    "#! find the positions after the patched intervention.\n",
    "# Is it looking at the first one, or the position of the \n",
    "# previous answer?\n",
    "\n",
    "failed_pos_track = {\n",
    "    \"clean_obj_idx\": [],\n",
    "    \"patch_obj_idx\": [],\n",
    "    \"first_obj_idx\": [],\n",
    "    \"other\": []\n",
    "}\n",
    "\n",
    "for failed_case in failed_cases:\n",
    "    clean_sample = failed_case[\"clean_sample\"]\n",
    "    patch_sample = failed_case[\"patch_sample\"]\n",
    "    int_track = failed_case[\"int_track\"]\n",
    "    clean_track = failed_case[\"clean_track\"]\n",
    "    clean_obj_idx = clean_sample.obj_idx\n",
    "    patch_obj_idx = patch_sample.obj_idx\n",
    "\n",
    "    int_top_tok = list(int_track.keys())[0]\n",
    "    int_top_obj = int_track[int_top_tok][1].token.strip()\n",
    "    int_top_idx = clean_sample.options.index(int_top_obj)\n",
    "\n",
    "    if int_top_idx == clean_obj_idx:\n",
    "        failed_pos_track[\"clean_obj_idx\"].append(failed_case)\n",
    "    elif int_top_idx == patch_obj_idx:\n",
    "        failed_pos_track[\"patch_obj_idx\"].append(failed_case)\n",
    "    elif int_top_idx == 0:\n",
    "        failed_pos_track[\"first_obj_idx\"].append(failed_case)\n",
    "    else:\n",
    "        failed_pos_track[\"other\"].append(failed_case)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d7a2e01",
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib import pyplot as plt\n",
    "x_vals = failed_pos_track.keys()\n",
    "y_vals = [len(failed_pos_track[key]) for key in x_vals]\n",
    "plt.bar(x_vals, y_vals)\n",
    "plt.xlabel(\"Failed Position Types\")\n",
    "plt.ylabel(\"Number of Failed Cases\")\n",
    "plt.title(\"Failed Position Types Distribution\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d0d5b31f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.functional import verify_head_patterns\n",
    "from src.selection.optimization import validate_q_proj_ie_on_sample_pair\n",
    "\n",
    "failed_case = copy.deepcopy(failed_pos_track[\"patch_obj_idx\"][11])\n",
    "# failed_case = copy.deepcopy(failed_cases[18])\n",
    "\n",
    "clean = failed_case[\"clean_sample\"]\n",
    "patch = failed_case[\"patch_sample\"]\n",
    "\n",
    "# clean, patch = copy.deepcopy(clean_sample), copy.deepcopy(patch_sample)\n",
    "\n",
    "clean_tokenized = prepare_input(prompts=clean.prompt(), tokenizer=mt)\n",
    "patch_tokenized = prepare_input(prompts=patch.prompt(), tokenizer=mt)\n",
    "\n",
    "validation_result = validate_q_proj_ie_on_sample_pair(\n",
    "    mt=mt,\n",
    "    clean_sample=clean,\n",
    "    patch_sample=patch,\n",
    "    # heads=HEADS,\n",
    "    heads=heads_selected,\n",
    "    # heads=backup_heads,\n",
    "    # heads=heads_selected + backup_heads,\n",
    "    # heads = overlapping_heads,\n",
    "    # heads=[(35, 19)],\n",
    "    query_indices={-3: -3, -2: -2, -1: -1},\n",
    "    verify_head_behavior_on=-1,\n",
    "    # ablate_possible_ans_info_from_options=True,\n",
    "    # amplification_scale=2.0,\n",
    "    patch_args = {\n",
    "        \"batch_size\": 2 * (N_DISTRACTORS + 1),\n",
    "        \"task\": select_task,\n",
    "        \"prompt_template_idx\": prompt_template_idx,\n",
    "        \"option_style\": \"single_line\",\n",
    "        \"distinct_options\": False,\n",
    "        \"n_distractors\": N_DISTRACTORS\n",
    "    }\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9b441410",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm.auto import tqdm\n",
    "\n",
    "validation_results = []\n",
    "for clean_sample, patch_sample in tqdm(validation_set):\n",
    "    # clean_sample = copy.deepcopy(clean_sample)\n",
    "    # clean_sample.options = patch_sample.options\n",
    "    result = validate_q_proj_ie_on_sample_pair(\n",
    "        mt=mt,\n",
    "        clean_sample=clean_sample,\n",
    "        patch_sample=patch_sample,\n",
    "        heads=heads_selected,\n",
    "        # heads=backup_heads,\n",
    "        # heads=heads_selected + backup_heads,\n",
    "        # heads = overlapping_heads,\n",
    "        query_indices={-3: -3, -2: -2, -1: -1},\n",
    "        verify_head_behavior_on=None,\n",
    "        # amplification_scale=1.5\n",
    "        patch_args={\n",
    "            \"batch_size\": N_DISTRACTORS + 1,\n",
    "            \"task\": select_task,\n",
    "            \"prompt_template_idx\": prompt_template_idx,\n",
    "            \"option_style\": \"single_line\",\n",
    "            # \"distinct_options\": False,\n",
    "            \"distinct_options\": True,\n",
    "            \"n_distractors\": N_DISTRACTORS,\n",
    "        },\n",
    "    )\n",
    "    validation_results.append(result)\n",
    "    print(\"=\" * 80)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7fb117c7",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "connection",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
