"""CodebookViz: Minimal backbone for codebook editing + network visualization.

Based on ike_chain.py structure with augmentation keys (not patch keys) from ike_chain_dual.py.

Embedding modes:
- "vision": vision_layer(image, text)
- "language": language_layer(image, text)
- "dual_sbert": <vision_layer(image, text), lang_scaler * sbert(text)>
- "dual_internal": <vision_layer(image, text), lang_scaler * language_layer(blank, text)>

Usage:
    viz = CodebookViz(config, vlm, mode="dual_sbert", lang_scaler=8)
    viz.add_edit(image, question, sentences, answer)
    viz.plot_codebook()
"""

import re
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from typing import List, Dict, Optional
from ..utils import parent_module, brackets_to_periods, Augmenter


class CodebookViz(nn.Module):
    """Minimal codebook for VLM editing with network visualization.
    
    Key structure (from ike_chain.py):
    - key_emb: embedding of <image, text>
    - value: sentence to retrieve
    - radius: estimated radius for retrieval
    
    Augmentation:
    - n_aug: augmented keys per sentence (image variants)
    - Question+answer is treated as one of the sentences
    """

    # ==================== 1. INIT ====================

    def __init__(self, config, model, mode="dual_sbert", lang_scaler=8.0,
                 n_aug=3, aug_mode="image_only", aug_area_pct=0.9,
                 radius_method="fixed", distance="l2", pool_method="mean",
                 plot_edge_pct=75):
        """
        Args:
            config: Config with model params (inner_params_vision, inner_params_lang)
            model: VLM wrapper (has .model and .encode)
            mode: "vision", "language", "dual_sbert", or "dual_internal"
            lang_scaler: Scaling for language component in dual modes
            n_aug: Augmented keys per sentence
            aug_mode: "image_only" (augment image, keep text) or "text_only" (rephrase text, keep image)
            aug_area_pct: Area percentage for image augmentation (only for aug_mode="image_only")
            radius_method: "balance", "augment", or "fixed"
            distance: "l2" or "cosine"
            pool_method: "mean" or "last"
            plot_edge_pct: Percentile threshold for edges in plot
        """
        super().__init__()
        
        self.config = config
        self.wrapper = model if hasattr(model, "model") else None
        self.model = model.model if hasattr(model, "model") else model
        self.device = getattr(config, "device", torch.device("cpu"))
        
        # Embedding config
        self.mode = mode
        self.lang_scaler = float(getattr(getattr(config, "model", config), "lang_scaler", lang_scaler))
        self.distance = distance
        self.pool_method = pool_method
        
        # Augmentation config
        self.n_aug = n_aug
        self.aug_mode = aug_mode  # "image_only" or "text_only"
        self.aug_area_pct = aug_area_pct
        
        # Radius estimation config (from ike_chain.py)
        self.radius_method = radius_method
        self.fixed_radius = 100.0
        self.n_positive_samples = 5
        self.balance_alpha = 0.5
        self.n_radiusaug_samples = 10
        self.radius_percentile = 99
        
        # Retrieval config
        self.cap_k = 3
        
        # Visualization config
        self.plot_edge_pct = plot_edge_pct
        
        # Internal state
        self.codebook = []  # List of key entries
        self.key_embs = None  # [N, hidden]
        self.key_radii = None  # [N]
        self._edit_count = 0
        self._added_uids = set()
        self._blank_image = Image.new('RGB', (224, 224), (128, 128, 128))
        
        # Hooks
        self._vision_act = None
        self._lang_act = None
        self._vision_hook = None
        self._lang_hook = None
        
        # Setup hooks based on mode
        model_cfg = getattr(config, "model", config)
        inner_params = getattr(model_cfg, "inner_params", [])
        inner_params_vision = getattr(model_cfg, "inner_params_vision", [])
        inner_params_lang = getattr(model_cfg, "inner_params_lang", [])
        
        if mode in ["vision", "dual_sbert", "dual_internal"]:
            if not inner_params_vision:
                raise ValueError(f"mode='{mode}' requires config.model.inner_params_vision")
            self._vision_hook = self._setup_hook(inner_params_vision[0], "_vision_act")
        
        if mode == "language":
            # Use inner_params for language mode (falls back to inner_params_lang)
            lang_params = inner_params or inner_params_lang
            if not lang_params:
                raise ValueError("mode='language' requires config.model.inner_params or inner_params_lang")
            self._lang_hook = self._setup_hook(lang_params[0], "_lang_act")
        elif mode == "dual_internal":
            if not inner_params_lang:
                raise ValueError("mode='dual_internal' requires config.model.inner_params_lang")
            self._lang_hook = self._setup_hook(inner_params_lang[0], "_lang_act")
        
        # SBERT for dual_sbert mode
        self._sbert = None
        
        # Augmenter (lazy init)
        self._augmenter = None
        dataset_name = getattr(getattr(config, "experiment", None), "dataset_name", None)
        self._dataset_name = dataset_name

    def _setup_hook(self, param_name, attr_name):
        """Register forward hook on layer (from ike_chain.py)."""
        name = param_name.rsplit(".", 1)[0] if param_name.endswith((".weight", ".bias")) else param_name
        mod = parent_module(self.model, brackets_to_periods(name))
        layer = getattr(mod, name.rsplit(".", 1)[-1])
        return layer.register_forward_hook(
            lambda m, i, o, an=attr_name: setattr(self, an, i[0].detach() if isinstance(i[0], torch.Tensor) else None)
        )

    def _get_sbert(self):
        """Lazy load SBERT (from ike_chain.py)."""
        if self._sbert is None:
            from sentence_transformers import SentenceTransformer
            self._sbert = SentenceTransformer("sentence-transformers/paraphrase-mpnet-base-v2")
            self._sbert.to(self.device)
        return self._sbert

    def _get_augmenter(self):
        """Lazy init augmenter."""
        if self._augmenter is None:
            self._augmenter = Augmenter(self.wrapper, mosaic_prob=1.0, dataset_name=self._dataset_name)
        return self._augmenter

    def forward(self, *a, **kw):
        return self.model(*a, **kw)

    def generate(self, *a, **kw):
        return (self.model if hasattr(self.model, "generate") else self.wrapper).generate(*a, **kw)

    # ==================== 2. ENCODING (from ike_chain.py) ====================

    def _pool_act(self, act, batch_size):
        """Pool activation to [B, hidden] shape."""
        if act is None:
            raise RuntimeError("Hook failed to capture activation")
        act = act.to(self.device, torch.float32)
        if act.dim() == 3:
            return act[:, -1, :] if self.pool_method == "last" else act.mean(dim=1)
        elif act.dim() == 2:
            if act.shape[0] == batch_size:
                return act
            elif act.shape[0] % batch_size == 0:
                patches = act.shape[0] // batch_size
                if self.pool_method == "last":
                    return act.view(batch_size, patches, -1)[:, -1, :]
                return act.view(batch_size, patches, -1).mean(dim=1)
            else:
                if self.pool_method == "last":
                    return act[-1:].expand(batch_size, -1)
                return act.mean(dim=0, keepdim=True).expand(batch_size, -1)
        raise RuntimeError(f"Expected 2D or 3D activation, got {act.shape}")

    @torch.no_grad()
    def _encode_vlm(self, images: List, texts: List[str]) -> torch.Tensor:
        """Get VLM embedding for <image, text> pairs (from ike_chain.py).
        
        Modes:
        - vision: vision_layer(image, text)
        - language: language_layer(image, text)
        - dual_sbert: concat(vision_layer(image, text), lang_scaler * sbert(text))
        - dual_internal: concat(vision_layer(image, text), lang_scaler * language_layer(blank, text))
        """
        self.model.eval()
        batch_size = len(images) if isinstance(images, list) else 1
        
        if self.mode == "vision":
            self._vision_act = None
            inputs = self.wrapper.encode(images, texts, tokenize=False)
            self.model(**inputs)
            emb = self._pool_act(self._vision_act, batch_size)
            self._vision_act = None
            return emb
        
        elif self.mode == "language":
            self._lang_act = None
            inputs = self.wrapper.encode(images, texts, tokenize=False)
            self.model(**inputs)
            emb = self._pool_act(self._lang_act, batch_size)
            self._lang_act = None
            return emb
        
        elif self.mode == "dual_sbert":
            # Pass 1: vision_layer(image, text)
            self._vision_act = None
            inputs = self.wrapper.encode(images, texts, tokenize=False)
            self.model(**inputs)
            vision_emb = self._pool_act(self._vision_act, batch_size)
            self._vision_act = None
            
            # Pass 2: SBERT(text) * lang_scaler
            sbert = self._get_sbert()
            lang_emb = sbert.encode(texts, convert_to_tensor=True)
            lang_emb = lang_emb.to(self.device, torch.float32) * self.lang_scaler
            
            return torch.cat([vision_emb, lang_emb], dim=-1)
        
        elif self.mode == "dual_internal":
            # Pass 1: vision_layer(image, text)
            self._vision_act = None
            inputs = self.wrapper.encode(images, texts, tokenize=False)
            self.model(**inputs)
            vision_emb = self._pool_act(self._vision_act, batch_size)
            self._vision_act = None
            
            # Pass 2: language_layer(blank, text) * lang_scaler
            self._lang_act = None
            blank_imgs = [self._blank_image] * batch_size
            inputs = self.wrapper.encode(blank_imgs, texts, tokenize=False)
            self.model(**inputs)
            lang_emb = self._pool_act(self._lang_act, batch_size) * self.lang_scaler
            self._lang_act = None
            
            return torch.cat([vision_emb, lang_emb], dim=-1)
        
        else:
            raise ValueError(f"Unknown mode: {self.mode}")

    # ==================== 3. RADIUS ESTIMATION (from ike_chain.py) ====================

    @torch.no_grad()
    def _estimate_radius(self, key_emb: torch.Tensor, img, text: str) -> float:
        """Estimate radius using configured method (from ike_chain.py)."""
        if self.radius_method == "fixed":
            return self.fixed_radius
        
        augmenter = self._get_augmenter()
        
        if self.radius_method == "balance":
            # Positive: augmented image + text
            pos_dists = []
            for _ in range(self.n_positive_samples):
                aug_img = augmenter.image(img, area_pct=self.aug_area_pct)
                aug_text = augmenter.question(text) if text else ""
                pos_emb = self._encode_vlm([aug_img], [aug_text])
                if self.distance == "cosine":
                    pos_emb = F.normalize(pos_emb, dim=-1)
                pos_dists.append(float(torch.norm(pos_emb.cpu() - key_emb.cpu())))
            d_pos = float(np.median(pos_dists)) if pos_dists else 0.0
            
            # Negative: blank image, same text
            neg_emb = self._encode_vlm([self._blank_image], [text])
            if self.distance == "cosine":
                neg_emb = F.normalize(neg_emb, dim=-1)
            d_neg = float(torch.norm(neg_emb.cpu() - key_emb.cpu()))
            
            return (1 - self.balance_alpha) * d_pos + self.balance_alpha * d_neg
        
        elif self.radius_method == "augment":
            aug_dists = []
            for _ in range(self.n_radiusaug_samples):
                aug_img = augmenter.image(img, area_pct=self.aug_area_pct)
                aug_emb = self._encode_vlm([aug_img], [text])
                if self.distance == "cosine":
                    aug_emb = F.normalize(aug_emb, dim=-1)
                aug_dists.append(float(torch.norm(aug_emb.cpu() - key_emb.cpu())))
            return float(np.percentile(aug_dists, self.radius_percentile))
        
        return self.fixed_radius

    # ==================== 4. KEY MANAGEMENT (augmentation from ike_chain_dual.py) ====================

    @torch.no_grad()
    def _add_edit(self, img, question: str, answer: str, rationale_sents: List[str]):
        """Add keys for one edit with augmentation. Question+answer treated as a sentence."""
        edit_idx = self._edit_count
        augmenter = self._get_augmenter()
        
        # Build all sentences: rationale + question/answer
        all_sents = list(rationale_sents)
        if question and answer:
            all_sents.append(f"The answer to '{question}' is {answer}.")
        
        # Collect all <img, text> pairs for this edit
        entries = []  # List of (img, text, entry_dict)
        
        # 1. Raw sentence keys: <img, sent> for each sentence
        for i, sent in enumerate(all_sents):
            entries.append((img, sent, {
                "value": sent,
                "edit_idx": edit_idx,
                "aug_idx": 0,
                "sent_idx": i,
            }))
        
        # 2. Augmented sentence keys
        if self.n_aug > 0:
            for aug_round in range(self.n_aug):
                aug_idx = aug_round + 1
                
                if self.aug_mode == "text_only":
                    # Rephrase text, keep image same
                    for i, sent in enumerate(all_sents):
                        aug_text = augmenter.question(sent) if sent else ""
                        entries.append((img, aug_text, {
                            "value": sent,  # Store original sentence as value
                            "edit_idx": edit_idx,
                            "aug_idx": aug_idx,
                            "sent_idx": i,
                        }))
                else:  # "image_only" (default)
                    # Augment image, keep text same
                    aug_img = augmenter.image(img, area_pct=self.aug_area_pct)
                    for i, sent in enumerate(all_sents):
                        entries.append((aug_img, sent, {
                            "value": sent,
                            "edit_idx": edit_idx,
                            "aug_idx": aug_idx,
                            "sent_idx": i,
                        }))
        
        self._edit_count += 1
        
        # Compute embeddings in batch
        imgs = [e[0] for e in entries]
        texts = [e[1] for e in entries]
        new_embs = self._encode_vlm(imgs, texts)
        if self.distance == "cosine":
            new_embs = F.normalize(new_embs, dim=-1)
        
        # Compute radii (optional - only for radius-based retrieval)
        new_radii = []
        for i, (src_img, text, _) in enumerate(entries):
            r = self._estimate_radius(new_embs[i:i+1], src_img, text)
            new_radii.append(r)
        new_radii = torch.tensor(new_radii, dtype=torch.float32)
        new_embs = new_embs.cpu()
        
        # Add to codebook
        for i, (_, _, entry) in enumerate(entries):
            entry["emb_idx"] = len(self.codebook)
            self.codebook.append(entry)
        
        # Append embeddings
        if self.key_embs is None:
            self.key_embs = new_embs
            self.key_radii = new_radii
        else:
            self.key_embs = torch.cat([self.key_embs, new_embs], dim=0)
            self.key_radii = torch.cat([self.key_radii, new_radii])
        
        n_raw = len(all_sents)
        n_aug_keys = len(entries) - n_raw
        print(f"  [edit {edit_idx}] +{n_raw} raw + {n_aug_keys} aug = {len(entries)} keys, total={len(self.codebook)}")

    def add_edit(self, image, question, sentences, answer=None, verbose=True):
        """Add an edit to the codebook (convenience wrapper)."""
        if isinstance(image, str):
            image = Image.open(image).convert("RGB")
        self._add_edit(image, question, answer or "", sentences)

    # ==================== 5. RETRIEVAL (simplified from ike_chain.py) ====================

    @torch.no_grad()
    def _retrieve(self, image, question: str = "") -> List[str]:
        """Retrieve values for a query <image, question>."""
        if self.key_embs is None or len(self.codebook) == 0:
            return []
        
        # Encode query
        q_emb = self._encode_vlm([image], [question])
        if self.distance == "cosine":
            q_emb = F.normalize(q_emb, dim=-1)
        q_emb = q_emb.cpu()
        
        # Compute distances
        if self.distance == "cosine":
            dists = 1 - (q_emb @ self.key_embs.t()).squeeze(0)
        else:
            dists = torch.norm(self.key_embs.float() - q_emb.float(), dim=-1)
        
        # Radius-based filtering
        in_radius = dists <= self.key_radii
        if not in_radius.any():
            return []
        
        # Get top-k within radius
        valid_idx = torch.where(in_radius)[0]
        valid_dists = dists[valid_idx]
        top_k_local = valid_dists.argsort()[:self.cap_k]
        top_k_idx = valid_idx[top_k_local].tolist()
        
        # Collect unique values
        retrieved = []
        seen = set()
        for idx in top_k_idx:
            value = self.codebook[idx]["value"]
            if value and value not in seen:
                seen.add(value)
                retrieved.append(value)
        
        return retrieved

    def retrieve(self, image, question: str = "") -> List[str]:
        """Public retrieval API."""
        return self._retrieve(image, question)

    # ==================== 6. DATASET & IO ====================

    def edit(self, config, tokens=None, batch_history=None, edit_ds=None, train_ds=None):
        """Add edits from dataset to codebook (API compat with ike_chain.py)."""
        if edit_ds is None:
            return self.model
        
        n_before = len(self.codebook)
        
        for ex in getattr(edit_ds, "data", []):
            uid = ex.get("uid") or (ex.get("image"), ex.get("question"))
            if uid in self._added_uids:
                continue
            
            rat = ex.get("cot") or ex.get("rationale") or ""
            img = ex.get("image")
            if not rat or img is None:
                continue
            
            sents = [s.strip() for s in re.split(r"(?<=[.!?])\s+", rat.strip()) if s.strip()]
            if sents:
                self._add_edit(img, ex.get("question", ""), 
                              ex.get("answer") or ex.get("target") or "", sents)
                self._added_uids.add(uid)
        
        n_after = len(self.codebook)
        mem_mb = self.key_embs.numel() * 4 / 1024 / 1024 if self.key_embs is not None else 0
        print(f"[CodebookViz] {n_before}->{n_after} keys, {mem_mb:.1f} MB")
        
        return self.model

    def save_index(self, path):
        """Save codebook to disk."""
        torch.save({
            "codebook": self.codebook,
            "key_embs": self.key_embs,
            "key_radii": self.key_radii,
            "_edit_count": self._edit_count,
        }, path)
        print(f"[CodebookViz] saved {len(self.codebook)} keys to {path}")

    def load_index(self, path):
        """Load codebook from disk."""
        data = torch.load(path, map_location=self.device)
        self.codebook = data["codebook"]
        self.key_embs = data["key_embs"]
        self.key_radii = data["key_radii"]
        self._edit_count = data.get("_edit_count", len(set(e["edit_idx"] for e in self.codebook)))
        print(f"[CodebookViz] loaded {len(self.codebook)} keys from {path}")

    def get_stats(self) -> Dict:
        """Return statistics about codebook."""
        if not self.codebook:
            return {"num_keys": 0, "num_edits": 0, "mode": self.mode}
        
        n_raw = sum(1 for e in self.codebook if e.get("aug_idx", 0) == 0)
        n_aug = len(self.codebook) - n_raw
        
        stats = {
            "num_keys": len(self.codebook),
            "num_raw": n_raw,
            "num_aug": n_aug,
            "num_edits": self._edit_count,
            "mode": self.mode,
            "lang_scaler": self.lang_scaler if "dual" in self.mode else None,
            "n_aug": self.n_aug,
            "radius_method": self.radius_method,
            "emb_dim": self.key_embs.shape[-1] if self.key_embs is not None else 0,
            "emb_size_mb": round(self.key_embs.numel() * 4 / 1024 / 1024, 2) if self.key_embs is not None else 0,
        }
        if self.key_radii is not None:
            stats["avg_radius"] = float(self.key_radii.mean())
            stats["min_radius"] = float(self.key_radii.min())
            stats["max_radius"] = float(self.key_radii.max())
        return stats

    def clear(self):
        """Clear codebook."""
        self.codebook = []
        self.key_embs = None
        self.key_radii = None
        self._edit_count = 0
        self._added_uids = set()

    # ==================== 7. VISUALIZATION ====================

    def plot_augmentation(self, image=None, text=None, n_samples=None, figsize=None):
        """Plot example of augmentation used in codebook.
        
        Args:
            image: PIL Image or path. If None, uses a blank image.
            text: Text to augment (for text_only mode). Default: sample sentence.
            n_samples: Number of augmented samples to show. Default: n_aug
            figsize: Figure size. Default: auto-sized based on n_samples
        """
        n_samples = n_samples or self.n_aug
        augmenter = self._get_augmenter()
        
        # Get or create image
        if image is None:
            img = Image.new('RGB', (224, 224), (200, 200, 200))
        elif isinstance(image, str):
            img = Image.open(image).convert("RGB")
        else:
            img = image
        
        if self.aug_mode == "text_only":
            # Text augmentation visualization
            text = text or "What color is the object in the image?"
            figsize = figsize or (10, 2 + n_samples * 0.5)
            
            fig, ax = plt.subplots(figsize=figsize)
            ax.axis("off")
            
            lines = [f"Original: {text}"]
            for i in range(n_samples):
                aug_text = augmenter.question(text)
                lines.append(f"Aug {i + 1}: {aug_text}")
            
            ax.text(0.05, 0.95, "\n".join(lines), transform=ax.transAxes, 
                   fontsize=10, verticalalignment='top', family='monospace',
                   bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
            plt.suptitle(f"Text Augmentation (aug_mode={self.aug_mode}, n_aug={self.n_aug})", fontsize=11)
        else:
            # Image augmentation visualization
            figsize = figsize or (3 * (n_samples + 1), 3)
            
            fig, axes = plt.subplots(1, n_samples + 1, figsize=figsize)
            axes[0].imshow(img)
            axes[0].set_title("Original", fontsize=10)
            axes[0].axis("off")
            
            for i in range(n_samples):
                aug_img = augmenter.image(img, area_pct=self.aug_area_pct)
                axes[i + 1].imshow(aug_img)
                axes[i + 1].set_title(f"Aug {i + 1}", fontsize=10)
                axes[i + 1].axis("off")
            
            plt.suptitle(f"Image Augmentation (area_pct={self.aug_area_pct}, n_aug={self.n_aug})", fontsize=11)
        
        plt.tight_layout()
        plt.show()

    def plot_codebook(self, max_edits=20, figsize=(6, 5)):
        """Plot force-directed network of keys.
        
        Colors by edit_idx, shapes by aug_idx.
        """
        import networkx as nx
        
        if self.key_embs is None or len(self.codebook) == 0:
            print("[CodebookViz] No keys to plot")
            return
        
        # Sample edits if too many
        all_edit_indices = set(e.get("edit_idx", 0) for e in self.codebook)
        if len(all_edit_indices) > max_edits:
            import random
            selected_edits = set(random.sample(list(all_edit_indices), max_edits))
        else:
            selected_edits = all_edit_indices
        
        # Filter keys by selected edits
        indices = [i for i, e in enumerate(self.codebook) if e.get("edit_idx", 0) in selected_edits]
        embs = self.key_embs[indices].float().cpu().numpy()
        
        # Pairwise similarity
        dists = np.linalg.norm(embs[:, None] - embs[None, :], axis=-1)
        sims = 1 / (1 + dists)
        
        # Build graph
        G = nx.Graph()
        for i, idx in enumerate(indices):
            G.add_node(i, 
                       aug_idx=self.codebook[idx].get("aug_idx", 0),
                       edit_idx=self.codebook[idx].get("edit_idx", 0),
                       sent_idx=self.codebook[idx].get("sent_idx", 0))
        
        # Add edges above threshold
        thresh = np.percentile(sims[np.triu_indices(len(indices), k=1)], self.plot_edge_pct) if len(indices) > 1 else 0
        for i in range(len(indices)):
            for j in range(i + 1, len(indices)):
                if sims[i, j] > thresh:
                    G.add_edge(i, j, weight=sims[i, j])
        
        # Layout
        pos = nx.spring_layout(G, weight='weight', seed=42, k=2/np.sqrt(len(indices)))
        
        # Shapes by aug_idx
        AUG_SHAPES = ['o', '^', 's', 'D', 'p', '*', 'h', 'v', '<', '>']
        
        # Group nodes by aug_idx
        nodes_by_aug = {}
        for i in G.nodes:
            aug_idx = G.nodes[i]['aug_idx']
            if aug_idx not in nodes_by_aug:
                nodes_by_aug[aug_idx] = []
            nodes_by_aug[aug_idx].append(i)
        
        # Colors by edit index
        n_edits = len(selected_edits)
        cmap = plt.cm.get_cmap('tab20', max(n_edits, 1))
        
        # Plot
        fig, ax = plt.subplots(figsize=figsize)
        nx.draw_networkx_edges(G, pos, alpha=0.15, width=0.1, ax=ax)
        
        # Draw each aug_idx group
        for aug_idx in sorted(nodes_by_aug.keys()):
            nodes = nodes_by_aug[aug_idx]
            colors = [cmap(G.nodes[i]['edit_idx'] % 20) for i in nodes]
            shape = AUG_SHAPES[aug_idx % len(AUG_SHAPES)]
            sizes = 40
            nx.draw_networkx_nodes(G, pos, nodelist=nodes, node_color=colors, 
                                   node_size=sizes, alpha=0.8, node_shape=shape, ax=ax)
        
        # Draw sentence index labels on all nodes
        labels = {i: str(G.nodes[i]['sent_idx']) for i in G.nodes}
        nx.draw_networkx_labels(G, pos, labels=labels, font_size=5, font_color='black', ax=ax)
        
        # Legend
        for aug_idx in sorted(nodes_by_aug.keys()):
            shape = AUG_SHAPES[aug_idx % len(AUG_SHAPES)]
            label = f'raw ({len(nodes_by_aug[aug_idx])})' if aug_idx == 0 else f'aug{aug_idx} ({len(nodes_by_aug[aug_idx])})'
            ax.scatter([], [], c='gray', s=15, marker=shape, label=label)
        ax.legend(loc='lower left', fontsize=5, markerscale=0.7)
        
        # Title
        mode_str = self.mode
        if "dual" in self.mode:
            mode_str += f" (λ={self.lang_scaler})"
        ax.set_title(f'Codebook ({n_edits} edits, {len(indices)} keys)\nmode={mode_str}', fontsize=9)
        ax.axis('off')
        
        plt.tight_layout()
        plt.show()

    def cleanup(self):
        """Remove hooks and free resources."""
        if self._vision_hook:
            self._vision_hook.remove()
        if self._lang_hook:
            self._lang_hook.remove()
        self._sbert = None
        self._augmenter = None
