{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "63d437ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "635faeb3",
   "metadata": {},
   "source": [
    "## Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c96bc212",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c2589d0b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "\n",
    "import sys\n",
    "\n",
    "sys.path.append(\"../\")\n",
    "\n",
    "##################################################################\n",
    "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0,1,2,3,4,5,6,7\"\n",
    "##################################################################\n",
    "\n",
    "import logging\n",
    "from src.utils import logging_utils\n",
    "from src.utils import env_utils\n",
    "\n",
    "logger = logging.getLogger(__name__)\n",
    "\n",
    "logging.basicConfig(\n",
    "    level=logging.DEBUG,\n",
    "    format=logging_utils.DEFAULT_FORMAT,\n",
    "    datefmt=logging_utils.DEFAULT_DATEFMT,\n",
    "    stream=sys.stdout,\n",
    ")\n",
    "\n",
    "import torch\n",
    "import transformers\n",
    "\n",
    "logger.info(f\"{torch.__version__=}, {torch.version.cuda=}\")\n",
    "logger.info(\n",
    "    f\"{torch.cuda.is_available()=}, {torch.cuda.device_count()=}, {torch.cuda.get_device_name()=}\"\n",
    ")\n",
    "logger.info(f\"{transformers.__version__=}\")\n",
    "\n",
    "from src.utils.training_utils import get_device_map\n",
    "\n",
    "model_key = \"meta-llama/Llama-3.3-70B-Instruct\"\n",
    "\n",
    "device_map = get_device_map(model_key, 32, n_gpus=8)\n",
    "print(device_map)\n",
    "\n",
    "print(os.getcwd())\n",
    "\n",
    "from src.models import ModelandTokenizer\n",
    "\n",
    "mt = ModelandTokenizer(\n",
    "    model_key=model_key,\n",
    "    torch_dtype=torch.bfloat16,\n",
    "    device_map=device_map,\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b955ca87",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from src.selection.data import SelectOneTask\n",
    "\n",
    "optimized_path = os.path.join(\n",
    "    \"retrieval/results/\"\n",
    "    \"selection/optimized_heads\",\n",
    "    mt.name.split(\"/\")[-1],\n",
    "    \"distinct_options\",\n",
    "    \"select_one\",\n",
    "    \"epoch_10.npz\"\n",
    ")\n",
    "optimization_results = np.load(optimized_path, allow_pickle=True)\n",
    "optimal_head_mask = torch.tensor(optimization_results['optimal_mask']).to(torch.float32)\n",
    "optimal_head_mask[50:, :] = 0.0\n",
    "\n",
    "plt.imshow(\n",
    "    optimal_head_mask.T.numpy(),\n",
    "    cmap=\"Blues\",\n",
    "    aspect=\"auto\",\n",
    "    vmin=0,\n",
    "    vmax=1,\n",
    ")\n",
    "\n",
    "heads_selected = torch.nonzero(optimal_head_mask > 0.5, as_tuple=False).tolist()\n",
    "heads_selected = [\n",
    "    (layer_idx, head_idx) for layer_idx, head_idx in heads_selected\n",
    "]\n",
    "print(len(heads_selected))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "42470176",
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy\n",
    "import random\n",
    "from src.selection.data import SelectionSample, SelectOrderTask, SelectFirstTask\n",
    "from src.selection.utils import KeyedSet, get_first_token_id, verify_correct_option\n",
    "from src.tokens import prepare_input\n",
    "\n",
    "N_DISTRACTORS = 5\n",
    "\n",
    "@torch.inference_mode()\n",
    "def get_counterfactual_samples_within_task(\n",
    "    task: SelectOneTask | SelectFirstTask,\n",
    "    patch_category: str | None = None,\n",
    "    clean_category: str | None = None,\n",
    "    shuffle_clean_options: bool = False,\n",
    "    prompt_template_idx=2,\n",
    "    option_style=\"numbered\",\n",
    "    filter_by_lm_prediction: bool = True,\n",
    "    distinct_options: bool = False,\n",
    "    n_distractors: int = N_DISTRACTORS,\n",
    "):\n",
    "    # Get the categories\n",
    "    categories = list(task.category_wise_examples.keys())\n",
    "\n",
    "    # Set the patch category\n",
    "    if patch_category is None:\n",
    "        patch_category = random.choice(categories)\n",
    "\n",
    "    # Set the patch subject and object\n",
    "    patch_subj, patch_obj = random.sample(\n",
    "        task.category_wise_examples[patch_category], 2\n",
    "    )\n",
    "    logger.info(\n",
    "        f\"Patch category: {patch_category}, subject: {patch_subj}, object: {patch_obj}\"\n",
    "    )\n",
    "\n",
    "    if clean_category is None:\n",
    "        clean_category = random.choice(list(set(categories) - {patch_category}))\n",
    "\n",
    "    clean_options = task.category_wise_examples[clean_category]\n",
    "    random.shuffle(clean_options)\n",
    "\n",
    "    clean_subj, clean_obj = random.sample(\n",
    "        (\n",
    "            KeyedSet(clean_options, mt.tokenizer) - KeyedSet([patch_obj], mt.tokenizer)\n",
    "        ).values,\n",
    "        2,\n",
    "    )\n",
    "    logger.info(\n",
    "        f\"Clean category: {clean_category}, subject: {clean_subj}, object: {clean_obj}\"\n",
    "    )\n",
    "\n",
    "    if distinct_options is False:\n",
    "        patch_type_obj = patch_obj\n",
    "        clean_type_obj = clean_obj\n",
    "    else:\n",
    "        patch_type_obj = random.choice(\n",
    "            (\n",
    "                KeyedSet(task.category_wise_examples[patch_category], mt.tokenizer)\n",
    "                - KeyedSet([patch_obj], mt.tokenizer)\n",
    "            ).values\n",
    "        )\n",
    "        clean_type_obj = random.choice(\n",
    "            (\n",
    "                KeyedSet(task.category_wise_examples[clean_category], mt.tokenizer)\n",
    "                - KeyedSet([clean_obj], mt.tokenizer)\n",
    "            ).values\n",
    "        )\n",
    "\n",
    "    patch_must_have_options = [patch_obj, clean_type_obj]\n",
    "    clean_must_have_options = [clean_obj, patch_type_obj]\n",
    "\n",
    "    logger.info(f\"{patch_must_have_options=}\")\n",
    "    logger.info(f\"{clean_must_have_options=}\")\n",
    "    logger.info(f\"{clean_type_obj=}\")\n",
    "    logger.info(f\"{patch_type_obj=}\")\n",
    "\n",
    "    patch_distractors = []\n",
    "    other_categories = random.sample(\n",
    "        list(set(categories) - {patch_category, clean_category}),\n",
    "        k=n_distractors - (len(patch_must_have_options)) + 1,\n",
    "    )\n",
    "\n",
    "    for other_category in other_categories:\n",
    "        other_examples = task.category_wise_examples[other_category]\n",
    "        random.shuffle(other_examples)\n",
    "        other_examples = KeyedSet(other_examples, mt.tokenizer)\n",
    "        patch_distractors.append(\n",
    "            random.choice(\n",
    "                (\n",
    "                    other_examples\n",
    "                    - KeyedSet(\n",
    "                        patch_must_have_options + patch_distractors,\n",
    "                        tokenizer=mt.tokenizer,\n",
    "                    )\n",
    "                ).values\n",
    "            )\n",
    "        )\n",
    "\n",
    "    patch_options = patch_must_have_options + patch_distractors\n",
    "    random.shuffle(patch_options)\n",
    "    patch_obj_idx = patch_options.index(patch_obj)\n",
    "    logger.info(f\"{patch_obj_idx=} | {patch_options}\")\n",
    "\n",
    "    if distinct_options is not True:\n",
    "        clean_options = copy.deepcopy(patch_options)\n",
    "        if shuffle_clean_options:\n",
    "            # Useful for the pointer experiments\n",
    "            while (\n",
    "                clean_options.index(clean_obj) == patch_obj_idx\n",
    "                or clean_options.index(patch_type_obj) == patch_obj_idx\n",
    "            ):\n",
    "                random.shuffle(clean_options)\n",
    "        clean_obj_idx = clean_options.index(clean_obj)\n",
    "\n",
    "    else:\n",
    "        other_categories = random.sample(\n",
    "            list(set(categories) - {patch_category, clean_category}),\n",
    "            k=n_distractors - (len(clean_must_have_options)) + 1,\n",
    "        )\n",
    "        clean_distractors = []\n",
    "        for other_category in other_categories:\n",
    "            other_examples = task.category_wise_examples[other_category]\n",
    "            random.shuffle(other_examples)\n",
    "            other_examples = KeyedSet(other_examples, mt.tokenizer)\n",
    "            clean_distractors.append(\n",
    "                random.choice(\n",
    "                    (\n",
    "                        other_examples\n",
    "                        - KeyedSet(\n",
    "                            clean_must_have_options + clean_distractors,\n",
    "                            tokenizer=mt.tokenizer,\n",
    "                        )\n",
    "                    ).values\n",
    "                )\n",
    "            )\n",
    "        clean_options = clean_must_have_options + clean_distractors\n",
    "        random.shuffle(clean_options)\n",
    "        while clean_options.index(clean_obj) == patch_obj_idx:\n",
    "            random.shuffle(clean_options)\n",
    "        clean_obj_idx = clean_options.index(clean_obj)\n",
    "\n",
    "    logger.info(f\"{clean_obj_idx=} | {clean_options}\")\n",
    "\n",
    "    kwargs = dict(\n",
    "        prompt_template=task.prompt_templates[prompt_template_idx],\n",
    "        default_option_style=option_style,\n",
    "    )\n",
    "    #print(f\"{type(task)=}\")\n",
    "    if isinstance(task, SelectOrderTask):\n",
    "        patch_metadata = {\n",
    "            \"track_type_obj_idx\": clean_obj_idx,\n",
    "            \"track_type_obj\": patch_options[clean_obj_idx],\n",
    "            \"track_type_obj_token_id\": get_first_token_id(\n",
    "                patch_options[clean_obj_idx], mt.tokenizer, prefix=\" \"\n",
    "            ),\n",
    "        }\n",
    "        clean_metadata = {\n",
    "            \"track_type_obj_idx\": patch_obj_idx,\n",
    "            \"track_type_obj\": clean_options[patch_obj_idx],\n",
    "            \"track_type_obj_token_id\": get_first_token_id(\n",
    "                clean_options[patch_obj_idx], mt.tokenizer, prefix=\" \"\n",
    "            ),\n",
    "        }\n",
    "    elif isinstance(task, SelectOneTask):\n",
    "        patch_metadata = {\n",
    "            \"track_category\": clean_category,\n",
    "            \"track_type_obj\": clean_type_obj,\n",
    "            \"track_type_obj_idx\": patch_options.index(clean_type_obj),\n",
    "            \"track_type_obj_token_id\": get_first_token_id(\n",
    "                clean_type_obj, mt.tokenizer, prefix=\" \"\n",
    "            ),\n",
    "        }\n",
    "        clean_metadata = {\n",
    "            \"track_category\": patch_category,\n",
    "            \"track_type_obj\": patch_type_obj,\n",
    "            \"track_type_obj_idx\": clean_options.index(patch_type_obj),\n",
    "            \"track_type_obj_token_id\": get_first_token_id(\n",
    "                patch_type_obj, mt.tokenizer, prefix=\" \"\n",
    "            ),\n",
    "        }\n",
    "    else:\n",
    "        raise NotImplementedError(f\"Unsupported task type: {type(task)}\")\n",
    "\n",
    "    patch_sample = SelectionSample(\n",
    "        subj=patch_subj,\n",
    "        obj=patch_obj,\n",
    "        answer=patch_obj,\n",
    "        obj_idx=patch_obj_idx,\n",
    "        ans_token_id=get_first_token_id(patch_obj, mt.tokenizer, prefix=\" \"),\n",
    "        options=patch_options,\n",
    "        category=patch_category,\n",
    "        metadata=patch_metadata,\n",
    "        **kwargs,\n",
    "    )\n",
    "    clean_sample = SelectionSample(\n",
    "        subj=clean_subj,\n",
    "        obj=clean_obj,\n",
    "        answer=clean_obj,\n",
    "        obj_idx=clean_obj_idx,\n",
    "        ans_token_id=get_first_token_id(clean_obj, mt.tokenizer, prefix=\" \"),\n",
    "        options=clean_options,\n",
    "        category=clean_category,\n",
    "        metadata=clean_metadata,\n",
    "        **kwargs,\n",
    "    )\n",
    "\n",
    "    if filter_by_lm_prediction:\n",
    "        test_samples = [patch_sample, clean_sample]\n",
    "        if distinct_options is True:\n",
    "            clean_sample_2 = copy.deepcopy(patch_sample)\n",
    "            clean_sample_2.options = clean_options\n",
    "            clean_sample_2.obj = clean_sample.metadata[\"track_type_obj\"]\n",
    "            clean_sample_2.obj_idx = clean_sample.metadata[\"track_type_obj_idx\"]\n",
    "            clean_sample_2.ans_token_id = clean_sample.metadata[\n",
    "                \"track_type_obj_token_id\"\n",
    "            ]\n",
    "            test_samples.append(clean_sample_2)\n",
    "\n",
    "        for sample in test_samples:\n",
    "            tokenized = prepare_input(tokenizer=mt, prompts=sample.prompt())\n",
    "            is_correct, predictions, track_options = verify_correct_option(\n",
    "                mt=mt, target=sample.obj, options=sample.options, input=tokenized\n",
    "            )\n",
    "            sample.metadata[\"tokenized\"] = tokenized.data\n",
    "            logger.info(sample.prompt())\n",
    "            logger.info(\n",
    "                f\"{sample.subj} | {sample.category} -> {sample.obj} | pred={[str(p) for p in predictions]}\"\n",
    "            )\n",
    "            if not is_correct:\n",
    "                logger.error(\n",
    "                    f'Prediction mismatch: {track_options[list(track_options.keys())[0]]}[\"{mt.tokenizer.decode(predictions[0].token_id)}\"] != {sample.ans_token_id}[\"{mt.tokenizer.decode(sample.ans_token_id)}\"]'\n",
    "                )\n",
    "                return get_counterfactual_samples_within_task(\n",
    "                    task=task,\n",
    "                    patch_category=patch_category,\n",
    "                    clean_category=clean_category,\n",
    "                    shuffle_clean_options=shuffle_clean_options,\n",
    "                    prompt_template_idx=prompt_template_idx,\n",
    "                    option_style=option_style,\n",
    "                    filter_by_lm_prediction=filter_by_lm_prediction,\n",
    "                    distinct_options=distinct_options,\n",
    "                    n_distractors=n_distractors,\n",
    "                )\n",
    "            sample.prediction = predictions\n",
    "\n",
    "    return patch_sample, clean_sample"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3efb0212",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.data import CountingSample, CountingTask, YesNoTask\n",
    "\n",
    "def get_counterfactual_samples_within_counting_task(\n",
    "    task: CountingTask | YesNoTask,\n",
    "    n_options: int,\n",
    "    n_clean = None,\n",
    "    n_patch = None,\n",
    "    prompt_template_idx= 1,\n",
    "    option_style = \"single_line\",\n",
    "    filter_by_lm_prediction=True,\n",
    "    patch_category = None,\n",
    "    clean_category = None,\n",
    "    verbose = True,\n",
    "    retry_count = 0,\n",
    "    distinct_options: bool = True,\n",
    "):\n",
    "    # Get counts\n",
    "\n",
    "    if isinstance(task, YesNoTask):\n",
    "        if random.random() < 0.5:\n",
    "            n_clean = random.randint(1, n_options)\n",
    "            n_patch = n_options - n_clean\n",
    "        else:\n",
    "            n_clean = 0\n",
    "            n_patch = n_options\n",
    "    else:\n",
    "        if n_clean and n_patch:\n",
    "            n_options = n_clean + n_patch\n",
    "        elif n_clean:\n",
    "            n_patch = n_options - n_clean\n",
    "        elif n_patch:\n",
    "            n_clean = n_options - n_patch\n",
    "        else:\n",
    "            n_clean = random.randint(1,n_options-1)\n",
    "            n_patch = n_options - n_clean\n",
    "\n",
    "    # Get the categories\n",
    "    categories = list(task.category_wise_examples.keys())\n",
    "\n",
    "    # Set the patch category\n",
    "    if patch_category is None:\n",
    "        patch_category = random.choice(categories)\n",
    "\n",
    "    # Set the patch objects\n",
    "    patch_objects = random.sample(\n",
    "        task.category_wise_examples[patch_category], n_patch\n",
    "    )\n",
    "\n",
    "    # Set the clean category\n",
    "    if clean_category is None:\n",
    "        clean_category = random.choice(list(set(categories) - {patch_category}))\n",
    "\n",
    "    # Set the clean objects\n",
    "    clean_objects = random.sample(\n",
    "        task.category_wise_examples[clean_category], n_clean\n",
    "    )\n",
    "\n",
    "    if distinct_options is False:\n",
    "        all_objects = clean_objects + patch_objects\n",
    "        random.shuffle(all_objects)\n",
    "        clean_options = all_objects\n",
    "        patch_options = all_objects\n",
    "    else:\n",
    "        alt_clean_objects = random.sample(\n",
    "            [opt for opt in task.category_wise_examples[clean_category] if opt not in clean_objects],    \n",
    "            n_clean\n",
    "        )\n",
    "        #print(f\"{alt_clean_objects=}\")\n",
    "        \n",
    "        alt_patch_objects = random.sample(\n",
    "            [opt for opt in task.category_wise_examples[patch_category] if opt not in patch_objects],    \n",
    "            n_patch\n",
    "        )\n",
    "        #print(f\"{alt_patch_objects=}\")\n",
    "\n",
    "        # TODO: Check that its ok that these lists are in different orders.\n",
    "        # The prediction has to do with the number of items of a given category,\n",
    "        # so I think as long as the numbers add up then its fine.\n",
    "        # It may even be preferable this way, having the orders unsynchronized.\n",
    "        clean_options = clean_objects + alt_patch_objects\n",
    "        random.shuffle(clean_options)\n",
    "        print(f\"{clean_options=}\")\n",
    "        patch_options = patch_objects + alt_clean_objects\n",
    "        random.shuffle(patch_options)\n",
    "        print(f\"{patch_options=}\")\n",
    "        \n",
    "\n",
    "    kwargs = dict(\n",
    "        prompt_template=task.prompt_templates[prompt_template_idx],\n",
    "        default_option_style=option_style\n",
    "    )\n",
    "\n",
    "    patch_sample = CountingSample(\n",
    "        options=patch_options,\n",
    "        count=n_patch,\n",
    "        category=patch_category,\n",
    "        **kwargs\n",
    "    )\n",
    "\n",
    "    clean_sample = CountingSample(\n",
    "        options=clean_options,\n",
    "        count=n_clean,\n",
    "        category=clean_category,\n",
    "        **kwargs\n",
    "    )\n",
    "\n",
    "    if verbose:\n",
    "        print(f\"{clean_category=}\")\n",
    "        print(f\"{patch_category=}\")\n",
    "        print(f\"{clean_options=}\")\n",
    "        print(f\"{patch_options=}\")\n",
    "\n",
    "    if filter_by_lm_prediction:\n",
    "        test_samples = [patch_sample, clean_sample]\n",
    "        \n",
    "        for sample in test_samples:\n",
    "            if retry_count >= 10: break\n",
    "            tokenized = prepare_input(tokenizer=mt, prompts=sample.prompt())\n",
    "            if isinstance(task, YesNoTask):\n",
    "                if sample.count == 0:\n",
    "                    target_token_id = mt.tokenizer.encode(\" No\", add_special_tokens=False)[0]\n",
    "                    print(\"No\")\n",
    "                else:\n",
    "                    target_token_id = mt.tokenizer.encode(\" Yes\", add_special_tokens=False)[0] \n",
    "                    print(\"Yes\")\n",
    "            else:\n",
    "                count_str_map = {\n",
    "                    0: \" zero\",\n",
    "                    1: \" one\",\n",
    "                    2: \" two\",\n",
    "                    3: \" three\",\n",
    "                    4: \" four\",\n",
    "                    5: \" five\",\n",
    "                    6: \" six\",\n",
    "                    7: \" seven\",\n",
    "                    8: \" eight\",\n",
    "                    9: \" nine\",\n",
    "                    10: \" ten\",\n",
    "                }\n",
    "                target_token_id = mt.tokenizer.encode(count_str_map[sample.count], add_special_tokens=False)[0]\n",
    "            #print(f\"{target_token_id=}\")\n",
    "            is_correct, predictions, track_options = verify_correct_option(\n",
    "                mt=mt, target=target_token_id, options=sample.options, input=tokenized,\n",
    "                is_counting_task=True,\n",
    "            )\n",
    "            sample.metadata[\"tokenized\"] = tokenized.data\n",
    "            sample.metadata[\"predictions\"] = predictions\n",
    "\n",
    "            if not is_correct:\n",
    "                logger.error(\n",
    "                    f\"Prediction mismatch!\"\n",
    "                    f\"Retry Count: {retry_count}\"\n",
    "                )\n",
    "                return get_counterfactual_samples_within_counting_task(\n",
    "                    task=task,\n",
    "                    n_options=n_options,\n",
    "                    n_clean=n_clean,\n",
    "                    n_patch=n_patch,\n",
    "                    prompt_template_idx=prompt_template_idx,\n",
    "                    option_style=option_style,\n",
    "                    filter_by_lm_prediction=filter_by_lm_prediction,\n",
    "                    patch_category=patch_category,\n",
    "                    clean_category=clean_category,\n",
    "                    verbose=verbose,\n",
    "                    retry_count=retry_count+1\n",
    "                )\n",
    "            sample.prediction = predictions\n",
    "\n",
    "    return patch_sample, clean_sample\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f2fb9022",
   "metadata": {},
   "source": [
    "Canonical Example Counting Task:\n",
    "\n",
    "chalk, apple, lotion, banana, car\n",
    "How many fruits are there in the previous list?\n",
    "Answer: >> two\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "335b7b5b",
   "metadata": {},
   "source": [
    "# Test Suite\n",
    "\n",
    "Testing generalization of the filter heads across different tasks.\n",
    "\n",
    "1. Paraphrasing + Presentation (MCQ)\n",
    "2. Different reduce operations\n",
    "- Select One\n",
    "- Counting\n",
    "- Yes/no - is there a fruit in the list\n",
    "- First fruit in the list\n",
    "3. (skip)\n",
    "4. Different language?\n",
    "- Spanish\n",
    "- French"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8cac2802",
   "metadata": {},
   "source": [
    "## Select One"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e14b5206",
   "metadata": {},
   "source": [
    "### Select One - Obj type (Objects - Question After - Single Line)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2bf1afce",
   "metadata": {},
   "source": [
    "#### Load the Selection Task"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1bb08b4e",
   "metadata": {},
   "outputs": [],
   "source": [
    "TASK_CLS = SelectOneTask\n",
    "N_DISTRACTORS = 5\n",
    "prompt_template_idx = 1\n",
    "OPTION_STYLE = \"single_line\"\n",
    "\n",
    "select_task = TASK_CLS.load(\n",
    "    path=\"data_save/selection/objects.json\"\n",
    ")\n",
    "\n",
    "print(select_task.task_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "89424352",
   "metadata": {},
   "source": [
    "#### Construct a validation set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "37d9e2ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "validation_set = []\n",
    "validation_limit = 1024\n",
    "\n",
    "while len(validation_set) < validation_limit:\n",
    "    patch, clean = get_counterfactual_samples_within_task(\n",
    "        task=select_task,\n",
    "        filter_by_lm_prediction=True,\n",
    "        prompt_template_idx=prompt_template_idx,\n",
    "        option_style=OPTION_STYLE,\n",
    "        distinct_options=True,\n",
    "        n_distractors=N_DISTRACTORS\n",
    "    )\n",
    "    validation_set.append((clean, patch))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f4406ba1",
   "metadata": {},
   "source": [
    "### Test the accuracy when we patch in our filter heads"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a8b3030b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "from src.selection.optimization import validate_q_proj_ie_on_sample_pair\n",
    "\n",
    "validation_results = []\n",
    "for clean_sample, patch_sample in tqdm(validation_set):\n",
    "    result = validate_q_proj_ie_on_sample_pair(\n",
    "        mt=mt,\n",
    "        clean_sample=clean_sample,\n",
    "        patch_sample=patch_sample,\n",
    "        heads=heads_selected,\n",
    "        query_indices={-3:-3, -2:-2, -1:-1},\n",
    "        verify_head_behavior_on=None,\n",
    "        #ablate_possible_ans_info_from_options=True,\n",
    "    )\n",
    "    validation_results.append(result)\n",
    "    print(\"=\" * 80)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dad33400",
   "metadata": {},
   "outputs": [],
   "source": [
    "counter_patch_type_top_option = 0\n",
    "failed_cases = []\n",
    "all_cases = []\n",
    "\n",
    "for intervention_result in tqdm(validation_results):\n",
    "    clean_sample = intervention_result[\"clean_sample\"]\n",
    "    patch_sample = intervention_result[\"patch_sample\"]\n",
    "    int_track = intervention_result[\"int_track\"]\n",
    "    clean_track = intervention_result[\"clean_track\"]\n",
    "\n",
    "    clean_obj = clean_sample.ans_token_id\n",
    "    target_obj = clean_sample.metadata[\"track_type_obj_token_id\"]\n",
    "\n",
    "    clean_logit_before = clean_track[clean_obj][1].logit\n",
    "    clean_logit_after = int_track[clean_obj][1].logit\n",
    "\n",
    "    target_logit_before = clean_track[target_obj][1].logit\n",
    "    target_logit_after = int_track[target_obj][1].logit\n",
    "\n",
    "    clean_logit_delta = clean_logit_after - clean_logit_before\n",
    "    target_logit_delta = target_logit_after - target_logit_before\n",
    "\n",
    "    if (\n",
    "        int_track[list(int_track.keys())[0]][1].token_id\n",
    "        == clean_sample.metadata[\"track_type_obj_token_id\"] # In the updated data this will actually have the position token id rather than the pexisitng\n",
    "    ):\n",
    "        counter_patch_type_top_option += 1\n",
    "        failed = False\n",
    "    else:\n",
    "        failed_cases.append(\n",
    "            {\n",
    "                \"clean_sample\": clean_sample,\n",
    "                \"patch_sample\": patch_sample,\n",
    "                \"int_track\": int_track,\n",
    "                \"clean_track\": clean_track\n",
    "            }\n",
    "        )\n",
    "        failed = True\n",
    "    all_cases.append(\n",
    "            {\n",
    "                \"clean_sample\": clean_sample,\n",
    "                \"patch_sample\": patch_sample,\n",
    "                \"int_track\": int_track,\n",
    "                \"clean_track\": clean_track,\n",
    "                \"clean_logit_delta\": clean_logit_delta,\n",
    "                \"target_logit_delta\": target_logit_delta,\n",
    "                \"failed\": failed\n",
    "            }\n",
    "    )\n",
    "\n",
    "top_1_accuracy = counter_patch_type_top_option / len(validation_results)\n",
    "\n",
    "print(\n",
    "    f\"Counterfactual patching accuracy: {top_1_accuracy:.4f}\",\n",
    "    f\"({counter_patch_type_top_option}/{len(validation_results)})\"\n",
    ")\n",
    "\n",
    "print(f\"Average clean logit delta: {sum(case['clean_logit_delta'] for case in all_cases) / len(all_cases)}\")\n",
    "print(f\"Average target logit delta: {sum(case['target_logit_delta'] for case in all_cases) / len(all_cases)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8917aee0",
   "metadata": {},
   "source": [
    "### Profession - Question After - Single Line "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11597382",
   "metadata": {},
   "outputs": [],
   "source": [
    "TASK_CLS = SelectOneTask\n",
    "N_DISTRACTORS = 5\n",
    "prompt_template_idx = 1\n",
    "OPTION_STYLE = \"single_line\"\n",
    "\n",
    "select_task = TASK_CLS.load(\n",
    "    path=\"data_save/selection/profession.json\"\n",
    ")\n",
    "\n",
    "print(select_task.task_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d53e3163",
   "metadata": {},
   "outputs": [],
   "source": [
    "validation_set = []\n",
    "validation_limit = 1024\n",
    "\n",
    "while len(validation_set) < validation_limit:\n",
    "    patch, clean = get_counterfactual_samples_within_task(\n",
    "        task=select_task,\n",
    "        filter_by_lm_prediction=True,\n",
    "        prompt_template_idx=prompt_template_idx,\n",
    "        option_style=OPTION_STYLE,\n",
    "        distinct_options=True,\n",
    "        n_distractors=N_DISTRACTORS\n",
    "    )\n",
    "    validation_set.append((clean, patch))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0bdd4745",
   "metadata": {},
   "source": [
    "### Test the accuracy when we patch in our filter heads"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d89c14e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "from src.selection.optimization import validate_q_proj_ie_on_sample_pair\n",
    "\n",
    "validation_results = []\n",
    "for clean_sample, patch_sample in tqdm(validation_set):\n",
    "    result = validate_q_proj_ie_on_sample_pair(\n",
    "        mt=mt,\n",
    "        clean_sample=clean_sample,\n",
    "        patch_sample=patch_sample,\n",
    "        heads=heads_selected,\n",
    "        query_indices={-3:-3, -2:-2, -1:-1},\n",
    "        verify_head_behavior_on=None,\n",
    "        #ablate_possible_ans_info_from_options=True,\n",
    "    )\n",
    "    validation_results.append(result)\n",
    "    print(\"=\" * 80)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0b8c2cb3",
   "metadata": {},
   "outputs": [],
   "source": [
    "counter_patch_type_top_option = 0\n",
    "failed_cases = []\n",
    "all_cases = []\n",
    "\n",
    "for intervention_result in tqdm(validation_results):\n",
    "    clean_sample = intervention_result[\"clean_sample\"]\n",
    "    patch_sample = intervention_result[\"patch_sample\"]\n",
    "    int_track = intervention_result[\"int_track\"]\n",
    "    clean_track = intervention_result[\"clean_track\"]\n",
    "\n",
    "    clean_obj = clean_sample.ans_token_id\n",
    "    target_obj = clean_sample.metadata[\"track_type_obj_token_id\"]\n",
    "\n",
    "    clean_logit_before = clean_track[clean_obj][1].logit\n",
    "    clean_logit_after = int_track[clean_obj][1].logit\n",
    "\n",
    "    target_logit_before = clean_track[target_obj][1].logit\n",
    "    target_logit_after = int_track[target_obj][1].logit\n",
    "\n",
    "    clean_logit_delta = clean_logit_after - clean_logit_before\n",
    "    target_logit_delta = target_logit_after - target_logit_before\n",
    "\n",
    "    if (\n",
    "        int_track[list(int_track.keys())[0]][1].token_id\n",
    "        == clean_sample.metadata[\"track_type_obj_token_id\"] # In the updated data this will actually have the position token id rather than the pexisitng\n",
    "    ):\n",
    "        counter_patch_type_top_option += 1\n",
    "        failed = False\n",
    "    else:\n",
    "        failed_cases.append(\n",
    "            {\n",
    "                \"clean_sample\": clean_sample,\n",
    "                \"patch_sample\": patch_sample,\n",
    "                \"int_track\": int_track,\n",
    "                \"clean_track\": clean_track\n",
    "            }\n",
    "        )\n",
    "        failed = True\n",
    "    all_cases.append(\n",
    "            {\n",
    "                \"clean_sample\": clean_sample,\n",
    "                \"patch_sample\": patch_sample,\n",
    "                \"int_track\": int_track,\n",
    "                \"clean_track\": clean_track,\n",
    "                \"clean_logit_delta\": clean_logit_delta,\n",
    "                \"target_logit_delta\": target_logit_delta,\n",
    "                \"failed\": failed\n",
    "            }\n",
    "    )\n",
    "\n",
    "top_1_accuracy = counter_patch_type_top_option / len(validation_results)\n",
    "\n",
    "print(\n",
    "    f\"Counterfactual patching accuracy: {top_1_accuracy:.4f}\",\n",
    "    f\"({counter_patch_type_top_option}/{len(validation_results)})\"\n",
    ")\n",
    "\n",
    "print(f\"Average clean logit delta: {sum(case['clean_logit_delta'] for case in all_cases) / len(all_cases)}\")\n",
    "print(f\"Average target logit delta: {sum(case['target_logit_delta'] for case in all_cases) / len(all_cases)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1113a8ce",
   "metadata": {},
   "source": [
    "## SelectOne - Numbered"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a85cc3ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "OPTION_STYLE = \"numbered\"\n",
    "\n",
    "validation_set = []\n",
    "validation_limit = 1024\n",
    "\n",
    "while len(validation_set) < validation_limit:\n",
    "    patch, clean = get_counterfactual_samples_within_task(\n",
    "        task=select_task,\n",
    "        filter_by_lm_prediction=True,\n",
    "        prompt_template_idx=prompt_template_idx,\n",
    "        option_style=OPTION_STYLE,\n",
    "        distinct_options=True,\n",
    "        n_distractors=N_DISTRACTORS\n",
    "    )\n",
    "    validation_set.append((clean, patch))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f9096a0f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "from src.selection.optimization import validate_q_proj_ie_on_sample_pair\n",
    "\n",
    "validation_results = []\n",
    "for clean_sample, patch_sample in tqdm(validation_set):\n",
    "    result = validate_q_proj_ie_on_sample_pair(\n",
    "        mt=mt,\n",
    "        clean_sample=clean_sample,\n",
    "        patch_sample=patch_sample,\n",
    "        heads=heads_selected,\n",
    "        query_indices={-3:-3, -2:-2, -1:-1},\n",
    "        verify_head_behavior_on=None,\n",
    "    )\n",
    "    validation_results.append(result)\n",
    "    print(\"=\" * 80)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6a65b0c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "counter_patch_type_top_option = 0\n",
    "failed_cases = []\n",
    "\n",
    "for intervention_result in tqdm(validation_results):\n",
    "    clean_sample = intervention_result[\"clean_sample\"]\n",
    "    patch_sample = intervention_result[\"patch_sample\"]\n",
    "    int_track = intervention_result[\"int_track\"]\n",
    "    clean_track = intervention_result[\"clean_track\"]\n",
    "    if (\n",
    "        int_track[list(int_track.keys())[0]][1].token_id\n",
    "        == clean_sample.metadata[\"track_type_obj_token_id\"]\n",
    "    ):\n",
    "        counter_patch_type_top_option += 1\n",
    "    else:\n",
    "        failed_cases.append(\n",
    "            {\n",
    "                \"clean_sample\": clean_sample,\n",
    "                \"patch_sample\": patch_sample,\n",
    "                \"int_track\": int_track,\n",
    "                \"clean_track\": clean_track\n",
    "            }\n",
    "        )\n",
    "\n",
    "top_1_accuracy = counter_patch_type_top_option / len(validation_results)\n",
    "\n",
    "print(\n",
    "    f\"Counterfactual patching accuracy: {top_1_accuracy:.4f}\",\n",
    "    f\"({counter_patch_type_top_option}/{len(validation_results)})\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5c44c16b",
   "metadata": {},
   "source": [
    "## Match Object Type"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f7e03405",
   "metadata": {},
   "outputs": [],
   "source": [
    "TASK_CLS = SelectOneTask\n",
    "N_DISTRACTORS = 5\n",
    "prompt_template_idx = 1\n",
    "OPTION_STYLE = \"single_line\"\n",
    "\n",
    "select_task = TASK_CLS.load(\n",
    "    path=\"data_save/selection/objects.json\"\n",
    ")\n",
    "\n",
    "print(select_task.task_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "acdc2bc7",
   "metadata": {},
   "source": [
    "## Person Profession"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ebd253c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "TASK_CLS = SelectOneTask\n",
    "N_DISTRACTORS = 5\n",
    "prompt_template_idx = 1\n",
    "OPTION_STYLE = \"single_line\"\n",
    "\n",
    "select_task = TASK_CLS.load(\n",
    "    path=\"/data_save/selection/profession.json\"\n",
    ")\n",
    "\n",
    "print(select_task.task_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fc51dab7",
   "metadata": {},
   "source": [
    "## Counting"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "576f8543",
   "metadata": {},
   "source": [
    "### Load the Counting Task"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "59bcb03c",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.data import CountingTask\n",
    "\n",
    "TASK_CLS = CountingTask\n",
    "N_DISTRACTORS = 5\n",
    "prompt_template_idx = 3\n",
    "OPTION_STYLE = \"single_line\"\n",
    "\n",
    "counting_task = TASK_CLS.load(\n",
    "    path=\"data_save/counting/fruits.json\"\n",
    ")\n",
    "\n",
    "print(counting_task)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "44d63d88",
   "metadata": {},
   "source": [
    "### Construct a validation set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "efde0b95",
   "metadata": {},
   "outputs": [],
   "source": [
    "counting_validation_set = []\n",
    "counting_validation_limit = 1024\n",
    "\n",
    "while len(counting_validation_set) < counting_validation_limit:\n",
    "    patch, clean = get_counterfactual_samples_within_counting_task(\n",
    "        task=counting_task,\n",
    "        n_options=5,\n",
    "        verbose=False,\n",
    "        distinct_options=True,\n",
    "        filter_by_lm_prediction=True,\n",
    "    )\n",
    "    counting_validation_set.append((clean, patch))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "de40b6b6",
   "metadata": {},
   "source": [
    "### Test the accuracy when we patch in our filter heads"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b222e710",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "from src.selection.optimization import validate_q_proj_ie_on_sample_pair\n",
    "\n",
    "counting_validation_results = []\n",
    "for clean_sample, patch_sample in tqdm(counting_validation_set):\n",
    "    result = validate_q_proj_ie_on_sample_pair(\n",
    "        mt=mt,\n",
    "        clean_sample=clean_sample,\n",
    "        patch_sample=patch_sample,\n",
    "        heads=heads_selected,\n",
    "        query_indices={-3:-3, -2:-2, -1:-1},\n",
    "        verify_head_behavior_on=None,\n",
    "    )\n",
    "    counting_validation_results.append(result)\n",
    "    print(\"=\" * 80)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d70b9553",
   "metadata": {},
   "outputs": [],
   "source": [
    "counter_patch_type_top_option = 0\n",
    "counting_failed_cases = []\n",
    "counting_all_cases = []\n",
    "\n",
    "for intervention_result in tqdm(counting_validation_results):\n",
    "    clean_sample = intervention_result[\"clean_sample\"]\n",
    "    patch_sample = intervention_result[\"patch_sample\"]\n",
    "    int_pred = intervention_result[\"int_predictions\"]\n",
    "    clean_pred = intervention_result[\"clean_predictions\"]\n",
    "    patch_pred = intervention_result[\"patch_predictions\"]\n",
    "\n",
    "    target_tok = clean_sample.prediction[0].token\n",
    "    target_obj = clean_sample.prediction[0].token_id\n",
    "\n",
    "    target_logit_before = [item.logit for item in patch_pred if item.token_id == target_obj][0]\n",
    "\n",
    "    target_logit_after = [item.logit for item in int_pred if item.token_id == clean_obj][0]\n",
    "    \n",
    "    target_logit_delta = target_logit_after - target_logit_before\n",
    "\n",
    "    if (\n",
    "        int_pred[0].token_id == clean_pred[0].token_id\n",
    "    ):\n",
    "        counter_patch_type_top_option += 1\n",
    "        failed = False\n",
    "    else:\n",
    "        counting_failed_cases.append(\n",
    "            {\n",
    "                \"clean_sample\": clean_sample,\n",
    "                \"patch_sample\": patch_sample,\n",
    "                \"int_track\": int_track,\n",
    "                \"clean_track\": clean_track\n",
    "            }\n",
    "        )\n",
    "        failed = True\n",
    "    counting_all_cases.append(\n",
    "            {\n",
    "                \"clean_sample\": clean_sample,\n",
    "                \"patch_sample\": patch_sample,\n",
    "                \"int_track\": int_track,\n",
    "                \"clean_track\": clean_track,\n",
    "                \"clean_logit_delta\": clean_logit_delta,\n",
    "                \"target_logit_delta\": target_logit_delta,\n",
    "                \"failed\": failed\n",
    "            }\n",
    "    )\n",
    "\n",
    "top_1_accuracy = counter_patch_type_top_option / len(counting_validation_results)\n",
    "\n",
    "print(\n",
    "    f\"Counterfactual patching accuracy: {top_1_accuracy:.4f}\",\n",
    "    f\"({counter_patch_type_top_option}/{len(counting_validation_results)})\"\n",
    ")\n",
    "\n",
    "print(f\"Average target logit delta: {sum(case['target_logit_delta'] for case in counting_all_cases) / len(counting_all_cases)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cabd608a",
   "metadata": {},
   "source": [
    "## Yes/No"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1c610e07",
   "metadata": {},
   "source": [
    "## Load the Task"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b78fb2b4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Reload the module to ensure YesNoTask is properly defined\n",
    "import importlib\n",
    "import src.selection.data\n",
    "\n",
    "from src.selection.data import YesNoTask\n",
    "\n",
    "TASK_CLS = YesNoTask\n",
    "N_DISTRACTORS = 5\n",
    "prompt_template_idx = 2\n",
    "OPTION_STYLE = \"single_line\"\n",
    "\n",
    "yes_no_task = TASK_CLS.load(\n",
    "    path=\"data_save/counting/fruits.json\"\n",
    ")\n",
    "\n",
    "print(yes_no_task)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0b99288b",
   "metadata": {},
   "source": [
    "### Construct a validation set"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3c4b13c7",
   "metadata": {},
   "source": [
    "We want to have n_options set to 6 as the standard across counting tasks to match the n_distractors=5 in the selection task.\n",
    "This means for the yes_no task that we need to to randomly have one of the objects fir the queried category.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ccdfc733",
   "metadata": {},
   "outputs": [],
   "source": [
    "yes_no_validation_set = []\n",
    "yes_no_validation_limit = 1024\n",
    "\n",
    "while len(yes_no_validation_set) < yes_no_validation_limit:\n",
    "    patch, clean = get_counterfactual_samples_within_counting_task(\n",
    "        task=yes_no_task,\n",
    "        n_options=5, # We want to have n_options be set to 6\n",
    "        verbose=False,\n",
    "        distinct_options=True,\n",
    "        filter_by_lm_prediction=True,\n",
    "    )\n",
    "    yes_no_validation_set.append((clean, patch))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d8356678",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "from src.selection.optimization import validate_q_proj_ie_on_sample_pair\n",
    "\n",
    "yes_no_validation_results = []\n",
    "for clean_sample, patch_sample in tqdm(yes_no_validation_set):\n",
    "    result = validate_q_proj_ie_on_sample_pair(\n",
    "        mt=mt,\n",
    "        clean_sample=clean_sample,\n",
    "        patch_sample=patch_sample,\n",
    "        heads=heads_selected,\n",
    "        query_indices={-3:-3, -2:-2, -1:-1},\n",
    "        verify_head_behavior_on=None,\n",
    "    )\n",
    "    yes_no_validation_results.append(result)\n",
    "    print(\"=\" * 80)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "69bb13fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "counter_patch_type_top_option = 0\n",
    "yes_no_failed_cases = []\n",
    "yes_no_all_cases = []\n",
    "\n",
    "for intervention_result in yes_no_validation_results:\n",
    "    clean_sample = intervention_result[\"clean_sample\"]\n",
    "    patch_sample = intervention_result[\"patch_sample\"]\n",
    "    int_pred = intervention_result[\"int_predictions\"]\n",
    "    clean_pred = intervention_result[\"clean_predictions\"]\n",
    "    patch_pred = intervention_result[\"patch_predictions\"]\n",
    "\n",
    "    target_tok = clean_sample.prediction[0].token\n",
    "    target_obj = clean_sample.prediction[0].token_id\n",
    "\n",
    "    target_logit_before = [item.logit for item in patch_pred if item.token_id == target_obj][0]\n",
    "\n",
    "    target_logit_after = [item.logit for item in int_pred if item.token_id == target_obj][0]\n",
    "    \n",
    "    target_logit_delta = target_logit_after - target_logit_before\n",
    "\n",
    "    if (\n",
    "        int_pred[0].token_id == clean_pred[0].token_id\n",
    "    ):\n",
    "        counter_patch_type_top_option += 1\n",
    "        failed = False\n",
    "    else:\n",
    "        yes_no_failed_cases.append(\n",
    "            {\n",
    "                \"clean_sample\": clean_sample,\n",
    "                \"patch_sample\": patch_sample,\n",
    "                #\"int_track\": int_track,\n",
    "                #\"clean_track\": clean_track\n",
    "            }\n",
    "        )\n",
    "        failed = True\n",
    "    yes_no_all_cases.append(\n",
    "            {\n",
    "                \"clean_sample\": clean_sample,\n",
    "                \"patch_sample\": patch_sample,\n",
    "                #\"int_track\": int_track,\n",
    "                #\"clean_track\": clean_track,\n",
    "                #\"clean_logit_delta\": clean_logit_delta,\n",
    "                \"target_logit_delta\": target_logit_delta,\n",
    "                \"failed\": failed\n",
    "            }\n",
    "    )\n",
    "\n",
    "yes_no_top_1_accuracy = counter_patch_type_top_option / len(yes_no_validation_results)\n",
    "\n",
    "print(\n",
    "    f\"Counterfactual patching accuracy: {yes_no_top_1_accuracy:.4f}\",\n",
    "    f\"({counter_patch_type_top_option}/{len(yes_no_validation_results)})\"\n",
    ")\n",
    "\n",
    "print(f\"Average target logit delta: {sum(case['target_logit_delta'] for case in yes_no_all_cases) / len(yes_no_all_cases)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5d6b1e35",
   "metadata": {},
   "source": [
    "So what do I need to change? I need to make it so that there is a roughly 50/50 chance of the list being structured such that it evokes a Yes vs a No. This might mean I have to add some conditional logic to the CountingSample code to suit the structure of the YesNoTask. "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9d73a4e3",
   "metadata": {},
   "source": [
    "## First fruit in the list"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ae3faf6b",
   "metadata": {},
   "source": [
    "### Load the task"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "da73b141",
   "metadata": {},
   "outputs": [],
   "source": [
    "import importlib\n",
    "import src.selection.data\n",
    "\n",
    "from src.selection.data import SelectFirstTask\n",
    "\n",
    "TASK_CLS = SelectFirstTask\n",
    "N_DISTRACTORS = 5\n",
    "prompt_template_idx = 1\n",
    "OPTION_STYLE = \"single_line\"\n",
    "\n",
    "select_first_task = TASK_CLS.load(\n",
    ")\n",
    "\n",
    "print(select_first_task.task_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a8d444cb",
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy\n",
    "import random\n",
    "from src.selection.data import SelectionSample, SelectOrderTask, SelectFirstTask\n",
    "from src.selection.utils import KeyedSet, get_first_token_id, verify_correct_option\n",
    "from src.tokens import prepare_input\n",
    "\n",
    "N_DISTRACTORS = 5\n",
    "\n",
    "@torch.inference_mode()\n",
    "def get_counterfactual_samples_within_first_task(\n",
    "    task: SelectFirstTask,\n",
    "    patch_category: str | None = None,\n",
    "    clean_category: str | None = None,\n",
    "    n_options: int = 5,\n",
    "    amt_to_sample: int = 2,\n",
    "    shuffle_clean_options: bool = False,\n",
    "    prompt_template_idx=2,\n",
    "    option_style=\"numbered\",\n",
    "    filter_by_lm_prediction: bool = True,\n",
    "    distinct_options: bool = True,\n",
    "    n_distractors: int = N_DISTRACTORS,\n",
    "    retry_count: int = 0\n",
    "):\n",
    "    # Get the categories\n",
    "    categories = list(task.category_wise_examples.keys())\n",
    "\n",
    "    # Set the patch category\n",
    "    if patch_category is None:\n",
    "        patch_category = random.choice(categories)\n",
    "\n",
    "    # Set the patch objects\n",
    "    patch_objects = random.sample(\n",
    "        task.category_wise_examples[patch_category], n_options\n",
    "    )\n",
    "\n",
    "    # Set the clean category\n",
    "    if clean_category is None:\n",
    "        clean_category = random.choice(list(set(categories) - {patch_category}))\n",
    "\n",
    "    # Set the clean objects\n",
    "    clean_objects = random.sample(\n",
    "        task.category_wise_examples[clean_category], n_options\n",
    "    )\n",
    "\n",
    "    # Set the other objects\n",
    "    other_objects = []\n",
    "    alt_other_objects = []\n",
    "    other_categories = random.sample(\n",
    "        list(set(categories) - {patch_category, clean_category}),\n",
    "        k=n_options,\n",
    "    )\n",
    "    ##print(f\"{other_categories=}\")\n",
    "    for other_category in other_categories:\n",
    "        other_examples = task.category_wise_examples[other_category]\n",
    "        rand_other_example = random.choice(other_examples)\n",
    "        other_objects.append(rand_other_example)\n",
    "        alt_other_example = random.choice(other_examples)\n",
    "        while rand_other_example == alt_other_example:\n",
    "            alt_other_example = random.choice(other_examples)\n",
    "        alt_other_objects.append(alt_other_example)\n",
    "    #print(f\"{other_objects=}\")\n",
    "    #print(f\"{alt_other_objects=}\")\n",
    "    \n",
    "    # Construct the clean and patch category options\n",
    "    clean_category_options = random.sample(\n",
    "        clean_objects,\n",
    "        amt_to_sample\n",
    "    )\n",
    "    #print(f\"{clean_category_options=}\")\n",
    "    patch_category_options = random.sample(\n",
    "        patch_objects,\n",
    "        amt_to_sample\n",
    "    )\n",
    "    #print(f\"{patch_category_options=}\")\n",
    "\n",
    "\n",
    "    if distinct_options is not True:\n",
    "    \n",
    "        # Combine the clean and patch options\n",
    "        combined_options = clean_category_options + patch_category_options\n",
    "    \n",
    "        # Add items form the other_objects list to pad it to n_options length\n",
    "        for i in range(n_options - len(combined_options)):\n",
    "            combined_options.append(other_objects.pop())\n",
    "\n",
    "        clean_options = combined_options\n",
    "        patch_options = combined_options\n",
    "\n",
    "    else:\n",
    "        # Get alternative patch options\n",
    "        alt_patch_category_options = random.sample(\n",
    "            [obj for obj in patch_objects if obj not in patch_category_options],\n",
    "            amt_to_sample\n",
    "        )\n",
    "        # Get alternative clean options\n",
    "        alt_clean_category_options = random.sample(\n",
    "            [obj for obj in clean_objects if obj not in clean_category_options],\n",
    "            amt_to_sample\n",
    "        )\n",
    "        # Compose prelimiary lists\n",
    "        clean_options = clean_category_options + alt_patch_category_options\n",
    "        patch_options = alt_clean_category_options + patch_category_options\n",
    "\n",
    "        # Add the filler other objects\n",
    "        for i in range(n_options - len(clean_options)):\n",
    "            clean_options.append(other_objects.pop())\n",
    "            patch_options.append(alt_other_objects.pop())\n",
    "\n",
    "        #print(f\"{clean_options=}\")\n",
    "        #print(f\"{patch_options=}\")\n",
    "\n",
    "        # So I have these two lists and I need to make sure that their values line up by category.\n",
    "        # So I start with two lists that definitely have their values line up: clean_category_options and patch_category_options\n",
    "        # And these two lists' items align categorically as well with alt_clean_category_options and alt_patch_category_options respectively.\n",
    "        # So now how to shuffle these in a random but synchronized way while adding more items to the list?\n",
    "        # Probably makes sense to add the additional items to the list first.\n",
    "        # That way I know everything is still in sync.\n",
    "        # Ok, now I need some sort of randomized indexing process.\n",
    "        # Maybe I can make a list of numbers to the range of the options\n",
    "        # And then randomly swap these indices \n",
    "        index_list = list(range(n_options))\n",
    "        #print(f\"{index_list=}\")\n",
    "        random.shuffle(index_list)\n",
    "        #print(f\"{index_list=}\")\n",
    "\n",
    "        clean_options = [clean_options[i] for i in index_list]\n",
    "        patch_options = [patch_options[i] for i in index_list]\n",
    "\n",
    "        #print(f\"{clean_options=}\")\n",
    "        #print(f\"{patch_options=}\")\n",
    "\n",
    "    # Gather the indices of the options of interest\n",
    "    clean_options_indices = []\n",
    "    for option in clean_category_options:\n",
    "        clean_options_indices.append(clean_options.index(option))\n",
    "    clean_first_idx = min(clean_options_indices)\n",
    "\n",
    "    # Store the information about the corresponsing object from patch options\n",
    "    patch_track_type_obj = patch_options[clean_first_idx]\n",
    "    patch_track_first_token_id = get_first_token_id(\n",
    "        patch_track_type_obj, mt.tokenizer, prefix=\" \"\n",
    "    )\n",
    "\n",
    "    patch_options_indices = []\n",
    "    for option in patch_category_options:\n",
    "        patch_options_indices.append(patch_options.index(option))\n",
    "    patch_first_idx = min(patch_options_indices)\n",
    "\n",
    "    # Store the information about the corresponding object from clean options\n",
    "    clean_track_type_obj = clean_options[patch_first_idx]\n",
    "    clean_track_first_token_id = get_first_token_id(\n",
    "        clean_track_type_obj, mt.tokenizer, prefix=\" \"\n",
    "    )\n",
    "    \n",
    "    kwargs = dict(\n",
    "        prompt_template=task.prompt_templates[prompt_template_idx],\n",
    "        default_option_style=option_style,\n",
    "    )\n",
    "\n",
    "    patch_metadata = {\n",
    "        \"track_category\": clean_category,\n",
    "        \"track_type_obj\": patch_track_type_obj,\n",
    "        \"track_type_obj_idx\": clean_first_idx,\n",
    "        \"track_type_obj_token_id\": patch_track_first_token_id,\n",
    "    }\n",
    "    clean_obj = clean_options[clean_first_idx]\n",
    "\n",
    "    clean_metadata = {\n",
    "        \"track_category\": patch_category,\n",
    "        \"track_type_obj\": clean_track_type_obj,\n",
    "        \"track_type_obj_idx\": patch_first_idx,\n",
    "        \"track_type_obj_token_id\": clean_track_first_token_id\n",
    "    }\n",
    "    patch_obj = patch_options[patch_first_idx]\n",
    "\n",
    "    patch_sample = SelectionSample(\n",
    "        obj=patch_obj,\n",
    "        answer=patch_obj,\n",
    "        obj_idx=patch_first_idx,\n",
    "        ans_token_id=get_first_token_id(patch_obj, mt.tokenizer, prefix=\" \"),\n",
    "        options=patch_options,\n",
    "        category=patch_category,\n",
    "        metadata=patch_metadata,\n",
    "        **kwargs\n",
    "    )\n",
    "\n",
    "    clean_sample = SelectionSample(\n",
    "        obj=clean_obj,\n",
    "        answer=clean_obj,\n",
    "        obj_idx=clean_first_idx,\n",
    "        ans_token_id=get_first_token_id(clean_obj, mt.tokenizer, prefix=\" \"),\n",
    "        options=clean_options,\n",
    "        category=clean_category,\n",
    "        metadata=clean_metadata,\n",
    "        **kwargs\n",
    "    )\n",
    "\n",
    "    if filter_by_lm_prediction:\n",
    "        test_samples = [patch_sample, clean_sample]\n",
    "\n",
    "        for sample in test_samples:\n",
    "            if retry_count >= 10: break\n",
    "            tokenized = prepare_input(tokenizer=mt, prompts=sample.prompt())\n",
    "            is_correct, predictions, track_options = verify_correct_option(\n",
    "                mt=mt, target=sample.ans_token_id, options=sample.options, input=tokenized\n",
    "            )\n",
    "            #print(f\"{is_correct=}\")\n",
    "            #print(f\"{predictions=}\")\n",
    "            #print(f\"{track_options=}\")\n",
    "\n",
    "            sample.metadata[\"tokenized\"] = tokenized.data\n",
    "            sample.metadata[\"predictions\"] = predictions\n",
    "\n",
    "            if not is_correct:\n",
    "                logger.error(\n",
    "                    f\"Prediction mismatch!\\n\"\n",
    "                    f\"Retry Count: {retry_count+1}\"\n",
    "                )\n",
    "                return get_counterfactual_samples_within_first_task(\n",
    "                    task=task,\n",
    "                    n_options=n_options,\n",
    "                    filter_by_lm_prediction=True,\n",
    "                    prompt_template_idx=prompt_template_idx,\n",
    "                    option_style=option_style,\n",
    "                    n_distractors=n_distractors,\n",
    "                    retry_count=retry_count+1,\n",
    "                )\n",
    "            sample.prediction = predictions\n",
    "\n",
    "    return patch_sample, clean_sample\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "980b0cd9",
   "metadata": {},
   "source": [
    "### Construct a validation set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ad7a39c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "select_first_validation_set = []\n",
    "select_first_validation_limit = 100\n",
    "\n",
    "while len(select_first_validation_set) < select_first_validation_limit:\n",
    "    patch, clean = get_counterfactual_samples_within_first_task(\n",
    "        task=select_first_task,\n",
    "        filter_by_lm_prediction=True,\n",
    "        prompt_template_idx=prompt_template_idx,\n",
    "        option_style=OPTION_STYLE,\n",
    "        distinct_options=True,\n",
    "        n_distractors=N_DISTRACTORS\n",
    "    )\n",
    "    #select_first_validation_set.append([])\n",
    "    select_first_validation_set.append((clean, patch))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0aa429b5",
   "metadata": {},
   "source": [
    "### Test the accuracy when we patch in our filter heads"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "41b4aded",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "from src.selection.optimization import validate_q_proj_ie_on_sample_pair\n",
    "\n",
    "select_first_val_results = []\n",
    "for clean_sample, patch_sample in tqdm(select_first_validation_set):\n",
    "    result = validate_q_proj_ie_on_sample_pair(\n",
    "        mt=mt,\n",
    "        clean_sample=clean_sample,\n",
    "        patch_sample=patch_sample,\n",
    "        heads=heads_selected,\n",
    "        query_indices={-3:-3, -2:-2, -1:-1},\n",
    "        verify_head_behavior_on=None,\n",
    "    )\n",
    "    select_first_val_results.append(result)\n",
    "    print(\"=\" * 80)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "639582b2",
   "metadata": {},
   "outputs": [],
   "source": [
    "counter_patch_type_top_option = 0\n",
    "failed_cases = []\n",
    "all_cases = []\n",
    "\n",
    "for intervention_result in tqdm(select_first_val_results):\n",
    "    clean_sample = intervention_result[\"clean_sample\"]\n",
    "    patch_sample = intervention_result[\"patch_sample\"]\n",
    "    int_track = intervention_result[\"int_track\"]\n",
    "    clean_track = intervention_result[\"clean_track\"]\n",
    "\n",
    "    clean_obj = clean_sample.ans_token_id\n",
    "    target_obj = clean_sample.metadata[\"track_type_obj_token_id\"]\n",
    "\n",
    "    clean_logit_before = clean_track[clean_obj][1].logit\n",
    "    clean_logit_after = int_track[clean_obj][1].logit\n",
    "\n",
    "    target_logit_before = clean_track[target_obj][1].logit\n",
    "    target_logit_after = int_track[target_obj][1].logit\n",
    "\n",
    "    clean_logit_delta = clean_logit_after - clean_logit_before\n",
    "    target_logit_delta = target_logit_after - target_logit_before\n",
    "\n",
    "    if (\n",
    "        int_track[list(int_track.keys())[0]][1].token_id\n",
    "        == clean_sample.metadata[\"track_type_obj_token_id\"] # In the updated data this will actually have the position token id rather than the pexisitng\n",
    "    ):\n",
    "        counter_patch_type_top_option += 1\n",
    "        failed = False\n",
    "    else:\n",
    "        failed_cases.append(\n",
    "            {\n",
    "                \"clean_sample\": clean_sample,\n",
    "                \"patch_sample\": patch_sample,\n",
    "                \"int_track\": int_track,\n",
    "                \"clean_track\": clean_track\n",
    "            }\n",
    "        )\n",
    "        failed = True\n",
    "    all_cases.append(\n",
    "            {\n",
    "                \"clean_sample\": clean_sample,\n",
    "                \"patch_sample\": patch_sample,\n",
    "                \"int_track\": int_track,\n",
    "                \"clean_track\": clean_track,\n",
    "                \"clean_logit_delta\": clean_logit_delta,\n",
    "                \"target_logit_delta\": target_logit_delta,\n",
    "                \"failed\": failed\n",
    "            }\n",
    "    )\n",
    "\n",
    "top_1_accuracy = counter_patch_type_top_option / len(select_first_val_results)\n",
    "\n",
    "print(\n",
    "    f\"Counterfactual patching accuracy: {top_1_accuracy:.4f}\",\n",
    "    f\"({counter_patch_type_top_option}/{len(select_first_val_results)})\"\n",
    ")\n",
    "\n",
    "print(f\"Average clean logit delta: {sum(case['clean_logit_delta'] for case in all_cases) / len(all_cases)}\")\n",
    "print(f\"Average target logit delta: {sum(case['target_logit_delta'] for case in all_cases) / len(all_cases)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "07c7cd89",
   "metadata": {},
   "source": [
    "## Select One - Spanish"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "73742a07",
   "metadata": {},
   "source": [
    "### Load the Selection Task"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5a9c9316",
   "metadata": {},
   "outputs": [],
   "source": [
    "TASK_CLS = SelectOneTask\n",
    "N_DISTRACTORS = 5\n",
    "prompt_template_idx = 3\n",
    "OPTION_STYLE = \"single_line\"\n",
    "\n",
    "select_task_spanish = TASK_CLS.load(\n",
    "    path=\"data_save/selection/objects_spanish.json\"\n",
    ")\n",
    "\n",
    "print(select_task_spanish.task_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "68c30bde",
   "metadata": {},
   "source": [
    "### Construct a validation set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "385103f9",
   "metadata": {},
   "outputs": [],
   "source": [
    "spanish_validation_set = []\n",
    "validation_limit = 1024\n",
    "\n",
    "while len(spanish_validation_set) < validation_limit:\n",
    "    patch, clean = get_counterfactual_samples_within_task(\n",
    "        task=select_task_spanish,\n",
    "        filter_by_lm_prediction=True,\n",
    "        prompt_template_idx=prompt_template_idx,\n",
    "        option_style=OPTION_STYLE,\n",
    "        distinct_options=True,\n",
    "        n_distractors=N_DISTRACTORS\n",
    "    )\n",
    "    spanish_validation_set.append((clean, patch))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5c90c037",
   "metadata": {},
   "source": [
    "### Test the accuracy when we patch in our filter heads"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "86b8dac0",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "from src.selection.optimization import validate_q_proj_ie_on_sample_pair\n",
    "\n",
    "spanish_validation_results = []\n",
    "for clean_sample, patch_sample in tqdm(spanish_validation_set):\n",
    "    result = validate_q_proj_ie_on_sample_pair(\n",
    "        mt=mt,\n",
    "        clean_sample=clean_sample,\n",
    "        patch_sample=patch_sample,\n",
    "        heads=heads_selected,\n",
    "        query_indices={-3:-3, -2:-2, -1:-1},\n",
    "        verify_head_behavior_on=None,\n",
    "    )\n",
    "    spanish_validation_results.append(result)\n",
    "    print(\"=\" * 80)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "257db153",
   "metadata": {},
   "outputs": [],
   "source": [
    "counter_patch_type_top_option = 0\n",
    "failed_cases = []\n",
    "all_cases = []\n",
    "\n",
    "for intervention_result in tqdm(spanish_validation_results[:1]):\n",
    "    clean_sample = intervention_result[\"clean_sample\"]\n",
    "    patch_sample = intervention_result[\"patch_sample\"]\n",
    "    int_track = intervention_result[\"int_track\"]\n",
    "    clean_track = intervention_result[\"clean_track\"]\n",
    "\n",
    "    print(f\"{len(int_track)=}\")\n",
    "    print(f\"{len(clean_track)=}\")\n",
    "\n",
    "    clean_obj = clean_sample.ans_token_id\n",
    "    target_obj = clean_sample.metadata[\"track_type_obj_token_id\"]\n",
    "\n",
    "    clean_logit_before = clean_track[clean_obj][1].logit\n",
    "    clean_logit_after = int_track[clean_obj][1].logit\n",
    "\n",
    "    target_logit_before = clean_track[target_obj][1].logit\n",
    "    target_logit_after = int_track[target_obj][1].logit\n",
    "\n",
    "    clean_logit_delta = clean_logit_after - clean_logit_before\n",
    "    target_logit_delta = target_logit_after - target_logit_before\n",
    "\n",
    "    if (\n",
    "        int_track[list(int_track.keys())[0]][1].token_id\n",
    "        == clean_sample.metadata[\"track_type_obj_token_id\"] # In the updated data this will actually have the position token id rather than the pexisitng\n",
    "    ):\n",
    "        counter_patch_type_top_option += 1\n",
    "        failed = False\n",
    "    else:\n",
    "        failed_cases.append(\n",
    "            {\n",
    "                \"clean_sample\": clean_sample,\n",
    "                \"patch_sample\": patch_sample,\n",
    "                \"int_track\": int_track,\n",
    "                \"clean_track\": clean_track\n",
    "            }\n",
    "        )\n",
    "        failed = True\n",
    "    all_cases.append(\n",
    "            {\n",
    "                \"clean_sample\": clean_sample,\n",
    "                \"patch_sample\": patch_sample,\n",
    "                \"int_track\": int_track,\n",
    "                \"clean_track\": clean_track,\n",
    "                \"clean_logit_delta\": clean_logit_delta,\n",
    "                \"target_logit_delta\": target_logit_delta,\n",
    "                \"failed\": failed\n",
    "            }\n",
    "    )\n",
    "\n",
    "top_1_accuracy = counter_patch_type_top_option / len(spanish_validation_results)\n",
    "\n",
    "print(\n",
    "    f\"Counterfactual patching accuracy: {top_1_accuracy:.4f}\",\n",
    "    f\"({counter_patch_type_top_option}/{len(spanish_validation_results)})\"\n",
    ")\n",
    "\n",
    "print(f\"Average clean logit delta: {sum(case['clean_logit_delta'] for case in all_cases) / len(all_cases)}\")\n",
    "print(f\"Average target logit delta: {sum(case['target_logit_delta'] for case in all_cases) / len(all_cases)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "794bac07",
   "metadata": {},
   "source": [
    "## Select One - French"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b83eb411",
   "metadata": {},
   "source": [
    "### Load the Selection Task"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "304d9280",
   "metadata": {},
   "outputs": [],
   "source": [
    "TASK_CLS = SelectOneTask\n",
    "N_DISTRACTORS = 5\n",
    "prompt_template_idx = 3\n",
    "OPTION_STYLE = \"single_line\"\n",
    "\n",
    "select_task_french = TASK_CLS.load(\n",
    "    path=\"data_save/selection/objects_french.json\"\n",
    ")\n",
    "\n",
    "print(select_task_french.task_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3108598e",
   "metadata": {},
   "source": [
    "### Construct a validation set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1767221",
   "metadata": {},
   "outputs": [],
   "source": [
    "french_validation_set = []\n",
    "validation_limit = 1024\n",
    "\n",
    "while len(french_validation_set) < validation_limit:\n",
    "    patch, clean = get_counterfactual_samples_within_task(\n",
    "        task=select_task_french,\n",
    "        filter_by_lm_prediction=True,\n",
    "        prompt_template_idx=prompt_template_idx,\n",
    "        option_style=OPTION_STYLE,\n",
    "        distinct_options=True,\n",
    "        n_distractors=N_DISTRACTORS\n",
    "    )\n",
    "    french_validation_set.append((clean, patch))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6cfe5b3d",
   "metadata": {},
   "source": [
    "### Test the accuracy when we patch in our filter heads"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5fdb188c",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "from src.selection.optimization import validate_q_proj_ie_on_sample_pair\n",
    "\n",
    "french_validation_results = []\n",
    "for clean_sample, patch_sample in tqdm(french_validation_set):\n",
    "    result = validate_q_proj_ie_on_sample_pair(\n",
    "        mt=mt,\n",
    "        clean_sample=clean_sample,\n",
    "        patch_sample=patch_sample,\n",
    "        heads=heads_selected,\n",
    "        query_indices={-3:-3, -2:-2, -1:-1},\n",
    "        verify_head_behavior_on=None,\n",
    "    )\n",
    "    french_validation_results.append(result)\n",
    "    print(\"=\" * 80)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d3d2cf0",
   "metadata": {},
   "outputs": [],
   "source": [
    "counter_patch_type_top_option = 0\n",
    "failed_cases = []\n",
    "\n",
    "for intervention_result in tqdm(french_validation_results):\n",
    "    clean_sample = intervention_result[\"clean_sample\"]\n",
    "    patch_sample = intervention_result[\"patch_sample\"]\n",
    "    int_track = intervention_result[\"int_track\"]\n",
    "    clean_track = intervention_result[\"clean_track\"]\n",
    "    if (\n",
    "        int_track[list(int_track.keys())[0]][1].token_id\n",
    "        == clean_sample.metadata[\"track_type_obj_token_id\"] # In the updated data this will actually have the position token id rather than the pexisitng\n",
    "    ):\n",
    "        counter_patch_type_top_option += 1\n",
    "    else:\n",
    "        failed_cases.append(\n",
    "            {\n",
    "                \"clean_sample\": clean_sample,\n",
    "                \"patch_sample\": patch_sample,\n",
    "                \"int_track\": int_track,\n",
    "                \"clean_track\": clean_track\n",
    "            }\n",
    "        )\n",
    "\n",
    "top_1_accuracy = counter_patch_type_top_option / len(french_validation_results)\n",
    "\n",
    "print(\n",
    "    f\"Counterfactual patching accuracy: {top_1_accuracy:.4f}\",\n",
    "    f\"({counter_patch_type_top_option}/{len(french_validation_results)})\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ebd7d890",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "744df2ce",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "ccf7527a",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "3af284c7",
   "metadata": {},
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python (retrieval2)",
   "language": "python",
   "name": "retrieval2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
