{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 290,
   "id": "91dc9c3f-0bcb-40e0-aec6-75342990540a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from tqdm import tqdm\n",
    "from datasets import load_dataset\n",
    "from PIL import Image, ImageDraw, ImageFont\n",
    "from IPython.display import display\n",
    "from transformers import LlavaForConditionalGeneration, AutoProcessor\n",
    "import traceback\n",
    "import random\n",
    "import numpy as np\n",
    "from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training\n",
    "from utils import *\n",
    "from transformers import TrainingArguments, Trainer, default_data_collator\n",
    "import os\n",
    "from torch.utils.data import DataLoader\n",
    "from tqdm import tqdm\n",
    "import torch.nn.functional as F\n",
    "from IPython.display import display\n",
    "from datasets import load_dataset\n",
    "import argparse\n",
    "import sys\n",
    "from pathlib import Path\n",
    "from matplotlib import pyplot as plt\n",
    "import re\n",
    "# import cv2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "id": "cb977a69-5a3c-4775-bae5-42d8bf19c5a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_axis_max(arr, threshold, axis = 1):\n",
    "    # Step 1: row maxima and their indices\n",
    "    row_max_vals = arr.max(axis=axis)        # shape (h,)\n",
    "    row_max_ids = arr.argmax(axis=axis)      # shape (h,)\n",
    "    print(row_max_vals)\n",
    "    print(row_max_ids)\n",
    "    \n",
    "    # Step 2: filter with threshold\n",
    "    result = np.where(row_max_vals >= threshold, row_max_ids, -1)\n",
    "    return result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 262,
   "id": "fd57a8e4-5ee8-498b-9f86-a17924885609",
   "metadata": {},
   "outputs": [],
   "source": [
    "def put_one_text(img, token_id, text, size, hor_tokens = 24, ver_tokens = 24):\n",
    "    h, w = img.shape[:2]\n",
    "    # w, h = size\n",
    "    px = (int(token_id)%hor_tokens)/ver_tokens\n",
    "    py = (int(token_id)//hor_tokens)/hor_tokens\n",
    "    x, y = int(px * w), int(py * h)\n",
    "    # img.text((x, y), text, fill=(0, 255, 0))\n",
    "    cv2.putText(img, text, (x, y), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (10, 255, 10), 2, cv2.LINE_AA)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a3477e74-56c1-4b5f-a453-5729ec938794",
   "metadata": {},
   "source": [
    "### Visualizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 304,
   "id": "07725598-6947-4a08-a2ce-c033e8f6f6f1",
   "metadata": {},
   "outputs": [],
   "source": [
    "def visualize_alignment(model, processor, dataset, max_new_tokens=1024, do_sample=False, temperature=0.7, top_p=0.9, max_sample = 1, start_id = 0, forced_prompt = None):\n",
    "    model.eval()\n",
    "    predictions, references = [], []\n",
    "\n",
    "    dataset = dataset.skip(start_id)\n",
    "    num = 0\n",
    "    \n",
    "    if forced_prompt:\n",
    "        max_sample = 1\n",
    "\n",
    "    for sample in dataset:\n",
    "        print('\\n', '*'*20, '\\nImage ID', num+start_id)\n",
    "\n",
    "        processor.patch_size = 14\n",
    "\n",
    "        question, gt_answer, img = unwrap_150k_row(sample)\n",
    "        # print(gt_answer, type(gt_answer))\n",
    "        if forced_prompt:\n",
    "            question = forced_prompt\n",
    "            \n",
    "        question = question + \"\\nASSISTANT:\"\n",
    "\n",
    "        # Prepare input\n",
    "        inputs = processor(\n",
    "            text=question,\n",
    "            images=img,\n",
    "            return_tensors=\"pt\",\n",
    "            padding=True,\n",
    "            truncation=True\n",
    "        ).to(model.device)\n",
    "\n",
    "\n",
    "        # img_aligned = ImageDraw.Draw(img)\n",
    "        # img_aligned = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)\n",
    "        img_aligned = np.array(img)\n",
    "        # img_aligned = img.copy()\n",
    "\n",
    "        generated_tokens = model.generate(\n",
    "            processor = processor,\n",
    "            pixel_values=inputs[\"pixel_values\"],\n",
    "            input_ids=inputs[\"input_ids\"],\n",
    "            attention_mask=inputs[\"attention_mask\"],\n",
    "            max_new_tokens=max_new_tokens\n",
    "        )\n",
    "        text_max = model.token_mixer.text_max\n",
    "        print(text_max, len(text_max))\n",
    "\n",
    "        # print(inputs[\"input_ids\"].shape)\n",
    "        # for i, token in enumerate(inputs[\"input_ids\"][0].cpu()):\n",
    "        for i, token in enumerate(model.token_mixer.prompt_text.cpu()):\n",
    "            token_text = processor.tokenizer.decode([token.item()])\n",
    "\n",
    "            token_id = text_max[i]\n",
    "            print(token_id, type(token), token.shape,token )\n",
    "\n",
    "            invalid_words = [\"is\", \"of\", \"the\", \"ASS\", \"IST\", \"ANT\", \"are\"]\n",
    "            if token_id > -0.5 and re.fullmatch(r\"[A-Za-z0-9 ]*\", token_text) and not any(word in token_text for word in invalid_words) and len(token_text)>1:\n",
    "                print(token_text)\n",
    "                put_one_text(img_aligned, token_id, token_text, img.size)\n",
    "  \n",
    "            \n",
    "        display(img)\n",
    "        # Decode generated tokens\n",
    "        generated_text = processor.tokenizer.decode(generated_tokens, skip_special_tokens=True).lower().strip()\n",
    "        print(question)\n",
    "        # img_aligned.show()\n",
    "        print(generated_text)\n",
    "        \n",
    "        # h, w = img_aligned.shape[:2]\n",
    "        # dpi = 180  # adjust for sharper display\n",
    "        # plt.figure(figsize=(w / dpi, h / dpi), dpi=dpi)\n",
    "        # plt.imshow(img_aligned)\n",
    "        # plt.axis('off')  # optional, hide axes\n",
    "        # plt.subplots_adjust(left=0, right=1, top=1, bottom=0)\n",
    "        # plt.show()\n",
    "        display(Image.fromarray(img_aligned))\n",
    "\n",
    "        \n",
    "        num += 1\n",
    "        if (num >= max_sample):\n",
    "            break\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "48c44ffd-e319-4684-b57e-8026905a53e3",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "### VisualizeAlignmentCosMixer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 158,
   "id": "34f3a6d3-0a2c-4457-a318-8fb6c549f58b",
   "metadata": {},
   "outputs": [],
   "source": [
    "class VisualizeAlignmentCosMixer(nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.similarity_threshold = None\n",
    "        self.think_mode = None # if think mode is on, the padding tokens are used as pause tokens during training to encourage thinking\n",
    "        self.enhance_mode = None # if enhance mode is on, the image tokens are copied instead of being moved to the similar text tokens\n",
    "        self.text_enhance_mode = None # if text enhance mode is on, text tokens are copied next to similar image tokens\n",
    "\n",
    "    def forward(\n",
    "        self,\n",
    "        pixel_values,\n",
    "        input_ids,\n",
    "        attention_mask,\n",
    "        labels,\n",
    "        tokenizer,\n",
    "        language_model_embedding,\n",
    "        vision_tower,\n",
    "        projector,\n",
    "        llama_hidden_size=2048,\n",
    "        use_permutation=True, # TODO adjust this to debug\n",
    "    ):\n",
    "        \"\"\"\n",
    "        Reorders input tokens and labels based on cosine similarity between 2048-D embeddings.\n",
    "        Image tokens are moved in front of their most similar text token only if the similarity\n",
    "        exceeds the specified threshold. Otherwise, they remain in their original position at the\n",
    "        start of the sequence.\n",
    "        Args:\n",
    "            pixel_values: [batch, 3, H, W]\n",
    "            input_ids: [batch, seq_len]\n",
    "            attention_mask: [batch, seq_len]\n",
    "            labels: [batch, seq_len], with -100 for prompt tokens, or None during inference\n",
    "            tokenizer: Processor’s tokenizer\n",
    "            language_model_embedding: Language model’s embedding layer\n",
    "            vision_tower: CLIP vision encoder\n",
    "            projector: Multimodal projector\n",
    "            llama_hidden_size: Hidden size of LLaMA (default: 2048)\n",
    "            use_permutation: Whether to apply token reordering (default: False)\n",
    "            similarity_threshold: Minimum cosine similarity to move image token (default: 0.6)\n",
    "        Returns:\n",
    "            inputs_embeds: [batch, new_seq_len, 2048]\n",
    "            new_attention_mask: [batch, new_seq_len]\n",
    "            new_labels: [batch, new_seq_len] (None during inference)\n",
    "        \"\"\"\n",
    "        if not use_permutation:\n",
    "            bsz = input_ids.size(0)\n",
    "            # 1. Look up text embeddings\n",
    "            txt_embeds = language_model_embedding(input_ids)  # [batch, seq_len, hidden_size]\n",
    "\n",
    "            # 2. Get vision features\n",
    "            img_hidden = vision_tower(pixel_values, output_hidden_states=True).hidden_states[-2][:, 1:, :]  # [batch, num_patches, vision_hidden_size]\n",
    "            img_embeds1 = projector(img_hidden).to(dtype=txt_embeds.dtype)  # [batch, num_patches, hidden_size]\n",
    "\n",
    "            # 3. Get image token ID\n",
    "            image_token_id = tokenizer.convert_tokens_to_ids(\"<image>\")\n",
    "\n",
    "            # 4. Create mask for image tokens\n",
    "            special_image_mask = (input_ids == image_token_id).unsqueeze(-1)  # [batch, seq_len, 1]\n",
    "            special_image_mask = special_image_mask.expand_as(txt_embeds).to(txt_embeds.device)  # [batch, seq_len, hidden_size]\n",
    "            n_image_tokens = (input_ids == image_token_id).sum()\n",
    "\n",
    "            # 5. Validate number of image tokens against image features\n",
    "            n_image_features = img_embeds1.view(-1).shape[0]  # Total features across batch\n",
    "            if n_image_tokens * img_embeds1.shape[-1] != n_image_features:\n",
    "                raise ValueError(\n",
    "                    f\"Image features and image tokens do not match: tokens: {n_image_tokens}, features: {n_image_features // img_embeds.shape[-1]}\"\n",
    "                )\n",
    "\n",
    "            # 6. Flatten image features to match token positions\n",
    "            img_embeds_flat = img_embeds1.view(-1, img_embeds1.shape[-1])  # [batch * num_patches, hidden_size]\n",
    "\n",
    "            # 7. Replace image tokens with image features\n",
    "            inputs_embeds = txt_embeds.masked_scatter(special_image_mask, img_embeds_flat)\n",
    "            \n",
    "            #comparison = torch.all(img_embeds1 == inputs_embeds[:, 1:577])\n",
    "            #print(\"Very Beginning!!!! img_embeds == inputs_embeds[:, 1:577]\", comparison)\n",
    "\n",
    "            # 8. Keep attention mask and labels unchanged\n",
    "            new_attention_mask = attention_mask\n",
    "            new_labels = labels\n",
    "            #print(inputs_embeds.shape)\n",
    "            return inputs_embeds, new_attention_mask, new_labels\n",
    "\n",
    "            \n",
    "\n",
    "        bsz = input_ids.size(0)\n",
    "\n",
    "        # ---------- 1. Obtain 2048-D image patch embeddings (dtype matches LM) ----------\n",
    "        img_hidden = vision_tower(pixel_values, output_hidden_states=True).hidden_states[-2][:, 1:, :]  # [batch, num_patches, vision_hidden_size]\n",
    "        img_embeds = projector(img_hidden).to(dtype=language_model_embedding.weight.dtype)  # [B, 576, 2048]\n",
    "\n",
    "        # img_embeds = img_embeds1 # Use pre-computed image embeddings for DEBUG!!!\n",
    "        # print(\"beginning: img_embeds shape:\", img_embeds.shape)\n",
    "        # print(\"beginning: inputs_embeds shape:\", inputs_embeds.shape)\n",
    "        # # Compare first batch element\n",
    "        # # Debug: Check if embeddings are identical\n",
    "        # comparison = torch.all(img_embeds == img_embeds1)\n",
    "        # print(\"beginning: img_embeds == img_embeds1\", comparison)\n",
    "        \n",
    "        # if not comparison:            \n",
    "        #     # Check max difference\n",
    "        #     diff = torch.abs(img_embeds - img_embeds1)\n",
    "        #     print(\"Max absolute difference:\", diff.max().item())\n",
    "        #     print(\"Mean absolute difference:\", diff.mean().item())\n",
    "\n",
    "\n",
    "        reordered_seqs, reordered_atts, reordered_labs = [], [], []\n",
    "\n",
    "        # Pre-compute id of the image sentinel once\n",
    "        image_token_id = tokenizer.convert_tokens_to_ids(\"<image>\")\n",
    "        # Pre-compute id of the ASSISTANT separator once (For Inference Only)\n",
    "        sep_id = tokenizer(\"ASSISTANT:\", add_special_tokens=False).input_ids[-1]\n",
    "        # print(\"image_token_id: \", image_token_id)\n",
    "        # print(\"sep_id: \", tokenizer(\"ASSISTANT:\", add_special_tokens=False).input_ids)\n",
    "        \n",
    "        for b in range(bsz):\n",
    "            #print(input_ids[b])\n",
    "            # ---------- 2. Prompt mask (TRAIN vs INFER) ----------\n",
    "            if labels is not None:  # Training\n",
    "                if self.think_mode:\n",
    "                    prompt_mask = (labels[b] == -100)\n",
    "                    resp_mask = ~prompt_mask\n",
    "                else:\n",
    "                    # Fix: Exclude padding by requiring attention_mask == 1\n",
    "                    prompt_mask = (labels[b] == -100) & (attention_mask[b] == 1)\n",
    "                    #print(\"prompt_mask: \", prompt_mask)\n",
    "                    # Also exclude padding from response mask\n",
    "                    resp_mask = (labels[b] != -100) & (attention_mask[b] == 1)\n",
    "                    #print(\"resp_mask: \", resp_mask)\n",
    "\n",
    "            else:  # Inference: prompt == everything up to last \"ASSISTANT:\"\n",
    "                hits = (input_ids[b] == sep_id).nonzero(as_tuple=True)[0]\n",
    "                if hits.numel() == 0:\n",
    "                    # No separator in this prompt → treat the whole sequence as prompt\n",
    "                    prompt_mask = attention_mask[b].bool()\n",
    "                else:\n",
    "                    sep_pos = hits.max()\n",
    "                    prompt_mask = torch.arange(input_ids.size(1), device=input_ids.device) <= sep_pos\n",
    "                    #print(\"prompt_mask: \", prompt_mask) #all true\n",
    "\n",
    "            prompt_ids = input_ids[b][prompt_mask]\n",
    "            #print(\"prompt_ids: \", prompt_ids)\n",
    "            prompt_text = prompt_ids[prompt_ids.ne(image_token_id)][2:] #excluding 1 and 29871 tokens\n",
    "            # print(\"prompt_text: \", prompt_text, type(prompt_text))\n",
    "\n",
    "            # ---------- 3. Group image patches by cosine similarity ----------\n",
    "            if prompt_text.numel() == 0:  # Corner-case: image-only prompt\n",
    "                groups, num_txt = [[]], 1\n",
    "                txt_labels = torch.tensor([-100], device=input_ids.device)\n",
    "                # Keep all image tokens in their original position\n",
    "                unmoved_images = img_embeds[b]  # All images are unmoved\n",
    "                print(\"image-only prompt\")\n",
    "            else:\n",
    "                txt_embeds = language_model_embedding(prompt_text.unsqueeze(0)).squeeze(0)  # [T, 2048]\n",
    "                sim = F.cosine_similarity(\n",
    "                    txt_embeds.unsqueeze(1),  # [T, 1, 2048]\n",
    "                    img_embeds[b].unsqueeze(0),  # [1, 576, 2048]\n",
    "                    dim=2,\n",
    "                )  # [T, 576]\n",
    "                print(sim.max(dim=1).values)\n",
    "                self.text_max = get_axis_max(sim.detach().cpu().numpy(), self.similarity_threshold)\n",
    "                self.prompt_text = prompt_text\n",
    "                # print(self.text_max)\n",
    "            \n",
    "                \n",
    "               \n",
    "                \n",
    "                if self.text_enhance_mode:\n",
    "                    # Text enhance mode: copy text tokens next to similar image tokens\n",
    "                    # For text enhance: find which image is most similar to each text token\n",
    "                    assign = sim.argmax(dim=1)  # [T] - which image is most similar to each text token\n",
    "                    max_sim = sim.max(dim=1).values  # [T] - max similarity for each text token\n",
    "                    #print(\"Max cosine similarity scores:\",max(max_sim))\n",
    "                    \n",
    "                    unmoved_images = []\n",
    "                    text_groups = [[] for _ in range(img_embeds[b].size(0))]  # Groups for text tokens per image\n",
    "                    \n",
    "                    for txt_idx, (img_idx, sim_score) in enumerate(zip(assign, max_sim)):\n",
    "                        if sim_score >= self.similarity_threshold:\n",
    "                            # Copy text token next to image token\n",
    "                            text_groups[img_idx].append(txt_embeds[txt_idx])\n",
    "                            #print(\"text enhanced tokens:\",txt_idx)\n",
    "                    \n",
    "                    # All image tokens stay in original position\n",
    "                    unmoved_images = img_embeds[b]\n",
    "                    num_txt = txt_embeds.size(0)\n",
    "                    txt_labels = torch.full((num_txt,), -100, dtype=torch.long, device=input_ids.device)\n",
    "                else:\n",
    "                    # Original logic for enhance_mode and reordering\n",
    "                    #print(\"cosine similarity scores:\",sim)\n",
    "                    assign = sim.argmax(dim=0)  # [576]\n",
    "                    max_sim = sim.max(dim=0).values  # [576]\n",
    "                    #print(\"Max cosine similarity scores:\",max(max_sim))\n",
    "                    \n",
    "                    groups = [[] for _ in range(txt_embeds.size(0))]\n",
    "                    if self.enhance_mode:\n",
    "                        # Keep all image tokens in their original position\n",
    "                        unmoved_images = img_embeds[b]  # [576, 2048]\n",
    "                        # print(\"unmoved_images == inputs_embeds[b][1:577]\", unmoved_images == inputs_embeds[b][1:577])\n",
    "                        # print(\"unmoved_images == inputs_embeds[b][1:577]\",torch.all(unmoved_images == inputs_embeds[b][1:577]))\n",
    "                    else:\n",
    "                        unmoved_images = []\n",
    "                    for img_idx, (txt_idx, sim_score) in enumerate(zip(assign, max_sim)):\n",
    "                        if sim_score >= self.similarity_threshold:\n",
    "                            groups[txt_idx].append(img_embeds[b, img_idx])\n",
    "                        elif not self.enhance_mode:\n",
    "                            unmoved_images.append(img_embeds[b, img_idx])\n",
    "                    if not self.enhance_mode:\n",
    "                        unmoved_images = torch.stack(unmoved_images) if unmoved_images else torch.tensor([], device=input_ids.device, dtype=img_embeds.dtype)\n",
    "                    num_txt = txt_embeds.size(0)\n",
    "                    txt_labels = torch.full((num_txt,), -100, dtype=torch.long, device=input_ids.device)\n",
    "\n",
    "            reordered, lab = [], []\n",
    "\n",
    "            #print(\"groups\", groups)\n",
    "\n",
    "            # Add BOS token (if present)\n",
    "            bos_token_id = 1  # BOS token (<s>)\n",
    "            if (prompt_ids == bos_token_id).any():\n",
    "                bos_embed = language_model_embedding(torch.tensor([bos_token_id], device=input_ids.device)).squeeze(0)  # [2048]\n",
    "                reordered.append(bos_embed)\n",
    "                lab.append(-100)\n",
    "     \n",
    "\n",
    "            \n",
    "            # Add grouped image tokens (or copies) and text tokens\n",
    "            if self.text_enhance_mode:\n",
    "                # Text enhance mode: add image tokens with their copied text tokens\n",
    "                for img_idx in range(len(unmoved_images)):\n",
    "                    # Add image token\n",
    "                    reordered.append(unmoved_images[img_idx])\n",
    "                    lab.append(-100)\n",
    "                    # Add copied text tokens next to this image token\n",
    "                    if text_groups[img_idx]:\n",
    "                        reordered.extend(text_groups[img_idx])\n",
    "                        lab.extend([-100] * len(text_groups[img_idx]))\n",
    "\n",
    "                # Add space token (if present)\n",
    "                space_token_id = 29871  # space token\n",
    "                if (prompt_ids == space_token_id).any():\n",
    "                    reordered.append(language_model_embedding(torch.tensor([space_token_id], device=input_ids.device)).squeeze(0))\n",
    "                    lab.append(-100)\n",
    "\n",
    "                # Add prompt text tokens at the end\n",
    "                for t_idx in range(num_txt):\n",
    "                    reordered.append(txt_embeds[t_idx])\n",
    "                    lab.append(txt_labels[t_idx])\n",
    "            else:\n",
    "                # Add unassigned or the original image tokens at the start\n",
    "                if unmoved_images.numel() > 0:\n",
    "                    reordered.extend(unmoved_images)\n",
    "                    # Convert list to tensor for comparison\n",
    "                    #reordered_tensor = torch.stack(reordered) if reordered else torch.tensor([], device=input_ids.device)\n",
    "                    #print(\"reordered == inputs_embeds[b][:577]\", torch.all(reordered_tensor == inputs_embeds[b][:len(reordered_tensor)]))\n",
    "                    lab.extend([-100] * len(unmoved_images))\n",
    "\n",
    "                # Add space token (if present)\n",
    "                space_token_id = 29871  # space token\n",
    "                if (prompt_ids == space_token_id).any():\n",
    "                    reordered.append(language_model_embedding(torch.tensor([space_token_id], device=input_ids.device)).squeeze(0))\n",
    "                    lab.append(-100)\n",
    "\n",
    "                #print(\"reordered\", reordered)\n",
    "                #print(\"reordered\", language_model_embedding(torch.tensor([input_ids[b][0]], device=input_ids.device) ) == reordered[0])\n",
    "                #print(\"reordered\", language_model_embedding(torch.tensor([input_ids[b][577] ], device=input_ids.device) ) == reordered[577])\n",
    "\n",
    "\n",
    "                # Original logic for enhance_mode and reordering\n",
    "                for t_idx in range(num_txt):\n",
    "                    # Add moved or copied image tokens for this text token\n",
    "                    if groups[t_idx]:  # Only add if there are image tokens assigned\n",
    "                        reordered.extend(groups[t_idx])\n",
    "                        lab.extend([-100] * len(groups[t_idx]))\n",
    "                    if t_idx < len(txt_embeds):\n",
    "                        reordered.append(txt_embeds[t_idx])\n",
    "                        lab.append(txt_labels[t_idx])\n",
    "\n",
    "            if labels is not None:  # Training: add answer tokens\n",
    "                reordered.extend(language_model_embedding(input_ids[b])[resp_mask])\n",
    "                lab.extend(labels[b][resp_mask].tolist())\n",
    "\n",
    "            reordered_seqs.append(torch.stack(reordered) if reordered else torch.tensor([], device=input_ids.device, dtype=img_embeds.dtype))\n",
    "            reordered_atts.append(torch.ones(len(reordered), device=input_ids.device, dtype=torch.long))\n",
    "            reordered_labs.append(torch.tensor(lab, device=input_ids.device, dtype=torch.long))\n",
    "\n",
    "        # ---------- 5. Pad to equal length ----------\n",
    "        max_len = max(seq.size(0) for seq in reordered_seqs) if reordered_seqs else 1\n",
    "        dtype = img_embeds.dtype\n",
    "        pad_seq = torch.zeros(bsz, max_len, llama_hidden_size, device=input_ids.device, dtype=dtype)\n",
    "        pad_att = torch.zeros(bsz, max_len, device=input_ids.device, dtype=torch.long)\n",
    "        pad_lab = None if labels is None else torch.full(\n",
    "            (bsz, max_len), -100, device=input_ids.device, dtype=torch.long\n",
    "        )\n",
    "\n",
    "        for b, (seq, att, lab) in enumerate(zip(reordered_seqs, reordered_atts, reordered_labs)):\n",
    "            if seq.numel() > 0:  # Handle empty sequences\n",
    "                pad_seq[b, :seq.size(0)] = seq\n",
    "                #print(\"pad_seq[b], len(seq)\", len(pad_seq[b]), len(seq))\n",
    "                pad_att[b, :att.size(0)] = att\n",
    "                #print(\"pad_att[b], len(att)\", len(pad_att[b]), len(att))\n",
    "                sl = seq.size(0)\n",
    "                pad_token_id = tokenizer.pad_token_id \n",
    "                pad_embed_single = language_model_embedding(torch.tensor([pad_token_id], device=input_ids.device)).squeeze(0)\n",
    "                if sl < max_len:\n",
    "                    pad_embed = pad_embed_single.repeat(max_len - sl, 1)  # [pad_len, hidden_size]\n",
    "                    pad_seq[b, sl:] = pad_embed\n",
    "                    #pad_seq[b, sl:] = language_model_embedding(input_ids[b, sl:])\n",
    "                if labels is not None:\n",
    "                    pad_lab[b, :lab.size(0)] = lab\n",
    "\n",
    "        # torch.set_printoptions(threshold=10000)\n",
    "        # # Per-token equality: True only if all hidden dims match for that token\n",
    "        # token_equal = (pad_seq == inputs_embeds).all(dim=2)\n",
    "        # #print(\"per-token equal (bsz x seq_len):\", token_equal)\n",
    "        # print(\"all tokens equal:\", bool(token_equal.all().item()))\n",
    "        # print(\"pad_seq==inputs_embeds\",torch.all(pad_seq == inputs_embeds))\n",
    "        # mismatched = (~token_equal).nonzero(as_tuple=False)\n",
    "        # #if mismatched.numel() > 0:\n",
    "        #     #print(\"mismatched token coords:\", mismatched.tolist())\n",
    "        # print(\"pad_att shape, new_attention_mask shape\",pad_att.shape, new_attention_mask.shape)\n",
    "        # print(\"pad_att==new_attention_mask\",torch.all(pad_att == new_attention_mask))\n",
    "        # print(\"pad_lab shape, new_labels shape\",pad_lab.shape,new_labels.shape)\n",
    "        # print(\"pad_lab==new_labels\",torch.all(pad_lab == new_labels))\n",
    "\n",
    "        return pad_seq, pad_att, pad_lab\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "def get_target_models(model):\n",
    "    patterns = [\n",
    "      re.compile(r\".*language_model\\.layers\\.\\d+\\.self_attn\\.(q|k|v|o)_proj$\"),\n",
    "      re.compile(r\".*language_model\\.layers\\.\\d+\\.mlp\\.(gate|up|down)_proj$\"),\n",
    "      re.compile(r\".*multi_modal_projector\\.linear_(1|2)$\"),  # Add projector modules\n",
    "\n",
    "    ]\n",
    "\n",
    "    matched = []\n",
    "    for name, module in model.named_modules():\n",
    "        if any(p.search(name) for p in patterns):\n",
    "            matched.append(name)\n",
    "    return matched\n",
    "\n",
    "def check_trainable_parameters(model):\n",
    "    print(\"Checking trainable parameters for permutation_net:\")\n",
    "    for name, param in model.named_parameters():\n",
    "        print(f\"{name}: requires_grad={param.requires_grad}\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7ece8989-134a-4e04-944f-f9d44a562c3c",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "### Load model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 227,
   "id": "a59f5629-149b-4ec9-83aa-af84f8acf2f1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# today_dir = \"08_26_150k_thres_0_12_tm_F_textenh\"\n",
    "today_dir = \"09_05_150k_thres_0_08_tm_F_textenh\"\n",
    "merged_path = \"tinyllava-lora/\"+today_dir+\"/merged\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 228,
   "id": "4ec71631-8ba4-488e-91e6-4242c91c7257",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = CustomLlavaForConditionalGeneration.from_pretrained(\n",
    "    merged_path,\n",
    "    torch_dtype=torch.float16,\n",
    "    device_map=\"cuda\",\n",
    "    token_mixer_class = VisualizeAlignmentCosMixer\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 229,
   "id": "9444b43b-d42f-414b-86b9-17eddd61f3d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "processor = AutoProcessor.from_pretrained(merged_path)\n",
    "processor.patch_size = 14\n",
    "model.set_token_mixer_processor(processor)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 270,
   "id": "de1f4f9d-f25f-4969-91d0-ece517b24d4a",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = load_dataset(\n",
    "        \"liuhaotian/LLaVA-Instruct-150K\",\n",
    "        split=\"train\",\n",
    "        streaming=True\n",
    "    ).shuffle(seed=126, buffer_size=1000)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "57ba497d-527b-49e4-ae34-10d0d7d16c1a",
   "metadata": {},
   "source": [
    "### Reload VisualizeAlignmentCosMixer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 271,
   "id": "c520ace2-d658-402e-ab75-6f4d1c2eec9d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CustomLlavaForConditionalGeneration loaded with thres=0.08 , think_mode=False , enhance_mode=False , text_enhance_mode=True\n"
     ]
    }
   ],
   "source": [
    "model.token_mixer = VisualizeAlignmentCosMixer()\n",
    "model.load_token_mixer_config(Path(merged_path).parent)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4d478c4b-612e-455b-aaae-b2dcb920351f",
   "metadata": {},
   "outputs": [],
   "source": [
    "visualize_alignment(model, processor, data, max_sample=1, start_id = 301)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e2ff1512-4cb0-4f14-ac21-5ce9dfd3b58a",
   "metadata": {},
   "source": [
    "## "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b1fb7403-9047-4cdf-9881-4fec2650430e",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9446eab1-7e18-4337-90fb-05586ef34347",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
