{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f8ec34ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bfae7235",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "\n",
    "import sys\n",
    "\n",
    "sys.path.append(\"../\")\n",
    "\n",
    "##################################################################\n",
    "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
    "# os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0,1,2,3,4,5,6,7\"\n",
    "##################################################################\n",
    "\n",
    "import logging\n",
    "from src.utils import logging_utils\n",
    "from src.utils import env_utils\n",
    "\n",
    "logger = logging.getLogger(__name__)\n",
    "\n",
    "logging.basicConfig(\n",
    "    level=logging.DEBUG,\n",
    "    format=logging_utils.DEFAULT_FORMAT,\n",
    "    datefmt=logging_utils.DEFAULT_DATEFMT,\n",
    "    stream=sys.stdout,\n",
    ")\n",
    "\n",
    "import torch\n",
    "import transformers\n",
    "\n",
    "logger.info(f\"{torch.__version__=}, {torch.version.cuda=}\")\n",
    "logger.info(\n",
    "    f\"{torch.cuda.is_available()=}, {torch.cuda.device_count()=}, {torch.cuda.get_device_name()=}\"\n",
    ")\n",
    "logger.info(f\"{transformers.__version__=}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc7720ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.utils.training_utils import get_device_map\n",
    "\n",
    "# model_key = \"meta-llama/Llama-3.2-3B\"\n",
    "# model_key = \"meta-llama/Llama-3.1-8B\"\n",
    "model_key = \"meta-llama/Llama-3.3-70B-Instruct\"\n",
    "# model_key = \"meta-llama/Llama-3.1-405B-Instruct\"\n",
    "\n",
    "# model_key = \"google/gemma-2-9b-it\"\n",
    "# model_key = \"google/gemma-3-12b-it\"\n",
    "# model_key = \"google/gemma-2-27b-it\"\n",
    "\n",
    "# model_key = \"deepseek-ai/DeepSeek-R1-Distill-Llama-8B\"\n",
    "\n",
    "# model_key = \"allenai/OLMo-2-1124-7B-Instruct\"\n",
    "# model_key = \"allenai/OLMo-7B-0424-hf\"\n",
    "\n",
    "# model_key = \"Qwen/Qwen2-7B\"\n",
    "# model_key = \"Qwen/Qwen2.5-14B-Instruct\"\n",
    "# model_key = \"Qwen/Qwen2.5-32B-Instruct\"\n",
    "# model_key = \"Qwen/Qwen2.5-72B-Instruct\"\n",
    "\n",
    "# model_key = \"Qwen/Qwen3-1.7B\"\n",
    "# model_key = \"Qwen/Qwen3-4B\"\n",
    "# model_key = \"Qwen/Qwen3-8B\"\n",
    "# model_key = \"Qwen/Qwen3-14B\"\n",
    "# model_key = \"Qwen/Qwen3-32B\"\n",
    "\n",
    "# device_map = get_device_map(model_key, 30, n_gpus=8)\n",
    "# device_map"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "683855df",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.models import ModelandTokenizer\n",
    "\n",
    "# from transformers import BitsAndBytesConfig\n",
    "\n",
    "mt = ModelandTokenizer(\n",
    "    model_key=model_key,\n",
    "    torch_dtype=torch.bfloat16,\n",
    "    # device_map=device_map,\n",
    "    device_map=\"auto\",\n",
    "    # quantization_config = BitsAndBytesConfig(\n",
    "    #     # load_in_4bit=True\n",
    "    #     load_in_8bit=True\n",
    "    # )\n",
    "    attn_implementation=\"eager\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "080021e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# file_path = os.path.join(\n",
    "#     env_utils.DEFAULT_DATA_DIR,\n",
    "#     \"selection\",\n",
    "#     # \"profession.json\"\n",
    "#     # \"nationality.json\"\n",
    "#     \"objects.json\",\n",
    "# )\n",
    "\n",
    "# with open(file_path, \"r\") as f:\n",
    "#     temp = json.load(f)\n",
    "\n",
    "# for cat in temp[\"categories\"]:\n",
    "#     temp[\"categories\"][cat] = [obj.capitalize() for obj in temp[\"categories\"][cat]]\n",
    "\n",
    "# with open(file_path, \"w\") as f:\n",
    "#     json.dump(temp, f, indent=4)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a62d97cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.data import SelectOneTask, SelectOrderTask\n",
    "\n",
    "#################################################################################\n",
    "# TASK_CLS = SelectOrderTask\n",
    "# prompt_template_idx = 1\n",
    "TASK_CLS = SelectOneTask\n",
    "prompt_template_idx = 3\n",
    "N_DISTRACTORS = 5\n",
    "OPTION_STYLE = \"single_line\"\n",
    "#################################################################################\n",
    "\n",
    "select_task = TASK_CLS.load(\n",
    "    path=os.path.join(\n",
    "        env_utils.DEFAULT_DATA_DIR, \n",
    "        \"selection\", \n",
    "        # \"profession.json\"\n",
    "        # \"nationality.json\"\n",
    "        \"objects.json\"\n",
    "    )\n",
    ")\n",
    "\n",
    "print(select_task)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "45087a7f",
   "metadata": {},
   "outputs": [],
   "source": [
    "sample = select_task.get_random_sample(\n",
    "    mt = mt,\n",
    "    option_style=OPTION_STYLE,\n",
    "    prompt_template_idx=prompt_template_idx,\n",
    "    obj_idx=2,\n",
    "    # category=\"actor\",\n",
    "    # category=\"Brazil\"\n",
    "    category=\"fruit\",\n",
    "    filter_by_lm_prediction=False,\n",
    ")\n",
    "\n",
    "print(sample)\n",
    "print(sample.prompt())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f66399b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.utils import verify_correct_option\n",
    "# sample.prompt_template = select_prof.prompt_templates[3]\n",
    "print(f'\"{sample.prompt()}\"', \">>\", sample.obj)\n",
    "\n",
    "verify_correct_option(\n",
    "    mt=mt,\n",
    "    target=sample.obj,\n",
    "    options=sample.options,\n",
    "    input=sample.prompt()\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea0499b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.functional import generate_with_patch\n",
    "\n",
    "gen = generate_with_patch(\n",
    "    mt = mt,\n",
    "    inputs = sample.prompt(),\n",
    "    max_new_tokens=20,\n",
    "    do_sample=False,\n",
    "    remove_prefix=True\n",
    ")[0]\n",
    "print(f'\"{gen}\"', \">>\", sample.obj)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2e8a8aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "mt.n_layer, mt.config.num_attention_heads"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a35516e",
   "metadata": {},
   "outputs": [],
   "source": [
    "HEADS = [\n",
    "    (33, 45),\n",
    "    (33, 18),\n",
    "    (34, 1),\n",
    "    (34, 6),\n",
    "    (34, 7),\n",
    "    (35, 19),\n",
    "    (39, 40),\n",
    "    (42, 30),\n",
    "    (47, 18),\n",
    "    (52, 58),\n",
    "]\n",
    "# HEADS = [(35, 19)]\n",
    "\n",
    "\n",
    "# with open(\"optimized_heads.json\", \"r\") as f:\n",
    "#     HEADS = json.load(f)\n",
    "\n",
    "# with open(\"category_wise_heads.json\", \"r\") as f:\n",
    "#     category_wise_heads = json.load(f)\n",
    "# HEADS = [\n",
    "#     (layer_idx, head_idx)\n",
    "#     for layer_idx, head_idx, score in category_wise_heads[\"all\"][:100]\n",
    "# ]\n",
    "# HEADS = [(layer_idx, head_idx) for layer_idx, head_idx in HEADS if layer_idx < 61]\n",
    "\n",
    "\n",
    "print(len(HEADS))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe381b2f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib import pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "optimized_path = os.path.join(\n",
    "    env_utils.DEFAULT_RESULTS_DIR,\n",
    "    \"selection/optimized_heads\",\n",
    "    mt.name.split(\"/\")[-1],\n",
    "    f\"{select_task.task_name}.npz\"\n",
    ")\n",
    "\n",
    "optimization_results = np.load(optimized_path, allow_pickle=True)\n",
    "plt.plot(optimization_results[\"losses\"])\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f87810f",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(20, 10))\n",
    "\n",
    "optimal_head_mask = torch.tensor(optimization_results[\"optimal_mask\"]).to(torch.float32)\n",
    "optimal_head_mask[50:, :] = 0.0\n",
    "\n",
    "plt.imshow(\n",
    "    optimal_head_mask.T.numpy(),\n",
    "    cmap=\"Blues\",\n",
    "    aspect=\"auto\",\n",
    "    vmin=0,\n",
    "    vmax=1,\n",
    ")\n",
    "\n",
    "heads_selected = torch.nonzero(optimal_head_mask > 0.5, as_tuple=False).tolist()\n",
    "heads_selected = [\n",
    "    (layer_idx, head_idx) for layer_idx, head_idx in heads_selected if layer_idx < 50\n",
    "]\n",
    "print(len(heads_selected))\n",
    "\n",
    "HEADS = heads_selected\n",
    "\n",
    "(35, 19) in HEADS, (35, 19) in heads_selected"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "85a9a2a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.attention import get_attention_matrices\n",
    "from src.selection.functional import (\n",
    "    verify_head_patterns,\n",
    "    get_patches_to_verify_independent_enrichment,\n",
    ")\n",
    "\n",
    "attn_pattern = verify_head_patterns(\n",
    "    prompt=sample.prompt(option_style=\"single_line\"),\n",
    "    options=sample.options,\n",
    "    pivot=sample.subj,\n",
    "    mt=mt,\n",
    "    heads=HEADS,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "129fe333",
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy\n",
    "import random\n",
    "from src.selection.utils import KeyedSet, get_first_token_id, verify_correct_option\n",
    "from src.selection.data import SelectionSample\n",
    "from src.functional import predict_next_token\n",
    "from src.tokens import prepare_input\n",
    "\n",
    "######################################################################\n",
    "N_DISTRACTORS = 5\n",
    "WINDOW_SPEC = {\n",
    "    mt.layer_name_format: 1,\n",
    "    mt.mlp_module_name_format: 9,\n",
    "    mt.attn_module_name_format: 9,\n",
    "}\n",
    "module_name_format = mt.layer_name_format\n",
    "# module_name_format = mt.mlp_module_name_format\n",
    "# module_name_format = mt.attn_module_name_format\n",
    "\n",
    "######################################################################\n",
    "\n",
    "\n",
    "@torch.inference_mode()\n",
    "def get_counterfactual_samples_within_task(\n",
    "    task: SelectOneTask | SelectOrderTask = select_task,\n",
    "    patch_category: str | None = None,\n",
    "    clean_category: str | None = None,\n",
    "    shuffle_clean_options: bool = False,\n",
    "    prompt_template_idx=2,\n",
    "    clean_option_style=\"numbered\",\n",
    "    patch_option_style=\"single_line\",\n",
    "    filter_by_lm_prediction: bool = True,\n",
    "    distinct_options: bool = False,\n",
    "    n_distractors: int = N_DISTRACTORS,\n",
    "):\n",
    "    categories = list(task.category_wise_examples.keys())\n",
    "    if patch_category is None:\n",
    "        patch_category = random.choice(categories)\n",
    "\n",
    "    patch_subj, patch_obj = random.sample(\n",
    "        task.category_wise_examples[patch_category], 2\n",
    "    )\n",
    "    logger.info(\n",
    "        f\"Patch category: {patch_category}, subject: {patch_subj}, object: {patch_obj}\"\n",
    "    )\n",
    "\n",
    "    if clean_category is None:\n",
    "        clean_category = random.choice(list(set(categories) - {patch_category}))\n",
    "\n",
    "    clean_options = task.category_wise_examples[clean_category]\n",
    "    random.shuffle(clean_options)\n",
    "\n",
    "    clean_subj, clean_obj = random.sample(\n",
    "        (\n",
    "            KeyedSet(clean_options, mt.tokenizer) - KeyedSet([patch_obj], mt.tokenizer)\n",
    "        ).values,\n",
    "        2,\n",
    "    )\n",
    "    logger.info(\n",
    "        f\"Clean category: {clean_category}, subject: {clean_subj}, object: {clean_obj}\"\n",
    "    )\n",
    "\n",
    "    if distinct_options is False:\n",
    "        patch_type_obj = patch_obj\n",
    "        clean_type_obj = clean_obj\n",
    "    else:\n",
    "        patch_type_obj = random.choice(\n",
    "            (\n",
    "                KeyedSet(task.category_wise_examples[patch_category], mt.tokenizer)\n",
    "                - KeyedSet([patch_obj], mt.tokenizer)\n",
    "            ).values\n",
    "        )\n",
    "        clean_type_obj = random.choice(\n",
    "            (\n",
    "                KeyedSet(task.category_wise_examples[clean_category], mt.tokenizer)\n",
    "                - KeyedSet([clean_obj], mt.tokenizer)\n",
    "            ).values\n",
    "        )\n",
    "\n",
    "    patch_must_have_options = [patch_obj, clean_type_obj]\n",
    "    clean_must_have_options = [clean_obj, patch_type_obj]\n",
    "\n",
    "    logger.info(f\"{patch_must_have_options=}\")\n",
    "    logger.info(f\"{clean_must_have_options=}\")\n",
    "    logger.info(f\"{clean_type_obj=}\")\n",
    "    logger.info(f\"{patch_type_obj=}\")\n",
    "\n",
    "    patch_distractors = []\n",
    "    other_categories = random.sample(\n",
    "        list(set(categories) - {patch_category, clean_category}),\n",
    "        k=n_distractors - (len(patch_must_have_options)) + 1,\n",
    "    )\n",
    "\n",
    "    for other_category in other_categories:\n",
    "        other_examples = task.category_wise_examples[other_category]\n",
    "        random.shuffle(other_examples)\n",
    "        other_examples = KeyedSet(other_examples, mt.tokenizer)\n",
    "        patch_distractors.append(\n",
    "            random.choice(\n",
    "                (\n",
    "                    other_examples\n",
    "                    - KeyedSet(\n",
    "                        patch_must_have_options + patch_distractors,\n",
    "                        tokenizer=mt.tokenizer,\n",
    "                    )\n",
    "                ).values\n",
    "            )\n",
    "        )\n",
    "\n",
    "    patch_options = patch_must_have_options + patch_distractors\n",
    "    random.shuffle(patch_options)\n",
    "    patch_obj_idx = patch_options.index(patch_obj)\n",
    "    logger.info(f\"{patch_obj_idx=} | {patch_options}\")\n",
    "\n",
    "    if distinct_options is not True:\n",
    "        clean_options = copy.deepcopy(patch_options)\n",
    "        if shuffle_clean_options:\n",
    "            # Useful for the pointer experiments\n",
    "            while (\n",
    "                clean_options.index(clean_obj) == patch_obj_idx\n",
    "                or clean_options.index(patch_type_obj) == patch_obj_idx\n",
    "            ):\n",
    "                random.shuffle(clean_options)\n",
    "        clean_obj_idx = clean_options.index(clean_obj)\n",
    "\n",
    "    else:\n",
    "        other_categories = random.sample(\n",
    "            list(set(categories) - {patch_category, clean_category}),\n",
    "            k=N_DISTRACTORS - (len(clean_must_have_options)) + 1,\n",
    "        )\n",
    "        clean_distractors = []\n",
    "        for other_category in other_categories:\n",
    "            other_examples = task.category_wise_examples[other_category]\n",
    "            random.shuffle(other_examples)\n",
    "            other_examples = KeyedSet(other_examples, mt.tokenizer)\n",
    "            clean_distractors.append(\n",
    "                random.choice(\n",
    "                    (\n",
    "                        other_examples\n",
    "                        - KeyedSet(\n",
    "                            clean_must_have_options + clean_distractors,\n",
    "                            tokenizer=mt.tokenizer,\n",
    "                        )\n",
    "                    ).values\n",
    "                )\n",
    "            )\n",
    "        clean_options = clean_must_have_options + clean_distractors\n",
    "        random.shuffle(clean_options)\n",
    "        while clean_options.index(clean_obj) == patch_obj_idx:\n",
    "            random.shuffle(clean_options)\n",
    "        clean_obj_idx = clean_options.index(clean_obj)\n",
    "\n",
    "    logger.info(f\"{clean_obj_idx=} | {clean_options}\")\n",
    "\n",
    "    print(f\"{type(task)=}\")\n",
    "    if isinstance(task, SelectOrderTask):\n",
    "        patch_metadata = {\n",
    "            \"track_type_obj_idx\": clean_obj_idx,\n",
    "            \"track_type_obj\": patch_options[clean_obj_idx],\n",
    "            \"track_type_obj_token_id\": get_first_token_id(\n",
    "                patch_options[clean_obj_idx], mt.tokenizer, prefix=\" \"\n",
    "            ),\n",
    "        }\n",
    "        clean_metadata = {\n",
    "            \"track_type_obj_idx\": patch_obj_idx,\n",
    "            \"track_type_obj\": clean_options[patch_obj_idx],\n",
    "            \"track_type_obj_token_id\": get_first_token_id(\n",
    "                clean_options[patch_obj_idx], mt.tokenizer, prefix=\" \"\n",
    "            ),\n",
    "        }\n",
    "    elif isinstance(task, SelectOneTask):\n",
    "        patch_metadata = {\n",
    "            \"track_category\": clean_category,\n",
    "            \"track_type_obj\": clean_type_obj,\n",
    "            \"track_type_obj_idx\": patch_options.index(clean_type_obj),\n",
    "            \"track_type_obj_token_id\": get_first_token_id(\n",
    "                clean_type_obj, mt.tokenizer, prefix=\" \"\n",
    "            ),\n",
    "        }\n",
    "        clean_metadata = {\n",
    "            \"track_category\": patch_category,\n",
    "            \"track_type_obj\": patch_type_obj,\n",
    "            \"track_type_obj_idx\": clean_options.index(patch_type_obj),\n",
    "            \"track_type_obj_token_id\": get_first_token_id(\n",
    "                patch_type_obj, mt.tokenizer, prefix=\" \"\n",
    "            ),\n",
    "        }\n",
    "    else:\n",
    "        raise NotImplementedError(f\"Unsupported task type: {type(task)}\")\n",
    "\n",
    "    patch_sample = SelectionSample(\n",
    "        subj=patch_subj,\n",
    "        obj=patch_obj,\n",
    "        answer=patch_obj,\n",
    "        obj_idx=patch_obj_idx,\n",
    "        ans_token_id=get_first_token_id(patch_obj, mt.tokenizer, prefix=\" \"),\n",
    "        options=patch_options,\n",
    "        category=patch_category,\n",
    "        metadata=patch_metadata,\n",
    "        prompt_template = task.prompt_templates[prompt_template_idx],\n",
    "        default_option_style=patch_option_style,\n",
    "    )\n",
    "    clean_sample = SelectionSample(\n",
    "        subj=clean_subj,\n",
    "        obj=clean_obj,\n",
    "        answer=clean_obj,\n",
    "        obj_idx=clean_obj_idx,\n",
    "        ans_token_id=get_first_token_id(clean_obj, mt.tokenizer, prefix=\" \"),\n",
    "        options=clean_options,\n",
    "        category=clean_category,\n",
    "        metadata=clean_metadata,\n",
    "        prompt_template = task.prompt_templates[prompt_template_idx],\n",
    "        default_option_style=clean_option_style,\n",
    "    )\n",
    "\n",
    "    if filter_by_lm_prediction:\n",
    "        test_samples = [patch_sample, clean_sample]\n",
    "        if distinct_options is True:\n",
    "            clean_sample_2 = copy.deepcopy(patch_sample)\n",
    "            clean_sample_2.options = clean_options\n",
    "            clean_sample_2.obj = clean_sample.metadata[\"track_type_obj\"]\n",
    "            clean_sample_2.obj_idx = clean_sample.metadata[\"track_type_obj_idx\"]\n",
    "            clean_sample_2.ans_token_id = clean_sample.metadata[\n",
    "                \"track_type_obj_token_id\"\n",
    "            ]\n",
    "            test_samples.append(clean_sample_2)\n",
    "\n",
    "        for sample in test_samples:\n",
    "            tokenized = prepare_input(tokenizer=mt, prompts=sample.prompt())\n",
    "            is_correct, predictions, track_options = verify_correct_option(\n",
    "                mt=mt, target=sample.obj, options=sample.options, input=tokenized\n",
    "            )\n",
    "            sample.metadata[\"tokenized\"] = tokenized.data\n",
    "            logger.info(sample.prompt())\n",
    "            logger.info(\n",
    "                f\"{sample.subj} | {sample.category} -> {sample.obj} | pred={[str(p) for p in predictions]}\"\n",
    "            )\n",
    "            if not is_correct:\n",
    "                logger.error(\n",
    "                    f'Prediction mismatch: {track_options[list(track_options.keys())[0]]}[\"{mt.tokenizer.decode(predictions[0].token_id)}\"] != {sample.ans_token_id}[\"{mt.tokenizer.decode(sample.ans_token_id)}\"]'\n",
    "                )\n",
    "                return get_counterfactual_samples_within_task(\n",
    "                    task=task,\n",
    "                    patch_category=patch_category,\n",
    "                    clean_category=clean_category,\n",
    "                    shuffle_clean_options=shuffle_clean_options,\n",
    "                    prompt_template_idx=prompt_template_idx,\n",
    "                    clean_option_style=clean_option_style,\n",
    "                    patch_option_style=patch_option_style,\n",
    "                    filter_by_lm_prediction=filter_by_lm_prediction,\n",
    "                    distinct_options=distinct_options,\n",
    "                    n_distractors=n_distractors,\n",
    "                )\n",
    "            sample.prediction = predictions\n",
    "\n",
    "    return patch_sample, clean_sample"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "425f6285",
   "metadata": {},
   "outputs": [],
   "source": [
    "patch_sample, clean_sample = get_counterfactual_samples_within_task(\n",
    "    # patch_category=\"politician\",\n",
    "    # clean_category=\"actor\",\n",
    "    task=select_task,\n",
    "    patch_category=\"fruit\",\n",
    "    clean_category=\"vehicle\",\n",
    "    filter_by_lm_prediction=True,\n",
    "    prompt_template_idx=prompt_template_idx,\n",
    "    distinct_options=True,\n",
    "    n_distractors = N_DISTRACTORS,\n",
    ")\n",
    "\n",
    "clean_sample.metadata"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "41815cd1",
   "metadata": {},
   "outputs": [],
   "source": [
    "tok_id = get_first_token_id(\"1\", mt.tokenizer, prefix=\"\")\n",
    "mt.tokenizer.decode(tok_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "86d35a26",
   "metadata": {},
   "outputs": [],
   "source": [
    "patch_sample = SelectionSample(\n",
    "    subj=\"#\",\n",
    "    obj=\"43\",\n",
    "    options=[\"43\", \"57\", \"55\", \"62\", \"39\"],\n",
    "    prompt_template=\"Which of the following numbers is a multiple of 3?\\n<_options_>\\nAnswer:\",\n",
    "    obj_idx=0,\n",
    "    answer=\"Yes\",\n",
    ")\n",
    "\n",
    "clean_sample_2 = SelectionSample(\n",
    "    subj=\"#\",\n",
    "    obj=\"Michael Jordan\",\n",
    "    options=[\n",
    "        \"Michael Jordan\",\n",
    "        \"Serena Williams\",\n",
    "        \"Justin Trudeau\",\n",
    "        \"Mike Tyson\",\n",
    "        \"Carl Sagan\",\n",
    "        \"Tom Cruise\",\n",
    "    ],\n",
    "    prompt_template=\"Who among these people shares the same profession as Roger Federer?\\n<_options_>\\nAnswer:\",\n",
    "    obj_idx=0,\n",
    "    answer=\"Yes\",\n",
    ")\n",
    "\n",
    "\n",
    "clean_sample = SelectionSample(\n",
    "    subj=\"#\",\n",
    "    obj=\"Bus\",\n",
    "    options=[\"Bus\", \"Peach\", \"Watch\", \"Car\", \"Banana\", \"Rabbit\", \"Mango\"],\n",
    "    prompt_template=\"<_options_>\\nCount the number of fruits in the list above.\\nAnswer: \",\n",
    "    obj_idx=0,\n",
    "    answer=\"2\",\n",
    "    ans_token_id=get_first_token_id(\"2\", mt.tokenizer, prefix=\"\"),\n",
    ")\n",
    "\n",
    "clean_sample.metadata = {\n",
    "    \"track_type_obj\": \"1\",\n",
    "    \"track_type_obj_token_id\": get_first_token_id(\"1\", mt.tokenizer, prefix=\"\"),\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "510772fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.functional import generate_with_patch\n",
    "\n",
    "# patch_sample.options[patch_sample.obj_idx] = \"Screw\"\n",
    "# patch_sample.options[patch_sample.obj_idx] = patch_sample.obj\n",
    "patch_sample.default_option_style = \"single_line\"\n",
    "clean_sample.default_option_style = \"numbered\"\n",
    "\n",
    "for sample in [clean_sample_2, patch_sample, clean_sample]:\n",
    "# for sample in [order_sample_1, order_sample_2]:\n",
    "    print(sample.prompt(), \">>\", sample.obj)\n",
    "    attn_pattern = verify_head_patterns(\n",
    "        prompt=sample.prompt(),\n",
    "        options=sample.options,\n",
    "        pivot=sample.subj,\n",
    "        mt=mt,\n",
    "        heads=HEADS,\n",
    "        generate_full_answer=True,\n",
    "        query_index=-1\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f414b376",
   "metadata": {},
   "source": [
    "## Validation of the patching effect"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "00698133",
   "metadata": {},
   "outputs": [],
   "source": [
    "# from src.functional import free_gpu_cache\n",
    "# free_gpu_cache()\n",
    "# validation_set = []\n",
    "# validation_limit = 256\n",
    "\n",
    "# while len(validation_set) < validation_limit:\n",
    "#     patch, clean = get_counterfactual_samples_within_task(\n",
    "#         filter_by_lm_prediction=True,\n",
    "#         prompt_template_idx=prompt_template_idx,\n",
    "#         distinct_options=True,\n",
    "#         n_distractors=N_DISTRACTORS,\n",
    "#     )\n",
    "#     validation_set.append((clean, patch))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3d041d50",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.optimization import validate_q_proj_ie_on_sample_pair\n",
    "\n",
    "\n",
    "# clean, patch = copy.deepcopy(validation_set[28])\n",
    "# clean.default_option_style=\"numbered\"\n",
    "# patch.default_option_style=\"numbered\"\n",
    "# clean, patch = train_set[18]\n",
    "clean, patch = clean_sample, patch_sample\n",
    "# clean, patch = order_sample_2, order_sample_1\n",
    "# patch, clean = order_sample_1, order_sample_2\n",
    "# failed_case = failed_cases[0]\n",
    "# clean = failed_case[\"clean_sample\"]\n",
    "# patch = failed_case[\"patch_sample\"]\n",
    "\n",
    "# patch.options[patch.obj_idx] = \"Mango\"\n",
    "# patch.options = [\"#\"]\n",
    "print(clean.prompt(), \">>\", clean.obj)\n",
    "print(patch.prompt(), \">>\", patch.obj)\n",
    "\n",
    "validation_result = validate_q_proj_ie_on_sample_pair(\n",
    "    mt=mt,\n",
    "    clean_sample=clean,\n",
    "    patch_sample=patch,\n",
    "    # heads=HEADS,\n",
    "    heads = heads_selected,\n",
    "    # heads=[(35, 19)],\n",
    "    query_indices={-3: -3, -2: -2, -1: -1},\n",
    "    verify_head_behavior_on=-1,\n",
    "    # ablate_possible_ans_info_from_options=True,\n",
    "    # amplification_scale=1.5,\n",
    "    must_track_tokens=[clean.metadata[\"track_type_obj_token_id\"], clean.ans_token_id],\n",
    ")\n",
    "\n",
    "clean_obj = clean.ans_token_id\n",
    "target_obj = clean.metadata[\"track_type_obj_token_id\"]\n",
    "\n",
    "logger.debug(f\"clean obj: {mt.tokenizer.decode(clean_obj)}\")\n",
    "logger.debug(f\"target obj: {mt.tokenizer.decode(target_obj)}\")\n",
    "\n",
    "before_intervention = {\n",
    "    \"clean_rank\": validation_result[\"clean_track\"][clean_obj][0],\n",
    "    \"clean_logit\": validation_result[\"clean_track\"][clean_obj][1].logit,\n",
    "    \"target_rank\": validation_result[\"clean_track\"][target_obj][0],\n",
    "    \"target_logit\": validation_result[\"clean_track\"][target_obj][1].logit,\n",
    "}\n",
    "\n",
    "after_intervention = {\n",
    "    \"clean_rank\": validation_result[\"int_track\"][clean_obj][0],\n",
    "    \"clean_logit\": validation_result[\"int_track\"][clean_obj][1].logit,\n",
    "    \"target_rank\": validation_result[\"int_track\"][target_obj][0],\n",
    "    \"target_logit\": validation_result[\"int_track\"][target_obj][1].logit,\n",
    "}\n",
    "\n",
    "clean_rank_delta = after_intervention[\"clean_rank\"] - before_intervention[\"clean_rank\"]\n",
    "target_rank_delta = after_intervention[\"target_rank\"] - before_intervention[\"target_rank\"]\n",
    "logger.info(f\"Clean Prediction Rank Change: {before_intervention['clean_rank']} -> {after_intervention['clean_rank']} | Delta: {clean_rank_delta} \")\n",
    "logger.info(f\"Target Prediction Rank Change: {before_intervention['target_rank']} -> {after_intervention['target_rank']} | Delta: {target_rank_delta} \")\n",
    "\n",
    "clean_logit_delta = after_intervention[\"clean_logit\"] - before_intervention[\"clean_logit\"]\n",
    "target_logit_delta =  after_intervention[\"target_logit\"] - before_intervention[\"target_logit\"]\n",
    "logger.info(f\"Clean Prediction Logit Change: {before_intervention['clean_logit']:.4f} -> {after_intervention['clean_logit']:.4f} | Delta: {clean_logit_delta:.4f} \")\n",
    "logger.info(f\"Target Prediction Logit Change: {before_intervention['target_logit']:.4f} -> {after_intervention['target_logit']:.4f} | Delta: {target_logit_delta:.4f} \")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d5e1292f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.functional import cache_q_projections\n",
    "patch_tokenized = prepare_input(\n",
    "    tokenizer=mt,\n",
    "    prompts=patch_sample.prompt()\n",
    ")\n",
    "clean_tokenized = prepare_input(\n",
    "    tokenizer=mt,\n",
    "    prompts=clean_sample.prompt()\n",
    ")\n",
    "\n",
    "query_indices = {-3: -3, -2: -2, -1: -1}\n",
    "\n",
    "query_locations = [\n",
    "    (layer_idx, head_idx, patch_query_idx)\n",
    "    for layer_idx, head_idx in heads_selected\n",
    "    for patch_query_idx in query_indices.keys()\n",
    "]\n",
    "\n",
    "cached_q_states, patch_output = cache_q_projections(\n",
    "    mt=mt,\n",
    "    input=patch_tokenized,\n",
    "    query_locations=query_locations,\n",
    "    return_output=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f81b8b0e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.functional import get_module_nnsight, interpret_logits\n",
    "\n",
    "layer_to_hq = {}\n",
    "for layer_idx, head_idx, query_idx in query_locations:\n",
    "    if layer_idx not in layer_to_hq:\n",
    "        layer_to_hq[layer_idx] = []\n",
    "    layer_to_hq[layer_idx].append((head_idx, query_idx))\n",
    "\n",
    "q_projections = {}\n",
    "batch_size = clean_tokenized.input_ids.shape[0]\n",
    "seq_len = clean_tokenized.input_ids.shape[1]\n",
    "n_heads = mt.config.num_attention_heads\n",
    "head_dim = mt.n_embd // n_heads\n",
    "with mt.trace(clean_tokenized) as tracer:\n",
    "    for layer_idx, query_locs in layer_to_hq.items():\n",
    "        q_proj_name = mt.attn_module_name_format.format(layer_idx) + \".q_proj\"\n",
    "        q_proj_module = get_module_nnsight(mt, q_proj_name)\n",
    "        tracer.log(q_proj_module.output.shape)\n",
    "        q_proj_out = q_proj_module.output.view(\n",
    "            batch_size, seq_len, n_heads, head_dim\n",
    "        ).transpose(1, 2)\n",
    "        for head_idx, token_idx in query_locs:\n",
    "            q_proj_out[:, head_idx, token_idx, :] = cached_q_states[(layer_idx, head_idx, token_idx)]\n",
    "    \n",
    "    output = mt.output.save()\n",
    "\n",
    "logits = output.logits[:, -1, :].squeeze()\n",
    "pred, track = interpret_logits(\n",
    "    logits=logits,\n",
    "    tokenizer=mt.tokenizer,\n",
    "    interested_tokens=[\n",
    "        clean.ans_token_id,\n",
    "        clean.metadata[\"track_type_obj_token_id\"],\n",
    "    ],\n",
    ")\n",
    "pred, track"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f66cf201",
   "metadata": {},
   "outputs": [],
   "source": [
    "with mt.generate(\n",
    "    clean_tokenized,\n",
    "    max_new_tokens=20,\n",
    "    do_sample=False,\n",
    "    output_scores=True,\n",
    "    return_dict_in_generate=True,\n",
    "    use_cache=False\n",
    ") as tracer:\n",
    "    mt.all()  # patch at all token positions\n",
    "    for layer_idx, query_locs in layer_to_hq.items():\n",
    "        q_proj_name = mt.attn_module_name_format.format(layer_idx) + \".q_proj\"\n",
    "        q_proj_module = get_module_nnsight(mt, q_proj_name)\n",
    "        q_proj_out = q_proj_module.output.view(\n",
    "            batch_size, -1, n_heads, head_dim\n",
    "        ).transpose(1, 2)\n",
    "        for head_idx, token_idx in query_locs:\n",
    "            q_proj_out[:, head_idx, token_idx, :] = cached_q_states[\n",
    "                (layer_idx, head_idx, token_idx)\n",
    "            ]\n",
    "    gen_out = mt.generator.output.save()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ceb51982",
   "metadata": {},
   "outputs": [],
   "source": [
    "start = clean_tokenized.input_ids.shape[1]\n",
    "mt.tokenizer.batch_decode(\n",
    "    gen_out.sequences[:, start:],\n",
    "    skip_special_tokens=True,\n",
    "    clean_up_tokenization_spaces=True,\n",
    ")[0].strip()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fece1ee1",
   "metadata": {},
   "outputs": [],
   "source": [
    "interpret_logits(\n",
    "    mt.tokenizer,\n",
    "    gen_out.scores[0],\n",
    "    interested_tokens=[\n",
    "        clean.ans_token_id,\n",
    "        clean.metadata[\"track_type_obj_token_id\"],\n",
    "    ],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2a0c415b",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "connection",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
