{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "70b92923",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "983a4427",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "\n",
    "import sys\n",
    "\n",
    "sys.path.append(\"../\")\n",
    "\n",
    "##################################################################\n",
    "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
    "# os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0,1,2,3,4,5,6,7\"\n",
    "##################################################################\n",
    "\n",
    "import logging\n",
    "from src.utils import logging_utils\n",
    "from src.utils import env_utils\n",
    "\n",
    "logger = logging.getLogger(__name__)\n",
    "\n",
    "logging.basicConfig(\n",
    "    level=logging.DEBUG,\n",
    "    format=logging_utils.DEFAULT_FORMAT,\n",
    "    datefmt=logging_utils.DEFAULT_DATEFMT,\n",
    "    stream=sys.stdout,\n",
    ")\n",
    "\n",
    "import torch\n",
    "import transformers\n",
    "\n",
    "logger.info(f\"{torch.__version__=}, {torch.version.cuda=}\")\n",
    "logger.info(\n",
    "    f\"{torch.cuda.is_available()=}, {torch.cuda.device_count()=}, {torch.cuda.get_device_name()=}\"\n",
    ")\n",
    "logger.info(f\"{transformers.__version__=}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "03464cab",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.utils.training_utils import get_device_map\n",
    "\n",
    "# model_key = \"meta-llama/Llama-3.2-3B\"\n",
    "# model_key = \"meta-llama/Llama-3.1-8B\"\n",
    "# model_key = \"meta-llama/Llama-3.1-70B-Instruct\"\n",
    "model_key = \"meta-llama/Llama-3.3-70B-Instruct\"\n",
    "# model_key = \"meta-llama/Llama-3.1-405B-Instruct\"\n",
    "\n",
    "# model_key = \"google/gemma-2-9b-it\"\n",
    "# model_key = \"google/gemma-3-12b-it\"\n",
    "# model_key = \"google/gemma-2-27b-it\"\n",
    "\n",
    "# model_key = \"deepseek-ai/DeepSeek-R1-Distill-Llama-8B\"\n",
    "\n",
    "# model_key = \"allenai/OLMo-2-1124-7B-Instruct\"\n",
    "# model_key = \"allenai/OLMo-7B-0424-hf\"\n",
    "\n",
    "# model_key = \"Qwen/Qwen2-7B\"\n",
    "# model_key = \"Qwen/Qwen2.5-14B-Instruct\"\n",
    "# model_key = \"Qwen/Qwen2.5-32B-Instruct\"\n",
    "# model_key = \"Qwen/Qwen2.5-72B-Instruct\"\n",
    "\n",
    "# model_key = \"Qwen/Qwen3-1.7B\"\n",
    "# model_key = \"Qwen/Qwen3-4B\"\n",
    "# model_key = \"Qwen/Qwen3-8B\"\n",
    "# model_key = \"Qwen/Qwen3-14B\"\n",
    "# model_key = \"Qwen/Qwen3-32B\"\n",
    "\n",
    "# device_map = get_device_map(model_key, 30, n_gpus=8)\n",
    "# device_map"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1a88f8c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.models import ModelandTokenizer\n",
    "\n",
    "# from transformers import BitsAndBytesConfig\n",
    "\n",
    "mt = ModelandTokenizer(\n",
    "    model_key=model_key,\n",
    "    torch_dtype=torch.bfloat16,\n",
    "    # device_map=device_map,\n",
    "    device_map=\"auto\",\n",
    "    # quantization_config = BitsAndBytesConfig(\n",
    "    #     # load_in_4bit=True\n",
    "    #     load_in_8bit=True\n",
    "    # )\n",
    "    attn_implementation=\"eager\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "45b1c057",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.data import SelectOneTask, SelectOrderTask\n",
    "\n",
    "#################################################################################\n",
    "# TASK_CLS = SelectOrderTask\n",
    "# prompt_template_idx = 1\n",
    "TASK_CLS = SelectOneTask\n",
    "prompt_template_idx = 3\n",
    "N_DISTRACTORS = 5\n",
    "OPTION_STYLE = \"single_line\"\n",
    "#################################################################################\n",
    "\n",
    "select_task = TASK_CLS.load(\n",
    "    path=os.path.join(\n",
    "        env_utils.DEFAULT_DATA_DIR, \n",
    "        \"selection\", \n",
    "        # \"profession.json\"\n",
    "        # \"nationality.json\"\n",
    "        \"objects.json\"\n",
    "    )\n",
    ")\n",
    "\n",
    "print(select_task)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1711aade",
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib import pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "# optimized_path = os.path.join(\n",
    "#     env_utils.DEFAULT_RESULTS_DIR,\n",
    "#     \"selection/optimized_heads\",\n",
    "#     mt.name.split(\"/\")[-1],\n",
    "#     f\"{select_task.task_name}.npz\"\n",
    "# )\n",
    "\n",
    "# optimized_path = os.path.join(\n",
    "#     env_utils.DEFAULT_RESULTS_DIR,\n",
    "#     \"selection/optimized_heads\",\n",
    "#     model_key.split(\"/\")[-1],\n",
    "#     # \"distinct_options\",\n",
    "#     f\"{select_task.task_name}\",\n",
    "#     \"epoch_10.npz\"\n",
    "# )\n",
    "\n",
    "optimized_path = os.path.join(\n",
    "    env_utils.DEFAULT_RESULTS_DIR,\n",
    "    \"selection/optimized_heads\",\n",
    "    model_key.split(\"/\")[-1],\n",
    "    \"distinct_options\",\n",
    "    # f\"{select_task.task_name}\",\n",
    "    \"select_one\",\n",
    "    # \"legacy\",\n",
    "    \"epoch_10.npz\"\n",
    ")\n",
    "\n",
    "\n",
    "\n",
    "optimization_results = np.load(optimized_path, allow_pickle=True)\n",
    "plt.plot(optimization_results[\"losses\"])\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f5b0ca7",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(20, 10))\n",
    "\n",
    "optimal_head_mask = torch.tensor(optimization_results[\"optimal_mask\"]).to(torch.float32)\n",
    "optimal_head_mask[52:, :] = 0.0\n",
    "\n",
    "plt.imshow(\n",
    "    optimal_head_mask.T.numpy(),\n",
    "    cmap=\"Blues\",\n",
    "    aspect=\"auto\",\n",
    "    vmin=0,\n",
    "    vmax=1,\n",
    ")\n",
    "\n",
    "heads_selected = torch.nonzero(optimal_head_mask > 0.5, as_tuple=False).tolist()\n",
    "heads_selected = [\n",
    "    (layer_idx, head_idx) for layer_idx, head_idx in heads_selected\n",
    "]\n",
    "print(len(heads_selected))\n",
    "\n",
    "# HEADS = heads_selected\n",
    "\n",
    "# (35, 19) in HEADS, (35, 19) in heads_selected"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "526820a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.data import MCQify_sample\n",
    "\n",
    "sample = select_task.get_random_sample(\n",
    "    mt = mt,\n",
    "    option_style=OPTION_STYLE,\n",
    "    prompt_template_idx=prompt_template_idx,\n",
    "    # category=\"actor\",\n",
    "    # category=\"Brazil\"\n",
    "    category=\"fruit\",\n",
    "    filter_by_lm_prediction=False,\n",
    ")\n",
    "sample = MCQify_sample(sample = sample, tokenizer = mt)\n",
    "\n",
    "print(sample.prompt(), \">>\", f'\"{mt.tokenizer.decode([sample.ans_token_id])}\"')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b13680a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.attention import get_attention_matrices\n",
    "from src.selection.functional import (\n",
    "    verify_head_patterns,\n",
    "    get_patches_to_verify_independent_enrichment,\n",
    ")\n",
    "\n",
    "attn_pattern = verify_head_patterns(\n",
    "    prompt=sample.prompt(),\n",
    "    options=sample.options,\n",
    "    mt=mt,\n",
    "    heads=heads_selected,\n",
    "    # heads = HEADS,\n",
    "    # heads = [(35, 19)],\n",
    "    start_from=1\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7723b118",
   "metadata": {},
   "source": [
    "## Apply Logit Lens on the OV contribution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1598a322",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.hooking.llama_attention import LlamaAttentionPatcher\n",
    "import types\n",
    "import copy\n",
    "import baukit\n",
    "from src.functional import patch_with_baukit, interpret_logits, get_hs\n",
    "from src.selection.utils import get_first_token_id\n",
    "from src.tokens import prepare_input\n",
    "\n",
    "HEADS = copy.deepcopy(heads_selected)\n",
    "tokenized = prepare_input(prompts=sample.prompt(), tokenizer=mt.tokenizer)\n",
    "\n",
    "mt.reset_forward()\n",
    "mt.set_attn_implementation(\"sdpa\")\n",
    "\n",
    "layers_to_heads = {}\n",
    "for layer_idx, head_idx in HEADS:\n",
    "    if layer_idx not in layers_to_heads:\n",
    "        layers_to_heads[layer_idx] = []\n",
    "    layers_to_heads[layer_idx].append(head_idx)\n",
    "\n",
    "head_contributions = {}\n",
    "for layer_idx, head_indices in layers_to_heads.items():\n",
    "    attn_block_name = mt.attn_module_name_format.format(layer_idx)\n",
    "    attn_block = baukit.get_module(mt._model, attn_block_name)\n",
    "    head_contributions[layer_idx] = {}\n",
    "\n",
    "    attn_block.forward = types.MethodType(\n",
    "        LlamaAttentionPatcher(\n",
    "            block_name=attn_block_name,\n",
    "            save_attn_for=head_indices,\n",
    "            store_head_contributions=head_contributions[layer_idx],\n",
    "        ),\n",
    "        attn_block,\n",
    "    )\n",
    "\n",
    "logit_location = (mt.lm_head_name, -1)\n",
    "logit = get_hs(\n",
    "    mt=mt,\n",
    "    input=tokenized,\n",
    "    locations = logit_location,\n",
    "    return_dict=False\n",
    ")\n",
    "\n",
    "mt.reset_forward()\n",
    "mt.set_attn_implementation(\"eager\")\n",
    "\n",
    "pred = interpret_logits(\n",
    "    logits=logit,\n",
    "    tokenizer=mt.tokenizer,\n",
    ")\n",
    "pred"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "01a32947",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.functional import logit_lens\n",
    "from src.selection.data import get_options_for_answer\n",
    "\n",
    "# head_contrib = head_contributions[35][19][:, -1, :]\n",
    "head_contrib = []\n",
    "for layer_idx, head_idx in HEADS:\n",
    "    head_contrib.append(\n",
    "        head_contributions[layer_idx][head_idx][:, -1, :].squeeze().cuda()\n",
    "    )\n",
    "\n",
    "head_contrib = torch.stack(head_contrib).sum(dim=0)\n",
    "\n",
    "logit_lens(\n",
    "    mt=mt,\n",
    "    h=head_contrib,\n",
    "    interested_tokens=[\n",
    "        get_first_token_id(name=opt, tokenizer=mt.tokenizer, prefix=\" \")\n",
    "        for opt in get_options_for_answer(sample) + sample.options\n",
    "    ],\n",
    "    k=20\n",
    ")\n",
    "\n",
    "# head_contrib.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "da850729",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.functional import patchscope\n",
    "\n",
    "patchscope(\n",
    "    mt=mt,\n",
    "    h=head_contrib,\n",
    "    interested_tokens=[\n",
    "        get_first_token_id(name=opt, tokenizer=mt.tokenizer, prefix=\" \")\n",
    "        for opt in get_options_for_answer(sample) + sample.options\n",
    "    ],\n",
    "    k=20\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8f45ace2",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.functional import free_gpu_cache\n",
    "del head_contributions\n",
    "del head_contrib\n",
    "free_gpu_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b56d23c",
   "metadata": {},
   "outputs": [],
   "source": [
    "sample.obj, sample.ans_token_id"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bcee117e",
   "metadata": {},
   "source": [
    "## Just apply Logit Lens on the Latents"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "01f4ad6c",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.tokens import prepare_input\n",
    "\n",
    "def apply_logit_lens_per_layer(mt, sample):\n",
    "    tokenized = prepare_input(tokenizer=mt.tokenizer, prompts=sample.prompt())\n",
    "    track_tokens = {\n",
    "        \"object\": get_first_token_id(name=sample.obj, tokenizer=mt.tokenizer, prefix=\" \"),\n",
    "        \"ans_tok\": sample.ans_token_id,\n",
    "    }\n",
    "    hs = get_hs(\n",
    "        mt=mt,\n",
    "        input=tokenized,\n",
    "        locations=[(layer_name, -1) for layer_name in mt.layer_names],\n",
    "        return_dict=True,\n",
    "    )\n",
    "\n",
    "    interested_tokens = [\n",
    "        get_first_token_id(name=opt, tokenizer=mt.tokenizer, prefix=\" \")\n",
    "        for opt in get_options_for_answer(sample) + sample.options\n",
    "    ]\n",
    "\n",
    "    layerwise_results = {}\n",
    "    for layer_idx in range(mt.n_layer):\n",
    "        layer_name = mt.layer_name_format.format(layer_idx)\n",
    "        h = hs[(layer_name, -1)]\n",
    "        logits, (ll_pred, ll_track) = logit_lens(mt=mt, h=h, interested_tokens=interested_tokens, return_logits=True)\n",
    "        print(\n",
    "            f\"{layer_name} | {[f'{mt.tokenizer.decode(token_id)}({ll_track[token_id][0]} | {ll_track[token_id][1].logit:.2f})' for token_id in ll_track.keys()]} | {[str(pred) for pred in ll_pred]}\"\n",
    "        )\n",
    "        layerwise_results[layer_idx] = {\n",
    "            \"ll_pred\": ll_pred,\n",
    "            \"ll_track\": ll_track,\n",
    "            \"logits\": logits\n",
    "        }\n",
    "\n",
    "    return {\n",
    "        \"sample\": sample,\n",
    "        \"track_tokens\": track_tokens,\n",
    "        \"layerwise_results\": layerwise_results,\n",
    "    }\n",
    "\n",
    "ll_result_for_sample = apply_logit_lens_per_layer(mt=mt, sample=sample)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e828a2a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "from tqdm.auto import tqdm\n",
    "\n",
    "# results = []\n",
    "limit = 32\n",
    "\n",
    "for _ in tqdm(range(limit)):\n",
    "    sample = select_task.get_random_sample(\n",
    "        mt = mt,\n",
    "        option_style=OPTION_STYLE,\n",
    "        prompt_template_idx=prompt_template_idx,\n",
    "        # category=\"fruit\",\n",
    "        filter_by_lm_prediction=True,\n",
    "        n_distractors=random.choice(range(2, 6)),\n",
    "    )\n",
    "    sample = MCQify_sample(sample = sample, tokenizer = mt)\n",
    "    print(sample.prompt(), \">>\", f'\"{mt.tokenizer.decode([sample.ans_token_id])}\"')\n",
    "    result = apply_logit_lens_per_layer(mt=mt, sample=sample)\n",
    "    results.append(result)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "416c8bcc",
   "metadata": {},
   "outputs": [],
   "source": [
    "# results = [ll_result_for_sample]\n",
    "\n",
    "from src.trace import rank_reward\n",
    "\n",
    "scores = {token_type: [] for token_type in results[0][\"track_tokens\"].keys()}\n",
    "causality = []\n",
    "for result in results:\n",
    "    for token_type in scores.keys():\n",
    "        layerwise_scores = []\n",
    "        token_id = result[\"track_tokens\"][token_type]\n",
    "        for layer_idx in range(mt.n_layer):\n",
    "            # score = result[\"layerwise_results\"][layer_idx][\"ll_track\"][token_id][1].logit\n",
    "            score = rank_reward(\n",
    "                rank=result[\"layerwise_results\"][layer_idx][\"ll_track\"][token_id][0],\n",
    "                k=500\n",
    "            )\n",
    "            layerwise_scores.append(score)\n",
    "        layerwise_scores = np.array(layerwise_scores)\n",
    "        scores[token_type].append(layerwise_scores)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c7bc2fa7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "# Create a figure with a single subplot\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "plt.rcdefaults()\n",
    "plt.rcParams[\"figure.dpi\"] = 200\n",
    "plt.rcParams[\"font.family\"] = \"Times New Roman\"\n",
    "\n",
    "SMALL_SIZE = 15\n",
    "MEDIUM_SIZE = 20\n",
    "BIGGER_SIZE = 35\n",
    "\n",
    "plt.rc(\"font\", size=SMALL_SIZE)  # controls default text sizes\n",
    "plt.rc(\"axes\", titlesize=MEDIUM_SIZE)  # fontsize of the axes title\n",
    "plt.rc(\"axes\", labelsize=MEDIUM_SIZE)  # fontsize of the x and y labels\n",
    "plt.rc(\"xtick\", labelsize=SMALL_SIZE)  # fontsize of the tick labels\n",
    "plt.rc(\"ytick\", labelsize=SMALL_SIZE)  # fontsize of the tick labels\n",
    "plt.rc(\"legend\", fontsize=SMALL_SIZE)  # legend fontsize\n",
    "plt.rc(\"figure\", titlesize=BIGGER_SIZE)  # fontsize of the figure title\n",
    "\n",
    "fig, ax1 = plt.subplots(1, 1, figsize=(12, 6))\n",
    "colors = {\n",
    "    \"object\": \"blue\",\n",
    "    \"ans_tok\": \"red\",\n",
    "}\n",
    "\n",
    "# Plot logits on the panel\n",
    "lines = []\n",
    "for token_type, layerwise_scores_list in scores.items():\n",
    "    mean_scores = np.mean(layerwise_scores_list, axis=0)\n",
    "    sterr_scores = np.std(layerwise_scores_list, axis=0) / np.sqrt(\n",
    "        len(layerwise_scores_list)\n",
    "    )\n",
    "    line, = ax1.plot(mean_scores, label=f\"{token_type}\", alpha=0.8, color=colors[token_type])\n",
    "    ax1.fill_between(\n",
    "        range(len(mean_scores)),\n",
    "        mean_scores - sterr_scores,\n",
    "        mean_scores + sterr_scores,\n",
    "        alpha=0.1,\n",
    "        color=colors[token_type],\n",
    "    )\n",
    "    lines.append(line)\n",
    "\n",
    "ax1.set_xlabel(\"Layer\")\n",
    "ax1.set_ylabel(\"Logit(x)\")\n",
    "ax1.set_title(f\"Residual | {mt.name.split('/')[-1]}\")\n",
    "\n",
    "# Place the legend horizontally on top of the panel\n",
    "ax1.legend(\n",
    "    handles=lines,\n",
    "    loc=\"lower center\",\n",
    "    bbox_to_anchor=(0.5, -0.15),\n",
    "    ncol=len(scores),\n",
    "    frameon=False,\n",
    "    fontsize=\"medium\"\n",
    ")\n",
    "\n",
    "plt.tight_layout()\n",
    "save_dir = os.path.join(\"figures\", model_key.split(\"/\")[-1], \"logit_lens_contribution\")\n",
    "os.makedirs(save_dir, exist_ok=True)\n",
    "plt.savefig(os.path.join(save_dir, f\"template_{prompt_template_idx}.pdf\"), bbox_inches=\"tight\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b2a3f489",
   "metadata": {},
   "outputs": [],
   "source": [
    "result[\"layerwise_results\"][layer_idx][\"logits\"].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "75682322",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.trace import rank_reward\n",
    "rank_reward(1, k=20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b2966b95",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "connection",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
