{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "41bfe4e4",
   "metadata": {},
   "source": [
    "# Attention Exploration\n",
    "\n",
    "In this notebook, we explore the learned attention for a VIT Tiny model, trained on STL10 for 200 epochs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c55b2aa9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# --- ViT Attention (all-layers CLS) batch processor — single cell ---\n",
    "\n",
    "# 0) Imports\n",
    "import os\n",
    "from pathlib import Path\n",
    "from typing import List\n",
    "\n",
    "import torch\n",
    "import numpy as np\n",
    "import cv2\n",
    "from PIL import Image\n",
    "from torchvision import transforms\n",
    "from timm.models.vision_transformer import Attention\n",
    "from IPython.display import display\n",
    "\n",
    "# Your Lightning module (assumed available)\n",
    "from pretrain.train_shannon_hyperspherical import SIMDEX\n",
    "\n",
    "# ===================\n",
    "# Config (edit here)\n",
    "# ===================\n",
    "CKPT_PATH = 'trained_models/stl10_vit_tiny/shannon_hypersphere_k1_stl10_vit_tiny_patch16_224_last.ckpt'\n",
    "\n",
    "IMAGE_PATHS: List[str] = [\n",
    "    \"images_demo/airplane1.jpg\",\n",
    "    \"images_demo/airplane2.jpg\",\n",
    "    \"images_demo/bird1.jpg\",\n",
    "    \"images_demo/bird2.jpg\",\n",
    "    \"images_demo/boat1.jpg\",\n",
    "    \"images_demo/boat2.jpg\",\n",
    "    \"images_demo/bus1.jpg\",\n",
    "    \"images_demo/cat1.jpg\",\n",
    "    \"images_demo/cat2.jpg\",\n",
    "    \"images_demo/cat3.jpg\",\n",
    "    \"images_demo/horse1.jpg\",\n",
    "    \"images_demo/horse2.jpg\",\n",
    "    \"images_demo/teleferico1.jpg\"\n",
    "]\n",
    "\n",
    "OUTPUT_DIR = \"images_demo/attention\"    # output directory\n",
    "SHOW_FIRST = True          # show the first overlay inline\n",
    "COLORMAP = cv2.COLORMAP_JET\n",
    "OVERLAY_ALPHA = 0.75        # blend for overlay: heatmap vs. original\n",
    "\n",
    "# ===================\n",
    "# 1) Attention module that stores A and is attn_mask-safe\n",
    "# ===================\n",
    "class AttentionWithStore(Attention):\n",
    "    def forward(self, x, attn_mask=None):\n",
    "        B, N, C = x.shape\n",
    "        qkv = self.qkv(x).reshape(\n",
    "            B, N, 3, self.num_heads, C // self.num_heads\n",
    "        ).permute(2, 0, 3, 1, 4)\n",
    "        q, k, v = qkv\n",
    "        attn = (q @ k.transpose(-2, -1)) * self.scale\n",
    "        if attn_mask is not None:\n",
    "            attn = attn + attn_mask\n",
    "        attn = attn.softmax(-1)\n",
    "        self.last_attn = attn.detach()            # (B, H, T, T)\n",
    "        x = (attn @ v).transpose(1, 2).reshape(B, N, C)\n",
    "        x = self.proj(x)\n",
    "        x = self.attn_drop(x)\n",
    "        return x\n",
    "\n",
    "# ===================\n",
    "# 2) Load checkpoint → backbone → patch every block\n",
    "# ===================\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "simdex = SIMDEX.load_from_checkpoint(CKPT_PATH, strict=False).eval()\n",
    "backbone = simdex.backbone\n",
    "for blk in backbone.blocks:                      # patch **all** blocks\n",
    "    blk.attn.__class__ = AttentionWithStore\n",
    "\n",
    "backbone.to(device).eval()\n",
    "\n",
    "# ===================\n",
    "# 3) Preprocess (224×224, ImageNet stats)\n",
    "# ===================\n",
    "prep = transforms.Compose([\n",
    "    transforms.Resize((224,224)),\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize(mean=(0.485,0.456,0.406),\n",
    "                         std =(0.229,0.224,0.225))\n",
    "])\n",
    "\n",
    "def load_image(path: str):\n",
    "    pil = Image.open(path).convert(\"RGB\")\n",
    "    x = prep(pil).unsqueeze(0).to(device)\n",
    "    return pil, x\n",
    "\n",
    "# ===================\n",
    "# 4) Utilities\n",
    "# ===================\n",
    "def ensure_dir(p):\n",
    "    Path(p).mkdir(parents=True, exist_ok=True)\n",
    "\n",
    "def suffix_path(path_str: str, suffix: str, ext: str = \".png\"):\n",
    "    p = Path(path_str)\n",
    "    return str(Path(OUTPUT_DIR) / f\"{p.stem}_{suffix}{ext}\")\n",
    "\n",
    "def colorize_and_overlay(orig_pil: Image.Image, heatmap_01: np.ndarray,\n",
    "                         alpha=OVERLAY_ALPHA, colormap=COLORMAP):\n",
    "    W, H = orig_pil.size\n",
    "    heat_up = cv2.resize(heatmap_01, (W, H), interpolation=cv2.INTER_CUBIC)\n",
    "    heat_u8 = np.uint8(255 * np.clip(heat_up, 0, 1))\n",
    "    heat_color_bgr = cv2.applyColorMap(heat_u8, colormap)\n",
    "    heat_color_rgb = heat_color_bgr[:, :, ::-1]  # BGR -> RGB\n",
    "    overlay = (alpha * heat_color_rgb + (1 - alpha) * np.array(orig_pil)).astype(np.uint8)\n",
    "    return heat_color_rgb, Image.fromarray(overlay)\n",
    "\n",
    "# ===================\n",
    "# 5) Run batch\n",
    "# ===================\n",
    "ensure_dir(OUTPUT_DIR)\n",
    "processed, failed = 0, []\n",
    "\n",
    "for idx, img_path in enumerate(IMAGE_PATHS):\n",
    "    try:\n",
    "        pil, x = load_image(img_path)\n",
    "\n",
    "        # Forward once – each block now has .last_attn\n",
    "        with torch.no_grad():\n",
    "            _ = backbone(x)\n",
    "\n",
    "        # Accumulate CLS attention from *all* layers (simple mean over layers & heads)\n",
    "        attn_layers = []\n",
    "        for blk in backbone.blocks:\n",
    "            A = blk.attn.last_attn.squeeze(0)  # (heads, tokens, tokens)\n",
    "            A = A.mean(0)                      # average heads → (tokens, tokens)\n",
    "            attn_layers.append(A)\n",
    "\n",
    "        attn_mat = torch.stack(attn_layers, dim=0).mean(0)  # mean over layers → (tokens, tokens)\n",
    "\n",
    "        # CLS → patch vector\n",
    "        cls_vec = attn_mat[0, 1:]                           # (Npatch,)\n",
    "        num_patches = cls_vec.numel()\n",
    "        grid = int(np.sqrt(num_patches))\n",
    "        if grid * grid != num_patches:\n",
    "            raise ValueError(f\"Non-square patch grid (got {num_patches}).\")\n",
    "\n",
    "        heat = cls_vec.reshape(grid, grid).detach().cpu().numpy()\n",
    "\n",
    "        # Normalize to [0,1]\n",
    "        hmin, hmax = float(heat.min()), float(heat.max())\n",
    "        heat_01 = (heat - hmin) / (hmax - hmin + 1e-12)\n",
    "\n",
    "        # Colorize + overlay (JET, same colors for both)\n",
    "        heat_color_rgb, overlay_pil = colorize_and_overlay(pil, heat_01)\n",
    "\n",
    "        # Save (attention-specific suffixes)\n",
    "        out_heat = suffix_path(img_path, \"attn\", \".png\")\n",
    "        out_overlay = suffix_path(img_path, \"attn_overlay\", \".png\")\n",
    "\n",
    "        Image.fromarray(heat_color_rgb).save(out_heat)\n",
    "        overlay_pil.save(out_overlay)\n",
    "\n",
    "        processed += 1\n",
    "        if SHOW_FIRST and idx == 0:\n",
    "            display(overlay_pil)\n",
    "\n",
    "        print(f\"[OK] {img_path} ->\")\n",
    "        print(f\"     {out_heat}\")\n",
    "        print(f\"     {out_overlay}\")\n",
    "\n",
    "    except Exception as e:\n",
    "        failed.append((img_path, str(e)))\n",
    "        print(f\"[FAIL] {img_path}: {e}\")\n",
    "\n",
    "print(f\"\\nDone. Processed: {processed} | Failed: {len(failed)}\")\n",
    "if failed:\n",
    "    for p, msg in failed:\n",
    "        print(f\" - {p}: {msg}\")\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "venv (3.12.3)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
