{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "22a0571c",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "665e8438",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "\n",
    "import sys\n",
    "\n",
    "sys.path.append(\"../\")\n",
    "\n",
    "##################################################################\n",
    "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
    "# os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0,1,2,3,4,5,6,7\"\n",
    "##################################################################\n",
    "\n",
    "import logging\n",
    "from src.utils import logging_utils\n",
    "from src.utils import env_utils\n",
    "\n",
    "logger = logging.getLogger(__name__)\n",
    "\n",
    "logging.basicConfig(\n",
    "    level=logging.DEBUG,\n",
    "    format=logging_utils.DEFAULT_FORMAT,\n",
    "    datefmt=logging_utils.DEFAULT_DATEFMT,\n",
    "    stream=sys.stdout,\n",
    ")\n",
    "\n",
    "import torch\n",
    "import transformers\n",
    "\n",
    "logger.info(f\"{torch.__version__=}, {torch.version.cuda=}\")\n",
    "logger.info(\n",
    "    f\"{torch.cuda.is_available()=}, {torch.cuda.device_count()=}, {torch.cuda.get_device_name()=}\"\n",
    ")\n",
    "logger.info(f\"{transformers.__version__=}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9a271e90",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.utils.training_utils import get_device_map\n",
    "\n",
    "# model_key = \"meta-llama/Llama-3.2-3B\"\n",
    "# model_key = \"meta-llama/Llama-3.1-8B\"\n",
    "model_key = \"meta-llama/Llama-3.3-70B-Instruct\"\n",
    "# model_key = \"meta-llama/Llama-3.1-405B-Instruct\"\n",
    "\n",
    "# model_key = \"google/gemma-2-9b-it\"\n",
    "# model_key = \"google/gemma-3-12b-it\"\n",
    "# model_key = \"google/gemma-2-27b-it\"\n",
    "\n",
    "# model_key = \"deepseek-ai/DeepSeek-R1-Distill-Llama-8B\"\n",
    "\n",
    "# model_key = \"allenai/OLMo-2-1124-7B-Instruct\"\n",
    "# model_key = \"allenai/OLMo-7B-0424-hf\"\n",
    "\n",
    "# model_key = \"Qwen/Qwen2-7B\"\n",
    "# model_key = \"Qwen/Qwen2.5-14B-Instruct\"\n",
    "# model_key = \"Qwen/Qwen2.5-32B-Instruct\"\n",
    "# model_key = \"Qwen/Qwen2.5-72B-Instruct\"\n",
    "\n",
    "# model_key = \"Qwen/Qwen3-1.7B\"\n",
    "# model_key = \"Qwen/Qwen3-4B\"\n",
    "# model_key = \"Qwen/Qwen3-8B\"\n",
    "# model_key = \"Qwen/Qwen3-14B\"\n",
    "# model_key = \"Qwen/Qwen3-32B\"\n",
    "\n",
    "# device_map = get_device_map(model_key, 30, n_gpus=8)\n",
    "# device_map"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e9d6194",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.selection.data import SelectOneTask, SelectOrderTask\n",
    "\n",
    "#################################################################################\n",
    "# TASK_CLS = SelectOrderTask\n",
    "# prompt_template_idx = 1\n",
    "TASK_CLS = SelectOneTask\n",
    "prompt_template_idx = 3\n",
    "N_DISTRACTORS = 5\n",
    "OPTION_STYLE = \"single_line\"\n",
    "#################################################################################\n",
    "\n",
    "select_task = TASK_CLS.load(\n",
    "    path=os.path.join(\n",
    "        env_utils.DEFAULT_DATA_DIR, \n",
    "        \"selection\", \n",
    "        # \"profession.json\"\n",
    "        # \"nationality.json\"\n",
    "        \"objects.json\"\n",
    "    )\n",
    ")\n",
    "\n",
    "print(select_task)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "87448daa",
   "metadata": {},
   "outputs": [],
   "source": [
    "# HEADS = [\n",
    "#     (33, 45),\n",
    "#     (33, 18),\n",
    "#     (34, 1),\n",
    "#     (34, 6),\n",
    "#     (34, 7),\n",
    "#     (35, 19),\n",
    "#     (39, 40),\n",
    "#     (42, 30),\n",
    "#     (47, 18),\n",
    "#     (52, 58),\n",
    "# ]\n",
    "\n",
    "HEADS = [(62, 1), (60, 9), (64, 8), (62, 0), (62, 45), (59, 59), (71, 28), (64, 12), (61, 7), (64, 13), (67, 53), (67, 51), (54, 44), (57, 5), (59, 60)]\n",
    "\n",
    "# HEADS = [(35, 19)]\n",
    "\n",
    "\n",
    "# with open(\"optimized_heads.json\", \"r\") as f:\n",
    "#     HEADS = json.load(f)\n",
    "\n",
    "# with open(\"category_wise_heads.json\", \"r\") as f:\n",
    "#     category_wise_heads = json.load(f)\n",
    "# HEADS = [\n",
    "#     (layer_idx, head_idx)\n",
    "#     for layer_idx, head_idx, score in category_wise_heads[\"all\"][:100]\n",
    "# ]\n",
    "# HEADS = [(layer_idx, head_idx) for layer_idx, head_idx in HEADS if layer_idx < 61]\n",
    "\n",
    "\n",
    "print(len(HEADS))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f0952fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib import pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "# optimized_path = os.path.join(\n",
    "#     env_utils.DEFAULT_RESULTS_DIR,\n",
    "#     \"selection/optimized_heads\",\n",
    "#     mt.name.split(\"/\")[-1],\n",
    "#     f\"{select_task.task_name}.npz\"\n",
    "# )\n",
    "\n",
    "optimized_path = os.path.join(\n",
    "    env_utils.DEFAULT_RESULTS_DIR,\n",
    "    \"selection/optimized_heads\",\n",
    "    model_key.split(\"/\")[-1],\n",
    "    # \"distinct_options\",\n",
    "    f\"{select_task.task_name}\",\n",
    "    \"epoch_10.npz\"\n",
    ")\n",
    "\n",
    "optimization_results = np.load(optimized_path, allow_pickle=True)\n",
    "plt.plot(optimization_results[\"losses\"])\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e14889d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(20, 10))\n",
    "\n",
    "optimal_head_mask = torch.tensor(optimization_results[\"optimal_mask\"]).to(torch.float32)\n",
    "optimal_head_mask[50:, :] = 0.0\n",
    "\n",
    "plt.imshow(\n",
    "    optimal_head_mask.T.numpy(),\n",
    "    cmap=\"Blues\",\n",
    "    aspect=\"auto\",\n",
    "    vmin=0,\n",
    "    vmax=1,\n",
    ")\n",
    "\n",
    "heads_selected = torch.nonzero(optimal_head_mask > 0.5, as_tuple=False).tolist()\n",
    "heads_selected = [\n",
    "    (layer_idx, head_idx) for layer_idx, head_idx in heads_selected\n",
    "]\n",
    "print(len(heads_selected))\n",
    "\n",
    "# HEADS = heads_selected\n",
    "\n",
    "# (35, 19) in HEADS, (35, 19) in heads_selected"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "88dca292",
   "metadata": {},
   "source": [
    "## Loading the Attention Behavior Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7eb9a9ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "from scripts.locate_via_attention_behavior import SelectionSampleAttn\n",
    "\n",
    "attn_path = os.path.join(\n",
    "    env_utils.DEFAULT_RESULTS_DIR,\n",
    "    \"selection/attention_patterns/select_one\",\n",
    "    model_key.split(\"/\")[-1],\n",
    "    \"objects\"\n",
    ")\n",
    "files = sorted(os.listdir(attn_path))\n",
    "print(files)\n",
    "\n",
    "#######################################################################\n",
    "# LIMIT = 100\n",
    "LIMIT = len(files)\n",
    "#######################################################################\n",
    "\n",
    "selection_attns = []\n",
    "\n",
    "for npz_file in files[:LIMIT]:\n",
    "    if not npz_file.endswith(\".npz\"):\n",
    "        continue\n",
    "\n",
    "    npz_path = os.path.join(attn_path, npz_file)\n",
    "    selection_attns.append(SelectionSampleAttn.from_npz(npz_path))\n",
    "    if len(selection_attns) % 128 == 0:\n",
    "        print(f\"Loaded {len(selection_attns)}/{LIMIT} files\")\n",
    "\n",
    "len(selection_attns), selection_attns[0].attention_pattern.attention_matrices.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "af64b02d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm.auto import tqdm\n",
    "\n",
    "#############################################################################\n",
    "n_layer = selection_attns[0].attention_pattern.attention_matrices.shape[0]\n",
    "n_head = selection_attns[0].attention_pattern.attention_matrices.shape[1]\n",
    "# token_idx = \"all\"\n",
    "token_idx = \"last\"\n",
    "##############################################################################\n",
    "\n",
    "resolution_scores = torch.zeros((n_head, n_layer), dtype=torch.float32)\n",
    "for selection_attn in tqdm(selection_attns):\n",
    "    for layer_idx in range(n_layer):\n",
    "        for head_idx in range(n_head):\n",
    "            resolution_scores[head_idx, layer_idx] += selection_attn.resolution_score(\n",
    "                layer_idx, head_idx, token_idx=token_idx\n",
    "            )[0]\n",
    "            # resolution_scores[head_idx, layer_idx] += selection_attn.first_token_score(\n",
    "            #     layer_idx, head_idx\n",
    "            # )[0]\n",
    "\n",
    "resolution_scores /= len(selection_attns)\n",
    "resolution_scores.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "acd42c4b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib import pyplot as plt\n",
    "\n",
    "plt.figure(figsize=(20, 10))\n",
    "scale = torch.max(torch.abs(resolution_scores))\n",
    "plt.imshow(\n",
    "    resolution_scores.cpu().numpy(),\n",
    "    cmap=\"RdBu\",\n",
    "    aspect=\"auto\",\n",
    "    vmin=-scale,\n",
    "    vmax=scale,\n",
    ")\n",
    "plt.colorbar()\n",
    "# plt.title(f\"score(target) - max(score(distractors)) | {token_idx.upper()} tokens of options\")\n",
    "plt.title(\"score(target[0]) - sum(score(target[1:]))\")\n",
    "plt.xlabel(\"Layer\")\n",
    "plt.ylabel(\"Head\")\n",
    "\n",
    "def get_ticks(ticks, skip=5):\n",
    "    ret = []\n",
    "    for i in ticks:\n",
    "        if i % skip == 0:\n",
    "            ret.append(str(i))\n",
    "        else:\n",
    "            ret.append(\"\")\n",
    "    return ret\n",
    "\n",
    "plt.xticks(\n",
    "    ticks=range(n_layer),\n",
    "    labels=get_ticks(range(n_layer)),\n",
    "    rotation=45,\n",
    ")\n",
    "plt.yticks(\n",
    "    ticks=range(n_head),\n",
    "    labels=get_ticks(range(n_head), skip=4),\n",
    ")\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "\n",
    "scores_per_head = []\n",
    "for head_idx in range(n_head):\n",
    "    for layer_idx in range(n_layer):\n",
    "        scores_per_head.append(\n",
    "            (head_idx, layer_idx, resolution_scores[head_idx, layer_idx].item())\n",
    "        )\n",
    "\n",
    "scores_per_head = sorted(scores_per_head, key=lambda x: x[2], reverse=True)\n",
    "for head_idx, layer_idx, score in scores_per_head[:15]:\n",
    "    print(f\"Layer {layer_idx}, Head {head_idx}: {score:.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe7ab754",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.attention import visualize_attn_matrix\n",
    "\n",
    "sample_idx = 45\n",
    "layer_idx = 35\n",
    "head_idx = 19\n",
    "# layer_idx = 54\n",
    "# head_idx = 44\n",
    "\n",
    "selection_attn = selection_attns[sample_idx]\n",
    "# selection_attn = non_aligned[2]\n",
    "print(selection_attn.resolution_score(layer_idx, head_idx))\n",
    "visualize_attn_matrix(\n",
    "    attn_matrix=selection_attn.attention_pattern.attention_matrices[layer_idx, head_idx],\n",
    "    tokens=selection_attn.attention_pattern.tokenized_prompt,\n",
    "    q_index=-1,\n",
    "    start_from=1,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3f60466c",
   "metadata": {},
   "outputs": [],
   "source": [
    "selection_attn.score_per_option(35, 19)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "77453193",
   "metadata": {},
   "outputs": [],
   "source": [
    "combined_attn_matrix = []\n",
    "combined_option_scores = []\n",
    "# for layer_idx, head_idx in heads_selected:\n",
    "for layer_idx, head_idx in random_heads:\n",
    "    combined_attn_matrix.append(\n",
    "        torch.Tensor(selection_attn.attention_pattern.attention_matrices[layer_idx, head_idx])\n",
    "    )\n",
    "    combined_option_scores.append(\n",
    "        selection_attn.score_per_option(layer_idx, head_idx)\n",
    "    )\n",
    "\n",
    "combined_attn_matrix = torch.stack(combined_attn_matrix).squeeze().mean(dim=0)\n",
    "\n",
    "visualize_attn_matrix(\n",
    "    attn_matrix=combined_attn_matrix,\n",
    "    tokens=selection_attn.attention_pattern.tokenized_prompt,\n",
    "    q_index=-1,\n",
    "    start_from=1\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c93b56e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "combined_option_scores = torch.Tensor(combined_option_scores)\n",
    "combined_option_scores.mean(dim=0).argmax().item() == selection_attn.sample.obj_idx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "78f60f7c",
   "metadata": {},
   "outputs": [],
   "source": [
    "len(heads_selected)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2ce4d4c7",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "70d18b78",
   "metadata": {},
   "outputs": [],
   "source": [
    "aligned = []\n",
    "non_aligned = []\n",
    "\n",
    "for filter_attn_result in tqdm(selection_attns):\n",
    "    combined_option_scores = []\n",
    "    for layer_idx, head_idx in heads_selected:\n",
    "    # for layer_idx, head_idx in random_heads:\n",
    "        combined_option_scores.append(\n",
    "            filter_attn_result.score_per_option(layer_idx, head_idx)\n",
    "        )\n",
    "    combined_option_scores = torch.Tensor(combined_option_scores)\n",
    "    if combined_option_scores.mean(dim=0).argmax().item() == filter_attn_result.sample.obj_idx:\n",
    "        aligned.append(filter_attn_result)\n",
    "    else:\n",
    "        non_aligned.append(filter_attn_result)\n",
    "\n",
    "f\"Accuracy: {len(aligned) / (len(selection_attns)):.2f} ({len(aligned)} / {len(selection_attns)})\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9818dddf",
   "metadata": {},
   "outputs": [],
   "source": [
    "N_RANDOM_TRIALS = 10\n",
    "rancom_acc = []\n",
    "\n",
    "from itertools import product\n",
    "import random\n",
    "\n",
    "all_heads = list(product(range(n_layer), range(n_head)))\n",
    "all_heads = list(set(all_heads) - set(heads_selected))\n",
    "\n",
    "trial_result = []\n",
    "for train in range(N_RANDOM_TRIALS):\n",
    "    random_heads = random.sample(all_heads, len(heads_selected))\n",
    "    combined_option_scores = []\n",
    "\n",
    "    correct_count = 0\n",
    "\n",
    "    for filter_attn_result in tqdm(selection_attns):\n",
    "        combined_option_scores = []\n",
    "        for layer_idx, head_idx in random_heads:\n",
    "            combined_option_scores.append(\n",
    "                filter_attn_result.score_per_option(layer_idx, head_idx)\n",
    "            )\n",
    "        combined_option_scores = torch.Tensor(combined_option_scores)\n",
    "        if combined_option_scores.mean(dim=0).argmax().item() == filter_attn_result.sample.obj_idx:\n",
    "            correct_count += 1\n",
    "\n",
    "    accuracy = correct_count / len(selection_attns)\n",
    "    trial_result.append(accuracy)\n",
    "\n",
    "trial_result = np.array(trial_result)\n",
    "f\"Random acc: {trial_result.mean():.2f} ± {trial_result.std():.2f}\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3c776c6f",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "connection",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
