{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "70b92923",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "11789602",
   "metadata": {},
   "source": [
    "## Possible Applications of the probing stuff\n",
    "\n",
    "* Apply the probes across some free-form text. See if they respond to the semantic information of something like \"Fruit\". The probe direction would be $q_{\\text{fruit}} @ \\text{W}_k^T$\n",
    "    * What could be a setup where the ground truth isn't obvious? Or, hard to get? Checkout Martin's modeling of users setup?\n",
    "\n",
    "* If we cluster the query states based on cosine similarity, do we see any patterns? Maybe some of the heads are doing more of semantic stuff, while other are doing something to do with task / formatting?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "983a4427",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "\n",
    "import sys\n",
    "\n",
    "sys.path.append(\"../\")\n",
    "\n",
    "##################################################################\n",
    "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1,2,3,4,5,6,7\"\n",
    "##################################################################\n",
    "\n",
    "import logging\n",
    "from src.utils import logging_utils\n",
    "from src.utils import env_utils\n",
    "\n",
    "logger = logging.getLogger(__name__)\n",
    "\n",
    "logging.basicConfig(\n",
    "    level=logging.DEBUG,\n",
    "    format=logging_utils.DEFAULT_FORMAT,\n",
    "    datefmt=logging_utils.DEFAULT_DATEFMT,\n",
    "    stream=sys.stdout,\n",
    ")\n",
    "\n",
    "import torch\n",
    "import transformers\n",
    "\n",
    "logger.info(f\"{torch.__version__=}, {torch.version.cuda=}\")\n",
    "logger.info(\n",
    "    f\"{torch.cuda.is_available()=}, {torch.cuda.device_count()=}, {torch.cuda.get_device_name()=}\"\n",
    ")\n",
    "logger.info(f\"{transformers.__version__=}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "03464cab",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.utils.training_utils import get_device_map\n",
    "\n",
    "# model_key = \"meta-llama/Llama-3.2-3B\"\n",
    "# model_key = \"meta-llama/Llama-3.1-8B\"\n",
    "# model_key = \"meta-llama/Llama-3.1-70B-Instruct\"\n",
    "model_key = \"meta-llama/Llama-3.3-70B-Instruct\"\n",
    "# model_key = \"meta-llama/Llama-3.1-405B-Instruct\"\n",
    "\n",
    "# model_key = \"google/gemma-2-9b-it\"\n",
    "# model_key = \"google/gemma-3-12b-it\"\n",
    "# model_key = \"google/gemma-2-27b-it\"\n",
    "\n",
    "# model_key = \"deepseek-ai/DeepSeek-R1-Distill-Llama-8B\"\n",
    "\n",
    "# model_key = \"allenai/OLMo-2-1124-7B-Instruct\"\n",
    "# model_key = \"allenai/OLMo-7B-0424-hf\"\n",
    "\n",
    "# model_key = \"Qwen/Qwen2-7B\"\n",
    "# model_key = \"Qwen/Qwen2.5-14B-Instruct\"\n",
    "# model_key = \"Qwen/Qwen2.5-32B-Instruct\"\n",
    "# model_key = \"Qwen/Qwen2.5-72B-Instruct\"\n",
    "\n",
    "# model_key = \"Qwen/Qwen3-1.7B\"\n",
    "# model_key = \"Qwen/Qwen3-4B\"\n",
    "# model_key = \"Qwen/Qwen3-8B\"\n",
    "# model_key = \"Qwen/Qwen3-14B\"\n",
    "# model_key = \"Qwen/Qwen3-32B\"\n",
    "\n",
    "# device_map = get_device_map(model_key, 30, n_gpus=8)\n",
    "# device_map"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1a88f8c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.models import ModelandTokenizer\n",
    "\n",
    "# from transformers import BitsAndBytesConfig\n",
    "\n",
    "mt = ModelandTokenizer(\n",
    "    model_key=model_key,\n",
    "    torch_dtype=torch.bfloat16,\n",
    "    # device_map=device_map,\n",
    "    device_map=\"auto\",\n",
    "    # quantization_config = BitsAndBytesConfig(\n",
    "    #     # load_in_4bit=True\n",
    "    #     load_in_8bit=True\n",
    "    # )\n",
    "    attn_implementation=\"eager\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "45b1c057",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.data import SelectOneTask, SelectOrderTask\n",
    "\n",
    "#################################################################################\n",
    "# TASK_CLS = SelectOrderTask\n",
    "# prompt_template_idx = 1\n",
    "TASK_CLS = SelectOneTask\n",
    "prompt_template_idx = 2\n",
    "N_DISTRACTORS = 5\n",
    "OPTION_STYLE = \"single_line\"\n",
    "#################################################################################\n",
    "\n",
    "select_task = TASK_CLS.load(\n",
    "    path=os.path.join(\n",
    "        env_utils.DEFAULT_DATA_DIR, \n",
    "        \"selection\", \n",
    "        # \"profession.json\"\n",
    "        # \"nationality.json\"\n",
    "        \"objects.json\"\n",
    "    )\n",
    ")\n",
    "\n",
    "print(select_task)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1711aade",
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib import pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "# optimized_path = os.path.join(\n",
    "#     env_utils.DEFAULT_RESULTS_DIR,\n",
    "#     \"selection/optimized_heads\",\n",
    "#     mt.name.split(\"/\")[-1],\n",
    "#     f\"{select_task.task_name}.npz\"\n",
    "# )\n",
    "\n",
    "# optimized_path = os.path.join(\n",
    "#     env_utils.DEFAULT_RESULTS_DIR,\n",
    "#     \"selection/optimized_heads\",\n",
    "#     model_key.split(\"/\")[-1],\n",
    "#     # \"distinct_options\",\n",
    "#     f\"{select_task.task_name}\",\n",
    "#     \"epoch_10.npz\"\n",
    "# )\n",
    "\n",
    "optimized_path = os.path.join(\n",
    "    env_utils.DEFAULT_RESULTS_DIR,\n",
    "    \"selection/optimized_heads\",\n",
    "    model_key.split(\"/\")[-1],\n",
    "    \"distinct_options\",\n",
    "    # f\"{select_task.task_name}\",\n",
    "    \"select_one\",\n",
    "    # \"legacy\",\n",
    "    \"epoch_10.npz\"\n",
    ")\n",
    "\n",
    "\n",
    "\n",
    "optimization_results = np.load(optimized_path, allow_pickle=True)\n",
    "plt.plot(optimization_results[\"losses\"])\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f5b0ca7",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(20, 10))\n",
    "\n",
    "optimal_head_mask = torch.tensor(optimization_results[\"optimal_mask\"]).to(torch.float32)\n",
    "optimal_head_mask[52:, :] = 0.0\n",
    "\n",
    "plt.imshow(\n",
    "    optimal_head_mask.T.numpy(),\n",
    "    cmap=\"Blues\",\n",
    "    aspect=\"auto\",\n",
    "    vmin=0,\n",
    "    vmax=1,\n",
    ")\n",
    "\n",
    "heads_selected = torch.nonzero(optimal_head_mask > 0.5, as_tuple=False).tolist()\n",
    "heads_selected = [\n",
    "    (layer_idx, head_idx) for layer_idx, head_idx in heads_selected\n",
    "]\n",
    "print(len(heads_selected))\n",
    "\n",
    "# HEADS = heads_selected\n",
    "\n",
    "# (35, 19) in HEADS, (35, 19) in heads_selected"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9b084ae8",
   "metadata": {},
   "source": [
    "## Probe without any training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "18555afc",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm.auto import tqdm\n",
    "from src.selection.functional import cache_q_projections\n",
    "from src.tokens import prepare_input, find_token_range\n",
    "from src.utils.typing import TokenizerOutput\n",
    "\n",
    "###########################################\n",
    "limit = 10\n",
    "prompt_template_idx=3\n",
    "option_style=\"single_line\"\n",
    "###########################################\n",
    "\n",
    "def get_probe_directions():\n",
    "    logger.info(\"Getting the samples for caching the predicate\")\n",
    "    category_wise_samples = {}\n",
    "    for category in tqdm(select_task.categories):\n",
    "        category_wise_samples[category] = []\n",
    "        for idx in range(limit):\n",
    "            sample = select_task.get_random_sample(\n",
    "                mt=mt,\n",
    "                category=category,\n",
    "                prompt_template_idx=prompt_template_idx,\n",
    "                option_style=option_style,\n",
    "                n_distractors=N_DISTRACTORS,\n",
    "                filter_by_lm_prediction=True\n",
    "            )\n",
    "            category_wise_samples[category].append(sample)\n",
    "\n",
    "    logger.info(\"Caching the predicate directions\")\n",
    "    category_wise_q_states = {category: None for category in select_task.categories}\n",
    "    for category in select_task.categories:\n",
    "        prompts = [sample.prompt() for sample in category_wise_samples[category]]\n",
    "        tokenized_prompts = prepare_input(\n",
    "            prompts=prompts,\n",
    "            tokenizer=mt.tokenizer,\n",
    "        )\n",
    "        category_wise_q_states[category] = cache_q_projections(\n",
    "            mt=mt,\n",
    "            input=tokenized_prompts,\n",
    "            heads=heads_selected,\n",
    "            token_indices=[[-3, -2, -1] for _ in range(len(prompts))],\n",
    "            projection_signature=\".q_proj\"\n",
    "        )\n",
    "    \n",
    "    logger.info(\"Computing the mean\")\n",
    "    category_wise_q_mean = {category: {} for category in select_task.categories}\n",
    "    for category in select_task.categories:\n",
    "        keys = list(category_wise_q_states[category][0].keys())\n",
    "        for key in keys:\n",
    "            q_states = torch.stack(\n",
    "                [sample_states[key] for sample_states in category_wise_q_states[category]]\n",
    "            )\n",
    "            category_wise_q_mean[category][key] = q_states.mean(dim=0)\n",
    "            # print(category, key, q_states.shape,category_wise_q_mean[category][key].shape)\n",
    "\n",
    "    return category_wise_q_mean"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb768320",
   "metadata": {},
   "outputs": [],
   "source": [
    "category_wise_q_mean = get_probe_directions()\n",
    "category_wise_q_mean[\"fruit\"][(35, 19, -1)].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c0480208",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.functional import PatchSpec\n",
    "patches_per_category = {category: [] for category in select_task.categories}\n",
    "for category in select_task.categories:\n",
    "    patches = []\n",
    "    for (layer_idx, head_idx, query_idx), q_proj in category_wise_q_mean[category].items():\n",
    "        patches.append(\n",
    "            PatchSpec(\n",
    "                location=(\n",
    "                    mt.attn_module_name_format.format(layer_idx) + \".q_proj\",\n",
    "                    head_idx,\n",
    "                    query_idx\n",
    "                ),\n",
    "                patch=q_proj\n",
    "            )\n",
    "        )\n",
    "    patches_per_category[category] = patches"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2fe970cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "from src.selection.functional import verify_head_patterns\n",
    "\n",
    "category=\"fruit\"\n",
    "test_sample = select_task.get_random_sample(\n",
    "    mt=mt,\n",
    "    category=random.choice(list(set(select_task.categories) - {category})),\n",
    "    prompt_template_idx=prompt_template_idx,\n",
    "    option_style=option_style,\n",
    "    n_distractors=N_DISTRACTORS,\n",
    "    filter_by_lm_prediction=True,\n",
    "    obj_idx=4,\n",
    "    insert_distractor=[(\"Apple\", 2)]\n",
    ")\n",
    "print(test_sample.prompt())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "440f5799",
   "metadata": {},
   "outputs": [],
   "source": [
    "clean_run = verify_head_patterns(\n",
    "    prompt=test_sample.prompt(),\n",
    "    mt=mt,\n",
    "    # heads=[(35, 19)],\n",
    "    heads=heads_selected,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c383542",
   "metadata": {},
   "outputs": [],
   "source": [
    "patched_run = verify_head_patterns(\n",
    "    prompt=test_sample.prompt(),\n",
    "    mt=mt,\n",
    "    # heads=[(35, 19)],\n",
    "    heads=heads_selected,\n",
    "    query_patches=patches_per_category[category]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eaf01f04",
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Literal\n",
    "\n",
    "def apply_probe_in_place(\n",
    "    obj: str,\n",
    "    mt: ModelandTokenizer,\n",
    "    heads: list[tuple[int, int]],\n",
    "    category_wise_probes: dict[tuple[int, int, int], torch.Tensor],\n",
    "    prompt_template:str = \"Option: {}\",\n",
    "    metric: Literal[\"cosine\", \"dot\"] = \"dot\",\n",
    "):\n",
    "    probe_prompt = prompt_template.format(obj)\n",
    "    probe_tokenized = prepare_input(\n",
    "        prompts=[probe_prompt],\n",
    "        tokenizer=mt.tokenizer,\n",
    "    )\n",
    "    key_states = cache_q_projections(\n",
    "        mt=mt,\n",
    "        input=probe_tokenized,\n",
    "        heads=heads,\n",
    "        token_indices=[[-1]],\n",
    "        projection_signature=\".k_proj\"\n",
    "    )\n",
    "    category_wise_scores = {category: [] for category in category_wise_probes.keys()}\n",
    "    for layer_idx, head_idx in heads:\n",
    "        key_state = key_states[0][(layer_idx, head_idx, -1)]  # (d_head,)\n",
    "        for category in category_wise_probes.keys():\n",
    "            probe = category_wise_probes[category][(layer_idx, head_idx, -1)]\n",
    "            if metric == \"dot\":\n",
    "                score = torch.dot(key_state, probe).item()\n",
    "            elif metric == \"cosine\":\n",
    "                score = torch.nn.functional.cosine_similarity(key_state, probe, dim=0).item()\n",
    "            else:\n",
    "                raise ValueError(f\"Unsupported metric: {metric}\")\n",
    "\n",
    "            category_wise_scores[category].append(score)\n",
    "    category_wise_scores = {\n",
    "        category: sum(scores) / len(scores)\n",
    "        for category, scores in category_wise_scores.items()\n",
    "    }\n",
    "    prediction = sorted(\n",
    "        category_wise_scores.items(),\n",
    "        key=lambda x: x[1],\n",
    "        reverse=True\n",
    "    )\n",
    "    return prediction[0][0], prediction\n",
    "\n",
    "\n",
    "obj = \"airplane\"\n",
    "# obj = incorrect_predictions[3][0]\n",
    "print(f\"Object: {obj}\")\n",
    "apply_probe_in_place(\n",
    "    obj=obj,\n",
    "    mt=mt,\n",
    "    heads=[(35, 19)],\n",
    "    # heads=heads_selected,\n",
    "    category_wise_probes=category_wise_q_mean,\n",
    "    # prompt_template=\"Option: {}\",\n",
    "    prompt_template=\"\"\"\n",
    "Apple -> fruit\n",
    "Eagle -> bird\n",
    "{}\"\"\",\n",
    "    metric=\"cosine\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1392783e",
   "metadata": {},
   "source": [
    "### In-place probe accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bcc8477a",
   "metadata": {},
   "outputs": [],
   "source": [
    "len(select_task.categories)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d5d2b194",
   "metadata": {},
   "outputs": [],
   "source": [
    "objects = []\n",
    "for category in select_task.categories:\n",
    "    objects.extend([(obj, category) for obj in select_task.category_wise_examples[category]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ff34cffc",
   "metadata": {},
   "outputs": [],
   "source": [
    "correct_predictions = []\n",
    "incorrect_predictions = []\n",
    "\n",
    "prompt_template = \"\"\"\n",
    "Apple -> fruit\n",
    "Eagle -> bird\n",
    "{}\"\"\"\n",
    "\n",
    "# prompt_template = \"We know that {}\" \n",
    "\n",
    "for obj, category in tqdm(objects):\n",
    "    prediction, scores = apply_probe_in_place(\n",
    "        obj=obj,\n",
    "        mt=mt,\n",
    "        heads=[(35, 19)],\n",
    "        # heads=heads_selected,\n",
    "        category_wise_probes=category_wise_q_mean,\n",
    "        prompt_template=prompt_template,\n",
    "        metric=\"cosine\",\n",
    "    )\n",
    "    if prediction == category:\n",
    "        correct_predictions.append((obj, category, scores))\n",
    "    else:\n",
    "        incorrect_predictions.append((obj, category, prediction, scores))\n",
    "\n",
    "accuracy = len(correct_predictions) / len(objects)\n",
    "print(f\"Accuracy: {accuracy*100:.2f}% ({len(correct_predictions)}/{len(objects)})\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab8b6c1b",
   "metadata": {},
   "outputs": [],
   "source": [
    "N_TRIALS = 10\n",
    "TRIAL_RESULTS = []\n",
    "for trial in range(N_TRIALS):\n",
    "    logger.info(f\"Trial {trial+1}/{N_TRIALS}\")\n",
    "    probe_directions = get_probe_directions()\n",
    "\n",
    "    logger.info(\"Evaluating\")\n",
    "    n_correct = 0\n",
    "    for obj, category in tqdm(objects):\n",
    "        prediction, scores = apply_probe_in_place(\n",
    "            obj=obj,\n",
    "            mt=mt,\n",
    "            heads=[(35, 19)],\n",
    "            # heads=heads_selected,\n",
    "            category_wise_probes=probe_directions,\n",
    "            prompt_template=prompt_template,\n",
    "            metric=\"cosine\",\n",
    "        )\n",
    "        if prediction == category:\n",
    "            n_correct += 1\n",
    "    accuracy = n_correct / len(objects)\n",
    "    logger.info(\"-\"*50)\n",
    "    logger.info(f\"Accuracy: {accuracy*100:.2f}% ({n_correct}/{len(objects)})\")\n",
    "    logger.info(\"-\"*50)\n",
    "    TRIAL_RESULTS.append(accuracy)\n",
    "\n",
    "TRIAL_RESULTS = np.array(TRIAL_RESULTS)\n",
    "logger.info(f\"Mean Accuracy: {TRIAL_RESULTS.mean():.2f} ± {TRIAL_RESULTS.std():.2f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "086be47e",
   "metadata": {},
   "source": [
    "### Out of place probe accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4fec04b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.utils import get_first_token_id\n",
    "from src.functional import get_hs, logit_lens\n",
    "from baukit import get_module\n",
    "\n",
    "prompt_template = \"\"\"\n",
    "Apple -> fruit\n",
    "Eagle -> bird\n",
    "{}\"\"\"\n",
    "\n",
    "obj = \"Train\"\n",
    "\n",
    "locations = []\n",
    "for layer_idx in range(mt.n_layer):\n",
    "    locations.append((mt.layer_name_format.format(layer_idx) + \".input_layernorm\", -1))\n",
    "    locations.append((mt.layer_name_format.format(layer_idx), -1))\n",
    "\n",
    "# print(locations)\n",
    "\n",
    "hs = get_hs(\n",
    "    mt=mt,\n",
    "    input=prompt_template.format(obj),\n",
    "    locations=locations,\n",
    "    return_dict=True,\n",
    ")\n",
    "\n",
    "h = hs[(mt.layer_name_format.format(35), -1)]\n",
    "logit_lens(\n",
    "    mt=mt,\n",
    "    h=h,\n",
    "    interested_tokens=[\n",
    "        get_first_token_id(tokenizer=mt.tokenizer, name=category)\n",
    "        for category in select_task.categories\n",
    "    ],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "835038f7",
   "metadata": {},
   "outputs": [],
   "source": [
    "layer_idx, head_idx = 35, 19\n",
    "\n",
    "key_states = cache_q_projections(\n",
    "    mt=mt,\n",
    "    input=prepare_input(\n",
    "        prompts=[prompt_template.format(obj)],\n",
    "        tokenizer=mt.tokenizer,\n",
    "    ),\n",
    "    heads=[(layer_idx, head_idx)],\n",
    "    token_indices=[[-1]],\n",
    "    projection_signature=\".k_proj\",\n",
    ")\n",
    "# key_states[0][(layer_idx, head_idx, -1)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "51f78245",
   "metadata": {},
   "outputs": [],
   "source": [
    "def apply_k_proj(\n",
    "    mt: ModelandTokenizer,\n",
    "    h: torch.Tensor,\n",
    "    layer_idx: int,\n",
    "    head_idx: int,\n",
    "):\n",
    "    attn_module = get_module(mt._model, mt.attn_module_name_format.format(layer_idx))\n",
    "    input_ln = get_module(mt._model, mt.layer_name_format.format(layer_idx) + \".input_layernorm\")\n",
    "\n",
    "    key_proj_weight = attn_module.k_proj.weight\n",
    "    # query_proj_weight = attn_module.q_proj.weight\n",
    "\n",
    "    num_key_value_heads = mt.config.num_key_value_heads\n",
    "    num_attention_heads = mt.config.num_attention_heads\n",
    "    head_dim = attn_module.head_dim\n",
    "\n",
    "    key_states = torch.matmul(key_proj_weight, input_ln(h))\n",
    "    # key_states = torch.matmul(key_proj_weight.to(h.device), h)\n",
    "    key_states = key_states.view(num_key_value_heads, head_dim)\n",
    "\n",
    "    num_kv_groups = num_attention_heads // num_key_value_heads\n",
    "    kv_head_idx = head_idx // num_kv_groups\n",
    "\n",
    "    key_state = key_states[kv_head_idx]\n",
    "    return key_state\n",
    "\n",
    "h = hs[(mt.layer_name_format.format(layer_idx - 1), -1)]\n",
    "key_state = apply_k_proj(mt, h, layer_idx=layer_idx, head_idx=head_idx)\n",
    "\n",
    "torch.allclose(key_state, key_states[0][(layer_idx, head_idx, -1)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5eae205c",
   "metadata": {},
   "outputs": [],
   "source": [
    "from baukit import get_module\n",
    "from src.functional import get_hs, logit_lens\n",
    "\n",
    "def apply_probe_out_of_place(\n",
    "    h: torch.Tensor,\n",
    "    mt: ModelandTokenizer,\n",
    "    heads: list[tuple[int, int]],\n",
    "    category_wise_probes: dict[tuple[int, int, int], torch.Tensor],\n",
    "    metric: Literal[\"cosine\", \"dot\"] = \"dot\",\n",
    "):\n",
    "    category_wise_scores = {category: [] for category in category_wise_probes.keys()}\n",
    "    for layer_idx, head_idx in heads:\n",
    "        key_state = apply_k_proj(mt, h, layer_idx=layer_idx, head_idx=head_idx)\n",
    "        \n",
    "        for category in category_wise_probes.keys():\n",
    "            probe = category_wise_probes[category][(layer_idx, head_idx, -1)].to(key_state.device)\n",
    "            if metric == \"dot\":\n",
    "                score = torch.dot(key_state, probe).item()\n",
    "            elif metric == \"cosine\":\n",
    "                score = torch.nn.functional.cosine_similarity(key_state, probe, dim=0).item()\n",
    "            else:\n",
    "                raise ValueError(f\"Unsupported metric: {metric}\")\n",
    "\n",
    "            category_wise_scores[category].append(score)\n",
    "    category_wise_scores = {\n",
    "        category: sum(scores) / len(scores)\n",
    "        for category, scores in category_wise_scores.items()\n",
    "    }\n",
    "    prediction = sorted(\n",
    "        category_wise_scores.items(),\n",
    "        key=lambda x: x[1],\n",
    "        reverse=True\n",
    "    )\n",
    "    return prediction[0][0], prediction\n",
    "\n",
    "apply_probe_out_of_place(\n",
    "    h=hs[(mt.layer_name_format.format(34), -1)],\n",
    "    mt=mt,\n",
    "    heads=[(35, 19)],\n",
    "    # heads=heads_selected,\n",
    "    category_wise_probes=category_wise_q_mean,\n",
    "    metric=\"cosine\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d1b332a",
   "metadata": {},
   "outputs": [],
   "source": [
    "apply_probe_in_place(\n",
    "    obj=obj,\n",
    "    mt=mt,\n",
    "    heads=[(35, 19)],\n",
    "    # heads=heads_selected,\n",
    "    category_wise_probes=category_wise_q_mean,\n",
    "    # prompt_template=\"Option: {}\",\n",
    "    prompt_template=prompt_template,\n",
    "    metric=\"cosine\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fbb067dc",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt_template = \"\"\"\n",
    "Apple -> fruit\n",
    "Eagle -> bird\n",
    "{}\"\"\"\n",
    "\n",
    "objects = []\n",
    "for category in select_task.categories:\n",
    "    objects.extend([(obj, category) for obj in select_task.category_wise_examples[category]])\n",
    "\n",
    "random.shuffle(objects)\n",
    "\n",
    "out_of_place_results = []\n",
    "for obj, category in tqdm(objects):\n",
    "    hs = get_hs(\n",
    "        mt=mt,\n",
    "        input=prompt_template.format(obj),\n",
    "        locations=[(mt.layer_name_format.format(layer_idx), -1) for layer_idx in range(mt.n_layer)],\n",
    "        return_dict=True,\n",
    "    )\n",
    "\n",
    "    # print(f\"Object: {obj}, True Category: {category}\")\n",
    "    interested_tokens = [\n",
    "        get_first_token_id(tokenizer=mt.tokenizer, name=category)\n",
    "        for category in select_task.categories\n",
    "    ]\n",
    "    layer_wise_results = {}\n",
    "    for layer_idx in range(mt.n_layer):\n",
    "        prediction, scores = apply_probe_out_of_place(\n",
    "            h=hs[(mt.layer_name_format.format(layer_idx), -1)],\n",
    "            mt=mt,\n",
    "            heads=[(35, 19)],\n",
    "            # heads=heads_selected,\n",
    "            category_wise_probes=category_wise_q_mean,\n",
    "            metric=\"cosine\",\n",
    "        )\n",
    "        # ll_pred, ll_track = logit_lens(\n",
    "        #     mt=mt,\n",
    "        #     h=hs[(mt.layer_name_format.format(layer_idx), -1)],\n",
    "        #     interested_tokens=interested_tokens,\n",
    "        # )\n",
    "        layer_wise_results[layer_idx] = {\n",
    "            \"predicted_category\": prediction,\n",
    "            \"scores\": scores,\n",
    "            # \"ll_pred\": ll_pred,\n",
    "            # \"ll_track\": ll_track\n",
    "        }\n",
    "    out_of_place_results.append({\n",
    "        \"object\": obj,\n",
    "        \"category\": category,\n",
    "        \"layer_wise_results\": layer_wise_results\n",
    "    })"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "65bcadba",
   "metadata": {},
   "outputs": [],
   "source": [
    "ll_baseline = []\n",
    "\n",
    "for obj, category in tqdm(objects[:100]):\n",
    "    hs = get_hs(\n",
    "        mt=mt,\n",
    "        input=prompt_template.format(obj),\n",
    "        locations=[(mt.layer_name_format.format(layer_idx), -1) for layer_idx in range(mt.n_layer)],\n",
    "        return_dict=True,\n",
    "    )\n",
    "    layers = list(range(0, mt.n_layer, 1)) + [mt.n_layer - 1]\n",
    "    layers = sorted(list(set(layers)))\n",
    "    layer_wise_results = {}\n",
    "    for layer_idx in layers:\n",
    "        ll_pred, ll_track = logit_lens(\n",
    "            mt=mt,\n",
    "            h=hs[(mt.layer_name_format.format(layer_idx), -1)],\n",
    "            interested_tokens=[\n",
    "                get_first_token_id(tokenizer=mt.tokenizer, name=category)\n",
    "                for category in select_task.categories\n",
    "            ],\n",
    "        )\n",
    "        layer_wise_results[layer_idx] = {\n",
    "            \"ll_pred\": ll_pred,\n",
    "            \"ll_track\": ll_track\n",
    "        }\n",
    "    ll_baseline.append({\n",
    "        \"object\": obj,\n",
    "        \"category\": category,\n",
    "        \"layer_wise_results\": layer_wise_results\n",
    "    })"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "64ffa323",
   "metadata": {},
   "outputs": [],
   "source": [
    "ll_baseline[0][\"layer_wise_results\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea4ec1d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# class StrEncoder(json.JSONEncoder):\n",
    "#     def default(self, obj):\n",
    "#         try:\n",
    "#             return super().default(obj)\n",
    "#         except TypeError:\n",
    "#             return str(obj)\n",
    "\n",
    "# with open(\"probe_baseline_logit_lens.json\", \"w\") as f:\n",
    "#     json.dump(ll_baseline, f, cls=StrEncoder, indent=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "326ae1a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib import pyplot as plt\n",
    "\n",
    "layer_wise_accuracy = {layer_idx: 0 for layer_idx in range(mt.n_layer)}\n",
    "for result in out_of_place_results:\n",
    "    category = result[\"category\"]\n",
    "    for layer_idx, layer_result in result[\"layer_wise_results\"].items():\n",
    "        if layer_result[\"predicted_category\"] == category:\n",
    "            layer_wise_accuracy[layer_idx] += 1\n",
    "\n",
    "layer_wise_accuracy = {layer_idx: acc / len(out_of_place_results) for layer_idx, acc in layer_wise_accuracy.items()}\n",
    "\n",
    "ll_baseline_accuracy = {layer_idx: 0 for layer_idx in range(mt.n_layer)}\n",
    "for result in ll_baseline:\n",
    "    category = result[\"category\"]\n",
    "    category_token_id = get_first_token_id(tokenizer=mt.tokenizer, name=category)\n",
    "    for layer_idx, layer_result in result[\"layer_wise_results\"].items():\n",
    "        ll_track = layer_result[\"ll_track\"]\n",
    "        predicted_category_token = list(ll_track.keys())[0]\n",
    "        if predicted_category_token == category_token_id:\n",
    "            ll_baseline_accuracy[layer_idx] += 1\n",
    "\n",
    "ll_baseline_accuracy = {layer_idx: acc / len(ll_baseline) for layer_idx, acc in ll_baseline_accuracy.items()}\n",
    "\n",
    "plt.figure(figsize=(10, 5))\n",
    "plt.plot(list(layer_wise_accuracy.keys()), list(layer_wise_accuracy.values()), label=\"Out of Place\")\n",
    "plt.plot(list(ll_baseline_accuracy.keys()), list(ll_baseline_accuracy.values()), label=\"Logit Lens Baseline\")\n",
    "plt.xlabel(\"Layer Index\")\n",
    "plt.ylabel(\"Accuracy\")\n",
    "plt.ylim(0, 1)\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "483208b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "probe_performance = {\n",
    "    \"out_of_place\": layer_wise_accuracy,\n",
    "    \"logit_lens_baseline\": ll_baseline_accuracy\n",
    "}\n",
    "figure_path = os.path.join(\"figures\", model_key.split(\"/\")[-1], \"raw\")\n",
    "os.makedirs(figure_path, exist_ok=True)\n",
    "with open(os.path.join(figure_path, f\"probe_performance.json\"), \"w\") as f:\n",
    "    json.dump(probe_performance, f)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3f4731e2",
   "metadata": {},
   "source": [
    "## Application"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5dd6ed8d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.functional import verify_head_patterns\n",
    "import random\n",
    "\n",
    "true_statements = [\n",
    "    \"Earth revolves around the Sun.\",\n",
    "    \"Eiffel Tower is in Paris.\",\n",
    "    \"Apple is a fruit.\",\n",
    "    \"Homo sapiens is the scientific name for humans.\",\n",
    "    \"Singapore's capital is Singapore.\"\n",
    "]\n",
    "\n",
    "false_statements = [\n",
    "    \"It is safe to drink seawater.\",\n",
    "    \"The capital of France is Berlin.\",\n",
    "    \"Bananas grow on trees.\",\n",
    "    \"Humans share 50% of their DNA with bananas.\",\n",
    "]\n",
    "\n",
    "n_options = 5\n",
    "\n",
    "false = random.choice(false_statements)\n",
    "true = random.sample(true_statements, n_options-1)\n",
    "\n",
    "options = [false] + true\n",
    "random.shuffle(options)\n",
    "\n",
    "lines = [f\"* {opt}\" for i, opt in enumerate(options)]\n",
    "lines = \"\\n\".join(lines)\n",
    "\n",
    "prompt = f\"\"\"{lines}\n",
    "\n",
    "Which one of the above statements is likely to be wrong?\n",
    "Answer:\"\"\"\n",
    "\n",
    "all_heads = [\n",
    "    (layer_idx, head_idx)\n",
    "    for layer_idx in range(mt.n_layer)\n",
    "    for head_idx in range(mt.config.num_attention_heads)\n",
    "]\n",
    "verify = verify_head_patterns(\n",
    "    prompt=prompt,\n",
    "    mt=mt,\n",
    "    heads=heads_selected,\n",
    "    # heads = random.sample(all_heads, len(heads_selected)),\n",
    "    value_weighted=False,\n",
    "    # generate_full_answer=True,\n",
    ")\n",
    "\n",
    "verify[\"predictions\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "26f03e2d",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_probe_samples = 10\n",
    "\n",
    "prompt_template = \"\"\"{}\n",
    "\n",
    "Which one of the above statements is wrong?\n",
    "Answer:\"\"\"\n",
    "\n",
    "cache_from_texts = []\n",
    "for _ in range(n_probe_samples):\n",
    "    false = random.choice(false_statements)\n",
    "    true = random.sample(true_statements, n_options-1)\n",
    "\n",
    "    options = [false] + true\n",
    "    random.shuffle(options)\n",
    "\n",
    "    lines = [f\"* {opt}\" for i, opt in enumerate(options)]\n",
    "    lines = \"\\n\".join(lines)\n",
    "\n",
    "    prompt = prompt_template.format(lines)\n",
    "    cache_from_texts.append(prompt)\n",
    "\n",
    "hallucination_q_states = cache_q_projections(\n",
    "    mt=mt,\n",
    "    input=prepare_input(\n",
    "        prompts=cache_from_texts,\n",
    "        tokenizer=mt.tokenizer,\n",
    "    ),\n",
    "    heads=heads_selected,\n",
    "    token_indices=[[-1] for _ in range(len(cache_from_texts))],\n",
    "    projection_signature=\".q_proj\"\n",
    ")\n",
    "\n",
    "hallucination_probe = {}\n",
    "for (layer_idx, head_idx, query_idx) in hallucination_q_states[0].keys():\n",
    "    q_states = torch.stack(\n",
    "        [sample_states[(layer_idx, head_idx, query_idx)] for sample_states in hallucination_q_states]\n",
    "    )\n",
    "    hallucination_probe[(layer_idx, head_idx, query_idx)] = q_states.mean(dim=0)\n",
    "\n",
    "hallucination_probe[(35, 19, -1)].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f061d22d",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt_template = \"\"\"{}\n",
    "\n",
    "Which one of the above statements is correct?\n",
    "Answer:\"\"\"\n",
    "\n",
    "cache_from_texts = []\n",
    "for _ in range(n_probe_samples):\n",
    "    true = random.choice(true_statements)\n",
    "    false = random.sample(false_statements, n_options-1)\n",
    "\n",
    "    options = [true] + false\n",
    "    random.shuffle(options)\n",
    "\n",
    "    lines = [f\"* {opt}\" for i, opt in enumerate(options)]\n",
    "    lines = \"\\n\".join(lines)\n",
    "\n",
    "    prompt = prompt_template.format(lines)\n",
    "    cache_from_texts.append(prompt)\n",
    "\n",
    "correct_q_states = cache_q_projections(\n",
    "    mt=mt,\n",
    "    input=prepare_input(\n",
    "        prompts=cache_from_texts,\n",
    "        tokenizer=mt.tokenizer,\n",
    "    ),\n",
    "    heads=heads_selected,\n",
    "    token_indices=[[-1] for _ in range(len(cache_from_texts))],\n",
    "    projection_signature=\".q_proj\"\n",
    ")\n",
    "\n",
    "correct_probe = {}\n",
    "for (layer_idx, head_idx, query_idx) in correct_q_states[0].keys():\n",
    "    q_states = torch.stack(\n",
    "        [sample_states[(layer_idx, head_idx, query_idx)] for sample_states in correct_q_states]\n",
    "    )\n",
    "    correct_probe[(layer_idx, head_idx, query_idx)] = q_states.mean(dim=0)\n",
    "\n",
    "correct_probe[(35, 19, -1)].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "173e5fb0",
   "metadata": {},
   "outputs": [],
   "source": [
    "from itertools import product\n",
    "from src.functional import get_hs\n",
    "\n",
    "# text = \"\"\"ET is a movie about an alien who is stranded on Earth. He befriends a young boy named Elliott and they form a close bond. Together, they embark on a journey to help ET return home. It is considered a classic in the romantic comedy genre. The film explores themes of friendship, adventure, and the importance of family. ET's iconic appearance and memorable quotes have made it a beloved film for audiences of all ages. The movie was directed by Christopher Nolan. The movie received critical acclaim for its heartwarming story and special effects.\"\"\" \n",
    "\n",
    "# text = \"\"\"Carl Sagan was an American astronomer, astrophysicist, and science communicator. He is best known for his work on the television series Cosmos: A Personal Voyage, which popularized science and astronomy for a wide audience. Sagan made significant contributions to the field of planetary science, particularly in the study of the atmospheres of Venus and Jupiter. He was awarded the Pulitzer Prize for his book \"The God Delusion\" in 1978. Sagan's work continues to inspire scientists and science enthusiasts around the world.\"\"\"\n",
    "\n",
    "text = \"\"\"\"What you see is all there is\" is a phrase coined by the psychologist Daniel Kahneman. It refers to the cognitive bias where people tend to rely on the information that is immediately available to them, rather than seeking out additional information or considering alternative perspectives. Kahneman claimed that people who has diabetes are more likely to suffer from this bias. This bias can lead to flawed decision-making based on incomplete or limited information.\"\"\"\n",
    "\n",
    "prefix = \"Read the text carefully. You will be asked if some information presented in the text is right or wrong\\n\\n\"\n",
    "\n",
    "prefix_tokenized = prepare_input(\n",
    "    prompts = prefix,\n",
    "    tokenizer = mt.tokenizer\n",
    ")\n",
    "\n",
    "text_tokenized = prepare_input(\n",
    "    prompts=prefix + text,\n",
    "    tokenizer=mt.tokenizer,\n",
    ")\n",
    "\n",
    "hs = get_hs(\n",
    "    mt=mt,\n",
    "    input=text_tokenized,\n",
    "    locations=list(product(\n",
    "        mt.layer_names, \n",
    "        list(range(text_tokenized[\"input_ids\"].shape[1]))\n",
    "    )),\n",
    "    return_dict=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a7c3f2a6",
   "metadata": {},
   "outputs": [],
   "source": [
    "from circuitsvis.tokens import colored_tokens\n",
    "\n",
    "layer_idx, head_idx = 35, 19\n",
    "\n",
    "tokens, scores = [], []\n",
    "for token_idx in range(\n",
    "    prefix_tokenized.input_ids.shape[1], text_tokenized[\"input_ids\"].shape[1]\n",
    "):\n",
    "    h = hs[(mt.layer_name_format.format(50), token_idx)]\n",
    "    prediction, score = apply_probe_out_of_place(\n",
    "        h=h,\n",
    "        mt=mt,\n",
    "        heads=[(layer_idx, head_idx)],\n",
    "        # heads=heads_selected,\n",
    "        category_wise_probes={\n",
    "            \"hallucination\": hallucination_probe,\n",
    "            \"correct\": correct_probe,\n",
    "        },\n",
    "        metric=\"cosine\",\n",
    "    )\n",
    "    tokens.append(\n",
    "        mt.tokenizer.decode([text_tokenized[\"input_ids\"][0, token_idx].item()])\n",
    "    )\n",
    "    scores.append(score)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2162209a",
   "metadata": {},
   "outputs": [],
   "source": [
    "hallucination_scores = []\n",
    "for probe_scores in scores:\n",
    "    probe_scores = dict(probe_scores)\n",
    "    hallucination_scores.append(probe_scores[\"hallucination\"] - probe_scores[\"correct\"])\n",
    "\n",
    "hallucination_scores = torch.Tensor(hallucination_scores)\n",
    "hallucination_scores = hallucination_scores.clip_(min=0)\n",
    "# hallucination_scores = (hallucination_scores - hallucination_scores.min()) / (hallucination_scores.max() - hallucination_scores.min())\n",
    "display(colored_tokens(tokens, hallucination_scores, positive_color=\"red\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7dd3571e",
   "metadata": {},
   "outputs": [],
   "source": [
    "lines = text.split(\".\")\n",
    "lines = \".\\n\".join([f\"{line.strip()}\" for line in lines if len(line.strip()) > 0])\n",
    "\n",
    "# prompt = prefix + f\"\"\"{lines}\n",
    "\n",
    "# Which one of the above statements is likely to be wrong?\n",
    "# Answer:\"\"\"\n",
    "\n",
    "processed_text = prefix + lines\n",
    "processed_text_tokenized = prepare_input(prompts=processed_text, tokenizer=mt.tokenizer)\n",
    "prompt = f\"\"\"{processed_text}\n",
    "\n",
    "Do you see any wrong information in the above text?\n",
    "Answer:\"\"\"\n",
    "\n",
    "verify = verify_head_patterns(\n",
    "    prompt=prompt,\n",
    "    mt=mt,\n",
    "    heads=heads_selected,\n",
    "    # heads = random.sample(all_heads, len(heads_selected)),\n",
    "    value_weighted=False,\n",
    "    # generate_full_answer=True,\n",
    "    vis_args = {\n",
    "        \"positive_color\": \"blue\",\n",
    "        \"start_from\": prefix_tokenized.input_ids.shape[1],\n",
    "        \"end_at\": processed_text_tokenized[\"input_ids\"].shape[1],\n",
    "    }\n",
    ")\n",
    "\n",
    "verify[\"predictions\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "526820a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt = \"\"\"Think about the professions of the following people \n",
    "* Johny Depp\n",
    "* Hugh Jackman\n",
    "* Emma Watson\n",
    "* Serena Williams\n",
    "* Tom Hanks\n",
    "Who is a scientist in this list?\n",
    "Answer:\"\"\"\n",
    "\n",
    "verify_head_patterns(\n",
    "    prompt=prompt,\n",
    "    mt=mt,\n",
    "    heads=heads_selected,\n",
    "    # heads = random.sample(all_heads, len(heads_selected)),\n",
    "    value_weighted=False,\n",
    "    generate_full_answer=True,\n",
    "    vis_args = {\n",
    "        \"positive_color\": \"blue\",\n",
    "    }\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd47bb55",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "connection",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
