{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b18e648",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4fa1fcd2",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "import json\n",
    "import sys\n",
    "\n",
    "sys.path.append(\"../\")\n",
    "\n",
    "##################################################################\n",
    "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0,1,2,3,4,5,6,7\"\n",
    "##################################################################\n",
    "\n",
    "import logging\n",
    "from src.utils import logging_utils\n",
    "from src.utils import env_utils\n",
    "\n",
    "logger = logging.getLogger(__name__)\n",
    "\n",
    "logging.basicConfig(\n",
    "    level=logging.DEBUG,\n",
    "    format=logging_utils.DEFAULT_FORMAT,\n",
    "    datefmt=logging_utils.DEFAULT_DATEFMT,\n",
    "    stream=sys.stdout,\n",
    ")\n",
    "\n",
    "import torch\n",
    "import transformers\n",
    "\n",
    "logger.info(f\"{torch.__version__=}, {torch.version.cuda=}\")\n",
    "logger.info(\n",
    "    f\"{torch.cuda.is_available()=}, {torch.cuda.device_count()=}, {torch.cuda.get_device_name()=}\"\n",
    ")\n",
    "logger.info(f\"{transformers.__version__=}\")\n",
    "\n",
    "from src.utils.training_utils import get_device_map\n",
    "\n",
    "model_key = \"meta-llama/Llama-3.3-70B-Instruct\"\n",
    "\n",
    "device_map = get_device_map(model_key, 32, n_gpus=8)\n",
    "print(device_map)\n",
    "\n",
    "print(os.getcwd())\n",
    "\n",
    "from src.models import ModelandTokenizer\n",
    "\n",
    "mt = ModelandTokenizer(\n",
    "    model_key=model_key,\n",
    "    torch_dtype=torch.bfloat16,\n",
    "    device_map=device_map,\n",
    ")\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "826d56b5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from src.selection.data import SelectOneTask\n",
    "\n",
    "optimized_path = os.path.join(\n",
    "    \"results/\"\n",
    "    \"selection/optimized_heads\",\n",
    "    mt.name.split(\"/\")[-1],\n",
    "    \"distinct_options\",\n",
    "    \"select_one\",\n",
    "    \"epoch_10.npz\"\n",
    ")\n",
    "optimization_results = np.load(optimized_path, allow_pickle=True)\n",
    "optimal_head_mask = torch.tensor(optimization_results['optimal_mask']).to(torch.float32)\n",
    "optimal_head_mask[50:, :] = 0.0\n",
    "\n",
    "plt.imshow(\n",
    "    optimal_head_mask.T.numpy(),\n",
    "    cmap=\"Blues\",\n",
    "    aspect=\"auto\",\n",
    "    vmin=0,\n",
    "    vmax=1,\n",
    ")\n",
    "\n",
    "heads_selected = torch.nonzero(optimal_head_mask > 0.5, as_tuple=False).tolist()\n",
    "heads_selected = [\n",
    "    (layer_idx, head_idx) for layer_idx, head_idx in heads_selected\n",
    "]\n",
    "print(len(heads_selected))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2c945f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy\n",
    "import random\n",
    "from src.selection.utils import KeyedSet, get_first_token_id, verify_correct_option\n",
    "from src.selection.data import SelectionSample\n",
    "from src.functional import predict_next_token\n",
    "from src.tokens import prepare_input\n",
    "\n",
    "######################################################################\n",
    "N_DISTRACTORS = 5\n",
    "######################################################################\n",
    "\n",
    "\n",
    "@torch.inference_mode()\n",
    "def get_counterfactual_samples_within_task(\n",
    "    task,\n",
    "    patch_category: str | None = None,\n",
    "    clean_category: str | None = None,\n",
    "    shuffle_clean_options: bool = False,\n",
    "    clean_prompt_template_idx=2,\n",
    "    patch_prompt_template_idx=2,\n",
    "    clean_option_style=\"single_line\",\n",
    "    patch_option_style=\"single_line\",\n",
    "    filter_by_lm_prediction: bool = True,\n",
    "    distinct_options: bool = False,\n",
    "    n_distractors: int = N_DISTRACTORS,\n",
    "):\n",
    "    categories = list(task.category_wise_examples.keys())\n",
    "    if patch_category is None:\n",
    "        patch_category = random.choice(categories)\n",
    "\n",
    "    patch_subj, patch_obj = random.sample(\n",
    "        task.category_wise_examples[patch_category], 2\n",
    "    )\n",
    "    logger.info(\n",
    "        f\"Patch category: {patch_category}, subject: {patch_subj}, object: {patch_obj}\"\n",
    "    )\n",
    "\n",
    "    if clean_category is None:\n",
    "        clean_category = random.choice(list(set(categories) - {patch_category}))\n",
    "\n",
    "    clean_options = task.category_wise_examples[clean_category]\n",
    "    random.shuffle(clean_options)\n",
    "\n",
    "    clean_subj, clean_obj = random.sample(\n",
    "        (\n",
    "            KeyedSet(clean_options, mt.tokenizer) - KeyedSet([patch_obj], mt.tokenizer)\n",
    "        ).values,\n",
    "        2,\n",
    "    )\n",
    "    logger.info(\n",
    "        f\"Clean category: {clean_category}, subject: {clean_subj}, object: {clean_obj}\"\n",
    "    )\n",
    "\n",
    "    if distinct_options is False:\n",
    "        patch_type_obj = patch_obj\n",
    "        clean_type_obj = clean_obj\n",
    "    else:\n",
    "        patch_type_obj = random.choice(\n",
    "            (\n",
    "                KeyedSet(task.category_wise_examples[patch_category], mt.tokenizer)\n",
    "                - KeyedSet([patch_obj], mt.tokenizer)\n",
    "            ).values\n",
    "        )\n",
    "        clean_type_obj = random.choice(\n",
    "            (\n",
    "                KeyedSet(task.category_wise_examples[clean_category], mt.tokenizer)\n",
    "                - KeyedSet([clean_obj], mt.tokenizer)\n",
    "            ).values\n",
    "        )\n",
    "\n",
    "    patch_must_have_options = [patch_obj, clean_type_obj]\n",
    "    clean_must_have_options = [clean_obj, patch_type_obj]\n",
    "\n",
    "    logger.info(f\"{patch_must_have_options=}\")\n",
    "    logger.info(f\"{clean_must_have_options=}\")\n",
    "    logger.info(f\"{clean_type_obj=}\")\n",
    "    logger.info(f\"{patch_type_obj=}\")\n",
    "\n",
    "    patch_distractors = []\n",
    "    other_categories = random.sample(\n",
    "        list(set(categories) - {patch_category, clean_category}),\n",
    "        k=n_distractors - (len(patch_must_have_options)) + 1,\n",
    "    )\n",
    "\n",
    "    for other_category in other_categories:\n",
    "        other_examples = task.category_wise_examples[other_category]\n",
    "        random.shuffle(other_examples)\n",
    "        other_examples = KeyedSet(other_examples, mt.tokenizer)\n",
    "        patch_distractors.append(\n",
    "            random.choice(\n",
    "                (\n",
    "                    other_examples\n",
    "                    - KeyedSet(\n",
    "                        patch_must_have_options + patch_distractors,\n",
    "                        tokenizer=mt.tokenizer,\n",
    "                    )\n",
    "                ).values\n",
    "            )\n",
    "        )\n",
    "\n",
    "    patch_options = patch_must_have_options + patch_distractors\n",
    "    random.shuffle(patch_options)\n",
    "    patch_obj_idx = patch_options.index(patch_obj)\n",
    "    logger.info(f\"{patch_obj_idx=} | {patch_options}\")\n",
    "\n",
    "    if distinct_options is not True:\n",
    "        clean_options = copy.deepcopy(patch_options)\n",
    "        if shuffle_clean_options:\n",
    "            # Useful for the pointer experiments\n",
    "            while (\n",
    "                clean_options.index(clean_obj) == patch_obj_idx\n",
    "                or clean_options.index(patch_type_obj) == patch_obj_idx\n",
    "            ):\n",
    "                random.shuffle(clean_options)\n",
    "        clean_obj_idx = clean_options.index(clean_obj)\n",
    "\n",
    "    else:\n",
    "        other_categories = random.sample(\n",
    "            list(set(categories) - {patch_category, clean_category}),\n",
    "            k=N_DISTRACTORS - (len(clean_must_have_options)) + 1,\n",
    "        )\n",
    "        clean_distractors = []\n",
    "        for other_category in other_categories:\n",
    "            other_examples = task.category_wise_examples[other_category]\n",
    "            random.shuffle(other_examples)\n",
    "            other_examples = KeyedSet(other_examples, mt.tokenizer)\n",
    "            clean_distractors.append(\n",
    "                random.choice(\n",
    "                    (\n",
    "                        other_examples\n",
    "                        - KeyedSet(\n",
    "                            clean_must_have_options + clean_distractors,\n",
    "                            tokenizer=mt.tokenizer,\n",
    "                        )\n",
    "                    ).values\n",
    "                )\n",
    "            )\n",
    "        clean_options = clean_must_have_options + clean_distractors\n",
    "        random.shuffle(clean_options)\n",
    "        while clean_options.index(clean_obj) == patch_obj_idx:\n",
    "            random.shuffle(clean_options)\n",
    "        clean_obj_idx = clean_options.index(clean_obj)\n",
    "\n",
    "    logger.info(f\"{clean_obj_idx=} | {clean_options}\")\n",
    "\n",
    "    print(f\"{type(task)=}\")\n",
    "    if isinstance(task, SelectOneTask):\n",
    "        patch_metadata = {\n",
    "            \"track_category\": clean_category,\n",
    "            \"track_type_obj\": clean_type_obj,\n",
    "            \"track_type_obj_idx\": patch_options.index(clean_type_obj),\n",
    "            \"track_type_obj_token_id\": get_first_token_id(\n",
    "                clean_type_obj, mt.tokenizer, prefix=\" \"\n",
    "            ),\n",
    "        }\n",
    "        clean_metadata = {\n",
    "            \"track_category\": patch_category,\n",
    "            \"track_type_obj\": patch_type_obj,\n",
    "            \"track_type_obj_idx\": clean_options.index(patch_type_obj),\n",
    "            \"track_type_obj_token_id\": get_first_token_id(\n",
    "                patch_type_obj, mt.tokenizer, prefix=\" \"\n",
    "            ),\n",
    "        }\n",
    "    else:\n",
    "        raise NotImplementedError(f\"Unsupported task type: {type(task)}\")\n",
    "\n",
    "    patch_sample = SelectionSample(\n",
    "        subj=patch_subj,\n",
    "        obj=patch_obj,\n",
    "        answer=patch_obj,\n",
    "        obj_idx=patch_obj_idx,\n",
    "        ans_token_id=get_first_token_id(patch_obj, mt.tokenizer, prefix=\" \"),\n",
    "        options=patch_options,\n",
    "        category=patch_category,\n",
    "        metadata=patch_metadata,\n",
    "        prompt_template = task.prompt_templates[patch_prompt_template_idx],\n",
    "        default_option_style=patch_option_style,\n",
    "    )\n",
    "    clean_sample = SelectionSample(\n",
    "        subj=clean_subj,\n",
    "        obj=clean_obj,\n",
    "        answer=clean_obj,\n",
    "        obj_idx=clean_obj_idx,\n",
    "        ans_token_id=get_first_token_id(clean_obj, mt.tokenizer, prefix=\" \"),\n",
    "        options=clean_options,\n",
    "        category=clean_category,\n",
    "        metadata=clean_metadata,\n",
    "        prompt_template = task.prompt_templates[clean_prompt_template_idx],\n",
    "        default_option_style=clean_option_style,\n",
    "    )\n",
    "\n",
    "    if filter_by_lm_prediction:\n",
    "        test_samples = [patch_sample, clean_sample]\n",
    "        if distinct_options is True:\n",
    "            clean_sample_2 = copy.deepcopy(patch_sample)\n",
    "            clean_sample_2.options = clean_options\n",
    "            clean_sample_2.obj = clean_sample.metadata[\"track_type_obj\"]\n",
    "            clean_sample_2.obj_idx = clean_sample.metadata[\"track_type_obj_idx\"]\n",
    "            clean_sample_2.ans_token_id = clean_sample.metadata[\n",
    "                \"track_type_obj_token_id\"\n",
    "            ]\n",
    "            test_samples.append(clean_sample_2)\n",
    "\n",
    "        for sample in test_samples:\n",
    "            tokenized = prepare_input(tokenizer=mt, prompts=sample.prompt())\n",
    "            is_correct, predictions, track_options = verify_correct_option(\n",
    "                mt=mt, target=sample.obj, options=sample.options, input=tokenized\n",
    "            )\n",
    "            sample.metadata[\"tokenized\"] = tokenized.data\n",
    "            logger.info(sample.prompt())\n",
    "            logger.info(\n",
    "                f\"{sample.subj} | {sample.category} -> {sample.obj} | pred={[str(p) for p in predictions]}\"\n",
    "            )\n",
    "            if not is_correct:\n",
    "                logger.error(\n",
    "                    f'Prediction mismatch: {track_options[list(track_options.keys())[0]]}[\"{mt.tokenizer.decode(predictions[0].token_id)}\"] != {sample.ans_token_id}[\"{mt.tokenizer.decode(sample.ans_token_id)}\"]'\n",
    "                )\n",
    "                return get_counterfactual_samples_within_task(\n",
    "                    task=task,\n",
    "                    patch_category=patch_category,\n",
    "                    clean_category=clean_category,\n",
    "                    shuffle_clean_options=shuffle_clean_options,\n",
    "                    clean_prompt_template_idx=clean_prompt_template_idx,\n",
    "                    patch_prompt_template_idx=patch_prompt_template_idx,\n",
    "                    clean_option_style=clean_option_style,\n",
    "                    patch_option_style=patch_option_style,\n",
    "                    filter_by_lm_prediction=filter_by_lm_prediction,\n",
    "                    distinct_options=distinct_options,\n",
    "                    n_distractors=n_distractors,\n",
    "                )\n",
    "            sample.prediction = predictions\n",
    "\n",
    "    return patch_sample, clean_sample"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e600a832",
   "metadata": {},
   "source": [
    "## Question: before -> after & after -> before"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bff6cb6a",
   "metadata": {},
   "outputs": [],
   "source": [
    "TASK_CLS = SelectOneTask\n",
    "N_DISTRACTORS = 5\n",
    "prompt_template_idx = 3\n",
    "OPTION_STYLE = \"single_line\"\n",
    "\n",
    "select_task = TASK_CLS.load(\n",
    "    path=\"data_save/selection/objects.json\"\n",
    ")\n",
    "\n",
    "print(select_task.task_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ac5cf025",
   "metadata": {},
   "outputs": [],
   "source": [
    "validation_set = []\n",
    "validation_limit = 100\n",
    "\n",
    "while len(validation_set) < validation_limit:\n",
    "    patch, clean = get_counterfactual_samples_within_task(\n",
    "        task=select_task,\n",
    "        clean_prompt_template_idx=3,\n",
    "        patch_prompt_template_idx=2,\n",
    "        filter_by_lm_prediction=True,\n",
    "        distinct_options=True,\n",
    "        n_distractors=N_DISTRACTORS\n",
    "    )\n",
    "    validation_set.append((clean, patch))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ebffd00f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "from src.tokens import find_token_range\n",
    "from src.selection.optimization import validate_q_proj_ie_on_sample_pair\n",
    "\n",
    "validation_results = []\n",
    "for clean_sample, patch_sample in tqdm(validation_set):\n",
    "\n",
    "    patch_ques = (\n",
    "        find_token_range(\n",
    "            tokenizer=mt.tokenizer,\n",
    "            string=patch_sample.prompt(),\n",
    "            substring=\"?\",\n",
    "            occurrence=-1,\n",
    "        )[1]\n",
    "        - 1\n",
    "    )\n",
    "    clean_ques = (\n",
    "        find_token_range(\n",
    "            tokenizer=mt.tokenizer,\n",
    "            string=clean_sample.prompt(),\n",
    "            substring=\"?\",\n",
    "            occurrence=-1,\n",
    "        )[1]\n",
    "        - 1\n",
    "    )\n",
    "    logger.debug(\n",
    "        f\"{patch_ques=} | \\\"{mt.tokenizer.decode(patch_sample.metadata['tokenized']['input_ids'][0][patch_ques])}\\\"\"\n",
    "    )\n",
    "    logger.debug(\n",
    "        f\"{clean_ques=} | \\\"{mt.tokenizer.decode(clean_sample.metadata['tokenized']['input_ids'][0][clean_ques])}\\\"\"\n",
    "    )\n",
    "\n",
    "    result = validate_q_proj_ie_on_sample_pair(\n",
    "        mt=mt,\n",
    "        clean_sample=clean_sample,\n",
    "        patch_sample=patch_sample,\n",
    "        heads=heads_selected,\n",
    "        query_indices={\n",
    "            #patch_ques: clean_ques,\n",
    "            -3:-3,\n",
    "            -2:-2,\n",
    "            -1:-1\n",
    "        },\n",
    "        verify_head_behavior_on=None,\n",
    "        ablate_possible_ans_info_from_options=True,\n",
    "    )\n",
    "    validation_results.append(result)\n",
    "    print(\"=\" * 80)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f78f02ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "counter_patch_type_top_option = 0\n",
    "failed_cases = []\n",
    "all_cases = []\n",
    "\n",
    "for intervention_result in tqdm(validation_results):\n",
    "    clean_sample = intervention_result[\"clean_sample\"]\n",
    "    patch_sample = intervention_result[\"patch_sample\"]\n",
    "    int_track = intervention_result[\"int_track\"]\n",
    "    clean_track = intervention_result[\"clean_track\"]\n",
    "\n",
    "    clean_obj = clean_sample.ans_token_id\n",
    "    target_obj = clean_sample.metadata[\"track_type_obj_token_id\"]\n",
    "\n",
    "    clean_logit_before = clean_track[clean_obj][1].logit\n",
    "    clean_logit_after = int_track[clean_obj][1].logit\n",
    "\n",
    "    target_logit_before = clean_track[target_obj][1].logit\n",
    "    target_logit_after = int_track[target_obj][1].logit\n",
    "\n",
    "    clean_logit_delta = clean_logit_after - clean_logit_before\n",
    "    target_logit_delta = target_logit_after - target_logit_before\n",
    "\n",
    "    if (\n",
    "        int_track[list(int_track.keys())[0]][1].token_id\n",
    "        == clean_sample.metadata[\"track_type_obj_token_id\"]\n",
    "    ):\n",
    "        counter_patch_type_top_option += 1\n",
    "        failed = False\n",
    "    else:\n",
    "        failed_cases.append(\n",
    "            {\n",
    "                \"clean_sample\": clean_sample,\n",
    "                \"patch_sample\": patch_sample,\n",
    "                \"int_track\": int_track,\n",
    "                \"clean_track\": clean_track\n",
    "            }\n",
    "        )\n",
    "        failed = True\n",
    "    all_cases.append(\n",
    "            {\n",
    "                \"clean_sample\": clean_sample,\n",
    "                \"patch_sample\": patch_sample,\n",
    "                \"int_track\": int_track,\n",
    "                \"clean_track\": clean_track,\n",
    "                \"clean_logit_delta\": clean_logit_delta,\n",
    "                \"target_logit_delta\": target_logit_delta,\n",
    "                \"failed\": failed\n",
    "            }\n",
    "    )\n",
    "\n",
    "top_1_accuracy = counter_patch_type_top_option / len(validation_results)\n",
    "\n",
    "print(\n",
    "    f\"Counterfactual patching accuracy: {top_1_accuracy:.4f}\",\n",
    "    f\"({counter_patch_type_top_option}/{len(validation_results)})\"\n",
    ")\n",
    "\n",
    "print(f\"Average clean logit delta: {sum(case['clean_logit_delta'] for case in all_cases) / len(all_cases)}\")\n",
    "print(f\"Average target logit delta: {sum(case['target_logit_delta'] for case in all_cases) / len(all_cases)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db80523d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.data import SelectOrderTask\n",
    "\n",
    "@torch.inference_mode()\n",
    "def get_counterfactual_samples_within_task(\n",
    "    task: SelectOneTask | SelectOrderTask = select_task,\n",
    "    patch_category: str | None = None,\n",
    "    clean_category: str | None = None,\n",
    "    shuffle_clean_options: bool = False,\n",
    "    clean_prompt_template_idx=2,\n",
    "    patch_prompt_template_idx=3,\n",
    "    clean_option_style=\"numbered\",\n",
    "    patch_option_style=\"single_line\",\n",
    "    filter_by_lm_prediction: bool = True,\n",
    "    distinct_options: bool = False,\n",
    "    n_distractors: int = N_DISTRACTORS,\n",
    "):\n",
    "    categories = list(task.category_wise_examples.keys())\n",
    "    if patch_category is None:\n",
    "        patch_category = random.choice(categories)\n",
    "\n",
    "    patch_subj, patch_obj = random.sample(\n",
    "        task.category_wise_examples[patch_category], 2\n",
    "    )\n",
    "    logger.info(\n",
    "        f\"Patch category: {patch_category}, subject: {patch_subj}, object: {patch_obj}\"\n",
    "    )\n",
    "\n",
    "    if clean_category is None:\n",
    "        clean_category = random.choice(list(set(categories) - {patch_category}))\n",
    "\n",
    "    clean_options = task.category_wise_examples[clean_category]\n",
    "    random.shuffle(clean_options)\n",
    "\n",
    "    clean_subj, clean_obj = random.sample(\n",
    "        (\n",
    "            KeyedSet(clean_options, mt.tokenizer) - KeyedSet([patch_obj], mt.tokenizer)\n",
    "        ).values,\n",
    "        2,\n",
    "    )\n",
    "    logger.info(\n",
    "        f\"Clean category: {clean_category}, subject: {clean_subj}, object: {clean_obj}\"\n",
    "    )\n",
    "\n",
    "    if distinct_options is False:\n",
    "        patch_type_obj = patch_obj\n",
    "        clean_type_obj = clean_obj\n",
    "    else:\n",
    "        patch_type_obj = random.choice(\n",
    "            (\n",
    "                KeyedSet(task.category_wise_examples[patch_category], mt.tokenizer)\n",
    "                - KeyedSet([patch_obj], mt.tokenizer)\n",
    "            ).values\n",
    "        )\n",
    "        clean_type_obj = random.choice(\n",
    "            (\n",
    "                KeyedSet(task.category_wise_examples[clean_category], mt.tokenizer)\n",
    "                - KeyedSet([clean_obj], mt.tokenizer)\n",
    "            ).values\n",
    "        )\n",
    "\n",
    "    patch_must_have_options = [patch_obj, clean_type_obj]\n",
    "    clean_must_have_options = [clean_obj, patch_type_obj]\n",
    "\n",
    "    logger.info(f\"{patch_must_have_options=}\")\n",
    "    logger.info(f\"{clean_must_have_options=}\")\n",
    "    logger.info(f\"{clean_type_obj=}\")\n",
    "    logger.info(f\"{patch_type_obj=}\")\n",
    "\n",
    "    patch_distractors = []\n",
    "    other_categories = random.sample(\n",
    "        list(set(categories) - {patch_category, clean_category}),\n",
    "        k=n_distractors - (len(patch_must_have_options)) + 1,\n",
    "    )\n",
    "\n",
    "    for other_category in other_categories:\n",
    "        other_examples = task.category_wise_examples[other_category]\n",
    "        random.shuffle(other_examples)\n",
    "        other_examples = KeyedSet(other_examples, mt.tokenizer)\n",
    "        patch_distractors.append(\n",
    "            random.choice(\n",
    "                (\n",
    "                    other_examples\n",
    "                    - KeyedSet(\n",
    "                        patch_must_have_options + patch_distractors,\n",
    "                        tokenizer=mt.tokenizer,\n",
    "                    )\n",
    "                ).values\n",
    "            )\n",
    "        )\n",
    "\n",
    "    patch_options = patch_must_have_options + patch_distractors\n",
    "    random.shuffle(patch_options)\n",
    "    patch_obj_idx = patch_options.index(patch_obj)\n",
    "    logger.info(f\"{patch_obj_idx=} | {patch_options}\")\n",
    "\n",
    "    if distinct_options is not True:\n",
    "        clean_options = copy.deepcopy(patch_options)\n",
    "        if shuffle_clean_options:\n",
    "            # Useful for the pointer experiments\n",
    "            while (\n",
    "                clean_options.index(clean_obj) == patch_obj_idx\n",
    "                or clean_options.index(patch_type_obj) == patch_obj_idx\n",
    "            ):\n",
    "                random.shuffle(clean_options)\n",
    "        clean_obj_idx = clean_options.index(clean_obj)\n",
    "\n",
    "    else:\n",
    "        other_categories = random.sample(\n",
    "            list(set(categories) - {patch_category, clean_category}),\n",
    "            k=N_DISTRACTORS - (len(clean_must_have_options)) + 1,\n",
    "        )\n",
    "        clean_distractors = []\n",
    "        for other_category in other_categories:\n",
    "            other_examples = task.category_wise_examples[other_category]\n",
    "            random.shuffle(other_examples)\n",
    "            other_examples = KeyedSet(other_examples, mt.tokenizer)\n",
    "            clean_distractors.append(\n",
    "                random.choice(\n",
    "                    (\n",
    "                        other_examples\n",
    "                        - KeyedSet(\n",
    "                            clean_must_have_options + clean_distractors,\n",
    "                            tokenizer=mt.tokenizer,\n",
    "                        )\n",
    "                    ).values\n",
    "                )\n",
    "            )\n",
    "        clean_options = clean_must_have_options + clean_distractors\n",
    "        random.shuffle(clean_options)\n",
    "        while clean_options.index(clean_obj) == patch_obj_idx:\n",
    "            random.shuffle(clean_options)\n",
    "        clean_obj_idx = clean_options.index(clean_obj)\n",
    "\n",
    "    logger.info(f\"{clean_obj_idx=} | {clean_options}\")\n",
    "\n",
    "    print(f\"{type(task)=}\")\n",
    "    if isinstance(task, SelectOrderTask):\n",
    "        patch_metadata = {\n",
    "            \"track_type_obj_idx\": clean_obj_idx,\n",
    "            \"track_type_obj\": patch_options[clean_obj_idx],\n",
    "            \"track_type_obj_token_id\": get_first_token_id(\n",
    "                patch_options[clean_obj_idx], mt.tokenizer, prefix=\" \"\n",
    "            ),\n",
    "        }\n",
    "        clean_metadata = {\n",
    "            \"track_type_obj_idx\": patch_obj_idx,\n",
    "            \"track_type_obj\": clean_options[patch_obj_idx],\n",
    "            \"track_type_obj_token_id\": get_first_token_id(\n",
    "                clean_options[patch_obj_idx], mt.tokenizer, prefix=\" \"\n",
    "            ),\n",
    "        }\n",
    "    elif isinstance(task, SelectOneTask):\n",
    "        patch_metadata = {\n",
    "            \"track_category\": clean_category,\n",
    "            \"track_type_obj\": clean_type_obj,\n",
    "            \"track_type_obj_idx\": patch_options.index(clean_type_obj),\n",
    "            \"track_type_obj_token_id\": get_first_token_id(\n",
    "                clean_type_obj, mt.tokenizer, prefix=\" \"\n",
    "            ),\n",
    "        }\n",
    "        clean_metadata = {\n",
    "            \"track_category\": patch_category,\n",
    "            \"track_type_obj\": patch_type_obj,\n",
    "            \"track_type_obj_idx\": clean_options.index(patch_type_obj),\n",
    "            \"track_type_obj_token_id\": get_first_token_id(\n",
    "                patch_type_obj, mt.tokenizer, prefix=\" \"\n",
    "            ),\n",
    "        }\n",
    "    else:\n",
    "        raise NotImplementedError(f\"Unsupported task type: {type(task)}\")\n",
    "\n",
    "    patch_sample = SelectionSample(\n",
    "        subj=patch_subj,\n",
    "        obj=patch_obj,\n",
    "        answer=patch_obj,\n",
    "        obj_idx=patch_obj_idx,\n",
    "        ans_token_id=get_first_token_id(patch_obj, mt.tokenizer, prefix=\" \"),\n",
    "        options=patch_options,\n",
    "        category=patch_category,\n",
    "        metadata=patch_metadata,\n",
    "        prompt_template = task.prompt_templates[patch_prompt_template_idx],\n",
    "        default_option_style=patch_option_style,\n",
    "    )\n",
    "    clean_sample = SelectionSample(\n",
    "        subj=clean_subj,\n",
    "        obj=clean_obj,\n",
    "        answer=clean_obj,\n",
    "        obj_idx=clean_obj_idx,\n",
    "        ans_token_id=get_first_token_id(clean_obj, mt.tokenizer, prefix=\" \"),\n",
    "        options=clean_options,\n",
    "        category=clean_category,\n",
    "        metadata=clean_metadata,\n",
    "        prompt_template = task.prompt_templates[clean_prompt_template_idx],\n",
    "        default_option_style=clean_option_style,\n",
    "    )\n",
    "\n",
    "    if filter_by_lm_prediction:\n",
    "        test_samples = [patch_sample, clean_sample]\n",
    "        if distinct_options is True:\n",
    "            clean_sample_2 = copy.deepcopy(patch_sample)\n",
    "            clean_sample_2.options = clean_options\n",
    "            clean_sample_2.obj = clean_sample.metadata[\"track_type_obj\"]\n",
    "            clean_sample_2.obj_idx = clean_sample.metadata[\"track_type_obj_idx\"]\n",
    "            clean_sample_2.ans_token_id = clean_sample.metadata[\n",
    "                \"track_type_obj_token_id\"\n",
    "            ]\n",
    "            test_samples.append(clean_sample_2)\n",
    "\n",
    "        for sample in test_samples:\n",
    "            tokenized = prepare_input(tokenizer=mt, prompts=sample.prompt())\n",
    "            is_correct, predictions, track_options = verify_correct_option(\n",
    "                mt=mt, target=sample.obj, options=sample.options, input=tokenized\n",
    "            )\n",
    "            sample.metadata[\"tokenized\"] = tokenized.data\n",
    "            logger.info(sample.prompt())\n",
    "            logger.info(\n",
    "                f\"{sample.subj} | {sample.category} -> {sample.obj} | pred={[str(p) for p in predictions]}\"\n",
    "            )\n",
    "            if not is_correct:\n",
    "                logger.error(\n",
    "                    f'Prediction mismatch: {track_options[list(track_options.keys())[0]]}[\"{mt.tokenizer.decode(predictions[0].token_id)}\"] != {sample.ans_token_id}[\"{mt.tokenizer.decode(sample.ans_token_id)}\"]'\n",
    "                )\n",
    "                return get_counterfactual_samples_within_task(\n",
    "                    task=task,\n",
    "                    patch_category=patch_category,\n",
    "                    clean_category=clean_category,\n",
    "                    shuffle_clean_options=shuffle_clean_options,\n",
    "                    clean_prompt_template_idx=clean_prompt_template_idx,\n",
    "                    patch_prompt_template_idx=patch_prompt_template_idx,\n",
    "                    clean_option_style=clean_option_style,\n",
    "                    patch_option_style=patch_option_style,\n",
    "                    filter_by_lm_prediction=filter_by_lm_prediction,\n",
    "                    distinct_options=distinct_options,\n",
    "                    n_distractors=n_distractors,\n",
    "                )\n",
    "            sample.prediction = predictions\n",
    "\n",
    "    return patch_sample, clean_sample"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "afb20c0d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.functional import free_gpu_cache\n",
    "free_gpu_cache()\n",
    "validation_set = []\n",
    "validation_limit = 10\n",
    "\n",
    "while len(validation_set) < validation_limit:\n",
    "    patch, clean = get_counterfactual_samples_within_task(\n",
    "        filter_by_lm_prediction=True,\n",
    "        distinct_options=True,\n",
    "        n_distractors=N_DISTRACTORS,\n",
    "        clean_prompt_template_idx=3,\n",
    "        patch_prompt_template_idx=2,\n",
    "        clean_option_style=\"single_line\",\n",
    "        patch_option_style=\"single_line\",\n",
    "    )\n",
    "    validation_set.append((clean, patch))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc8aee45",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm.auto import tqdm\n",
    "\n",
    "validation_results = []\n",
    "for clean_sample, patch_sample in tqdm(validation_set):\n",
    "    patch_ques = (\n",
    "        find_token_range(\n",
    "            tokenizer=mt.tokenizer,\n",
    "            string=patch_sample.prompt(),\n",
    "            substring=\"?\",\n",
    "            occurrence=-1,\n",
    "        )[1]\n",
    "        - 1\n",
    "    )\n",
    "    clean_ques = (\n",
    "        find_token_range(\n",
    "            tokenizer=mt.tokenizer,\n",
    "            string=clean_sample.prompt(),\n",
    "            substring=\"?\",\n",
    "            occurrence=-1,\n",
    "        )[1]\n",
    "        - 1\n",
    "    )\n",
    "    logger.debug(\n",
    "        f\"{patch_ques=} | \\\"{mt.tokenizer.decode(patch_sample.metadata['tokenized']['input_ids'][0][patch_ques])}\\\"\"\n",
    "    )\n",
    "    logger.debug(\n",
    "        f\"{clean_ques=} | \\\"{mt.tokenizer.decode(clean_sample.metadata['tokenized']['input_ids'][0][clean_ques])}\\\"\"\n",
    "    )\n",
    "\n",
    "    result = validate_q_proj_ie_on_sample_pair(\n",
    "        mt=mt,\n",
    "        clean_sample=clean_sample,\n",
    "        patch_sample=patch_sample,\n",
    "        # heads=HEADS,\n",
    "        heads = heads_selected,\n",
    "        query_indices={\n",
    "            #patch_ques: clean_ques,\n",
    "            -3: -3,\n",
    "            -2: -2,\n",
    "            -1: -1\n",
    "        },\n",
    "        verify_head_behavior_on=None,\n",
    "    )\n",
    "    validation_results.append(result)\n",
    "    print(\"=\" * 80)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "735ddb80",
   "metadata": {},
   "outputs": [],
   "source": [
    "counter_patch_type_top_option = 0\n",
    "failed_cases = []\n",
    "all_cases = []\n",
    "\n",
    "for intervention_result in tqdm(validation_results):\n",
    "    clean_sample = intervention_result[\"clean_sample\"]\n",
    "    patch_sample = intervention_result[\"patch_sample\"]\n",
    "    int_track = intervention_result[\"int_track\"]\n",
    "    clean_track = intervention_result[\"clean_track\"]\n",
    "\n",
    "    clean_obj = clean_sample.ans_token_id\n",
    "    target_obj = clean_sample.metadata[\"track_type_obj_token_id\"]\n",
    "\n",
    "    clean_logit_before = clean_track[clean_obj][1].logit\n",
    "    clean_logit_after = int_track[clean_obj][1].logit\n",
    "\n",
    "    target_logit_before = clean_track[target_obj][1].logit\n",
    "    target_logit_after = int_track[target_obj][1].logit\n",
    "\n",
    "    clean_logit_delta = clean_logit_after - clean_logit_before\n",
    "    target_logit_delta = target_logit_after - target_logit_before\n",
    "\n",
    "    if (\n",
    "        int_track[list(int_track.keys())[0]][1].token_id\n",
    "        == clean_sample.metadata[\"track_type_obj_token_id\"] # In the updated data this will actually have the position token id rather than the pexisitng\n",
    "    ):\n",
    "        counter_patch_type_top_option += 1\n",
    "        failed = False\n",
    "    else:\n",
    "        failed_cases.append(\n",
    "            {\n",
    "                \"clean_sample\": clean_sample,\n",
    "                \"patch_sample\": patch_sample,\n",
    "                \"int_track\": int_track,\n",
    "                \"clean_track\": clean_track\n",
    "            }\n",
    "        )\n",
    "        failed = True\n",
    "    all_cases.append(\n",
    "            {\n",
    "                \"clean_sample\": clean_sample,\n",
    "                \"patch_sample\": patch_sample,\n",
    "                \"int_track\": int_track,\n",
    "                \"clean_track\": clean_track,\n",
    "                \"clean_logit_delta\": clean_logit_delta,\n",
    "                \"target_logit_delta\": target_logit_delta,\n",
    "                \"failed\": failed\n",
    "            }\n",
    "    )\n",
    "\n",
    "top_1_accuracy = counter_patch_type_top_option / len(validation_results)\n",
    "\n",
    "print(\n",
    "    f\"Counterfactual patching accuracy: {top_1_accuracy:.4f}\",\n",
    "    f\"({counter_patch_type_top_option}/{len(validation_results)})\"\n",
    ")\n",
    "\n",
    "print(f\"Average clean logit delta: {sum(case['clean_logit_delta'] for case in all_cases) / len(all_cases)}\")\n",
    "print(f\"Average target logit delta: {sum(case['target_logit_delta'] for case in all_cases) / len(all_cases)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3289910f",
   "metadata": {},
   "source": [
    "## Options: single line -> lettered & lettered -> single line"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0674e5e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "validation_set = []\n",
    "validation_limit = 1024\n",
    "\n",
    "while len(validation_set) < validation_limit:\n",
    "    patch, clean = get_counterfactual_samples_within_task(\n",
    "        task=select_task,\n",
    "        clean_prompt_template_idx=3,\n",
    "        patch_prompt_template_idx=3,\n",
    "        filter_by_lm_prediction=True,\n",
    "        distinct_options=True,\n",
    "        clean_option_style=\"numbered\",\n",
    "        patch_option_style=\"single_line\",\n",
    "        n_distractors=N_DISTRACTORS\n",
    "    )\n",
    "    validation_set.append((clean, patch))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "850fb78e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "from src.tokens import find_token_range\n",
    "from src.selection.optimization import validate_q_proj_ie_on_sample_pair\n",
    "\n",
    "validation_results = []\n",
    "for clean_sample, patch_sample in tqdm(validation_set):\n",
    "\n",
    "    # patch_ques = (\n",
    "    #     find_token_range(\n",
    "    #         tokenizer=mt.tokenizer,\n",
    "    #         string=patch_sample.prompt(),\n",
    "    #         substring=\"?\",\n",
    "    #         occurrence=-1,\n",
    "    #     )[1]\n",
    "    #     - 1\n",
    "    # )\n",
    "    # clean_ques = (\n",
    "    #     find_token_range(\n",
    "    #         tokenizer=mt.tokenizer,\n",
    "    #         string=clean_sample.prompt(),\n",
    "    #         substring=\"?\",\n",
    "    #         occurrence=-1,\n",
    "    #     )[1]\n",
    "    #     - 1\n",
    "    # )\n",
    "    # logger.debug(\n",
    "    #     f\"{patch_ques=} | \\\"{mt.tokenizer.decode(patch_sample.metadata['tokenized']['input_ids'][0][patch_ques])}\\\"\"\n",
    "    # )\n",
    "    # logger.debug(\n",
    "    #     f\"{clean_ques=} | \\\"{mt.tokenizer.decode(clean_sample.metadata['tokenized']['input_ids'][0][clean_ques])}\\\"\"\n",
    "    # )\n",
    "\n",
    "    result = validate_q_proj_ie_on_sample_pair(\n",
    "        mt=mt,\n",
    "        clean_sample=clean_sample,\n",
    "        patch_sample=patch_sample,\n",
    "        heads=heads_selected,\n",
    "        query_indices={\n",
    "            #patch_ques: clean_ques,\n",
    "            -3:-3,\n",
    "            -2:-2,\n",
    "            -1:-1\n",
    "        },\n",
    "        verify_head_behavior_on=None,\n",
    "        #ablate_possible_ans_info_from_options=True,\n",
    "    )\n",
    "    validation_results.append(result)\n",
    "    print(\"=\" * 80)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a69922db",
   "metadata": {},
   "outputs": [],
   "source": [
    "counter_patch_type_top_option = 0\n",
    "failed_cases = []\n",
    "all_cases = []\n",
    "\n",
    "for intervention_result in tqdm(validation_results):\n",
    "    clean_sample = intervention_result[\"clean_sample\"]\n",
    "    patch_sample = intervention_result[\"patch_sample\"]\n",
    "    int_track = intervention_result[\"int_track\"]\n",
    "    clean_track = intervention_result[\"clean_track\"]\n",
    "\n",
    "    clean_obj = clean_sample.ans_token_id\n",
    "    target_obj = clean_sample.metadata[\"track_type_obj_token_id\"]\n",
    "\n",
    "    clean_logit_before = clean_track[clean_obj][1].logit\n",
    "    clean_logit_after = int_track[clean_obj][1].logit\n",
    "\n",
    "    target_logit_before = clean_track[target_obj][1].logit\n",
    "    target_logit_after = int_track[target_obj][1].logit\n",
    "\n",
    "    clean_logit_delta = clean_logit_after - clean_logit_before\n",
    "    target_logit_delta = target_logit_after - target_logit_before\n",
    "\n",
    "    if (\n",
    "        int_track[list(int_track.keys())[0]][1].token_id\n",
    "        == clean_sample.metadata[\"track_type_obj_token_id\"] # In the updated data this will actually have the position token id rather than the pexisitng\n",
    "    ):\n",
    "        counter_patch_type_top_option += 1\n",
    "        failed = False\n",
    "    else:\n",
    "        failed_cases.append(\n",
    "            {\n",
    "                \"clean_sample\": clean_sample,\n",
    "                \"patch_sample\": patch_sample,\n",
    "                \"int_track\": int_track,\n",
    "                \"clean_track\": clean_track\n",
    "            }\n",
    "        )\n",
    "        failed = True\n",
    "    all_cases.append(\n",
    "            {\n",
    "                \"clean_sample\": clean_sample,\n",
    "                \"patch_sample\": patch_sample,\n",
    "                \"int_track\": int_track,\n",
    "                \"clean_track\": clean_track,\n",
    "                \"clean_logit_delta\": clean_logit_delta,\n",
    "                \"target_logit_delta\": target_logit_delta,\n",
    "                \"failed\": failed\n",
    "            }\n",
    "    )\n",
    "\n",
    "top_1_accuracy = counter_patch_type_top_option / len(validation_results)\n",
    "\n",
    "print(\n",
    "    f\"Counterfactual patching accuracy: {top_1_accuracy:.4f}\",\n",
    "    f\"({counter_patch_type_top_option}/{len(validation_results)})\"\n",
    ")\n",
    "\n",
    "print(f\"Average clean logit delta: {sum(case['clean_logit_delta'] for case in all_cases.values()) / len(all_cases)}\")\n",
    "print(f\"Average target logit delta: {sum(case['target_logit_delta'] for case in all_cases.values()) / len(all_cases)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a919ec5c",
   "metadata": {},
   "outputs": [],
   "source": [
    "validation_set = []\n",
    "validation_limit = 1024\n",
    "\n",
    "while len(validation_set) < validation_limit:\n",
    "    patch, clean = get_counterfactual_samples_within_task(\n",
    "        task=select_task,\n",
    "        clean_prompt_template_idx=3,\n",
    "        patch_prompt_template_idx=3,\n",
    "        filter_by_lm_prediction=True,\n",
    "        distinct_options=True,\n",
    "        clean_option_style=\"single_line\",\n",
    "        patch_option_style=\"lettered\",\n",
    "        n_distractors=N_DISTRACTORS\n",
    "    )\n",
    "    validation_set.append((clean, patch))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b647863e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "from src.tokens import find_token_range\n",
    "from src.selection.optimization import validate_q_proj_ie_on_sample_pair\n",
    "\n",
    "validation_results = []\n",
    "for clean_sample, patch_sample in tqdm(validation_set):\n",
    "\n",
    "    patch_ques = (\n",
    "        find_token_range(\n",
    "            tokenizer=mt.tokenizer,\n",
    "            string=patch_sample.prompt(),\n",
    "            substring=\"?\",\n",
    "            occurrence=-1,\n",
    "        )[1]\n",
    "        - 1\n",
    "    )\n",
    "    clean_ques = (\n",
    "        find_token_range(\n",
    "            tokenizer=mt.tokenizer,\n",
    "            string=clean_sample.prompt(),\n",
    "            substring=\"?\",\n",
    "            occurrence=-1,\n",
    "        )[1]\n",
    "        - 1\n",
    "    )\n",
    "    logger.debug(\n",
    "        f\"{patch_ques=} | \\\"{mt.tokenizer.decode(patch_sample.metadata['tokenized']['input_ids'][0][patch_ques])}\\\"\"\n",
    "    )\n",
    "    logger.debug(\n",
    "        f\"{clean_ques=} | \\\"{mt.tokenizer.decode(clean_sample.metadata['tokenized']['input_ids'][0][clean_ques])}\\\"\"\n",
    "    )\n",
    "\n",
    "    result = validate_q_proj_ie_on_sample_pair(\n",
    "        mt=mt,\n",
    "        clean_sample=clean_sample,\n",
    "        patch_sample=patch_sample,\n",
    "        heads=heads_selected,\n",
    "        query_indices={\n",
    "            #patch_ques: clean_ques,\n",
    "            -3:-3,\n",
    "            -2:-2,\n",
    "            -1:-1\n",
    "        },\n",
    "        verify_head_behavior_on=None,\n",
    "        ablate_possible_ans_info_from_options=True,\n",
    "    )\n",
    "    validation_results.append(result)\n",
    "    print(\"=\" * 80)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea3a5f1c",
   "metadata": {},
   "outputs": [],
   "source": [
    "counter_patch_type_top_option = 0\n",
    "failed_cases = []\n",
    "all_cases = []\n",
    "\n",
    "for intervention_result in tqdm(validation_results):\n",
    "    clean_sample = intervention_result[\"clean_sample\"]\n",
    "    patch_sample = intervention_result[\"patch_sample\"]\n",
    "    int_track = intervention_result[\"int_track\"]\n",
    "    clean_track = intervention_result[\"clean_track\"]\n",
    "\n",
    "    clean_obj = clean_sample.ans_token_id\n",
    "    target_obj = clean_sample.metadata[\"track_type_obj_token_id\"]\n",
    "\n",
    "    clean_logit_before = clean_track[clean_obj][1].logit\n",
    "    clean_logit_after = int_track[clean_obj][1].logit\n",
    "\n",
    "    target_logit_before = clean_track[target_obj][1].logit\n",
    "    target_logit_after = int_track[target_obj][1].logit\n",
    "\n",
    "    clean_logit_delta = clean_logit_after - clean_logit_before\n",
    "    target_logit_delta = target_logit_after - target_logit_before\n",
    "\n",
    "    if (\n",
    "        int_track[list(int_track.keys())[0]][1].token_id\n",
    "        == clean_sample.metadata[\"track_type_obj_token_id\"] # In the updated data this will actually have the position token id rather than the pexisitng\n",
    "    ):\n",
    "        counter_patch_type_top_option += 1\n",
    "        failed = False\n",
    "    else:\n",
    "        failed_cases.append(\n",
    "            {\n",
    "                \"clean_sample\": clean_sample,\n",
    "                \"patch_sample\": patch_sample,\n",
    "                \"int_track\": int_track,\n",
    "                \"clean_track\": clean_track\n",
    "            }\n",
    "        )\n",
    "        failed = True\n",
    "    all_cases.append(\n",
    "            {\n",
    "                \"clean_sample\": clean_sample,\n",
    "                \"patch_sample\": patch_sample,\n",
    "                \"int_track\": int_track,\n",
    "                \"clean_track\": clean_track,\n",
    "                \"clean_logit_delta\": clean_logit_delta,\n",
    "                \"target_logit_delta\": target_logit_delta,\n",
    "                \"failed\": failed\n",
    "            }\n",
    "    )\n",
    "\n",
    "top_1_accuracy = counter_patch_type_top_option / len(validation_results)\n",
    "\n",
    "print(\n",
    "    f\"Counterfactual patching accuracy: {top_1_accuracy:.4f}\",\n",
    "    f\"({counter_patch_type_top_option}/{len(validation_results)})\"\n",
    ")\n",
    "\n",
    "print(f\"Average clean logit delta: {sum(case['clean_logit_delta'] for case in all_cases.values()) / len(all_cases)}\")\n",
    "print(f\"Average target logit delta: {sum(case['target_logit_delta'] for case in all_cases.values()) / len(all_cases)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e799028b",
   "metadata": {},
   "source": [
    "## Language: English -> French & Spanish -> English"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "73f4b0cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "TASK_CLS = SelectOneTask\n",
    "N_DISTRACTORS = 5\n",
    "prompt_template_idx = 2\n",
    "OPTION_STYLE = \"single_line\"\n",
    "\n",
    "select_task_spanish = TASK_CLS.load(\n",
    "    path=\"data_save/selection/objects_spanish.json\"\n",
    ")\n",
    "\n",
    "print(select_task_spanish.task_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "288bb159",
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "\n",
    "def get_counterfactual_samples_across_language(\n",
    "    clean_task,\n",
    "    patch_task,\n",
    "    clean_lang = \"english\",\n",
    "    patch_lang = \"spanish\",\n",
    "    clean_prompt_template_idx=3,\n",
    "    patch_prompt_template_idx=3,\n",
    "    clean_option_style=\"single_line\",\n",
    "    patch_option_style=\"single_line\",\n",
    "    patch_category = None,\n",
    "    clean_category = None,\n",
    "    n_distractors = 5,\n",
    "    filter_by_lm_prediction=True,\n",
    "    retry_count = 0,\n",
    "    retry_limit = 20,\n",
    "):\n",
    "    \n",
    "    # Get the clean categories\n",
    "    clean_categories = list(clean_task.category_wise_examples.keys())\n",
    "\n",
    "    # Get a clean category\n",
    "    if clean_category is None:\n",
    "        clean_category = random.choice(clean_categories)\n",
    "\n",
    "    # Get a clean subject and object\n",
    "    clean_subj, clean_obj = random.sample(\n",
    "        clean_task.category_wise_examples[clean_category], 2\n",
    "    )\n",
    "\n",
    "    # Get the patch categories\n",
    "    patch_categories = list(patch_task.category_wise_examples.keys())\n",
    "\n",
    "    # Get a patch category\n",
    "    if patch_category is None:\n",
    "        patch_category = random.choice(patch_categories)\n",
    "\n",
    "    # Get a patch subject and object\n",
    "    patch_subj, patch_obj = random.sample(\n",
    "        patch_task.category_wise_examples[patch_category], 2\n",
    "    )\n",
    "\n",
    "    patch_type_obj = random.choice(\n",
    "        (\n",
    "            KeyedSet(clean_task.category_wise_examples[patch_category], mt.tokenizer)\n",
    "            - KeyedSet([patch_obj], mt.tokenizer)\n",
    "        ).values\n",
    "    )\n",
    "    clean_type_obj = random.choice(\n",
    "        (\n",
    "            KeyedSet(patch_task.category_wise_examples[clean_category], mt.tokenizer)\n",
    "            - KeyedSet([clean_obj], mt.tokenizer)\n",
    "        ).values\n",
    "    )\n",
    "\n",
    "    patch_must_have_options = [patch_obj, clean_type_obj]\n",
    "    clean_must_have_options = [clean_obj, patch_type_obj]\n",
    "\n",
    "    patch_distractors = []\n",
    "    clean_distractors = []\n",
    "\n",
    "    other_categories = random.sample(\n",
    "        list(set(clean_categories) - {patch_category, clean_category}),\n",
    "        k=n_distractors - (len(patch_must_have_options)) + 1,\n",
    "    )\n",
    "\n",
    "    print(f\"{other_categories=}\")\n",
    "\n",
    "    for other_category in other_categories:\n",
    "        other_clean_examples = clean_task.category_wise_examples[other_category]\n",
    "        random.shuffle(other_clean_examples)\n",
    "        other_clean_examples = KeyedSet(other_clean_examples, mt.tokenizer)\n",
    "        clean_distractors.append(\n",
    "            random.choice(\n",
    "                (\n",
    "                    other_clean_examples\n",
    "                    - KeyedSet(\n",
    "                        clean_must_have_options + clean_distractors,\n",
    "                        tokenizer=mt.tokenizer,\n",
    "                    )\n",
    "                ).values\n",
    "            )\n",
    "        )\n",
    "\n",
    "        other_patch_examples = patch_task.category_wise_examples[other_category]\n",
    "        random.shuffle(other_patch_examples)\n",
    "        other_patch_examples = KeyedSet(other_patch_examples, mt.tokenizer)\n",
    "        patch_distractors.append(\n",
    "            random.choice(\n",
    "                (\n",
    "                    other_patch_examples\n",
    "                    - KeyedSet(\n",
    "                        patch_must_have_options + patch_distractors,\n",
    "                        tokenizer=mt.tokenizer,\n",
    "                    )\n",
    "                ).values\n",
    "            )\n",
    "        )\n",
    "\n",
    "    print(f\"{clean_distractors=}\")\n",
    "    print(f\"{patch_distractors=}\")\n",
    "\n",
    "    clean_options = clean_must_have_options + clean_distractors\n",
    "    patch_options = patch_must_have_options + patch_distractors\n",
    "\n",
    "    print(f\"{clean_options=}\")\n",
    "    print(f\"{patch_options=}\")\n",
    "\n",
    "    random.shuffle(clean_options)\n",
    "    random.shuffle(patch_options)\n",
    "\n",
    "    patch_obj_idx = patch_options.index(patch_obj)\n",
    "    clean_obj_idx = clean_options.index(clean_obj)\n",
    "\n",
    "    patch_metadata = {\n",
    "        \"track_category\": clean_category,\n",
    "        \"track_type_obj\": clean_type_obj,\n",
    "        \"track_type_obj_idx\": patch_options.index(clean_type_obj),\n",
    "        \"track_type_obj_token_id\": get_first_token_id(\n",
    "            clean_type_obj, mt.tokenizer, prefix=\" \"\n",
    "        ),\n",
    "    }\n",
    "    clean_metadata = {\n",
    "        \"track_category\": patch_category,\n",
    "        \"track_type_obj\": patch_type_obj,\n",
    "        \"track_type_obj_idx\": clean_options.index(patch_type_obj),\n",
    "        \"track_type_obj_token_id\": get_first_token_id(\n",
    "            patch_type_obj, mt.tokenizer, prefix=\" \"\n",
    "        ),\n",
    "    }\n",
    "\n",
    "    print(f\"{patch_metadata=}\")\n",
    "    print(f\"{clean_metadata=}\")\n",
    "\n",
    "    lang_categories = {\n",
    "        \"spanish\": {\n",
    "            \"fruit\": \"fruta\",\n",
    "            \"vehicle\": \"vehículo\",\n",
    "            \"furniture\": \"mueble\",\n",
    "            \"animal\": \"animal\",\n",
    "            \"music instrument\": \"instrumento musical\",\n",
    "            \"clothing\": \"ropa\",\n",
    "            \"electronics\": \"electrónica\",\n",
    "            \"sport equipment\": \"equipo deportivo\",\n",
    "            \"kitchen appliance\": \"electrodoméstico de cocina\",\n",
    "            \"vegetable\": \"verdura\",\n",
    "            \"building\": \"edificio\",\n",
    "            \"office supply\": \"material de oficina\",\n",
    "            \"bathroom item\": \"artículo de baño\",\n",
    "            \"flower\": \"flor\",\n",
    "            \"tree\": \"árbol\",\n",
    "            \"jewelry\": \"joyería\"\n",
    "        },\n",
    "        \"french\": {\n",
    "            \"fruit\": \"fruit\",\n",
    "            \"vehicle\": \"véhicule\",\n",
    "            \"furniture\": \"meuble\",\n",
    "            \"animal\": \"animal\",\n",
    "            \"music instrument\": \"instrument de musique\",\n",
    "            \"clothing\": \"vêtement\",\n",
    "            \"electronics\": \"électronique\",\n",
    "            \"sport equipment\": \"équipement sportif\",\n",
    "            \"kitchen appliance\": \"appareil de cuisine\",\n",
    "            \"vegetable\": \"légume\",\n",
    "            \"building\": \"bâtiment\",\n",
    "            \"office supply\": \"fourniture de bureau\",\n",
    "            \"bathroom item\": \"article de salle de bain\",\n",
    "            \"flower\": \"fleur\",\n",
    "            \"tree\": \"arbre\",\n",
    "            \"jewelry\": \"bijou\"\n",
    "        }\n",
    "    }\n",
    "\n",
    "    if clean_lang != \"english\":\n",
    "        clean_lang_category = lang_categories[clean_lang][clean_category]\n",
    "    else:\n",
    "        clean_lang_category = clean_category\n",
    "    \n",
    "    if patch_lang != \"english\":\n",
    "        patch_lang_category = lang_categories[patch_lang][patch_category]\n",
    "    else:\n",
    "        patch_lang_category = patch_category\n",
    "\n",
    "    patch_sample = SelectionSample(\n",
    "        subj=patch_subj,\n",
    "        obj=patch_obj,\n",
    "        answer=patch_obj,\n",
    "        obj_idx=patch_obj_idx,\n",
    "        ans_token_id=get_first_token_id(patch_obj, mt.tokenizer, prefix=\" \"),\n",
    "        options=patch_options,\n",
    "        category=patch_lang_category,\n",
    "        metadata=patch_metadata,\n",
    "        prompt_template=patch_task.prompt_templates[patch_prompt_template_idx],\n",
    "        default_option_style=patch_option_style,\n",
    "        language=patch_lang,\n",
    "    )\n",
    "    clean_sample = SelectionSample(\n",
    "        subj=clean_subj,\n",
    "        obj=clean_obj,\n",
    "        answer=clean_obj,\n",
    "        obj_idx=clean_obj_idx,\n",
    "        ans_token_id=get_first_token_id(clean_obj, mt.tokenizer, prefix=\" \"),\n",
    "        options=clean_options,\n",
    "        category=clean_lang_category,\n",
    "        metadata=clean_metadata,\n",
    "        prompt_template=clean_task.prompt_templates[clean_prompt_template_idx],\n",
    "        default_option_style=clean_option_style,\n",
    "        language=clean_lang,\n",
    "    )\n",
    "\n",
    "    if filter_by_lm_prediction:\n",
    "        if retry_count >= retry_limit:\n",
    "            print(\"WARNING: Retry Limit Reached!\")\n",
    "            return\n",
    "\n",
    "        test_samples = [patch_sample, clean_sample]\n",
    "\n",
    "        for sample in test_samples:\n",
    "            tokenized = prepare_input(tokenizer=mt, prompts=sample.prompt())\n",
    "\n",
    "            is_correct, predictions, track_options = verify_correct_option(\n",
    "                mt=mt, target=sample.obj, options=sample.options, input=tokenized\n",
    "            )\n",
    "\n",
    "            sample.metadata[\"tokenized\"] = tokenized.data\n",
    "\n",
    "            if not is_correct:\n",
    "                \n",
    "                return get_counterfactual_samples_across_language(\n",
    "                    clean_task=clean_task,\n",
    "                    patch_task=patch_task,\n",
    "                    clean_lang = clean_lang,\n",
    "                    patch_lang = patch_lang,\n",
    "                    clean_prompt_template_idx=clean_prompt_template_idx,\n",
    "                    patch_prompt_template_idx=patch_prompt_template_idx,\n",
    "                    clean_option_style=clean_option_style,\n",
    "                    patch_option_style=patch_option_style,\n",
    "                    patch_category = patch_category,\n",
    "                    clean_category = clean_category,\n",
    "                    n_distractors = n_distractors, \n",
    "                    filter_by_lm_prediction=filter_by_lm_prediction,\n",
    "                    retry_count = retry_count + 1,\n",
    "                )\n",
    "\n",
    "            sample.prediction = predictions\n",
    "\n",
    "    return patch_sample, clean_sample\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "40b4b498",
   "metadata": {},
   "outputs": [],
   "source": [
    "validation_set = []\n",
    "validation_limit = 1024\n",
    "\n",
    "while len(validation_set) < validation_limit:\n",
    "    patch, clean = get_counterfactual_samples_across_language(\n",
    "        clean_task=select_task,\n",
    "        patch_task=select_task_spanish,\n",
    "        clean_lang = \"english\",\n",
    "        patch_lang = \"spanish\",\n",
    "        clean_prompt_template_idx=3,\n",
    "        patch_prompt_template_idx=3,\n",
    "        patch_category = None,\n",
    "        clean_category = None,\n",
    "        filter_by_lm_prediction=True,\n",
    "        distinct_options=True,\n",
    "        clean_option_style=\"single_line\",\n",
    "        patch_option_style=\"single_line\",\n",
    "        n_distractors=N_DISTRACTORS\n",
    "    )\n",
    "    validation_set.append((clean, patch))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "94a2aa67",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "from src.tokens import find_token_range\n",
    "from src.selection.optimization import validate_q_proj_ie_on_sample_pair\n",
    "\n",
    "validation_results = []\n",
    "for clean_sample, patch_sample in tqdm(validation_set):\n",
    "\n",
    "    patch_ques = (\n",
    "        find_token_range(\n",
    "            tokenizer=mt.tokenizer,\n",
    "            string=patch_sample.prompt(),\n",
    "            substring=\"?\",\n",
    "            occurrence=-1,\n",
    "        )[1]\n",
    "        - 1\n",
    "    )\n",
    "    clean_ques = (\n",
    "        find_token_range(\n",
    "            tokenizer=mt.tokenizer,\n",
    "            string=clean_sample.prompt(),\n",
    "            substring=\"?\",\n",
    "            occurrence=-1,\n",
    "        )[1]\n",
    "        - 1\n",
    "    )\n",
    "    logger.debug(\n",
    "        f\"{patch_ques=} | \\\"{mt.tokenizer.decode(patch_sample.metadata['tokenized']['input_ids'][0][patch_ques])}\\\"\"\n",
    "    )\n",
    "    logger.debug(\n",
    "        f\"{clean_ques=} | \\\"{mt.tokenizer.decode(clean_sample.metadata['tokenized']['input_ids'][0][clean_ques])}\\\"\"\n",
    "    )\n",
    "\n",
    "    result = validate_q_proj_ie_on_sample_pair(\n",
    "        mt=mt,\n",
    "        clean_sample=clean_sample,\n",
    "        patch_sample=patch_sample,\n",
    "        heads=heads_selected,\n",
    "        query_indices={\n",
    "            #patch_ques: clean_ques,\n",
    "            -3:-3,\n",
    "            -2:-2,\n",
    "            -1:-1\n",
    "        },\n",
    "        verify_head_behavior_on=None,\n",
    "        ablate_possible_ans_info_from_options=True,\n",
    "    )\n",
    "    validation_results.append(result)\n",
    "    print(\"=\" * 80)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3e9569a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "counter_patch_type_top_option = 0\n",
    "failed_cases = []\n",
    "all_cases = []\n",
    "\n",
    "for intervention_result in tqdm(validation_results):\n",
    "    clean_sample = intervention_result[\"clean_sample\"]\n",
    "    patch_sample = intervention_result[\"patch_sample\"]\n",
    "    int_track = intervention_result[\"int_track\"]\n",
    "    clean_track = intervention_result[\"clean_track\"]\n",
    "\n",
    "    clean_obj = clean_sample.ans_token_id\n",
    "    target_obj = clean_sample.metadata[\"track_type_obj_token_id\"]\n",
    "\n",
    "    clean_logit_before = clean_track[clean_obj][1].logit\n",
    "    clean_logit_after = int_track[clean_obj][1].logit\n",
    "\n",
    "    target_logit_before = clean_track[target_obj][1].logit\n",
    "    target_logit_after = int_track[target_obj][1].logit\n",
    "\n",
    "    clean_logit_delta = clean_logit_after - clean_logit_before\n",
    "    target_logit_delta = target_logit_after - target_logit_before\n",
    "\n",
    "    if (\n",
    "        int_track[list(int_track.keys())[0]][1].token_id\n",
    "        == clean_sample.metadata[\"track_type_obj_token_id\"] # In the updated data this will actually have the position token id rather than the pexisitng\n",
    "    ):\n",
    "        counter_patch_type_top_option += 1\n",
    "        failed = False\n",
    "    else:\n",
    "        failed_cases.append(\n",
    "            {\n",
    "                \"clean_sample\": clean_sample,\n",
    "                \"patch_sample\": patch_sample,\n",
    "                \"int_track\": int_track,\n",
    "                \"clean_track\": clean_track\n",
    "            }\n",
    "        )\n",
    "        failed = True\n",
    "    all_cases.append(\n",
    "            {\n",
    "                \"clean_sample\": clean_sample,\n",
    "                \"patch_sample\": patch_sample,\n",
    "                \"int_track\": int_track,\n",
    "                \"clean_track\": clean_track,\n",
    "                \"clean_logit_delta\": clean_logit_delta,\n",
    "                \"target_logit_delta\": target_logit_delta,\n",
    "                \"failed\": failed\n",
    "            }\n",
    "    )\n",
    "\n",
    "top_1_accuracy = counter_patch_type_top_option / len(validation_results)\n",
    "\n",
    "print(\n",
    "    f\"Counterfactual patching accuracy: {top_1_accuracy:.4f}\",\n",
    "    f\"({counter_patch_type_top_option}/{len(validation_results)})\"\n",
    ")\n",
    "\n",
    "print(f\"Average clean logit delta: {sum(case['clean_logit_delta'] for case in all_cases.values()) / len(all_cases)}\")\n",
    "print(f\"Average target logit delta: {sum(case['target_logit_delta'] for case in all_cases.values()) / len(all_cases)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f2ff5935",
   "metadata": {},
   "source": [
    "## Task: Select One -> Counting & Select One -> Yes No"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ebf854a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "\n",
    "def get_counterfactual_samples_across_task(\n",
    "    clean_task,\n",
    "    patch_task,\n",
    "    clean_prompt_template_idx=1,\n",
    "    patch_prompt_template_idx=3,\n",
    "    patch_category = None,\n",
    "    clean_category = None,\n",
    "    n_distractors = 5,\n",
    "    n_options = 5,\n",
    "    filter_by_lm_prediction=True,\n",
    "    retry_count = 0,\n",
    "    retry_limit = 20,\n",
    "):\n",
    "    \n",
    "    # Get the clean categories\n",
    "    clean_categories = list(clean_task.category_wise_examples.keys())\n",
    "\n",
    "    # Get a clean category\n",
    "    if clean_category is None:\n",
    "        clean_category = random.choice(clean_categories)\n",
    "\n",
    "    # Get a clean subject and object\n",
    "    clean_subj, clean_obj = random.sample(\n",
    "        clean_task.category_wise_examples[clean_category], 2\n",
    "    )\n",
    "\n",
    "    # Get the patch categories\n",
    "    patch_categories = list(patch_task.category_wise_examples.keys())\n",
    "\n",
    "    # Get a patch category\n",
    "    if patch_category is None:\n",
    "        patch_category = random.choice(patch_categories)\n",
    "\n",
    "    # Get a patch subject and object\n",
    "    patch_subj, patch_obj = random.sample(\n",
    "        patch_task.category_wise_examples[patch_category], 2\n",
    "    )\n",
    "\n",
    "    print(f\"{clean_subj=}\")\n",
    "    print(f\"{clean_obj=}\")\n",
    "    print(f\"{patch_subj=}\")\n",
    "    print(f\"{patch_obj=}\")\n",
    "\n",
    "    patch_type_obj = random.choice(\n",
    "        (\n",
    "            KeyedSet(clean_task.category_wise_examples[patch_category], mt.tokenizer)\n",
    "            - KeyedSet([patch_obj], mt.tokenizer)\n",
    "        ).values\n",
    "    )\n",
    "    clean_type_obj = random.choice(\n",
    "        (\n",
    "            KeyedSet(patch_task.category_wise_examples[clean_category], mt.tokenizer)\n",
    "            - KeyedSet([clean_obj], mt.tokenizer)\n",
    "        ).values\n",
    "    )\n",
    "\n",
    "    patch_must_have_options = [patch_obj, clean_type_obj]\n",
    "    clean_must_have_options = [clean_obj, patch_type_obj]\n",
    "\n",
    "    patch_distractors = []\n",
    "    clean_distractors = []\n",
    "\n",
    "    other_categories = random.sample(\n",
    "        list(set(clean_categories) - {patch_category, clean_category}),\n",
    "        k=n_distractors - (len(patch_must_have_options)) + 1,\n",
    "    )\n",
    "\n",
    "    print(f\"{other_categories=}\")\n",
    "\n",
    "    for other_category in other_categories:\n",
    "        other_clean_examples = clean_task.category_wise_examples[other_category]\n",
    "        random.shuffle(other_clean_examples)\n",
    "        other_clean_examples = KeyedSet(other_clean_examples, mt.tokenizer)\n",
    "        clean_distractors.append(\n",
    "            random.choice(\n",
    "                (\n",
    "                    other_clean_examples\n",
    "                    - KeyedSet(\n",
    "                        clean_must_have_options + clean_distractors,\n",
    "                        tokenizer=mt.tokenizer,\n",
    "                    )\n",
    "                ).values\n",
    "            )\n",
    "        )\n",
    "\n",
    "        other_patch_examples = patch_task.category_wise_examples[other_category]\n",
    "        random.shuffle(other_patch_examples)\n",
    "        other_patch_examples = KeyedSet(other_patch_examples, mt.tokenizer)\n",
    "        patch_distractors.append(\n",
    "            random.choice(\n",
    "                (\n",
    "                    other_patch_examples\n",
    "                    - KeyedSet(\n",
    "                        patch_must_have_options + patch_distractors,\n",
    "                        tokenizer=mt.tokenizer,\n",
    "                    )\n",
    "                ).values\n",
    "            )\n",
    "        )\n",
    "\n",
    "    print(f\"{clean_distractors=}\")\n",
    "    print(f\"{patch_distractors=}\")\n",
    "\n",
    "    clean_options = clean_must_have_options + clean_distractors\n",
    "    patch_options = patch_must_have_options + patch_distractors\n",
    "\n",
    "    print(f\"{clean_options=}\")\n",
    "    print(f\"{patch_options=}\")\n",
    "\n",
    "    random.shuffle(clean_options)\n",
    "    random.shuffle(patch_options)\n",
    "\n",
    "    patch_obj_idx = patch_options.index(patch_obj)\n",
    "    clean_obj_idx = clean_options.index(clean_obj)\n",
    "\n",
    "    patch_metadata = {\n",
    "        \"track_category\": clean_category,\n",
    "        \"track_type_obj\": clean_type_obj,\n",
    "        \"track_type_obj_idx\": patch_options.index(clean_type_obj),\n",
    "        \"track_type_obj_token_id\": get_first_token_id(\n",
    "            clean_type_obj, mt.tokenizer, prefix=\" \"\n",
    "        ),\n",
    "    }\n",
    "    clean_metadata = {\n",
    "        \"track_category\": patch_category,\n",
    "        \"track_type_obj\": patch_type_obj,\n",
    "        \"track_type_obj_idx\": clean_options.index(patch_type_obj),\n",
    "        \"track_type_obj_token_id\": get_first_token_id(\n",
    "            patch_type_obj, mt.tokenizer, prefix=\" \"\n",
    "        ),\n",
    "    }\n",
    "\n",
    "    print(f\"{patch_metadata=}\")\n",
    "    print(f\"{clean_metadata=}\")\n",
    "\n",
    "    lang_categories = {\n",
    "        \"spanish\": {\n",
    "            \"fruit\": \"fruta\",\n",
    "            \"vehicle\": \"vehículo\",\n",
    "            \"furniture\": \"mueble\",\n",
    "            \"animal\": \"animal\",\n",
    "            \"music instrument\": \"instrumento musical\",\n",
    "            \"clothing\": \"ropa\",\n",
    "            \"electronics\": \"electrónica\",\n",
    "            \"sport equipment\": \"equipo deportivo\",\n",
    "            \"kitchen appliance\": \"electrodoméstico de cocina\",\n",
    "            \"vegetable\": \"verdura\",\n",
    "            \"building\": \"edificio\",\n",
    "            \"office supply\": \"material de oficina\",\n",
    "            \"bathroom item\": \"artículo de baño\",\n",
    "            \"flower\": \"flor\",\n",
    "            \"tree\": \"árbol\",\n",
    "            \"jewelry\": \"joyería\"\n",
    "        },\n",
    "        \"french\": {\n",
    "            \"fruit\": \"fruit\",\n",
    "            \"vehicle\": \"véhicule\",\n",
    "            \"furniture\": \"meuble\",\n",
    "            \"animal\": \"animal\",\n",
    "            \"music instrument\": \"instrument de musique\",\n",
    "            \"clothing\": \"vêtement\",\n",
    "            \"electronics\": \"électronique\",\n",
    "            \"sport equipment\": \"équipement sportif\",\n",
    "            \"kitchen appliance\": \"appareil de cuisine\",\n",
    "            \"vegetable\": \"légume\",\n",
    "            \"building\": \"bâtiment\",\n",
    "            \"office supply\": \"fourniture de bureau\",\n",
    "            \"bathroom item\": \"article de salle de bain\",\n",
    "            \"flower\": \"fleur\",\n",
    "            \"tree\": \"arbre\",\n",
    "            \"jewelry\": \"bijou\"\n",
    "        }\n",
    "    }\n",
    "\n",
    "    if clean_lang != \"english\":\n",
    "        clean_lang_category = lang_categories[clean_lang][clean_category]\n",
    "    else:\n",
    "        clean_lang_category = clean_category\n",
    "    \n",
    "    if patch_lang != \"english\":\n",
    "        patch_lang_category = lang_categories[patch_lang][patch_category]\n",
    "    else:\n",
    "        patch_lang_category = patch_category\n",
    "\n",
    "    patch_sample = SelectionSample(\n",
    "        subj=patch_subj,\n",
    "        obj=patch_obj,\n",
    "        answer=patch_obj,\n",
    "        obj_idx=patch_obj_idx,\n",
    "        ans_token_id=get_first_token_id(patch_obj, mt.tokenizer, prefix=\" \"),\n",
    "        options=patch_options,\n",
    "        category=patch_lang_category,\n",
    "        metadata=patch_metadata,\n",
    "        prompt_template=patch_task.prompt_templates[patch_prompt_template_idx],\n",
    "        default_option_style=patch_option_style,\n",
    "        language=patch_lang,\n",
    "    )\n",
    "    clean_sample = SelectionSample(\n",
    "        subj=clean_subj,\n",
    "        obj=clean_obj,\n",
    "        answer=clean_obj,\n",
    "        obj_idx=clean_obj_idx,\n",
    "        ans_token_id=get_first_token_id(clean_obj, mt.tokenizer, prefix=\" \"),\n",
    "        options=clean_options,\n",
    "        category=clean_lang_category,\n",
    "        metadata=clean_metadata,\n",
    "        prompt_template=clean_task.prompt_templates[clean_prompt_template_idx],\n",
    "        default_option_style=clean_option_style,\n",
    "        language=clean_lang,\n",
    "    )\n",
    "\n",
    "    if filter_by_lm_prediction:\n",
    "        if retry_count >= retry_limit:\n",
    "            print(\"WARNING: Retry Limit Reached!\")\n",
    "            return\n",
    "\n",
    "        test_samples = [patch_sample, clean_sample]\n",
    "\n",
    "        for sample in test_samples:\n",
    "            tokenized = prepare_input(tokenizer=mt, prompts=sample.prompt())\n",
    "\n",
    "            is_correct, predictions, track_options = verify_correct_option(\n",
    "                mt=mt, target=sample.obj, options=sample.options, input=tokenized\n",
    "            )\n",
    "\n",
    "            sample.metadata[\"tokenized\"] = tokenized.data\n",
    "\n",
    "            if not is_correct:\n",
    "                \n",
    "                return get_counterfactual_samples_across_language(\n",
    "                    clean_task=clean_task,\n",
    "                    patch_task=patch_task,\n",
    "                    clean_lang = clean_lang,\n",
    "                    patch_lang = patch_lang,\n",
    "                    clean_prompt_template_idx=clean_prompt_template_idx,\n",
    "                    patch_prompt_template_idx=patch_prompt_template_idx,\n",
    "                    clean_option_style=clean_option_style,\n",
    "                    patch_option_style=patch_option_style,\n",
    "                    patch_category = patch_category,\n",
    "                    clean_category = clean_category,\n",
    "                    n_distractors = n_distractors, \n",
    "                    filter_by_lm_prediction=filter_by_lm_prediction,\n",
    "                    retry_count = retry_count + 1,\n",
    "                )\n",
    "\n",
    "            sample.prediction = predictions\n",
    "\n",
    "    return patch_sample, clean_sample\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python (retrieval2)",
   "language": "python",
   "name": "retrieval2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
