{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Take LLaVA-1.5-7B as an example."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import argparse\n",
    "import json\n",
    "import os\n",
    "from PIL import Image\n",
    "import pickle\n",
    "\n",
    "import numpy as np\n",
    "from scipy import stats\n",
    "import os\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
    "import torch.nn.functional as F\n",
    "import torch.backends.cudnn as cudnn\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "import seaborn as sns\n",
    "import transformers\n",
    "import warnings \n",
    "warnings.filterwarnings('ignore')\n",
    "\n",
    "from model_manager import ModelManager\n",
    "from utils import setup_seeds, disable_torch_init\n",
    "\n",
    "parser = argparse.ArgumentParser(description=\"Case studies on LVLMs.\")\n",
    "parser.add_argument(\"--model\", type=str, default='llava-1.5', help=\"model\")\n",
    "parser.add_argument(\"--batch-size\", type=int, default=1)\n",
    "parser.add_argument(\"--beam\", type=int, default=1) # 1 for Greedy Decoding\n",
    "parser.add_argument(\"--max-tokens\", type=int, default=512)\n",
    "args = parser.parse_known_args()[0]\n",
    "\n",
    "setup_seeds()\n",
    "disable_torch_init()\n",
    "\n",
    "# Load model\n",
    "model_manager = ModelManager(args.model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import cv2\n",
    "from typing import List\n",
    "def get_topk_attention_indices(\n",
    "    text, layer_start, layer_end, head_start, head_end, outputs, tokenizer, image,\n",
    "    vision_token_start, vision_token_end\n",
    "):\n",
    "    selected_token_id = tokenizer(text, add_special_tokens=False)[\"input_ids\"][0]\n",
    "    token_in_generation_idx = torch.nonzero(outputs['sequences'][0][1:] == selected_token_id)[0].item()\n",
    "\n",
    "    attn_maps = []\n",
    "    for layer_id in range(layer_start, layer_end):\n",
    "        attn_layer = outputs['attentions'][token_in_generation_idx][layer_id]  # [1, num_heads, q_len, k_len]\n",
    "        attn_layer = attn_layer.squeeze(0)                                     # [num_heads, q_len, k_len]\n",
    "        selected_heads = attn_layer[head_start:head_end]                       # [num_selected_heads, q_len, k_len]\n",
    "        attn_maps.append(selected_heads)\n",
    "\n",
    "    attn_tensor = torch.stack(attn_maps, dim=0)\n",
    "    row_attn = attn_tensor.mean(dim=(0, 1)).cpu().detach()\n",
    "\n",
    "    visual_row_attn = row_attn[-1, vision_token_start:vision_token_end].to(torch.float32)\n",
    "    visual_row_attn = visual_row_attn / visual_row_attn.sum()\n",
    "\n",
    "    topk = torch.topk(visual_row_attn, k=10)\n",
    "    topk_indices = topk.indices.tolist()\n",
    "    topk_values = topk.values.tolist()\n",
    "\n",
    "\n",
    "    return topk_indices,topk_values\n",
    "\n",
    "from collections import Counter, defaultdict\n",
    "\n",
    "def analyze_top_vision_token_predictions(\n",
    "    model, tokenizer, outputs, top_vision_token_indices: List[int],\n",
    "    vision_token_start: int, layer_range: List[int]\n",
    "):\n",
    "    \"\"\"\n",
    "    Returns:\n",
    "        result_dict: dict {token_idx: {'most_common_token': token_str, 'count': int}}\n",
    "    \"\"\"\n",
    "\n",
    "    hidden_states = outputs['hidden_states'][0]  # Tuple of length 33\n",
    "    assert isinstance(hidden_states, (tuple, list)), \"outputs['hidden_states'] must be a tuple or list\"\n",
    "\n",
    "    batch_id = 0  \n",
    "    result_list = []\n",
    "\n",
    "    for vis_token_idx in top_vision_token_indices:\n",
    "        global_token_idx = vision_token_start + vis_token_idx\n",
    "\n",
    "        for layer_id in layer_range:\n",
    "            hidden = hidden_states[layer_id][batch_id, global_token_idx].clone().detach()  # shape: [hidden_dim]\n",
    "            logits = model.lm_head(hidden)  # shape: [vocab_size]\n",
    "            pred_token_id = torch.argmax(logits).item()\n",
    "            result_list.append(pred_token_id)\n",
    "\n",
    "    return result_list\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_image_non_sink_attention_sums(attentions, vision_token_start: int, img_token_num: int, sink_tokens: list[int]):\n",
    "        num_layers = len(attentions)\n",
    "        batch_id = 0\n",
    "        image_attn_sums = []\n",
    "\n",
    "        all_img_indices = torch.arange(img_token_num)\n",
    "        sink_mask = torch.zeros(img_token_num, dtype=torch.bool)\n",
    "        sink_mask[sink_tokens] = True\n",
    "        valid_img_indices = all_img_indices[~sink_mask]  # 0-based index within vision tokens\n",
    "        valid_token_indices = vision_token_start + valid_img_indices\n",
    "\n",
    "        for layer_id in range(num_layers):\n",
    "            attn = attentions[layer_id][batch_id]  # [head, seq_len, seq_len]\n",
    "            attn_last = attn[:, -1, :]  # [head, seq_len]\n",
    "            img_attn = attn_last[:, valid_token_indices]  # [head, num_valid_img_tokens]\n",
    "            img_sum = img_attn.sum(dim=1)  # [head]\n",
    "            image_attn_sums.append(img_sum)\n",
    "\n",
    "        image_attn_sums = torch.stack(image_attn_sums, dim=0)  # [layer, head]\n",
    "        return image_attn_sums\n",
    "    \n",
    "def compute_vision_sink_tokens(\n",
    "        model,\n",
    "        outputs,\n",
    "        sink_token_ids,\n",
    "        vision_token_start: int = 32,\n",
    "        vision_token_num: int = 576,\n",
    "        tau=23\n",
    "    ):\n",
    "        hidden_states = outputs['hidden_states'][0]  # [num_layers+1, batch, seq_len, hidden_dim]\n",
    "        num_layers = len(hidden_states) - 1  # skip embedding layer\n",
    "        batch_id = 0\n",
    "\n",
    "        sink_token_ids = set(sink_token_ids)  # for fast lookup\n",
    "        token_hit_counts = []\n",
    "\n",
    "        \n",
    "        ### accelerate ###\n",
    "\n",
    "        all_layers = torch.stack(hidden_states[1:], dim=0)  # shape: [num_layers, B, seq_len, H]\n",
    "\n",
    "        vision_hidden = all_layers[:, batch_id, vision_token_start:vision_token_start + vision_token_num, :]  # shape: [L, T, H]\n",
    "\n",
    "        L, T, H = vision_hidden.shape\n",
    "        vision_hidden_flat = vision_hidden.reshape(L * T, H)  # [L*T, H]\n",
    "        logits_flat = model.lm_head(vision_hidden_flat)        # [L*T, vocab_size]\n",
    "\n",
    "        top_token_ids_flat = torch.argmax(logits_flat, dim=-1)     # [L*T]\n",
    "        top_token_ids = top_token_ids_flat.reshape(L, T)           # [L, T]\n",
    "\n",
    "        sink_token_mask = torch.tensor([tid in sink_token_ids for tid in range(model.lm_head.out_features)], device=top_token_ids.device)\n",
    "        sink_hits = sink_token_mask[top_token_ids]  # shape: [L, T], bool\n",
    "\n",
    "        token_hit_counts = sink_hits.sum(dim=0)  # shape: [T]\n",
    "        token_hit_counts = [(i, count.item()) for i, count in enumerate(token_hit_counts)]\n",
    "        ### accelerate ###\n",
    "\n",
    "        sink_token_counts = [(idx, count) for idx, count in token_hit_counts if count >= tau]  \n",
    "        return sink_token_counts\n",
    "    \n",
    "def compute_vision_sink_tokens_new(\n",
    "        model,\n",
    "        outputs,\n",
    "        sink_token_ids,\n",
    "        vision_token_start: int = 32,\n",
    "        vision_token_num: int = 576,\n",
    "    ):\n",
    "\n",
    "        hidden_states = outputs['hidden_states'][0]  # [num_layers+1, batch, seq_len, hidden_dim]\n",
    "        num_layers = len(hidden_states) - 1  # skip embedding layer\n",
    "        batch_id = 0\n",
    "\n",
    "        sink_token_ids = set(sink_token_ids)  # for fast lookup\n",
    "        token_hit_counts = []\n",
    "        \n",
    "        ### accelerate ###\n",
    "        all_layers = torch.stack(hidden_states[1:], dim=0)  # shape: [num_layers, B, seq_len, H]\n",
    "\n",
    "        vision_hidden = all_layers[:, batch_id, vision_token_start:vision_token_start + vision_token_num, :]  # shape: [L, T, H]\n",
    "\n",
    "        L, T, H = vision_hidden.shape\n",
    "        vision_hidden_flat = vision_hidden.reshape(L * T, H)  # [L*T, H]\n",
    "        logits_flat = model.lm_head(vision_hidden_flat)        # [L*T, vocab_size]\n",
    "\n",
    "        top_token_ids_flat = torch.argmax(logits_flat, dim=-1)     # [L*T]\n",
    "        top_token_ids = top_token_ids_flat.reshape(L, T)           # [L, T]\n",
    "\n",
    "        sink_token_mask = torch.tensor([tid in sink_token_ids for tid in range(model.lm_head.out_features)], device=top_token_ids.device)\n",
    "        sink_hits = sink_token_mask[top_token_ids]  # shape: [L, T], bool\n",
    "\n",
    "        token_hit_counts = sink_hits.sum(dim=0)  # shape: [T]\n",
    "        token_hit_counts = [(i, count.item()) for i, count in enumerate(token_hit_counts)]\n",
    "        ### accelerate ###\n",
    "        \n",
    "        sink_token_counts = [(idx, count) for idx, count in token_hit_counts if count >= 0]  \n",
    "\n",
    "        return sink_token_counts\n",
    "    \n",
    "    \n",
    "def compute_vision_sink_tokens_hits(\n",
    "        model,\n",
    "        outputs,\n",
    "        intersection,\n",
    "        sink_tokens_new,\n",
    "        vision_token_start: int = 32,\n",
    "        vision_token_num: int = 576,\n",
    "    ):\n",
    "\n",
    "    hidden_states = outputs['hidden_states'][0]  # [num_layers+1, batch, seq_len, hidden_dim]\n",
    "    num_layers = len(hidden_states) - 1  # skip embedding layer\n",
    "    batch_id = 0\n",
    "\n",
    "    sink_token_ids = set(sink_tokens_new)  # for fast lookup\n",
    "    \n",
    "    all_layers = torch.stack(hidden_states[1:], dim=0)  # shape: [num_layers, B, seq_len, H]\n",
    "\n",
    "    vision_hidden = all_layers[:, batch_id, vision_token_start:vision_token_start + vision_token_num, :]  # [L, T, H]\n",
    "\n",
    "    L, T, H = vision_hidden.shape\n",
    "    vision_hidden_flat = vision_hidden.reshape(L * T, H)  # [L*T, H]\n",
    "    logits_flat = model.lm_head(vision_hidden_flat)       # [L*T, vocab_size]\n",
    "\n",
    "    top_token_ids_flat = torch.argmax(logits_flat, dim=-1)  # [L*T]\n",
    "    top_token_ids = top_token_ids_flat.reshape(L, T)        # [L, T]\n",
    "\n",
    "\n",
    "    sink_token_mask = torch.tensor(\n",
    "        [tid in sink_token_ids for tid in range(model.lm_head.out_features)], \n",
    "        device=top_token_ids.device\n",
    "    )\n",
    "    sink_hits = sink_token_mask[top_token_ids]  # shape: [L, T], bool\n",
    "\n",
    "    intersection = torch.tensor(intersection, device=top_token_ids.device)\n",
    "    selected_hits = sink_hits[:, intersection]              # [L, |intersection|]\n",
    "    token_hit_counts = selected_hits.sum(dim=0)             # [|intersection|]\n",
    "\n",
    "    sink_token_counts = [\n",
    "        (intersection[i].item(), token_hit_counts[i].item()) \n",
    "        for i in range(len(intersection))\n",
    "    ]\n",
    "\n",
    "    return sink_token_counts\n",
    "\n",
    "    \n",
    "def compute_sink_attention_scores(attentions, sink_tokens: list[int], vision_token_start: int, img_token_num: int):\n",
    "\n",
    "    num_layers = len(attentions)\n",
    "    batch_id = 0  \n",
    "    num_heads = attentions[0].shape[1]\n",
    "    # print(attentions.shape)\n",
    "\n",
    "    sink_token_indices = [vision_token_start + rel_idx for rel_idx in sink_tokens]\n",
    "    vision_token_indices = list(range(vision_token_start, vision_token_start + img_token_num))\n",
    "\n",
    "    sink_attn_ratios = []\n",
    "\n",
    "    for layer_id in range(num_layers):\n",
    "        # attn: [H, Q, K]\n",
    "        attn = attentions[layer_id][batch_id]                # [32, seq_len, seq_len]\n",
    "        attn_last_token = attn[:, -1, :]                     # [32, seq_len]\n",
    "\n",
    "        sink_values = attn_last_token[:, sink_token_indices]          # [32, len(sink_tokens)]\n",
    "        vision_values = attn_last_token[:, vision_token_indices]      # [32, img_token_num]\n",
    "\n",
    "        sink_sum = sink_values.sum(dim=1)        # [32]\n",
    "        vision_sum = vision_values.sum(dim=1)    # [32]\n",
    "\n",
    "        ratio = sink_sum / (vision_sum + 1e-8)   # [32] \n",
    "        sink_attn_ratios.append(ratio)\n",
    "\n",
    "    sink_attn_ratios = torch.stack(sink_attn_ratios, dim=0)  # [32, 32]\n",
    "\n",
    "    return sink_attn_ratios"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"sink_token_ids_llava_15.json\", \"r\") as f:\n",
    "    sink_token_ids = json.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from chair import chair_eval\n",
    "from llava.mm_utils import process_images\n",
    "from utils import get_only_attn_out_contribution\n",
    "from utils import attnw_over_vision_layer_head_selected_text\n",
    "\n",
    "img_query_lists = [\n",
    "    json.loads(line) for line in open('./examples/query.jsonl')\n",
    "]\n",
    "\n",
    "visual_attn_weights = []\n",
    "visual_attn_weights_hal = []\n",
    "non_sink_visual_attn_weights = []\n",
    "non_sink_visual_attn_weights_hal = []\n",
    "sink_scores_real = []\n",
    "sink_scores_hallu = []\n",
    "real_token_index = []\n",
    "hallu_token_index = []\n",
    "vision_sink_tokens_hits_list = []\n",
    "large_attention_token_indices = []\n",
    "\n",
    "from tqdm import tqdm\n",
    "\n",
    "for img_query in tqdm(img_query_lists, desc=\"Processing img_query\"):\n",
    "    # prepare inputs\n",
    "    \n",
    "    if 'image_id' in img_query.keys():\n",
    "        img_id = f\"COCO_val2014_{str(img_query['image_id']).zfill(12)}.jpg\"\n",
    "    elif 'image' in img_query.keys():\n",
    "        img_id = f\"COCO_val2014_{str(img_query['image'])}\"\n",
    "        \n",
    "    img_path = os.path.join(args.data_path, img_id)\n",
    "    img = Image.open(img_path).convert('RGB')\n",
    "    images_tensor = process_images(\n",
    "                            [img],\n",
    "                            model_manager.image_processor,\n",
    "                            model_manager.llm_model.config\n",
    "                    ).to(model_manager.llm_model.device, dtype=torch.float16)\n",
    "    if 'text' in img_query.keys():\n",
    "        query = [img_query['text']]\n",
    "    elif 'instruction' in img_query.keys():\n",
    "        query = [img_query['instruction']]\n",
    "    \n",
    "    questions, input_ids, kwargs = model_manager.prepare_inputs_for_model(query, images_tensor, use_dataloader=False)\n",
    "\n",
    "\n",
    "    with torch.inference_mode():\n",
    "        outputs = model_manager.llm_model.generate(\n",
    "            input_ids,\n",
    "            do_sample=False,\n",
    "            num_beams=args.beam,\n",
    "            max_new_tokens=args.max_tokens,\n",
    "            use_cache=True,\n",
    "            output_scores=True,\n",
    "            output_hidden_states=True,\n",
    "            output_attentions=True,\n",
    "            return_dict_in_generate=True,\n",
    "            **kwargs,\n",
    "        )\n",
    "\n",
    "    answer = model_manager.tokenizer.batch_decode(outputs['sequences'], skip_special_tokens=True)[0].strip()\n",
    "    img_info = chair_eval(evaluator, img_id, answer)\n",
    "\n",
    "    # calc some constants\n",
    "    vision_token_start = model_manager.img_start_idx\n",
    "    vision_token_end = model_manager.img_end_idx\n",
    "    input_token_len = (model_manager.llm_model.get_vision_tower().num_patches\n",
    "                    + len(input_ids[0]) - 1 # -1 for the <image> token\n",
    "    )\n",
    "    gt_words = img_info['mscoco_gt_words']\n",
    "    generated_words = img_info['mscoco_generated_words']\n",
    "    # print(gt_words, generated_words)\n",
    "    # Real words Calculation\n",
    "    for ri, real_word in enumerate(set(generated_words) & set(gt_words)):\n",
    "        # calculate attn sublayer contribution for each real word\n",
    "        try:\n",
    "            # get visual attention weights\n",
    "            real_word_attnw_matrix, token_in_generation_idx = attnw_over_vision_layer_head_selected_text(\n",
    "                    real_word, outputs, model_manager.tokenizer,\n",
    "                    vision_token_start, vision_token_end\n",
    "            )\n",
    "            # print(\"sink_token_counts\",outputs)\n",
    "            sink_token_counts = compute_vision_sink_tokens(model_manager.llm_model, outputs, sink_token_ids, vision_token_start)\n",
    "            sink_token_counts_new = compute_vision_sink_tokens_new(model_manager.llm_model, outputs, sink_token_ids, vision_token_start)\n",
    "            sink_tokens_new = [item[0] for item in  sink_token_counts_new]\n",
    "            # print(\"sink_token_counts \",sink_token_counts)\n",
    "            sink_tokens = [item[0] for item in  sink_token_counts]\n",
    "            \n",
    "            sink_scores = compute_sink_attention_scores(\n",
    "                    attentions=outputs['attentions'][token_in_generation_idx], \n",
    "                    sink_tokens=sink_tokens,\n",
    "                    vision_token_start=vision_token_start,\n",
    "                    img_token_num=576\n",
    "                )\n",
    "            \n",
    "            sink_scores_real.append(sink_scores.cpu().detach().numpy())\n",
    "            \n",
    "            \n",
    "            \n",
    "            non_sink_attention =compute_image_non_sink_attention_sums(attentions=outputs['attentions'][token_in_generation_idx], \n",
    "                                                                      vision_token_start=vision_token_start, img_token_num=576, sink_tokens=sink_tokens)\n",
    "            non_sink_visual_attn_weights.append(non_sink_attention.cpu().detach().numpy())\n",
    "            \n",
    "            real_token_index.append(token_in_generation_idx)\n",
    "            \n",
    "            visual_attn_weights.append(real_word_attnw_matrix)\n",
    "            \n",
    "            \n",
    "            \n",
    "            real_word_layer_attnw = real_word_attnw_matrix.mean(axis=1)[::-1]\n",
    "            \n",
    "            topk_indices,topk_values=get_topk_attention_indices(\n",
    "                text=real_word, layer_start=0, layer_end=32, head_start=0, head_end=32, outputs=outputs, tokenizer=model_manager.tokenizer, image=img,\n",
    "                vision_token_start=model_manager.img_start_idx, vision_token_end=model_manager.img_end_idx\n",
    "            )\n",
    "            \n",
    "            intersection = np.intersect1d(sink_tokens_new, topk_indices)\n",
    "            \n",
    "            vision_sink_tokens_hits=compute_vision_sink_tokens_hits(model_manager.llm_model, outputs,intersection, sink_token_ids, vision_token_start)\n",
    "            vision_sink_tokens_hits_list.extend(vision_sink_tokens_hits)\n",
    "            result_list = analyze_top_vision_token_predictions(\n",
    "                model=model_manager.llm_model,\n",
    "                tokenizer=model_manager.tokenizer,\n",
    "                outputs=outputs,\n",
    "                top_vision_token_indices=topk_indices,\n",
    "                vision_token_start=model_manager.img_start_idx,\n",
    "                layer_range=list(range(1, 33)), \n",
    "            )\n",
    "            large_attention_token_indices.extend(result_list)\n",
    "            \n",
    "        except Exception as e:\n",
    "            pass\n",
    "\n",
    "    if len(img_info['mscoco_hallucinated_words']) == 0:\n",
    "        continue\n",
    "\n",
    "    # Hallucinated words Calculation\n",
    "    hallucination_words = [\n",
    "        item for sublist in img_info['mscoco_hallucinated_words'] for item in sublist\n",
    "    ]\n",
    "    # print(\"hallucination_words\", hallucination_words)\n",
    "    for hi, hallu_word in enumerate(set(hallucination_words)):\n",
    "        # calculate attn sublayer contribution for each hallu word\n",
    "        try:\n",
    "            # get visual attention weights\n",
    "            hallu_word_attnw_matrix, token_in_generation_idx = attnw_over_vision_layer_head_selected_text(\n",
    "                    hallu_word, outputs, model_manager.tokenizer,\n",
    "                    vision_token_start, vision_token_end\n",
    "            )\n",
    "            hallu_token_index.append(token_in_generation_idx)\n",
    "            \n",
    "            # visual_attn_weights.append(hallu_word_attnw_matrix)\n",
    "            visual_attn_weights_hal.append(hallu_word_attnw_matrix)\n",
    "            hallu_word_layer_attnw = hallu_word_attnw_matrix.mean(axis=1)[::-1]\n",
    "            \n",
    "            sink_token_counts = compute_vision_sink_tokens(model_manager.llm_model, outputs, sink_token_ids, vision_token_start)\n",
    "            sink_token_counts_new = compute_vision_sink_tokens_new(model_manager.llm_model, outputs, sink_token_ids, vision_token_start)\n",
    "            sink_tokens_new = [item[0] for item in  sink_token_counts_new]\n",
    "            sink_tokens = [item[0] for item in  sink_token_counts]\n",
    "            \n",
    "            sink_scores = compute_sink_attention_scores(\n",
    "                    attentions=outputs['attentions'][token_in_generation_idx], \n",
    "                    sink_tokens=sink_tokens,\n",
    "                    vision_token_start=vision_token_start,\n",
    "                    img_token_num=576\n",
    "                )\n",
    "            sink_scores_hallu.append(sink_scores.cpu().detach().numpy())\n",
    "            non_sink_attention =compute_image_non_sink_attention_sums(attentions=outputs['attentions'][token_in_generation_idx], \n",
    "                                                                      vision_token_start=vision_token_start, img_token_num=576, sink_tokens=sink_tokens)\n",
    "            non_sink_visual_attn_weights_hal.append(non_sink_attention.cpu().detach().numpy())\n",
    "            \n",
    "            \n",
    "            topk_indices,topk_values=get_topk_attention_indices(\n",
    "                text=hallu_word, layer_start=0, layer_end=32, head_start=0, head_end=32, outputs=outputs, tokenizer=model_manager.tokenizer, image=img,\n",
    "                vision_token_start=model_manager.img_start_idx, vision_token_end=model_manager.img_end_idx\n",
    "            )\n",
    "            intersection = np.intersect1d(sink_tokens_new, topk_indices)\n",
    "            vision_sink_tokens_hits=compute_vision_sink_tokens_hits(model_manager.llm_model, outputs,intersection, sink_token_ids, vision_token_start)\n",
    "            vision_sink_tokens_hits_list.extend(vision_sink_tokens_hits)\n",
    "            result_list = analyze_top_vision_token_predictions(\n",
    "                model=model_manager.llm_model,\n",
    "                tokenizer=model_manager.tokenizer,\n",
    "                outputs=outputs,\n",
    "                top_vision_token_indices=topk_indices,\n",
    "                vision_token_start=model_manager.img_start_idx,\n",
    "                layer_range=list(range(1, 33)),  #\n",
    "            )\n",
    "            large_attention_token_indices.extend(result_list)\n",
    "        except Exception as e:\n",
    "            pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "token_counter = Counter(large_attention_token_indices)\n",
    "most_common_token_ids = token_counter.most_common(15)  # List of (token_id, count)\n",
    "most_common_token_strs = [model_manager.tokenizer.decode([item[0]]) for item in most_common_token_ids]\n",
    "combined = [(token_id, count, token_str) for (token_id, count), token_str in zip(most_common_token_ids, most_common_token_strs)]\n",
    "token_ids = [item[0] for item in most_common_token_ids]\n",
    "counts = [item[1] for item in most_common_token_ids]\n",
    "token_strs = [s for s in most_common_token_strs]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "sns.set_theme(style=\"ticks\", context=\"talk\")\n",
    "bar_color = '#1f77b4'\n",
    "\n",
    "token_indices = np.array(vision_sink_tokens_hits_list)[:, 0]\n",
    "hits = np.array(vision_sink_tokens_hits_list)[:, 1]\n",
    "\n",
    "plt.figure(figsize=(8, 5))\n",
    "plt.hist(hits, bins=32,color=bar_color, edgecolor=\"black\", alpha=0.7)\n",
    "plt.grid(True, axis='y', linestyle='--', alpha=0.6)\n",
    "\n",
    "plt.xlabel(\"Vocabulary Fixation Score\", fontsize=18)\n",
    "plt.ylabel(\"Frequency\", fontsize=18)\n",
    "plt.tick_params(axis='both', which='major', labelsize=13)\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"llava15_7b_vocabulary_fixation_U.svg\",bbox_inches='tight')\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.set_theme(style=\"ticks\", context=\"talk\")\n",
    "\n",
    "bar_color = '#1f77b4' \n",
    "\n",
    "# Create the bar chart\n",
    "plt.figure(figsize=(8, 5))\n",
    "\n",
    "plt.bar(range(len(counts)), counts, tick_label=token_strs, \n",
    "        color=bar_color, \n",
    "        edgecolor='black', \n",
    "        alpha=0.7)\n",
    "\n",
    "plt.grid(True, axis='y', linestyle='--', alpha=0.6)\n",
    "\n",
    "plt.xticks(rotation=45, ha=\"right\")\n",
    "plt.xlabel(\"Token\", fontsize=18)\n",
    "plt.ylabel(\"Frequency\", fontsize=18)\n",
    "plt.tick_params(axis='both', which='major', labelsize=13)\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"llava15_7b_vocabulary_fixation.svg\",bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def topk_heads(attn_map, top_k):\n",
    "\n",
    "    flat_indices = np.argsort(attn_map, axis=None)[-top_k:]  \n",
    "    coords = np.array(np.unravel_index(flat_indices, attn_map.shape)).T  \n",
    "    return coords\n",
    "\n",
    "non_sink_attn_heads = topk_heads(np.array(non_sink_visual_attn_weights).mean(axis=0), top_k=1000)\n",
    "\n",
    "# np.save(\"non_sink_max_attention_heads_llava.npy\", non_sink_attn_heads)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "devils",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
