{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8471d78a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from pathlib import Path\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import numpy as np\n",
    "import cv2\n",
    "from PIL import Image\n",
    "from torchvision import models, transforms\n",
    "from typing import List, Optional\n",
    "\n",
    "CKPT_PATH = \"trained_models/im1k_resnet50/shannon_hs_k1_im1k_resnet50.ckpt\"\n",
    "\n",
    "# List of images to process\n",
    "IMAGE_PATHS: List[str] = [\n",
    "    \"images_demo/airplane1.jpg\",\n",
    "    \"images_demo/airplane2.jpg\",\n",
    "    \"images_demo/bird1.jpg\",\n",
    "    \"images_demo/bird2.jpg\",\n",
    "    \"images_demo/boat1.jpg\",\n",
    "    \"images_demo/boat2.jpg\",\n",
    "    \"images_demo/bus1.jpg\",\n",
    "    \"images_demo/cat1.jpg\",\n",
    "    \"images_demo/cat2.jpg\",\n",
    "    \"images_demo/cat3.jpg\",\n",
    "    \"images_demo/horse1.jpg\",\n",
    "    \"images_demo/horse2.jpg\",\n",
    "    \"images_demo/teleferico1.jpg\"\n",
    "]\n",
    "\n",
    "REFERENCE_IMAGE_PATH: Optional[str] = None\n",
    "\n",
    "# Output folder name (as requested): \"ARGUMENT\"\n",
    "OUTPUT_DIR = \"images_demo/smoothgrad-cam/\"\n",
    "\n",
    "ALPHA_FACTOR=0.8  # blend factor for overlay\n",
    "\n",
    "# SmoothGrad settings\n",
    "N_SAMPLES = 20\n",
    "NOISE_SIGMA = 0.10\n",
    "\n",
    "# Target layer (ResNet-50: last bottleneck's conv3)\n",
    "TARGET_LAYER_PATH = (\"layer4\", -1, \"conv3\")\n",
    "\n",
    "# Device\n",
    "DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "# Optional: whether to display the first overlay inline in notebook\n",
    "SHOW_FIRST = True\n",
    "\n",
    "# ===================\n",
    "# Utilities\n",
    "# ===================\n",
    "\n",
    "def load_state_dict_flex(model, ckpt_path):\n",
    "    ckpt = torch.load(ckpt_path, map_location=\"cpu\", weights_only=False)\n",
    "    state_dict = ckpt[\"state_dict\"] if \"state_dict\" in ckpt else ckpt\n",
    "    new_sd = {}\n",
    "    for k, v in state_dict.items():\n",
    "        nk = k\n",
    "        for prefix in [\"model.\", \"backbone.\", \"net.\", \"module.\"]:\n",
    "            if nk.startswith(prefix):\n",
    "                nk = nk[len(prefix):]\n",
    "        new_sd[nk] = v\n",
    "    model.load_state_dict(new_sd, strict=False)\n",
    "    return model\n",
    "\n",
    "class GradCAM:\n",
    "    def __init__(self, model, target_layer):\n",
    "        self.model = model\n",
    "        self.activations = None\n",
    "        self.gradients = None\n",
    "        # hooks\n",
    "        target_layer.register_forward_hook(self.fwd_hook)\n",
    "        target_layer.register_full_backward_hook(self.bwd_hook)\n",
    "\n",
    "    def fwd_hook(self, module, inp, out):\n",
    "        self.activations = out.detach()\n",
    "\n",
    "    def bwd_hook(self, module, grad_in, grad_out):\n",
    "        self.gradients = grad_out[0].detach()\n",
    "\n",
    "    def __call__(self, x, objective):\n",
    "        self.model.zero_grad(set_to_none=True)\n",
    "        emb = self.model(x)\n",
    "        loss = objective(emb)\n",
    "        loss.backward()\n",
    "\n",
    "        A = self.activations            # [B, C, H, W]\n",
    "        dYdA = self.gradients           # [B, C, H, W]\n",
    "        weights = dYdA.mean(dim=(2,3), keepdim=True)  # [B, C, 1, 1]\n",
    "        cam = (weights * A).sum(dim=1)  # [B, H, W]\n",
    "        cam = torch.relu(cam)[0].cpu().numpy()\n",
    "        cam = (cam - cam.min()) / (cam.max() + 1e-8)\n",
    "        return cam\n",
    "\n",
    "class SmoothGradCAM:\n",
    "    def __init__(self, gradcam, n_samples=20, noise_sigma=0.1):\n",
    "        self.gradcam = gradcam\n",
    "        self.n_samples = n_samples\n",
    "        self.noise_sigma = noise_sigma\n",
    "\n",
    "    def __call__(self, x, objective):\n",
    "        cams = []\n",
    "        # per-image std for noise scale\n",
    "        std = x.std().item() + 1e-12\n",
    "        x_min, x_max = x.min().item(), x.max().item()\n",
    "        for _ in range(self.n_samples):\n",
    "            noise = torch.randn_like(x) * (self.noise_sigma * std)\n",
    "            x_noisy = (x + noise).clamp(min=x_min, max=x_max)\n",
    "            with torch.enable_grad():\n",
    "                cam = self.gradcam(x_noisy, objective)\n",
    "            cams.append(cam)\n",
    "        cams = np.stack(cams, axis=0)\n",
    "        cam_mean = cams.mean(0)\n",
    "        cam_mean = (cam_mean - cam_mean.min()) / (cam_mean.max() + 1e-8)\n",
    "        return cam_mean\n",
    "\n",
    "# Preprocess / postprocess\n",
    "transform = transforms.Compose([\n",
    "    transforms.Resize((224,224)),\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize(mean=(0.485,0.456,0.406), std=(0.229,0.224,0.225))\n",
    "])\n",
    "\n",
    "def preprocess(path):\n",
    "    img = Image.open(path).convert(\"RGB\")\n",
    "    x = transform(img).unsqueeze(0).to(DEVICE)\n",
    "    return img, x\n",
    "\n",
    "def colorize_and_overlay(orig_pil: Image.Image, heatmap_01: np.ndarray, alpha=0.8):\n",
    "    W, H = orig_pil.size  # PIL gives (W,H)\n",
    "    heatmap_up = cv2.resize(heatmap_01, (W, H), interpolation=cv2.INTER_LINEAR)\n",
    "    heat_u8 = np.uint8(255 * heatmap_up)\n",
    "    heat_color = cv2.applyColorMap(heat_u8, cv2.COLORMAP_JET)[:, :, ::-1]  # to RGB\n",
    "    overlay = (alpha * heat_color + (1 - alpha) * np.array(orig_pil)).astype(np.uint8)\n",
    "    return heat_color, Image.fromarray(overlay)\n",
    "\n",
    "def ensure_dir(p):\n",
    "    Path(p).mkdir(parents=True, exist_ok=True)\n",
    "\n",
    "def suffix_path(path_str: str, suffix: str, ext: str = \".png\"):\n",
    "    p = Path(path_str)\n",
    "    stem = p.stem\n",
    "    return str(Path(OUTPUT_DIR) / f\"{stem}_{suffix}{ext}\")\n",
    "\n",
    "# ===================\n",
    "# Model / target layer\n",
    "# ===================\n",
    "\n",
    "backbone = models.resnet50(weights=None)\n",
    "backbone.fc = nn.Identity()\n",
    "model = load_state_dict_flex(backbone, CKPT_PATH).to(DEVICE).eval()\n",
    "\n",
    "# Resolve target layer (layer4[-1].conv3 by default)\n",
    "target_layer = getattr(model, TARGET_LAYER_PATH[0])\n",
    "target_layer = target_layer[TARGET_LAYER_PATH[1]]\n",
    "target_layer = getattr(target_layer, TARGET_LAYER_PATH[2])\n",
    "\n",
    "gradcam = GradCAM(model, target_layer)\n",
    "smoothcam = SmoothGradCAM(gradcam, n_samples=N_SAMPLES, noise_sigma=NOISE_SIGMA)\n",
    "\n",
    "# ===================\n",
    "# Objective\n",
    "# ===================\n",
    "if REFERENCE_IMAGE_PATH is not None:\n",
    "    _, x_ref = preprocess(REFERENCE_IMAGE_PATH)\n",
    "    with torch.no_grad():\n",
    "        ref_emb = model(x_ref)\n",
    "    def objective(e):\n",
    "        return F.cosine_similarity(e, ref_emb, dim=1).mean()\n",
    "else:\n",
    "    # default: encourage large embedding norm\n",
    "    def objective(e):\n",
    "        return e.norm(p=2, dim=1).mean()\n",
    "\n",
    "# ===================\n",
    "# Run batch\n",
    "# ===================\n",
    "ensure_dir(OUTPUT_DIR)\n",
    "\n",
    "processed = 0\n",
    "failed = []\n",
    "\n",
    "for idx, img_path in enumerate(IMAGE_PATHS):\n",
    "    try:\n",
    "        orig_img, x = preprocess(img_path)\n",
    "        with torch.enable_grad():\n",
    "            heatmap_smooth = smoothcam(x, objective)  # [H, W] in [0,1]\n",
    "\n",
    "        # Save grayscale heatmap\n",
    "        heat_color, overlay_pil = colorize_and_overlay(orig_img, heatmap_smooth, alpha=ALPHA_FACTOR)\n",
    "        out_heat_path = suffix_path(img_path, \"smoothgrad\", \".png\")\n",
    "        out_overlay_path = suffix_path(img_path, \"smoothgrad_overlay\", \".png\")\n",
    "\n",
    "        Image.fromarray(heat_color).save(out_heat_path)\n",
    "        overlay_pil.save(out_overlay_path)\n",
    "\n",
    "        processed += 1\n",
    "\n",
    "        # Optionally show first overlay in notebook\n",
    "        if SHOW_FIRST and idx == 0:\n",
    "            display(overlay_pil)\n",
    "\n",
    "        print(f\"[OK] {img_path} ->\")\n",
    "        print(f\"     {out_heat_path}\")\n",
    "        print(f\"     {out_overlay_path}\")\n",
    "\n",
    "    except Exception as e:\n",
    "        failed.append((img_path, str(e)))\n",
    "        print(f\"[FAIL] {img_path}: {e}\")\n",
    "\n",
    "print(f\"\\nDone. Processed: {processed} | Failed: {len(failed)}\")\n",
    "if failed:\n",
    "    for p, msg in failed:\n",
    "        print(f\" - {p}: {msg}\")\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "venv (3.12.3)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
