{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f8ec34ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bfae7235",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "\n",
    "import sys\n",
    "\n",
    "sys.path.append(\"../\")\n",
    "\n",
    "##################################################################\n",
    "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
    "# os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0,1,2,3,4,5,6,7\"\n",
    "##################################################################\n",
    "\n",
    "import logging\n",
    "from src.utils import logging_utils\n",
    "from src.utils import env_utils\n",
    "\n",
    "logger = logging.getLogger(__name__)\n",
    "\n",
    "logging.basicConfig(\n",
    "    level=logging.DEBUG,\n",
    "    format=logging_utils.DEFAULT_FORMAT,\n",
    "    datefmt=logging_utils.DEFAULT_DATEFMT,\n",
    "    stream=sys.stdout,\n",
    ")\n",
    "\n",
    "import torch\n",
    "import transformers\n",
    "\n",
    "logger.info(f\"{torch.__version__=}, {torch.version.cuda=}\")\n",
    "logger.info(\n",
    "    f\"{torch.cuda.is_available()=}, {torch.cuda.device_count()=}, {torch.cuda.get_device_name()=}\"\n",
    ")\n",
    "logger.info(f\"{transformers.__version__=}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc7720ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.utils.training_utils import get_device_map\n",
    "\n",
    "# model_key = \"meta-llama/Llama-3.2-3B\"\n",
    "# model_key = \"meta-llama/Llama-3.1-8B\"\n",
    "model_key = \"meta-llama/Llama-3.3-70B-Instruct\"\n",
    "# model_key = \"meta-llama/Llama-3.1-405B-Instruct\"\n",
    "\n",
    "# model_key = \"google/gemma-2-9b-it\"\n",
    "# model_key = \"google/gemma-3-12b-it\"\n",
    "# model_key = \"google/gemma-2-27b-it\"\n",
    "\n",
    "# model_key = \"deepseek-ai/DeepSeek-R1-Distill-Llama-8B\"\n",
    "\n",
    "# model_key = \"allenai/OLMo-2-1124-7B-Instruct\"\n",
    "# model_key = \"allenai/OLMo-7B-0424-hf\"\n",
    "\n",
    "# model_key = \"Qwen/Qwen2-7B\"\n",
    "# model_key = \"Qwen/Qwen2.5-14B-Instruct\"\n",
    "# model_key = \"Qwen/Qwen2.5-32B-Instruct\"\n",
    "# model_key = \"Qwen/Qwen2.5-72B-Instruct\"\n",
    "\n",
    "# model_key = \"Qwen/Qwen3-1.7B\"\n",
    "# model_key = \"Qwen/Qwen3-4B\"\n",
    "# model_key = \"Qwen/Qwen3-8B\"\n",
    "# model_key = \"Qwen/Qwen3-14B\"\n",
    "# model_key = \"Qwen/Qwen3-32B\"\n",
    "\n",
    "# device_map = get_device_map(model_key, 30, n_gpus=8)\n",
    "# device_map"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "683855df",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.models import ModelandTokenizer\n",
    "\n",
    "# from transformers import BitsAndBytesConfig\n",
    "\n",
    "mt = ModelandTokenizer(\n",
    "    model_key=model_key,\n",
    "    torch_dtype=torch.bfloat16,\n",
    "    # device_map=device_map,\n",
    "    device_map=\"auto\",\n",
    "    # quantization_config = BitsAndBytesConfig(\n",
    "    #     # load_in_4bit=True\n",
    "    #     load_in_8bit=True\n",
    "    # )\n",
    "    attn_implementation=\"eager\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9a7f22eb",
   "metadata": {},
   "source": [
    "## Selection Task"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a62d97cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.data import SelectOneTask\n",
    "\n",
    "select_prof = SelectOneTask.load(\n",
    "    path=\"data_save/selection/profession.json\"\n",
    ")\n",
    "\n",
    "print(select_prof)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c2aa1529",
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Literal\n",
    "\n",
    "output_formatting: Literal[\n",
    "    \"zero_shot\", # no formatting, model preference\n",
    "    \"object\", # Bill Gates\n",
    "    \"lettered\", # a. Bill Gates\n",
    "] = \"zero_shot\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "45087a7f",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt_template_idx = 1\n",
    "option_style = \"numbered\"  # \"numbered\", \"lettered\", \"ordinal\"\n",
    "\n",
    "# one_shot = select_prof.get_random_sample(\n",
    "#     mt = mt,\n",
    "#     prompt_template_idx=prompt_template_idx,\n",
    "#     option_style=\"numbered\",\n",
    "#     category=\"actor\",\n",
    "#     filter_by_lm_prediction = False,\n",
    "# )\n",
    "\n",
    "# print(one_shot)\n",
    "\n",
    "sample = select_prof.get_random_sample(\n",
    "    mt = mt,\n",
    "    obj_idx=3,\n",
    "    prompt_template_idx=prompt_template_idx,\n",
    "    option_style=option_style,\n",
    "    category=\"actor\",\n",
    "    filter_by_lm_prediction = True,\n",
    "    # output_formatting=output_formatting,\n",
    ")\n",
    "print(sample)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f66399b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# sample.prompt_template = select_prof.prompt_templates[1]\n",
    "\n",
    "print(f'\"{sample.prompt()}\"', \">>\", sample.obj)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f57d67c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# print(sample.prompt(option_style=\"single_line\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea0499b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.functional import generate_with_patch\n",
    "\n",
    "gen = generate_with_patch(\n",
    "    mt = mt,\n",
    "    inputs = sample.prompt(),\n",
    "    max_new_tokens=20,\n",
    "    do_sample=False,\n",
    "    remove_prefix=True\n",
    ")[0]\n",
    "print(f'\"{gen}\"', \">>\", sample.obj)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "53444ce3",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.functional import prepare_input, get_hs, interpret_logits\n",
    "from src.selection.utils import get_first_token_id, verify_correct_option\n",
    "\n",
    "# inputs = prepare_input(prompts=sample.prompt(), tokenizer=mt)\n",
    "# logit_module = (mt.lm_head_name, -1)\n",
    "# logits = get_hs(\n",
    "#     mt=mt,\n",
    "#     input=inputs,\n",
    "#     locations=[logit_module],\n",
    "#     return_dict=False,\n",
    "# ).squeeze()\n",
    "\n",
    "verify_correct_option(\n",
    "    mt=mt,\n",
    "    # logits=logits,\n",
    "    target=sample.obj,\n",
    "    options=sample.options,\n",
    "    input=sample.prompt()\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "85a9a2a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.experiments.utils import (\n",
    "    get_patches_to_verify_independent_enrichment,\n",
    "    verify_head_patterns,\n",
    ")\n",
    "\n",
    "HEADS = [\n",
    "    (33, 45),\n",
    "    (33, 18),\n",
    "    (34, 1),\n",
    "    (34, 6),\n",
    "    (34, 7),\n",
    "    (35, 19),\n",
    "    (39, 40),\n",
    "    (42, 30),\n",
    "    (47, 18),\n",
    "    (52, 58),\n",
    "]\n",
    "\n",
    "attn_pattern = verify_head_patterns(\n",
    "    prompt = sample.prompt(),\n",
    "    options = sample.options,\n",
    "    pivot = sample.subj,\n",
    "    mt = mt,\n",
    "    heads = HEADS\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "56273a20",
   "metadata": {},
   "source": [
    "## Odd one out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "abf391d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.data import SelectOddOneOutTask\n",
    "\n",
    "odd_one_out = SelectOddOneOutTask.load(\n",
    "    path=os.path.join(env_utils.DEFAULT_DATA_DIR, \"selection/profession.json\")\n",
    ")\n",
    "\n",
    "print(odd_one_out)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "03895c66",
   "metadata": {},
   "outputs": [],
   "source": [
    "sample = odd_one_out.get_random_sample(\n",
    "    mt = mt,\n",
    "    obj_idx=3,\n",
    "    prompt_template_idx=3,\n",
    "    option_style=option_style,\n",
    "    # category=\"actor\",\n",
    "    filter_by_lm_prediction = False,\n",
    "    # output_formatting=output_formatting,\n",
    ")\n",
    "print(sample)\n",
    "print(f'\"{sample.prompt()}\"', \">>\", sample.obj)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "55d00abc",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.functional import generate_with_patch\n",
    "\n",
    "gen = generate_with_patch(\n",
    "    mt = mt,\n",
    "    inputs = sample.prompt(),\n",
    "    max_new_tokens=20,\n",
    "    do_sample=False,\n",
    "    remove_prefix=True\n",
    ")[0]\n",
    "print(f'\"{gen}\"', \">>\", sample.obj)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "45c04622",
   "metadata": {},
   "outputs": [],
   "source": [
    "verify_correct_option(\n",
    "    mt=mt,\n",
    "    # logits=logits,\n",
    "    target=sample.obj,\n",
    "    options=sample.options,\n",
    "    input=sample.prompt()\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ae075932",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.experiments.utils import (\n",
    "    get_patches_to_verify_independent_enrichment,\n",
    "    verify_head_patterns,\n",
    ")\n",
    "\n",
    "HEADS = [\n",
    "    (33, 45),\n",
    "    (33, 18),\n",
    "    (34, 1),\n",
    "    (34, 6),\n",
    "    (34, 7),\n",
    "    (35, 19),\n",
    "    (39, 40),\n",
    "    (42, 30),\n",
    "    (47, 18),\n",
    "    (52, 58),\n",
    "]\n",
    "\n",
    "attn_pattern = verify_head_patterns(\n",
    "    prompt = sample.prompt(),\n",
    "    options = sample.options,\n",
    "    pivot = sample.subj,\n",
    "    mt = mt,\n",
    "    heads = HEADS\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5a1d928a",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71cbe3ce",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71656308",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "5256345c",
   "metadata": {},
   "source": [
    "## Counting Task"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ae6f1dce",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.data import CountingTask\n",
    "\n",
    "counting_fruits = CountingTask.load(\n",
    "    path=\"../data_save/counting/fruits.json\"\n",
    ")\n",
    "\n",
    "print(counting_fruits)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "270977e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "sample = counting_fruits.get_random_sample(\n",
    "    mt = mt,\n",
    "    prompt_template_idx=0,\n",
    "    option_style=\"single_line\",\n",
    "    category=\"fruits\",\n",
    "    filter_by_lm_prediction=True,\n",
    "    n_count=2,\n",
    "    n_distractors=3\n",
    ")\n",
    "\n",
    "print(sample)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "97f061f4",
   "metadata": {},
   "outputs": [],
   "source": [
    "sample.prompt_template = counting_fruits.prompt_templates[1]\n",
    "\n",
    "print(f'\"{sample.prompt()}\"', \">>\", sample.count)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6dfe3c73",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.functional import generate_with_patch\n",
    "\n",
    "gen = generate_with_patch(\n",
    "    mt = mt,\n",
    "    inputs = sample.prompt(),\n",
    "    max_new_tokens=20,\n",
    "    do_sample=False,\n",
    "    remove_prefix=True\n",
    ")[0]\n",
    "print(f'\"{gen}\"', \">>\", sample.count)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b5be303",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.experiments.utils import (\n",
    "    get_patches_to_verify_independent_enrichment,\n",
    "    verify_head_patterns,\n",
    ")\n",
    "\n",
    "HEADS = [\n",
    "    (33, 45),\n",
    "    (33, 18),\n",
    "    (34, 1),\n",
    "    (34, 6),\n",
    "    (34, 7),\n",
    "    (35, 19),\n",
    "    (39, 40),\n",
    "    (42, 30),\n",
    "    (47, 18),\n",
    "    (52, 58),\n",
    "]\n",
    "\n",
    "attn_pattern = verify_head_patterns(\n",
    "    prompt = sample.prompt(),\n",
    "    options = sample.options,\n",
    "    pivot = sample.category,\n",
    "    mt = mt,\n",
    "    heads = HEADS\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e7f8199f",
   "metadata": {},
   "source": [
    "## Deduction Task"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8354926b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.data import DeductionTask\n",
    "\n",
    "deduction_task = DeductionTask.load(\n",
    "    dir_path=\"../data_save/deduction\"\n",
    ")\n",
    "\n",
    "print(deduction_task)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11842fda",
   "metadata": {},
   "outputs": [],
   "source": [
    "sample = deduction_task.get_random_sample(\n",
    "    mt = mt,\n",
    "    topic_name = \"height\",\n",
    "    depth = 5,\n",
    ")\n",
    "\n",
    "print(sample)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d11c679",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.functional import generate_with_patch\n",
    "\n",
    "generate_with_patch(\n",
    "    mt = mt,\n",
    "    inputs = sample.prompt,\n",
    "    max_new_tokens=20,\n",
    "    do_sample=False,\n",
    "    remove_prefix=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a4b82159",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.experiments.utils import (\n",
    "    get_patches_to_verify_independent_enrichment,\n",
    "    verify_head_patterns,\n",
    ")\n",
    "\n",
    "HEADS = [\n",
    "    (33, 45),\n",
    "    (33, 18),\n",
    "    (34, 1),\n",
    "    (34, 6),\n",
    "    (34, 7),\n",
    "    (35, 19),\n",
    "    (39, 40),\n",
    "    (42, 30),\n",
    "    (47, 18),\n",
    "    (52, 58),\n",
    "]\n",
    "\n",
    "attn_pattern = verify_head_patterns(\n",
    "    prompt = sample.prompt,\n",
    "    options = [\"Alice\", \"Bob\", \"Cam\", \"Dave\", \"Eli\"],\n",
    "    pivot = sample.answer,\n",
    "    mt = mt,\n",
    "    heads = HEADS\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fa64ccfa",
   "metadata": {},
   "source": [
    "## All of the Above Selection"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "615569a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.data import SelectAllTask\n",
    "\n",
    "select_all_prof = SelectAllTask.load(\n",
    "    path=\"../data_save/selection/profession.json\"\n",
    ")\n",
    "\n",
    "print(select_all_prof)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "88daf1df",
   "metadata": {},
   "outputs": [],
   "source": [
    "sample = select_all_prof.get_random_sample(\n",
    "    mt=mt\n",
    ")\n",
    "\n",
    "print(sample)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10763c31",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.functional import generate_with_patch\n",
    "\n",
    "generate_with_patch(\n",
    "    mt = mt,\n",
    "    inputs = sample.prompt(),\n",
    "    n_gen_per_prompt=1,\n",
    "    max_new_tokens=20,\n",
    "    do_sample=False,\n",
    "    remove_prefix=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6e212453",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.experiments.utils import (\n",
    "    get_patches_to_verify_independent_enrichment,\n",
    "    verify_head_patterns,\n",
    ")\n",
    "\n",
    "HEADS = [\n",
    "    (33, 45),\n",
    "    (33, 18),\n",
    "    (34, 1),\n",
    "    (34, 6),\n",
    "    (34, 7),\n",
    "    (35, 19),\n",
    "    (39, 40),\n",
    "    (42, 30),\n",
    "    (47, 18),\n",
    "    (52, 58),\n",
    "]\n",
    "\n",
    "attn_pattern = verify_head_patterns(\n",
    "    prompt = sample.prompt(),\n",
    "    options = sample.options,\n",
    "    pivot = sample.category,\n",
    "    mt = mt,\n",
    "    heads = HEADS\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "990dbdb9",
   "metadata": {},
   "source": [
    "## MISC"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc6f686e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from itertools import product\n",
    "import random\n",
    "\n",
    "all_heads = list(product(range(mt.n_layer), range(mt.config.num_attention_heads)))\n",
    "\n",
    "random_heads = random.sample(\n",
    "    all_heads,\n",
    "    k = 50\n",
    ")\n",
    "random_heads"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f1951dc1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# instruction = \"\"\"Instructions: 1. Track the belief of each character as described in the story. 2. A character's belief is formed only when they perform an action themselves or can observe the action taking place. 3. A character does not have any belief about the container or its content which they cannot observe directly. 4. To answer the question, predict only the final state of the queried container in fewest tokens possible, strictly based on the belief of the character, mentioned in the question. 5. Do not predict the entire sentence with character or container as the final output.\\n\\n\"\"\"\n",
    "\n",
    "# prompt = instruction + \"\"\"Story: Karen and Max are working in a busy restaurant. To complete an order, Karen grabs an opaque flute and fills it with soda. Then Max grabs another opaque jar and fills it with coffee. Max cannot observe Karen's actions. Karen cannot observe Max's actions.\n",
    "# Question: What does Karen believe the jar contains?\n",
    "# Answer:\"\"\"\n",
    "\n",
    "# prompt = \"\"\"1. Bus\n",
    "# 2. Book\n",
    "# 3. Cup\n",
    "# 4. Plate\n",
    "# 5. Glass\n",
    "# 6. None of the above\n",
    "# Which one of the objects mentioned above is a fruit?\n",
    "# Answer:\"\"\"\n",
    "\n",
    "# prompt = \"\"\"1. Peach\n",
    "# 2. Apple\n",
    "# 3. Banana\n",
    "# 4. Orange\n",
    "# 5. Grapes\n",
    "# 6. All of the above\n",
    "# Which one of the objects mentioned above is a fruit?\n",
    "# Answer:\"\"\"\n",
    "\n",
    "prompt = \"\"\"Items: Tea, Mango, Coffee, Orange, Transistor, Water, Kiwi, Cup\n",
    "Find the sixth item in the list?\n",
    "Answer:\"\"\"\n",
    "\n",
    "verify_head_patterns(\n",
    "    prompt=prompt,\n",
    "    options=[\n",
    "        \"Tea\",\n",
    "        \"Mango\",\n",
    "        \"Coffee\",\n",
    "        \"Orange\",\n",
    "        \"Juice\",\n",
    "        \"Water\",\n",
    "        \"Kiwi\",\n",
    "        \"Cup\"\n",
    "    ],\n",
    "    # ablate_possible_ans_info_from_options=True,\n",
    "    pivot=\"cup\",\n",
    "    mt=mt,\n",
    "    heads=HEADS,\n",
    "    # heads = [(33, 18), (35, 19), (47, 18)],\n",
    "    # visualize_individual_heads=True,\n",
    "    # heads= random_heads,\n",
    "    # generate_full_answer=True,\n",
    "    bare_prompt_template=\" {}\"\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "connection",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
