{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2ecdaf7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# %%\n",
    "# t-SNE for Hugging Face 'clane9/imagenet-100' using your ResNet-50 embeddings\n",
    "# - Computes & saves embeddings (.npz)\n",
    "# - Reloads & visualizes with t-SNE (colored by class, small points)\n",
    "# - Configurable POINTS_PER_CLASS and LABELED_PERCENT\n",
    "\n",
    "# -------------------\n",
    "# Imports\n",
    "# -------------------\n",
    "import os\n",
    "import random\n",
    "import time\n",
    "from pathlib import Path\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.utils.data import DataLoader, Dataset, Subset\n",
    "\n",
    "from torchvision import transforms, models\n",
    "from PIL import Image\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.manifold import TSNE\n",
    "\n",
    "from datasets import load_dataset, disable_caching\n",
    "\n",
    "# -------------------\n",
    "# Config (EDIT THESE)\n",
    "# -------------------\n",
    "# Hugging Face dataset + split\n",
    "HF_DATASET = \"clane9/imagenet-100\"\n",
    "HF_SPLIT = \"train\"      # options likely include 'train' and 'val' (or 'test' if present)\n",
    "HF_REVISION = None      # e.g., a specific commit/tag; None uses default\n",
    "\n",
    "# Output paths\n",
    "OUTPUT_DIR = \"outputs_im100_tsne_hf\"\n",
    "EMB_FILE = \"embeddings_im100_resnet50_hf.npz\"\n",
    "TSNE_FILE = \"tsne_im100_resnet50_hf.npy\"\n",
    "SAVE_TSNE_EMB = True\n",
    "\n",
    "# Model checkpoint & device\n",
    "CKPT_PATH = \"trained_models/im1k_resnet50/shannon_hs_k1_im1k_resnet50.ckpt\"\n",
    "DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "# Data & loader\n",
    "IMG_SIZE = 224\n",
    "BATCH_SIZE = 64\n",
    "NUM_WORKERS = 4\n",
    "\n",
    "# Subsampling per class (None -> use all)\n",
    "POINTS_PER_CLASS = 100     # set to None for all images per class\n",
    "\n",
    "# Plot controls\n",
    "POINT_SIZE = 2             # small points\n",
    "LABELED_PERCENT = 0.01     # fraction of points to annotate (e.g., 0.01 = 1%)\n",
    "RANDOM_SEED = 42\n",
    "\n",
    "# t-SNE hyperparameters\n",
    "TSNE_PERPLEXITY = 30\n",
    "TSNE_N_ITER = 1500\n",
    "TSNE_LEARNING_RATE = \"auto\"\n",
    "TSNE_METRIC = \"euclidean\"\n",
    "TSNE_INIT = \"pca\"  # or \"random\"\n",
    "\n",
    "# (optional) Hugging Face datasets cache control\n",
    "# disable_caching()  # uncomment if you don't want HF to cache to ~/.cache/huggingface\n",
    "\n",
    "# -------------------\n",
    "# Reproducibility\n",
    "# -------------------\n",
    "def set_seed(seed=RANDOM_SEED):\n",
    "    random.seed(seed)\n",
    "    np.random.seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "    torch.cuda.manual_seed_all(seed)\n",
    "\n",
    "set_seed(RANDOM_SEED)\n",
    "\n",
    "# -------------------\n",
    "# Your flexible state-dict loader\n",
    "# -------------------\n",
    "def load_state_dict_flex(model, ckpt_path):\n",
    "    ckpt = torch.load(ckpt_path, map_location=\"cpu\", weights_only=False)\n",
    "    state_dict = ckpt[\"state_dict\"] if isinstance(ckpt, dict) and \"state_dict\" in ckpt else ckpt\n",
    "    new_sd = {}\n",
    "    for k, v in state_dict.items():\n",
    "        nk = k\n",
    "        for prefix in [\"model.\", \"backbone.\", \"net.\", \"module.\"]:\n",
    "            if nk.startswith(prefix):\n",
    "                nk = nk[len(prefix):]\n",
    "        new_sd[nk] = v\n",
    "    model.load_state_dict(new_sd, strict=False)\n",
    "    return model\n",
    "\n",
    "# Build backbone and load weights\n",
    "backbone = models.resnet50(weights=None)\n",
    "backbone.fc = nn.Identity()  # output embeddings\n",
    "model = load_state_dict_flex(backbone, CKPT_PATH).to(DEVICE).eval()\n",
    "\n",
    "# -------------------\n",
    "# Load HF dataset\n",
    "# -------------------\n",
    "print(f\"Loading HF dataset: {HF_DATASET} [{HF_SPLIT}] ...\")\n",
    "hf_ds = load_dataset(HF_DATASET, split=HF_SPLIT, revision=HF_REVISION)\n",
    "\n",
    "# Expect features like: {'image': Image (PIL), 'label': ClassLabel}\n",
    "# We’ll be robust in case 'image' is array -> convert to PIL\n",
    "feature_names = hf_ds.features\n",
    "assert \"label\" in feature_names, \"Expected a 'label' feature in the dataset.\"\n",
    "label_feature = hf_ds.features[\"label\"]\n",
    "if hasattr(label_feature, \"names\") and label_feature.names is not None:\n",
    "    class_names = list(label_feature.names)\n",
    "else:\n",
    "    # Fallback if not a ClassLabel (unlikely for this dataset)\n",
    "    class_names = sorted(list(set(hf_ds[\"label\"])))\n",
    "num_classes = len(class_names)\n",
    "print(f\"Classes: {num_classes}\")\n",
    "\n",
    "# -------------------\n",
    "# Torch Dataset wrapper\n",
    "# -------------------\n",
    "img_transform = transforms.Compose([\n",
    "    transforms.Resize((IMG_SIZE, IMG_SIZE)),\n",
    "    transforms.Lambda(lambda im: im.convert(\"RGB\")),   # <- ensure 3 channels\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize(mean=(0.485,0.456,0.406), std=(0.229,0.224,0.225)),\n",
    "])\n",
    "\n",
    "class HFImageNet100(Dataset):\n",
    "    def __init__(self, hf_dataset, transform=None):\n",
    "        self.ds = hf_dataset\n",
    "        self.transform = transform\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.ds)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        idx = int(idx)  # HF needs a Python int\n",
    "        sample = self.ds[idx]\n",
    "        img = sample[\"image\"]\n",
    "        label = int(sample[\"label\"])\n",
    "\n",
    "        # Ensure PIL and 3-channel RGB\n",
    "        if not isinstance(img, Image.Image):\n",
    "            img = Image.fromarray(np.array(img))\n",
    "        if img.mode != \"RGB\":\n",
    "            img = img.convert(\"RGB\")\n",
    "\n",
    "        if self.transform is not None:\n",
    "            img = self.transform(img)\n",
    "        return img, label\n",
    "\n",
    "\n",
    "base_ds = HFImageNet100(hf_ds, transform=img_transform)\n",
    "\n",
    "# -------------------\n",
    "# Build per-class index lists & subsample\n",
    "# -------------------\n",
    "per_class = [[] for _ in range(num_classes)]\n",
    "for i in range(len(hf_ds)):\n",
    "    per_class[int(hf_ds[i][\"label\"])].append(i)\n",
    "\n",
    "def build_balanced_subset_indices(per_class_lists, points_per_class=None, rng=np.random):\n",
    "    chosen = []\n",
    "    for cls_id, idxs in enumerate(per_class_lists):\n",
    "        if points_per_class is None or points_per_class >= len(idxs):\n",
    "            chosen.extend(int(i) for i in idxs)\n",
    "        else:\n",
    "            # rng.choice returns numpy.int64; cast to Python int\n",
    "            picked = rng.choice(idxs, size=points_per_class, replace=False).tolist()\n",
    "            chosen.extend(int(i) for i in picked)\n",
    "    return sorted(chosen)\n",
    "\n",
    "\n",
    "subset_indices = build_balanced_subset_indices(per_class, POINTS_PER_CLASS, np.random)\n",
    "subset = Subset(base_ds, subset_indices)\n",
    "\n",
    "print(f\"Total selected points: {len(subset)} \"\n",
    "      f\"(per-class cap: {'ALL' if POINTS_PER_CLASS is None else POINTS_PER_CLASS})\")\n",
    "\n",
    "# -------------------\n",
    "# Compute & Save Embeddings\n",
    "# -------------------\n",
    "Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)\n",
    "emb_out_path = str(Path(OUTPUT_DIR) / EMB_FILE)\n",
    "tsne_out_path = str(Path(OUTPUT_DIR) / TSNE_FILE)\n",
    "\n",
    "def compute_embeddings(dset, mdl, device=DEVICE, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS):\n",
    "    loader = DataLoader(dset, batch_size=batch_size, shuffle=False,\n",
    "                        num_workers=num_workers, pin_memory=True)\n",
    "    all_feats, all_labels = [], []\n",
    "    mdl.eval()\n",
    "    with torch.no_grad():\n",
    "        for imgs, labels in loader:\n",
    "            imgs = imgs.to(device, non_blocking=True)\n",
    "            feats = mdl(imgs)  # (B, D)\n",
    "            all_feats.append(feats.detach().cpu().numpy())\n",
    "            all_labels.append(labels.numpy())\n",
    "    feats = np.concatenate(all_feats, axis=0)\n",
    "    labels = np.concatenate(all_labels, axis=0)\n",
    "    return feats, labels\n",
    "\n",
    "if not os.path.exists(emb_out_path):\n",
    "    print(\"Computing embeddings...\")\n",
    "    t0 = time.time()\n",
    "    embeddings, labels = compute_embeddings(subset, model)\n",
    "    dt = time.time() - t0\n",
    "    print(f\"Done. Embeddings: {embeddings.shape} | time: {dt:.1f}s\")\n",
    "\n",
    "    np.savez(\n",
    "        emb_out_path,\n",
    "        embeddings=embeddings.astype(np.float32),\n",
    "        labels=labels.astype(np.int32),\n",
    "        class_names=np.array(class_names),\n",
    "        subset_indices=np.array(subset_indices, dtype=np.int64),\n",
    "        hf_dataset=np.array([HF_DATASET]),\n",
    "        hf_split=np.array([HF_SPLIT]),\n",
    "        ckpt_path=np.array([CKPT_PATH]),\n",
    "        points_per_class=np.array([-1 if POINTS_PER_CLASS is None else POINTS_PER_CLASS]),\n",
    "    )\n",
    "    print(f\"Saved embeddings -> {emb_out_path}\")\n",
    "else:\n",
    "    print(f\"Embeddings already present at {emb_out_path} — skipping compute.\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ba5fc626",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# -------------------\n",
    "# Load embeddings\n",
    "# -------------------\n",
    "data = np.load(emb_out_path, allow_pickle=True)\n",
    "embeddings = data[\"embeddings\"]\n",
    "labels = data[\"labels\"]\n",
    "class_names = data[\"class_names\"].tolist()\n",
    "print(f\"Loaded: embeddings {embeddings.shape}, labels {labels.shape}, classes {len(class_names)}\")\n",
    "\n",
    "def remove_after_comma(s):\n",
    "    return s.split(\",\")[0] if isinstance(s, str) else s\n",
    "\n",
    "class_names = [remove_after_comma(cn) for cn in class_names]\n",
    "\n",
    "# -------------------\n",
    "# t-SNE (compute or load cache)\n",
    "# -------------------\n",
    "if SAVE_TSNE_EMB and os.path.exists(tsne_out_path):\n",
    "    print(f\"Loading cached t-SNE -> {tsne_out_path}\")\n",
    "    tsne_2d = np.load(tsne_out_path)\n",
    "else:\n",
    "    print(\"Running t-SNE...\")\n",
    "    tsne = TSNE(\n",
    "        n_components=2,\n",
    "        perplexity=TSNE_PERPLEXITY,\n",
    "        learning_rate=TSNE_LEARNING_RATE,\n",
    "        metric=TSNE_METRIC,\n",
    "        init=TSNE_INIT,\n",
    "        random_state=RANDOM_SEED,\n",
    "        verbose=1,\n",
    "    )\n",
    "    tsne_2d = tsne.fit_transform(embeddings)\n",
    "    if SAVE_TSNE_EMB:\n",
    "        np.save(tsne_out_path, tsne_2d)\n",
    "        print(f\"Saved t-SNE -> {tsne_out_path}\")\n",
    "\n",
    "# -------------------\n",
    "# Plot helpers\n",
    "# -------------------\n",
    "def make_class_colors(n_classes):\n",
    "    # Use a large, varied map with many distinct hues\n",
    "    cmap = plt.cm.get_cmap(\"gist_ncar\", n_classes)\n",
    "    return np.array([cmap(i) for i in range(n_classes)])\n",
    "\n",
    "def plot_tsne(tsne_xy, labels, class_names, point_size=POINT_SIZE, labeled_percent=LABELED_PERCENT, seed=RANDOM_SEED):\n",
    "    set_seed(seed)\n",
    "    x, y = tsne_xy[:, 0], tsne_xy[:, 1]\n",
    "    lbls = labels.astype(int)\n",
    "    n_classes = len(class_names)\n",
    "    colors = make_class_colors(n_classes)\n",
    "\n",
    "    plt.figure(figsize=(10, 8), dpi=150)\n",
    "    for c in range(n_classes):\n",
    "        mask = (lbls == c)\n",
    "        if not np.any(mask):\n",
    "            continue\n",
    "        plt.scatter(x[mask], y[mask], s=point_size, alpha=0.75, c=[colors[c]], label=class_names[c])\n",
    "\n",
    "    # Label a random subset\n",
    "    if labeled_percent and labeled_percent > 0:\n",
    "        N = len(x)\n",
    "        k = max(1, int(N * labeled_percent))\n",
    "        chosen = np.random.choice(np.arange(N), size=k, replace=False)\n",
    "        for i in chosen:\n",
    "            cls = class_names[int(lbls[i])]\n",
    "            plt.text(x[i], y[i], cls, fontsize=5, alpha=0.6)\n",
    "\n",
    "    # Legend for 100 classes can be overwhelming; keep off by default\n",
    "    # plt.legend(markerscale=3, bbox_to_anchor=(1.04, 1), loc=\"upper left\", borderaxespad=0)\n",
    "    plt.title(f\"t-SNE of {HF_DATASET} ({HF_SPLIT}) embeddings — color = class\")\n",
    "    plt.xlabel(\"t-SNE 1\")\n",
    "    plt.ylabel(\"t-SNE 2\")\n",
    "    plt.tight_layout()\n",
    "    plt.show()\n",
    "\n",
    "# -------------------\n",
    "# Plot!\n",
    "# -------------------\n",
    "plot_tsne(tsne_2d, labels, class_names)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "83670a2b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "def make_class_colors(n_classes):\n",
    "    cmap = plt.cm.get_cmap(\"gist_ncar\", n_classes)\n",
    "    return np.array([cmap(i) for i in range(n_classes)])\n",
    "\n",
    "def plot_tsne_centroids(\n",
    "    tsne_xy,\n",
    "    labels,\n",
    "    class_names,\n",
    "    *,\n",
    "    base_size=120,           # marker size for a medium-sized class\n",
    "    scale_by_count=True,     # scale markers by class frequency\n",
    "    annotate=True,           # write class name next to each centroid\n",
    "    annotate_fontsize=7,\n",
    "    annotate_alpha=0.8,\n",
    "    title=None\n",
    "):\n",
    "    xy = np.asarray(tsne_xy)\n",
    "    lbls = labels.astype(int)\n",
    "    n_classes = len(class_names)\n",
    "    colors = make_class_colors(n_classes)\n",
    "\n",
    "    # Compute centroids (mean in t-SNE space) and counts per class\n",
    "    centroids, counts, valid_classes = [], [], []\n",
    "    for c in range(n_classes):\n",
    "        mask = (lbls == c)\n",
    "        if not np.any(mask):\n",
    "            continue\n",
    "        centroids.append(xy[mask].mean(axis=0))\n",
    "        counts.append(mask.sum())\n",
    "        valid_classes.append(c)\n",
    "    centroids = np.vstack(centroids)\n",
    "    counts = np.asarray(counts)\n",
    "\n",
    "    # Size mapping\n",
    "    if scale_by_count:\n",
    "        # gentle scaling: sqrt to avoid huge bubbles\n",
    "        sizes = base_size * (np.sqrt(counts / counts.max()) * 0.8 + 0.2)\n",
    "    else:\n",
    "        sizes = np.full_like(counts, fill_value=base_size, dtype=float)\n",
    "\n",
    "    # Plot\n",
    "    plt.figure(figsize=(8, 6), dpi=150)\n",
    "    for (cx, cy), c, s in zip(centroids, valid_classes, sizes):\n",
    "        plt.scatter([cx], [cy], s=s, c=[colors[c]], alpha=0.9, edgecolors=\"none\")\n",
    "\n",
    "    if annotate:\n",
    "        for (cx, cy), c in zip(centroids, valid_classes):\n",
    "            plt.text(cx, cy, class_names[c], fontsize=annotate_fontsize, alpha=annotate_alpha,\n",
    "                     ha=\"left\", va=\"bottom\")\n",
    "\n",
    "    plt.xlabel(\"t-SNE 1\")\n",
    "    plt.ylabel(\"t-SNE 2\")\n",
    "    if title is None:\n",
    "        title = f\"t-SNE class centroids — {len(valid_classes)} classes\"\n",
    "    plt.title(title)\n",
    "    plt.tight_layout()\n",
    "    plt.show()\n",
    "\n",
    "plot_tsne_centroids(\n",
    "    tsne_2d, labels, class_names,\n",
    "    base_size=120,\n",
    "    scale_by_count=True,\n",
    "    annotate=True,\n",
    "    annotate_fontsize=7,\n",
    "    title=f\"t-SNE of {HF_DATASET} ({HF_SPLIT}) — class centroids\"\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9e7e8895",
   "metadata": {},
   "outputs": [],
   "source": [
    "# %%\n",
    "# Horizontal dendrogram (HAC) over class centroids\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from scipy.spatial.distance import pdist, squareform\n",
    "from scipy.cluster.hierarchy import linkage, dendrogram, fcluster\n",
    "\n",
    "# 1) Collapse to class-level embeddings (centroids)\n",
    "n_classes = len(class_names)\n",
    "dim = embeddings.shape[1]\n",
    "class_means = np.zeros((n_classes, dim), dtype=np.float32)\n",
    "for c in range(n_classes):\n",
    "    class_means[c] = embeddings[labels == c].mean(axis=0)\n",
    "\n",
    "# (optional) L2-normalize centroids — makes cosine distance behave nicely\n",
    "class_means = class_means / (np.linalg.norm(class_means, axis=1, keepdims=True) + 1e-9)\n",
    "\n",
    "# 2) Pairwise distances and HAC\n",
    "# Use cosine distance; average linkage is a good default for cosine\n",
    "D = pdist(class_means, metric=\"cosine\")      # condensed distance vector\n",
    "Z = linkage(D, method=\"average\")             # or \"complete\", \"single\", \"ward\" (ward needs euclidean)\n",
    "\n",
    "# 3) Plot horizontal dendrogram\n",
    "plt.figure(figsize=(14, max(6, 0.22 * n_classes)))  # scale height with #classes\n",
    "dendrogram(\n",
    "    Z,\n",
    "    labels=class_names,        # your cleaned names (comma-stripped)\n",
    "    orientation=\"left\",        # <-- horizontal dendrogram\n",
    "    leaf_font_size=8,\n",
    "    distance_sort=\"ascending\", # keeps similar items near each other\n",
    "    show_leaf_counts=False,\n",
    ")\n",
    "plt.title(\"Hierarchical Clustering of Class Centroids (cosine, average linkage)\")\n",
    "plt.xlabel(\"Cosine distance\")\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c8317cf2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# =========================\n",
    "# Clustered class-sim heatmap (cosine)\n",
    "# =========================\n",
    "# Inputs expected in memory:\n",
    "#   embeddings : np.ndarray [N, D]\n",
    "#   labels     : np.ndarray [N] (ints 0..C-1)\n",
    "#   class_names: list[str]  (len C)\n",
    "#\n",
    "# Output:\n",
    "#   class_similarity_heatmap.pdf (change via save_pdf)\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from scipy.spatial.distance import pdist, squareform\n",
    "from scipy.cluster.hierarchy import linkage, leaves_list, dendrogram\n",
    "from matplotlib import gridspec\n",
    "\n",
    "def compute_class_centroids(embeddings, labels, n_classes=None):\n",
    "    \"\"\"Return centroids [C,D] in the original embedding space.\"\"\"\n",
    "    emb = np.asarray(embeddings, dtype=np.float32)\n",
    "    y = np.asarray(labels, dtype=np.int64)\n",
    "    if n_classes is None:\n",
    "        n_classes = int(y.max()) + 1\n",
    "    C = n_classes\n",
    "    D = emb.shape[1]\n",
    "    sums = np.zeros((C, D), dtype=np.float32)\n",
    "    counts = np.zeros(C, dtype=np.int64)\n",
    "    for i in range(emb.shape[0]):\n",
    "        c = y[i]\n",
    "        sums[c] += emb[i]\n",
    "        counts[c] += 1\n",
    "    # avoid div-by-zero (classes with zero samples)\n",
    "    counts = np.maximum(counts, 1)\n",
    "    centroids = sums / counts[:, None]\n",
    "    return centroids, counts\n",
    "\n",
    "def plot_clustered_similarity_heatmap(\n",
    "    embeddings,\n",
    "    labels,\n",
    "    class_names,\n",
    "    *,\n",
    "    save_pdf=\"class_similarity_heatmap.pdf\",\n",
    "    title=None,\n",
    "    cmap=\"viridis\",\n",
    "    show=False,\n",
    "    leaf_fontsize=3,\n",
    "    colorbar_label=\"Cosine similarity\",\n",
    "    top_dendrogram=True,\n",
    "    side_dendrogram=True,\n",
    "):\n",
    "    # 1) Centroids\n",
    "    C = len(class_names)\n",
    "    E_c, counts = compute_class_centroids(embeddings, labels, n_classes=C)\n",
    "\n",
    "    # 2) L2-normalize centroids for cosine\n",
    "    E_c = E_c / (np.linalg.norm(E_c, axis=1, keepdims=True) + 1e-12)\n",
    "\n",
    "    # 3) Cosine distances for clustering (condensed) + dendrogram order\n",
    "    d_cond = pdist(E_c, metric=\"cosine\")  # clean, exact zeros on diagonal when squareformed\n",
    "    Z = linkage(d_cond, method=\"average\")\n",
    "    order = leaves_list(Z)\n",
    "\n",
    "    # 4) Similarity matrix (for the heatmap), reordered\n",
    "    S = E_c @ E_c.T                        # cosine similarity in [-1,1]\n",
    "    S = np.clip(S, -1.0, 1.0)\n",
    "    S_sorted = S[np.ix_(order, order)]\n",
    "    names_sorted = [class_names[i] for i in order]\n",
    "\n",
    "    # 5) Figure layout: dendrograms on top+left, heatmap center\n",
    "    fig = plt.figure(figsize=(5.5, 5.5), dpi=1200)\n",
    "    gs = gridspec.GridSpec(\n",
    "        nrows=3, ncols=3,\n",
    "        width_ratios=[0.3, 0.05, 1.0],    # left dendro, spacer, heatmap\n",
    "        height_ratios=[0.3, 0.05, 1.0],   # top dendro, spacer, heatmap\n",
    "        wspace=0.0, hspace=0.0\n",
    "    )\n",
    "\n",
    "    if top_dendrogram:\n",
    "        # Top dendrogram (uses same Z but drawn over reordered indices)\n",
    "        ax_dend_top = fig.add_subplot(gs[0, 2])\n",
    "        dendrogram(\n",
    "            Z, ax=ax_dend_top, color_threshold=None, no_labels=True,\n",
    "            orientation=\"top\", above_threshold_color=\"gray\", labels=None\n",
    "        )\n",
    "        ax_dend_top.set_xticks([]); ax_dend_top.set_yticks([])\n",
    "        for spine in ax_dend_top.spines.values():\n",
    "            spine.set_visible(False)\n",
    "\n",
    "    if side_dendrogram:\n",
    "        # Left dendrogram\n",
    "        ax_dend_left = fig.add_subplot(gs[2, 0])\n",
    "        dendrogram(\n",
    "            Z, ax=ax_dend_left, color_threshold=None, no_labels=True,\n",
    "            orientation=\"left\", above_threshold_color=\"gray\", labels=None\n",
    "        )\n",
    "        ax_dend_left.set_xticks([]); ax_dend_left.set_yticks([])\n",
    "        for spine in ax_dend_left.spines.values():\n",
    "            spine.set_visible(False)\n",
    "\n",
    "    # Heatmap\n",
    "    ax_hm = fig.add_subplot(gs[2, 2])\n",
    "    im = ax_hm.imshow(S_sorted, vmin=0.0, vmax=1.0, cmap=cmap, interpolation=\"nearest\")\n",
    "    ax_hm.set_xticks(range(C)); ax_hm.set_yticks(range(C))\n",
    "    ax_hm.set_xticklabels(names_sorted, rotation=90, fontsize=leaf_fontsize)\n",
    "    ax_hm.set_yticklabels(names_sorted, fontsize=leaf_fontsize)\n",
    "    ax_hm.tick_params(length=0)\n",
    "    for spine in ax_hm.spines.values():\n",
    "        spine.set_visible(False)\n",
    "\n",
    "    # Colorbar\n",
    "    #cax = fig.add_subplot(gs[2, 1])\n",
    "    #cb = plt.colorbar(im, cax=cax)\n",
    "    #cb.set_label(colorbar_label)\n",
    "\n",
    "    if title:\n",
    "        fig.suptitle(title, y=0.0, fontsize=11)\n",
    "\n",
    "    plt.tight_layout(rect=[0, 0, 1, 0.97])\n",
    "    plt.savefig(save_pdf, bbox_inches=\"tight\")\n",
    "    if show:\n",
    "        plt.show()\n",
    "    else:\n",
    "        plt.close(fig)\n",
    "    print(f\"[OK] saved {save_pdf}\")\n",
    "\n",
    "plot_clustered_similarity_heatmap(\n",
    "    embeddings, labels, class_names,\n",
    "    save_pdf=\"class_similarity_heatmap.pdf\",\n",
    "    #title=\"Clustered class affinities (cosine, average linkage)\",\n",
    "    cmap=\"viridis\",\n",
    "    show=True,\n",
    "    top_dendrogram=False,\n",
    "    side_dendrogram=False,\n",
    "    leaf_fontsize=2\n",
    ")\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b8fa49c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# %%\n",
    "# Recursive DBSCAN taxonomy (cosine) over image-level embeddings\n",
    "# - No pre-aggregation\n",
    "# - Start with high eps; recurse inside mixed clusters with eps *= EPS_DECAY\n",
    "# - Cluster title = nearest class (cosine) to the cluster centroid\n",
    "# - Prints a text-based taxonomy tree\n",
    "\n",
    "import numpy as np\n",
    "from sklearn.cluster import DBSCAN\n",
    "from sklearn.metrics.pairwise import cosine_distances\n",
    "from collections import Counter\n",
    "\n",
    "# --------------------\n",
    "# Config\n",
    "# --------------------\n",
    "INIT_EPS = 1          # start fairly high; cosine distance is in [0, 2]\n",
    "MIN_SAMPLES = 10        # density threshold; adjust depending on dataset size\n",
    "EPS_DECAY = 0.75        # lower eps by this factor at each recursion\n",
    "MIN_EPS = 0.001          # don't go below this (prevents overfragmentation)\n",
    "MAX_DEPTH = 10          # safety cap\n",
    "\n",
    "HANDLE_NOISE = False     # include a \"noise\" branch for DBSCAN label = -1\n",
    "SHOW_CLASS_DISTRIB = True  # print top-3 class proportions per cluster\n",
    "\n",
    "# --------------------\n",
    "# Helper: class means (for naming clusters by nearest class to centroid)\n",
    "# --------------------\n",
    "classes = np.unique(labels)\n",
    "# (Optional but helpful) L2-normalize embeddings for cosine\n",
    "emb_norm = embeddings / (np.linalg.norm(embeddings, axis=1, keepdims=True) + 1e-9)\n",
    "\n",
    "class_means = []\n",
    "for c in classes:\n",
    "    cls_vecs = emb_norm[labels == c]\n",
    "    class_means.append(cls_vecs.mean(axis=0))\n",
    "class_means = np.vstack(class_means)  # shape [C, D]\n",
    "\n",
    "def nearest_class_name(vec):\n",
    "    \"\"\"Return class name whose mean is closest (cosine) to vec.\"\"\"\n",
    "    v = vec / (np.linalg.norm(vec) + 1e-9)\n",
    "    dists = cosine_distances(v.reshape(1, -1), class_means)[0]  # smaller = closer\n",
    "    cidx = int(np.argmin(dists))\n",
    "    return class_names[cidx]\n",
    "\n",
    "# --------------------\n",
    "# Recursive DBSCAN\n",
    "# --------------------\n",
    "def run_dbscan(X, eps, min_samples):\n",
    "    # sklearn supports metric='cosine' directly (no need to precompute)\n",
    "    db = DBSCAN(eps=eps, min_samples=min_samples, metric='cosine', n_jobs=-1)\n",
    "    labels_db = db.fit_predict(X)\n",
    "    return labels_db\n",
    "\n",
    "def summarize_distribution(y):\n",
    "    ct = Counter(y.tolist())\n",
    "    total = len(y)\n",
    "    parts = []\n",
    "    for cls_idx, cnt in sorted(ct.items(), key=lambda x: -x[1]):\n",
    "        parts.append(f\"{class_names[int(cls_idx)]} {cnt/total:.1%}\")\n",
    "    return \", \".join(parts)\n",
    "\n",
    "def build_taxonomy(X, y, eps, depth=0, node_name=\"ROOT\"):\n",
    "    \"\"\"\n",
    "    X: embeddings for this node (image-level, already normalized)\n",
    "    y: class labels for X (ints)\n",
    "    eps: DBSCAN epsilon (cosine distance)\n",
    "    depth: recursion depth\n",
    "    Returns a nested structure (dict) describing the taxonomy.\n",
    "    \"\"\"\n",
    "    # Stopping conditions\n",
    "    unique_classes = np.unique(y)\n",
    "    if len(unique_classes) == 1:\n",
    "        # Pure leaf\n",
    "        pure_cls = class_names[int(unique_classes[0])]\n",
    "        return {\"name\": f\"{node_name} → {pure_cls}\", \"leaf\": True, \"count\": len(y)}\n",
    "\n",
    "    if eps < MIN_EPS or depth >= MAX_DEPTH:\n",
    "        # Could not / should not split further\n",
    "        return {\n",
    "            \"name\": f\"{node_name} (mixed; eps={eps:.3f})\",\n",
    "            \"leaf\": True,\n",
    "            \"count\": len(y),\n",
    "            \"distribution\": summarize_distribution(y) if SHOW_CLASS_DISTRIB else None,\n",
    "        }\n",
    "\n",
    "    # Run DBSCAN at this level\n",
    "    lbls = run_dbscan(X, eps=eps, min_samples=MIN_SAMPLES)\n",
    "    clusters = [k for k in np.unique(lbls) if k != -1]\n",
    "    has_noise = np.any(lbls == -1)\n",
    "\n",
    "    # If DBSCAN produced <2 clusters (i.e., one cluster or none), try one decay step right here\n",
    "    if len(clusters) < 2:\n",
    "        next_eps = eps * EPS_DECAY\n",
    "        if next_eps >= MIN_EPS:\n",
    "            return build_taxonomy(X, y, next_eps, depth=depth, node_name=node_name)\n",
    "        else:\n",
    "            return {\n",
    "                \"name\": f\"{node_name} (mixed; unsplit; eps≈{eps:.3f})\",\n",
    "                \"leaf\": True,\n",
    "                \"count\": len(y),\n",
    "                \"distribution\": summarize_distribution(y) if SHOW_CLASS_DISTRIB else None,\n",
    "            }\n",
    "\n",
    "    # Otherwise, we have multiple clusters — recurse into each\n",
    "    children = []\n",
    "    for i, k in enumerate(clusters, start=1):\n",
    "        mask = (lbls == k)\n",
    "        Xk, yk = X[mask], y[mask]\n",
    "        # Name child by nearest class to the centroid\n",
    "        centroid = Xk.mean(axis=0)\n",
    "        cname = nearest_class_name(centroid)\n",
    "        child_name = f\"Cluster {i}: {cname} (n={len(yk)})\"\n",
    "        child = build_taxonomy(\n",
    "            Xk, yk, eps * EPS_DECAY, depth=depth + 1, node_name=child_name\n",
    "        )\n",
    "        children.append(child)\n",
    "\n",
    "    # Optionally attach noise as its own branch\n",
    "    if HANDLE_NOISE and has_noise:\n",
    "        nmask = (lbls == -1)\n",
    "        Xn, yn = X[nmask], y[nmask]\n",
    "        if len(yn) > 0:\n",
    "            noise_name = f\"Noise (n={len(yn)})\"\n",
    "            # Don’t recurse into noise; just summarize\n",
    "            noise_node = {\n",
    "                \"name\": noise_name,\n",
    "                \"leaf\": True,\n",
    "                \"count\": len(yn),\n",
    "                \"distribution\": summarize_distribution(yn) if SHOW_CLASS_DISTRIB else None,\n",
    "            }\n",
    "            children.append(noise_node)\n",
    "\n",
    "    return {\"name\": f\"{node_name} (eps={eps:.3f})\", \"children\": children}\n",
    "\n",
    "# --------------------\n",
    "# Build + Print taxonomy\n",
    "# --------------------\n",
    "tree = build_taxonomy(emb_norm, labels, eps=INIT_EPS, depth=0, node_name=\"ROOT\")\n",
    "\n",
    "def print_tree(node, indent=\"\"):\n",
    "    print(indent + node[\"name\"])\n",
    "    if \"leaf\" in node and node[\"leaf\"]:\n",
    "        if SHOW_CLASS_DISTRIB and node.get(\"distribution\"):\n",
    "            print(indent + \"  classes: \" + node[\"distribution\"])\n",
    "        return\n",
    "    for ch in node.get(\"children\", []):\n",
    "        print_tree(ch, indent + \"  \")\n",
    "\n",
    "print_tree(tree)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e0e3d1a2",
   "metadata": {},
   "outputs": [],
   "source": [
    "for i, name in enumerate(class_names):\n",
    "    print(i, name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b46362ed",
   "metadata": {},
   "outputs": [],
   "source": [
    "# %%\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.manifold import TSNE\n",
    "\n",
    "def tsne_subset(embeddings, labels, class_names, class_ids,\n",
    "                perplexity=30, n_iter=1000, random_state=42,\n",
    "                point_size=6, label_fraction=0.01):\n",
    "    \"\"\"\n",
    "    Run and plot t-SNE for a subset of classes.\n",
    "\n",
    "    embeddings : np.ndarray, shape [N, D]\n",
    "    labels     : np.ndarray, shape [N]\n",
    "    class_names: list[str]\n",
    "    class_ids  : list[int]  (class IDs to include)\n",
    "\n",
    "    Returns: tsne_2d coords, subset_labels\n",
    "    \"\"\"\n",
    "    # Mask for subset of classes\n",
    "    mask = np.isin(labels, class_ids)\n",
    "    emb_sub = embeddings[mask]\n",
    "    lbl_sub = labels[mask]\n",
    "\n",
    "    # Run t-SNE\n",
    "    tsne = TSNE(\n",
    "        n_components=2, perplexity=perplexity,\n",
    "        learning_rate=\"auto\", metric=\"euclidean\",\n",
    "        init=\"pca\", random_state=random_state, verbose=1\n",
    "    )\n",
    "    tsne_2d = tsne.fit_transform(emb_sub)\n",
    "\n",
    "    # Plot\n",
    "    plt.figure(figsize=(8, 6), dpi=120)\n",
    "    colors = make_class_colors(len(class_ids))\n",
    "    for i, cid in enumerate(class_ids):\n",
    "        m = lbl_sub == cid\n",
    "        plt.scatter(tsne_2d[m,0], tsne_2d[m,1], s=point_size,\n",
    "                    alpha=0.7, c=[colors[i]], label=class_names[cid])\n",
    "\n",
    "    # Optionally label some points\n",
    "    if label_fraction > 0:\n",
    "        n = tsne_2d.shape[0]\n",
    "        k = max(1, int(n * label_fraction))\n",
    "        chosen = np.random.choice(n, k, replace=False)\n",
    "        for idx in chosen:\n",
    "            cid = lbl_sub[idx]\n",
    "            plt.text(tsne_2d[idx,0], tsne_2d[idx,1],\n",
    "                     class_names[cid], fontsize=6, alpha=0.6)\n",
    "\n",
    "    plt.legend(markerscale=2, bbox_to_anchor=(1.05, 1),\n",
    "               loc=\"upper left\", borderaxespad=0.)\n",
    "    plt.title(\"t-SNE projection (subset of classes)\")\n",
    "    # background color\n",
    "    plt.gca().set_facecolor(\"gray\")\n",
    "    plt.tight_layout()\n",
    "    plt.show()\n",
    "\n",
    "    return tsne_2d, lbl_sub\n",
    "\n",
    "subset_ids = [3, 24, 25, 28, 34, 39, 43, 58, 62, 64, 70, 73, 95, 96]\n",
    "tsne_coords, subset_labels = tsne_subset(\n",
    "    embeddings, labels, class_names, subset_ids,\n",
    "    perplexity=40, n_iter=1500, point_size=5, label_fraction=0.00\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d685edc6",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "091ed122",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.neighbors import KNeighborsClassifier\n",
    "from sklearn.metrics import top_k_accuracy_score\n",
    "\n",
    "# Linear probe (Logistic Regression, multinomial, no regularization for fair probe)\n",
    "clf = LogisticRegression(\n",
    "    multi_class=\"multinomial\",\n",
    "    solver=\"lbfgs\",\n",
    "    max_iter=200,\n",
    "    C=1e4,        # very weak regularization\n",
    "    verbose=1,\n",
    "    n_jobs=-1,\n",
    ")\n",
    "clf.fit(embeddings, labels)\n",
    "probs = clf.predict_proba(embeddings)\n",
    "top1_acc = top_k_accuracy_score(labels, probs, k=1)\n",
    "top5_acc = top_k_accuracy_score(labels, probs, k=5)\n",
    "print(f\"Linear probe (LogReg): Top-1 acc = {top1_acc:.3f}, Top-5 acc = {top5_acc:.3f}\")\n",
    "\n",
    "# kNN classifier (k=5)\n",
    "knn = KNeighborsClassifier(n_neighbors=5, metric=\"cosine\", n_jobs=-1)\n",
    "knn.fit(embeddings, labels)\n",
    "knn_probs = knn.predict_proba(embeddings)\n",
    "knn_top1 = top_k_accuracy_score(labels, knn_probs, k=1)\n",
    "knn_top5 = top_k_accuracy_score(labels, knn_probs, k=5)\n",
    "print(f\"kNN (k=5): Top-1 acc = {knn_top1:.3f}, Top-5 acc = {knn_top5:.3f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d79163e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.neighbors import KNeighborsClassifier\n",
    "from sklearn.metrics import top_k_accuracy_score\n",
    "\n",
    "# Configurable train/eval split\n",
    "TRAIN_FRAC = 0.8  # fraction of data for training\n",
    "\n",
    "X_train, X_eval, y_train, y_eval = train_test_split(\n",
    "    embeddings, labels, train_size=TRAIN_FRAC, random_state=42, stratify=labels\n",
    ")\n",
    "\n",
    "# Linear probe (Logistic Regression)\n",
    "clf = LogisticRegression(\n",
    "    multi_class=\"multinomial\",\n",
    "    solver=\"lbfgs\",\n",
    "    max_iter=200,\n",
    "    C=1e4,\n",
    "    verbose=1,\n",
    "    n_jobs=-1,\n",
    ")\n",
    "clf.fit(X_train, y_train)\n",
    "probs = clf.predict_proba(X_eval)\n",
    "top1_acc = top_k_accuracy_score(y_eval, probs, k=1)\n",
    "top5_acc = top_k_accuracy_score(y_eval, probs, k=5)\n",
    "print(f\"Linear probe (LogReg): Top-1 acc = {top1_acc:.6f}, Top-5 acc = {top5_acc:.6f}\")\n",
    "\n",
    "# kNN classifier (k=5)\n",
    "knn = KNeighborsClassifier(n_neighbors=5, metric=\"cosine\", n_jobs=-1)\n",
    "knn.fit(X_train, y_train)\n",
    "knn_probs = knn.predict_proba(X_eval)\n",
    "knn_top1 = top_k_accuracy_score(y_eval, knn_probs, k=1)\n",
    "knn_top5 = top_k_accuracy_score(y_eval, knn_probs, k=5)\n",
    "print(f\"kNN (k=5): Top-1 acc = {knn_top1:.6f}, Top-5 acc = {knn_top5:.6f}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "venv (3.12.3)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
