"""
Router evaluation: predict model selections using the trained router.

This module evaluates the router's ability to select the correct model for each test prompt.
"""

import torch
import torch.nn.functional as F
from typing import List, Dict, Any, Optional, Tuple
from pathlib import Path
import json
from collections import defaultdict
import numpy as np
import hashlib

from .model_selection_carve import ModelRegistry, RouterModel
from .model_selection_carve.candidates import CandidateSetBuilder
from .model_selection_carve.model_registry import normalize_domain
from .openmodel_carve import LoRAModelManager
from .utils_carve.compute_tracker import ComputeTracker


def tensor_sha(t: torch.Tensor) -> str:
    """
    Compute SHA1 hash of tensor bytes (float32 CPU representation).
    
    Args:
        t: Input tensor
    
    Returns:
        First 8 hex characters of SHA1 hash
    """
    # Convert to float32 CPU and get bytes
    t_bytes = t.float().cpu().numpy().tobytes()
    return hashlib.sha1(t_bytes).hexdigest()[:8]


def format_eval_prompt(
    ex: Dict[str, Any],
    tokenizer: Any,
    eval_use_chat_template: bool = False,
    system_prompt: str = "",
    model_card: str = "",
) -> str:
    """
    Format evaluation prompt to match training format.
    
    Args:
        ex: Example dict with 'prompt_text' or 'instruction'
        tokenizer: Tokenizer (for chat template if enabled)
        eval_use_chat_template: If True, use tokenizer.apply_chat_template
        system_prompt: System prompt to prepend (if not using chat template)
        model_card: Model card/retriever info to append (if not using chat template)
    
    Returns:
        Formatted prompt string
    """
    # Extract prompt text
    prompt_text = ex.get("prompt_text") or ex.get("instruction", "")
    
    if eval_use_chat_template and hasattr(tokenizer, "apply_chat_template"):
        # Use chat template (if available and enabled)
        # This matches training if training used chat templates
        messages = [
            {"role": "system", "content": system_prompt} if system_prompt else None,
            {"role": "user", "content": prompt_text + model_card}
        ]
        messages = [m for m in messages if m is not None]
        formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        return formatted
    else:
        # Default: use string concatenation (matches training format)
        # Training format: system_prompt + prompt + model_card + "\n###Response:"
        full_prompt = system_prompt + prompt_text + model_card + "\n###Response:"
        return full_prompt


def compute_registry_fingerprint(model_registry: ModelRegistry) -> str:
    """
    Compute stable fingerprint of registry ordering.
    
    Args:
        model_registry: ModelRegistry instance
    
    Returns:
        First 12 hex characters of SHA1 hash
    """
    # Build ordered list of model names by index
    num_models = len(model_registry)
    idx2model_list = [model_registry.idx2model[i] for i in range(num_models)]
    registry_str = "\n".join(idx2model_list)
    return hashlib.sha1(registry_str.encode()).hexdigest()[:12]


def load_trained_router(
    checkpoint_dir: Path,
    device: str = "cuda",
    strict: bool = True,
    num_models_override: Optional[int] = None,
) -> RouterModel:
    """
    Load a trained RouterModel robustly.

    Expected files:
      - router_config.json   (saved at train time)
      - router_model.pt      (state_dict)

    Supports registry expansion via num_models_override by resizing model_embeddings.
    
    Args:
        checkpoint_dir: Directory containing the router checkpoint
        device: Device to load the router on
        strict: If True, use strict=True for load_state_dict (default: True for evaluation)
        num_models_override: Optional override for num_models if registry evolved since training
    
    Returns:
        Loaded RouterModel with exact training-time architecture parameters
    
    Raises:
        FileNotFoundError: If router checkpoint or config files are missing
        KeyError: If checkpoint is missing required keys
        ValueError: If embedding dimensions don't match
    """
    router_path = checkpoint_dir / "router_model.pt"
    cfg_path = checkpoint_dir / "router_config.json"

    if not router_path.exists():
        raise FileNotFoundError(f"No router checkpoint found at {router_path}")
    if not cfg_path.exists():
        raise FileNotFoundError(
            f"Missing router_config.json at {cfg_path}. "
            f"Save config during training to ensure evaluation parity."
        )

    cfg = json.loads(cfg_path.read_text())

    # Use trained config; optionally override num_models if registry evolved
    num_models_ckpt = int(cfg["num_models"])
    num_models = int(num_models_override) if num_models_override is not None else num_models_ckpt

    # Load all architecture parameters from config
    # Handle backward compatibility: lm_hidden_size may be missing in older checkpoints
    lm_hidden_size = int(cfg.get("lm_hidden_size", 4096))  # Default to 4096 if missing
    
    router = RouterModel(
        num_models=num_models,
        embedding_dim=int(cfg["embedding_dim"]),
        lm_hidden_size=lm_hidden_size,
        tau=float(cfg["tau"]),
        pooling=cfg["pooling"],
    ).to(device)

    # Load state dict with safer options (weights_only if available)
    try:
        # PyTorch >= 2.0 supports weights_only for security
        state = torch.load(router_path, map_location=device, weights_only=True)
    except TypeError:
        # Fallback for older PyTorch versions
        state = torch.load(router_path, map_location=device)

    # Handle model_embeddings resize if num_models differs
    if num_models != num_models_ckpt:
        if "model_embeddings.weight" not in state:
            raise KeyError("Checkpoint missing model_embeddings.weight; cannot resize safely.")

        old_w = state["model_embeddings.weight"]  # [num_models_ckpt, D]
        new_w = router.model_embeddings.weight    # [num_models, D]

        if old_w.shape[1] != new_w.shape[1]:
            raise ValueError(
                f"Embedding dim mismatch: ckpt {old_w.shape} vs model {new_w.shape}"
            )

        # Copy overlap
        overlap = min(num_models, num_models_ckpt)
        with torch.no_grad():
            new_w[:overlap].copy_(old_w[:overlap])

        # Free old embedding tensor from memory immediately
        del old_w

        # Remove the old tensor from state_dict so load_state_dict doesn't complain about shape
        # and load remaining parameters strictly.
        state = {k: v for k, v in state.items() if k != "model_embeddings.weight"}

        missing, unexpected = router.load_state_dict(state, strict=strict)
        
        # Free checkpoint state dict from GPU memory after loading
        del state
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        
        # Re-inject embedding table already handled
        # (router already has the correct weight)
        
        print(f"✓ Loaded router weights from {router_path}")
        print(f"  strict_mode: {strict}")
        print(f"  ⚠️  Registry size changed: {num_models_ckpt} → {num_models}")
        print(f"  ✓ Copied {overlap} overlapping embedding rows")
        if num_models > num_models_ckpt:
            print(f"  ⚠️  {num_models - num_models_ckpt} new rows initialized (not trained)")
        elif num_models < num_models_ckpt:
            print(f"  ⚠️  {num_models_ckpt - num_models} rows truncated (registry shrunk)")
    else:
        missing, unexpected = router.load_state_dict(state, strict=strict)
        
        # Free checkpoint state dict from GPU memory after loading
        del state
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        
        print(f"✓ Loaded router weights from {router_path}")
        print(f"  strict_mode: {strict}")

    # Print missing and unexpected keys
    if missing:
        print(f"  ⚠️  Missing keys ({len(missing)}):")
        for key in missing:
            print(f"    - {key}")
    else:
        print(f"  ✓ No missing keys")
    
    if unexpected:
        print(f"  ⚠️  Unexpected keys ({len(unexpected)}):")
        for key in unexpected:
            print(f"    - {key}")
    else:
        print(f"  ✓ No unexpected keys")
    
    # Print parameter summary for trainable router parameters
    print(f"\n  [Router Parameter Summary]")
    with torch.inference_mode():
        for name, param in router.named_parameters():
            if param.requires_grad:
                mean_val = param.float().mean().item()
                std_val = param.float().std().item()
                l2_norm = param.float().norm(p=2).item()
                print(f"    {name}:")
                print(f"      shape: {tuple(param.shape)}")
                print(f"      mean: {mean_val:.6f}")
                print(f"      std:  {std_val:.6f}")
                print(f"      L2 norm: {l2_norm:.6f}")
    
    # Print architecture config for verification
    print(f"\n  [Router Architecture Config]")
    print(f"    num_models: {router.num_models}")
    print(f"    embedding_dim: {router.embedding_dim}")
    print(f"    lm_hidden_size: {router.lm_hidden_size}")
    print(f"    tau: {router.tau}")
    print(f"    pooling: {router.pooling}")

    router.eval()
    return router


def compute_domain_scores(
    logits_all: torch.Tensor,
    domain_to_indices_tensor: Dict[str, torch.Tensor],
    mode: str = "logsumexp",
    topk: int = 10,
    alpha: float = 0.5,
) -> Dict[str, torch.Tensor]:
    """
    Compute domain scores by aggregating model logits within each domain.
    
    Supports multiple aggregation strategies:
    - "logsumexp": logsumexp over all models in domain (default, stable aggregation)
    - "max": maximum logit in domain
    - "topk_logsumexp": logsumexp over top-k models in domain
    - "hybrid": alpha * max + (1-alpha) * logsumexp(topk)
    
    Args:
        logits_all: Router logits for all models [num_models]
        domain_to_indices_tensor: Mapping from domain name to tensor of model indices (on same device as logits_all)
        mode: Scoring mode ("logsumexp", "max", "topk_logsumexp", "hybrid")
        topk: Number of top models to use for topk_logsumexp or hybrid (default: 10)
        alpha: Weight for max in hybrid mode (default: 0.5)
    
    Returns:
        Dictionary mapping domain name to aggregated score (tensor)
    """
    domain_scores = {}
    for domain, idx in domain_to_indices_tensor.items():
        if idx.numel() == 0:
            continue
        dom_logits = logits_all.index_select(0, idx)

        if mode == "max":
            score = dom_logits.max()

        elif mode == "logsumexp":
            score = torch.logsumexp(dom_logits, dim=0)

        elif mode == "topk_logsumexp":
            k = min(topk, dom_logits.numel())
            topk_vals = torch.topk(dom_logits, k=k, largest=True).values
            score = torch.logsumexp(topk_vals, dim=0)

        elif mode == "hybrid":
            k = min(topk, dom_logits.numel())
            topk_vals = torch.topk(dom_logits, k=k, largest=True).values
            lse_topk = torch.logsumexp(topk_vals, dim=0)
            mx = dom_logits.max()
            score = alpha * mx + (1.0 - alpha) * lse_topk

        else:
            raise ValueError(f"Unknown hier domain score mode: {mode}")

        domain_scores[domain] = score
    return domain_scores


def hierarchical_rerank_topN(
    logits_all: torch.Tensor,
    restricted_indices: List[int],
    gold_model_idx: int,
    device: torch.device,
) -> Tuple[bool, bool, bool]:
    """
    Re-rank models within a restricted set and compute accuracy metrics.
    
    Args:
        logits_all: Router logits for all models [num_models]
        restricted_indices: List of model indices in the restricted set (already computed)
        gold_model_idx: Index of the gold model
        device: Device for tensors
    
    Returns:
        Tuple of (top1_correct, top5_correct, top10_correct)
    """
    # Check if gold model is in restricted set
    if gold_model_idx not in restricted_indices:
        # Gold not in restricted set - all metrics are False
        return False, False, False
    
    # Re-rank models by logits restricted to union
    restricted_indices_tensor = torch.tensor(restricted_indices, dtype=torch.long, device=device)
    restricted_logits = logits_all[restricted_indices_tensor]  # [restricted_size]
    
    # Find gold model's position in restricted set
    gold_pos_in_restricted = restricted_indices.index(gold_model_idx)
    
    # Compute ranks
    sorted_restricted = torch.argsort(restricted_logits, descending=True)
    gold_rank_restricted = (sorted_restricted == gold_pos_in_restricted).nonzero(as_tuple=True)[0].item() + 1
    
    # Compute top-k accuracies
    top1_correct = (gold_rank_restricted == 1)
    top5_correct = (gold_rank_restricted <= 5)
    top10_correct = (gold_rank_restricted <= 10)
    
    return top1_correct, top5_correct, top10_correct


def build_restricted_set(
    top_domains: List[str],
    domain_to_indices: Dict[str, List[int]],
    N: int,
    num_models: int,
) -> List[int]:
    """
    Build restricted set as union of models in top-N domains.
    If union is empty, fall back to all models.
    
    Args:
        top_domains: List of domain names (sorted by score, descending)
        domain_to_indices: Mapping from domain name to list of model indices
        N: Number of top domains to consider
        num_models: Total number of models (for fallback)
    
    Returns:
        Sorted list of model indices in restricted set
    """
    restricted_indices_set = set()
    for domain in top_domains[:N]:
        if domain in domain_to_indices:
            restricted_indices_set.update(domain_to_indices[domain])
    
    # Handle edge case: if union is empty, fall back to all models
    if not restricted_indices_set:
        restricted_indices = list(range(num_models))
    else:
        restricted_indices = sorted(list(restricted_indices_set))
    
    return restricted_indices


def evaluate_router(
    router_model: RouterModel,
    model_registry: ModelRegistry,
    lm_model: LoRAModelManager,
    test_data: List[Dict[str, Any]],
    k_values: List[int] = [1, 3, 5, 10],
    batch_size: int = 32,
    device: str = "cuda",
    debug: bool = False,
    eval_use_chat_template: bool = False,
    system_prompt: str = "",
    checkpoint_dir: Optional[Path] = None,
    max_length: int = 512,
    candidate_K: Optional[int] = None,
    router_config: Optional[Dict[str, Any]] = None,
    known_domain_mode: bool = False,
    hierarchical_eval: bool = False,
    hierarchy_level: str = "domain",
    hierarchical_topk: int = 1,
    hier_domain_score_mode: str = "logsumexp",
    hier_domain_topk: int = 10,
    hier_domain_hybrid_alpha: float = 0.5,
    model_family_lookup: Optional[Dict[str, str]] = None,
) -> Dict[str, Any]:
    """
    Evaluate router model on test data.
    
    Returns:
        Dictionary with evaluation metrics including compute metrics if tracked.
    """
    # Initialize compute tracker for evaluation
    eval_compute_tracker = ComputeTracker()
    """
    Evaluate router on test data.
    
    Args:
        router_model: Trained RouterModel
        model_registry: ModelRegistry with model name → ID mappings
        lm_model: LoRAModelManager to encode prompts
        test_data: List of test examples with 'prompt_text' and 'model_name'
        k_values: List of k values for top-k accuracy computation
        batch_size: Batch size for evaluation
        device: Device for computation
        debug: Whether to enable detailed debug output
        eval_use_chat_template: If True, use tokenizer.apply_chat_template for formatting
        system_prompt: System prompt to prepend (if not using chat template)
        checkpoint_dir: Optional checkpoint directory for registry fingerprint validation
        known_domain_mode: If True, only compare against models within the same domain
            (assumes domain is known and filters candidates accordingly)
        hierarchical_eval: If True, enable two-stage hierarchical evaluation
        hierarchy_level: Level for hierarchical grouping ("domain" or "parent_group")
        hierarchical_topk: Number of top groups to consider (default: 1)
        hier_domain_score_mode: Domain scoring strategy ("logsumexp", "max", "topk_logsumexp", "hybrid", default: "logsumexp")
        hier_domain_topk: Number of top models for topk_logsumexp/hybrid modes (default: 10)
        hier_domain_hybrid_alpha: Weight for max in hybrid mode (default: 0.5)
    
    Returns:
        Dictionary of evaluation metrics (and diagnostics if debug=True)
        If known_domain_mode=True, includes separate metrics prefixed with "known_domain_"
        If hierarchical_eval=True, includes hierarchical metrics prefixed with "hier_"
    """
    # A) Enforce eval mode for both router and LM
    router_model.eval()
    if hasattr(lm_model, 'model') and hasattr(lm_model.model, 'eval'):
        lm_model.model.eval()
    elif hasattr(lm_model, 'eval'):
        lm_model.eval()
    
    # Validate registry ↔ embedding alignment (C)
    num_models = len(model_registry)
    embedding_dim = router_model.embedding_dim
    
    # Check embedding table shape matches registry
    assert router_model.model_embeddings.weight.shape[0] == num_models, \
        f"Registry has {num_models} models but embedding table has {router_model.model_embeddings.weight.shape[0]} rows"
    assert router_model.model_embeddings.weight.shape[1] == embedding_dim, \
        f"Expected embedding_dim={embedding_dim} but got {router_model.model_embeddings.weight.shape[1]}"
    
    # Compute registry fingerprint
    registry_fingerprint = compute_registry_fingerprint(model_registry)
    
    if debug:
        print(f"\n[Registry Validation]")
        print(f"  num_models: {num_models}")
        print(f"  embedding_dim: {embedding_dim}")
        print(f"  embedding_table.shape: {router_model.model_embeddings.weight.shape}")
        print(f"  registry_fingerprint: {registry_fingerprint}")
    
    # Check against saved fingerprint if checkpoint_dir provided
    if checkpoint_dir is not None:
        fingerprint_path = checkpoint_dir / "registry_fingerprint.txt"
        router_config_path = checkpoint_dir / "router_config.json"
        
        saved_fingerprint = None
        if fingerprint_path.exists():
            saved_fingerprint = fingerprint_path.read_text().strip()
        elif router_config_path.exists():
            # Try to load from router_config.json if it has fingerprint
            try:
                with open(router_config_path, 'r') as f:
                    config_data = json.load(f)
                    saved_fingerprint = config_data.get("registry_fingerprint")
            except:
                pass
        
        if saved_fingerprint:
            if saved_fingerprint != registry_fingerprint:
                raise RuntimeError(
                    f"Registry fingerprint mismatch!\n"
                    f"  Saved: {saved_fingerprint}\n"
                    f"  Current: {registry_fingerprint}\n"
                    f"This indicates idx2model ordering changed. Check registry loading."
                )
            if debug:
                print(f"  ✓ Fingerprint matches saved: {saved_fingerprint}")
        else:
            if debug:
                print(f"  ⚠️  No saved fingerprint found (expected at {fingerprint_path})")
    
    # Track diagnostics for return
    diagnostics = {
        "registry_fingerprint": registry_fingerprint,
        "ids_hashes": [],
        "prompt_emb_hashes": [],
        "score_vec_hashes": [],
    } if debug else None
    
    all_predictions: List[int] = []
    all_labels: List[int] = []
    all_domains: List[str] = []
    all_scores: List[List[float]] = []  # Store all scores for analysis
    
    # Track per-domain metrics
    domain_correct = defaultdict(int)
    domain_total = defaultdict(int)
    
    # Track model family accuracy
    family_correct = defaultdict(int)
    family_total = defaultdict(int)
    
    # Track forgetting metrics (for old models/domains/families from earlier experiences)
    # Forgetting = 1 - accuracy on old items (measures how much performance degraded)
    old_model_correct = 0
    old_model_total = 0
    old_domain_correct = defaultdict(int)
    old_domain_total = defaultdict(int)
    old_family_correct = defaultdict(int)
    old_family_total = defaultdict(int)
    
    # Track top-k accuracies (acc_all over all models)
    topk_correct = {k: 0 for k in k_values}

    # Track known-domain mode metrics (if enabled)
    known_domain_topk_correct = {k: 0 for k in k_values} if known_domain_mode else None
    known_domain_total = 0 if known_domain_mode else 0
    known_domain_gold_ranks = [] if known_domain_mode else []  # Track gold model rank within domain
    
    # Track known-domain candidate-set metrics (comparable to training)
    known_domain_cand_correct = 0 if known_domain_mode else 0
    known_domain_cand_total = 0 if known_domain_mode else 0
    
    # Track score margins for comparison with training
    candidate_score_margins = []  # For candidate-set evaluation
    all_models_score_margins = []  # For all-models evaluation

    # Track candidate-set accuracies (acc_candidate over K candidates)
    cand_top1_correct = 0
    cand_total = 0
    
    # Track gold_in_registry for diagnostic
    gold_in_registry_count = 0
    total_examples_processed = 0
    
    # Track entropy for diagnostic
    all_entropies = []
    
    # Track example IDs for debug output
    example_counter = 0
    
    # Track hierarchical evaluation metrics (if enabled)
    hier_group_correct = 0 if hierarchical_eval else 0
    hier_group_total = 0 if hierarchical_eval else 0
    hier_model_topk_correct = {k: 0 for k in k_values} if hierarchical_eval else {}
    hier_model_total = 0 if hierarchical_eval else 0
    hier_e2e_top1_correct = 0 if hierarchical_eval else 0
    hier_restricted_sizes = [] if hierarchical_eval else []
    
    # Track Top-N Domain Hierarchical Rerank metrics (N=1,2,3)
    # Legacy: conditional metrics (only when gold domain in top-N)
    hier_model_top1_atN = {1: 0, 2: 0, 3: 0} if hierarchical_eval else {}
    hier_model_top5_atN = {1: 0, 2: 0, 3: 0} if hierarchical_eval else {}
    hier_model_top10_atN = {1: 0, 2: 0, 3: 0} if hierarchical_eval else {}
    hier_model_total_atN = {1: 0, 2: 0, 3: 0} if hierarchical_eval else {}  # Count of examples where gold in restricted set
    
    # New: E2E metrics (all examples, correct denominator)
    hier_model_top1_e2e_atN = {1: 0, 2: 0, 3: 0} if hierarchical_eval else {}  # E2E correct count
    hier_model_top1_cond_atN = {1: 0, 2: 0, 3: 0} if hierarchical_eval else {}  # Conditional correct count (same as hier_model_top1_atN, but renamed for clarity)
    hier_domain_included_atN = {1: 0, 2: 0, 3: 0} if hierarchical_eval else {}  # Count where gold domain ∈ predicted top-N
    
    hier_restricted_sizes_atN = {1: [], 2: [], 3: []} if hierarchical_eval else {}
    
    # Track missing gold examples (gold model not in registry)
    hier_missing_gold_count = 0 if hierarchical_eval else 0
    
    # Debug: Track why examples are filtered out (kept for backward compatibility)
    hier_gold_not_in_predicted_domain = {1: 0, 2: 0, 3: 0} if hierarchical_eval and debug else {}
    hier_gold_domain_missing_from_registry = 0 if hierarchical_eval and debug else 0
    
    # Precompute group mappings for hierarchical evaluation
    group_to_model_indices: Dict[str, List[int]] = {}
    group_to_model_indices_tensor: Dict[str, torch.Tensor] = {}
    model_idx_to_group: Dict[int, str] = {}
    if hierarchical_eval:
        if hierarchy_level == "domain":
            # Use domain2models mapping
            for domain, model_indices in model_registry.domain2models.items():
                group_to_model_indices[domain] = model_indices
                for model_idx in model_indices:
                    model_idx_to_group[model_idx] = domain
        elif hierarchy_level == "parent_group":
            # Use parent_group2models mapping
            for parent_group, model_indices in model_registry.parent_group2models.items():
                group_to_model_indices[parent_group] = model_indices
                for model_idx in model_indices:
                    model_idx_to_group[model_idx] = parent_group
        else:
            raise ValueError(f"Invalid hierarchy_level: {hierarchy_level}. Must be 'domain' or 'parent_group'")
        
        # Ensure all models have a group (map missing to "unknown")
        for model_idx in range(num_models):
            if model_idx not in model_idx_to_group:
                if "unknown" not in group_to_model_indices:
                    group_to_model_indices["unknown"] = []
                group_to_model_indices["unknown"].append(model_idx)
                model_idx_to_group[model_idx] = "unknown"
        
        # Precompute tensor-based indices for efficient domain scoring
        device_tensor = torch.device(device)
        for domain, model_indices in group_to_model_indices.items():
            if model_indices:
                group_to_model_indices_tensor[domain] = torch.tensor(
                    model_indices, dtype=torch.long, device=device_tensor
                )
            else:
                group_to_model_indices_tensor[domain] = torch.tensor(
                    [], dtype=torch.long, device=device_tensor
                )
        
        # Debug: Verify registry contains all test examples' gold models
        if debug:
            print(f"\n[Hierarchical Evaluation Setup]")
            print(f"  hierarchy_level: {hierarchy_level}")
            print(f"  hierarchical_topk: {hierarchical_topk}")
            print(f"  hier_domain_score_mode: {hier_domain_score_mode}")
            print(f"  hier_domain_topk: {hier_domain_topk}")
            print(f"  hier_domain_hybrid_alpha: {hier_domain_hybrid_alpha}")
            print(f"  num_groups: {len(group_to_model_indices)}")
            for group, indices in sorted(group_to_model_indices.items(), key=lambda x: -len(x[1]))[:10]:
                print(f"    {group}: {len(indices)} models")
            
            # Check if all test examples' gold models have a group mapping
            test_gold_models = set()
            test_gold_domains = set()
            for ex in test_data:
                model_name = ex.get('model_name', '')
                domain = ex.get('domain', 'unknown')
                if model_name in model_registry.model2idx:
                    test_gold_models.add(model_name)
                    test_gold_domains.add(domain)
            
            print(f"\n[Registry Coverage Check]")
            print(f"  Test examples with gold models in registry: {len(test_gold_models)}")
            print(f"  Unique domains in test data: {len(test_gold_domains)}")
            print(f"  Domains in registry: {len(model_registry.domain2models)}")
            print(f"  Groups in group_to_model_indices: {len(group_to_model_indices)}")
            
            # Check if any test gold models don't have a group mapping
            missing_group_count = 0
            for model_name in test_gold_models:
                if model_name in model_registry.model2idx:
                    model_idx = model_registry.model2idx[model_name]
                    if model_idx not in model_idx_to_group:
                        missing_group_count += 1
            if missing_group_count > 0:
                print(f"  ⚠️  WARNING: {missing_group_count} test gold models don't have a group mapping!")
            else:
                print(f"  ✓ All test gold models have group mappings")
    
    num_examples = len(test_data)
    print(f"\nEvaluating router on {num_examples} test examples...")
    
    # Check if model_family_lookup is available
    if model_family_lookup is None:
        print(f"  ⚠️  WARNING: model_family_lookup is None - family accuracy may not be computed correctly!")
        print(f"      Make sure to pass model_family_lookup when calling evaluate_router.")
    else:
        print(f"  ✓ model_family_lookup available with {len(model_family_lookup)} entries")
    
    if known_domain_mode:
        print(f"[KNOWN DOMAIN MODE ENABLED] Only comparing against models within the same domain")
    
    if debug:
        print(f"[DEBUG MODE ENABLED] Will print detailed diagnostics")
    
    # Track first batch for diagnostics
    first_batch_processed = False
    
    with torch.no_grad():
        for start_idx in range(0, num_examples, batch_size):
            end_idx = min(start_idx + batch_size, num_examples)
            batch = test_data[start_idx:end_idx]
            
            # B) Format prompts to match training format
            tokenizer = lm_model.tokenizer
            formatted_prompts = []
            for ex in batch:
                # Extract model_card if available (for retriever info)
                model_card = ex.get("model_card", "") or ex.get("reference_api", "")
                if model_card and not model_card.startswith("\n<Reference API>:"):
                    model_card = "\n<Reference API>: " + model_card
                
                formatted = format_eval_prompt(
                    ex=ex,
                    tokenizer=tokenizer,
                    eval_use_chat_template=eval_use_chat_template,
                    system_prompt=system_prompt,
                    model_card=model_card,
                )
                formatted_prompts.append(formatted)
            
            # Debug output for first batch (B)
            if debug and not first_batch_processed:
                print(f"\n[Prompt Formatting Debug] (first batch, start_idx={start_idx})")
                for i in range(min(2, len(formatted_prompts))):
                    prompt_snippet = formatted_prompts[i][:200]
                    print(f"  Example {i} prompt (first 200 chars): {repr(prompt_snippet)}")
            
            prompts = formatted_prompts
            labels = [ex['model_name'] for ex in batch]
            domains = [ex.get('domain', 'unknown') for ex in batch]
            
            # Convert labels to indices
            label_indices = []
            valid_mask = []
            gold_in_registry_flags = []
            for label in labels:
                if label in model_registry.model2idx:
                    label_indices.append(model_registry.model2idx[label])
                    valid_mask.append(True)
                    gold_in_registry_flags.append(True)
                else:
                    # Unknown model in test set
                    label_indices.append(-1)
                    valid_mask.append(False)
                    gold_in_registry_flags.append(False)
            
            # Encode prompts using the LM backbone
            # We need to get hidden states for prompt tokens only
            inputs = tokenizer(
                prompts,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=max_length,
            ).to(device)
            
            # Debug: token counts and input_ids hashes for first batch (B)
            if debug and not first_batch_processed:
                print(f"\n[Tokenization Debug] (first batch)")
                for i in range(min(2, len(prompts))):
                    token_count = inputs['attention_mask'][i].sum().item()
                    ids_hash = tensor_sha(inputs['input_ids'][i])
                    print(f"  Example {i}: {token_count} tokens, input_ids hash: {ids_hash}")
                    if diagnostics is not None:
                        diagnostics["ids_hashes"].append(ids_hash)
            
            # Get model outputs with hidden states
            lm_outputs = lm_model.model.model(
                input_ids=inputs['input_ids'],
                attention_mask=inputs['attention_mask'],
                output_hidden_states=True,
                return_dict=True,
            )
            
            # Extract prompt embeddings using router's encode_prompt method
            # This respects the pooling setting (last_token vs mean) and applies projection
            hidden_states = lm_outputs.hidden_states[-1]  # [B, L, D]
            attention_mask = inputs['attention_mask']  # [B, L]
            
            # Ensure router model and hidden states are on same device and dtype
            router_model = router_model.to(device=hidden_states.device, dtype=hidden_states.dtype)
            
            # In evaluation, all tokens are prompt tokens (no completion)
            # So prompt_mask = attention_mask
            prompt_mask = inputs["attention_mask"].to(dtype=torch.bool)  # [B, L]
            
            # Debug: compute last_nonpad indices and verify pooled token ids (first batch only)
            if debug and not first_batch_processed:
                print(f"\n[Pooling Debug] (first batch, start_idx={start_idx})")
                B = prompt_mask.shape[0]
                L = prompt_mask.shape[1]
                input_ids = inputs['input_ids']
                pad_token_id = tokenizer.pad_token_id if hasattr(tokenizer, 'pad_token_id') else tokenizer.eos_token_id
                padding_side = tokenizer.padding_side if hasattr(tokenizer, 'padding_side') else 'unknown'
                
                print(f"  tokenizer.padding_side: {padding_side}")
                print(f"  pad_token_id: {pad_token_id}")
                
                # Compute last_nonpad indices using SAME method as RouterModel (padding-side agnostic)
                # Use mask * positions max method (works for both left and right padding)
                if prompt_mask.dtype == torch.bool:
                    mask_long = prompt_mask.long()
                else:
                    mask_long = (prompt_mask > 0.5).long()
                
                positions = torch.arange(L, device=prompt_mask.device).unsqueeze(0).expand(B, L)  # [B, L]
                masked_positions = mask_long * positions  # [B, L]
                last_nonpad = masked_positions.max(dim=1).values  # [B] - padding-side agnostic
                
                last_ids = input_ids[torch.arange(B, device=input_ids.device), last_nonpad]  # [B]
                
                print(f"  last_nonpad (padding-side agnostic, first 8): {last_nonpad[:8].tolist()}")
                print(f"  last_ids (first 8): {last_ids[:8].tolist()}")
                
                # Verify that last_nonpad points to non-pad positions
                if pad_token_id is not None:
                    pad_check = (last_ids != pad_token_id).all().item()
                    print(f"  ✓ All last_ids != pad_token_id: {pad_check}")
                    if not pad_check:
                        pad_examples = (last_ids == pad_token_id).nonzero(as_tuple=True)[0].tolist()
                        print(f"  ⚠️ WARNING: Examples {pad_examples} have last_ids == pad_token_id!")
                else:
                    # If no pad_token_id, verify attention_mask at last_nonpad
                    mask_at_last = mask_long[torch.arange(B, device=prompt_mask.device), last_nonpad]
                    mask_check = (mask_at_last == 1).all().item()
                    print(f"  ✓ All attention_mask[last_nonpad] == 1: {mask_check}")
                    if not mask_check:
                        invalid_examples = (mask_at_last == 0).nonzero(as_tuple=True)[0].tolist()
                        print(f"  ⚠️ WARNING: Examples {invalid_examples} have attention_mask[last_nonpad] == 0!")
                
                # Hash hidden_states at last_nonpad for first few examples
                for i in range(min(3, B)):
                    hidden_at_last = hidden_states[i, last_nonpad[i]]  # [D]
                    hidden_hash = tensor_sha(hidden_at_last)
                    print(f"  Example {i}: last_nonpad={last_nonpad[i].item()}, hidden_hash={hidden_hash}")
            
            # Use router's encode_prompt which respects pooling setting and applies projection
            prompt_embeddings = router_model.encode_prompt(
                hidden_states=hidden_states,
                prompt_mask=prompt_mask,
                debug=False
            )  # [B, embedding_dim] - already projected to router embedding space
            
            # ====================================================================
            # DIAGNOSTICS: After computing prompt_embs
            # ====================================================================
            prompt_embs = prompt_embeddings  # Alias for clarity
            batch_size_actual = prompt_embs.shape[0]
            
            if debug:
                print(f"[DIAG] prompt_embs.shape: {prompt_embs.shape}")
                assert prompt_embs.ndim == 2, f"prompt_embs.ndim must be 2, got {prompt_embs.ndim}"
                assert prompt_embs.shape[0] == batch_size_actual, f"prompt_embs.shape[0] must be {batch_size_actual}, got {prompt_embs.shape[0]}"
                
                # Compute max_abs_diff_p01 and cos_p01 (if batch_size >= 2)
                if batch_size_actual >= 2:
                    max_abs_diff_p01 = (prompt_embs[0] - prompt_embs[1]).abs().max().item()
                    cos_p01 = F.cosine_similarity(prompt_embs[0:1], prompt_embs[1:2], dim=1).item()
                    print(f"[DIAG] max_abs_diff_p01: {max_abs_diff_p01:.6f}")
                    print(f"[DIAG] cos_p01: {cos_p01:.6f}")
                else:
                    print(f"[DIAG] batch_size < 2, skipping p01 comparisons")
                
                # Compute rowwise_max_abs
                rowwise_max_abs = (prompt_embs - prompt_embs[0]).abs().max(dim=1).values
                print(f"[DIAG] rowwise_max_abs.min: {rowwise_max_abs.min().item():.6f}")
                print(f"[DIAG] rowwise_max_abs.max: {rowwise_max_abs.max().item():.6f}")
                print(f"[DIAG] rowwise_max_abs.mean: {rowwise_max_abs.mean().item():.6f}")
            
            # D) Pipeline parity: compute scores using router_model.score_all() (preferred)
            # This uses the same internal pipeline as forward() and avoids manual scoring
            B = batch_size_actual
            result = router_model.score_all(
                hidden_states=hidden_states,
                prompt_mask=prompt_mask,
                debug=False,
                return_compute_metrics=True,
            )
            if isinstance(result, tuple):
                scores, compute_metrics = result
                # Accumulate compute metrics
                eval_compute_tracker.accumulate(compute_metrics, phase="evaluation")
            else:
                scores = result
            
            # Parity check: compare with manual scoring (non-blocking, dtype-aware)
            # Only run on first batch to avoid performance impact
            if not first_batch_processed:
                # Manual scoring for comparison
                all_model_embeddings = router_model.model_embeddings.weight  # [N, embedding_dim]
                prompt_embeddings_norm = F.normalize(prompt_embeddings, p=2, dim=-1)  # [B, embedding_dim]
                all_model_embeddings_norm = F.normalize(all_model_embeddings, p=2, dim=-1)  # [N, embedding_dim]
                cosine_sims = torch.matmul(prompt_embeddings_norm, all_model_embeddings_norm.T)  # [B, N]
                scores_manual = cosine_sims / router_model.tau  # [B, N]
                
                # Convert to float32 for comparison (avoids dtype-specific numerical issues)
                scores_manual_f = scores_manual.float()
                scores_fwd_f = scores.float()
                
                # Compute max_abs_diff on float32
                max_abs_diff = (scores_manual_f - scores_fwd_f).abs().max().item()
                
                # Top-1 matching
                top1_manual = scores_manual.argmax(dim=-1)  # [B]
                top1_fwd = scores.argmax(dim=-1)  # [B]
                top1_match = (top1_manual == top1_fwd).all().item()
                top1_match_rate = (top1_manual == top1_fwd).float().mean().item()
                
                # Determine dtype and set appropriate tolerances
                score_dtype = scores.dtype
                if score_dtype in (torch.bfloat16, torch.float16):
                    atol = 1e-2
                    rtol = 1e-2
                else:  # float32 or float64
                    atol = 1e-4
                    rtol = 1e-4
                
                # Check if scores are close using dtype-aware tolerances
                scores_close = torch.allclose(scores_manual_f, scores_fwd_f, atol=atol, rtol=rtol)
                
                # Print dtype information and parity check results
                if debug:
                    print(f"\n[Pipeline Parity Check] (first batch)")
                    print(f"  hidden_states.dtype: {hidden_states.dtype}")
                    print(f"  prompt_embeddings.dtype: {prompt_embeddings.dtype}")
                    print(f"  model_embeddings.dtype: {all_model_embeddings.dtype}")
                    print(f"  manual_scores.dtype: {scores_manual.dtype}")
                    print(f"  forward_scores.dtype: {scores.dtype}")
                    print(f"  max_abs_diff (float32): {max_abs_diff:.6e}")
                    print(f"  top1_match (all examples): {top1_match}")
                    print(f"  top1_match_rate: {top1_match_rate:.4f}")
                    print(f"  dtype-aware tolerance: atol={atol:.0e}, rtol={rtol:.0e}")
                    print(f"  torch.allclose (float32): {scores_close}")
                
                # Warn if top1 matches but allclose fails (expected numerical drift)
                if top1_match_rate == 1.0 and not scores_close:
                    print(f"\n⚠️  WARNING: Train/Eval scoring parity check")
                    print(f"   Top-1 predictions match perfectly, but scores differ within tolerance.")
                    print(f"   max_abs_diff: {max_abs_diff:.6e} (dtype: {score_dtype})")
                    print(f"   atol={atol:.0e}, rtol={rtol:.0e}")
                    print(f"   This is expected numerical drift in {score_dtype} (e.g., bfloat16 ULP).")
                    print(f"   Continuing evaluation...")
                elif not top1_match:
                    # Only raise error if top-1 predictions don't match (real mismatch)
                    error_msg = (
                        f"Train/Eval scoring mismatch detected!\n"
                        f"  max_abs_diff: {max_abs_diff:.6e} (dtype: {score_dtype})\n"
                        f"  top1_match: {top1_match}\n"
                        f"  top1_match_rate: {top1_match_rate:.4f}\n"
                        f"  torch.allclose (atol={atol:.0e}, rtol={rtol:.0e}): {scores_close}\n"
                        f"This indicates manual scoring diverges from router_model.forward()."
                    )
                    if debug:
                        print(f"\n[Pipeline Flags]")
                        router_model.forward(
                            hidden_states=hidden_states[:1],
                            prompt_mask=prompt_mask[:1],
                            candidate_indices=torch.arange(num_models, device=device).unsqueeze(0),
                            debug=False,
                            print_pipeline_flags=True,
                        )
                    raise RuntimeError(error_msg)
            
            # ====================================================================
            # DIAGNOSTICS: After computing logits
            # ====================================================================
            logits = scores  # Alias for clarity
            
            if debug:
                print(f"[DIAG] logits.shape: {logits.shape}")
                assert logits.ndim == 2, f"logits.ndim must be 2, got {logits.ndim}"
                assert logits.shape == (batch_size_actual, num_models), f"logits.shape must be ({batch_size_actual}, {num_models}), got {logits.shape}"
                
                # Compute max_abs_diff_l01 (if batch_size >= 2)
                if batch_size_actual >= 2:
                    max_abs_diff_l01 = (logits[0] - logits[1]).abs().max().item()
                    print(f"[DIAG] max_abs_diff_l01: {max_abs_diff_l01:.6f}")
                else:
                    print(f"[DIAG] batch_size < 2, skipping l01 comparison")
                
                # Compute rowwise_max_abs_logits (E: check for constant logits)
                rowwise_max_abs_logits = (logits - logits[0]).abs().max(dim=1).values
                print(f"[DIAG] rowwise_max_abs_logits.min: {rowwise_max_abs_logits.min().item():.6f}")
                print(f"[DIAG] rowwise_max_abs_logits.max: {rowwise_max_abs_logits.max().item():.6f}")
                print(f"[DIAG] rowwise_max_abs_logits.mean: {rowwise_max_abs_logits.mean().item():.6f}")
                
                # Check if all logits rows are identical (indexing/broadcasting bug)
                # Fix: compare against batch size, not tautology
                if batch_size_actual > 1 and rowwise_max_abs_logits.max().item() < 1e-6:
                    raise RuntimeError("All logits rows identical: indexing/broadcasting bug")
            
            # E) Root-cause diagnostics: hashes for first 10 examples
            if debug and example_counter < 10:
                num_to_collect = min(10 - example_counter, batch_size_actual)
                for i in range(num_to_collect):
                    # Get hashes
                    ids_hash = tensor_sha(inputs['input_ids'][i])
                    prompt_emb_hash = tensor_sha(prompt_embs[i])
                    score_vec_hash = tensor_sha(logits[i])
                    
                    if diagnostics is not None:
                        # Only collect ids_hashes if we haven't collected them yet (first batch)
                        if not first_batch_processed and len(diagnostics["ids_hashes"]) < 10:
                            diagnostics["ids_hashes"].append(ids_hash)
                        diagnostics["prompt_emb_hashes"].append(prompt_emb_hash)
                        diagnostics["score_vec_hashes"].append(score_vec_hash)
                    
                    if not first_batch_processed:
                        print(f"\n[Root-Cause Diagnostics] Example {example_counter + i}")
                        print(f"  ids_hash: {ids_hash}")
                        print(f"  prompt_emb_hash: {prompt_emb_hash}")
                        print(f"  score_vec_hash: {score_vec_hash}")
            
            # E) Explicit checks for constant embeddings/logits (first batch only, after collecting hashes)
            if debug and not first_batch_processed and batch_size_actual > 1:
                # Wait until we've collected hashes from this batch
                if diagnostics is not None and len(diagnostics["prompt_emb_hashes"]) >= min(2, batch_size_actual):
                    # Check: if ids hashes differ but prompt_emb_hash identical
                    if len(diagnostics["ids_hashes"]) > 1:
                        unique_ids = len(set(diagnostics["ids_hashes"]))
                        unique_prompt_embs = len(set(diagnostics["prompt_emb_hashes"]))
                        
                        if unique_ids > 1 and unique_prompt_embs == 1:
                            raise RuntimeError(
                                "Prompt embeddings constant across different tokenized inputs!\n"
                                f"  Unique ids_hashes: {unique_ids}, unique prompt_emb_hashes: {unique_prompt_embs}"
                            )
                        
                        # Check: if prompt_emb_hash differ but score_vec_hash identical
                        unique_scores = len(set(diagnostics["score_vec_hashes"]))
                        if unique_prompt_embs > 1 and unique_scores == 1:
                            raise RuntimeError(
                                "Scores/logits constant across examples (broadcasting/overwriting bug)!\n"
                                f"  Unique prompt_emb_hashes: {unique_prompt_embs}, unique score_vec_hashes: {unique_scores}"
                            )
                        
                        # Check: if ids hashes identical
                        if unique_ids == 1 and batch_size_actual > 1:
                            raise RuntimeError(
                                "Tokenized inputs identical across examples (dataset/prompt formatting bug)!\n"
                                f"  All {batch_size_actual} examples have same ids_hash: {diagnostics['ids_hashes'][0]}"
                            )
            
            first_batch_processed = True
            
            # Batch-level prompt embedding statistics (for first batch only, if debug enabled)
            if debug and start_idx == 0:
                print(f"\n[DEBUG] Computing batch-level statistics for first batch (start_idx={start_idx})")
                # Compute batch statistics
                prompt_norms = prompt_embeddings_norm.norm(p=2, dim=-1).cpu().float()  # [B]
                prompt_norm_mean = prompt_norms.mean().item()
                prompt_norm_std = prompt_norms.std().item()
                prompt_norm_min = prompt_norms.min().item()
                prompt_norm_max = prompt_norms.max().item()
                
                # Per-dimension std (mean of std across dimensions)
                prompt_emb_std_per_dim = prompt_embeddings_norm.float().std(dim=0).cpu()  # [embedding_dim]
                prompt_emb_std_per_dim_mean = prompt_emb_std_per_dim.mean().item()
                
                # Mean pairwise cosine similarity (use subsample if batch is large)
                batch_size_actual = prompt_embeddings_norm.size(0)
                subsample_size = min(32, batch_size_actual)  # Subsample for efficiency
                if subsample_size < batch_size_actual:
                    indices = torch.randperm(batch_size_actual, device=prompt_embeddings_norm.device)[:subsample_size]
                    prompt_emb_subset = prompt_embeddings_norm[indices]  # [subsample_size, embedding_dim]
                else:
                    prompt_emb_subset = prompt_embeddings_norm
                
                # Compute pairwise cosine similarities
                pairwise_cosine = torch.matmul(prompt_emb_subset, prompt_emb_subset.T)  # [subsample_size, subsample_size]
                # Extract upper triangle (excluding diagonal)
                mask = torch.triu(torch.ones(subsample_size, subsample_size, device=pairwise_cosine.device), diagonal=1).bool()
                mean_pairwise_cosine = pairwise_cosine[mask].mean().item()
                
                print(f"\n{'='*80}")
                print(f"[BATCH-LEVEL PROMPT EMBEDDING STATISTICS] (batch_size={batch_size_actual})")
                print(f"{'='*80}")
                print(f"  prompt_norm_mean: {prompt_norm_mean:.6f}")
                print(f"  prompt_norm_std:  {prompt_norm_std:.6f}")
                print(f"  prompt_norm_min:  {prompt_norm_min:.6f}")
                print(f"  prompt_norm_max:  {prompt_norm_max:.6f}")
                print(f"  prompt_emb_std_per_dim_mean: {prompt_emb_std_per_dim_mean:.6f}")
                print(f"  mean_pairwise_cosine(prompt_embs): {mean_pairwise_cosine:.6f} (computed on {subsample_size} examples)")
                print(f"{'='*80}\n")
            
            # Get top-k predictions (over all models)
            top_scores, top_indices = scores.topk(k=max(k_values), dim=-1)
            
            # If known_domain_mode is enabled, also compute filtered scores per example
            known_domain_scores = None
            known_domain_top_indices = None
            if known_domain_mode:
                # For each example, filter scores to only include models from the same domain
                batch_size_actual = scores.shape[0]
                known_domain_scores_list = []
                known_domain_top_indices_list = []
                
                for i in range(batch_size_actual):
                    domain = domains[i]
                    # Normalize domain
                    normalized_domain = normalize_domain(domain)
                    
                    # Get all models in this domain
                    domain_model_indices = model_registry.get_domain_models(normalized_domain)
                    
                    # Create a mask for domain models
                    domain_mask = torch.zeros(len(model_registry), dtype=torch.bool, device=device)
                    
                    if len(domain_model_indices) == 0:
                        # No models in domain - skip this example for known-domain metrics
                        # Set empty top indices
                        top_indices_filtered = torch.tensor([], dtype=torch.long, device=device)
                        filtered_scores = scores[i].clone()  # Keep original scores for consistency
                    else:
                        domain_mask[domain_model_indices] = True
                        
                        # Filter scores to only domain models
                        # Set scores for non-domain models to a very negative value
                        filtered_scores = scores[i].clone()
                        filtered_scores[~domain_mask] = float('-inf')
                        
                        # Get top-k over filtered scores
                        max_k = max(k_values)
                        if len(domain_model_indices) < max_k:
                            max_k = len(domain_model_indices)
                        
                        if max_k > 0:
                            top_scores_filtered, top_indices_filtered = filtered_scores.topk(k=max_k, dim=-1)
                        else:
                            # Edge case: no models in domain (shouldn't happen due to check above)
                            top_indices_filtered = torch.tensor([], dtype=torch.long, device=device)
                    
                    known_domain_scores_list.append(filtered_scores)
                    known_domain_top_indices_list.append(top_indices_filtered)
                
                # Stack into tensors (note: top_indices may have different lengths, so we keep as list)
                known_domain_scores = torch.stack(known_domain_scores_list)  # [B, num_models]
                known_domain_top_indices = known_domain_top_indices_list  # List of [k] tensors
            
            # Optionally compute candidate-set logits/metrics on the SAME batch
            cand_logits: Optional[torch.Tensor] = None
            cand_indices_tensor: Optional[torch.Tensor] = None
            if candidate_K is not None and candidate_K > 0:
                # Build candidate sets using the same strategy as training
                # Defaults match router_config.json if available
                K_total = candidate_K
                K_semantic = router_config.get("K_semantic", K_total - 1) if router_config else min(48, K_total - 1)
                K_far = router_config.get("K_far", 0) if router_config else 0
                K_hard = router_config.get("K_hard", 0) if router_config else 0

                cand_builder = CandidateSetBuilder(
                    registry=model_registry,
                    K_total=K_total,
                    K_semantic=K_semantic,
                    K_far=K_far,
                    K_hard=K_hard,
                )

                # Build candidate sets only for examples whose gold model is in the registry
                cand_y_indices: List[int] = []
                cand_domains: List[str] = []
                cand_valid_indices: List[int] = []
                for i, (label_idx, domain, valid) in enumerate(zip(label_indices, domains, valid_mask)):
                    if not valid or label_idx < 0:
                        continue
                    cand_y_indices.append(label_idx)
                    cand_domains.append(domain)
                    cand_valid_indices.append(i)

                if cand_y_indices:
                    cand_list = cand_builder.build_batch(
                        y_indices=cand_y_indices,
                        domains=cand_domains,
                        hard_negative_cache=None,
                    )
                    cand_indices_tensor = torch.tensor(
                        cand_list, dtype=torch.long, device=device
                    )  # [B_cand, K]

                    # Sub-select hidden_states / prompt_mask for candidate examples
                    hs_cand = hidden_states[cand_valid_indices]
                    pm_cand = prompt_mask[cand_valid_indices]

                    result = router_model(
                        hidden_states=hs_cand,
                        prompt_mask=pm_cand,
                        candidate_indices=cand_indices_tensor,
                        debug=False,
                        return_compute_metrics=True,
                    )
                    if isinstance(result, tuple):
                        cand_logits, compute_metrics = result
                        eval_compute_tracker.accumulate(compute_metrics, phase="evaluation")
                    else:
                        cand_logits = result

                    # Candidate-set top1 accuracy: gold always at index 0 by construction
                    cand_pred = cand_logits.argmax(dim=-1)  # [B_cand]
                    cand_top1_correct += (cand_pred == 0).sum().item()
                    cand_total += cand_logits.shape[0]
                    
                    # Track score margins for candidate sets (comparable to training)
                    if cand_logits.shape[0] > 0:
                        positive_scores = cand_logits[:, 0]  # [B_cand] - gold at index 0
                        negative_scores = cand_logits[:, 1:].mean(dim=-1)  # [B_cand] - average of negatives
                        margins = (positive_scores - negative_scores).cpu().tolist()
                        candidate_score_margins.extend(margins)

            # Compute metrics for valid examples (global acc_all)
            for i, (label_idx, pred_indices, domain, valid, gold_in_reg) in enumerate(
                zip(label_indices, top_indices, domains, valid_mask, gold_in_registry_flags)
            ):
                example_id = start_idx + i
                total_examples_processed += 1
                
                # ====================================================================
                # PER-EXAMPLE DIAGNOSTICS: Use p_i and logits_i, print hashes
                # ====================================================================
                p_i = prompt_embs[i]  # [embedding_dim]
                logits_i = logits[i]  # [num_models]
                
                # Compute hashes (using tensor_sha helper)
                p_hash = tensor_sha(p_i)
                l_hash = tensor_sha(logits_i)
                if debug:
                    print(f"[DIAG] i={i} p_hash={p_hash} l_hash={l_hash}")
                
                if gold_in_reg:
                    gold_in_registry_count += 1
                
                if not valid:
                    continue
                
                # Get score vector for this example
                # Convert to float32 before numpy conversion (handles BFloat16)
                score_vec = logits_i.cpu().float().numpy()  # [N]
                
                # Track score margins for all-models evaluation
                gold_score = logits_i[label_idx].item()
                # Get average score of all other models
                other_scores = torch.cat([logits_i[:label_idx], logits_i[label_idx+1:]]).mean().item()
                all_models_score_margins.append(gold_score - other_scores)
                
                # Compute entropy of softmax distribution
                # Convert scores to probabilities using softmax
                probs = F.softmax(logits_i, dim=-1).cpu().float().numpy()
                # Compute entropy: -sum(p * log(p))
                entropy = -np.sum(probs * np.log(probs + 1e-10))
                all_entropies.append(entropy)
                
                # Get top-10 predictions for debug output
                top10_scores, top10_indices = logits_i.topk(k=min(10, len(score_vec)), dim=-1)
                top10_predictions = [
                    (rank + 1, model_registry.idx2model[idx.item()], score.item())
                    for rank, (idx, score) in enumerate(zip(top10_indices, top10_scores))
                ]
                
                # Find gold rank among ALL models
                # Sort all scores to find rank
                sorted_indices = torch.argsort(logits_i, descending=True)
                gold_rank = (sorted_indices == label_idx).nonzero(as_tuple=True)[0].item() + 1
                gold_score = logits_i[label_idx].item()
                
                # Enhanced debug output for first N=10 examples (or all if fewer)
                if debug and example_counter < 10:
                    if example_counter == 0:
                        print(f"\n[DEBUG] Entering diagnostic output block (debug={debug}, example_counter={example_counter})")
                    gold_model_id = labels[i]
                    
                    # ====================================================================
                    # A) Verify gold indexing correctness
                    # ====================================================================
                    gold_idx_from_registry = model_registry.model2idx.get(gold_model_id, -1)
                    gold_model_id_from_idx = model_registry.idx2model.get(label_idx, "NOT_FOUND")
                    
                    # ====================================================================
                    # B) Logit range + probability sharpness
                    # ====================================================================
                    # Compute full logits over ALL models (scores are already computed with temperature)
                    logits_np = logits_i.cpu().float().numpy()  # [num_models]
                    
                    # Compute probabilities from logits (using softmax)
                    probs_tensor = F.softmax(logits_i, dim=-1)
                    probs_np = probs_tensor.cpu().float().numpy()
                    
                    # Get top-1
                    top1_idx = logits_i.argmax().item()
                    top1_model_id = model_registry.idx2model[top1_idx]
                    top1_prob = probs_np[top1_idx]
                    top1_logit = logits_np[top1_idx]
                    
                    # Gold info
                    gold_logit = logits_np[label_idx]
                    gold_prob = probs_np[label_idx]
                    
                    # Get top-10 with probabilities
                    top10_logits, top10_indices_tensor = logits_i.topk(k=min(10, len(logits_i)))
                    top10_list = []
                    for rank, (idx, logit_val) in enumerate(zip(top10_indices_tensor, top10_logits)):
                        model_id = model_registry.idx2model[idx.item()]
                        prob_val = probs_np[idx.item()]
                        top10_list.append((rank + 1, model_id, idx.item(), logit_val.item(), prob_val))
                    
                    # ====================================================================
                    # C) Prompt embedding health checks
                    # ====================================================================
                    prompt_emb = p_i  # [embedding_dim] (before normalization)
                    prompt_emb_norm = prompt_emb.norm(p=2).item()
                    prompt_emb_mean = prompt_emb.float().mean().item()
                    prompt_emb_std = prompt_emb.float().std().item()
                    
                    # ====================================================================
                    # D) Model embedding health checks
                    # ====================================================================
                    # Model embedding norms (over all models, before normalization)
                    model_emb_norms = all_model_embeddings.norm(p=2, dim=-1)  # [num_models]
                    model_emb_norm_mean = model_emb_norms.float().mean().item()
                    model_emb_norm_std = model_emb_norms.float().std().item()
                    model_emb_norm_min = model_emb_norms.float().min().item()
                    model_emb_norm_max = model_emb_norms.float().max().item()
                    
                    # Dot products (after normalization, matching training)
                    prompt_emb_normalized = F.normalize(p_i, p=2, dim=-1)  # [embedding_dim]
                    dot_products = (prompt_emb_normalized @ all_model_embeddings_norm.T).cpu().float()  # [num_models]
                    dot_products_np = dot_products.numpy()
                    
                    # ====================================================================
                    # Print comprehensive diagnostics
                    # ====================================================================
                    print(f"\n{'='*80}")
                    print(f"[DIAGNOSTIC] Example {example_counter + 1} (example_id={example_id})")
                    print(f"{'='*80}")
                    
                    # A) Gold indexing correctness
                    print(f"\n[A] Gold Indexing Correctness")
                    print(f"  gold_model_id: {gold_model_id}")
                    print(f"  gold_idx (from ModelRegistry.model2idx): {gold_idx_from_registry}")
                    print(f"  gold_idx (from label_indices): {label_idx}")
                    print(f"  ModelRegistry.idx2model[{label_idx}]: {gold_model_id_from_idx}")
                    if gold_idx_from_registry == label_idx and gold_model_id_from_idx == gold_model_id:
                        print(f"  ✓ Gold indexing is CORRECT")
                    else:
                        print(f"  ⚠️  Gold indexing MISMATCH!")
                        if gold_idx_from_registry != label_idx:
                            print(f"     model2idx[{gold_model_id}] = {gold_idx_from_registry} != label_idx {label_idx}")
                        if gold_model_id_from_idx != gold_model_id:
                            print(f"     idx2model[{label_idx}] = {gold_model_id_from_idx} != gold_model_id {gold_model_id}")
                    
                    # B) Logit range + probability sharpness
                    print(f"\n[B] Logit Range + Probability Sharpness")
                    print(f"  logits.mean(): {np.mean(logits_np):.6f}")
                    print(f"  logits.std():  {np.std(logits_np):.6f}")
                    print(f"  logits.min():  {np.min(logits_np):.6f}")
                    print(f"  logits.max():  {np.max(logits_np):.6f}")
                    print(f"  logits.range:  {np.max(logits_np) - np.min(logits_np):.6f}")
                    print(f"  top1_logit:    {top1_logit:.6f}")
                    print(f"  top1_prob:     {top1_prob:.6f}")
                    print(f"  gold_logit:    {gold_logit:.6f}")
                    print(f"  gold_rank:     {gold_rank} (among {len(model_registry)} models)")
                    print(f"  (logits.max - logits.min): {np.max(logits_np) - np.min(logits_np):.6f}")
                    print(f"  (top1_logit - gold_logit):  {top1_logit - gold_logit:.6f}")
                    
                    print(f"\n  [Top-10 Predictions] (rank, model_id, idx, logit, prob)")
                    for rank, model_id, idx, logit_val, prob_val in top10_list:
                        marker = " <-- GOLD" if model_id == gold_model_id else ""
                        print(f"    {rank:2d}. {model_id:40s} idx={idx:4d} logit={logit_val:8.6f} prob={prob_val:.6f}{marker}")
                    
                    # C) Prompt embedding health checks
                    print(f"\n[C] Prompt Embedding Health Checks")
                    print(f"  p.norm():      {prompt_emb_norm:.6f}")
                    print(f"  p.mean():      {prompt_emb_mean:.6f}")
                    print(f"  p.std():       {prompt_emb_std:.6f}")
                    
                    # D) Model embedding health checks
                    print(f"\n[D] Model Embedding Health Checks")
                    print(f"  model_embeddings row-norm stats:")
                    print(f"    mean: {model_emb_norm_mean:.6f}")
                    print(f"    std:  {model_emb_norm_std:.6f}")
                    print(f"    min:  {model_emb_norm_min:.6f}")
                    print(f"    max:  {model_emb_norm_max:.6f}")
                    print(f"  dot-product stats (p @ model_embeddings.T, after normalization):")
                    print(f"    mean: {np.mean(dot_products_np):.6f}")
                    print(f"    std:  {np.std(dot_products_np):.6f}")
                    print(f"    min:  {np.min(dot_products_np):.6f}")
                    print(f"    max:  {np.max(dot_products_np):.6f}")
                    
                    # E) Pipeline parity (will be printed once at the start)
                    if example_counter == 0:
                        print(f"\n[E] Pipeline Parity (Train vs Eval)")
                        router_model.forward(
                            hidden_states=hidden_states[:1],
                            prompt_mask=prompt_mask[:1],
                            candidate_indices=torch.zeros(1, 1, dtype=torch.long, device=hidden_states.device),
                            debug=False,
                            print_pipeline_flags=True
                        )
                    
                    print(f"\n  Entropy: {entropy:.6f}")
                    print(f"{'='*80}")
                    example_counter += 1
                
                # Top-1 prediction
                pred_idx = pred_indices[0].item()
                all_predictions.append(pred_idx)
                all_labels.append(label_idx)
                all_domains.append(domain)
                all_scores.append(logits_i.cpu().tolist())
                
                # Top-k accuracy (over all models)
                for k in k_values:
                    if label_idx in pred_indices[:k].tolist():
                        topk_correct[k] += 1
                
                # Known-domain mode: compute accuracy over filtered domain models
                if known_domain_mode and known_domain_top_indices is not None:
                    # Get filtered top indices for this example
                    filtered_top_indices = known_domain_top_indices[i]
                    
                    # Check if gold model is in the domain
                    normalized_domain = normalize_domain(domain)
                    domain_model_indices = model_registry.get_domain_models(normalized_domain)
                    
                    # Also check what domain the gold model actually has in the registry
                    gold_model_domain_in_registry = model_registry.metadata.get(label_idx, {}).get('domain', 'unknown')
                    gold_model_domain_normalized = normalize_domain(gold_model_domain_in_registry)
                    
                    # Verify domain match
                    domain_matches = (normalized_domain == gold_model_domain_normalized)
                    
                    if label_idx in domain_model_indices:
                        known_domain_total += 1
                        
                        # Compute gold model's rank within domain
                        # Get filtered scores for this example
                        filtered_scores_i = known_domain_scores[i] if known_domain_scores is not None else None
                        if filtered_scores_i is not None:
                            # Get scores for domain models only (convert list to tensor for indexing)
                            domain_indices_tensor = torch.tensor(domain_model_indices, dtype=torch.long, device=device)
                            domain_scores = filtered_scores_i[domain_indices_tensor]  # [num_domain_models]
                            gold_score_in_domain = filtered_scores_i[label_idx].item()
                            
                            # Count how many domain models have higher scores
                            gold_rank_in_domain = (domain_scores > gold_score_in_domain).sum().item() + 1
                            known_domain_gold_ranks.append(gold_rank_in_domain)
                        
                        # Check top-k accuracy over filtered set
                        for k in k_values:
                            if len(filtered_top_indices) >= k:
                                if label_idx in filtered_top_indices[:k].tolist():
                                    known_domain_topk_correct[k] += 1
                            elif len(filtered_top_indices) > 0:
                                # If fewer than k models in domain, check if gold is in the available ones
                                if label_idx in filtered_top_indices.tolist():
                                    known_domain_topk_correct[k] += 1
                    elif not domain_matches and debug:
                        # Debug: log domain mismatches
                        if example_counter < 5:  # Only log first few
                            print(f"\n[KNOWN DOMAIN DEBUG] Example {example_counter}: Domain mismatch")
                            print(f"  Test data domain: '{domain}' -> normalized: '{normalized_domain}'")
                            print(f"  Gold model domain in registry: '{gold_model_domain_in_registry}' -> normalized: '{gold_model_domain_normalized}'")
                            print(f"  Gold model idx: {label_idx}, model name: {model_registry.idx2model[label_idx]}")
                    # Note: If gold model is not in domain, we skip it for known-domain metrics
                
                # Known-domain candidate-set mode: build candidate sets from within domain only
                # This is more comparable to training (candidate sets) but restricted to domain
                if known_domain_mode and candidate_K is not None and candidate_K > 0 and valid:
                    normalized_domain = normalize_domain(domain)
                    domain_model_indices = model_registry.get_domain_models(normalized_domain)
                    
                    # Only process if gold model is in domain and domain has enough models for candidate set
                    # Need at least 2 models (1 positive + 1 negative minimum)
                    if label_idx in domain_model_indices and len(domain_model_indices) >= 2:
                        # Build candidate set from domain models only
                        # Use same K_total as regular candidate sets, but sample only from domain
                        K_domain = min(candidate_K, len(domain_model_indices))
                        
                        # Skip if domain is too small (need at least K_domain models)
                        if K_domain < 2:
                            continue
                        
                        # Create a domain-only candidate builder (CandidateSetBuilder already imported at top)
                        domain_cand_builder = CandidateSetBuilder(
                            registry=model_registry,
                            K_total=K_domain,
                            K_semantic=K_domain - 1,  # All from same domain
                            K_far=0,  # No far negatives (all in domain)
                            K_hard=0,  # No hard negatives for now
                            semantic_pool_mode="domain_only",  # Only exact domain
                        )
                        
                        # Build candidate set (positive at index 0)
                        domain_candidates = domain_cand_builder.build(
                            y_idx=label_idx,
                            domain=normalized_domain,
                            hard_negative_cache=None,
                        )
                        
                        # Verify positive is at index 0
                        if domain_candidates[0] != label_idx:
                            if debug and known_domain_cand_total < 5:
                                print(f"\n[KNOWN DOMAIN CAND DEBUG] Example {example_counter}: Positive not at index 0!")
                                print(f"  domain_candidates[0] = {domain_candidates[0]}, label_idx = {label_idx}")
                            continue
                        
                        # Compute logits over domain candidate set
                        domain_cand_tensor = torch.tensor([domain_candidates], dtype=torch.long, device=device)  # [1, K_domain]
                        result = router_model(
                            hidden_states=hidden_states[i:i+1],
                            prompt_mask=prompt_mask[i:i+1],
                            candidate_indices=domain_cand_tensor,
                            debug=False,
                            return_compute_metrics=True,
                        )
                        if isinstance(result, tuple):
                            domain_cand_logits, compute_metrics = result
                            eval_compute_tracker.accumulate(compute_metrics, phase="evaluation")
                        else:
                            domain_cand_logits = result
                        
                        # Check if gold (at index 0) is ranked #1
                        domain_cand_pred = domain_cand_logits.argmax(dim=-1).item()
                        if domain_cand_pred == 0:
                            known_domain_cand_correct += 1
                        known_domain_cand_total += 1
                
                # Hierarchical evaluation: two-stage inference with Top-N Domain Hierarchical Rerank
                if hierarchical_eval and valid:
                    # Debug: Print registry coverage summary for first example only
                    if debug and i == 0 and start_idx == 0:
                        print(f"\n{'='*80}")
                        print(f"[Hierarchical Evaluation Registry Coverage]")
                        print(f"{'='*80}")
                        print(f"  Registry contains:")
                        print(f"    - {len(model_registry)} total models")
                        print(f"    - {len(model_registry.domain2models)} domains")
                        print(f"    - {len(group_to_model_indices)} groups in group_to_model_indices")
                        
                        # Count test examples by domain
                        test_domains_count = defaultdict(int)
                        test_domains_in_registry = set()
                        for ex in test_data:
                            domain = ex.get('domain', 'unknown')
                            test_domains_count[domain] += 1
                            if domain in model_registry.domain2models:
                                test_domains_in_registry.add(domain)
                        
                        print(f"  Test data contains:")
                        print(f"    - {len(test_data)} total examples")
                        print(f"    - {len(test_domains_count)} unique domains")
                        print(f"    - {len(test_domains_in_registry)} domains that are in registry")
                        if len(test_domains_count) != len(test_domains_in_registry):
                            missing_domains = set(test_domains_count.keys()) - test_domains_in_registry
                            print(f"    ⚠️  WARNING: {len(missing_domains)} test domains are NOT in registry: {list(missing_domains)[:5]}")
                        print(f"{'='*80}\n")
                    
                    # Step 1: Compute domain scores using selected aggregation strategy
                    domain_scores = compute_domain_scores(
                        logits_all=logits_i,
                        domain_to_indices_tensor=group_to_model_indices_tensor,
                        mode=hier_domain_score_mode,
                        topk=hier_domain_topk,
                        alpha=hier_domain_hybrid_alpha,
                    )
                    
                    # Step 2: Identify top-N domains (N=1,2,3) by domain_score
                    if domain_scores:
                        # Sort domains by score (descending)
                        # Performance: Use torch.topk on tensor instead of sorting with .item() to avoid CPU sync
                        domain_names = list(domain_scores.keys())
                        domain_scores_tensor = torch.stack([domain_scores[name] for name in domain_names])
                        # Get top-3 domains (for N=1,2,3)
                        k = min(3, len(domain_scores_tensor))
                        topk_scores, topk_indices = torch.topk(domain_scores_tensor, k=k, largest=True)
                        top_domains = [domain_names[idx.item()] for idx in topk_indices]
                        
                        # Step 3: Check if gold model is in registry (for denominator tracking)
                        # Note: gold_in_reg is True when model is in registry (label_idx >= 0 is implied)
                        gold_in_registry = (label_idx >= 0 and gold_in_reg)
                        if not gold_in_registry:
                            hier_missing_gold_count += 1
                        
                        # Step 4: Get gold group and predicted group (for legacy metrics)
                        # Normalize top_domains to ensure consistency (group_to_model_indices keys are normalized)
                        top_domains_normalized = [normalize_domain(d) for d in top_domains] if top_domains else []
                        predicted_group = top_domains_normalized[0] if top_domains_normalized else "unknown"
                        
                        # Compute legacy group accuracy (top-1 domain) - only when gold in registry
                        if gold_in_registry:
                            gold_group = model_idx_to_group.get(label_idx, "unknown")
                            gold_group_normalized = normalize_domain(gold_group) if gold_group != "unknown" else "unknown"
                            hier_group_total += 1
                            if predicted_group == gold_group_normalized:
                                hier_group_correct += 1
                        
                        # Step 5: For each N in {1,2,3}, compute hierarchical rerank metrics
                        # Build restricted sets once and reuse (fixes issue #2: avoid computing twice)
                        restricted_indices_atN = {}
                        for N in [1, 2, 3]:
                            restricted_indices = build_restricted_set(
                                top_domains=top_domains_normalized,
                                domain_to_indices=group_to_model_indices,
                                N=N,
                                num_models=len(logits_i),
                            )
                            restricted_indices_atN[N] = restricted_indices
                        
                        for N in [1, 2, 3]:
                            restricted_indices = restricted_indices_atN[N]
                            # Track restricted set size (fixes issue #1: use same fallback logic)
                            restricted_size = len(restricted_indices)
                            hier_restricted_sizes_atN[N].append(restricted_size)
                            
                            # Check if gold domain is in predicted top-N domains (for domain recall)
                            # Note: Both model_idx_to_group and top_domains_normalized use normalized domains
                            gold_domain_in_topN = False
                            if gold_in_registry:
                                gold_group = model_idx_to_group.get(label_idx, "unknown")
                                gold_group_normalized = normalize_domain(gold_group) if gold_group != "unknown" else "unknown"
                                # top_domains_normalized are already normalized
                                gold_domain_in_topN = (gold_group_normalized in top_domains_normalized[:N])
                                if gold_domain_in_topN:
                                    hier_domain_included_atN[N] += 1
                            
                            # Compute hierarchical rerank metrics (using pre-built restricted set)
                            top1_correct, top5_correct, top10_correct = hierarchical_rerank_topN(
                                logits_all=logits_i,
                                restricted_indices=restricted_indices,
                                gold_model_idx=label_idx,
                                device=device,
                            )
                            
                            # Check if gold model is in restricted set
                            gold_in_restricted = (label_idx in restricted_indices) if gold_in_registry else False
                            
                            # E2E metrics: count correct predictions for ALL examples where gold is in registry
                            # E2E = domain_recall * conditional_accuracy
                            # Requires: (1) gold domain in top-N AND (2) gold model top-1 in restricted set
                            # If gold domain not in top-N, prediction is incorrect (count as 0)
                            if gold_in_registry:
                                # E2E is correct only when domain is selected AND model is top-1
                                if gold_domain_in_topN and gold_in_restricted and top1_correct:
                                    hier_model_top1_e2e_atN[N] += 1
                                # Note: If gold domain not in top-N OR gold not top-1, E2E count stays 0 (correctly counts as incorrect)
                            
                            # Conditional metrics: only count when gold DOMAIN is in top-N (not just gold model in restricted set)
                            # This ensures: E2E = domain_recall * conditional_accuracy
                            # Note: gold_domain_in_topN implies gold_in_restricted (if domain mapping is correct),
                            # but we check gold_domain_in_topN explicitly to ensure correct conditional denominator
                            if gold_domain_in_topN and gold_in_registry:
                                hier_model_total_atN[N] += 1  # Count of examples where gold domain in top-N
                                if gold_in_restricted and top1_correct:
                                    hier_model_top1_atN[N] += 1  # Legacy: conditional top1
                                    hier_model_top1_cond_atN[N] += 1  # New: explicit conditional top1
                                if gold_in_restricted and top5_correct:
                                    hier_model_top5_atN[N] += 1
                                if gold_in_restricted and top10_correct:
                                    hier_model_top10_atN[N] += 1
                            
                            # Debug: Track why gold model is not in restricted set
                            if debug and gold_in_registry and not gold_in_restricted:
                                gold_group = model_idx_to_group.get(label_idx, "unknown")
                                gold_group_normalized = normalize_domain(gold_group) if gold_group != "unknown" else "unknown"
                                if gold_group_normalized not in group_to_model_indices:
                                    hier_gold_domain_missing_from_registry += 1
                                elif gold_group_normalized not in top_domains_normalized[:N]:
                                    hier_gold_not_in_predicted_domain[N] += 1
                        
                        # Legacy: Restrict to top-1 domain and re-rank (for backward compatibility)
                        # NOTE: These metrics are CONDITIONAL (only when gold model is in restricted set)
                        # They should NOT be compared to E2E metrics which use correct denominators
                        legacy_restricted_indices = group_to_model_indices.get(predicted_group, [])
                        if legacy_restricted_indices:
                            # Convert to tensor for indexing
                            legacy_restricted_indices_tensor = torch.tensor(legacy_restricted_indices, dtype=torch.long, device=device)
                            # Get logits for models in restricted set
                            legacy_restricted_logits = logits_i[legacy_restricted_indices_tensor]  # [num_models_in_restricted]
                            
                            # Track restricted set size (legacy)
                            hier_restricted_sizes.append(len(legacy_restricted_indices))
                            
                            # Find gold model's position in restricted set
                            # NOTE: This is conditional - only counts when gold is in restricted set
                            if label_idx in legacy_restricted_indices and gold_in_registry:
                                # Get index of gold model in restricted_indices list
                                gold_pos_in_restricted = legacy_restricted_indices.index(label_idx)
                                
                                # Compute top-k accuracy within restricted set (legacy - CONDITIONAL)
                                hier_model_total += 1
                                sorted_restricted = torch.argsort(legacy_restricted_logits, descending=True)
                                gold_rank_restricted = (sorted_restricted == gold_pos_in_restricted).nonzero(as_tuple=True)[0].item() + 1
                                
                                # Check top-k accuracy (legacy - CONDITIONAL)
                                for k in k_values:
                                    if gold_rank_restricted <= k:
                                        hier_model_topk_correct[k] += 1
                                
                                # Compute end-to-end accuracy: group correct AND model correct (top-1) (legacy - CONDITIONAL)
                                # NOTE: This is NOT true E2E - it's conditional on gold being in restricted set
                                if gold_in_registry:
                                    gold_group = model_idx_to_group.get(label_idx, "unknown")
                                    gold_group_normalized = normalize_domain(gold_group) if gold_group != "unknown" else "unknown"
                                    if predicted_group == gold_group_normalized and gold_rank_restricted == 1:
                                        hier_e2e_top1_correct += 1
                        else:
                            # Predicted group is empty (shouldn't happen, but handle gracefully)
                            hier_restricted_sizes.append(0)
                
                # Domain accuracy (is predicted model in same domain as ground truth?)
                pred_model_name = model_registry.idx2model[pred_idx]
                true_model_name = model_registry.idx2model[label_idx]
                # Metadata is keyed by index, not model name
                pred_domain = model_registry.metadata.get(pred_idx, {}).get('domain', 'unknown')
                true_domain = model_registry.metadata.get(label_idx, {}).get('domain', 'unknown')
                
                # Normalize domains before comparison (consistent with other domain comparisons in this file)
                pred_domain_normalized = normalize_domain(pred_domain)
                true_domain_normalized = normalize_domain(true_domain)
                
                # Use normalized domain as key for consistent counting
                domain_total[true_domain_normalized] += 1
                if pred_domain_normalized == true_domain_normalized:
                    domain_correct[true_domain_normalized] += 1
                
                # Model family accuracy (is predicted model in same family as ground truth?)
                # Get model names first
                pred_model_name = model_registry.idx2model[pred_idx]
                
                # Try to get family from model_family_lookup FIRST (most reliable source)
                # This is the primary source since it comes from the original dataset
                pred_family = None
                true_family = None
                
                if model_family_lookup is not None:
                    # Try exact match first
                    pred_family = model_family_lookup.get(pred_model_name)
                    true_family = model_family_lookup.get(true_model_name)
                    
                    # Also try normalized names (case-insensitive lookup) if exact match failed
                    if pred_family is None:
                        from .model_selection_carve.model_registry import normalize_model_name
                        pred_family = model_family_lookup.get(normalize_model_name(pred_model_name))
                    if true_family is None:
                        from .model_selection_carve.model_registry import normalize_model_name
                        true_family = model_family_lookup.get(normalize_model_name(true_model_name))
                
                # Fallback: try registry metadata (might not have family info if created without family_key)
                if pred_family is None:
                    pred_metadata = model_registry.metadata.get(pred_idx, {})
                    pred_family = pred_metadata.get('family') or pred_metadata.get('model_family')
                
                if true_family is None:
                    true_metadata = model_registry.metadata.get(label_idx, {})
                    true_family = true_metadata.get('family') or true_metadata.get('model_family')
                
                # Fallback: try to get from batch/test_data if not found yet
                if pred_family is None:
                    # Look up predicted model's family from test data by matching model name
                    for test_ex in test_data:
                        if test_ex.get('model_name') == pred_model_name:
                            pred_family = test_ex.get('model_family')
                            break
                
                if true_family is None:
                    # Get true family from current batch example (more efficient than searching all test_data)
                    if i < len(batch):
                        true_family = batch[i].get('model_family')
                    # Fallback: search test_data if not in batch
                    if true_family is None:
                        # Use example_id to index into test_data directly (more efficient)
                        if example_id < len(test_data):
                            true_family = test_data[example_id].get('model_family')
                        # Last resort: search by model name
                        if true_family is None:
                            for test_ex in test_data:
                                if test_ex.get('model_name') == true_model_name:
                                    true_family = test_ex.get('model_family')
                                    break
                
                # Only count if both families are available (not None and not empty)
                if pred_family and true_family and pred_family.strip() and true_family.strip():
                    family_total[true_family] += 1
                    if pred_family == true_family:
                        family_correct[true_family] += 1
                elif debug and example_id < 10:
                    # Debug: print first few examples where family is missing
                    print(f"[FAMILY DEBUG] Example {example_id}: pred_family={pred_family}, true_family={true_family}")
                    print(f"  pred_model={pred_model_name}, true_model={true_model_name}")
                    print(f"  pred_idx={pred_idx}, label_idx={label_idx}")
                    print(f"  pred_metadata={pred_metadata}")
                    print(f"  true_metadata={true_metadata}")
                    if model_family_lookup:
                        print(f"  pred_in_lookup={pred_model_name in model_family_lookup}, true_in_lookup={true_model_name in model_family_lookup}")
                        if pred_model_name in model_family_lookup:
                            print(f"  pred_family_from_lookup={model_family_lookup[pred_model_name]}")
                        if true_model_name in model_family_lookup:
                            print(f"  true_family_from_lookup={model_family_lookup[true_model_name]}")
                
                # Track forgetting metrics for old models/domains/families
                # Old items are those from earlier experiences (identified by M_old if available)
                # For now, we'll compute this after the loop using exp2 diagnostics if available
                # Otherwise, we'll need to identify old items differently
            
            if (end_idx) % 100 == 0 or end_idx == num_examples:
                print(f"  Processed {end_idx}/{num_examples} examples...")
    
    # Compute overall metrics
    num_valid = len(all_predictions)
    
    topk_accuracy = {}
    for k in k_values:
        topk_accuracy[f"top{k}_accuracy"] = topk_correct[k] / num_valid if num_valid > 0 else 0.0
    
    # Overall domain accuracy
    total_domain_correct = sum(domain_correct.values())
    total_domain_count = sum(domain_total.values())
    overall_domain_accuracy = total_domain_correct / total_domain_count if total_domain_count > 0 else 0.0
    
    # Per-domain accuracy
    per_domain_accuracy = {}
    for domain in domain_total:
        acc = domain_correct[domain] / domain_total[domain] if domain_total[domain] > 0 else 0.0
        per_domain_accuracy[f"accuracy_domain_{domain}"] = acc
    
    # Overall model family accuracy
    # NOTE: Use total_family_count as denominator (only examples with family info)
    # This ensures family accuracy is calculated on the same set of examples where
    # family information is available, making it comparable to model accuracy
    total_family_correct = sum(family_correct.values())
    total_family_count = sum(family_total.values())
    # Use total_family_count (examples with family info) as denominator
    # This is the correct denominator since family_correct only counts examples with family info
    overall_family_accuracy = total_family_correct / total_family_count if total_family_count > 0 else 0.0
    # Also compute family accuracy using num_valid as denominator for comparison
    # Examples without family info are treated as incorrect (0 in numerator)
    family_accuracy_all_examples = total_family_correct / num_valid if num_valid > 0 else 0.0
    
    # Per-family accuracy
    per_family_accuracy = {}
    for family in family_total:
        acc = family_correct[family] / family_total[family] if family_total[family] > 0 else 0.0
        per_family_accuracy[f"accuracy_family_{family}"] = acc
    
    # Debug: Print family accuracy statistics if count is 0 or significantly different from num_valid
    # This helps diagnose denominator mismatch issues
    if total_family_count == 0:
        print(f"\n{'='*80}")
        print(f"[Model Family Accuracy Debug]")
        print(f"{'='*80}")
        print(f"  ⚠️  WARNING: No examples had both pred_family and true_family available!")
        print(f"  total_family_count: {total_family_count} (out of {num_valid} valid examples)")
        print(f"  total_family_correct: {total_family_correct}")
        print(f"  overall_family_accuracy: {overall_family_accuracy:.4f} (uses total_family_count={total_family_count} as denominator)")
        print(f"  family_accuracy_all_examples: {family_accuracy_all_examples:.4f} (uses num_valid={num_valid} as denominator)")
    elif total_family_count < num_valid * 0.5:
        # Warn if less than 50% of examples have family info
        print(f"\n{'='*80}")
        print(f"[Model Family Accuracy Warning]")
        print(f"{'='*80}")
        print(f"  ⚠️  WARNING: Only {total_family_count}/{num_valid} examples have family info ({100*total_family_count/num_valid:.1f}%)")
        print(f"  model_family_accuracy: {overall_family_accuracy:.4f} (uses total_family_count={total_family_count} as denominator)")
        print(f"  model_family_accuracy_all_examples: {family_accuracy_all_examples:.4f} (uses num_valid={num_valid} as denominator)")
        print(f"  Note: model_family_accuracy now uses total_family_count as denominator (only examples with family info)")
        if model_family_lookup:
            print(f"  model_family_lookup size: {len(model_family_lookup)}")
            # Sample a few entries
            sample_models = list(model_family_lookup.items())[:5]
            print(f"  Sample lookup entries: {sample_models}")
        else:
            print(f"  ⚠️  model_family_lookup is None!")
        
        # Check a few examples from registry
        print(f"\n  Checking registry metadata for family info:")
        sample_indices = list(range(min(5, len(model_registry))))
        for idx in sample_indices:
            model_name = model_registry.idx2model.get(idx, "unknown")
            metadata = model_registry.metadata.get(idx, {})
            family = metadata.get('family') or metadata.get('model_family')
            print(f"    idx={idx}, model={model_name}, family={family}")
        
        # Check test_data
        if test_data:
            print(f"\n  Checking test_data for model_family:")
            for i, ex in enumerate(test_data[:5]):
                model_name = ex.get('model_name', 'unknown')
                family = ex.get('model_family')
                print(f"    test_data[{i}]: model={model_name}, family={family}")
        
        print(f"{'='*80}\n")
    
    # Initialize forgetting metrics (will be computed in exp2 diagnostics section if M_old is available)
    model_forgetting = None
    domain_forgetting = None
    model_family_forgetting = None
    
    # Candidate-set accuracy (acc_candidate)
    if cand_total > 0:
        acc_candidate_top1 = cand_top1_correct / cand_total
    else:
        acc_candidate_top1 = 0.0
    
    # Score margin statistics (for comparison with training)
    import numpy as _np_margin
    candidate_margin_mean = float(_np_margin.mean(candidate_score_margins)) if candidate_score_margins else 0.0
    all_models_margin_mean = float(_np_margin.mean(all_models_score_margins)) if all_models_score_margins else 0.0
    
    # Known-domain mode metrics (if enabled)
    known_domain_metrics = {}
    if known_domain_mode:
        if known_domain_topk_correct is not None and known_domain_total > 0:
            for k in k_values:
                acc = known_domain_topk_correct[k] / known_domain_total
                known_domain_metrics[f"known_domain_top{k}_accuracy"] = acc
            known_domain_metrics["known_domain_num_examples"] = known_domain_total
            
            # Add rank statistics
            if known_domain_gold_ranks:
                import numpy as _np3
                known_domain_metrics["known_domain_gold_rank_median"] = float(_np3.median(known_domain_gold_ranks))
                known_domain_metrics["known_domain_gold_rank_mean"] = float(_np3.mean(known_domain_gold_ranks))
                known_domain_metrics["known_domain_gold_rank_p90"] = float(_np3.percentile(known_domain_gold_ranks, 90))
                known_domain_metrics["known_domain_gold_mrr"] = float(_np3.mean([1.0 / r for r in known_domain_gold_ranks]))
        
        # Known-domain candidate-set accuracy (comparable to training)
        if known_domain_cand_total > 0:
            known_domain_metrics["known_domain_candidate_top1_accuracy"] = known_domain_cand_correct / known_domain_cand_total
            known_domain_metrics["known_domain_candidate_num_examples"] = known_domain_cand_total
        else:
            # Debug: why are candidate sets not being computed?
            print(f"\n[KNOWN DOMAIN CAND DEBUG]")
            print(f"  known_domain_cand_total: {known_domain_cand_total}")
            print(f"  candidate_K: {candidate_K}")
            print(f"  known_domain_total: {known_domain_total}")
            if candidate_K is None or candidate_K == 0:
                print(f"  WARNING: candidate_K is None or 0, so candidate sets not built!")
            elif known_domain_total > 0:
                print(f"  WARNING: {known_domain_total} examples in known-domain mode, but 0 candidate sets built!")
                print(f"  This might indicate domains are too small (< 2 models) or other filtering issue.")
        
        # Debug: why are main metrics not being computed?
        if known_domain_topk_correct is None or known_domain_total == 0:
            print(f"\n[KNOWN DOMAIN MODE DEBUG]")
            print(f"  known_domain_topk_correct is None: {known_domain_topk_correct is None}")
            print(f"  known_domain_total: {known_domain_total}")
            if known_domain_topk_correct is None:
                print(f"  WARNING: known_domain_topk_correct was not initialized!")
            if known_domain_total == 0:
                print(f"  WARNING: No examples had gold models in their domains!")
                print(f"  This might indicate a domain mismatch between test data and registry.")

    # Gold rank diagnostics over ALL models (geometry / global separation)
    gold_ranks: List[int] = []
    gold_mrr_values: List[float] = []
    gold_margin_top1: List[float] = []
    gold_margin_topK: List[float] = []

    if num_valid > 0 and all_scores:
        K_global = max(k_values) if k_values else 1
        import numpy as _np
        for score_vec, label_idx in zip(all_scores, all_labels):
            if label_idx < 0 or label_idx >= len(score_vec):
                continue
            scores_arr = _np.asarray(score_vec, dtype=_np.float32)
            gold_score = float(scores_arr[label_idx])
            # Rank: 1 + number of models with strictly higher score
            rank = int((scores_arr > gold_score).sum()) + 1
            gold_ranks.append(rank)
            gold_mrr_values.append(1.0 / rank)
            # Margin vs top-1
            top1_score = float(scores_arr.max())
            gold_margin_top1.append(gold_score - top1_score)
            # Margin vs top-K boundary
            if len(scores_arr) >= K_global:
                kth_idx = _np.argpartition(-scores_arr, K_global - 1)[K_global - 1]
                kth_score = float(scores_arr[kth_idx])
            else:
                kth_score = float(scores_arr.min())
            gold_margin_topK.append(gold_score - kth_score)

    def _safe_stat(xs, fn, default=0.0):
        return float(fn(xs)) if xs else float(default)

    import numpy as _np2
    gold_rank_median = _safe_stat(gold_ranks, lambda v: _np2.median(v))
    gold_rank_p90 = _safe_stat(gold_ranks, lambda v: _np2.percentile(v, 90))
    gold_rank_mean = _safe_stat(gold_ranks, lambda v: _np2.mean(v))
    gold_mrr = _safe_stat(gold_mrr_values, lambda v: _np2.mean(v))
    gold_margin_top1_mean = _safe_stat(gold_margin_top1, lambda v: _np2.mean(v))
    gold_margin_topK_mean = _safe_stat(gold_margin_topK, lambda v: _np2.mean(v))

    # Diagnostic: percent of examples where gold_in_registry==True
    gold_in_registry_percent = (gold_in_registry_count / total_examples_processed * 100.0) if total_examples_processed > 0 else 0.0
    
    # Diagnostic: entropy statistics
    # Convert numpy types to Python native types for JSON serialization
    mean_entropy = float(np.mean(all_entropies)) if all_entropies else 0.0
    std_entropy = float(np.std(all_entropies)) if all_entropies else 0.0
    min_entropy = float(np.min(all_entropies)) if all_entropies else 0.0
    max_entropy = float(np.max(all_entropies)) if all_entropies else 0.0
    
    # Expected entropy for uniform distribution (log(N))
    # F) Fix: uniform has HIGH entropy, low entropy means peaked/collapsed
    expected_entropy_uniform = float(np.log(len(model_registry))) if len(model_registry) > 0 else 0.0
    
    # Compute hierarchical metrics (if enabled)
    hier_metrics = {}
    if hierarchical_eval:
        # Group accuracy
        hier_metrics["hier_group_accuracy"] = hier_group_correct / hier_group_total if hier_group_total > 0 else 0.0
        
        # Model top-k accuracy within restricted set (legacy - top-1 domain only)
        for k in k_values:
            hier_metrics[f"hier_model_top{k}"] = hier_model_topk_correct[k] / hier_model_total if hier_model_total > 0 else 0.0
        
        # End-to-end accuracy (group correct AND model correct) (legacy)
        hier_metrics["hier_e2e_top1"] = hier_e2e_top1_correct / hier_model_total if hier_model_total > 0 else 0.0
        
        # Restricted set size statistics (legacy - top-1 domain only)
        if hier_restricted_sizes:
            hier_metrics["hier_restricted_size_mean"] = float(np.mean(hier_restricted_sizes))
            hier_metrics["hier_restricted_size_median"] = float(np.median(hier_restricted_sizes))
            hier_metrics["hier_restricted_size_p90"] = float(np.percentile(hier_restricted_sizes, 90))
        else:
            hier_metrics["hier_restricted_size_mean"] = 0.0
            hier_metrics["hier_restricted_size_median"] = 0.0
            hier_metrics["hier_restricted_size_p90"] = 0.0
        
        # Top-N Domain Hierarchical Rerank metrics (N=1,2,3)
        # Use correct denominators: gold_in_registry_count for primary, total_examples_processed for strict
        denom_in_registry = gold_in_registry_count  # Examples where gold model is in registry
        denom_total = total_examples_processed  # All examples (strict)
        
        for N in [1, 2, 3]:
            # Legacy conditional metrics (only when gold domain in top-N)
            total_N = hier_model_total_atN[N]  # Count where gold in restricted set
            if total_N > 0:
                hier_metrics[f"hier_model_top1_at{N}"] = hier_model_top1_atN[N] / total_N
                hier_metrics[f"hier_model_top5_at{N}"] = hier_model_top5_atN[N] / total_N
                hier_metrics[f"hier_model_top10_at{N}"] = hier_model_top10_atN[N] / total_N
            else:
                hier_metrics[f"hier_model_top1_at{N}"] = 0.0
                hier_metrics[f"hier_model_top5_at{N}"] = 0.0
                hier_metrics[f"hier_model_top10_at{N}"] = 0.0
            
            # New: Conditional metrics (explicit naming)
            # Conditional denominator: examples where gold domain in top-N (should match hier_domain_included_atN[N])
            included_count = hier_domain_included_atN[N]  # Count where gold domain in top-N
            # Validation: hier_model_total_atN[N] should equal hier_domain_included_atN[N] after our fix
            if hier_model_total_atN[N] != included_count:
                # This should not happen - warn if there's a mismatch
                pass  # Will be caught in validation below
            if included_count > 0:
                hier_metrics[f"hier_model_top1_cond_at{N}"] = hier_model_top1_cond_atN[N] / included_count
            else:
                hier_metrics[f"hier_model_top1_cond_at{N}"] = 0.0
            
            # New: E2E metrics (primary - all examples where gold in registry)
            # E2E numerator should equal conditional numerator (both count: gold_domain_in_topN AND gold_in_restricted AND top1_correct)
            # Therefore: E2E = domain_recall * conditional_accuracy
            if denom_in_registry > 0:
                hier_metrics[f"hier_model_top1_e2e_at{N}"] = hier_model_top1_e2e_atN[N] / denom_in_registry
            else:
                hier_metrics[f"hier_model_top1_e2e_at{N}"] = 0.0
            
            # New: E2E strict metrics (all examples, missing gold counts as incorrect)
            if denom_total > 0:
                hier_metrics[f"hier_model_top1_e2e_at{N}_strict"] = hier_model_top1_e2e_atN[N] / denom_total
            else:
                hier_metrics[f"hier_model_top1_e2e_at{N}_strict"] = 0.0
            
            # New: Domain recall@N (fraction where gold domain ∈ predicted top-N)
            if denom_in_registry > 0:
                hier_metrics[f"hier_domain_recall_at{N}"] = hier_domain_included_atN[N] / denom_in_registry
            else:
                hier_metrics[f"hier_domain_recall_at{N}"] = 0.0
            
            # New: Domain recall@N strict (all examples)
            if denom_total > 0:
                hier_metrics[f"hier_domain_recall_at{N}_strict"] = hier_domain_included_atN[N] / denom_total
            else:
                hier_metrics[f"hier_domain_recall_at{N}_strict"] = 0.0
            
            # Restricted set size statistics for each N
            sizes_N = hier_restricted_sizes_atN[N]
            if sizes_N:
                hier_metrics[f"hier_restricted_size_mean_at{N}"] = float(np.mean(sizes_N))
                hier_metrics[f"hier_restricted_size_median_at{N}"] = float(np.median(sizes_N))
                hier_metrics[f"hier_restricted_size_p90_at{N}"] = float(np.percentile(sizes_N, 90))
            else:
                hier_metrics[f"hier_restricted_size_mean_at{N}"] = 0.0
                hier_metrics[f"hier_restricted_size_median_at{N}"] = 0.0
                hier_metrics[f"hier_restricted_size_p90_at{N}"] = 0.0
        
        # Add diagnostic counts
        hier_metrics["hier_missing_gold_count"] = hier_missing_gold_count
        hier_metrics["hier_gold_in_registry_count"] = denom_in_registry
        hier_metrics["hier_num_examples_total"] = denom_total
        hier_metrics["hier_domain_score_mode"] = hier_domain_score_mode
        hier_metrics["hier_domain_topk"] = hier_domain_topk
        hier_metrics["hier_domain_hybrid_alpha"] = hier_domain_hybrid_alpha
    
    # =====================================================================
    # Exp2 Collapse Diagnostics: Separate new-model interference vs exp1 drift
    # =====================================================================
    exp2_diagnostics = {}
    M_old = None
    print(f"  [Exp2 Diagnostics] Starting M_old detection...")
    print(f"  [Exp2 Diagnostics] checkpoint_dir: {checkpoint_dir}")
    print(f"  [Exp2 Diagnostics] router_config is None: {router_config is None}")
    if checkpoint_dir is not None:
        # Try to read M_old (base registry size from exp1) from multiple sources:
        # 1. router_config.json (may have router_registry_base_path or router_exp1_preservation_M_old)
        # 2. exp1 checkpoint's model_registry.json (from router_registry_base_path)
        # 3. current checkpoint's model_registry.json (if it's the exp1 checkpoint)
        
        # Try 1: Check router_config for explicit M_old or base registry path
        if router_config is not None:
            # Always print what we're checking (not just in debug mode)
            print(f"  [Exp2 Diagnostics] Checking router_config for M_old...")
            print(f"  [Exp2 Diagnostics] router_config keys: {list(router_config.keys())}")
            
            # Check for explicit M_old
            if "router_exp1_preservation_M_old" in router_config:
                M_old_value = router_config.get("router_exp1_preservation_M_old")
                print(f"  [Exp2 Diagnostics] Found router_exp1_preservation_M_old = {M_old_value}")
                if M_old_value is not None:
                    M_old = int(M_old_value)
                    print(f"  [Exp2 Diagnostics] ✓ M_old={M_old} (from router_config.router_exp1_preservation_M_old)")
                else:
                    print(f"  [Exp2 Diagnostics] ⚠️  router_exp1_preservation_M_old is None")
            else:
                print(f"  [Exp2 Diagnostics] router_exp1_preservation_M_old not found in router_config")
            
            # Check for base registry path
            if M_old is None and "router_registry_base_path" in router_config:
                base_registry_path_str = router_config.get("router_registry_base_path")
                print(f"  [Exp2 Diagnostics] Found router_registry_base_path = {base_registry_path_str}")
                if base_registry_path_str:
                    base_registry_path = Path(base_registry_path_str)
                    # Try to resolve relative path: check multiple possible locations
                    if not base_registry_path.is_absolute():
                        # Strategy 1: Try relative to current working directory (project root)
                        resolved_paths = [Path.cwd() / base_registry_path]
                        # Strategy 2: Try relative to checkpoint directory's parent (if checkpoint is in results/ or cco/experiments/)
                        if "results" in str(checkpoint_dir) or "experiments" in str(checkpoint_dir):
                            # Go up to project root: results/apibench/... -> results/ -> project root
                            # or cco/experiments/... -> cco/experiments/ -> cco/ -> project root
                            project_root = checkpoint_dir
                            while project_root.name not in ["results", "cco", "CCO"] and len(project_root.parts) > 1:
                                project_root = project_root.parent
                            if project_root.parent.exists():
                                resolved_paths.append(project_root.parent / base_registry_path)
                        # Strategy 3: Try relative to checkpoint directory itself
                        resolved_paths.append(checkpoint_dir / base_registry_path)
                        
                        # Find first existing path
                        for candidate_path in resolved_paths:
                            if candidate_path.exists():
                                base_registry_path = candidate_path
                                break
                    
                    if base_registry_path.exists():
                        try:
                            with open(base_registry_path, 'r') as f:
                                base_registry_data = json.load(f)
                                if "num_models" in base_registry_data:
                                    M_old = int(base_registry_data["num_models"])
                                elif "model2idx" in base_registry_data:
                                    M_old = len(base_registry_data["model2idx"])
                            if M_old is not None:
                                print(f"  [Exp2 Diagnostics] M_old={M_old} (from router_registry_base_path: {base_registry_path})")
                        except Exception as e:
                            print(f"  ⚠️  [Exp2 Diagnostics] Could not read M_old from base registry {base_registry_path}: {e}")
                    else:
                        print(f"  ⚠️  [Exp2 Diagnostics] Base registry path does not exist: {base_registry_path_str}")
                        print(f"      Resolved to: {base_registry_path.resolve()}")
                        print(f"      Current working directory: {Path.cwd()}")
                        print(f"      Checkpoint directory: {checkpoint_dir}")
        
        # Try 2: Check current checkpoint's model_registry.json (if it's exp1-sized)
        if M_old is None:
            model_registry_path = checkpoint_dir / "model_registry.json"
            if model_registry_path.exists():
                try:
                    with open(model_registry_path, 'r') as f:
                        registry_data = json.load(f)
                        # Get num_models from saved registry
                        if "num_models" in registry_data:
                            candidate_M_old = int(registry_data["num_models"])
                        elif "model2idx" in registry_data:
                            candidate_M_old = len(registry_data["model2idx"])
                        else:
                            candidate_M_old = None
                        
                        # Only use if it's smaller than current registry (indicates exp1)
                        if candidate_M_old is not None and candidate_M_old < len(model_registry):
                            M_old = candidate_M_old
                            print(f"  [Exp2 Diagnostics] M_old={M_old} (from checkpoint model_registry.json, current registry: {len(model_registry)})")
                        elif candidate_M_old is not None:
                            print(f"  [Exp2 Diagnostics] Checkpoint registry size ({candidate_M_old}) >= current registry ({len(model_registry)}), skipping")
                except Exception as e:
                    print(f"  ⚠️  [Exp2 Diagnostics] Could not read M_old from {model_registry_path}: {e}")
            else:
                print(f"  ⚠️  [Exp2 Diagnostics] Checkpoint model_registry.json not found: {model_registry_path}")
        
        # Warn if M_old still not found
        if M_old is None:
            print(f"  ⚠️  [Exp2 Diagnostics] M_old not found. Cannot compute exp2 collapse diagnostics.")
            print(f"      Tried: router_config, router_registry_base_path, checkpoint model_registry.json")
            if router_config is None:
                print(f"      router_config is None - check if router_config.json exists in checkpoint_dir")
            elif "router_exp1_preservation_M_old" not in router_config and "router_registry_base_path" not in router_config:
                print(f"      router_config exists but missing both router_exp1_preservation_M_old and router_registry_base_path")
                print(f"      router_config keys: {list(router_config.keys())}")
            else:
                print(f"      router_config has fields but path resolution may have failed")
            print(f"      To enable exp2 diagnostics, ensure router_config contains router_exp1_preservation_M_old or router_registry_base_path")
        
        if M_old is not None and len(all_predictions) > 0 and len(all_labels) > 0:
            M_new = len(model_registry)
            
            # Convert to numpy for easier computation
            pred_indices = np.array(all_predictions)
            gold_indices = np.array(all_labels)
            
            # A) pred_new_rate = mean(pred_idx >= M_old) - % predictions with pred_idx >= M_old
            pred_new_mask = pred_indices >= M_old
            pred_new_rate = pred_new_mask.mean()
            
            # B) pred_new_rate_on_old_gold = mean((pred_idx >= M_old) & (gold_idx < M_old))
            # to measure new-model interference on old-gold examples
            old_gold_mask = gold_indices < M_old
            pred_new_on_old_gold = (pred_new_mask & old_gold_mask).mean()
            num_old_gold = old_gold_mask.sum()
            
            # C) old_only_top1_accuracy: compute top1 using argmax over logits[:M_old] only
            # (for examples with gold_idx < M_old)
            # Note: We need to recompute predictions with restricted logits
            # For now, we'll approximate by checking if pred_idx < M_old when gold_idx < M_old
            old_only_correct = ((pred_indices < M_old) & (pred_indices == gold_indices) & old_gold_mask).sum()
            old_only_total = old_gold_mask.sum()
            old_only_top1_accuracy = old_only_correct / old_only_total if old_only_total > 0 else 0.0
            
            # D) full_top1_accuracy (existing - already computed as top1_accuracy)
            full_top1_accuracy = topk_accuracy.get("top1_accuracy", 0.0)
            
            # E) top1_accuracy_exp1_slice: compute top-1 accuracy restricted to exp1 slice (logits[:M_old])
            # for ALL examples (not just old-gold ones)
            # This requires recomputing predictions using only logits[:M_old]
            top1_accuracy_exp1_slice = None
            if len(all_scores) > 0 and len(all_scores) == len(all_labels):
                # Recompute predictions using only exp1 slice
                exp1_slice_correct = 0
                exp1_slice_total = 0
                for scores_list, gold_idx in zip(all_scores, all_labels):
                    if gold_idx < 0:  # Skip invalid examples
                        continue
                    # Convert to tensor and restrict to exp1 slice
                    scores_tensor = torch.tensor(scores_list[:M_old], dtype=torch.float32)  # [M_old]
                    if len(scores_tensor) > 0:
                        pred_idx_exp1 = scores_tensor.argmax().item()
                        if pred_idx_exp1 == gold_idx:
                            exp1_slice_correct += 1
                        exp1_slice_total += 1
                
                top1_accuracy_exp1_slice = exp1_slice_correct / exp1_slice_total if exp1_slice_total > 0 else 0.0
            else:
                # Fallback: approximate using existing predictions (less accurate)
                # Only count as correct if pred_idx < M_old and matches gold
                exp1_slice_correct = ((pred_indices < M_old) & (pred_indices == gold_indices)).sum()
                exp1_slice_total = len(pred_indices)
                top1_accuracy_exp1_slice = exp1_slice_correct / exp1_slice_total if exp1_slice_total > 0 else 0.0
            
            # Compute forgetting metrics for old models/domains/families
            # Forgetting = 1 - accuracy on old items (measures performance degradation)
            
            # Identify old models, domains, and families
            old_model_indices = set(range(M_old))
            old_domains = set()
            old_families = set()
            for idx in old_model_indices:
                metadata = model_registry.metadata.get(idx, {})
                domain = metadata.get('domain')
                family = metadata.get('family') or metadata.get('model_family')
                if domain:
                    old_domains.add(domain)
                if family:
                    old_families.add(family)
            
            # Compute accuracy on old models (already computed as old_only_top1_accuracy)
            model_forgetting = 1.0 - old_only_top1_accuracy if old_only_total > 0 else None
            
            # Compute accuracy on old domains
            old_domain_correct_count = 0
            old_domain_total_count = 0
            for domain in old_domains:
                if domain in domain_total:
                    old_domain_total_count += domain_total[domain]
                    old_domain_correct_count += domain_correct[domain]
            domain_forgetting = 1.0 - (old_domain_correct_count / old_domain_total_count) if old_domain_total_count > 0 else None
            
            # Compute accuracy on old families
            old_family_correct_count = 0
            old_family_total_count = 0
            for family in old_families:
                if family in family_total:
                    old_family_total_count += family_total[family]
                    old_family_correct_count += family_correct[family]
            model_family_forgetting = 1.0 - (old_family_correct_count / old_family_total_count) if old_family_total_count > 0 else None
            
            exp2_diagnostics = {
                "exp2_M_old": M_old,
                "exp2_M_new": M_new,
                "exp2_pred_new_rate": float(pred_new_rate),  # % predictions with pred_idx >= M_old
                "exp2_pred_new_rate_on_old_gold": float(pred_new_on_old_gold),
                "exp2_old_only_top1_accuracy": float(old_only_top1_accuracy),
                "exp2_full_top1_accuracy": float(full_top1_accuracy),  # Top-1 accuracy over full registry
                "exp2_top1_accuracy_exp1_slice": float(top1_accuracy_exp1_slice) if top1_accuracy_exp1_slice is not None else 0.0,  # Top-1 accuracy restricted to exp1 slice
                "exp2_num_old_gold_examples": int(num_old_gold),
                "exp2_old_only_correct": int(old_only_correct),
            }
            
            # Update forgetting metrics (will be added to main metrics dict)
            if model_forgetting is not None:
                model_forgetting = float(model_forgetting)
            if domain_forgetting is not None:
                domain_forgetting = float(domain_forgetting)
            if model_family_forgetting is not None:
                model_family_forgetting = float(model_family_forgetting)
            
            # Always print exp2 diagnostics (not just when debug=True)
            print(f"\n{'='*80}")
            print(f"[Exp2 Collapse Diagnostics]")
            print(f"{'='*80}")
            print(f"  M_old (exp1 registry size): {M_old}")
            print(f"  M_new (current registry size): {M_new}")
            print(f"  pred_new_rate: {pred_new_rate:.4f} ({pred_new_rate*100:.2f}% of predictions are new models)")
            print(f"  pred_new_rate_on_old_gold: {pred_new_on_old_gold:.4f} ({pred_new_on_old_gold*100:.2f}% of old-gold examples predicted as new)")
            print(f"  old_only_top1_accuracy: {old_only_top1_accuracy:.4f} ({old_only_correct}/{old_only_total} correct)")
            print(f"  full_top1_accuracy: {full_top1_accuracy:.4f} (over full registry)")
            if top1_accuracy_exp1_slice is not None:
                print(f"  top1_accuracy_exp1_slice: {top1_accuracy_exp1_slice:.4f} (restricted to exp1 slice)")
                print(f"  accuracy_drop_exp1_vs_full: {full_top1_accuracy - top1_accuracy_exp1_slice:.4f}")
            print(f"  Interpretation:")
            print(f"    - High pred_new_rate_on_old_gold suggests new-model interference")
            print(f"    - Low old_only_top1_accuracy suggests exp1 drift (even without new models)")
            print(f"{'='*80}\n")
    
    # Add forgetting metrics to metrics dict (if computed)
    forgetting_metrics = {}
    if model_forgetting is not None:
        forgetting_metrics["model_forgetting"] = model_forgetting
    if domain_forgetting is not None:
        forgetting_metrics["domain_forgetting"] = domain_forgetting
    if model_family_forgetting is not None:
        forgetting_metrics["model_family_forgetting"] = model_family_forgetting
    
    metrics = {
        **topk_accuracy,
        "domain_accuracy": overall_domain_accuracy,
        **per_domain_accuracy,
        "model_family_accuracy": overall_family_accuracy,  # Uses total_family_count as denominator (correct calculation)
        "model_family_accuracy_conditional": overall_family_accuracy,  # Same as model_family_accuracy (for backward compatibility)
        "model_family_accuracy_all_examples": family_accuracy_all_examples,  # Uses num_valid as denominator (for comparison)
        "model_family_num_examples_with_family_info": total_family_count,  # Diagnostic: how many examples had family info
        **per_family_accuracy,
        **forgetting_metrics,
        "num_examples_evaluated": num_valid,
        "num_examples_total": num_examples,
        "num_models": len(model_registry),
        "gold_in_registry_percent": gold_in_registry_percent,
        # Dual-metric reporting
        "top1_accuracy_candidate": acc_candidate_top1,
        # Global rank diagnostics over all models
        "gold_rank_median": gold_rank_median,
        "gold_rank_p90": gold_rank_p90,
        "gold_rank_mean": gold_rank_mean,
        "gold_mrr": gold_mrr,
        "gold_margin_top1_mean": gold_margin_top1_mean,
        "gold_margin_topK_mean": gold_margin_topK_mean,
        "entropy_mean": mean_entropy,
        "entropy_std": std_entropy,
        "entropy_min": min_entropy,
        "entropy_max": max_entropy,
        "entropy_expected_uniform": expected_entropy_uniform,
        # Score margins (for comparison with training)
        "candidate_score_margin_mean": candidate_margin_mean,
        "all_models_score_margin_mean": all_models_margin_mean,
        # Known-domain mode metrics (if enabled)
        **known_domain_metrics,
        # Hierarchical evaluation metrics (if enabled)
        **hier_metrics,
        # Exp2 collapse diagnostics (if checkpoint_dir provided)
        **exp2_diagnostics,
    }
    
    # G) Add diagnostics to return value if debug enabled
    if debug and diagnostics is not None:
        # Add entropy stats to diagnostics
        diagnostics["entropy_mean"] = mean_entropy
        diagnostics["entropy_std"] = std_entropy
        diagnostics["entropy_min"] = min_entropy
        diagnostics["entropy_max"] = max_entropy
        
        # Count unique hashes
        diagnostics["unique_ids_hashes"] = len(set(diagnostics["ids_hashes"])) if diagnostics["ids_hashes"] else 0
        diagnostics["unique_prompt_emb_hashes"] = len(set(diagnostics["prompt_emb_hashes"])) if diagnostics["prompt_emb_hashes"] else 0
        diagnostics["unique_score_vec_hashes"] = len(set(diagnostics["score_vec_hashes"])) if diagnostics["score_vec_hashes"] else 0
        
        metrics["diagnostics"] = diagnostics
    
    # Print diagnostic summary if debug enabled
    if debug:
        print(f"\n{'='*80}")
        print(f"[DEBUG] Diagnostic Summary")
        print(f"{'='*80}")
        print(f"  gold_in_registry_percent: {gold_in_registry_percent:.2f}% ({gold_in_registry_count}/{total_examples_processed})")
        print(f"  Entropy statistics:")
        print(f"    mean: {mean_entropy:.6f}")
        print(f"    std:  {std_entropy:.6f}")
        print(f"    min:  {min_entropy:.6f}")
        print(f"    max:  {max_entropy:.6f}")
        print(f"    expected (uniform): {expected_entropy_uniform:.6f}")
        print(f"  Note: High entropy (~log(N)) suggests uniform distribution. Low entropy suggests peaked/collapsed distribution.")
        print(f"\n  Per-domain accuracy (with support counts):")
        for domain in sorted(domain_total.keys()):
            n_examples = domain_total[domain]
            acc = domain_correct[domain] / n_examples if n_examples > 0 else 0.0
            print(f"    domain_{domain}: {acc:.4f} ({domain_correct[domain]}/{n_examples} examples)")
        
        if family_total:
            print(f"\n  Per-family accuracy (with support counts):")
            for family in sorted(family_total.keys()):
                n_examples = family_total[family]
                acc = family_correct[family] / n_examples if n_examples > 0 else 0.0
                print(f"    family_{family}: {acc:.4f} ({family_correct[family]}/{n_examples} examples)")
        
        print(f"{'='*80}\n")
    
    # Print score margin comparison with training
    print(f"\n{'='*80}")
    print(f"[SCORE MARGIN COMPARISON]")
    print(f"{'='*80}")
    print(f"  Candidate-set score margin (comparable to training): {candidate_margin_mean:.4f}")
    print(f"  All-models score margin: {all_models_margin_mean:.4f}")
    print(f"  Training avg_score_margin (from logs): ~5.88")
    print(f"  {'✓' if candidate_margin_mean > 0 else '✗'} Candidate margin is {'positive' if candidate_margin_mean > 0 else 'negative'}")
    print(f"{'='*80}\n")
    
    # Always print known-domain metrics if enabled (not just in debug mode)
    if known_domain_mode:
        print(f"\n{'='*80}")
        print(f"[KNOWN DOMAIN MODE] Metrics Summary")
        print(f"{'='*80}")
        if known_domain_metrics:
            for k in k_values:
                key = f"known_domain_top{k}_accuracy"
                if key in known_domain_metrics:
                    print(f"  {key}: {known_domain_metrics[key]:.4f}")
            if "known_domain_num_examples" in known_domain_metrics:
                num_examples = known_domain_metrics['known_domain_num_examples']
                print(f"  known_domain_num_examples: {num_examples}")
                print(f"  (out of {num_valid} total valid examples)")
                
                # Diagnostic: compute average domain size
                domain_sizes = []
                for domain in model_registry.get_all_domains():
                    size = len(model_registry.get_domain_models(domain))
                    if size > 0:
                        domain_sizes.append(size)
                if domain_sizes:
                    avg_domain_size = sum(domain_sizes) / len(domain_sizes)
                    min_domain_size = min(domain_sizes)
                    max_domain_size = max(domain_sizes)
                    print(f"  Domain size stats: avg={avg_domain_size:.1f}, min={min_domain_size}, max={max_domain_size}")
                    print(f"  Expected random top-1 accuracy: ~{1.0/avg_domain_size*100:.2f}%")
                
                # Print rank statistics if available
                if "known_domain_gold_rank_median" in known_domain_metrics:
                    print(f"  Gold model rank within domain:")
                    print(f"    median: {known_domain_metrics['known_domain_gold_rank_median']:.1f}")
                    print(f"    mean: {known_domain_metrics['known_domain_gold_rank_mean']:.1f}")
                    print(f"    p90: {known_domain_metrics['known_domain_gold_rank_p90']:.1f}")
                    print(f"    MRR: {known_domain_metrics['known_domain_gold_mrr']:.4f}")
                
                # Print candidate-set accuracy (comparable to training)
                if "known_domain_candidate_top1_accuracy" in known_domain_metrics:
                    cand_acc = known_domain_metrics['known_domain_candidate_top1_accuracy']
                    cand_num = known_domain_metrics['known_domain_candidate_num_examples']
                    all_domain_acc = known_domain_metrics.get('known_domain_top1_accuracy', 0)
                    print(f"  Known-domain candidate-set accuracy (comparable to training):")
                    print(f"    top1_accuracy: {cand_acc:.4f} ({known_domain_cand_correct}/{cand_num} examples)")
                    print(f"    All-domain-models accuracy: {all_domain_acc:.4f} (for comparison)")
                    print(f"    Training top1_accuracy: ~0.59 (59%)")
                    print(f"    {'✓' if cand_acc > 0.5 else '✗'} {'Better' if cand_acc > 0.59 else 'Worse'} than training")
                    if abs(cand_acc - all_domain_acc) < 0.01:
                        print(f"    ⚠️  WARNING: Candidate-set accuracy ({cand_acc:.4f}) is very similar to all-domain accuracy ({all_domain_acc:.4f})")
                        print(f"       This suggests candidate sets aren't helping, or domains are too small for candidate sets.")
        else:
            print(f"  WARNING: No known-domain metrics computed!")
            print(f"  This might indicate a domain mismatch between test data and registry.")
        print(f"{'='*80}\n")
        if diagnostics is not None:
            print(f"[Diagnostics Summary]")
            print(f"  unique_ids_hashes: {diagnostics.get('unique_ids_hashes', 'N/A')}")
            print(f"  unique_prompt_emb_hashes: {diagnostics.get('unique_prompt_emb_hashes', 'N/A')}")
            print(f"  unique_score_vec_hashes: {diagnostics.get('unique_score_vec_hashes', 'N/A')}")
            print(f"{'='*80}\n")
    
    # Always print hierarchical metrics if enabled (not just in debug mode)
    if hierarchical_eval:
        print(f"\n{'='*80}")
        print(f"[HIERARCHICAL EVALUATION] Metrics Summary")
        print(f"{'='*80}")
        if hier_metrics:
            print(f"  Hierarchy level: {hierarchy_level}")
            print(f"  Hierarchical top-k: {hierarchical_topk}")
            print(f"  Domain score mode: {hier_domain_score_mode}")
            print(f"  Domain topk: {hier_domain_topk}")
            if hier_domain_score_mode == "hybrid":
                print(f"  Hybrid alpha: {hier_domain_hybrid_alpha}")
            print(f"  hier_group_accuracy: {hier_metrics.get('hier_group_accuracy', 0.0):.4f}")
            
            # Compare hierarchical domain recall with flat domain accuracy
            # Note: These measure different things and can differ significantly
            if "domain_accuracy" in metrics:
                flat_domain_acc = metrics["domain_accuracy"]
                hier_domain_recall_1 = hier_metrics.get('hier_domain_recall_at1', 0.0)
                if abs(hier_domain_recall_1 - flat_domain_acc) > 0.01:
                    print(f"\n  [Domain Accuracy Comparison]")
                    print(f"    domain_accuracy (flat): {flat_domain_acc:.4f}")
                    print(f"      → Checks if top-1 MODEL's domain (from flat selection) matches gold domain")
                    print(f"    hier_domain_recall_at1: {hier_domain_recall_1:.4f}")
                    print(f"      → Checks if top-1 DOMAIN (from logsumexp aggregation) matches gold domain")
                    print(f"    Difference: {abs(hier_domain_recall_1 - flat_domain_acc):.4f}")
                    print(f"    ⚠️  These differ because domain-level aggregation can rank domains")
                    print(f"       differently than model-level selection (logsumexp vs max)")
            
            # Legacy metrics (top-1 domain only)
            # NOTE: These are CONDITIONAL metrics (only when gold model is in restricted set)
            # They should NOT be compared to E2E metrics which use correct denominators
            print(f"\n  [Legacy Metrics - Top-1 Domain Only (CONDITIONAL)]")
            print(f"    ⚠️  WARNING: These metrics are conditional (denom: {hier_model_total}, only when gold in restricted set)")
            print(f"       They are NOT comparable to E2E metrics which use correct denominators")
            for k in k_values:
                key = f"hier_model_top{k}"
                if key in hier_metrics:
                    print(f"    {key}: {hier_metrics[key]:.4f} (conditional, denom: {hier_model_total})")
            if "hier_e2e_top1" in hier_metrics:
                print(f"    hier_e2e_top1: {hier_metrics['hier_e2e_top1']:.4f} (conditional, NOT true E2E, denom: {hier_model_total})")
            if "hier_restricted_size_mean" in hier_metrics:
                print(f"    Restricted set size stats:")
                print(f"      mean: {hier_metrics['hier_restricted_size_mean']:.1f}")
                print(f"      median: {hier_metrics['hier_restricted_size_median']:.1f}")
                print(f"      p90: {hier_metrics['hier_restricted_size_p90']:.1f}")
            print(f"    Conditional denominator (gold in restricted set): {hier_model_total}")
            
            # Top-N Domain Hierarchical Rerank metrics (N=1,2,3)
            print(f"\n  [Top-N Domain Hierarchical Rerank Metrics]")
            denom_in_registry = hier_metrics.get("hier_gold_in_registry_count", 0)
            denom_total = hier_metrics.get("hier_num_examples_total", 0)
            missing_gold = hier_metrics.get("hier_missing_gold_count", 0)
            
            print(f"  Denominators:")
            print(f"    gold_in_registry_count: {denom_in_registry} (examples where gold model is in registry)")
            print(f"    num_examples_total: {denom_total} (all examples, strict)")
            if missing_gold > 0:
                print(f"    missing_gold_count: {missing_gold} (examples where gold model not in registry)")
            
            for N in [1, 2, 3]:
                total_N = hier_model_total_atN[N]  # Conditional denominator (gold in restricted set)
                print(f"\n    N={N} (Top-{N} Domains):")
                
                # E2E metrics (primary - headline metrics)
                print(f"      [End-to-End Metrics (Primary)]")
                e2e_atN = hier_metrics.get(f'hier_model_top1_e2e_at{N}', 0.0)
                e2e_strict = hier_metrics.get(f'hier_model_top1_e2e_at{N}_strict', 0.0)
                domain_recall = hier_metrics.get(f'hier_domain_recall_at{N}', 0.0)
                domain_recall_strict = hier_metrics.get(f'hier_domain_recall_at{N}_strict', 0.0)
                
                print(f"        hier_model_top1_e2e_at{N}: {e2e_atN:.4f} (denom: {denom_in_registry})")
                print(f"        hier_model_top1_e2e_at{N}_strict: {e2e_strict:.4f} (denom: {denom_total})")
                print(f"        hier_domain_recall_at{N}: {domain_recall:.4f} (fraction where gold domain ∈ top-{N} domains)")
                print(f"        hier_domain_recall_at{N}_strict: {domain_recall_strict:.4f}")
                # Note: hier_domain_recall_at1 may differ from domain_accuracy because:
                # - domain_accuracy: checks if top-1 MODEL's domain (from flat model selection) matches gold domain
                # - hier_domain_recall_at1: checks if top-1 DOMAIN (from logsumexp aggregation) matches gold domain
                # These can differ when domain-level aggregation ranks domains differently than model-level selection
                
                # Conditional metrics (diagnostic - only when gold domain in top-N)
                print(f"      [Conditional Metrics (Diagnostic)]")
                if total_N > 0:
                    cond_atN = hier_metrics.get(f'hier_model_top1_cond_at{N}', 0.0)
                    flat_top1 = hier_metrics.get('top1_accuracy', 0.0) if 'top1_accuracy' in hier_metrics else 0.0
                    print(f"        hier_model_top1_cond_at{N}: {cond_atN:.4f} (denom: {total_N}, only when gold domain ∈ top-{N})")
                    print(f"        Flat top1_accuracy (for comparison): {flat_top1:.4f}")
                    if cond_atN < flat_top1 and N == 1:
                        print(f"        ⚠️  NOTE: Conditional accuracy ({cond_atN:.4f}) < Flat accuracy ({flat_top1:.4f})")
                        print(f"           This can happen because:")
                        print(f"           - Hierarchical compares within UNION of top-{N} domain(s), not just gold domain")
                        print(f"           - Even when gold domain is selected, models from other selected domains can outrank gold")
                        print(f"           - Or: gold model may not be router's top choice even within its own domain")
                    print(f"        hier_model_top5_at{N}: {hier_metrics.get(f'hier_model_top5_at{N}', 0.0):.4f}")
                    print(f"        hier_model_top10_at{N}: {hier_metrics.get(f'hier_model_top10_at{N}', 0.0):.4f}")
                else:
                    print(f"        No examples where gold domain ∈ top-{N} domains")
                
                # Legacy metrics (for backward compatibility)
                if total_N > 0:
                    print(f"      [Legacy Metrics (Conditional)]")
                    print(f"        hier_model_top1_at{N}: {hier_metrics.get(f'hier_model_top1_at{N}', 0.0):.4f} (same as cond)")
                
                # Restricted set size stats
                print(f"      Restricted set size stats:")
                print(f"        mean: {hier_metrics.get(f'hier_restricted_size_mean_at{N}', 0.0):.1f}")
                print(f"        median: {hier_metrics.get(f'hier_restricted_size_median_at{N}', 0.0):.1f}")
                print(f"        p90: {hier_metrics.get(f'hier_restricted_size_p90_at{N}', 0.0):.1f}")
                print(f"      Conditional denominator (gold in restricted set): {total_N}")
            
            # Validation: E2E should equal domain_recall * conditional_accuracy
            for N in [1, 2, 3]:
                domain_recall_val = hier_metrics.get(f'hier_domain_recall_at{N}', 0.0)
                cond_acc_val = hier_metrics.get(f'hier_model_top1_cond_at{N}', 0.0)
                e2e_val = hier_metrics.get(f'hier_model_top1_e2e_at{N}', 0.0)
                expected_e2e = domain_recall_val * cond_acc_val
                if abs(e2e_val - expected_e2e) > 1e-4 and domain_recall_val > 0:
                    print(f"\n  ⚠️  WARNING: E2E validation failed for N={N}")
                    print(f"     hier_model_top1_e2e_at{N}: {e2e_val:.6f}")
                    print(f"     hier_domain_recall_at{N} * hier_model_top1_cond_at{N}: {expected_e2e:.6f}")
                    print(f"     Difference: {abs(e2e_val - expected_e2e):.6f}")
                    print(f"     This indicates a bug in the metric computation!")
                elif domain_recall_val > 0:
                    print(f"\n  ✓ Validation: E2E@N={N} = domain_recall * conditional ({e2e_val:.6f} = {domain_recall_val:.6f} * {cond_acc_val:.6f})")
            
            # Validation: hier_model_top1_at1 should match hier_model_top1 (legacy)
            if "hier_model_top1" in hier_metrics and "hier_model_top1_at1" in hier_metrics:
                legacy_val = hier_metrics["hier_model_top1"]
                new_val = hier_metrics["hier_model_top1_at1"]
                if abs(legacy_val - new_val) > 1e-5:
                    print(f"\n  ⚠️  WARNING: hier_model_top1 ({legacy_val:.4f}) != hier_model_top1_at1 ({new_val:.4f})")
                    print(f"     This may indicate an implementation difference.")
                else:
                    print(f"\n  ✓ Validation: hier_model_top1_at1 matches hier_model_top1 (legacy)")
            
            # Debug: Print filtering statistics (now showing domain recall instead of filtering)
            if debug:
                print(f"\n  [Hierarchical Evaluation Domain Recall Debug]")
                denom_in_registry = hier_metrics.get("hier_gold_in_registry_count", 0)
                if denom_in_registry > 0:
                    print(f"    Domain recall (fraction where gold domain ∈ predicted top-N):")
                    for N in [1, 2, 3]:
                        included = hier_domain_included_atN.get(N, 0)
                        recall = hier_metrics.get(f'hier_domain_recall_at{N}', 0.0)
                        print(f"      N={N}: {included}/{denom_in_registry} = {recall:.4f} ({100*(1-recall):.1f}% where gold domain NOT in top-{N})")
                if hier_gold_domain_missing_from_registry > 0:
                    print(f"    ⚠️  WARNING: {hier_gold_domain_missing_from_registry} examples have gold models whose domain is missing from registry!")
                missing_gold = hier_metrics.get("hier_missing_gold_count", 0)
                if missing_gold > 0:
                    print(f"    Missing gold count: {missing_gold} (examples where gold model not in registry)")
        else:
            print(f"  WARNING: No hierarchical metrics computed!")
        print(f"{'='*80}\n")
    
    # Add compute metrics to return dict
    compute_summary = eval_compute_tracker.get_summary()
    if compute_summary["total_examples"] > 0:
        metrics["compute"] = compute_summary
        metrics["total_flops"] = compute_summary["total_flops"]
        metrics["total_flops_gflops"] = compute_summary["total_flops_gflops"]
        metrics["flops_per_example"] = compute_summary["flops_per_example"]
        metrics["eval_total_examples"] = compute_summary["total_examples"]
        metrics["eval_total_batches"] = compute_summary["total_batches"]
    
    return metrics


def save_router_predictions(
    predictions: List[int],
    scores: List[List[float]],
    test_data: List[Dict[str, Any]],
    model_registry: ModelRegistry,
    output_path: Path,
):
    """
    Save router predictions to a JSON file for analysis.
    
    Args:
        predictions: List of predicted model indices
        scores: List of score vectors (one per example)
        test_data: Original test examples
        model_registry: ModelRegistry for name lookups
        output_path: Path to save predictions
    """
    output = []
    for pred_idx, score_vec, example in zip(predictions, scores, test_data):
        pred_name = model_registry.idx2model[pred_idx]
        true_name = example['model_name']
        
        # Get top-5 predictions for this example
        top5_scores, top5_indices = torch.tensor(score_vec).topk(k=5)
        top5_predictions = [
            {
                "model_name": model_registry.idx2model[idx],
                "score": score.item(),
            }
            for idx, score in zip(top5_indices.tolist(), top5_scores.tolist())
        ]
        
        output.append({
            "prompt": example['prompt_text'],
            "ground_truth": true_name,
            "prediction": pred_name,
            "correct": pred_name == true_name,
            "domain_ground_truth": example.get('domain', 'unknown'),
            "top5_predictions": top5_predictions,
        })
    
    with open(output_path, 'w') as f:
        json.dump(output, f, indent=2)
    
    print(f"✓ Saved router predictions to {output_path}")


def main_router_eval(
    checkpoint_dir: Path,
    test_data_path: Path,
    output_dir: Optional[Path] = None,
    embedding_dim: int = 256,
    batch_size: int = 32,
    device: str = "cuda",
):
    """
    Main function for router evaluation.
    
    Args:
        checkpoint_dir: Directory containing trained router and model registry
        test_data_path: Path to test data JSON file
        output_dir: Optional directory to save predictions
        embedding_dim: Router embedding dimension
        batch_size: Evaluation batch size
        device: Device for computation
    """
    # Load model registry
    registry_path = checkpoint_dir / "model_registry.json"
    if not registry_path.exists():
        raise FileNotFoundError(f"Model registry not found at {registry_path}")
    
    with open(registry_path, 'r') as f:
        registry_data = json.load(f)
    
    # Reconstruct ModelRegistry
    # This requires storing the registry in a format we can reload
    # For now, we'll need to rebuild it from the test data
    # TODO: Implement ModelRegistry.save() and ModelRegistry.load() methods
    
    print("TODO: Implement model registry save/load for evaluation")
    print("For now, router evaluation requires rebuilding the registry from training data.")

