{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f8ec34ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bfae7235",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "\n",
    "import sys\n",
    "\n",
    "sys.path.append(\"../\")\n",
    "\n",
    "##################################################################\n",
    "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
    "# os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0,1,2,3,4,5,6,7\"\n",
    "##################################################################\n",
    "\n",
    "import logging\n",
    "from src.utils import logging_utils\n",
    "from src.utils import env_utils\n",
    "\n",
    "logger = logging.getLogger(__name__)\n",
    "\n",
    "logging.basicConfig(\n",
    "    level=logging.DEBUG,\n",
    "    format=logging_utils.DEFAULT_FORMAT,\n",
    "    datefmt=logging_utils.DEFAULT_DATEFMT,\n",
    "    stream=sys.stdout,\n",
    ")\n",
    "\n",
    "import torch\n",
    "import transformers\n",
    "\n",
    "logger.info(f\"{torch.__version__=}, {torch.version.cuda=}\")\n",
    "logger.info(\n",
    "    f\"{torch.cuda.is_available()=}, {torch.cuda.device_count()=}, {torch.cuda.get_device_name()=}\"\n",
    ")\n",
    "logger.info(f\"{transformers.__version__=}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc7720ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.utils.training_utils import get_device_map\n",
    "\n",
    "# model_key = \"meta-llama/Llama-3.2-3B\"\n",
    "# model_key = \"meta-llama/Llama-3.1-8B\"\n",
    "model_key = \"meta-llama/Llama-3.3-70B-Instruct\"\n",
    "# model_key = \"meta-llama/Llama-3.1-405B-Instruct\"\n",
    "\n",
    "# model_key = \"google/gemma-2-9b-it\"\n",
    "# model_key = \"google/gemma-3-12b-it\"\n",
    "# model_key = \"google/gemma-2-27b-it\"\n",
    "\n",
    "# model_key = \"deepseek-ai/DeepSeek-R1-Distill-Llama-8B\"\n",
    "\n",
    "# model_key = \"allenai/OLMo-2-1124-7B-Instruct\"\n",
    "# model_key = \"allenai/OLMo-7B-0424-hf\"\n",
    "\n",
    "# model_key = \"Qwen/Qwen2-7B\"\n",
    "# model_key = \"Qwen/Qwen2.5-14B-Instruct\"\n",
    "# model_key = \"Qwen/Qwen2.5-32B-Instruct\"\n",
    "# model_key = \"Qwen/Qwen2.5-72B-Instruct\"\n",
    "\n",
    "# model_key = \"Qwen/Qwen3-1.7B\"\n",
    "# model_key = \"Qwen/Qwen3-4B\"\n",
    "# model_key = \"Qwen/Qwen3-8B\"\n",
    "# model_key = \"Qwen/Qwen3-14B\"\n",
    "# model_key = \"Qwen/Qwen3-32B\"\n",
    "\n",
    "# device_map = get_device_map(model_key, 30, n_gpus=8)\n",
    "# device_map"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "683855df",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.models import ModelandTokenizer\n",
    "\n",
    "# from transformers import BitsAndBytesConfig\n",
    "\n",
    "mt = ModelandTokenizer(\n",
    "    model_key=model_key,\n",
    "    torch_dtype=torch.bfloat16,\n",
    "    # device_map=device_map,\n",
    "    device_map=\"auto\",\n",
    "    # quantization_config = BitsAndBytesConfig(\n",
    "    #     # load_in_4bit=True\n",
    "    #     load_in_8bit=True\n",
    "    # )\n",
    "    attn_implementation=\"eager\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "080021e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# file_path = os.path.join(\n",
    "#     env_utils.DEFAULT_DATA_DIR,\n",
    "#     \"selection\",\n",
    "#     # \"profession.json\"\n",
    "#     # \"nationality.json\"\n",
    "#     \"objects.json\",\n",
    "# )\n",
    "\n",
    "# with open(file_path, \"r\") as f:\n",
    "#     temp = json.load(f)\n",
    "\n",
    "# for cat in temp[\"categories\"]:\n",
    "#     temp[\"categories\"][cat] = [obj.capitalize() for obj in temp[\"categories\"][cat]]\n",
    "\n",
    "# with open(file_path, \"w\") as f:\n",
    "#     json.dump(temp, f, indent=4)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a62d97cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.data import SelectOneTask, SelectOrderTask\n",
    "\n",
    "#################################################################################\n",
    "# TASK_CLS = SelectOrderTask\n",
    "# prompt_template_idx = 1\n",
    "TASK_CLS = SelectOneTask\n",
    "prompt_template_idx = 3\n",
    "N_DISTRACTORS = 5\n",
    "OPTION_STYLE = \"single_line\"\n",
    "#################################################################################\n",
    "\n",
    "select_task = TASK_CLS.load(\n",
    "    path=os.path.join(\n",
    "        env_utils.DEFAULT_DATA_DIR, \n",
    "        \"selection\", \n",
    "        # \"profession.json\"\n",
    "        # \"nationality.json\"\n",
    "        \"objects.json\"\n",
    "    )\n",
    ")\n",
    "\n",
    "# select_task.filter_single_token(tokenizer=mt.tokenizer, prefix=\" \")\n",
    "\n",
    "print(select_task)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "45087a7f",
   "metadata": {},
   "outputs": [],
   "source": [
    "sample = select_task.get_random_sample(\n",
    "    mt = mt,\n",
    "    option_style=OPTION_STYLE,\n",
    "    prompt_template_idx=prompt_template_idx,\n",
    "    obj_idx=2,\n",
    "    # category=\"actor\",\n",
    "    # category=\"Brazil\"\n",
    "    category=\"fruit\",\n",
    "    filter_by_lm_prediction=False,\n",
    ")\n",
    "\n",
    "print(sample)\n",
    "print(sample.prompt())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f66399b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.utils import verify_correct_option\n",
    "# sample.prompt_template = select_prof.prompt_templates[3]\n",
    "print(f'\"{sample.prompt()}\"', \">>\", sample.obj)\n",
    "\n",
    "verify_correct_option(\n",
    "    mt=mt,\n",
    "    target=sample.obj,\n",
    "    options=sample.options,\n",
    "    input=sample.prompt()\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea0499b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.functional import generate_with_patch\n",
    "\n",
    "gen = generate_with_patch(\n",
    "    mt = mt,\n",
    "    inputs = sample.prompt(),\n",
    "    max_new_tokens=20,\n",
    "    do_sample=False,\n",
    "    remove_prefix=True\n",
    ")[0]\n",
    "print(f'\"{gen}\"', \">>\", sample.obj)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2e8a8aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "mt.n_layer, mt.config.num_attention_heads"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a35516e",
   "metadata": {},
   "outputs": [],
   "source": [
    "llama_70_heads = [\n",
    "    (33, 45),\n",
    "    (33, 18),\n",
    "    (34, 1),\n",
    "    (34, 6),\n",
    "    (34, 7),\n",
    "    (35, 19),\n",
    "    (39, 40),\n",
    "    (42, 30),\n",
    "    (47, 18),\n",
    "    (52, 58),\n",
    "]\n",
    "\n",
    "qwen_72_heads = [\n",
    "    (62, 1),\n",
    "    (60, 9),\n",
    "    (64, 8),\n",
    "    (62, 0),\n",
    "    (62, 45),\n",
    "    (59, 59),\n",
    "    (71, 28),\n",
    "    (64, 12),\n",
    "    (61, 7),\n",
    "    (64, 13),\n",
    "    (67, 53),\n",
    "    (67, 51),\n",
    "    (54, 44),\n",
    "    (57, 5),\n",
    "    (59, 60),\n",
    "    (71, 25),\n",
    "    (62, 7),\n",
    "    (64, 9),\n",
    "    (62, 23),\n",
    "    (65, 40),\n",
    "]\n",
    "\n",
    "qwen_32_heads = [\n",
    "    (51, 11),\n",
    "    (48, 4),\n",
    "    (52, 21),\n",
    "    (54, 35),\n",
    "    (48, 8),\n",
    "    (50, 6),\n",
    "    (48, 9),\n",
    "    (48, 32),\n",
    "    (52, 10),\n",
    "    (45, 11),\n",
    "    (45, 13),\n",
    "    (48, 34),\n",
    "    (53, 16),\n",
    "    (50, 12),\n",
    "    (49, 2),\n",
    "    (54, 38),\n",
    "    (55, 4),\n",
    "    (50, 27),\n",
    "    (54, 33),\n",
    "    (50, 14),\n",
    "]\n",
    "\n",
    "# HEADS = [(35, 19)]\n",
    "# with open(\"optimized_heads.json\", \"r\") as f:\n",
    "#     HEADS = json.load(f)\n",
    "\n",
    "# with open(\"category_wise_heads.json\", \"r\") as f:\n",
    "#     category_wise_heads = json.load(f)\n",
    "# HEADS = [\n",
    "#     (layer_idx, head_idx)\n",
    "#     for layer_idx, head_idx, score in category_wise_heads[\"all\"][:100]\n",
    "# ]\n",
    "# HEADS = [(layer_idx, head_idx) for layer_idx, head_idx in HEADS if layer_idx < 61]\n",
    "\n",
    "HEADS = qwen_32_heads\n",
    "# HEADS = llama_70_heads\n",
    "print(len(HEADS))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe381b2f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib import pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "# optimized_path = os.path.join(\n",
    "#     env_utils.DEFAULT_RESULTS_DIR,\n",
    "#     \"selection/optimized_backup_heads\",\n",
    "#     mt.name.split(\"/\")[-1],\n",
    "#     f\"{select_task.task_name}.npz\"\n",
    "# )\n",
    "\n",
    "optimized_path = os.path.join(\n",
    "    env_utils.DEFAULT_RESULTS_DIR,\n",
    "    \"selection/optimized_heads\",\n",
    "    model_key.split(\"/\")[-1],\n",
    "    \"distinct_options\",\n",
    "    # f\"{select_task.task_name}\",\n",
    "    \"select_one\",\n",
    "    # \"legacy\",\n",
    "    \"epoch_10.npz\"\n",
    ")\n",
    "\n",
    "optimization_results = np.load(optimized_path, allow_pickle=True)\n",
    "plt.plot(optimization_results[\"losses\"])\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f87810f",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(20, 10))\n",
    "\n",
    "optimal_head_mask = torch.tensor(optimization_results[\"optimal_mask\"]).to(torch.float32)\n",
    "optimal_head_mask[52:, :] = 0.0\n",
    "\n",
    "plt.imshow(\n",
    "    optimal_head_mask.T.numpy(),\n",
    "    cmap=\"Blues\",\n",
    "    aspect=\"auto\",\n",
    "    vmin=0,\n",
    "    vmax=1,\n",
    ")\n",
    "\n",
    "optimized_heads = torch.nonzero(optimal_head_mask > 0.5, as_tuple=False).tolist()\n",
    "optimized_heads = [\n",
    "    (layer_idx, head_idx) for layer_idx, head_idx in optimized_heads\n",
    "]\n",
    "print(len(optimized_heads))\n",
    "\n",
    "HEADS = optimized_heads\n",
    "\n",
    "(35, 19) in HEADS, (35, 19) in optimized_heads\n",
    "# [(29, 3) in HEADS]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "85a9a2a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.attention import get_attention_matrices\n",
    "from src.selection.functional import (\n",
    "    verify_head_patterns,\n",
    "    get_patches_to_verify_independent_enrichment,\n",
    ")\n",
    "\n",
    "attn_pattern = verify_head_patterns(\n",
    "    prompt=sample.prompt(option_style=\"single_line\"),\n",
    "    options=sample.options,\n",
    "    pivot=sample.subj,\n",
    "    mt=mt,\n",
    "    heads=HEADS,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "831efaff",
   "metadata": {},
   "source": [
    "## Cosine Similarity"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a333c4c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm.auto import tqdm\n",
    "import random\n",
    "\n",
    "###########################################\n",
    "limit = 100\n",
    "prompt_template_idx=3\n",
    "# option_style=\"single_line\" --- IGNORE ---\n",
    "###########################################\n",
    "\n",
    "category_wise_samples = {}\n",
    "for category in tqdm(select_task.categories):\n",
    "    category_wise_samples[category] = []\n",
    "    for idx in range(limit):\n",
    "        sample = select_task.get_random_sample(\n",
    "            mt=mt,\n",
    "            category=category,  \n",
    "            prompt_template_idx=prompt_template_idx,\n",
    "            n_distractors=random.choice(range(2, 7)),\n",
    "            filter_by_lm_prediction=True,\n",
    "            option_style=random.choice([\"single_line\", \"numbered\", \"bulleted\"])\n",
    "        )\n",
    "        category_wise_samples[category].append(sample)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "df63f81e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.tokens import prepare_input\n",
    "from src.selection.functional import cache_q_projections\n",
    "import random\n",
    "import copy\n",
    "\n",
    "heads = copy.deepcopy(optimized_heads)\n",
    "\n",
    "# all_heads = [\n",
    "#     (layer_idx, head_idx)\n",
    "#     for layer_idx in range(mt.n_layer)\n",
    "#     for head_idx in range(mt.config.num_attention_heads)\n",
    "# ]\n",
    "# heads = random.sample(all_heads, len(optimized_heads))\n",
    "\n",
    "heads = sorted(heads)\n",
    "\n",
    "logger.info(\"Caching the predicate directions\")\n",
    "category_wise_q_states = {category: None for category in select_task.categories}\n",
    "\n",
    "for category in tqdm(select_task.categories):\n",
    "    prompts = [sample.prompt() for sample in category_wise_samples[category]]\n",
    "    tokenized_prompts = prepare_input(\n",
    "        prompts=prompts,\n",
    "        tokenizer=mt.tokenizer,\n",
    "    )\n",
    "    category_wise_q_states[category] = cache_q_projections(\n",
    "        mt=mt,\n",
    "        input=tokenized_prompts,\n",
    "        heads=heads,\n",
    "        token_indices=[[-3, -2, -1] for _ in range(len(prompts))],\n",
    "        projection_signature=\".q_proj\"\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a9215095",
   "metadata": {},
   "outputs": [],
   "source": [
    "category_wise_q_states[\"fruit\"][0][heads[0] + (-1,)].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb825ca7",
   "metadata": {},
   "outputs": [],
   "source": [
    "category_wise_cosine_sim = {}\n",
    "\n",
    "for category in tqdm(select_task.categories):\n",
    "    category_wise_cosine_sim[category] = {}\n",
    "    q_states = category_wise_q_states[category]\n",
    "    for head in heads:\n",
    "        q_head = torch.stack(\n",
    "            [q_state[head + (-1,)] for q_state in q_states], dim=0\n",
    "        )\n",
    "        q_head = q_head / q_head.norm(dim=-1, keepdim=True)\n",
    "        cosine_sim = (q_head @ q_head.T).to(torch.float32)\n",
    "        category_wise_cosine_sim[category][head] = cosine_sim\n",
    "\n",
    "\n",
    "for head in heads:\n",
    "    sims = torch.stack(\n",
    "        [category_wise_cosine_sim[category][head].mean() for category in select_task.categories],\n",
    "        dim=0\n",
    "    )\n",
    "    logger.info(f\"Head {head} >> {sims.mean().item():.4f} ± {sims.std().item():.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c48560de",
   "metadata": {},
   "outputs": [],
   "source": [
    "category_wise_cosine_sim[\"fruit\"].keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "710610b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "for category in select_task.categories:\n",
    "    sim_matrix = category_wise_cosine_sim[category][(35, 19)]\n",
    "    logger.info(f\"{category} >> {sim_matrix.mean().item():.4f} ± {sim_matrix.std().item():.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f6c9c870",
   "metadata": {},
   "outputs": [],
   "source": [
    "across_category_cosine_sim = {}\n",
    "for head in heads:\n",
    "    across_category_cosine_sim[head] = {}\n",
    "    for category in select_task.categories:\n",
    "        across_category_cosine_sim[head][category] = {}\n",
    "        q_states = category_wise_q_states[category]\n",
    "        q_head = torch.stack(\n",
    "            [q_state[head + (-1,)] for q_state in q_states], dim=0\n",
    "        )\n",
    "        q_head = q_head / q_head.norm(dim=-1, keepdim=True)\n",
    "        for other_category in select_task.categories:\n",
    "            other_q_states = category_wise_q_states[other_category]\n",
    "            other_q_head = torch.stack(\n",
    "                [q_state[head + (-1,)] for q_state in other_q_states], dim=0\n",
    "            )\n",
    "            other_q_head = other_q_head / other_q_head.norm(dim=-1, keepdim=True)\n",
    "            cosine_sim = (q_head @ other_q_head.T).to(torch.float32)\n",
    "            across_category_cosine_sim[head][category][other_category] = cosine_sim"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5c2e0a78",
   "metadata": {},
   "outputs": [],
   "source": [
    "for category in select_task.categories:\n",
    "    print(f\"Category: {category}\")\n",
    "    for other_category in select_task.categories:\n",
    "        sim_matrix = across_category_cosine_sim[(35, 19)][category][other_category]\n",
    "        logger.info(\n",
    "            f\"{other_category} >> {sim_matrix.mean().item():.4f} ± {sim_matrix.std().item():.4f}\"\n",
    "        )\n",
    "    \n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1a6f1a32",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "fb706674",
   "metadata": {},
   "source": [
    "## Composition"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "441141eb",
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "import copy\n",
    "from src.tokens import prepare_input\n",
    "from src.selection.utils import KeyedSet\n",
    "\n",
    "\n",
    "def get_composition_samples(\n",
    "    task: SelectOneTask,\n",
    "    mt: ModelandTokenizer,\n",
    "    unit_categories: list[str],\n",
    "    prompt_template_idx=3,\n",
    "    option_style=\"single_line\",\n",
    "    filter_by_lm_prediction: bool = True,\n",
    "    n_distractors: int = 5,\n",
    "):\n",
    "    unit_samples = []\n",
    "    for category in unit_categories:\n",
    "        unit_samples.append(\n",
    "            task.get_random_sample(\n",
    "                mt=mt,\n",
    "                category=category,\n",
    "                prompt_template_idx=prompt_template_idx,\n",
    "                option_style=option_style,\n",
    "                filter_by_lm_prediction=filter_by_lm_prediction,\n",
    "                n_distractors=n_distractors,\n",
    "                exclude_distractor_categories=task.exclude_for_category(category)\n",
    "            )\n",
    "        )\n",
    "\n",
    "    exclude_objs = [sample.obj for sample in unit_samples]\n",
    "    clean_category = random.choice(list(set(task.categories) - set(unit_categories)))\n",
    "    target_category = random.choice(unit_categories)\n",
    "    target_obj = random.choice(\n",
    "        (\n",
    "            KeyedSet(\n",
    "                items=task.category_wise_examples[target_category],\n",
    "                tokenizer=mt.tokenizer,\n",
    "            )\n",
    "            - KeyedSet(items=exclude_objs, tokenizer=mt.tokenizer)\n",
    "        ).values\n",
    "    )\n",
    "    obj_idx = random.randint(0, n_distractors)\n",
    "    target_idx = random.choice([i for i in range(n_distractors + 1) if i != obj_idx])\n",
    "\n",
    "    exclude_clean_distractor_categories = copy.deepcopy(unit_categories)\n",
    "    for category in unit_categories:\n",
    "        # print(category, task.exclude_for_category(category))\n",
    "        exclude_clean_distractor_categories.extend(task.exclude_for_category(category))\n",
    "    exclude_clean_distractor_categories = list(set(exclude_clean_distractor_categories))\n",
    "    print(f\"{len(exclude_clean_distractor_categories)} | {exclude_clean_distractor_categories}\")\n",
    "\n",
    "    clean_sample = task.get_random_sample(\n",
    "        mt=mt,\n",
    "        category=clean_category,\n",
    "        obj_idx=obj_idx,\n",
    "        prompt_template_idx=prompt_template_idx,\n",
    "        option_style=option_style,\n",
    "        filter_by_lm_prediction=filter_by_lm_prediction,\n",
    "        n_distractors=n_distractors,\n",
    "        exclude_objs=exclude_objs,\n",
    "        exclude_distractor_categories=exclude_clean_distractor_categories,\n",
    "        insert_distractor=[(target_obj, target_idx)],\n",
    "    )\n",
    "    comp_sample = copy.deepcopy(clean_sample)\n",
    "    comp_sample.category = \" or \".join(unit_categories)\n",
    "    comp_sample.obj = target_obj\n",
    "    comp_sample.obj_idx = target_idx\n",
    "    assert (\n",
    "        \"<_category_>\" in comp_sample.prompt_template\n",
    "    ), \"Composition prompt template must have <_category_> token\"\n",
    "    # comp_sample.prompt_template = comp_sample.prompt_template.replace(\n",
    "    #     \"<_category_>\", \" or \".join(unit_categories)\n",
    "    # )\n",
    "\n",
    "    if filter_by_lm_prediction:\n",
    "        tokenized = prepare_input(tokenizer=mt, prompts=comp_sample.prompt())\n",
    "        is_correct, predictions, track_options = verify_correct_option(\n",
    "            mt=mt,\n",
    "            target=comp_sample.obj,\n",
    "            options=comp_sample.options,\n",
    "            input=tokenized,\n",
    "        )\n",
    "        comp_sample.metadata[\"tokenized\"] = tokenized\n",
    "\n",
    "        logger.debug(comp_sample.prompt())\n",
    "        logger.debug(\n",
    "            f\"{comp_sample.subj} | {comp_sample.category} -> {comp_sample.obj} | pred={[str(p) for p in predictions]}\"\n",
    "        )\n",
    "        if not is_correct:\n",
    "            logger.error(\n",
    "                f'Prediction mismatch: {track_options[list(track_options.keys())[0]]}[\"{mt.tokenizer.decode(predictions[0].token_id)}\"] != {comp_sample.ans_token_id}[\"{mt.tokenizer.decode(comp_sample.ans_token_id)}\"]'\n",
    "            )\n",
    "            return get_composition_samples(\n",
    "                task=task,\n",
    "                mt=mt,\n",
    "                unit_categories=unit_categories,\n",
    "                prompt_template_idx=prompt_template_idx,\n",
    "                option_style=option_style,\n",
    "                filter_by_lm_prediction=filter_by_lm_prediction,\n",
    "                n_distractors=n_distractors,\n",
    "            )\n",
    "        comp_sample.prediction = predictions\n",
    "\n",
    "    return unit_samples, clean_sample, comp_sample\n",
    "\n",
    "\n",
    "unit_samples, clean_sample, comp_sample = get_composition_samples(\n",
    "    task=select_task,\n",
    "    mt=mt,\n",
    "    unit_categories=[\"fruit\", \"vehicle\"],\n",
    "    prompt_template_idx=3,\n",
    "    option_style=\"single_line\",\n",
    "    filter_by_lm_prediction=True,\n",
    "    n_distractors=5,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "274252dc",
   "metadata": {},
   "outputs": [],
   "source": [
    "# failed_case_idx = 10\n",
    "# failed_case = failed_cases[failed_case_idx]\n",
    "# unit_samples = failed_case[\"unit_samples\"]\n",
    "# clean_sample = failed_case[\"clean_sample\"]\n",
    "# comp_sample = failed_case[\"comp_sample\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d62ba5a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(clean_sample.prompt(), \">>\", clean_sample.obj)\n",
    "print(comp_sample.prompt(), \">>\", comp_sample.obj)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b5eb8b5e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.functional import generate_with_patch\n",
    "mt.reset_forward()\n",
    "mt.set_attn_implementation(\"eager\")\n",
    "\n",
    "for sample in unit_samples:\n",
    "    print(sample.prompt(), \">>\", sample.obj)\n",
    "    attn_pattern = verify_head_patterns(\n",
    "        prompt=sample.prompt(),\n",
    "        options=sample.options,\n",
    "        pivot=sample.subj,\n",
    "        mt=mt,\n",
    "        heads=HEADS,\n",
    "        # heads=[(35, 19)],\n",
    "        # generate_full_answer=True,\n",
    "        query_index=-1\n",
    "    )\n",
    "\n",
    "for sample in [clean_sample, comp_sample]:\n",
    "    print(sample.prompt(), \">>\", sample.obj)\n",
    "    attn_pattern = verify_head_patterns(\n",
    "        prompt=sample.prompt(),\n",
    "        options=sample.options,\n",
    "        pivot=sample.subj,\n",
    "        mt=mt,\n",
    "        heads=HEADS,\n",
    "        # heads=[(35, 19)],\n",
    "        # generate_full_answer=True,\n",
    "        query_index=-1\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bfb49ea4",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.functional import cache_q_projections\n",
    "from src.utils.typing import TokenizerOutput\n",
    "from src.tokens import prepare_input, find_token_range\n",
    "\n",
    "query_indices = {-3: -3, -2: -2, -1: -1}\n",
    "\n",
    "unit_query_states = []\n",
    "for sample in unit_samples:\n",
    "    if \"tokenized\" in sample.metadata:\n",
    "        tokenized = TokenizerOutput(data=sample.metadata[\"tokenized\"])\n",
    "    else:\n",
    "        tokenized = prepare_input(prompts=sample.prompt(), tokenizer=mt.tokenizer)\n",
    "    unit_query_states.append(\n",
    "        cache_q_projections(\n",
    "            mt=mt,\n",
    "            input=tokenized,\n",
    "            heads=optimized_heads,\n",
    "            token_indices=[list(query_indices.keys())],\n",
    "            return_output=False\n",
    "        )[0]\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4363c11e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.functional import PatchSpec, patch_with_baukit, interpret_logits\n",
    "from src.selection.utils import get_first_token_id\n",
    "\n",
    "combined_q_proj_patches = []\n",
    "\n",
    "for layer_idx, head_idx, query_idx in unit_query_states[0]:\n",
    "    proj = torch.stack([\n",
    "        query_state[(layer_idx, head_idx, query_idx)] for query_state in unit_query_states\n",
    "    ]).squeeze().sum(dim=0)\n",
    "    combined_q_proj_patches.append(\n",
    "        PatchSpec(\n",
    "            location=(\n",
    "                mt.attn_module_name_format.format(layer_idx) + \".q_proj\",\n",
    "                head_idx,\n",
    "                query_idx,\n",
    "            ),\n",
    "            patch=proj,\n",
    "        )\n",
    "    )\n",
    "\n",
    "clean_tokenized = prepare_input(prompts=clean_sample.prompt(), tokenizer=mt.tokenizer)\n",
    "\n",
    "comp_attn = verify_head_patterns(\n",
    "    prompt=clean_sample.prompt(),\n",
    "    tokenized_prompt=clean_tokenized,\n",
    "    mt=mt,\n",
    "    heads=optimized_heads,\n",
    "    # heads=[(35, 19)],\n",
    "    # generate_full_answer=True,\n",
    "    query_index=-1,\n",
    "    query_patches=combined_q_proj_patches,\n",
    ")\n",
    "\n",
    "patched_out = patch_with_baukit(\n",
    "    mt=mt,\n",
    "    inputs=clean_tokenized,\n",
    "    patches=combined_q_proj_patches,\n",
    ")\n",
    "patched_logits = patched_out.logits[:, -1, :].squeeze()\n",
    "patched_pred, patched_track = interpret_logits(\n",
    "    tokenizer=mt.tokenizer,\n",
    "    logits=patched_logits,\n",
    "    k=10,\n",
    "    interested_tokens = [\n",
    "        get_first_token_id(name=opt, tokenizer=mt.tokenizer, prefix=\" \")\n",
    "        for opt in clean_sample.options\n",
    "    ]\n",
    ")\n",
    "patched_track"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9d9e91aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "# failed_case[\"patched_pred\"], failed_case[\"patched_track\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "646de0bc",
   "metadata": {},
   "outputs": [],
   "source": [
    "# failed_case[\"clean_pred\"], failed_case[\"clean_track\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cf5e7e61",
   "metadata": {},
   "source": [
    "## Scale Up"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a760f23",
   "metadata": {},
   "outputs": [],
   "source": [
    "from dataclasses import dataclass\n",
    "from dataclasses_json import DataClassJsonMixin\n",
    "from typing import Union\n",
    "from src.selection.data import SelectionSample, CountingSample, YesNoSample\n",
    "\n",
    "@dataclass\n",
    "class CounterFactualforComposition(DataClassJsonMixin):\n",
    "    unit_samples: list[Union[SelectionSample, CountingSample, YesNoSample]]\n",
    "    clean_sample: Union[SelectionSample, CountingSample, YesNoSample]\n",
    "    comp_sample: Union[SelectionSample, CountingSample, YesNoSample]\n",
    "\n",
    "    @staticmethod\n",
    "    def sample_type_to_class():\n",
    "        return {\n",
    "            \"selection\": SelectionSample,\n",
    "            \"counting\": CountingSample,\n",
    "            \"yes_no\": YesNoSample,\n",
    "        }\n",
    "\n",
    "    def detensorize(self):\n",
    "        for sample in self.unit_samples + [self.clean_sample, self.comp_sample]:\n",
    "            class_name = type(sample).__name__\n",
    "            type_to_name = {\n",
    "                \"SelectionSample\": \"selection\",\n",
    "                \"CountingSample\": \"counting\",\n",
    "                \"YesNoSample\": \"yes_no\",\n",
    "            }\n",
    "            sample.metadata[\"sample_type\"] = type_to_name[class_name]\n",
    "            sample.detensorize()\n",
    "\n",
    "    @staticmethod\n",
    "    def from_dict(d):\n",
    "        unit_samples = []\n",
    "        for sample_dict in d[\"unit_samples\"]:\n",
    "            sample_type = sample_dict[\"metadata\"].pop(\"sample_type\")\n",
    "            sample_cls = CounterFactualforComposition.sample_type_to_class()[sample_type]\n",
    "            unit_samples.append(sample_cls.from_dict(sample_dict))\n",
    "        \n",
    "        comp_sample_type = d[\"comp_sample\"][\"metadata\"].pop(\"sample_type\")\n",
    "        comp_sample_cls = CounterFactualforComposition.sample_type_to_class()[comp_sample_type]\n",
    "        comp_sample = comp_sample_cls.from_dict(d[\"comp_sample\"])\n",
    "\n",
    "        clean_sample_type = d[\"clean_sample\"][\"metadata\"].pop(\"sample_type\")\n",
    "        clean_sample_cls = CounterFactualforComposition.sample_type_to_class()[clean_sample_type]\n",
    "        clean_sample = clean_sample_cls.from_dict(d[\"clean_sample\"])\n",
    "\n",
    "        return CounterFactualforComposition(\n",
    "            unit_samples=unit_samples,\n",
    "            comp_sample=comp_sample,\n",
    "            clean_sample=clean_sample,\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4034871c",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.functional import free_gpu_cache\n",
    "\n",
    "test_samples_save_path = os.path.join(\n",
    "    env_utils.DEFAULT_RESULTS_DIR,\n",
    "    \"selection\",\n",
    "    \"samples\",\n",
    "    \"validation_composition\",\n",
    "    mt.name.split(\"/\")[-1],\n",
    "    select_task.task_name,\n",
    "    \"objects\",\n",
    ")\n",
    "os.makedirs(test_samples_save_path, exist_ok=True)\n",
    "\n",
    "free_gpu_cache()\n",
    "\n",
    "\n",
    "#########################################\n",
    "test_limit = 512\n",
    "start_number = 1\n",
    "#########################################\n",
    "test_set = []\n",
    "\n",
    "test_set = []\n",
    "while len(test_set) < test_limit:\n",
    "    print(f\"sample {len(test_set)+1} / {test_limit}\")\n",
    "    unit_samples, clean_sample, comp_sample = get_composition_samples(\n",
    "        task=select_task,\n",
    "        mt=mt,\n",
    "        unit_categories=random.sample(select_task.categories, 2),\n",
    "        prompt_template_idx=3,\n",
    "        option_style=\"single_line\",\n",
    "        filter_by_lm_prediction=True,\n",
    "        n_distractors=random.choice(range(2, 6)),\n",
    "    )\n",
    "    test_set.append((unit_samples, clean_sample, comp_sample))\n",
    "    counterfactual = CounterFactualforComposition(\n",
    "        unit_samples=unit_samples,\n",
    "        clean_sample=clean_sample,\n",
    "        comp_sample=comp_sample,\n",
    "    )\n",
    "    counterfactual.detensorize()\n",
    "    comp_sample.metadata.pop(\"tokenized\", None)\n",
    "    with open(\n",
    "        os.path.join(test_samples_save_path, f\"{len(test_set) + start_number - 1:05d}.json\"),\n",
    "        \"w\",\n",
    "    ) as f:\n",
    "        json.dump(counterfactual.to_dict(), f, indent=2)\n",
    "\n",
    "len(test_set)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4cf361ce",
   "metadata": {},
   "outputs": [],
   "source": [
    "#########################################\n",
    "test_set = []\n",
    "test_limit = 512\n",
    "#########################################\n",
    "\n",
    "test_samples_load_path = os.path.join(\n",
    "    env_utils.DEFAULT_RESULTS_DIR,\n",
    "    \"selection\",\n",
    "    \"samples\",\n",
    "    \"validation_composition\",\n",
    "    mt.name.split(\"/\")[-1],\n",
    "    select_task.task_name,\n",
    "    \"objects\",\n",
    ")\n",
    "\n",
    "sample_files = [\n",
    "    os.path.join(test_samples_load_path, f)\n",
    "    for f in os.listdir(test_samples_load_path)\n",
    "    if f.endswith(\".json\")\n",
    "]\n",
    "logger.info(f\"Found {len(sample_files)} sample files\")\n",
    "\n",
    "random.shuffle(sample_files)\n",
    "sample_files = sample_files[:test_limit]\n",
    "for sample_file in sample_files:\n",
    "    with open(sample_file, \"r\") as f:\n",
    "        cf_composition = json.load(f)\n",
    "    cf_composition = CounterFactualforComposition.from_dict(cf_composition)\n",
    "    test_set.append(\n",
    "        (\n",
    "            cf_composition.unit_samples,\n",
    "            cf_composition.clean_sample,\n",
    "            cf_composition.comp_sample,\n",
    "        )\n",
    "    )\n",
    "\n",
    "len(test_set)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6270211a",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.functional import cache_q_projections\n",
    "from src.selection.functional import find_quesmark_pos\n",
    "\n",
    "def cache_q_projections_with_avg_trick(\n",
    "    mt: ModelandTokenizer,\n",
    "    patch_sample: Union[SelectionSample, CountingSample, YesNoSample],\n",
    "    heads: list[tuple[int, int]],\n",
    "    token_indices: list[int],\n",
    "    add_ques_pos_to_query_indices: bool = True,\n",
    "):\n",
    "    patch_samples = []\n",
    "    while len(patch_samples) < len(patch_sample.options):\n",
    "        obj_idx = len(patch_samples) % len(patch_sample.options)\n",
    "        sample = copy.deepcopy(patch_sample)\n",
    "        sample.options[obj_idx], sample.options[sample.obj_idx] = (\n",
    "            sample.options[sample.obj_idx],\n",
    "            sample.options[obj_idx],\n",
    "        )\n",
    "        sample.obj_idx = obj_idx\n",
    "        patch_samples.append(sample)\n",
    "    \n",
    "    patch_tokenized_batch = prepare_input(\n",
    "        prompts=[sample.prompt() for sample in patch_samples],\n",
    "        tokenizer=mt.tokenizer,\n",
    "        return_offsets_mapping=True,\n",
    "    )\n",
    "    patch_offset_mapping_batch = patch_tokenized_batch.pop(\"offset_mapping\")\n",
    "    patch_token_indices = []\n",
    "    for idx in range(len(patch_samples)):\n",
    "        cur_indices = {i: i for i in query_indices}\n",
    "        if add_ques_pos_to_query_indices:\n",
    "            patch_ques_pos = find_quesmark_pos(\n",
    "                prompt=patch_samples[idx].prompt(),\n",
    "                tokenizer=mt.tokenizer,\n",
    "                tokenized=TokenizerOutput(\n",
    "                    data={\n",
    "                        k: v[idx : idx + 1, :]\n",
    "                        for k, v in patch_tokenized_batch.items()\n",
    "                    }\n",
    "                ),\n",
    "                offset_mapping=patch_offset_mapping_batch[idx],\n",
    "            )\n",
    "            cur_indices[patch_ques_pos] = '#'\n",
    "        patch_token_indices.append(list(cur_indices.keys()))\n",
    "    logger.debug(f\"{patch_tokenized_batch.input_ids.shape}\")\n",
    "\n",
    "\n",
    "    cached_q_states = cache_q_projections(\n",
    "        mt=mt,\n",
    "        input=patch_tokenized_batch,\n",
    "        heads=heads,\n",
    "        token_indices=patch_token_indices,\n",
    "        return_output=False,\n",
    "    )\n",
    "    locations = list(cached_q_states[0].keys())\n",
    "    avg_q_states = {}\n",
    "    for loc in locations:\n",
    "        avg_q_states[loc] = torch.stack(\n",
    "            [cached_q_states[i][loc] for i in range(len(cached_q_states))]\n",
    "        ).mean(dim=0)\n",
    "    cached_q_states = [avg_q_states]\n",
    "\n",
    "    return cached_q_states"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "16b55e8b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm.auto import tqdm\n",
    "from src.functional import patch_with_baukit, interpret_logits\n",
    "from src.selection.utils import get_first_token_id\n",
    "\n",
    "query_indices = {-3: -3, -2: -2, -1: -1}\n",
    "query_locations = [\n",
    "    (layer_idx, head_idx, patch_query_idx)\n",
    "    for layer_idx, head_idx in HEADS\n",
    "    for patch_query_idx in query_indices.keys()\n",
    "]\n",
    "\n",
    "test_results = []\n",
    "for unit_samples, clean_sample, comp_sample in tqdm(test_set):\n",
    "    clean_obj = get_first_token_id(clean_sample.obj, tokenizer=mt.tokenizer)\n",
    "    target_obj = get_first_token_id(comp_sample.obj, tokenizer=mt.tokenizer)\n",
    "\n",
    "    unit_query_states = []\n",
    "    for sample in unit_samples:\n",
    "        # if \"tokenized\" in sample.metadata:\n",
    "        #     tokenized = TokenizerOutput(data=sample.metadata[\"tokenized\"])\n",
    "        # else:\n",
    "        unit_tokenized = prepare_input(prompts=sample.prompt(), tokenizer=mt.tokenizer)\n",
    "        # print(tokenized)\n",
    "        # unit_query_states.append(\n",
    "        #     cache_q_projections(\n",
    "        #         mt=mt,\n",
    "        #         input=unit_tokenized,\n",
    "        #         heads=optimized_heads,\n",
    "        #         token_indices=[list(query_indices.keys())],\n",
    "        #         return_output=False\n",
    "        #     )[0]\n",
    "        # )\n",
    "        unit_query_states.append(\n",
    "            cache_q_projections_with_avg_trick(\n",
    "                mt=mt,\n",
    "                patch_sample=sample,\n",
    "                heads=optimized_heads,\n",
    "                token_indices=list(query_indices.keys()),\n",
    "                add_ques_pos_to_query_indices=False,\n",
    "            )[0]\n",
    "        )\n",
    "\n",
    "    # inputs = TokenizerOutput(data=clean_sample.metadata[\"tokenized\"])\n",
    "    clean_tokenized = prepare_input(prompts=clean_sample.prompt(), tokenizer=mt.tokenizer)\n",
    "    interested_tokens = [\n",
    "        get_first_token_id(option, mt.tokenizer) for option in clean_sample.options\n",
    "    ]\n",
    "\n",
    "    # clean_run\n",
    "    clean_output = patch_with_baukit(\n",
    "        mt=mt,\n",
    "        inputs=clean_tokenized,\n",
    "    )\n",
    "    clean_logits = clean_output.logits[:, -1, :]\n",
    "    clean_pred, clean_track = interpret_logits(\n",
    "        tokenizer=mt.tokenizer, logits=clean_logits, interested_tokens=interested_tokens\n",
    "    )\n",
    "    before_intervention = {\n",
    "        \"clean_rank\": clean_track[clean_obj][0],\n",
    "        \"clean_logit\": clean_track[clean_obj][1].logit,\n",
    "        \"target_rank\": clean_track[target_obj][0],\n",
    "        \"target_logit\": clean_track[target_obj][1].logit,\n",
    "    }\n",
    "\n",
    "    # patched run\n",
    "\n",
    "    combined_q_proj_patches = []\n",
    "\n",
    "    for layer_idx, head_idx, query_idx in unit_query_states[0]:\n",
    "        proj = (\n",
    "            torch.stack(\n",
    "                [\n",
    "                    query_state[(layer_idx, head_idx, query_idx)]\n",
    "                    for query_state in unit_query_states\n",
    "                ]\n",
    "            )\n",
    "            .squeeze()\n",
    "            .sum(dim=0)\n",
    "        )\n",
    "        combined_q_proj_patches.append(\n",
    "            PatchSpec(\n",
    "                location=(\n",
    "                    mt.attn_module_name_format.format(layer_idx) + \".q_proj\",\n",
    "                    head_idx,\n",
    "                    query_idx,\n",
    "                ),\n",
    "                patch=proj,\n",
    "            )\n",
    "        )\n",
    "        \n",
    "    patched_output = patch_with_baukit(\n",
    "        mt=mt,\n",
    "        inputs=clean_tokenized,\n",
    "        patches=combined_q_proj_patches,\n",
    "    )\n",
    "    patched_logits = patched_output.logits[:, -1, :]\n",
    "    patched_pred, patched_track = interpret_logits(\n",
    "        tokenizer=mt.tokenizer,\n",
    "        logits=patched_logits,\n",
    "        interested_tokens=interested_tokens,\n",
    "    )\n",
    "    after_intervention = {\n",
    "        \"clean_rank\": patched_track[clean_obj][0],\n",
    "        \"clean_logit\": patched_track[clean_obj][1].logit,\n",
    "        \"target_rank\": patched_track[target_obj][0],\n",
    "        \"target_logit\": patched_track[target_obj][1].logit,\n",
    "    }\n",
    "\n",
    "    test_results.append(\n",
    "        {\n",
    "            \"unit_samples\": unit_samples,\n",
    "            \"clean_sample\": clean_sample,\n",
    "            \"comp_sample\": comp_sample,\n",
    "            \"before_intervention\": before_intervention,\n",
    "            \"after_intervention\": after_intervention,\n",
    "            \"clean_pred\": clean_pred,\n",
    "            \"clean_track\": clean_track,\n",
    "            \"patched_pred\": patched_pred,\n",
    "            \"patched_track\": patched_track,\n",
    "        }\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9fe69b2b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "clean_rank_delta = []\n",
    "target_rank_delta = []\n",
    "for result in test_results:\n",
    "    clean_rank_delta.append(result[\"after_intervention\"][\"clean_rank\"] - result[\"before_intervention\"][\"clean_rank\"])\n",
    "    target_rank_delta.append(result[\"after_intervention\"][\"target_rank\"] - result[\"before_intervention\"][\"target_rank\"])\n",
    "\n",
    "clean_rank_delta = np.array(clean_rank_delta)\n",
    "target_rank_delta = np.array(target_rank_delta)\n",
    "print(f\"clean_rank_delta: {clean_rank_delta.mean():.4f} ± {clean_rank_delta.std():.4f}\")\n",
    "print(\n",
    "    f\"target_rank_delta: {target_rank_delta.mean():.4f} ± {target_rank_delta.std():.4f}\"\n",
    ")\n",
    "\n",
    "clean_rank_after_intervention = [result[\"after_intervention\"][\"clean_rank\"] for result in test_results]\n",
    "clean_rank_after_intervention = np.array(clean_rank_after_intervention)\n",
    "print(\n",
    "    f\"clean_rank_after_intervention: {clean_rank_after_intervention.mean():.4f} ± {clean_rank_after_intervention.std():.4f}\"\n",
    ")\n",
    "\n",
    "target_rank_after_intervention = [result[\"after_intervention\"][\"target_rank\"] for result in test_results]\n",
    "target_rank_after_intervention = np.array(target_rank_after_intervention)\n",
    "print(\n",
    "    f\"target_rank_after_intervention: {target_rank_after_intervention.mean():.4f} ± {target_rank_after_intervention.std():.4f}\"\n",
    ")\n",
    "\n",
    "print(\"=\"*100)\n",
    "\n",
    "clean_logit_delta = []\n",
    "target_logit_delta = []\n",
    "for result in test_results:\n",
    "    clean_logit_delta.append(result[\"after_intervention\"][\"clean_logit\"] - result[\"before_intervention\"][\"clean_logit\"])\n",
    "    target_logit_delta.append(result[\"after_intervention\"][\"target_logit\"] - result[\"before_intervention\"][\"target_logit\"])\n",
    "\n",
    "clean_logit_delta, target_logit_delta = np.array(clean_logit_delta), np.array(target_logit_delta)\n",
    "print(f\"clean_logit_delta: {clean_logit_delta.mean():.4f} ± {clean_logit_delta.std():.4f}\")\n",
    "print(f\"target_logit_delta: {target_logit_delta.mean():.4f} ± {target_logit_delta.std():.4f}\")\n",
    "\n",
    "clean_logit_after_intervention = [result[\"after_intervention\"][\"clean_logit\"] for result in test_results]\n",
    "clean_logit_after_intervention = np.array(clean_logit_after_intervention)\n",
    "print(f\"clean_logit_after_intervention: {clean_logit_after_intervention.mean():.4f} ± {clean_logit_after_intervention.std():.4f}\")\n",
    "\n",
    "target_logit_after_intervention = [\n",
    "    result[\"after_intervention\"][\"target_logit\"] for result in test_results\n",
    "]\n",
    "target_logit_after_intervention = np.array(target_logit_after_intervention)\n",
    "print(f\"target_logit_after_intervention: {target_logit_after_intervention.mean():.4f} ± {target_logit_after_intervention.std():.4f}\")\n",
    "\n",
    "print(\"=\"*100)\n",
    "\n",
    "counter_patch_type_top_option = 0\n",
    "failed_cases = []\n",
    "\n",
    "for intervention_result in test_results:\n",
    "    clean_sample = intervention_result[\"clean_sample\"]\n",
    "    comp_sample = intervention_result[\"comp_sample\"]\n",
    "    int_track = intervention_result[\"patched_track\"]\n",
    "    clean_track = intervention_result[\"clean_track\"]\n",
    "    if (\n",
    "        int_track[list(int_track.keys())[0]][1].token_id\n",
    "        == get_first_token_id(comp_sample.obj, tokenizer=mt.tokenizer)\n",
    "    ): \n",
    "        counter_patch_type_top_option += 1\n",
    "    else:\n",
    "        failed_cases.append(intervention_result)\n",
    "\n",
    "top_1_accuracy = counter_patch_type_top_option / len(test_results)\n",
    "print(\n",
    "    f\"Counterfactual patching accuracy: {top_1_accuracy:.4f} ({counter_patch_type_top_option}/{len(test_results)})\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a12727c8",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5565a29b",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "connection",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
