{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "61e7c525",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b2cd4155",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "\n",
    "import sys\n",
    "\n",
    "sys.path.append(\"../../\")\n",
    "\n",
    "##################################################################\n",
    "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1,2,3,4,5,6,7\"\n",
    "##################################################################\n",
    "\n",
    "import logging\n",
    "from src.utils import logging_utils\n",
    "from src.utils import env_utils\n",
    "\n",
    "logger = logging.getLogger(__name__)\n",
    "\n",
    "logging.basicConfig(\n",
    "    level=logging.DEBUG,\n",
    "    format=logging_utils.DEFAULT_FORMAT,\n",
    "    datefmt=logging_utils.DEFAULT_DATEFMT,\n",
    "    stream=sys.stdout,\n",
    ")\n",
    "\n",
    "import torch\n",
    "import transformers\n",
    "\n",
    "logger.info(f\"{torch.__version__=}, {torch.version.cuda=}\")\n",
    "logger.info(\n",
    "    f\"{torch.cuda.is_available()=}, {torch.cuda.device_count()=}, {torch.cuda.get_device_name()=}\"\n",
    ")\n",
    "logger.info(f\"{transformers.__version__=}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9d3b2420",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.utils.training_utils import get_device_map\n",
    "\n",
    "# model_key = \"meta-llama/Llama-3.2-3B\"\n",
    "# model_key = \"meta-llama/Llama-3.1-8B\"\n",
    "# model_key = \"meta-llama/Llama-3.1-70B-Instruct\"\n",
    "model_key = \"meta-llama/Llama-3.3-70B-Instruct\"\n",
    "# model_key = \"meta-llama/Llama-3.1-405B-Instruct\"\n",
    "\n",
    "# model_key = \"google/gemma-2-9b-it\"\n",
    "# model_key = \"google/gemma-3-12b-it\"\n",
    "# model_key = \"google/gemma-2-27b-it\"\n",
    "\n",
    "# model_key = \"deepseek-ai/DeepSeek-R1-Distill-Llama-8B\"\n",
    "\n",
    "# model_key = \"allenai/OLMo-2-1124-7B-Instruct\"\n",
    "# model_key = \"allenai/OLMo-7B-0424-hf\"\n",
    "\n",
    "# model_key = \"Qwen/Qwen2-7B\"\n",
    "# model_key = \"Qwen/Qwen2.5-14B-Instruct\"\n",
    "# model_key = \"Qwen/Qwen2.5-32B-Instruct\"\n",
    "# model_key = \"Qwen/Qwen2.5-72B-Instruct\"\n",
    "\n",
    "# model_key = \"Qwen/Qwen3-1.7B\"\n",
    "# model_key = \"Qwen/Qwen3-4B\"\n",
    "# model_key = \"Qwen/Qwen3-8B\"\n",
    "# model_key = \"Qwen/Qwen3-14B\"\n",
    "# model_key = \"Qwen/Qwen3-32B\"\n",
    "\n",
    "# device_map = get_device_map(model_key, 30, n_gpus=8)\n",
    "# device_map"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7b62bcde",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.models import ModelandTokenizer\n",
    "\n",
    "# from transformers import BitsAndBytesConfig\n",
    "\n",
    "mt = ModelandTokenizer(\n",
    "    model_key=model_key,\n",
    "    torch_dtype=torch.bfloat16,\n",
    "    # device_map=device_map,\n",
    "    device_map=\"auto\",\n",
    "    # quantization_config = BitsAndBytesConfig(\n",
    "    #     # load_in_4bit=True\n",
    "    #     load_in_8bit=True\n",
    "    # )\n",
    "    attn_implementation=\"eager\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f77abadf",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.data import SelectOneTask, CountingTask\n",
    "\n",
    "#################################################################################\n",
    "TASK_CLS = SelectOneTask\n",
    "prompt_template_idx = 3\n",
    "N_DISTRACTORS = 5\n",
    "OPTION_STYLE = \"single_line\"\n",
    "#################################################################################\n",
    "\n",
    "source_task = TASK_CLS.load(\n",
    "    path=os.path.join(\n",
    "        env_utils.DEFAULT_DATA_DIR, \n",
    "        \"selection\", \n",
    "        # \"profession.json\"\n",
    "        # \"nationality.json\"\n",
    "        \"objects.json\"\n",
    "    )\n",
    ")\n",
    "\n",
    "print(source_task)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "da783c78",
   "metadata": {},
   "outputs": [],
   "source": [
    "sample = source_task.get_random_sample(\n",
    "    mt = mt,\n",
    "    option_style=OPTION_STYLE,\n",
    "    prompt_template_idx=prompt_template_idx,\n",
    "    # category=\"actor\",\n",
    "    # category=\"Brazil\"\n",
    "    category=\"fruit\",\n",
    "    filter_by_lm_prediction=False,\n",
    ")\n",
    "\n",
    "print(sample.prompt(), \">>\", f'\"{mt.tokenizer.decode([sample.ans_token_id])}\"')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8cd154f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "sample.default_option_style = \"bulleted\"\n",
    "print(sample.prompt(), \">>\", f'\"{mt.tokenizer.decode([sample.ans_token_id])}\"')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8eb708dc",
   "metadata": {},
   "source": [
    "## Loading the heads"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "42843f91",
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib import pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "# optimized_path = os.path.join(\n",
    "#     env_utils.DEFAULT_RESULTS_DIR,\n",
    "#     \"selection/optimized_backup_heads\",\n",
    "#     mt.name.split(\"/\")[-1],\n",
    "#     f\"{select_task.task_name}.npz\"\n",
    "# )\n",
    "\n",
    "optimized_path = os.path.join(\n",
    "    env_utils.DEFAULT_RESULTS_DIR,\n",
    "    \"selection/optimized_heads\",\n",
    "    model_key.split(\"/\")[-1],\n",
    "    \"distinct_options\",\n",
    "    f\"{source_task.task_name}\",\n",
    "    # \"legacy\",\n",
    "    \"epoch_10.npz\"\n",
    ")\n",
    "\n",
    "# optimized_path = os.path.join(\n",
    "#     env_utils.DEFAULT_RESULTS_DIR,\n",
    "#     \"test_opt_code\",\n",
    "#     model_key.split(\"/\")[-1],\n",
    "#     \"distinct_options\",\n",
    "#     f\"{select_task.task_name}\",\n",
    "#     # \"select_one\",\n",
    "#     \"legacy\",\n",
    "#     \"epoch_10.npz\"\n",
    "# )\n",
    "\n",
    "optimization_results = np.load(optimized_path, allow_pickle=True)\n",
    "plt.plot(optimization_results[\"losses\"])\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bdcc867a",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(20, 10))\n",
    "\n",
    "optimal_head_mask = torch.tensor(optimization_results[\"optimal_mask\"]).to(torch.float32)\n",
    "optimal_head_mask[52:, :] = 0.0\n",
    "\n",
    "plt.imshow(\n",
    "    optimal_head_mask.T.numpy(),\n",
    "    cmap=\"Blues\",\n",
    "    aspect=\"auto\",\n",
    "    vmin=0,\n",
    "    vmax=1,\n",
    ")\n",
    "\n",
    "optimized_heads = torch.nonzero(optimal_head_mask > 0.5, as_tuple=False).tolist()\n",
    "optimized_heads = [\n",
    "    (layer_idx, head_idx) for layer_idx, head_idx in optimized_heads\n",
    "]\n",
    "print(len(optimized_heads))\n",
    "\n",
    "HEADS = optimized_heads\n",
    "\n",
    "(35, 19) in HEADS, (35, 19) in optimized_heads\n",
    "# [(29, 3) in HEADS]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6bb620ae",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.attention import get_attention_matrices\n",
    "from src.selection.functional import (\n",
    "    verify_head_patterns,\n",
    "    get_patches_to_verify_independent_enrichment,\n",
    ")\n",
    "\n",
    "attn_pattern = verify_head_patterns(\n",
    "    prompt=sample.prompt(option_style=\"single_line\"),\n",
    "    options=sample.options,\n",
    "    mt=mt,\n",
    "    heads=optimized_heads,\n",
    "    # heads = HEADS,\n",
    "    # heads = [(35, 19)],\n",
    "    start_from=1\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "506b8618",
   "metadata": {},
   "source": [
    "## Checking the effect of formatting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "85ae8aa1",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.data import get_counterfactual_samples_within_task\n",
    "\n",
    "clean_sample, patch_sample = get_counterfactual_samples_within_task(\n",
    "    task=source_task,\n",
    "    mt=mt,\n",
    "    patch_category=\"fruit\",\n",
    "    clean_category=\"vehicle\",\n",
    "    distinct_options=True,\n",
    "    n_distractors=5,\n",
    "    filter_by_lm_prediction=False,\n",
    ")\n",
    "\n",
    "clean_sample.default_option_style = \"bulleted\"\n",
    "patch_sample.default_option_style = \"single_line\"\n",
    "\n",
    "assert clean_sample.default_option_style != patch_sample.default_option_style\n",
    "\n",
    "print(\n",
    "    \"CLEAN:\",\n",
    "    clean_sample.prompt(),\n",
    "    \">>\",\n",
    "    f'\"{mt.tokenizer.decode([clean_sample.ans_token_id])}\"',\n",
    ")\n",
    "print(\n",
    "    \"PATCH:\",\n",
    "    patch_sample.prompt(),\n",
    "    \">>\",\n",
    "    f'\"{mt.tokenizer.decode([patch_sample.ans_token_id])}\"',\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "62f2de39",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "326d58d6",
   "metadata": {},
   "outputs": [],
   "source": [
    "source_task.categories"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d17fff0b",
   "metadata": {},
   "source": [
    "# Validating Against Other Reduce Tasks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c5428ff0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "from src.selection.data import get_counterfactual_samples_interface\n",
    "from src.selection.data import YesNoTask, SelectFirstTask, CountingTask, SelectLastTask\n",
    "from typing import Optional\n",
    "import copy\n",
    "from src.tokens import prepare_input\n",
    "from src.selection.utils import verify_correct_option\n",
    "from src.selection.data import get_options_for_answer\n",
    "\n",
    "\n",
    "def get_task_specific_kwargs(task, distinct_options=True):\n",
    "    kwargs = {}\n",
    "    if isinstance(task, CountingTask):\n",
    "        kwargs[\"clean_n_options\"] = random.choice(range(4, 7))\n",
    "        kwargs[\"patch_n_options\"] = random.choice(range(4, 7))\n",
    "        kwargs[\"distinct_options\"] = distinct_options\n",
    "    elif isinstance(task, YesNoTask):\n",
    "        kwargs[\"clean_n_options\"] = random.choice(range(3, 6))\n",
    "        kwargs[\"patch_n_options\"] = random.choice(range(3, 6))\n",
    "        # No distinct options for yes/no task\n",
    "    elif isinstance(task, SelectFirstTask | SelectLastTask):\n",
    "        #! this has to come before SelectOneTask since SelectFirstTask is a subclass of SelectOneTask\n",
    "        kwargs[\"distinct_options\"] = distinct_options\n",
    "        kwargs[\"n_distractors\"] = random.choice(range(3, 6))\n",
    "    elif isinstance(task, SelectOneTask):\n",
    "        kwargs[\"distinct_options\"] = distinct_options\n",
    "        kwargs[\"patch_n_distractors\"] = random.choice(range(2, 7))\n",
    "        kwargs[\"clean_n_distractors\"] = random.choice(range(2, 7))\n",
    "    return kwargs\n",
    "\n",
    "\n",
    "def get_counterfactual_samples_across_tasks(\n",
    "    mt: ModelandTokenizer,\n",
    "    patch_task,\n",
    "    clean_task,\n",
    "    patch_category: str | None = None,\n",
    "    clean_category: str | None = None,\n",
    "    patch_prompt_template_idx: int = 3,\n",
    "    clean_prompt_template_idx: int = 3,\n",
    "    clean_transform: Optional[callable] = None,\n",
    "    patch_transform: Optional[callable] = None,\n",
    "    filter_by_lm_prediction=False,\n",
    "    retry_count=0,\n",
    "):\n",
    "    categories = patch_task.categories\n",
    "    for category in clean_task.categories:\n",
    "        assert category in categories, \"Categories must be the same!\"\n",
    "\n",
    "    patch_category = patch_category or random.choice(categories)\n",
    "    clean_category = clean_category or random.choice(\n",
    "        list(set(categories) - {patch_category})\n",
    "    )\n",
    "\n",
    "    assert patch_category != clean_category, \"Categories must be different!\"\n",
    "\n",
    "    patch_sample, _ = get_counterfactual_samples_interface[patch_task.task_name](\n",
    "        mt=mt,\n",
    "        task=patch_task,\n",
    "        patch_category=patch_category,\n",
    "        clean_category=clean_category,\n",
    "        prompt_template_idx=patch_prompt_template_idx,\n",
    "        filter_by_lm_prediction=False,\n",
    "        **get_task_specific_kwargs(patch_task, distinct_options=True),\n",
    "    )\n",
    "\n",
    "    _, clean_sample = get_counterfactual_samples_interface[clean_task.task_name](\n",
    "        mt=mt,\n",
    "        task=clean_task,\n",
    "        patch_category=patch_category,\n",
    "        clean_category=clean_category,\n",
    "        prompt_template_idx=clean_prompt_template_idx,\n",
    "        filter_by_lm_prediction=False,\n",
    "        **get_task_specific_kwargs(clean_task, distinct_options=True),\n",
    "    )\n",
    "\n",
    "    if patch_transform is not None:\n",
    "        patch_sample = patch_transform(patch_sample)\n",
    "    if clean_transform is not None:\n",
    "        clean_sample = clean_transform(clean_sample)\n",
    "\n",
    "    if \"qwen\" in mt.name.lower():\n",
    "        # for attention sink\n",
    "        clean_sample.prompt_template = (\n",
    "            \"# \" + clean_sample.prompt_template\n",
    "            if not clean_sample.prompt_template.startswith(\"#\")\n",
    "            else clean_sample.prompt_template\n",
    "        )\n",
    "        patch_sample.prompt_template = (\n",
    "            \"# \" + patch_sample.prompt_template\n",
    "            if not patch_sample.prompt_template.startswith(\"#\")\n",
    "            else patch_sample.prompt_template\n",
    "        )\n",
    "\n",
    "    if filter_by_lm_prediction:\n",
    "        test_samples = [patch_sample, clean_sample]\n",
    "        gold_sample = copy.deepcopy(clean_sample)\n",
    "        gold_sample.category = clean_sample.metadata[\"track_category\"]\n",
    "        gold_sample.ans_token_id = clean_sample.metadata[\"track_type_obj_token_id\"]\n",
    "        test_samples.append(gold_sample)\n",
    "\n",
    "        for sample in test_samples:\n",
    "            prompt = sample.prompt()\n",
    "            tokenized_inputs = prepare_input(prompts=prompt, tokenizer=mt.tokenizer)\n",
    "            sample.metadata[\"tokenized\"] = tokenized_inputs.data\n",
    "\n",
    "            print(\"-\" * 80)\n",
    "            print(sample.prompt(), \">>\", mt.tokenizer.decode(sample.ans_token_id))\n",
    "\n",
    "            is_correct, predictions, track_options = verify_correct_option(\n",
    "                mt=mt,\n",
    "                input=tokenized_inputs,\n",
    "                target=sample.ans_token_id,\n",
    "                options=get_options_for_answer(sample),\n",
    "            )\n",
    "\n",
    "            if is_correct is False:\n",
    "                logger.error(\n",
    "                    f'Prediction mismatch: {track_options[list(track_options.keys())[0]]}[\"{mt.tokenizer.decode(predictions[0].token_id)}\"] != {sample.ans_token_id}[\"{mt.tokenizer.decode(sample.ans_token_id)}\"]'\n",
    "                )\n",
    "                # for debugging\n",
    "                if retry_count > 20:\n",
    "                    raise ValueError(f\"Max retries ({retry_count}) exceeded!\")\n",
    "\n",
    "                return get_counterfactual_samples_across_tasks(\n",
    "                    mt=mt,\n",
    "                    patch_task=patch_task,\n",
    "                    clean_task=clean_task,\n",
    "                    patch_category=patch_category,\n",
    "                    clean_category=clean_category,\n",
    "                    patch_prompt_template_idx=patch_prompt_template_idx,\n",
    "                    clean_prompt_template_idx=clean_prompt_template_idx,\n",
    "                    clean_transform=clean_transform,\n",
    "                    patch_transform=patch_transform,\n",
    "                    filter_by_lm_prediction=filter_by_lm_prediction,\n",
    "                    retry_count=retry_count + 1,\n",
    "                )\n",
    "\n",
    "    return patch_sample, clean_sample"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7556f772",
   "metadata": {},
   "source": [
    "### Select One -- MCQ"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c1776c1f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.data import SelectionSample, SelectOneTask\n",
    "\n",
    "select_one_mcq = SelectOneTask.load(\n",
    "    path=os.path.join(\n",
    "        env_utils.DEFAULT_DATA_DIR, \n",
    "        \"selection\", \n",
    "        \"objects.json\"\n",
    "        # \"profession.json\"\n",
    "        # \"nationality.json\"\n",
    "        # \"landmarks.json\"\n",
    "    )\n",
    ")\n",
    "print(select_one_mcq)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4affdd7f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy\n",
    "from src.selection.utils import get_first_token_id\n",
    "from src.functional import predict_next_token\n",
    "from src.selection.data import MCQify_sample\n",
    "\n",
    "test_sample = select_one_mcq.get_random_sample(\n",
    "    mt=mt,\n",
    "    option_style=OPTION_STYLE,\n",
    "    prompt_template_idx=3,\n",
    "    category=\"fruit\",\n",
    "    # category=\"actor\",\n",
    "    # category=\"United Kingdom\",\n",
    "    filter_by_lm_prediction=True,\n",
    ")\n",
    "\n",
    "test_sample = MCQify_sample(mt, test_sample)\n",
    "print(\n",
    "    test_sample.prompt(), \">>\", f'\"{mt.tokenizer.decode([test_sample.ans_token_id])}\"'\n",
    ")\n",
    "\n",
    "predict_next_token(mt=mt, inputs=test_sample.prompt())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "75a70eae",
   "metadata": {},
   "outputs": [],
   "source": [
    "def mcq_transform(sample):\n",
    "    sample = MCQify_sample(mt=mt, sample=sample)\n",
    "    if \"track_type_obj_token_id\" in sample.metadata:\n",
    "        patch_obj_idx = sample.metadata[\"track_type_obj_idx\"]\n",
    "        sample.metadata[\"track_type_obj_token_id\"] = get_first_token_id(\n",
    "            name=chr(ord(\"a\") + patch_obj_idx), tokenizer=mt.tokenizer, prefix=\" \"\n",
    "        )\n",
    "    return sample\n",
    "\n",
    "\n",
    "patch_sample, clean_sample = get_counterfactual_samples_across_tasks(\n",
    "    mt=mt,\n",
    "    patch_task=source_task,\n",
    "    clean_task=select_one_mcq,\n",
    "    patch_category=\"fruit\",\n",
    "    clean_category=\"vehicle\",\n",
    "    clean_transform=mcq_transform,\n",
    "    filter_by_lm_prediction=True\n",
    ")\n",
    "\n",
    "print(\n",
    "    \"CLEAN:\",\n",
    "    clean_sample.prompt(),\n",
    "    \">>\",\n",
    "    f'\"{mt.tokenizer.decode([clean_sample.ans_token_id])}\"',\n",
    ")\n",
    "print(\n",
    "    \"PATCH:\",\n",
    "    patch_sample.prompt(),\n",
    "    \">>\",\n",
    "    f'\"{mt.tokenizer.decode([patch_sample.ans_token_id])}\"',\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0e5e09b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.data import CounterFactualSamplePair\n",
    "from src.functional import free_gpu_cache\n",
    "from src.selection.data import get_counterfactual_samples_interface\n",
    "import random\n",
    "\n",
    "validation_samples_save_path = os.path.join(\n",
    "    env_utils.DEFAULT_RESULTS_DIR,\n",
    "    \"selection\",\n",
    "    \"samples\",\n",
    "    \"predicate_generalization\",\n",
    "    mt.name.split(\"/\")[-1],\n",
    "    source_task.task_name,\n",
    "    select_one_mcq.task_name + \"_mcq\",\n",
    "    \"objects\",\n",
    ")\n",
    "\n",
    "os.makedirs(validation_samples_save_path, exist_ok=True)\n",
    "\n",
    "\n",
    "free_gpu_cache()\n",
    "validation_set = []\n",
    "validation_limit = 512\n",
    "start_from = 1\n",
    "\n",
    "counterfactual_sampler = get_counterfactual_samples_interface[select_one_mcq.task_name]\n",
    "\n",
    "while len(validation_set) < validation_limit:\n",
    "    print(f\"sample {len(validation_set)+1} / {validation_limit}\")\n",
    "    patch, clean = get_counterfactual_samples_across_tasks(\n",
    "        mt=mt,\n",
    "        patch_task=source_task,\n",
    "        clean_task=select_one_mcq,\n",
    "        clean_transform=mcq_transform,\n",
    "        filter_by_lm_prediction=True,\n",
    "    )\n",
    "    validation_set.append((clean, patch))\n",
    "    cf_pair = CounterFactualSamplePair(\n",
    "        patch_sample=patch,\n",
    "        clean_sample=clean,\n",
    "    )\n",
    "    cf_pair.detensorize()\n",
    "    with open(\n",
    "        os.path.join(\n",
    "            validation_samples_save_path,\n",
    "            f\"{len(validation_set) + start_from - 1:05d}.json\",\n",
    "        ),\n",
    "        \"w\",\n",
    "    ) as f:\n",
    "        json.dump(cf_pair.to_dict(), f, indent=2)\n",
    "\n",
    "len(validation_set)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b15df782",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.data import CounterFactualSamplePair\n",
    "import random\n",
    "\n",
    "validation_set = []\n",
    "validation_limit = 512\n",
    "\n",
    "validation_samples_load_path = os.path.join(\n",
    "    env_utils.DEFAULT_RESULTS_DIR,\n",
    "    \"selection\",\n",
    "    \"samples\",\n",
    "    \"predicate_generalization\",\n",
    "    mt.name.split(\"/\")[-1],\n",
    "    source_task.task_name,\n",
    "    select_one_mcq.task_name + \"_mcq\",\n",
    "    \"objects\",\n",
    ")\n",
    "\n",
    "sample_files = [\n",
    "    os.path.join(validation_samples_load_path, f)\n",
    "    for f in os.listdir(validation_samples_load_path)\n",
    "    if f.endswith(\".json\")\n",
    "]\n",
    "logger.info(f\"Found {len(sample_files)} sample files\")\n",
    "\n",
    "prefix = \"\"\n",
    "# prefix = \"Recall the nationality of these people:\\n\"\n",
    "# prefix = \"Recall which country these landmarks are located in:\\n\"\n",
    "# prefix = \"Think about how these words sound when you say them aloud:\\n\"\n",
    "\n",
    "random.shuffle(sample_files)\n",
    "sample_files = sample_files[:validation_limit]\n",
    "for sample_file in sample_files:\n",
    "    with open(sample_file, \"r\") as f:\n",
    "        cf_pair_data = json.load(f)\n",
    "    cf_pair = CounterFactualSamplePair.from_dict(cf_pair_data)\n",
    "\n",
    "    cf_pair.clean_sample.prompt_template = prefix + cf_pair.clean_sample.prompt_template\n",
    "    cf_pair.patch_sample.prompt_template = prefix + cf_pair.patch_sample.prompt_template\n",
    "    validation_set.append((cf_pair.clean_sample, cf_pair.patch_sample))\n",
    "\n",
    "len(validation_set)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb073450",
   "metadata": {},
   "outputs": [],
   "source": [
    "clean, patch = validation_set[3]\n",
    "print(patch.prompt(), \">>\", mt.tokenizer.decode(patch.ans_token_id))\n",
    "print(clean.prompt(), \">>\", mt.tokenizer.decode(clean.ans_token_id))\n",
    "clean.metadata[\"track_type_obj_token_id\"], mt.tokenizer.decode(clean.metadata[\"track_type_obj_token_id\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d68bade0",
   "metadata": {},
   "outputs": [],
   "source": [
    "free_gpu_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d3516b76",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.optimization import validate_q_proj_ie_on_sample_pair\n",
    "import copy\n",
    "\n",
    "clean, patch = copy.deepcopy(validation_set[3])\n",
    "\n",
    "val_sample_result = validate_q_proj_ie_on_sample_pair(\n",
    "    mt=mt,\n",
    "    clean_sample=clean,\n",
    "    patch_sample=patch,\n",
    "    heads=optimized_heads,\n",
    "    query_indices={-2: -2, -1: -1},\n",
    "    add_ques_pos_to_query_indices=True,\n",
    "    verify_head_behavior_on=-1,\n",
    "    patch_args={\n",
    "        \"batch_size\": len(patch.options),\n",
    "        \"distinct_options\": False,\n",
    "        # \"task\": select_task,\n",
    "        # \"prompt_template_idx\": prompt_template_idx,\n",
    "        # \"option_style\": patch.default_option_style,\n",
    "        # \"n_distractors\": N_DISTRACTORS,\n",
    "    },\n",
    ")\n",
    "\n",
    "clean_obj = clean.ans_token_id\n",
    "target_obj = clean.metadata[\"track_type_obj_token_id\"]\n",
    "\n",
    "logger.debug(f\"clean obj: {mt.tokenizer.decode(clean_obj)}\")\n",
    "logger.debug(f\"target obj: {mt.tokenizer.decode(target_obj)}\")\n",
    "\n",
    "before_intervention = {\n",
    "    \"clean_rank\": val_sample_result[\"clean_track\"][clean_obj][0],\n",
    "    \"clean_logit\": val_sample_result[\"clean_track\"][clean_obj][1].logit,\n",
    "    \"target_rank\": val_sample_result[\"clean_track\"][target_obj][0],\n",
    "    \"target_logit\": val_sample_result[\"clean_track\"][target_obj][1].logit,\n",
    "}\n",
    "\n",
    "after_intervention = {\n",
    "    \"clean_rank\": val_sample_result[\"int_track\"][clean_obj][0],\n",
    "    \"clean_logit\": val_sample_result[\"int_track\"][clean_obj][1].logit,\n",
    "    \"target_rank\": val_sample_result[\"int_track\"][target_obj][0],\n",
    "    \"target_logit\": val_sample_result[\"int_track\"][target_obj][1].logit,\n",
    "}\n",
    "\n",
    "clean_rank_delta = after_intervention[\"clean_rank\"] - before_intervention[\"clean_rank\"]\n",
    "target_rank_delta = (\n",
    "    after_intervention[\"target_rank\"] - before_intervention[\"target_rank\"]\n",
    ")\n",
    "logger.info(\n",
    "    f\"Clean Prediction Rank Change: {before_intervention['clean_rank']} -> {after_intervention['clean_rank']} | Delta: {clean_rank_delta} \"\n",
    ")\n",
    "logger.info(\n",
    "    f\"Target Prediction Rank Change: {before_intervention['target_rank']} -> {after_intervention['target_rank']} | Delta: {target_rank_delta} \"\n",
    ")\n",
    "\n",
    "clean_logit_delta = (\n",
    "    after_intervention[\"clean_logit\"] - before_intervention[\"clean_logit\"]\n",
    ")\n",
    "target_logit_delta = (\n",
    "    after_intervention[\"target_logit\"] - before_intervention[\"target_logit\"]\n",
    ")\n",
    "logger.info(\n",
    "    f\"Clean Prediction Logit Change: {before_intervention['clean_logit']:.4f} -> {after_intervention['clean_logit']:.4f} | Delta: {clean_logit_delta:.4f} \"\n",
    ")\n",
    "logger.info(\n",
    "    f\"Target Prediction Logit Change: {before_intervention['target_logit']:.4f} -> {after_intervention['target_logit']:.4f} | Delta: {target_logit_delta:.4f} \"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f24e33ac",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm.auto import tqdm\n",
    "\n",
    "validation_results = []\n",
    "\n",
    "for clean_sample, patch_sample in tqdm(validation_set):\n",
    "    val_sample_result = validate_q_proj_ie_on_sample_pair(\n",
    "        mt=mt,\n",
    "        clean_sample=clean_sample,\n",
    "        patch_sample=patch_sample,\n",
    "        heads=optimized_heads,\n",
    "        query_indices={-2: -2, -1: -1},\n",
    "        add_ques_pos_to_query_indices=True,\n",
    "        patch_args={\n",
    "            \"batch_size\": len(patch_sample.options),\n",
    "            \"distinct_options\": False,\n",
    "        },\n",
    "    )\n",
    "    validation_results.append(val_sample_result)\n",
    "    print(\"=\" * 80)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f928f60",
   "metadata": {},
   "outputs": [],
   "source": [
    "before_intervention = []\n",
    "after_intervention = []\n",
    "\n",
    "for intervention_result in validation_results:\n",
    "    clean_sample = intervention_result[\"clean_sample\"]\n",
    "    patch_sample = intervention_result[\"patch_sample\"]\n",
    "\n",
    "    clean_obj = clean_sample.ans_token_id\n",
    "    target_obj = clean_sample.metadata[\"track_type_obj_token_id\"]\n",
    "\n",
    "    before_intervention.append({\n",
    "        \"clean_rank\": intervention_result[\"clean_track\"][clean_obj][0],\n",
    "        \"clean_logit\": intervention_result[\"clean_track\"][clean_obj][1].logit,\n",
    "        \"target_rank\": intervention_result[\"clean_track\"][target_obj][0],\n",
    "        \"target_logit\": intervention_result[\"clean_track\"][target_obj][1].logit,\n",
    "    })\n",
    "\n",
    "    after_intervention.append({\n",
    "        \"clean_rank\": intervention_result[\"int_track\"][clean_obj][0],\n",
    "        \"clean_logit\": intervention_result[\"int_track\"][clean_obj][1].logit,\n",
    "        \"target_rank\": intervention_result[\"int_track\"][target_obj][0],\n",
    "        \"target_logit\": intervention_result[\"int_track\"][target_obj][1].logit,\n",
    "    })"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4738762f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "clean_rank_delta = [\n",
    "    after[\"clean_rank\"] - before[\"clean_rank\"]\n",
    "    for before, after in zip(before_intervention, after_intervention)\n",
    "]\n",
    "target_rank_delta = [\n",
    "    after[\"target_rank\"] - before[\"target_rank\"]\n",
    "    for before, after in zip(before_intervention, after_intervention)\n",
    "]\n",
    "\n",
    "clean_rank_delta, target_rank_delta = np.array(clean_rank_delta), np.array(\n",
    "    target_rank_delta\n",
    ")\n",
    "print(f\"clean_rank_delta: {clean_rank_delta.mean():.4f} ± {clean_rank_delta.std():.4f}\")\n",
    "print(\n",
    "    f\"target_rank_delta: {target_rank_delta.mean():.4f} ± {target_rank_delta.std():.4f}\"\n",
    ")\n",
    "\n",
    "clean_rank_after_intervention = [after[\"clean_rank\"] for after in after_intervention]\n",
    "clean_rank_after_intervention = np.array(clean_rank_after_intervention)\n",
    "print(\n",
    "    f\"clean_rank_after_intervention: {clean_rank_after_intervention.mean():.4f} ± {clean_rank_after_intervention.std():.4f}\"\n",
    ")\n",
    "\n",
    "target_rank_after_intervention = [after[\"target_rank\"] for after in after_intervention]\n",
    "target_rank_after_intervention = np.array(target_rank_after_intervention)\n",
    "print(\n",
    "    f\"target_rank_after_intervention: {target_rank_after_intervention.mean():.4f} ± {target_rank_after_intervention.std():.4f}\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "00d73bab",
   "metadata": {},
   "outputs": [],
   "source": [
    "clean_logit_delta = [\n",
    "    after[\"clean_logit\"] - before[\"clean_logit\"]\n",
    "    for before, after in zip(before_intervention, after_intervention)\n",
    "]\n",
    "target_logit_delta = [\n",
    "    after[\"target_logit\"] - before[\"target_logit\"]\n",
    "    for before, after in zip(before_intervention, after_intervention)\n",
    "]\n",
    "clean_logit_delta, target_logit_delta = np.array(clean_logit_delta), np.array(target_logit_delta)\n",
    "print(f\"clean_logit_delta: {clean_logit_delta.mean():.4f} ± {clean_logit_delta.std():.4f}\")\n",
    "print(f\"target_logit_delta: {target_logit_delta.mean():.4f} ± {target_logit_delta.std():.4f}\")\n",
    "\n",
    "clean_logit_after_intervention = [\n",
    "    after[\"clean_logit\"]\n",
    "    for after in after_intervention\n",
    "]\n",
    "clean_logit_after_intervention = np.array(clean_logit_after_intervention)\n",
    "print(f\"clean_logit_after_intervention: {clean_logit_after_intervention.mean():.4f} ± {clean_logit_after_intervention.std():.4f}\")\n",
    "\n",
    "target_logit_after_intervention = [\n",
    "    after[\"target_logit\"]\n",
    "    for after in after_intervention\n",
    "]\n",
    "target_logit_after_intervention = np.array(target_logit_after_intervention)\n",
    "print(f\"target_logit_after_intervention: {target_logit_after_intervention.mean():.4f} ± {target_logit_after_intervention.std():.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49e5685d",
   "metadata": {},
   "outputs": [],
   "source": [
    "top_1 = sum([1 for after in after_intervention if after[\"target_rank\"] == 1])\n",
    "top_1 / len(after_intervention)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "87045d6c",
   "metadata": {},
   "outputs": [],
   "source": [
    "counter_patch_type_top_option = 0\n",
    "failed_cases = []\n",
    "\n",
    "for intervention_result in validation_results:\n",
    "    clean_sample = intervention_result[\"clean_sample\"]\n",
    "    patch_sample = intervention_result[\"patch_sample\"]\n",
    "    int_track = intervention_result[\"int_track\"]\n",
    "    clean_track = intervention_result[\"clean_track\"]\n",
    "    if (\n",
    "        int_track[list(int_track.keys())[0]][1].token_id\n",
    "        == clean_sample.metadata[\"track_type_obj_token_id\"]\n",
    "    ): \n",
    "        counter_patch_type_top_option += 1\n",
    "    else:\n",
    "        failed_cases.append(\n",
    "            {\n",
    "                \"clean_sample\": clean_sample,\n",
    "                \"patch_sample\": patch_sample,\n",
    "                \"int_track\": int_track,\n",
    "                \"clean_track\": clean_track,\n",
    "            }\n",
    "        )\n",
    "\n",
    "top_1_accuracy = counter_patch_type_top_option / len(validation_results)\n",
    "print(\"=\" * 80)\n",
    "print(\n",
    "    f\"Counterfactual patching accuracy: {top_1_accuracy:.4f} ({counter_patch_type_top_option}/{len(validation_results)})\"\n",
    ")\n",
    "print(\"=\" * 80)\n",
    "print(f\"{len(failed_cases)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a6bd2b5c",
   "metadata": {},
   "source": [
    "### SelectFirstTask"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7d29f1b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.data import SelectFirstTask\n",
    "\n",
    "select_first_task = SelectFirstTask.load(\n",
    "    path=os.path.join(\n",
    "        env_utils.DEFAULT_DATA_DIR, \n",
    "        \"selection\", \n",
    "        \"objects.json\"\n",
    "    )\n",
    ")\n",
    "print(select_first_task)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d0144dc5",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_sample = select_first_task.get_random_sample(\n",
    "    mt = mt,\n",
    "    option_style=\"single_line\",\n",
    "    prompt_template_idx=3,\n",
    "    category=\"fruit\",\n",
    "    filter_by_lm_prediction=True,\n",
    ")\n",
    "print(test_sample.prompt(), \">>\", f'\"{mt.tokenizer.decode([test_sample.ans_token_id])}\"')\n",
    "test_sample.prediction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2535336e",
   "metadata": {},
   "outputs": [],
   "source": [
    "patch_sample, clean_sample = get_counterfactual_samples_across_tasks(\n",
    "    mt=mt,\n",
    "    patch_task=source_task,\n",
    "    clean_task=select_first_task,\n",
    "    patch_category=\"fruit\",\n",
    "    clean_category=\"vehicle\",\n",
    "    clean_transform=None,\n",
    "    filter_by_lm_prediction=True\n",
    ")\n",
    "\n",
    "print(\n",
    "    \"CLEAN:\",\n",
    "    clean_sample.prompt(),\n",
    "    \">>\",\n",
    "    f'\"{mt.tokenizer.decode([clean_sample.ans_token_id])}\"',\n",
    ")\n",
    "print(\n",
    "    \"PATCH:\",\n",
    "    patch_sample.prompt(),\n",
    "    \">>\",\n",
    "    f'\"{mt.tokenizer.decode([patch_sample.ans_token_id])}\"',\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "35d92ac7",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.data import CounterFactualSamplePair\n",
    "from src.functional import free_gpu_cache\n",
    "from src.selection.data import get_counterfactual_samples_interface\n",
    "import random\n",
    "\n",
    "validation_samples_save_path = os.path.join(\n",
    "    env_utils.DEFAULT_RESULTS_DIR,\n",
    "    \"selection\",\n",
    "    \"samples\",\n",
    "    \"predicate_generalization\",\n",
    "    mt.name.split(\"/\")[-1],\n",
    "    source_task.task_name,\n",
    "    select_first_task.task_name,\n",
    "    \"objects\",\n",
    ")\n",
    "\n",
    "os.makedirs(validation_samples_save_path, exist_ok=True)\n",
    "\n",
    "\n",
    "free_gpu_cache()\n",
    "validation_set = []\n",
    "validation_limit = 512\n",
    "start_from = 1\n",
    "\n",
    "while len(validation_set) < validation_limit:\n",
    "    print(f\"sample {len(validation_set)+1} / {validation_limit}\")\n",
    "    patch, clean = get_counterfactual_samples_across_tasks(\n",
    "        mt=mt,\n",
    "        patch_task=source_task,\n",
    "        clean_task=select_first_task,\n",
    "        clean_transform=None,\n",
    "        filter_by_lm_prediction=True,\n",
    "    )\n",
    "    validation_set.append((clean, patch))\n",
    "    cf_pair = CounterFactualSamplePair(\n",
    "        patch_sample=patch,\n",
    "        clean_sample=clean,\n",
    "    )\n",
    "    cf_pair.detensorize()\n",
    "    with open(\n",
    "        os.path.join(\n",
    "            validation_samples_save_path,\n",
    "            f\"{len(validation_set) + start_from - 1:05d}.json\",\n",
    "        ),\n",
    "        \"w\",\n",
    "    ) as f:\n",
    "        json.dump(cf_pair.to_dict(), f, indent=2)\n",
    "\n",
    "len(validation_set)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a5d15ccc",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.data import CounterFactualSamplePair\n",
    "import random\n",
    "\n",
    "validation_set = []\n",
    "validation_limit = 512\n",
    "\n",
    "validation_samples_load_path = os.path.join(\n",
    "    env_utils.DEFAULT_RESULTS_DIR,\n",
    "    \"selection\",\n",
    "    \"samples\",\n",
    "    \"predicate_generalization\",\n",
    "    mt.name.split(\"/\")[-1],\n",
    "    source_task.task_name,\n",
    "    select_first_task.task_name,\n",
    "    \"objects\",\n",
    ")\n",
    "\n",
    "sample_files = [\n",
    "    os.path.join(validation_samples_load_path, f)\n",
    "    for f in os.listdir(validation_samples_load_path)\n",
    "    if f.endswith(\".json\")\n",
    "]\n",
    "logger.info(f\"Found {len(sample_files)} sample files\")\n",
    "\n",
    "prefix = \"\"\n",
    "# prefix = \"Recall the nationality of these people:\\n\"\n",
    "# prefix = \"Recall which country these landmarks are located in:\\n\"\n",
    "# prefix = \"Think about how these words sound when you say them aloud:\\n\"\n",
    "\n",
    "random.shuffle(sample_files)\n",
    "sample_files = sample_files[:validation_limit]\n",
    "for sample_file in sample_files:\n",
    "    with open(sample_file, \"r\") as f:\n",
    "        cf_pair_data = json.load(f)\n",
    "    cf_pair = CounterFactualSamplePair.from_dict(cf_pair_data)\n",
    "\n",
    "    cf_pair.clean_sample.prompt_template = prefix + cf_pair.clean_sample.prompt_template\n",
    "    cf_pair.patch_sample.prompt_template = prefix + cf_pair.patch_sample.prompt_template\n",
    "    validation_set.append((cf_pair.clean_sample, cf_pair.patch_sample))\n",
    "\n",
    "len(validation_set)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "24817042",
   "metadata": {},
   "outputs": [],
   "source": [
    "clean, patch = validation_set[3]\n",
    "print(patch.prompt(), \">>\", mt.tokenizer.decode(patch.ans_token_id))\n",
    "print(clean.prompt(), \">>\", mt.tokenizer.decode(clean.ans_token_id))\n",
    "clean.metadata[\"track_type_obj_token_id\"], mt.tokenizer.decode(clean.metadata[\"track_type_obj_token_id\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "440b7265",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.optimization import validate_q_proj_ie_on_sample_pair\n",
    "import copy\n",
    "\n",
    "clean, patch = copy.deepcopy(validation_set[3])\n",
    "\n",
    "val_sample_result = validate_q_proj_ie_on_sample_pair(\n",
    "    mt=mt,\n",
    "    clean_sample=clean,\n",
    "    patch_sample=patch,\n",
    "    heads=optimized_heads,\n",
    "    query_indices={-2: -2, -1: -1},\n",
    "    add_ques_pos_to_query_indices=True,\n",
    "    verify_head_behavior_on=-1,\n",
    "    patch_args={\n",
    "        \"batch_size\": len(patch.options),\n",
    "        \"distinct_options\": False,\n",
    "        # \"task\": select_task,\n",
    "        # \"prompt_template_idx\": prompt_template_idx,\n",
    "        # \"option_style\": patch.default_option_style,\n",
    "        # \"n_distractors\": N_DISTRACTORS,\n",
    "    },\n",
    ")\n",
    "\n",
    "clean_obj = clean.ans_token_id\n",
    "target_obj = clean.metadata[\"track_type_obj_token_id\"]\n",
    "\n",
    "logger.debug(f\"clean obj: {mt.tokenizer.decode(clean_obj)}\")\n",
    "logger.debug(f\"target obj: {mt.tokenizer.decode(target_obj)}\")\n",
    "\n",
    "before_intervention = {\n",
    "    \"clean_rank\": val_sample_result[\"clean_track\"][clean_obj][0],\n",
    "    \"clean_logit\": val_sample_result[\"clean_track\"][clean_obj][1].logit,\n",
    "    \"target_rank\": val_sample_result[\"clean_track\"][target_obj][0],\n",
    "    \"target_logit\": val_sample_result[\"clean_track\"][target_obj][1].logit,\n",
    "}\n",
    "\n",
    "after_intervention = {\n",
    "    \"clean_rank\": val_sample_result[\"int_track\"][clean_obj][0],\n",
    "    \"clean_logit\": val_sample_result[\"int_track\"][clean_obj][1].logit,\n",
    "    \"target_rank\": val_sample_result[\"int_track\"][target_obj][0],\n",
    "    \"target_logit\": val_sample_result[\"int_track\"][target_obj][1].logit,\n",
    "}\n",
    "\n",
    "clean_rank_delta = after_intervention[\"clean_rank\"] - before_intervention[\"clean_rank\"]\n",
    "target_rank_delta = (\n",
    "    after_intervention[\"target_rank\"] - before_intervention[\"target_rank\"]\n",
    ")\n",
    "logger.info(\n",
    "    f\"Clean Prediction Rank Change: {before_intervention['clean_rank']} -> {after_intervention['clean_rank']} | Delta: {clean_rank_delta} \"\n",
    ")\n",
    "logger.info(\n",
    "    f\"Target Prediction Rank Change: {before_intervention['target_rank']} -> {after_intervention['target_rank']} | Delta: {target_rank_delta} \"\n",
    ")\n",
    "\n",
    "clean_logit_delta = (\n",
    "    after_intervention[\"clean_logit\"] - before_intervention[\"clean_logit\"]\n",
    ")\n",
    "target_logit_delta = (\n",
    "    after_intervention[\"target_logit\"] - before_intervention[\"target_logit\"]\n",
    ")\n",
    "logger.info(\n",
    "    f\"Clean Prediction Logit Change: {before_intervention['clean_logit']:.4f} -> {after_intervention['clean_logit']:.4f} | Delta: {clean_logit_delta:.4f} \"\n",
    ")\n",
    "logger.info(\n",
    "    f\"Target Prediction Logit Change: {before_intervention['target_logit']:.4f} -> {after_intervention['target_logit']:.4f} | Delta: {target_logit_delta:.4f} \"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b7974f1f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm.auto import tqdm\n",
    "\n",
    "validation_results = []\n",
    "\n",
    "for clean_sample, patch_sample in tqdm(validation_set):\n",
    "    val_sample_result = validate_q_proj_ie_on_sample_pair(\n",
    "        mt=mt,\n",
    "        clean_sample=clean_sample,\n",
    "        patch_sample=patch_sample,\n",
    "        heads=optimized_heads,\n",
    "        query_indices={-2: -2, -1: -1},\n",
    "        add_ques_pos_to_query_indices=True,\n",
    "        patch_args={\n",
    "            \"batch_size\": len(patch_sample.options),\n",
    "            \"distinct_options\": False,\n",
    "        },\n",
    "    )\n",
    "    validation_results.append(val_sample_result)\n",
    "    print(\"=\" * 80)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "56aff59f",
   "metadata": {},
   "outputs": [],
   "source": [
    "before_intervention = []\n",
    "after_intervention = []\n",
    "\n",
    "for intervention_result in validation_results:\n",
    "    clean_sample = intervention_result[\"clean_sample\"]\n",
    "    patch_sample = intervention_result[\"patch_sample\"]\n",
    "\n",
    "    clean_obj = clean_sample.ans_token_id\n",
    "    target_obj = clean_sample.metadata[\"track_type_obj_token_id\"]\n",
    "\n",
    "    before_intervention.append({\n",
    "        \"clean_rank\": intervention_result[\"clean_track\"][clean_obj][0],\n",
    "        \"clean_logit\": intervention_result[\"clean_track\"][clean_obj][1].logit,\n",
    "        \"target_rank\": intervention_result[\"clean_track\"][target_obj][0],\n",
    "        \"target_logit\": intervention_result[\"clean_track\"][target_obj][1].logit,\n",
    "    })\n",
    "\n",
    "    after_intervention.append({\n",
    "        \"clean_rank\": intervention_result[\"int_track\"][clean_obj][0],\n",
    "        \"clean_logit\": intervention_result[\"int_track\"][clean_obj][1].logit,\n",
    "        \"target_rank\": intervention_result[\"int_track\"][target_obj][0],\n",
    "        \"target_logit\": intervention_result[\"int_track\"][target_obj][1].logit,\n",
    "    })"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3658dde1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "clean_rank_delta = [\n",
    "    after[\"clean_rank\"] - before[\"clean_rank\"]\n",
    "    for before, after in zip(before_intervention, after_intervention)\n",
    "]\n",
    "target_rank_delta = [\n",
    "    after[\"target_rank\"] - before[\"target_rank\"]\n",
    "    for before, after in zip(before_intervention, after_intervention)\n",
    "]\n",
    "\n",
    "clean_rank_delta, target_rank_delta = np.array(clean_rank_delta), np.array(\n",
    "    target_rank_delta\n",
    ")\n",
    "print(f\"clean_rank_delta: {clean_rank_delta.mean():.4f} ± {clean_rank_delta.std():.4f}\")\n",
    "print(\n",
    "    f\"target_rank_delta: {target_rank_delta.mean():.4f} ± {target_rank_delta.std():.4f}\"\n",
    ")\n",
    "\n",
    "clean_rank_after_intervention = [after[\"clean_rank\"] for after in after_intervention]\n",
    "clean_rank_after_intervention = np.array(clean_rank_after_intervention)\n",
    "print(\n",
    "    f\"clean_rank_after_intervention: {clean_rank_after_intervention.mean():.4f} ± {clean_rank_after_intervention.std():.4f}\"\n",
    ")\n",
    "\n",
    "target_rank_after_intervention = [after[\"target_rank\"] for after in after_intervention]\n",
    "target_rank_after_intervention = np.array(target_rank_after_intervention)\n",
    "print(\n",
    "    f\"target_rank_after_intervention: {target_rank_after_intervention.mean():.4f} ± {target_rank_after_intervention.std():.4f}\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11d74673",
   "metadata": {},
   "outputs": [],
   "source": [
    "clean_logit_delta = [\n",
    "    after[\"clean_logit\"] - before[\"clean_logit\"]\n",
    "    for before, after in zip(before_intervention, after_intervention)\n",
    "]\n",
    "target_logit_delta = [\n",
    "    after[\"target_logit\"] - before[\"target_logit\"]\n",
    "    for before, after in zip(before_intervention, after_intervention)\n",
    "]\n",
    "clean_logit_delta, target_logit_delta = np.array(clean_logit_delta), np.array(target_logit_delta)\n",
    "print(f\"clean_logit_delta: {clean_logit_delta.mean():.4f} ± {clean_logit_delta.std():.4f}\")\n",
    "print(f\"target_logit_delta: {target_logit_delta.mean():.4f} ± {target_logit_delta.std():.4f}\")\n",
    "\n",
    "clean_logit_after_intervention = [\n",
    "    after[\"clean_logit\"]\n",
    "    for after in after_intervention\n",
    "]\n",
    "clean_logit_after_intervention = np.array(clean_logit_after_intervention)\n",
    "print(f\"clean_logit_after_intervention: {clean_logit_after_intervention.mean():.4f} ± {clean_logit_after_intervention.std():.4f}\")\n",
    "\n",
    "target_logit_after_intervention = [\n",
    "    after[\"target_logit\"]\n",
    "    for after in after_intervention\n",
    "]\n",
    "target_logit_after_intervention = np.array(target_logit_after_intervention)\n",
    "print(f\"target_logit_after_intervention: {target_logit_after_intervention.mean():.4f} ± {target_logit_after_intervention.std():.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "521e3bbc",
   "metadata": {},
   "outputs": [],
   "source": [
    "top_1 = sum([1 for after in after_intervention if after[\"target_rank\"] == 1])\n",
    "top_1 / len(after_intervention)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "88c92fb4",
   "metadata": {},
   "outputs": [],
   "source": [
    "counter_patch_type_top_option = 0\n",
    "failed_cases = []\n",
    "\n",
    "for intervention_result in validation_results:\n",
    "    clean_sample = intervention_result[\"clean_sample\"]\n",
    "    patch_sample = intervention_result[\"patch_sample\"]\n",
    "    int_track = intervention_result[\"int_track\"]\n",
    "    clean_track = intervention_result[\"clean_track\"]\n",
    "    if (\n",
    "        int_track[list(int_track.keys())[0]][1].token_id\n",
    "        == clean_sample.metadata[\"track_type_obj_token_id\"]\n",
    "    ): \n",
    "        counter_patch_type_top_option += 1\n",
    "    else:\n",
    "        failed_cases.append(\n",
    "            {\n",
    "                \"clean_sample\": clean_sample,\n",
    "                \"patch_sample\": patch_sample,\n",
    "                \"int_track\": int_track,\n",
    "                \"clean_track\": clean_track,\n",
    "            }\n",
    "        )\n",
    "\n",
    "top_1_accuracy = counter_patch_type_top_option / len(validation_results)\n",
    "print(\"=\" * 80)\n",
    "print(\n",
    "    f\"Counterfactual patching accuracy: {top_1_accuracy:.4f} ({counter_patch_type_top_option}/{len(validation_results)})\"\n",
    ")\n",
    "print(\"=\" * 80)\n",
    "print(f\"{len(failed_cases)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fed1d901",
   "metadata": {},
   "source": [
    "### SelectLastTask"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6fb79abd",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.data import SelectionSample, SelectLastTask\n",
    "\n",
    "select_last_task = SelectLastTask.load(\n",
    "    path=os.path.join(\n",
    "        env_utils.DEFAULT_DATA_DIR, \n",
    "        \"selection\", \n",
    "        \"objects.json\"\n",
    "    )\n",
    ")\n",
    "print(select_first_task)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4b0c16d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_sample = select_last_task.get_random_sample(\n",
    "    mt = mt,\n",
    "    option_style=\"single_line\",\n",
    "    prompt_template_idx=3,\n",
    "    category=\"fruit\",\n",
    "    filter_by_lm_prediction=True,\n",
    ")\n",
    "print(test_sample.prompt(), \">>\", f'\"{mt.tokenizer.decode([test_sample.ans_token_id])}\"')\n",
    "test_sample.prediction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e331dc7c",
   "metadata": {},
   "outputs": [],
   "source": [
    "get_task_specific_kwargs(select_last_task)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "24c95861",
   "metadata": {},
   "outputs": [],
   "source": [
    "patch_sample, clean_sample = get_counterfactual_samples_across_tasks(\n",
    "    mt=mt,\n",
    "    patch_task=source_task,\n",
    "    clean_task=select_last_task,\n",
    "    patch_category=\"fruit\",\n",
    "    clean_category=\"vehicle\",\n",
    "    clean_transform=None,\n",
    "    filter_by_lm_prediction=True\n",
    ")\n",
    "\n",
    "print(\n",
    "    \"CLEAN:\",\n",
    "    clean_sample.prompt(),\n",
    "    \">>\",\n",
    "    f'\"{mt.tokenizer.decode([clean_sample.ans_token_id])}\"',\n",
    ")\n",
    "print(\n",
    "    \"PATCH:\",\n",
    "    patch_sample.prompt(),\n",
    "    \">>\",\n",
    "    f'\"{mt.tokenizer.decode([patch_sample.ans_token_id])}\"',\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2cdc7153",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.data import CounterFactualSamplePair\n",
    "from src.functional import free_gpu_cache\n",
    "from src.selection.data import get_counterfactual_samples_interface\n",
    "import random\n",
    "\n",
    "validation_samples_save_path = os.path.join(\n",
    "    env_utils.DEFAULT_RESULTS_DIR,\n",
    "    \"selection\",\n",
    "    \"samples\",\n",
    "    \"predicate_generalization\",\n",
    "    mt.name.split(\"/\")[-1],\n",
    "    source_task.task_name,\n",
    "    select_last_task.task_name,\n",
    "    \"objects\",\n",
    ")\n",
    "\n",
    "os.makedirs(validation_samples_save_path, exist_ok=True)\n",
    "\n",
    "\n",
    "free_gpu_cache()\n",
    "validation_set = []\n",
    "validation_limit = 512\n",
    "start_from = 1\n",
    "\n",
    "while len(validation_set) < validation_limit:\n",
    "    print(f\"sample {len(validation_set)+1} / {validation_limit}\")\n",
    "    patch, clean = get_counterfactual_samples_across_tasks(\n",
    "        mt=mt,\n",
    "        patch_task=source_task,\n",
    "        clean_task=select_last_task,\n",
    "        clean_transform=None,\n",
    "        filter_by_lm_prediction=True,\n",
    "    )\n",
    "    validation_set.append((clean, patch))\n",
    "    cf_pair = CounterFactualSamplePair(\n",
    "        patch_sample=patch,\n",
    "        clean_sample=clean,\n",
    "    )\n",
    "    cf_pair.detensorize()\n",
    "    with open(\n",
    "        os.path.join(\n",
    "            validation_samples_save_path,\n",
    "            f\"{len(validation_set) + start_from - 1:05d}.json\",\n",
    "        ),\n",
    "        \"w\",\n",
    "    ) as f:\n",
    "        json.dump(cf_pair.to_dict(), f, indent=2)\n",
    "\n",
    "len(validation_set)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1e82a53c",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.data import CounterFactualSamplePair\n",
    "import random\n",
    "\n",
    "validation_set = []\n",
    "validation_limit = 512\n",
    "\n",
    "validation_samples_load_path = os.path.join(\n",
    "    env_utils.DEFAULT_RESULTS_DIR,\n",
    "    \"selection\",\n",
    "    \"samples\",\n",
    "    \"predicate_generalization\",\n",
    "    mt.name.split(\"/\")[-1],\n",
    "    source_task.task_name,\n",
    "    select_last_task.task_name,\n",
    "    \"objects\",\n",
    ")\n",
    "\n",
    "sample_files = [\n",
    "    os.path.join(validation_samples_load_path, f)\n",
    "    for f in os.listdir(validation_samples_load_path)\n",
    "    if f.endswith(\".json\")\n",
    "]\n",
    "logger.info(f\"Found {len(sample_files)} sample files\")\n",
    "\n",
    "prefix = \"\"\n",
    "# prefix = \"Recall the nationality of these people:\\n\"\n",
    "# prefix = \"Recall which country these landmarks are located in:\\n\"\n",
    "# prefix = \"Think about how these words sound when you say them aloud:\\n\"\n",
    "\n",
    "random.shuffle(sample_files)\n",
    "sample_files = sample_files[:validation_limit]\n",
    "for sample_file in sample_files:\n",
    "    with open(sample_file, \"r\") as f:\n",
    "        cf_pair_data = json.load(f)\n",
    "    cf_pair = CounterFactualSamplePair.from_dict(cf_pair_data)\n",
    "\n",
    "    cf_pair.clean_sample.prompt_template = prefix + cf_pair.clean_sample.prompt_template\n",
    "    cf_pair.patch_sample.prompt_template = prefix + cf_pair.patch_sample.prompt_template\n",
    "    validation_set.append((cf_pair.clean_sample, cf_pair.patch_sample))\n",
    "\n",
    "len(validation_set)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ff4acca9",
   "metadata": {},
   "outputs": [],
   "source": [
    "clean, patch = validation_set[3]\n",
    "print(patch.prompt(), \">>\", mt.tokenizer.decode(patch.ans_token_id))\n",
    "print(clean.prompt(), \">>\", mt.tokenizer.decode(clean.ans_token_id))\n",
    "clean.metadata[\"track_type_obj_token_id\"], mt.tokenizer.decode(clean.metadata[\"track_type_obj_token_id\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e65f8bdc",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.optimization import validate_q_proj_ie_on_sample_pair\n",
    "import copy\n",
    "\n",
    "clean, patch = copy.deepcopy(validation_set[3])\n",
    "\n",
    "val_sample_result = validate_q_proj_ie_on_sample_pair(\n",
    "    mt=mt,\n",
    "    clean_sample=clean,\n",
    "    patch_sample=patch,\n",
    "    heads=optimized_heads,\n",
    "    query_indices={-2: -2, -1: -1},\n",
    "    add_ques_pos_to_query_indices=True,\n",
    "    verify_head_behavior_on=-1,\n",
    "    patch_args={\n",
    "        \"batch_size\": len(patch.options),\n",
    "        \"distinct_options\": False,\n",
    "        # \"task\": select_task,\n",
    "        # \"prompt_template_idx\": prompt_template_idx,\n",
    "        # \"option_style\": patch.default_option_style,\n",
    "        # \"n_distractors\": N_DISTRACTORS,\n",
    "    },\n",
    ")\n",
    "\n",
    "clean_obj = clean.ans_token_id\n",
    "target_obj = clean.metadata[\"track_type_obj_token_id\"]\n",
    "\n",
    "logger.debug(f\"clean obj: {mt.tokenizer.decode(clean_obj)}\")\n",
    "logger.debug(f\"target obj: {mt.tokenizer.decode(target_obj)}\")\n",
    "\n",
    "before_intervention = {\n",
    "    \"clean_rank\": val_sample_result[\"clean_track\"][clean_obj][0],\n",
    "    \"clean_logit\": val_sample_result[\"clean_track\"][clean_obj][1].logit,\n",
    "    \"target_rank\": val_sample_result[\"clean_track\"][target_obj][0],\n",
    "    \"target_logit\": val_sample_result[\"clean_track\"][target_obj][1].logit,\n",
    "}\n",
    "\n",
    "after_intervention = {\n",
    "    \"clean_rank\": val_sample_result[\"int_track\"][clean_obj][0],\n",
    "    \"clean_logit\": val_sample_result[\"int_track\"][clean_obj][1].logit,\n",
    "    \"target_rank\": val_sample_result[\"int_track\"][target_obj][0],\n",
    "    \"target_logit\": val_sample_result[\"int_track\"][target_obj][1].logit,\n",
    "}\n",
    "\n",
    "clean_rank_delta = after_intervention[\"clean_rank\"] - before_intervention[\"clean_rank\"]\n",
    "target_rank_delta = (\n",
    "    after_intervention[\"target_rank\"] - before_intervention[\"target_rank\"]\n",
    ")\n",
    "logger.info(\n",
    "    f\"Clean Prediction Rank Change: {before_intervention['clean_rank']} -> {after_intervention['clean_rank']} | Delta: {clean_rank_delta} \"\n",
    ")\n",
    "logger.info(\n",
    "    f\"Target Prediction Rank Change: {before_intervention['target_rank']} -> {after_intervention['target_rank']} | Delta: {target_rank_delta} \"\n",
    ")\n",
    "\n",
    "clean_logit_delta = (\n",
    "    after_intervention[\"clean_logit\"] - before_intervention[\"clean_logit\"]\n",
    ")\n",
    "target_logit_delta = (\n",
    "    after_intervention[\"target_logit\"] - before_intervention[\"target_logit\"]\n",
    ")\n",
    "logger.info(\n",
    "    f\"Clean Prediction Logit Change: {before_intervention['clean_logit']:.4f} -> {after_intervention['clean_logit']:.4f} | Delta: {clean_logit_delta:.4f} \"\n",
    ")\n",
    "logger.info(\n",
    "    f\"Target Prediction Logit Change: {before_intervention['target_logit']:.4f} -> {after_intervention['target_logit']:.4f} | Delta: {target_logit_delta:.4f} \"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "121e4d2c",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm.auto import tqdm\n",
    "\n",
    "validation_results = []\n",
    "\n",
    "for clean_sample, patch_sample in tqdm(validation_set):\n",
    "    val_sample_result = validate_q_proj_ie_on_sample_pair(\n",
    "        mt=mt,\n",
    "        clean_sample=clean_sample,\n",
    "        patch_sample=patch_sample,\n",
    "        heads=optimized_heads,\n",
    "        query_indices={-2: -2, -1: -1},\n",
    "        add_ques_pos_to_query_indices=True,\n",
    "        patch_args={\n",
    "            \"batch_size\": len(patch_sample.options),\n",
    "            \"distinct_options\": False,\n",
    "        },\n",
    "    )\n",
    "    validation_results.append(val_sample_result)\n",
    "    print(\"=\" * 80)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "acb7d5fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "before_intervention = []\n",
    "after_intervention = []\n",
    "\n",
    "for intervention_result in validation_results:\n",
    "    clean_sample = intervention_result[\"clean_sample\"]\n",
    "    patch_sample = intervention_result[\"patch_sample\"]\n",
    "\n",
    "    clean_obj = clean_sample.ans_token_id\n",
    "    target_obj = clean_sample.metadata[\"track_type_obj_token_id\"]\n",
    "\n",
    "    before_intervention.append({\n",
    "        \"clean_rank\": intervention_result[\"clean_track\"][clean_obj][0],\n",
    "        \"clean_logit\": intervention_result[\"clean_track\"][clean_obj][1].logit,\n",
    "        \"target_rank\": intervention_result[\"clean_track\"][target_obj][0],\n",
    "        \"target_logit\": intervention_result[\"clean_track\"][target_obj][1].logit,\n",
    "    })\n",
    "\n",
    "    after_intervention.append({\n",
    "        \"clean_rank\": intervention_result[\"int_track\"][clean_obj][0],\n",
    "        \"clean_logit\": intervention_result[\"int_track\"][clean_obj][1].logit,\n",
    "        \"target_rank\": intervention_result[\"int_track\"][target_obj][0],\n",
    "        \"target_logit\": intervention_result[\"int_track\"][target_obj][1].logit,\n",
    "    })"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "81162f2f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "clean_rank_delta = [\n",
    "    after[\"clean_rank\"] - before[\"clean_rank\"]\n",
    "    for before, after in zip(before_intervention, after_intervention)\n",
    "]\n",
    "target_rank_delta = [\n",
    "    after[\"target_rank\"] - before[\"target_rank\"]\n",
    "    for before, after in zip(before_intervention, after_intervention)\n",
    "]\n",
    "\n",
    "clean_rank_delta, target_rank_delta = np.array(clean_rank_delta), np.array(\n",
    "    target_rank_delta\n",
    ")\n",
    "print(f\"clean_rank_delta: {clean_rank_delta.mean():.4f} ± {clean_rank_delta.std():.4f}\")\n",
    "print(\n",
    "    f\"target_rank_delta: {target_rank_delta.mean():.4f} ± {target_rank_delta.std():.4f}\"\n",
    ")\n",
    "\n",
    "clean_rank_after_intervention = [after[\"clean_rank\"] for after in after_intervention]\n",
    "clean_rank_after_intervention = np.array(clean_rank_after_intervention)\n",
    "print(\n",
    "    f\"clean_rank_after_intervention: {clean_rank_after_intervention.mean():.4f} ± {clean_rank_after_intervention.std():.4f}\"\n",
    ")\n",
    "\n",
    "target_rank_after_intervention = [after[\"target_rank\"] for after in after_intervention]\n",
    "target_rank_after_intervention = np.array(target_rank_after_intervention)\n",
    "print(\n",
    "    f\"target_rank_after_intervention: {target_rank_after_intervention.mean():.4f} ± {target_rank_after_intervention.std():.4f}\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "91d5bc2f",
   "metadata": {},
   "outputs": [],
   "source": [
    "clean_logit_delta = [\n",
    "    after[\"clean_logit\"] - before[\"clean_logit\"]\n",
    "    for before, after in zip(before_intervention, after_intervention)\n",
    "]\n",
    "target_logit_delta = [\n",
    "    after[\"target_logit\"] - before[\"target_logit\"]\n",
    "    for before, after in zip(before_intervention, after_intervention)\n",
    "]\n",
    "clean_logit_delta, target_logit_delta = np.array(clean_logit_delta), np.array(target_logit_delta)\n",
    "print(f\"clean_logit_delta: {clean_logit_delta.mean():.4f} ± {clean_logit_delta.std():.4f}\")\n",
    "print(f\"target_logit_delta: {target_logit_delta.mean():.4f} ± {target_logit_delta.std():.4f}\")\n",
    "\n",
    "clean_logit_after_intervention = [\n",
    "    after[\"clean_logit\"]\n",
    "    for after in after_intervention\n",
    "]\n",
    "clean_logit_after_intervention = np.array(clean_logit_after_intervention)\n",
    "print(f\"clean_logit_after_intervention: {clean_logit_after_intervention.mean():.4f} ± {clean_logit_after_intervention.std():.4f}\")\n",
    "\n",
    "target_logit_after_intervention = [\n",
    "    after[\"target_logit\"]\n",
    "    for after in after_intervention\n",
    "]\n",
    "target_logit_after_intervention = np.array(target_logit_after_intervention)\n",
    "print(f\"target_logit_after_intervention: {target_logit_after_intervention.mean():.4f} ± {target_logit_after_intervention.std():.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e0a4cda1",
   "metadata": {},
   "outputs": [],
   "source": [
    "top_1 = sum([1 for after in after_intervention if after[\"target_rank\"] == 1])\n",
    "top_1 / len(after_intervention)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4ad7cbad",
   "metadata": {},
   "outputs": [],
   "source": [
    "counter_patch_type_top_option = 0\n",
    "failed_cases = []\n",
    "\n",
    "for intervention_result in validation_results:\n",
    "    clean_sample = intervention_result[\"clean_sample\"]\n",
    "    patch_sample = intervention_result[\"patch_sample\"]\n",
    "    int_track = intervention_result[\"int_track\"]\n",
    "    clean_track = intervention_result[\"clean_track\"]\n",
    "    if (\n",
    "        int_track[list(int_track.keys())[0]][1].token_id\n",
    "        == clean_sample.metadata[\"track_type_obj_token_id\"]\n",
    "    ): \n",
    "        counter_patch_type_top_option += 1\n",
    "    else:\n",
    "        failed_cases.append(\n",
    "            {\n",
    "                \"clean_sample\": clean_sample,\n",
    "                \"patch_sample\": patch_sample,\n",
    "                \"int_track\": int_track,\n",
    "                \"clean_track\": clean_track,\n",
    "            }\n",
    "        )\n",
    "\n",
    "top_1_accuracy = counter_patch_type_top_option / len(validation_results)\n",
    "print(\"=\" * 80)\n",
    "print(\n",
    "    f\"Counterfactual patching accuracy: {top_1_accuracy:.4f} ({counter_patch_type_top_option}/{len(validation_results)})\"\n",
    ")\n",
    "print(\"=\" * 80)\n",
    "print(f\"{len(failed_cases)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "91721522",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "4bc74551",
   "metadata": {},
   "source": [
    "### Counting Task"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "16ee6cab",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.data import CountingTask\n",
    "\n",
    "counting_task = CountingTask.load(\n",
    "    path=os.path.join(\n",
    "        env_utils.DEFAULT_DATA_DIR, \n",
    "        \"selection\", \n",
    "        \"objects.json\"\n",
    "    )\n",
    ")\n",
    "print(counting_task)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "88114883",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_sample = counting_task.get_random_sample(\n",
    "    mt = mt,\n",
    "    option_style=OPTION_STYLE,\n",
    "    prompt_template_idx=3,\n",
    "    category=\"fruit\",\n",
    "    filter_by_lm_prediction=True,\n",
    ")\n",
    "\n",
    "print(test_sample.prompt(), \">>\", f'\"{mt.tokenizer.decode([test_sample.ans_token_id])}\"')\n",
    "test_sample.prediction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3496bf5c",
   "metadata": {},
   "outputs": [],
   "source": [
    "patch_sample, clean_sample = get_counterfactual_samples_across_tasks(\n",
    "    mt=mt,\n",
    "    patch_task=source_task,\n",
    "    clean_task=counting_task,\n",
    "    patch_category=\"fruit\",\n",
    "    clean_category=\"vehicle\",\n",
    "    clean_transform=None,\n",
    "    filter_by_lm_prediction=True,\n",
    ")\n",
    "\n",
    "print(\n",
    "    \"CLEAN:\",\n",
    "    clean_sample.prompt(),\n",
    "    \">>\",\n",
    "    f'\"{mt.tokenizer.decode([clean_sample.ans_token_id])}\"',\n",
    ")\n",
    "print(\n",
    "    \"PATCH:\",\n",
    "    patch_sample.prompt(),\n",
    "    \">>\",\n",
    "    f'\"{mt.tokenizer.decode([patch_sample.ans_token_id])}\"',\n",
    ")\n",
    "\n",
    "print(\n",
    "    clean_sample.metadata[\"track_type_obj_token_id\"],\n",
    "    mt.tokenizer.decode(clean_sample.metadata[\"track_type_obj_token_id\"]),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b4e2c7b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.data import CounterFactualSamplePair\n",
    "from src.functional import free_gpu_cache\n",
    "from src.selection.data import get_counterfactual_samples_interface\n",
    "import random\n",
    "\n",
    "validation_samples_save_path = os.path.join(\n",
    "    env_utils.DEFAULT_RESULTS_DIR,\n",
    "    \"selection\",\n",
    "    \"samples\",\n",
    "    \"predicate_generalization\",\n",
    "    mt.name.split(\"/\")[-1],\n",
    "    source_task.task_name,\n",
    "    counting_task.task_name,\n",
    "    \"objects\",\n",
    ")\n",
    "\n",
    "os.makedirs(validation_samples_save_path, exist_ok=True)\n",
    "\n",
    "\n",
    "free_gpu_cache()\n",
    "validation_set = []\n",
    "validation_limit = 32\n",
    "start_from = 250\n",
    "\n",
    "while len(validation_set) < validation_limit:\n",
    "    print(f\"sample {len(validation_set)+1} / {validation_limit}\")\n",
    "    patch, clean = get_counterfactual_samples_across_tasks(\n",
    "        mt=mt,\n",
    "        patch_task=source_task,\n",
    "        clean_task=counting_task,\n",
    "        clean_transform=None,\n",
    "        filter_by_lm_prediction=True,\n",
    "    )\n",
    "    validation_set.append((clean, patch))\n",
    "    cf_pair = CounterFactualSamplePair(\n",
    "        patch_sample=patch,\n",
    "        clean_sample=clean,\n",
    "    )\n",
    "    cf_pair.detensorize()\n",
    "    with open(\n",
    "        os.path.join(\n",
    "            validation_samples_save_path,\n",
    "            f\"{len(validation_set) + start_from - 1:05d}.json\",\n",
    "        ),\n",
    "        \"w\",\n",
    "    ) as f:\n",
    "        json.dump(cf_pair.to_dict(), f, indent=2)\n",
    "\n",
    "len(validation_set)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "be111927",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.data import CounterFactualSamplePair\n",
    "import random\n",
    "\n",
    "validation_set = []\n",
    "validation_limit = 256\n",
    "\n",
    "validation_samples_load_path = os.path.join(\n",
    "    env_utils.DEFAULT_RESULTS_DIR,\n",
    "    \"selection\",\n",
    "    \"samples\",\n",
    "    \"predicate_generalization\",\n",
    "    mt.name.split(\"/\")[-1],\n",
    "    source_task.task_name,\n",
    "    counting_task.task_name,\n",
    "    \"objects\",\n",
    ")\n",
    "\n",
    "sample_files = [\n",
    "    os.path.join(validation_samples_load_path, f)\n",
    "    for f in os.listdir(validation_samples_load_path)\n",
    "    if f.endswith(\".json\")\n",
    "]\n",
    "logger.info(f\"Found {len(sample_files)} sample files\")\n",
    "\n",
    "prefix = \"\"\n",
    "# prefix = \"Recall the nationality of these people:\\n\"\n",
    "# prefix = \"Recall which country these landmarks are located in:\\n\"\n",
    "# prefix = \"Think about how these words sound when you say them aloud:\\n\"\n",
    "\n",
    "random.shuffle(sample_files)\n",
    "sample_files = sample_files[:validation_limit]\n",
    "for sample_file in sample_files:\n",
    "    with open(sample_file, \"r\") as f:\n",
    "        cf_pair_data = json.load(f)\n",
    "    cf_pair = CounterFactualSamplePair.from_dict(cf_pair_data)\n",
    "\n",
    "    cf_pair.clean_sample.prompt_template = prefix + cf_pair.clean_sample.prompt_template\n",
    "    cf_pair.patch_sample.prompt_template = prefix + cf_pair.patch_sample.prompt_template\n",
    "    validation_set.append((cf_pair.clean_sample, cf_pair.patch_sample))\n",
    "\n",
    "len(validation_set)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "25a449d9",
   "metadata": {},
   "outputs": [],
   "source": [
    "clean, patch = validation_set[3]\n",
    "print(patch.prompt(), \">>\", mt.tokenizer.decode(patch.ans_token_id))\n",
    "print(clean.prompt(), \">>\", mt.tokenizer.decode(clean.ans_token_id))\n",
    "clean.metadata[\"track_type_obj_token_id\"], mt.tokenizer.decode(clean.metadata[\"track_type_obj_token_id\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5bf5d2f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.optimization import validate_q_proj_ie_on_sample_pair\n",
    "import copy\n",
    "\n",
    "clean, patch = copy.deepcopy(validation_set[3])\n",
    "\n",
    "val_sample_result = validate_q_proj_ie_on_sample_pair(\n",
    "    mt=mt,\n",
    "    clean_sample=clean,\n",
    "    patch_sample=patch,\n",
    "    heads=optimized_heads,\n",
    "    query_indices={-2: -2, -1: -1},\n",
    "    add_ques_pos_to_query_indices=True,\n",
    "    verify_head_behavior_on=-1,\n",
    "    patch_args={\n",
    "        \"batch_size\": len(patch.options),\n",
    "        \"distinct_options\": False,\n",
    "        # \"task\": select_task,\n",
    "        # \"prompt_template_idx\": prompt_template_idx,\n",
    "        # \"option_style\": patch.default_option_style,\n",
    "        # \"n_distractors\": N_DISTRACTORS,\n",
    "    },\n",
    ")\n",
    "\n",
    "clean_obj = clean.ans_token_id\n",
    "target_obj = clean.metadata[\"track_type_obj_token_id\"]\n",
    "\n",
    "logger.debug(f\"clean obj: {mt.tokenizer.decode(clean_obj)}\")\n",
    "logger.debug(f\"target obj: {mt.tokenizer.decode(target_obj)}\")\n",
    "\n",
    "before_intervention = {\n",
    "    \"clean_rank\": val_sample_result[\"clean_track\"][clean_obj][0],\n",
    "    \"clean_logit\": val_sample_result[\"clean_track\"][clean_obj][1].logit,\n",
    "    \"target_rank\": val_sample_result[\"clean_track\"][target_obj][0],\n",
    "    \"target_logit\": val_sample_result[\"clean_track\"][target_obj][1].logit,\n",
    "}\n",
    "\n",
    "after_intervention = {\n",
    "    \"clean_rank\": val_sample_result[\"int_track\"][clean_obj][0],\n",
    "    \"clean_logit\": val_sample_result[\"int_track\"][clean_obj][1].logit,\n",
    "    \"target_rank\": val_sample_result[\"int_track\"][target_obj][0],\n",
    "    \"target_logit\": val_sample_result[\"int_track\"][target_obj][1].logit,\n",
    "}\n",
    "\n",
    "clean_rank_delta = after_intervention[\"clean_rank\"] - before_intervention[\"clean_rank\"]\n",
    "target_rank_delta = (\n",
    "    after_intervention[\"target_rank\"] - before_intervention[\"target_rank\"]\n",
    ")\n",
    "logger.info(\n",
    "    f\"Clean Prediction Rank Change: {before_intervention['clean_rank']} -> {after_intervention['clean_rank']} | Delta: {clean_rank_delta} \"\n",
    ")\n",
    "logger.info(\n",
    "    f\"Target Prediction Rank Change: {before_intervention['target_rank']} -> {after_intervention['target_rank']} | Delta: {target_rank_delta} \"\n",
    ")\n",
    "\n",
    "clean_logit_delta = (\n",
    "    after_intervention[\"clean_logit\"] - before_intervention[\"clean_logit\"]\n",
    ")\n",
    "target_logit_delta = (\n",
    "    after_intervention[\"target_logit\"] - before_intervention[\"target_logit\"]\n",
    ")\n",
    "logger.info(\n",
    "    f\"Clean Prediction Logit Change: {before_intervention['clean_logit']:.4f} -> {after_intervention['clean_logit']:.4f} | Delta: {clean_logit_delta:.4f} \"\n",
    ")\n",
    "logger.info(\n",
    "    f\"Target Prediction Logit Change: {before_intervention['target_logit']:.4f} -> {after_intervention['target_logit']:.4f} | Delta: {target_logit_delta:.4f} \"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aad6e974",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm.auto import tqdm\n",
    "\n",
    "validation_results = []\n",
    "\n",
    "for clean_sample, patch_sample in tqdm(validation_set):\n",
    "    val_sample_result = validate_q_proj_ie_on_sample_pair(\n",
    "        mt=mt,\n",
    "        clean_sample=clean_sample,\n",
    "        patch_sample=patch_sample,\n",
    "        heads=optimized_heads,\n",
    "        query_indices={-2: -2, -1: -1},\n",
    "        add_ques_pos_to_query_indices=True,\n",
    "        patch_args={\n",
    "            \"batch_size\": len(patch_sample.options),\n",
    "            \"distinct_options\": False,\n",
    "        },\n",
    "    )\n",
    "    validation_results.append(val_sample_result)\n",
    "    print(\"=\" * 80)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "043e2a74",
   "metadata": {},
   "outputs": [],
   "source": [
    "before_intervention = []\n",
    "after_intervention = []\n",
    "\n",
    "for intervention_result in validation_results:\n",
    "    clean_sample = intervention_result[\"clean_sample\"]\n",
    "    patch_sample = intervention_result[\"patch_sample\"]\n",
    "\n",
    "    clean_obj = clean_sample.ans_token_id\n",
    "    target_obj = clean_sample.metadata[\"track_type_obj_token_id\"]\n",
    "\n",
    "    before_intervention.append({\n",
    "        \"clean_rank\": intervention_result[\"clean_track\"][clean_obj][0],\n",
    "        \"clean_logit\": intervention_result[\"clean_track\"][clean_obj][1].logit,\n",
    "        \"target_rank\": intervention_result[\"clean_track\"][target_obj][0],\n",
    "        \"target_logit\": intervention_result[\"clean_track\"][target_obj][1].logit,\n",
    "    })\n",
    "\n",
    "    after_intervention.append({\n",
    "        \"clean_rank\": intervention_result[\"int_track\"][clean_obj][0],\n",
    "        \"clean_logit\": intervention_result[\"int_track\"][clean_obj][1].logit,\n",
    "        \"target_rank\": intervention_result[\"int_track\"][target_obj][0],\n",
    "        \"target_logit\": intervention_result[\"int_track\"][target_obj][1].logit,\n",
    "    })"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fdaedb47",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "clean_rank_delta = [\n",
    "    after[\"clean_rank\"] - before[\"clean_rank\"]\n",
    "    for before, after in zip(before_intervention, after_intervention)\n",
    "]\n",
    "target_rank_delta = [\n",
    "    after[\"target_rank\"] - before[\"target_rank\"]\n",
    "    for before, after in zip(before_intervention, after_intervention)\n",
    "]\n",
    "\n",
    "clean_rank_delta, target_rank_delta = np.array(clean_rank_delta), np.array(\n",
    "    target_rank_delta\n",
    ")\n",
    "print(f\"clean_rank_delta: {clean_rank_delta.mean():.4f} ± {clean_rank_delta.std():.4f}\")\n",
    "print(\n",
    "    f\"target_rank_delta: {target_rank_delta.mean():.4f} ± {target_rank_delta.std():.4f}\"\n",
    ")\n",
    "\n",
    "clean_rank_after_intervention = [after[\"clean_rank\"] for after in after_intervention]\n",
    "clean_rank_after_intervention = np.array(clean_rank_after_intervention)\n",
    "print(\n",
    "    f\"clean_rank_after_intervention: {clean_rank_after_intervention.mean():.4f} ± {clean_rank_after_intervention.std():.4f}\"\n",
    ")\n",
    "\n",
    "target_rank_after_intervention = [after[\"target_rank\"] for after in after_intervention]\n",
    "target_rank_after_intervention = np.array(target_rank_after_intervention)\n",
    "print(\n",
    "    f\"target_rank_after_intervention: {target_rank_after_intervention.mean():.4f} ± {target_rank_after_intervention.std():.4f}\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2bb3b4db",
   "metadata": {},
   "outputs": [],
   "source": [
    "clean_logit_delta = [\n",
    "    after[\"clean_logit\"] - before[\"clean_logit\"]\n",
    "    for before, after in zip(before_intervention, after_intervention)\n",
    "]\n",
    "target_logit_delta = [\n",
    "    after[\"target_logit\"] - before[\"target_logit\"]\n",
    "    for before, after in zip(before_intervention, after_intervention)\n",
    "]\n",
    "clean_logit_delta, target_logit_delta = np.array(clean_logit_delta), np.array(target_logit_delta)\n",
    "print(f\"clean_logit_delta: {clean_logit_delta.mean():.4f} ± {clean_logit_delta.std():.4f}\")\n",
    "print(f\"target_logit_delta: {target_logit_delta.mean():.4f} ± {target_logit_delta.std():.4f}\")\n",
    "\n",
    "clean_logit_after_intervention = [\n",
    "    after[\"clean_logit\"]\n",
    "    for after in after_intervention\n",
    "]\n",
    "clean_logit_after_intervention = np.array(clean_logit_after_intervention)\n",
    "print(f\"clean_logit_after_intervention: {clean_logit_after_intervention.mean():.4f} ± {clean_logit_after_intervention.std():.4f}\")\n",
    "\n",
    "target_logit_after_intervention = [\n",
    "    after[\"target_logit\"]\n",
    "    for after in after_intervention\n",
    "]\n",
    "target_logit_after_intervention = np.array(target_logit_after_intervention)\n",
    "print(f\"target_logit_after_intervention: {target_logit_after_intervention.mean():.4f} ± {target_logit_after_intervention.std():.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6c740a90",
   "metadata": {},
   "outputs": [],
   "source": [
    "top_1 = sum([1 for after in after_intervention if after[\"target_rank\"] == 1])\n",
    "top_1 / len(after_intervention)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d13b54dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "counter_patch_type_top_option = 0\n",
    "failed_cases = []\n",
    "\n",
    "for intervention_result in validation_results:\n",
    "    clean_sample = intervention_result[\"clean_sample\"]\n",
    "    patch_sample = intervention_result[\"patch_sample\"]\n",
    "    int_track = intervention_result[\"int_track\"]\n",
    "    clean_track = intervention_result[\"clean_track\"]\n",
    "    if (\n",
    "        int_track[list(int_track.keys())[0]][1].token_id\n",
    "        == clean_sample.metadata[\"track_type_obj_token_id\"]\n",
    "    ): \n",
    "        counter_patch_type_top_option += 1\n",
    "    else:\n",
    "        failed_cases.append(\n",
    "            {\n",
    "                \"clean_sample\": clean_sample,\n",
    "                \"patch_sample\": patch_sample,\n",
    "                \"int_track\": int_track,\n",
    "                \"clean_track\": clean_track,\n",
    "            }\n",
    "        )\n",
    "\n",
    "top_1_accuracy = counter_patch_type_top_option / len(validation_results)\n",
    "print(\"=\" * 80)\n",
    "print(\n",
    "    f\"Counterfactual patching accuracy: {top_1_accuracy:.4f} ({counter_patch_type_top_option}/{len(validation_results)})\"\n",
    ")\n",
    "print(\"=\" * 80)\n",
    "print(f\"{len(failed_cases)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d2197346",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab55549f",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "ab0c8b1c",
   "metadata": {},
   "source": [
    "### Yes/No Task"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "47869299",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.data import YesNoSample, YesNoTask\n",
    "\n",
    "yes_no_task = YesNoTask.load(\n",
    "    path=os.path.join(\n",
    "        env_utils.DEFAULT_DATA_DIR, \n",
    "        \"selection\", \n",
    "        \"objects.json\"\n",
    "    )\n",
    ")\n",
    "print(yes_no_task)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "53813195",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_sample = counting_task.get_random_sample(\n",
    "    mt = mt,\n",
    "    option_style=OPTION_STYLE,\n",
    "    prompt_template_idx=3,\n",
    "    category=\"fruit\",\n",
    "    filter_by_lm_prediction=True,\n",
    ")\n",
    "\n",
    "print(test_sample.prompt(), \">>\", f'\"{mt.tokenizer.decode([test_sample.ans_token_id])}\"')\n",
    "test_sample.prediction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a6bec85b",
   "metadata": {},
   "outputs": [],
   "source": [
    "patch_sample, clean_sample = get_counterfactual_samples_across_tasks(\n",
    "    mt=mt,\n",
    "    patch_task=source_task,\n",
    "    clean_task=yes_no_task,\n",
    "    patch_category=\"fruit\",\n",
    "    clean_category=\"vehicle\",\n",
    "    clean_transform=None,\n",
    "    filter_by_lm_prediction=True,\n",
    ")\n",
    "\n",
    "print(\n",
    "    \"CLEAN:\",\n",
    "    clean_sample.prompt(),\n",
    "    \">>\",\n",
    "    f'\"{mt.tokenizer.decode([clean_sample.ans_token_id])}\"',\n",
    ")\n",
    "print(\n",
    "    \"PATCH:\",\n",
    "    patch_sample.prompt(),\n",
    "    \">>\",\n",
    "    f'\"{mt.tokenizer.decode([patch_sample.ans_token_id])}\"',\n",
    ")\n",
    "\n",
    "print(\n",
    "    clean_sample.metadata[\"track_type_obj_token_id\"],\n",
    "    mt.tokenizer.decode(clean_sample.metadata[\"track_type_obj_token_id\"]),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3528d708",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.data import CounterFactualSamplePair\n",
    "from src.functional import free_gpu_cache\n",
    "from src.selection.data import get_counterfactual_samples_interface\n",
    "import random\n",
    "\n",
    "validation_samples_save_path = os.path.join(\n",
    "    env_utils.DEFAULT_RESULTS_DIR,\n",
    "    \"selection\",\n",
    "    \"samples\",\n",
    "    \"predicate_generalization\",\n",
    "    mt.name.split(\"/\")[-1],\n",
    "    source_task.task_name,\n",
    "    yes_no_task.task_name,\n",
    "    \"objects\",\n",
    ")\n",
    "\n",
    "os.makedirs(validation_samples_save_path, exist_ok=True)\n",
    "\n",
    "\n",
    "free_gpu_cache()\n",
    "validation_set = []\n",
    "validation_limit = 256\n",
    "start_from = 1\n",
    "\n",
    "while len(validation_set) < validation_limit:\n",
    "    print(f\"sample {len(validation_set)+1} / {validation_limit}\")\n",
    "    patch, clean = get_counterfactual_samples_across_tasks(\n",
    "        mt=mt,\n",
    "        patch_task=source_task,\n",
    "        clean_task=yes_no_task,\n",
    "        clean_transform=None,\n",
    "        filter_by_lm_prediction=True,\n",
    "    )\n",
    "    validation_set.append((clean, patch))\n",
    "    cf_pair = CounterFactualSamplePair(\n",
    "        patch_sample=patch,\n",
    "        clean_sample=clean,\n",
    "    )\n",
    "    cf_pair.detensorize()\n",
    "    with open(\n",
    "        os.path.join(\n",
    "            validation_samples_save_path,\n",
    "            f\"{len(validation_set) + start_from - 1:05d}.json\",\n",
    "        ),\n",
    "        \"w\",\n",
    "    ) as f:\n",
    "        json.dump(cf_pair.to_dict(), f, indent=2)\n",
    "\n",
    "len(validation_set)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db573c70",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.data import CounterFactualSamplePair\n",
    "import random\n",
    "\n",
    "validation_set = []\n",
    "validation_limit = 256\n",
    "\n",
    "validation_samples_load_path = os.path.join(\n",
    "    env_utils.DEFAULT_RESULTS_DIR,\n",
    "    \"selection\",\n",
    "    \"samples\",\n",
    "    \"predicate_generalization\",\n",
    "    mt.name.split(\"/\")[-1],\n",
    "    source_task.task_name,\n",
    "    yes_no_task.task_name,\n",
    "    \"objects\",\n",
    ")\n",
    "\n",
    "sample_files = [\n",
    "    os.path.join(validation_samples_load_path, f)\n",
    "    for f in os.listdir(validation_samples_load_path)\n",
    "    if f.endswith(\".json\")\n",
    "]\n",
    "logger.info(f\"Found {len(sample_files)} sample files\")\n",
    "\n",
    "prefix = \"\"\n",
    "# prefix = \"Recall the nationality of these people:\\n\"\n",
    "# prefix = \"Recall which country these landmarks are located in:\\n\"\n",
    "# prefix = \"Think about how these words sound when you say them aloud:\\n\"\n",
    "\n",
    "random.shuffle(sample_files)\n",
    "sample_files = sample_files[:validation_limit]\n",
    "for sample_file in sample_files:\n",
    "    with open(sample_file, \"r\") as f:\n",
    "        cf_pair_data = json.load(f)\n",
    "    cf_pair = CounterFactualSamplePair.from_dict(cf_pair_data)\n",
    "\n",
    "    cf_pair.clean_sample.prompt_template = prefix + cf_pair.clean_sample.prompt_template\n",
    "    cf_pair.patch_sample.prompt_template = prefix + cf_pair.patch_sample.prompt_template\n",
    "    validation_set.append((cf_pair.clean_sample, cf_pair.patch_sample))\n",
    "\n",
    "len(validation_set)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "48985203",
   "metadata": {},
   "outputs": [],
   "source": [
    "clean, patch = validation_set[3]\n",
    "print(patch.prompt(), \">>\", mt.tokenizer.decode(patch.ans_token_id))\n",
    "print(clean.prompt(), \">>\", mt.tokenizer.decode(clean.ans_token_id))\n",
    "clean.metadata[\"track_type_obj_token_id\"], mt.tokenizer.decode(clean.metadata[\"track_type_obj_token_id\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "85b8ad3a",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.optimization import validate_q_proj_ie_on_sample_pair\n",
    "import copy\n",
    "\n",
    "clean, patch = copy.deepcopy(validation_set[3])\n",
    "\n",
    "val_sample_result = validate_q_proj_ie_on_sample_pair(\n",
    "    mt=mt,\n",
    "    clean_sample=clean,\n",
    "    patch_sample=patch,\n",
    "    heads=optimized_heads,\n",
    "    query_indices={-2: -2, -1: -1},\n",
    "    add_ques_pos_to_query_indices=True,\n",
    "    verify_head_behavior_on=-1,\n",
    "    patch_args={\n",
    "        \"batch_size\": len(patch.options),\n",
    "        \"distinct_options\": False,\n",
    "        # \"task\": select_task,\n",
    "        # \"prompt_template_idx\": prompt_template_idx,\n",
    "        # \"option_style\": patch.default_option_style,\n",
    "        # \"n_distractors\": N_DISTRACTORS,\n",
    "    },\n",
    ")\n",
    "\n",
    "clean_obj = clean.ans_token_id\n",
    "target_obj = clean.metadata[\"track_type_obj_token_id\"]\n",
    "\n",
    "logger.debug(f\"clean obj: {mt.tokenizer.decode(clean_obj)}\")\n",
    "logger.debug(f\"target obj: {mt.tokenizer.decode(target_obj)}\")\n",
    "\n",
    "before_intervention = {\n",
    "    \"clean_rank\": val_sample_result[\"clean_track\"][clean_obj][0],\n",
    "    \"clean_logit\": val_sample_result[\"clean_track\"][clean_obj][1].logit,\n",
    "    \"target_rank\": val_sample_result[\"clean_track\"][target_obj][0],\n",
    "    \"target_logit\": val_sample_result[\"clean_track\"][target_obj][1].logit,\n",
    "}\n",
    "\n",
    "after_intervention = {\n",
    "    \"clean_rank\": val_sample_result[\"int_track\"][clean_obj][0],\n",
    "    \"clean_logit\": val_sample_result[\"int_track\"][clean_obj][1].logit,\n",
    "    \"target_rank\": val_sample_result[\"int_track\"][target_obj][0],\n",
    "    \"target_logit\": val_sample_result[\"int_track\"][target_obj][1].logit,\n",
    "}\n",
    "\n",
    "clean_rank_delta = after_intervention[\"clean_rank\"] - before_intervention[\"clean_rank\"]\n",
    "target_rank_delta = (\n",
    "    after_intervention[\"target_rank\"] - before_intervention[\"target_rank\"]\n",
    ")\n",
    "logger.info(\n",
    "    f\"Clean Prediction Rank Change: {before_intervention['clean_rank']} -> {after_intervention['clean_rank']} | Delta: {clean_rank_delta} \"\n",
    ")\n",
    "logger.info(\n",
    "    f\"Target Prediction Rank Change: {before_intervention['target_rank']} -> {after_intervention['target_rank']} | Delta: {target_rank_delta} \"\n",
    ")\n",
    "\n",
    "clean_logit_delta = (\n",
    "    after_intervention[\"clean_logit\"] - before_intervention[\"clean_logit\"]\n",
    ")\n",
    "target_logit_delta = (\n",
    "    after_intervention[\"target_logit\"] - before_intervention[\"target_logit\"]\n",
    ")\n",
    "logger.info(\n",
    "    f\"Clean Prediction Logit Change: {before_intervention['clean_logit']:.4f} -> {after_intervention['clean_logit']:.4f} | Delta: {clean_logit_delta:.4f} \"\n",
    ")\n",
    "logger.info(\n",
    "    f\"Target Prediction Logit Change: {before_intervention['target_logit']:.4f} -> {after_intervention['target_logit']:.4f} | Delta: {target_logit_delta:.4f} \"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b2e4124f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm.auto import tqdm\n",
    "\n",
    "validation_results = []\n",
    "\n",
    "for clean_sample, patch_sample in tqdm(validation_set):\n",
    "    val_sample_result = validate_q_proj_ie_on_sample_pair(\n",
    "        mt=mt,\n",
    "        clean_sample=clean_sample,\n",
    "        patch_sample=patch_sample,\n",
    "        heads=optimized_heads,\n",
    "        query_indices={-2: -2, -1: -1},\n",
    "        add_ques_pos_to_query_indices=True,\n",
    "        patch_args={\n",
    "            \"batch_size\": len(patch_sample.options),\n",
    "            \"distinct_options\": False,\n",
    "        },\n",
    "    )\n",
    "    validation_results.append(val_sample_result)\n",
    "    print(\"=\" * 80)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ab35fb7",
   "metadata": {},
   "outputs": [],
   "source": [
    "before_intervention = []\n",
    "after_intervention = []\n",
    "\n",
    "for intervention_result in validation_results:\n",
    "    clean_sample = intervention_result[\"clean_sample\"]\n",
    "    patch_sample = intervention_result[\"patch_sample\"]\n",
    "\n",
    "    clean_obj = clean_sample.ans_token_id\n",
    "    target_obj = clean_sample.metadata[\"track_type_obj_token_id\"]\n",
    "\n",
    "    before_intervention.append({\n",
    "        \"clean_rank\": intervention_result[\"clean_track\"][clean_obj][0],\n",
    "        \"clean_logit\": intervention_result[\"clean_track\"][clean_obj][1].logit,\n",
    "        \"target_rank\": intervention_result[\"clean_track\"][target_obj][0],\n",
    "        \"target_logit\": intervention_result[\"clean_track\"][target_obj][1].logit,\n",
    "    })\n",
    "\n",
    "    after_intervention.append({\n",
    "        \"clean_rank\": intervention_result[\"int_track\"][clean_obj][0],\n",
    "        \"clean_logit\": intervention_result[\"int_track\"][clean_obj][1].logit,\n",
    "        \"target_rank\": intervention_result[\"int_track\"][target_obj][0],\n",
    "        \"target_logit\": intervention_result[\"int_track\"][target_obj][1].logit,\n",
    "    })"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6106589c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "clean_rank_delta = [\n",
    "    after[\"clean_rank\"] - before[\"clean_rank\"]\n",
    "    for before, after in zip(before_intervention, after_intervention)\n",
    "]\n",
    "target_rank_delta = [\n",
    "    after[\"target_rank\"] - before[\"target_rank\"]\n",
    "    for before, after in zip(before_intervention, after_intervention)\n",
    "]\n",
    "\n",
    "clean_rank_delta, target_rank_delta = np.array(clean_rank_delta), np.array(\n",
    "    target_rank_delta\n",
    ")\n",
    "print(f\"clean_rank_delta: {clean_rank_delta.mean():.4f} ± {clean_rank_delta.std():.4f}\")\n",
    "print(\n",
    "    f\"target_rank_delta: {target_rank_delta.mean():.4f} ± {target_rank_delta.std():.4f}\"\n",
    ")\n",
    "\n",
    "clean_rank_after_intervention = [after[\"clean_rank\"] for after in after_intervention]\n",
    "clean_rank_after_intervention = np.array(clean_rank_after_intervention)\n",
    "print(\n",
    "    f\"clean_rank_after_intervention: {clean_rank_after_intervention.mean():.4f} ± {clean_rank_after_intervention.std():.4f}\"\n",
    ")\n",
    "\n",
    "target_rank_after_intervention = [after[\"target_rank\"] for after in after_intervention]\n",
    "target_rank_after_intervention = np.array(target_rank_after_intervention)\n",
    "print(\n",
    "    f\"target_rank_after_intervention: {target_rank_after_intervention.mean():.4f} ± {target_rank_after_intervention.std():.4f}\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a3e32706",
   "metadata": {},
   "outputs": [],
   "source": [
    "clean_logit_delta = [\n",
    "    after[\"clean_logit\"] - before[\"clean_logit\"]\n",
    "    for before, after in zip(before_intervention, after_intervention)\n",
    "]\n",
    "target_logit_delta = [\n",
    "    after[\"target_logit\"] - before[\"target_logit\"]\n",
    "    for before, after in zip(before_intervention, after_intervention)\n",
    "]\n",
    "clean_logit_delta, target_logit_delta = np.array(clean_logit_delta), np.array(target_logit_delta)\n",
    "print(f\"clean_logit_delta: {clean_logit_delta.mean():.4f} ± {clean_logit_delta.std():.4f}\")\n",
    "print(f\"target_logit_delta: {target_logit_delta.mean():.4f} ± {target_logit_delta.std():.4f}\")\n",
    "\n",
    "clean_logit_after_intervention = [\n",
    "    after[\"clean_logit\"]\n",
    "    for after in after_intervention\n",
    "]\n",
    "clean_logit_after_intervention = np.array(clean_logit_after_intervention)\n",
    "print(f\"clean_logit_after_intervention: {clean_logit_after_intervention.mean():.4f} ± {clean_logit_after_intervention.std():.4f}\")\n",
    "\n",
    "target_logit_after_intervention = [\n",
    "    after[\"target_logit\"]\n",
    "    for after in after_intervention\n",
    "]\n",
    "target_logit_after_intervention = np.array(target_logit_after_intervention)\n",
    "print(f\"target_logit_after_intervention: {target_logit_after_intervention.mean():.4f} ± {target_logit_after_intervention.std():.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c5169804",
   "metadata": {},
   "outputs": [],
   "source": [
    "top_1 = sum([1 for after in after_intervention if after[\"target_rank\"] == 1])\n",
    "top_1 / len(after_intervention)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3152f250",
   "metadata": {},
   "outputs": [],
   "source": [
    "counter_patch_type_top_option = 0\n",
    "failed_cases = []\n",
    "\n",
    "for intervention_result in validation_results:\n",
    "    clean_sample = intervention_result[\"clean_sample\"]\n",
    "    patch_sample = intervention_result[\"patch_sample\"]\n",
    "    int_track = intervention_result[\"int_track\"]\n",
    "    clean_track = intervention_result[\"clean_track\"]\n",
    "    if (\n",
    "        int_track[list(int_track.keys())[0]][1].token_id\n",
    "        == clean_sample.metadata[\"track_type_obj_token_id\"]\n",
    "    ): \n",
    "        counter_patch_type_top_option += 1\n",
    "    else:\n",
    "        failed_cases.append(\n",
    "            {\n",
    "                \"clean_sample\": clean_sample,\n",
    "                \"patch_sample\": patch_sample,\n",
    "                \"int_track\": int_track,\n",
    "                \"clean_track\": clean_track,\n",
    "            }\n",
    "        )\n",
    "\n",
    "top_1_accuracy = counter_patch_type_top_option / len(validation_results)\n",
    "print(\"=\" * 80)\n",
    "print(\n",
    "    f\"Counterfactual patching accuracy: {top_1_accuracy:.4f} ({counter_patch_type_top_option}/{len(validation_results)})\"\n",
    ")\n",
    "print(\"=\" * 80)\n",
    "print(f\"{len(failed_cases)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d661d55",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "connection",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
