{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5728a60e",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f6bcda99",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "\n",
    "import sys\n",
    "\n",
    "sys.path.append(\"../\")\n",
    "\n",
    "##################################################################\n",
    "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
    "# os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0,1,2,3,4,5,6,7\"\n",
    "##################################################################\n",
    "\n",
    "import logging\n",
    "from src.utils import logging_utils\n",
    "from src.utils import env_utils\n",
    "\n",
    "logger = logging.getLogger(__name__)\n",
    "\n",
    "logging.basicConfig(\n",
    "    level=logging.DEBUG,\n",
    "    format=logging_utils.DEFAULT_FORMAT,\n",
    "    datefmt=logging_utils.DEFAULT_DATEFMT,\n",
    "    stream=sys.stdout,\n",
    ")\n",
    "\n",
    "import torch\n",
    "import transformers\n",
    "\n",
    "logger.info(f\"{torch.__version__=}, {torch.version.cuda=}\")\n",
    "logger.info(\n",
    "    f\"{torch.cuda.is_available()=}, {torch.cuda.device_count()=}, {torch.cuda.get_device_name()=}\"\n",
    ")\n",
    "logger.info(f\"{transformers.__version__=}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "da104a94",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.utils.training_utils import get_device_map\n",
    "\n",
    "# model_key = \"meta-llama/Llama-3.2-3B\"\n",
    "# model_key = \"meta-llama/Llama-3.1-8B\"\n",
    "model_key = \"meta-llama/Llama-3.3-70B-Instruct\"\n",
    "# model_key = \"meta-llama/Llama-3.1-405B-Instruct\"\n",
    "\n",
    "# model_key = \"google/gemma-2-9b-it\"\n",
    "# model_key = \"google/gemma-3-12b-it\"\n",
    "# model_key = \"google/gemma-2-27b-it\"\n",
    "\n",
    "# model_key = \"deepseek-ai/DeepSeek-R1-Distill-Llama-8B\"\n",
    "\n",
    "# model_key = \"allenai/OLMo-2-1124-7B-Instruct\"\n",
    "# model_key = \"allenai/OLMo-7B-0424-hf\"\n",
    "\n",
    "# model_key = \"Qwen/Qwen2-7B\"\n",
    "# model_key = \"Qwen/Qwen2.5-14B-Instruct\"\n",
    "# model_key = \"Qwen/Qwen2.5-32B-Instruct\"\n",
    "# model_key = \"Qwen/Qwen2.5-72B-Instruct\"\n",
    "\n",
    "# model_key = \"Qwen/Qwen3-1.7B\"\n",
    "# model_key = \"Qwen/Qwen3-4B\"\n",
    "# model_key = \"Qwen/Qwen3-8B\"\n",
    "# model_key = \"Qwen/Qwen3-14B\"\n",
    "# model_key = \"Qwen/Qwen3-32B\"\n",
    "\n",
    "# device_map = get_device_map(model_key, 30, n_gpus=8)\n",
    "# device_map"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "19dcf5e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# os.environ[\"BNB_CUDA_VERSION\"] = \"124\"\n",
    "# ! echo $BNB_CUDA_VERSION\n",
    "# ! python -m bitsandbytes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f7ced8e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.models import ModelandTokenizer\n",
    "\n",
    "# from transformers import BitsAndBytesConfig\n",
    "\n",
    "mt = ModelandTokenizer(\n",
    "    model_key=model_key,\n",
    "    torch_dtype=torch.bfloat16,\n",
    "    # device_map=device_map,\n",
    "    device_map=\"auto\",\n",
    "    # quantization_config = BitsAndBytesConfig(\n",
    "    #     # load_in_4bit=True\n",
    "    #     load_in_8bit=True\n",
    "    # )\n",
    "    attn_implementation=\"eager\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b76d98b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# from src.functional import free_gpu_cache\n",
    "\n",
    "# # SYNTH_DATASET = \"icosahedron_1\"\n",
    "# SYNTH_DATASET = \"64\"\n",
    "\n",
    "# checkpoint_path = os.path.join(\n",
    "#     env_utils.DEFAULT_RESULTS_DIR,\n",
    "#     \"trained_params\",\n",
    "#     f\"{SYNTH_DATASET}\",\n",
    "#     \"_full__clamp=0.001\",\n",
    "#     model_key.split(\"/\")[-1],\n",
    "# )\n",
    "\n",
    "# version = \"epoch_1\"\n",
    "# # version = \"final_model\"\n",
    "\n",
    "# checkpoint_path = os.path.join(env_utils.DEFAULT_RESULTS_DIR, checkpoint_path, version)\n",
    "\n",
    "# print(os.listdir(checkpoint_path))\n",
    "\n",
    "# checkpoint_path = os.path.join(checkpoint_path, \"trainable_params.pt\")\n",
    "\n",
    "# loaded_deltas = torch.load(checkpoint_path, map_location=\"cpu\")\n",
    "# # loaded_deltas\n",
    "\n",
    "# free_gpu_cache()\n",
    "\n",
    "\n",
    "# d = loaded_deltas[\"model<>layers<>5<>mlp<>gate_proj\"]\n",
    "# d.abs().max()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4ec24d05",
   "metadata": {},
   "outputs": [],
   "source": [
    "# from src.utils.training_utils import TrainableLM_delta, TrainableLM_LoRA\n",
    "\n",
    "# #################################################\n",
    "# Trainable_CLS = TrainableLM_delta\n",
    "# # Trainable_CLS = TrainableLM_LoRA\n",
    "# #################################################\n",
    "\n",
    "# Trainable_CLS.fuse_with_model(mt._model, loaded_deltas)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d20ced5",
   "metadata": {},
   "outputs": [],
   "source": [
    "###############################################################################\n",
    "LAYERS = None\n",
    "HEADS = [\n",
    "    # (33, 45),\n",
    "    # (33, 18),\n",
    "    # (34, 1),\n",
    "    # (34, 6),\n",
    "    # (34, 7),\n",
    "    (35, 19),\n",
    "    # (39, 40),\n",
    "    # (42, 30),\n",
    "    # (47, 18),\n",
    "    # (52, 58),\n",
    "]\n",
    "\n",
    "# LAYERS = [54, 60, 67]\n",
    "# HEADS = [\n",
    "#     (54, 44),\n",
    "#     (54, 63),\n",
    "#     (55, 43),\n",
    "#     (55, 29),\n",
    "#     (60, 9),\n",
    "#     (60, 25),\n",
    "#     (60, 42),\n",
    "#     (67, 51)\n",
    "# ]\n",
    "###############################################################################"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11424207",
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib import pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "# optimized_path = os.path.join(\n",
    "#     env_utils.DEFAULT_RESULTS_DIR,\n",
    "#     \"selection/optimized_backup_heads\",\n",
    "#     mt.name.split(\"/\")[-1],\n",
    "#     f\"{select_task.task_name}.npz\"\n",
    "# )\n",
    "\n",
    "optimized_path = os.path.join(\n",
    "    env_utils.DEFAULT_RESULTS_DIR,\n",
    "    \"selection/optimized_heads\",\n",
    "    model_key.split(\"/\")[-1],\n",
    "    \"distinct_options\",\n",
    "    # f\"{select_task.task_name}\",\n",
    "    \"select_one\",\n",
    "    # \"legacy\",\n",
    "    \"epoch_10.npz\"\n",
    ")\n",
    "\n",
    "# optimized_path = os.path.join(\n",
    "#     env_utils.DEFAULT_RESULTS_DIR,\n",
    "#     \"test_opt_code\",\n",
    "#     model_key.split(\"/\")[-1],\n",
    "#     \"distinct_options\",\n",
    "#     f\"{select_task.task_name}\",\n",
    "#     # \"select_one\",\n",
    "#     \"legacy\",\n",
    "#     \"epoch_10.npz\"\n",
    "# )\n",
    "\n",
    "optimization_results = np.load(optimized_path, allow_pickle=True)\n",
    "plt.plot(optimization_results[\"losses\"])\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e59d8648",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(20, 10))\n",
    "\n",
    "optimal_head_mask = torch.tensor(optimization_results[\"optimal_mask\"]).to(torch.float32)\n",
    "optimal_head_mask[52:, :] = 0.0\n",
    "\n",
    "plt.imshow(\n",
    "    optimal_head_mask.T.numpy(),\n",
    "    cmap=\"Blues\",\n",
    "    aspect=\"auto\",\n",
    "    vmin=0,\n",
    "    vmax=1,\n",
    ")\n",
    "\n",
    "optimized_heads = torch.nonzero(optimal_head_mask > 0.5, as_tuple=False).tolist()\n",
    "optimized_heads = [\n",
    "    (layer_idx, head_idx) for layer_idx, head_idx in optimized_heads\n",
    "]\n",
    "print(len(optimized_heads))\n",
    "\n",
    "HEADS = optimized_heads\n",
    "\n",
    "(35, 19) in HEADS, (35, 19) in optimized_heads\n",
    "# [(29, 3) in HEADS]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "44b9d51b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.data  import SelectOneTask\n",
    "\n",
    "select_task = SelectOneTask.load(\n",
    "    path=os.path.join(\n",
    "        env_utils.DEFAULT_DATA_DIR,\n",
    "        \"selection\",\n",
    "        \"rhymes.json\"\n",
    "    )\n",
    ")\n",
    "\n",
    "print(select_task)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "75538067",
   "metadata": {},
   "outputs": [],
   "source": [
    "sample = select_task.get_random_sample(\n",
    "    mt=mt, \n",
    "    # category=\"politician\",\n",
    "    filter_by_lm_prediction=True,\n",
    "    prompt_template_idx=3,\n",
    "    n_distractors=5,\n",
    ")\n",
    "print(sample.prompt(), \">>\", mt.tokenizer.decode(sample.ans_token_id))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b73f4fb7",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.functional import verify_head_patterns\n",
    "prompt=sample.prompt()\n",
    "\n",
    "# prompt = \"\"\"Which person from the following list has their occupation in common with Samuel L. Jackson? \n",
    "# Options: Norah O'Donnell, Julia Roberts, Rory McIlroy, Petra Kvitová, Christian Louboutin, Wes Anderson. \n",
    "# Answer:\"\"\"\n",
    "\n",
    "# prompt = \"\"\"options: \n",
    "# 1. The Space Needle \n",
    "# 2. Louvre Museum \n",
    "# 3. Colosseum \n",
    "# 4. Christ the Redeemer \n",
    "# 5. State of Liberty \n",
    "# 6. Big Ben \n",
    "# Which of these landmarks is located in Brazil?\n",
    "# Answer:\"\"\"\n",
    "\n",
    "# prompt = \"\"\"Which of these numbers is a multiple of 3? \n",
    "# options: 43, 57, 55, 62, 39 \n",
    "# Answer:\"\"\"\n",
    "\n",
    "# prompt = \"\"\"John is taller than Bob, but shorter than Steve. Who is the shortest? \n",
    "# Ans:\"\"\"\n",
    "\n",
    "# prompt = \"\"\"Brad Pitt, Tom Hanks, Leonardo DiCaprio, Scarlett Johansson, Tom Brady, Hugh Jackman\n",
    "# Which person from this list is different from the others?\n",
    "# Answer:\"\"\"\n",
    "\n",
    "# prompt = \"\"\"Which of these items is not a fruit?\n",
    "# options: Apple, Potato, Banana, Grape, Orange\n",
    "# Answer:\"\"\"\n",
    "\n",
    "# prompt = \"\"\"Which of these words rhymes with the word \"look\"? \n",
    "# Items: orange, mat, book, rabbit, bowl, watch, mirror \n",
    "# Answer:\"\"\"\n",
    "\n",
    "# prompt = \"\"\"Items: Apple, Banana, Panda, Grape, Orange.\n",
    "# How many fruits are listed above?\n",
    "# Answer:\"\"\"\n",
    "\n",
    "# prompt = \"\"\"Items: cat, door, table, window, pen, keyboard.\n",
    "# Which word has exactly 4 letters?\n",
    "# Answer:\"\"\"\n",
    "# prompt = \"\"\"Which word has exactly 4 letters?\n",
    "# Items: cat, door, table, window, pen, keyboard.\n",
    "# Answer:\"\"\"\n",
    "\n",
    "# prompt = \"\"\"Items: zebra, monkey, elephant, tiger, lion.\n",
    "# Which word comes first alphabetically?\n",
    "# Answer:\"\"\"\n",
    "prompt = \"\"\"Which word comes first alphabetically?\n",
    "Items: zebra, monkey, elephant, tiger, lion.\n",
    "Answer:\"\"\"\n",
    "\n",
    "verify = verify_head_patterns(\n",
    "    prompt=prompt,\n",
    "    # options=sample.options,\n",
    "    # pivot=sample.subj,\n",
    "    mt=mt,\n",
    "    heads=HEADS,\n",
    "    # heads = optimized_heads,\n",
    "    value_weighted=False,\n",
    "    # generate_full_answer=True,\n",
    ")\n",
    "\n",
    "verify[\"predictions\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ebe22ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "verify = verify_head_patterns(\n",
    "    prompt=sample.prompt,\n",
    "    options=sample.options,\n",
    "    pivot=sample.match_with,\n",
    "    mt=mt,\n",
    "    value_weighted=False,\n",
    "    generate_full_answer=False,\n",
    "    ablate_possible_ans_info_from_options=True,\n",
    ")\n",
    "\n",
    "verify[\"predictions\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ad8689e",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt_template = \"\"\"{}\n",
    "Who among these people mentioned above is by profession a {}?\n",
    "Answer:\"\"\"\n",
    "\n",
    "sample.options = [\"Norah O'Donnell\", \"Julia Roberts\", \"Rory McIlroy\", \"Petra Kvitová\", \"Christian Louboutin\", \"Wes Anderson\"]\n",
    "# prompt = prompt_template.format(\n",
    "#     \", \".join(sample.options),\n",
    "#     sample.metadata[\"attribute\"]\n",
    "# )\n",
    "\n",
    "last_name_first = []\n",
    "for name in sample.options:\n",
    "    first_name = name.split(\" \")[0]\n",
    "    last_name = \" \".join(name.split(\" \")[1:])\n",
    "    last_name_first.append(f\"{last_name}, {first_name}\")\n",
    "\n",
    "prompt = prompt_template.format(\n",
    "    \"\\n\".join(f\"{idx+1}. {name}\" for idx, name in enumerate(last_name_first)),\n",
    "    sample.metadata[\"attribute\"]\n",
    ")\n",
    "\n",
    "# prompt = sample.prompt\n",
    "\n",
    "print(prompt, \" >> \")\n",
    "\n",
    "verify = verify_head_patterns(\n",
    "    prompt=prompt,\n",
    "    options=sample.options,\n",
    "    pivot=sample.match_with,\n",
    "    mt=mt,\n",
    "    value_weighted=False,\n",
    "    generate_full_answer=True,\n",
    ")\n",
    "\n",
    "verify[\"predictions\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fda47f59",
   "metadata": {},
   "outputs": [],
   "source": [
    "verify_ind = verify_head_patterns(\n",
    "    prompt=prompt,\n",
    "    options=last_name_first,\n",
    "    pivot=sample.match_with,\n",
    "    mt=mt,\n",
    "    value_weighted=False,\n",
    "    generate_full_answer=False,\n",
    "    ablate_possible_ans_info_from_options=True,\n",
    ")\n",
    "verify_ind[\"predictions\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a6519ed6",
   "metadata": {},
   "source": [
    "### Different Attribute"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7863a3af",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.data  import load_people_by_category\n",
    "\n",
    "people_by_nationality = load_people_by_category(\n",
    "    tokenizer = mt.tokenizer,\n",
    "    path = os.path.join(env_utils.DEFAULT_DATA_DIR, \"selection_real/nationality.json\")\n",
    ")\n",
    "list(people_by_nationality.keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f6de76cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.data import get_random_sample\n",
    "\n",
    "sample = get_random_sample(\n",
    "    people_by_category=people_by_nationality,\n",
    "    mt=mt,\n",
    "    category=\"nationality\",\n",
    "    attribute=\"Australia\",\n",
    "    n_distractors=5\n",
    ")\n",
    "print(sample.prompt, \">>\", sample.obj)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a8bae6d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.functional import generate_with_patch, predict_next_token\n",
    "\n",
    "# prompt_template = \"\"\"Which of these peoples is by profession a {}?\n",
    "# options: {}\n",
    "# Answer:\"\"\"\n",
    "# prompt = prompt_template.format(\n",
    "#     sample.metadata[\"attribute\"],\n",
    "#     \", \".join(sample.options)\n",
    "# )\n",
    "\n",
    "prompt_template = \"\"\"{}\n",
    "Who among these people mentioned above is from {}?\n",
    "Answer:\"\"\"\n",
    "prompt = prompt_template.format(\n",
    "    \", \".join(sample.options),\n",
    "    sample.metadata[\"attribute\"]\n",
    ")\n",
    "\n",
    "print(prompt, \" >> \", sample.obj)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0fc6132a",
   "metadata": {},
   "outputs": [],
   "source": [
    "verify = verify_head_patterns(\n",
    "    prompt=prompt,\n",
    "    options=sample.options,\n",
    "    pivot=sample.match_with,\n",
    "    mt=mt,\n",
    "    value_weighted=False,\n",
    "    generate_full_answer=True,\n",
    ")\n",
    "\n",
    "verify[\"predictions\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fc7375bc",
   "metadata": {},
   "outputs": [],
   "source": [
    "verify_ind = verify_head_patterns(\n",
    "    prompt=prompt,\n",
    "    options=sample.options,\n",
    "    pivot=sample.match_with,\n",
    "    mt=mt,\n",
    "    value_weighted=False,\n",
    "    generate_full_answer=False,\n",
    "    ablate_possible_ans_info_from_options=True,\n",
    ")\n",
    "\n",
    "verify_ind[\"predictions\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea4eec49",
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "\n",
    "# options = [\"Table\", \"Orange\", \"Transistor\", \"Spinach\", \"Piano\", \"Coffee\"]\n",
    "# prompt_template = \"\"\"Which of these objects is by type a {}?\n",
    "# options: {}\n",
    "# Answer:\"\"\"\n",
    "# prompt = prompt_template.format(\n",
    "#     \"musical instrument\",\n",
    "#     \", \".join(options)\n",
    "# )\n",
    "# print(prompt)\n",
    "\n",
    "# options = [\"anaconda\", \"python\", \"cobra\", \"viper\", \"mamba\", \"rattlesnake\"]\n",
    "# prompt = f\"\"\"Which of these snake names is also a programming language?\n",
    "# options: {\", \".join(options)}\n",
    "# Answer:\"\"\"\n",
    "\n",
    "options = [\n",
    "    \"The Space Needle\",\n",
    "    \"Louvre Museum\",\n",
    "    \"Colosseum\",\n",
    "    \"Christ the Redeemer\",\n",
    "    \"State of Liberty\",\n",
    "    \"Big Ben\",\n",
    "]\n",
    "options += [\"Eiffel Tower\"]\n",
    "# random.shuffle(options)\n",
    "country = \"England\"\n",
    "prompt_template = \"\"\"{}\n",
    "Which of these landmarks is located in {}?\n",
    "Answer:\"\"\"\n",
    "prompt = prompt_template.format(\n",
    "    \"\\n\".join([f\"{idx+1}. {opt}\" for idx, opt in enumerate(options)]), country\n",
    ")\n",
    "\n",
    "verify = verify_head_patterns(\n",
    "    prompt=prompt,\n",
    "    options=options,\n",
    "    pivot=country,\n",
    "    mt=mt,\n",
    "    value_weighted=False,\n",
    "    generate_full_answer=True,\n",
    "    ablate_possible_ans_info_from_options=False,\n",
    ")\n",
    "\n",
    "verify[\"predictions\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7d0db58f",
   "metadata": {},
   "outputs": [],
   "source": [
    "verify_ind = verify_head_patterns(\n",
    "    prompt=prompt,\n",
    "    options=options,\n",
    "    pivot=country,\n",
    "    mt=mt,\n",
    "    value_weighted=False,\n",
    "    generate_full_answer=False,\n",
    "    ablate_possible_ans_info_from_options=True,\n",
    ")\n",
    "\n",
    "verify_ind[\"predictions\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cf6b285c",
   "metadata": {},
   "outputs": [],
   "source": [
    "options = [\"Marie Curie\", \"Albert Einstein\", \"Grover Cleveland\", \"Charles Darwin\", \"Nikola Tesla\", \"Issac Newton\"]\n",
    "# prompt_template = \"\"\"Which of these people is not a scientist?\n",
    "# options: {}\n",
    "# Answer:\"\"\"\n",
    "# prompt = prompt_template.format(\n",
    "#     \", \".join(options)\n",
    "# )\n",
    "# print(prompt)\n",
    "\n",
    "prompt_template = \"\"\"options: {}\n",
    "Which of these people is not a scientist?\n",
    "Answer:\"\"\"\n",
    "prompt = prompt_template.format(\n",
    "    \", \".join(options)\n",
    ")\n",
    "print(prompt)\n",
    "\n",
    "verify = verify_head_patterns(\n",
    "    prompt=prompt,\n",
    "    options=options,\n",
    "    pivot=\"scientist\",\n",
    "    mt=mt,\n",
    "    value_weighted=False,\n",
    "    generate_full_answer=True,\n",
    ")\n",
    "verify[\"predictions\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9d6764ef",
   "metadata": {},
   "outputs": [],
   "source": [
    "verify_ind = verify_head_patterns(\n",
    "    prompt=prompt,\n",
    "    options=options,\n",
    "    pivot=\"scientist\",\n",
    "    mt=mt,\n",
    "    value_weighted=False,\n",
    "    generate_full_answer=False,\n",
    "    ablate_possible_ans_info_from_options=True,\n",
    ")\n",
    "verify_ind[\"predictions\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "027d93d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "options = [\"Michael Jordan\", \"Serena Williams\", \"Nikki Haley\", \"Mike Tyson\", \"Carl Sagan\", \"Tom Cruise\"]\n",
    "prompt_template = \"\"\"Which of these people is a politician or a scientist?\n",
    "options: {}\n",
    "Answer:\"\"\"\n",
    "\n",
    "prompt = prompt_template.format(\n",
    "    \", \".join(options)\n",
    ")\n",
    "print(prompt)\n",
    "\n",
    "verify = verify_head_patterns(\n",
    "    prompt=prompt,\n",
    "    options=options,\n",
    "    pivot=\"politician or scientist\",\n",
    "    mt=mt,\n",
    "    value_weighted=False,\n",
    "    generate_full_answer=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "414081ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "verify_ind = verify_head_patterns(\n",
    "    prompt=prompt,\n",
    "    options=options,\n",
    "    pivot=\"politician or scientist\",\n",
    "    mt=mt,\n",
    "    value_weighted=False,\n",
    "    generate_full_answer=False,\n",
    "    ablate_possible_ans_info_from_options=True,\n",
    ")\n",
    "verify_ind[\"predictions\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "60e64440",
   "metadata": {},
   "outputs": [],
   "source": [
    "options = [\"46\", \"57\", \"55\", \"62\", \"39\"]\n",
    "# prompt_template = \"\"\"Which of these numbers is a multiple of 3?\n",
    "# options: {}\n",
    "# Answer:\"\"\"\n",
    "\n",
    "prompt_template = \"\"\"{}\n",
    "Which of these numbers is a multiple of 3?\n",
    "Ans:\"\"\"\n",
    "\n",
    "prompt = prompt_template.format(\n",
    "    \", \".join(options)\n",
    ")\n",
    "print(prompt)\n",
    "\n",
    "verify = verify_head_patterns(\n",
    "    prompt=prompt,\n",
    "    options=options,\n",
    "    pivot=\"multiple of 3\",\n",
    "    mt=mt,\n",
    "    value_weighted=False,\n",
    "    generate_full_answer=True,\n",
    ")\n",
    "verify[\"predictions\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cdc6ef8d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# verify_ind = verify_head_patterns(\n",
    "#     prompt=prompt,\n",
    "#     options=options,\n",
    "#     pivot=\"politician or scientist\",\n",
    "#     mt=mt,\n",
    "#     value_weighted=False,\n",
    "#     generate_full_answer=False,\n",
    "#     ablate_possible_ans_info_from_options=True,\n",
    "# )\n",
    "# verify_ind[\"predictions\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "62f912ca",
   "metadata": {},
   "outputs": [],
   "source": [
    "# NOVELTY: Find the odd one out\n",
    "\n",
    "# options = [\"Issac Newton\", \"Albert Einstein\", \"Marie Curie\", \"Hugh Jackman\", \"Charles Darwin\"]\n",
    "options = [\n",
    "    \"Brad Pitt\",\n",
    "    \"Tom Hanks\",\n",
    "    \"Leonardo DiCaprio\",\n",
    "    \"Tom Brady\",\n",
    "    \"Daniel Radcliffe\",\n",
    "    \"Hugh Jackman\",\n",
    "]\n",
    "\n",
    "# prompt_template = \"\"\"Which person from the following list is different from the others?\n",
    "# {}.\n",
    "# Ans:\"\"\"\n",
    "prompt_template = \"\"\"{}\n",
    "Which among these people are different from others?\n",
    "Answer:\"\"\"\n",
    "# prompt = prompt_template.format(\", \".join(options))\n",
    "prompt = prompt_template.format(\"\\n\".join([f\"{idx+1}. {name}\" for idx, name in enumerate(options)]))\n",
    "\n",
    "print(prompt)\n",
    "verify = verify_head_patterns(\n",
    "    prompt=prompt,\n",
    "    options=options,\n",
    "    pivot=\"different from others\",\n",
    "    mt=mt,\n",
    "    value_weighted=False,\n",
    "    generate_full_answer=True,\n",
    ")\n",
    "verify[\"predictions\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c20b519d",
   "metadata": {},
   "outputs": [],
   "source": [
    "verify_ind = verify_head_patterns(\n",
    "    prompt=prompt,\n",
    "    options=options,\n",
    "    pivot=\"different from others\",\n",
    "    mt=mt,\n",
    "    value_weighted=False,\n",
    "    generate_full_answer=False,\n",
    "    ablate_possible_ans_info_from_options=True,\n",
    ")\n",
    "verify_ind[\"predictions\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "34473276",
   "metadata": {},
   "outputs": [],
   "source": [
    "# rhymes with\n",
    "\n",
    "options = [\"orange\", \"mat\", \"book\", \"rabbit\", \"bowl\", \"watch\", \"mirror\"]\n",
    "word = \"rat\"\n",
    "# word = \"look\"\n",
    "# word = \"fetch\"\n",
    "# word = \"growl\"\n",
    "\n",
    "# prompt_template = \"\"\"{}\n",
    "# Which of these words rhymes with the word \"{}\"?\n",
    "# Answer:\"\"\"\n",
    "# prompt = prompt_template.format(\n",
    "#     \", \".join(options),\n",
    "#     word\n",
    "# )\n",
    "# prompt = prompt_template.format(\"\\n\".join([f\"{idx+1}. {name}\" for idx, name in enumerate(options)]), word)\n",
    "\n",
    "\n",
    "prompt_template = \"\"\"Which of the following words rhymes with {}\n",
    "options: {}\n",
    "Answer:\"\"\"\n",
    "prompt = prompt_template.format(\n",
    "    word,\n",
    "    \", \".join(options),\n",
    ")\n",
    "# prompt = prompt_template.format(word, \"\\n\".join([f\"{idx+1}. {name}\" for idx, name in enumerate(options)]))\n",
    "\n",
    "\n",
    "print(prompt)\n",
    "\n",
    "verify = verify_head_patterns(\n",
    "    prompt=prompt,\n",
    "    options=options,\n",
    "    pivot=word,\n",
    "    mt=mt,\n",
    "    value_weighted=False,\n",
    "    generate_full_answer=True,\n",
    ")\n",
    "verify[\"predictions\"]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "050e2d91",
   "metadata": {},
   "outputs": [],
   "source": [
    "verify_ind = verify_head_patterns(\n",
    "    prompt=prompt,\n",
    "    # options=[f\"{op}\\n\" for op in options],\n",
    "    options=options,\n",
    "    pivot=word,\n",
    "    mt=mt,\n",
    "    value_weighted=False,\n",
    "    generate_full_answer=False,\n",
    "    ablate_possible_ans_info_from_options=True,\n",
    ")\n",
    "verify_ind[\"predictions\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bf4d279e",
   "metadata": {},
   "outputs": [],
   "source": [
    "#! Logical deduction\n",
    "\n",
    "# prompt = \"\"\"John is taller than Mary. Mary is taller than Steve. Who is the shortest?\n",
    "# Ans:\"\"\"\n",
    "prompt = \"\"\"John is taller than Bob, but shorter than Steve. Who is the shortest?\n",
    "Ans:\"\"\"\n",
    "# prompt = \"\"\"All the cookies are either chocolate or vanilla. None of the chocolate cookies have nuts.\n",
    "# This cookie has nuts. What flavor is it?\n",
    "# Ans:\"\"\"\n",
    "# prompt = \"\"\"The red box is to the left of the blue box, and to the right of the green box.\n",
    "# Which is the rightmost box?\n",
    "# Ans:\"\"\"\n",
    "\n",
    "verify = verify_head_patterns(\n",
    "    prompt=prompt,\n",
    "    options=[\"John\", \"Bob\", \"Steve\"],\n",
    "    pivot=\"shortest\",\n",
    "    mt=mt,\n",
    "    value_weighted=False,\n",
    "    generate_full_answer=True,\n",
    ")\n",
    "verify[\"predictions\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f90edc47",
   "metadata": {},
   "outputs": [],
   "source": [
    "verify_ind = verify_head_patterns(\n",
    "    prompt=prompt,\n",
    "    options=[\"John\", \"Bob\", \"Steve\"],\n",
    "    pivot=\"shortest\",\n",
    "    mt=mt,\n",
    "    value_weighted=False,\n",
    "    generate_full_answer=False,\n",
    ")\n",
    "verify_ind[\"predictions\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b4fa521c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# prompt = \"\"\"x = True\n",
    "# y = False\n",
    "# z = True\n",
    "\n",
    "# a = x and y\n",
    "# b = y or z\n",
    "# c = z xor x\n",
    "\n",
    "# Which one among a, b, c is True?\n",
    "# Ans:\"\"\"\n",
    "\n",
    "prompt = \"\"\"a, b, c is calculated as below. Find which one among a, b, c is True\n",
    "\n",
    "x = True\n",
    "y = False\n",
    "z = False\n",
    "\n",
    "a = x and y\n",
    "b = y or z\n",
    "c = z xor x\n",
    "\n",
    "Ans:\"\"\"\n",
    "\n",
    "verify = verify_head_patterns(\n",
    "    prompt=prompt,\n",
    "    options=[\"a\", \"b\", \"c\"],\n",
    "    pivot=\"True\",\n",
    "    mt=mt,\n",
    "    value_weighted=False,\n",
    "    generate_full_answer=True,\n",
    ")\n",
    "verify[\"predictions\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "be3e600d",
   "metadata": {},
   "outputs": [],
   "source": [
    "options = [\n",
    "    \"grape\",\n",
    "    \"marker\",\n",
    "    \"truck\",\n",
    "    \"watch\",\n",
    "    \"car\",\n",
    "    \"banana\",\n",
    "    \"mouse\",\n",
    "    \"mango\",\n",
    "    \"bus\",\n",
    "]\n",
    "# prompt_template = \"\"\"Count the number of fruits from the following list:\n",
    "# {}\n",
    "# Answer:\"\"\"\n",
    "prompt_template = \"\"\"{}\n",
    "Count the number of vehicles from the following list:\n",
    "Answer: \"\"\"\n",
    "\n",
    "# prompt = prompt_template.format(\n",
    "#     \", \".join(options)\n",
    "# )\n",
    "\n",
    "prompt = prompt_template.format(\n",
    "    \"\\n\".join([f\"{idx+1}. {name}\" for idx, name in enumerate(options)])\n",
    ")\n",
    "\n",
    "print(prompt)\n",
    "\n",
    "predict_next_token(mt=mt, inputs=prompt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "79b596f4",
   "metadata": {},
   "outputs": [],
   "source": [
    "verify = verify_head_patterns(\n",
    "    prompt=prompt,\n",
    "    options=options,\n",
    "    pivot=\"fruit\",\n",
    "    mt=mt,\n",
    "    value_weighted=False,\n",
    "    generate_full_answer=True,\n",
    ")\n",
    "verify[\"predictions\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c37fb680",
   "metadata": {},
   "outputs": [],
   "source": [
    "verify_ind = verify_head_patterns(\n",
    "    prompt=prompt,\n",
    "    # options=options,\n",
    "    options=[f\"{op}\\n\" for op in options],\n",
    "    pivot=\"fruit\",\n",
    "    mt=mt,\n",
    "    value_weighted=False,\n",
    "    generate_full_answer=False,\n",
    "    ablate_possible_ans_info_from_options=True,\n",
    ")\n",
    "verify_ind[\"predictions\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3cb0cdf0",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "connection",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
