"""
Custom SFTTrainer with Neighbour-Consistency and Neighbour-Contrastive Regularisation




"""

import os
import random
from pathlib import Path
from typing import Dict, List, Optional, Any, Union, Tuple
from collections import defaultdict

import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader

from trl import SFTTrainer, SFTConfig
from transformers import PreTrainedModel, PreTrainedTokenizer
from transformers.trainer_callback import TrainerCallback


from .xclr import (
     enable_hidden_states,
     get_last_hidden_states,
     canonicalize_domain,
 )
from .replay_candidate_pool import ReplayCandidatePool, build_replay_candidate_pool, extract_instruction_only
from .router_training import (
    ROUTER_DEBUG,
    compute_routing_loss,
    compute_label_graph_regularizer,
    compute_router_metrics,
)
from .compute_tracker import ComputeTracker
from .router_debug_checks import RouterDebugChecker
from .router_exceptions import RouterTrainingError
from ..model_selection_carve import (
    ModelRegistry,
    normalize_domain,
    normalize_model_name,
    CandidateSetBuilder,
    HardNegativeMiner,
    RouterModel,
    CompositeModelWithRouter,
)
from ..model_selection_carve.router import extract_prompt_mask




class MetadataPreservingDataCollator:
    """
    Data collator that wraps another collator and preserves metadata fields.
    
    TRL's SFTTrainer tokenizes datasets and the default collator only handles
    tensor fields. This wrapper preserves non-tensor metadata like is_replay,
    model_name, and domain as Python lists in the batch.
    """
    
    METADATA_FIELDS = ["is_replay", "model_name", "domain"]
    
    def __init__(self, base_collator):
        self.base_collator = base_collator
    
    def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
        # Extract metadata before base collation
        metadata = {field: [] for field in self.METADATA_FIELDS}
        
        for feature in features:
            for field in self.METADATA_FIELDS:
                if field in feature:
                    metadata[field].append(feature[field])
        
        # Call base collator (handles input_ids, attention_mask, labels)
        batch = self.base_collator(features)
        
        # Add metadata back to batch
        for field, values in metadata.items():
            if values and len(values) == len(features):
                batch[field] = values
        
        return batch


class ConsistencyLoggingCallback(TrainerCallback):
    """Callback to log consistency, contrastive, and X-CLR loss metrics at epoch end."""
    
    def __init__(self, trainer: "NeighborConsistencySFTTrainer"):
        self.trainer = trainer
    
    def on_epoch_end(self, args, state, control, **kwargs):
        """Log and reset metrics at epoch end."""

        
        # Log router metrics if enabled
        if hasattr(self.trainer, "_loss_mode") and self.trainer._loss_mode in ["router", "router+graph", "supervised+router", "supervised+router+graph"]:
            metrics = self.trainer.get_router_metrics()
            if metrics:
                import logging
                logger = logging.getLogger(__name__)
                logger.info(f"\n  [Router Metrics @ Epoch {state.epoch:.0f}]")
                for key, value in metrics.items():
                    if isinstance(value, float):
                        logger.info(f"    {key}: {value:.4f}")
                    else:
                        logger.info(f"    {key}: {value}")
        
        self.trainer.reset_consistency_metrics()
    
    def on_save(self, args, state, control, **kwargs):
        """Save router checkpoint when a checkpoint is saved during training."""
        # Save router checkpoint if router training was enabled
        if hasattr(self.trainer, "_loss_mode") and self.trainer._loss_mode in ["router", "router+graph", "supervised+router", "supervised+router+graph"]:
            if hasattr(self.trainer, "_router_model") and self.trainer._router_model is not None:
                # Determine checkpoint directory
                # HuggingFace Trainer saves checkpoints to subdirectories like "checkpoint-{step}" or "checkpoint-{epoch}"
                # We find the most recently created checkpoint directory
                from pathlib import Path
                import os
                import json
                output_dir = Path(args.output_dir)
                
                # Try to find the checkpoint directory that was just created
                # Check both step-based and epoch-based naming
                checkpoint_dir = None
                possible_names = [
                    f"checkpoint-{state.global_step}",  # Step-based
                    f"checkpoint-{int(state.epoch)}",    # Epoch-based
                ]
                
                for name in possible_names:
                    candidate = output_dir / name
                    if candidate.exists():
                        checkpoint_dir = candidate
                        break
                
                # If not found by name, find the most recently created checkpoint directory
                if checkpoint_dir is None:
                    checkpoint_dirs = [d for d in output_dir.iterdir() 
                                     if d.is_dir() and d.name.startswith("checkpoint-")]
                    if checkpoint_dirs:
                        # Sort by modification time, most recent first
                        checkpoint_dirs.sort(key=lambda x: os.path.getmtime(x), reverse=True)
                        checkpoint_dir = checkpoint_dirs[0]
                
                if checkpoint_dir and checkpoint_dir.exists():
                    import logging
                    logger = logging.getLogger(__name__)
                    logger.info(f"\n[Router] Saving router checkpoint to {checkpoint_dir}")
                    self.trainer.save_router_checkpoint(str(checkpoint_dir))
                    
                    # Save compute metrics to checkpoint metrics.json
                    if hasattr(self.trainer, "_compute_tracker"):
                        import json
                        compute_summary = self.trainer._compute_tracker.get_summary()
                        metrics_file = checkpoint_dir / "metrics.json"
                        
                        # Load existing metrics if file exists, otherwise create new dict
                        if metrics_file.exists():
                            try:
                                with open(metrics_file, 'r') as f:
                                    existing_metrics = json.load(f)
                            except (json.JSONDecodeError, IOError):
                                existing_metrics = {}
                        else:
                            existing_metrics = {}
                        
                        # Add compute metrics to existing metrics
                        if compute_summary["total_examples"] > 0:
                            existing_metrics["router_compute"] = compute_summary
                            existing_metrics["router_total_flops"] = compute_summary["total_flops"]
                            existing_metrics["router_total_flops_gflops"] = compute_summary["total_flops_gflops"]
                            existing_metrics["router_flops_per_example"] = compute_summary["flops_per_example"]
                            existing_metrics["router_total_examples_processed"] = compute_summary["total_examples"]
                            existing_metrics["router_total_batches_processed"] = compute_summary["total_batches"]
                            existing_metrics["router_epoch"] = int(state.epoch)
                            existing_metrics["router_global_step"] = state.global_step
                            
                            # Save updated metrics
                            with open(metrics_file, 'w') as f:
                                json.dump(existing_metrics, f, indent=2)
                            logger.info(f"  Saved compute metrics to {metrics_file}")
    
    def on_train_end(self, args, state, control, **kwargs):
        """Save router checkpoint at end of training (fallback to parent directory for backwards compatibility)."""
        # Save router checkpoint if router training was enabled
        # This is kept for backwards compatibility, but on_save should handle most cases
        if hasattr(self.trainer, "_loss_mode") and self.trainer._loss_mode in ["router", "router+graph", "supervised+router", "supervised+router+graph"]:
            if hasattr(self.trainer, "_router_model") and self.trainer._router_model is not None:
                # Only save to parent if no checkpoint-specific save happened
                # Check if there's a latest checkpoint directory
                from pathlib import Path
                output_dir = Path(args.output_dir)
                checkpoint_dirs = sorted([d for d in output_dir.iterdir() if d.is_dir() and d.name.startswith("checkpoint-")])
                
                if checkpoint_dirs:
                    # Save to the latest checkpoint directory
                    latest_checkpoint = checkpoint_dirs[-1]
                    import logging
                    logger = logging.getLogger(__name__)
                    logger.info(f"\n[Router] Saving router checkpoint to latest checkpoint: {latest_checkpoint}")
                    self.trainer.save_router_checkpoint(str(latest_checkpoint))
                else:
                    # Fallback to parent directory if no checkpoints exist
                    import logging
                    logger = logging.getLogger(__name__)
                    logger.info(f"\n[Router] Saving router checkpoint to {args.output_dir}")
                    self.trainer.save_router_checkpoint(args.output_dir)


class NeighborConsistencySFTTrainer(SFTTrainer):
    """
    SFTTrainer with optional neighbour-consistency regularisation.
    
    When use_neighbor_consistency=True, this trainer:
    1. Retrieves similar prompts for each training example
    2. Computes the standard supervised loss
    3. Adds a KL-divergence consistency loss between anchor and neighbor predictions
    4. Returns the combined loss: L_supervised + weight * L_consistency
    
    The neighbor index should be built from previous experience data (e.g., APIBench)
    and passed during initialization.
    


    """
    
    def __init__(
        self,
        model: Union[PreTrainedModel, nn.Module],
        args: SFTConfig,
        train_dataset,
        processing_class: Optional[PreTrainedTokenizer] = None,
        eval_dataset=None,
        callbacks=None,

        # ======================================================================
        # Router Training Parameters (Semantic Batching + Candidate-Set Routing)
        # ======================================================================
        loss_mode: str = "supervised",  # "supervised", "router", "router+graph"
        router_loss_weight: float = 1.0,  # Weight for routing loss
        lm_loss_weight: float = 1.0,  # Weight for LM supervised loss
        # Semantic batching (for router training)
        semantic_batching: bool = False,  # Enable domain-based semantic batching
        domains_per_batch: int = 1,  # Number of domains per batch
        mix_replay_in_semantic_batches: bool = True,  # Mix replay into semantic batches
        # Router architecture
        router_embedding_dim: Optional[int] = None,  # Defaults to hidden_size
        router_tau: float = 0.07,  # Temperature for scaling logits
        router_pooling: str = "last_token",  # "last_token" or "mean"
        # Router learning rates (split by parameter group)
        router_proj_lr: Optional[float] = None,  # Learning rate for projection head (None = use args.learning_rate)
        router_embedding_lr: Optional[float] = None,  # Learning rate for embedding table (None = use args.learning_rate)
        # Candidate sampling
        router_K_total: int = 64,
        router_K_semantic: int = 48,
        router_K_far: int = 8,
        router_K_hard: int = 7,
        # Hard negative mining
        router_mine_every_steps: int = 200,
        router_K_hard_pool: int = 20,
        router_semantic_pool_size: int = 512,
        router_max_pool_size: int = 1024,
        # Semantic pool expansion (Option B)
        router_semantic_pool_mode: str = "parent_group",
        router_semantic_pool_max_domains: Optional[int] = None,
        router_semantic_pool_depth: int = 1,
        # Soft targets
        router_use_soft_targets: bool = False,
        router_soft_target_eps: float = 0.1,
        router_soft_target_k_neighbors: int = 5,
        # Label-side graph regularizer
        router_use_label_graph_reg: bool = False,
        router_label_graph_lambda: float = 0.1,
        router_label_graph_tau: float = 0.07,
        router_label_graph_tau_target: float = 0.1,
        router_label_graph_max_models: int = 256,
        router_label_graph_alpha_domain: float = 0.3,
        # Model registry
        router_registry_path: Optional[str] = None,
        router_registry_init_mode: str = "extend",
        router_registry_base_path: Optional[str] = None,
        # Router debug parameters
        debug_router_supervision: bool = False,
        debug_router_every: int = 100,
        debug_router_first_steps: int = 50,
        debug_router_strict: bool = False,
        # Two-phase training schedule (for Experience 2+ to reduce forgetting)
        router_two_phase_enable: bool = False,
        router_phase1_frac: float = 0.2,
        router_phase1_loss_mode: str = "router",
        router_phase1_replay_ratio: Optional[float] = None,
        router_phase1_router_loss_weight: float = 1.0,
        router_phase1_lm_loss_weight: float = 0.0,
        router_phase1_proj_lr: Optional[float] = None,
        router_phase1_embedding_lr: Optional[float] = None,
        router_phase1_use_soft_targets: bool = False,
        router_phase1_soft_target_eps: float = 0.02,
        router_replay_loss_multiplier: float = 1.0,
        # Exp1-preservation training mode
        router_exp1_preservation_enable: bool = False,
        router_exp1_preservation_M_old: Optional[int] = None,
        # Router embedding anchoring regularizer
        router_anchor_enable: bool = False,
        router_anchor_lambda: float = 1e-3,
        router_anchor_mode: str = "normalized",
        router_anchor_apply_phase: str = "phase1",
        router_anchor_scope: str = "all_old",
        router_anchor_M_old: Optional[int] = None,
        # Router projection anchoring regularizer
        router_proj_anchor_enable: bool = False,
        router_proj_anchor_lambda: float = 1e-2,
        router_proj_anchor_apply_phase: str = "phase1",
        # Router freeze LM option (for router-only runs)
        router_freeze_lm: bool = False,
        **kwargs
    ):
        """
        Initialize the trainer.
        
        Args:
            model: The model to train
            args: SFTConfig training arguments
            train_dataset: Training dataset
            processing_class: Tokenizer
            eval_dataset: Optional evaluation dataset
            callbacks: Optional list of callbacks
            

            

            
            **kwargs: Additional arguments for SFTTrainer
            
        Note on multi-model routing:
            In scenarios where similar prompts may route to different models, enforcing
            output consistency can be counterproductive. Use neighbor_replay_only=True
            to only apply consistency between replay examples (both anchor and neighbor
            from the replay buffer). This preserves knowledge about old models without
            incorrectly pushing new model prompts toward old outputs.
            
            Alternatively, use neighbor_contrastive=True with apply_to="replay_only"
            to add a ranking loss that DISCRIMINATES between similar prompts with
            different target models (hard negative mining).
        """
        # Initialize callbacks list
        if callbacks is None:
            callbacks = []
        
        # # Neighbor consistency and contrastive features have been removed
        # self._use_neighbor_consistency = False
        # self._neighbor_retriever = None
        # self._neighbor_k = 3
        # self._consistency_weight = 0.1
        # self._consistency_temperature = 1.0
        # self._raw_train_prompts = []
        
        # # Additional consistency refinement parameters (disabled)
        # self._neighbor_max_consistency_samples = 4
        # self._neighbor_consistency_num_tokens = 1
        # self._neighbor_domain_filter_mode = "none"
        # self._neighbor_domain_bias = 0.0
        # self._neighbor_min_same_domain = 1
        # self._neighbor_replay_only = False
        # self._neighbor_replay_similarity_threshold = 0.95
        

        
        # Initialize tracking metrics
        self._consistency_loss_sum = 0.0
        self._consistency_loss_count = 0
        self._supervised_loss_sum = 0.0
        self._supervised_loss_count = 0
        
        # Contrastive loss tracking
        self._contrastive_loss_sum = 0.0
        self._contrastive_loss_count = 0
        self._contrastive_anchors_used = 0
        self._contrastive_negatives_used = 0
        

        

        
        # Extract custom parameters that SFTTrainer doesn't accept
        replay_source_examples = kwargs.pop('replay_source_examples', None)
        
        # Call parent init
        super().__init__(
            model=model,
            args=args,
            train_dataset=train_dataset,
            processing_class=processing_class,
            eval_dataset=eval_dataset,
            callbacks=callbacks,
            **kwargs
        )
        
        # Store replay_source_examples for later use
        self._xclr_replay_source_examples = replay_source_examples
        
        # Add our logging callback if any auxiliary loss mode is enabled
        # Also add for router training modes to ensure router checkpoint is saved
        if loss_mode in ["router", "router+graph", "supervised+router", "supervised+router+graph"]:
            self.add_callback(ConsistencyLoggingCallback(self))
        
  
        # ======================================================================
        # Router Training Initialization (Semantic Batching + Candidate Routing)
        # ======================================================================
        self._loss_mode = loss_mode
        self._router_loss_weight = router_loss_weight
        self._lm_loss_weight = lm_loss_weight
        # Semantic batching
        self._semantic_batching = semantic_batching
        self._domains_per_batch = domains_per_batch
        self._mix_replay_in_semantic_batches = mix_replay_in_semantic_batches
        # Router components
        self._router_model = None
        self._router_registry = None
        self._compute_tracker = ComputeTracker()  # Track FLOPs and compute metrics
        self._router_candidate_builder = None
        self._router_hard_miner = None
        self._router_hard_negative_cache = {}
        self._router_use_soft_targets = router_use_soft_targets
        self._router_soft_target_eps = router_soft_target_eps
        self._router_soft_target_k_neighbors = router_soft_target_k_neighbors
        self._router_use_label_graph_reg = router_use_label_graph_reg
        self._router_label_graph_lambda = router_label_graph_lambda
        self._router_label_graph_tau = router_label_graph_tau
        self._router_label_graph_tau_target = router_label_graph_tau_target
        self._router_label_graph_max_models = router_label_graph_max_models
        self._router_label_graph_alpha_domain = router_label_graph_alpha_domain
        self._router_mine_every_steps = router_mine_every_steps
        # Router learning rates (ensure they are floats)
        self._router_proj_lr = float(router_proj_lr) if router_proj_lr is not None else float(args.learning_rate)
        self._router_embedding_lr = float(router_embedding_lr) if router_embedding_lr is not None else float(args.learning_rate)
        # Router debug parameters
        self._debug_router_supervision = debug_router_supervision
        self._debug_router_every = debug_router_every
        self._debug_router_first_steps = debug_router_first_steps
        self._debug_router_strict = debug_router_strict
        
        # Two-phase training schedule (for Experience 2+ to reduce forgetting)
        self._router_two_phase_enable = router_two_phase_enable
        self._router_phase1_frac = router_phase1_frac
        self._router_phase1_loss_mode = router_phase1_loss_mode
        self._router_phase1_replay_ratio = router_phase1_replay_ratio
        self._router_phase1_router_loss_weight = router_phase1_router_loss_weight
        self._router_phase1_lm_loss_weight = router_phase1_lm_loss_weight
        self._router_phase1_proj_lr = float(router_phase1_proj_lr) if router_phase1_proj_lr is not None else self._router_proj_lr
        self._router_phase1_embedding_lr = float(router_phase1_embedding_lr) if router_phase1_embedding_lr is not None else self._router_embedding_lr
        self._router_phase1_use_soft_targets = router_phase1_use_soft_targets
        self._router_phase1_soft_target_eps = router_phase1_soft_target_eps
        self._router_replay_loss_multiplier = router_replay_loss_multiplier
        
        # Exp1-preservation training mode
        self._router_exp1_preservation_enable = router_exp1_preservation_enable
        self._router_exp1_preservation_M_old = router_exp1_preservation_M_old
        # Router embedding anchoring regularizer
        self._router_anchor_enable = router_anchor_enable
        self._router_anchor_lambda = float(router_anchor_lambda) if router_anchor_lambda is not None else 1e-3
        self._router_anchor_mode = router_anchor_mode
        self._router_anchor_apply_phase = router_anchor_apply_phase
        self._router_anchor_scope = router_anchor_scope
        self._router_anchor_M_old = router_anchor_M_old
        self._router_anchor_ref_cpu = None  # Reference snapshot on CPU (FP32, compact)
        self._router_anchor_ref = None  # Cached device copy (materialized on-demand)
        # Projection anchoring
        self._router_proj_anchor_enable = router_proj_anchor_enable
        self._router_proj_anchor_lambda = float(router_proj_anchor_lambda) if router_proj_anchor_lambda is not None else 1e-2
        self._router_proj_anchor_apply_phase = router_proj_anchor_apply_phase
        self._router_proj_anchor_ref_cpu = None  # Reference snapshot of projection weights (CPU, FP32)
        self._router_proj_anchor_ref = None  # Cached device copy (materialized on-demand)
        # Router freeze LM option
        self._router_freeze_lm = router_freeze_lm
        self._exp1_preservation_applied = False  # Track if freezing has been applied
        self._exp1_preservation_hooks = []  # Store gradient hook handles for cleanup
        # Store base registry path for saving to router_config.json (for evaluation diagnostics)
        self._router_registry_base_path = router_registry_base_path
        
        # Phase tracking (computed during training setup)
        self._phase1_steps = None  # Will be computed once during training setup
        self._current_phase = None  # Will be set in compute_loss
        self._phase_transition_logged = False  # Track if we've logged phase transitions
        self._experience_start_global_step = None  # Track when this experience started (for per-experience phase switching)
        self._lm_param_groups_original_lr = None  # Store original LM param group LRs for restoration
        
        # Store original config values (for Phase 2)
        self._original_loss_mode = loss_mode
        self._original_router_loss_weight = router_loss_weight
        self._original_lm_loss_weight = lm_loss_weight
        self._original_router_proj_lr = self._router_proj_lr
        self._original_router_embedding_lr = self._router_embedding_lr
        self._original_router_use_soft_targets = router_use_soft_targets
        self._original_router_soft_target_eps = router_soft_target_eps
        
        # Track if LM params are frozen (for Phase 1)
        self._lm_params_frozen = False
        self._lm_param_groups_original = None  # Store original optimizer param groups
        
        # Router metrics tracking
        self._router_loss_sum = 0.0
        self._router_loss_count = 0
        self._router_graph_loss_sum = 0.0
        self._router_graph_loss_count = 0
        
        # Semantic batching validation
        self._semantic_batch_validation_count = 0
        self._semantic_batch_validation_max = 10  # Validate first 10 batches
        
        # Initialize debug checker for runtime invariant validation
        self._router_debug_checker = RouterDebugChecker(max_check_steps=20)
        
        # Track if we've already logged the frozen LM eval mode message
        self._frozen_lm_eval_mode_logged = False
        
        if loss_mode in ["router", "router+graph", "supervised+router", "supervised+router+graph"]:
            print(f"\n[NeighborConsistencySFTTrainer] Router training mode initialized:")
            print(f"  loss_mode: {loss_mode}")
            print(f"  router_loss_weight: {router_loss_weight}")
            print(f"  lm_loss_weight: {lm_loss_weight}")
            print(f"  tau: {router_tau}")
            print(f"  pooling: {router_pooling}")
            print(f"  K_total: {router_K_total} (K_semantic={router_K_semantic}, K_far={router_K_far}, K_hard={router_K_hard})")
            print(f"  mine_every_steps: {router_mine_every_steps}")
            print(f"  use_soft_targets: {router_use_soft_targets}")
            if router_use_soft_targets:
                print(f"  soft_target_eps: {router_soft_target_eps}")
                print(f"  soft_target_k_neighbors: {router_soft_target_k_neighbors}")
            print(f"  semantic_batching: {semantic_batching}")
            if semantic_batching:
                print(f"    domains_per_batch: {domains_per_batch}")
                print(f"    mix_replay: {mix_replay_in_semantic_batches}")
            if loss_mode in ["router+graph", "supervised+router+graph"]:
                print(f"  label_graph_lambda: {router_label_graph_lambda}")
                print(f"  label_graph_alpha_domain: {router_label_graph_alpha_domain}")
            
            # Build model registry
            print(f"\n[Router] Building model registry...")
            train_examples_for_registry = []
            
            # Determine if we need to process ALL examples (for joint training or large datasets)
            # Check dataset size - if it's large, we should process all to avoid missing models
            train_dataset_size = len(train_dataset) if hasattr(train_dataset, '__len__') else None
            eval_dataset_size = len(eval_dataset) if eval_dataset and hasattr(eval_dataset, '__len__') else None
            
            # Use limits only for smaller datasets (sequential training)
            # For large datasets (likely joint training or 100% replay), process all examples
            max_train_examples = None
            max_eval_examples = None
            
            # For registry building, we need to see all examples to register all models.
            # Only apply limits for very large datasets (>50k) to avoid memory issues.
            # For 100% replay scenarios, combined datasets (previous + current experience)
            # can be 10k-20k examples, so we should process all of them.
            if train_dataset_size and train_dataset_size <= 10000:
                max_train_examples = 10000  # For small datasets, limit equals size (process all)
            elif train_dataset_size and train_dataset_size <= 50000:
                # For medium datasets (including 100% replay), process all examples
                max_train_examples = None
                print(f"  [Medium Dataset Detected] Processing ALL {train_dataset_size} training examples (no limit for registry)")
            elif train_dataset_size:
                print(f"  [Large Dataset Detected] Processing ALL {train_dataset_size} training examples (no limit)")
            
            if eval_dataset_size and eval_dataset_size <= 10000:
                max_eval_examples = 5000  # Limit for smaller datasets
            elif eval_dataset_size:
                print(f"  [Large Dataset Detected] Processing ALL {eval_dataset_size} eval examples (no limit)")
            
            # Collect from training dataset
            if hasattr(train_dataset, '__iter__'):
                try:
                    for example in train_dataset:
                        if isinstance(example, dict):
                            train_examples_for_registry.append(example)
                        if max_train_examples and len(train_examples_for_registry) >= max_train_examples:
                            print(f"  Reached training example limit ({max_train_examples}), stopping collection")
                            break
                except Exception as e:
                    print(f"  Warning: Could not iterate train_dataset for registry: {e}")
            
            # Collect from replay source examples
            if self._xclr_replay_source_examples:
                train_examples_for_registry.extend(self._xclr_replay_source_examples)
            
            # Collect from eval dataset if available
            eval_examples_for_registry = []
            if eval_dataset and hasattr(eval_dataset, '__iter__'):
                try:
                    for example in eval_dataset:
                        if isinstance(example, dict):
                            eval_examples_for_registry.append(example)
                        if max_eval_examples and len(eval_examples_for_registry) >= max_eval_examples:
                            print(f"  Reached eval example limit ({max_eval_examples}), stopping collection")
                            break
                except Exception as e:
                    print(f"  Warning: Could not iterate eval_dataset for registry: {e}")
            
            all_examples_for_registry = train_examples_for_registry + eval_examples_for_registry
            
            # Determine base registry path for extend mode
            base_registry_path = router_registry_base_path
            if router_registry_init_mode == "extend" and not base_registry_path:
                # Try to derive from resume_from_checkpoint if available
                if hasattr(self.args, 'resume_from_checkpoint') and self.args.resume_from_checkpoint:
                    resume_path = self.args.resume_from_checkpoint
                    # Check if it's a checkpoint directory
                    if os.path.isdir(resume_path):
                        potential_registry = os.path.join(resume_path, "model_registry.json")
                        if os.path.exists(potential_registry):
                            base_registry_path = potential_registry
                            print(f"  Derived base registry path from resume_from_checkpoint: {base_registry_path}")
                
                # If still not found, try to find previous experience's latest checkpoint
                # This is useful when using experiences_sequence (e.g., apibench -> mllm)
                if not base_registry_path:
                    # Try to find the latest checkpoint in the output directory's parent
                    # Pattern: {output_root}/{prev_experience}-{variant}/checkpoint-{step}/model_registry.json
                    output_dir = getattr(self.args, 'output_dir', None)
                    if output_dir:
                        output_path = Path(output_dir)
                        # Check if we're in a checkpoint subdirectory
                        if output_path.name.startswith('checkpoint-'):
                            # We're inside a checkpoint, go up to experiment root
                            exp_root = output_path.parent
                        else:
                            # We're at the experiment root
                            exp_root = output_path
                        
                        # Look for checkpoints in this directory
                        checkpoint_dirs = sorted(
                            [d for d in exp_root.iterdir() if d.is_dir() and d.name.startswith('checkpoint-')],
                            key=lambda x: int(x.name.split('-')[1]) if x.name.split('-')[1].isdigit() else 0,
                            reverse=True
                        )
                        
                        # Check the latest checkpoint for model_registry.json
                        for checkpoint_dir in checkpoint_dirs:
                            potential_registry = checkpoint_dir / "model_registry.json"
                            if potential_registry.exists():
                                base_registry_path = str(potential_registry)
                                print(f"  Auto-found base registry from latest checkpoint: {base_registry_path}")
                                break
                        
                        # If not found in current exp dir, try to find previous experience's directory
                        # This handles the case: apibench -> mllm (different experience names)
                        if not base_registry_path and exp_root.parent.exists():
                            # Look for other experience directories in the same output_root
                            parent_dir = exp_root.parent
                            # Get current experience name from directory name (e.g., "mllm-variant" -> "mllm")
                            current_exp_name = exp_root.name.split('-')[0] if '-' in exp_root.name else None
                            
                            if current_exp_name:
                                # Find directories for other experiences (previous ones)
                                other_exp_dirs = [
                                    d for d in parent_dir.iterdir()
                                    if d.is_dir() and not d.name.startswith(current_exp_name) and '-' in d.name
                                ]
                                
                                # Sort by modification time (most recent first) and check for registries
                                other_exp_dirs.sort(key=lambda x: x.stat().st_mtime, reverse=True)
                                
                                for other_exp_dir in other_exp_dirs:
                                    # Find latest checkpoint in this experience directory
                                    other_checkpoints = sorted(
                                        [d for d in other_exp_dir.iterdir() 
                                         if d.is_dir() and d.name.startswith('checkpoint-')],
                                        key=lambda x: int(x.name.split('-')[1]) if x.name.split('-')[1].isdigit() else 0,
                                        reverse=True
                                    )
                                    
                                    for checkpoint_dir in other_checkpoints:
                                        potential_registry = checkpoint_dir / "model_registry.json"
                                        if potential_registry.exists():
                                            base_registry_path = str(potential_registry)
                                            print(f"  Auto-found base registry from previous experience ({other_exp_dir.name}): {base_registry_path}")
                                            break
                                    
                                    if base_registry_path:
                                        break
                
                # Warn loudly if extend mode requested but no base path found
                if not base_registry_path:
                    print(f"  ⚠️  To use extend mode, provide router_registry_base_path or ensure previous checkpoint exists\n")
                    router_registry_init_mode = "fresh"  # Override to fresh mode
            
            # Load or build registry based on init mode
            if router_registry_init_mode == "extend" and base_registry_path:
                # Normalize path: if it's a directory, append model_registry.json
                if os.path.isdir(base_registry_path):
                    base_registry_path = os.path.join(base_registry_path, "model_registry.json")
                
                if os.path.exists(base_registry_path):
                    print(f"  [Extend Mode] Loading base registry from {base_registry_path}")
                    self._router_registry = ModelRegistry.load(base_registry_path)
                M_old = len(self._router_registry)
                
                # Extend registry with new models from current dataset
                print(f"  Extending registry with models from current dataset...")
                num_added = self._router_registry.extend_from_examples(
                    examples=all_examples_for_registry,
                    model_name_key="model_name",
                    domain_key="domain",
                    family_key=None,
                )
                M_new = len(self._router_registry)
                
                print(f"  Loaded base registry: {M_old} models; extended registry: {M_new} models; added: {num_added}")
                
                # Print first 5 newly added model names for sanity
                if num_added > 0:
                    new_model_names = []
                    for idx in range(M_old, M_new):
                        if idx in self._router_registry.idx2model:
                            new_model_names.append(self._router_registry.idx2model[idx])
                    print(f"  First {min(5, len(new_model_names))} newly added models:")
                    for name in new_model_names[:5]:
                        print(f"    - {name}")
                
                # ENFORCE APPEND-ONLY REGISTRY: Assert ID stability (verify ID equality, not just membership)
                # All prior model IDs must be unchanged (critical for continual learning)
                base_registry = None
                if base_registry_path:
                    # Normalize path: if it's a directory, append model_registry.json
                    normalized_path = base_registry_path
                    if os.path.isdir(normalized_path):
                        normalized_path = os.path.join(normalized_path, "model_registry.json")
                    if os.path.exists(normalized_path):
                        base_registry = ModelRegistry.load(normalized_path)
                
                violations = []
                if base_registry is not None:
                    # Check 1: For every old ID i, new.idx2model[i] == old.idx2model[i]
                    for idx in range(M_old):
                        if idx not in self._router_registry.idx2model:
                            violations.append(f"Base model ID {idx} missing after extension")
                        else:
                            current_model = self._router_registry.idx2model[idx]
                            old_model = base_registry.idx2model.get(idx, None)
                            if old_model is None:
                                violations.append(f"ID {idx} exists in extended registry but not in base registry")
                            elif current_model != old_model:
                                violations.append(f"ID {idx} changed: base='{old_model}' -> extended='{current_model}'")
                    
                    # Check 2: For every old model name, new.model2idx[name] == old.model2idx[name]
                    for model_name, old_idx in base_registry.model2idx.items():
                        if model_name not in self._router_registry.model2idx:
                            violations.append(f"Model '{model_name}' (base ID {old_idx}) missing in extended registry")
                        else:
                            new_idx = self._router_registry.model2idx[model_name]
                            if new_idx != old_idx:
                                violations.append(f"Model '{model_name}' ID changed: base={old_idx} -> extended={new_idx}")
                
                if violations:
                    error_msg = "CRITICAL: Registry ID stability violations detected:\n" + "\n".join(f"  - {v}" for v in violations[:10])
                    if len(violations) > 10:
                        error_msg += f"\n  ... and {len(violations) - 10} more violations"
                    raise ValueError(error_msg)
                
                if base_registry is not None:
                    print(f"  ✓ Registry ID stability check passed:")
                    print(f"    - All {M_old} prior IDs unchanged (idx2model equality verified)")
                    print(f"    - All {len(base_registry.model2idx)} prior model names have same IDs (model2idx equality verified)")
                
            elif router_registry_path and os.path.exists(router_registry_path):
                # Load existing registry (backward compatibility)
                print(f"  Loading registry from {router_registry_path}")
                self._router_registry = ModelRegistry.load(router_registry_path)
            else:
                # Build fresh registry
                print(f"  [Fresh Mode] Building registry from scratch")
                self._router_registry = ModelRegistry.from_examples(
                    train_examples=train_examples_for_registry,
                    replay_examples=None,  # Already included above
                    raw_prompts=None,
                )
                if router_registry_path:
                    print(f"  Saving registry to {router_registry_path}")
                    os.makedirs(os.path.dirname(router_registry_path), exist_ok=True)
                    self._router_registry.save(router_registry_path)
            
            print(f"  Registry: {len(self._router_registry)} unique models, {len(self._router_registry.get_all_domains())} domains")
            
            # Initialize router model
            if router_embedding_dim is None:
                # Infer from base model
                if hasattr(model, 'config') and hasattr(model.config, 'hidden_size'):
                    router_embedding_dim = model.config.hidden_size
                elif hasattr(model, 'model') and hasattr(model.model, 'config'):
                    router_embedding_dim = model.model.config.hidden_size
                elif hasattr(model, 'base_model') and hasattr(model.base_model, 'config'):
                    router_embedding_dim = model.base_model.config.hidden_size
                else:
                    router_embedding_dim = 4096  # Fallback
            
            device = next(model.parameters()).device
            dtype = next(model.parameters()).dtype
            
            # Get LM hidden size from model config
            lm_hidden_size = model.config.hidden_size if hasattr(model.config, 'hidden_size') else 4096
            
            # Check if we're resuming from checkpoint and should load router weights
            # =====================================================================
            # Router Checkpoint Loading: CRITICAL for exp2+ to load trained exp1 router
            # =====================================================================
            router_checkpoint_path = None
            resume_path = None
            
            # First, try explicit resume_from_checkpoint
            if hasattr(self.args, 'resume_from_checkpoint') and self.args.resume_from_checkpoint:
                resume_path = self.args.resume_from_checkpoint
                print(f"  [Router Checkpoint] Using explicit resume_from_checkpoint: {resume_path}")
            
            # If not set and in extend mode, try to infer from router_registry_base_path
            elif router_registry_init_mode == "extend" and router_registry_base_path:
                # router_registry_base_path can be:
                # 1. A file: .../checkpoint-XXX/model_registry.json -> router_model.pt in same dir
                # 2. A directory: .../experiment-name -> look for latest checkpoint-XXX/router_model.pt
                if os.path.isdir(router_registry_base_path):
                    # It's a directory, look for latest checkpoint inside it
                    exp_dir = Path(router_registry_base_path)
                    checkpoint_dirs = sorted(
                        [d for d in exp_dir.iterdir() if d.is_dir() and d.name.startswith('checkpoint-')],
                        key=lambda x: int(x.name.split('-')[1]) if x.name.split('-')[1].isdigit() else 0,
                        reverse=True
                    )
                    for checkpoint_dir in checkpoint_dirs:
                        potential_router = checkpoint_dir / "router_model.pt"
                        if potential_router.exists():
                            resume_path = str(checkpoint_dir)
                            print(f"  [Router Checkpoint] Inferred resume_from_checkpoint from router_registry_base_path (latest checkpoint): {resume_path}")
                            break
                    if not resume_path:
                        print(f"  ⚠️  [Router Checkpoint] router_registry_base_path is directory {router_registry_base_path}")
                        print(f"     But no checkpoint directories with router_model.pt found")
                else:
                    # It's a file path, router_model.pt should be in the same directory
                    base_registry_dir = os.path.dirname(router_registry_base_path)
                    if os.path.isdir(base_registry_dir):
                        potential_router = os.path.join(base_registry_dir, "router_model.pt")
                        if os.path.exists(potential_router):
                            resume_path = base_registry_dir
                            print(f"  [Router Checkpoint] Inferred resume_from_checkpoint from router_registry_base_path: {resume_path}")
                        else:
                            print(f"  ⚠️  [Router Checkpoint] router_registry_base_path points to {base_registry_dir}")
                            print(f"     But router_model.pt not found at {potential_router}")
            
            # Now check for router checkpoint in the resolved path
            if resume_path:
                if os.path.isdir(resume_path):
                    potential_router = os.path.join(resume_path, "router_model.pt")
                    if os.path.exists(potential_router):
                        router_checkpoint_path = potential_router
                        print(f"  ✓ Found router checkpoint at {router_checkpoint_path}")
                    else:
                        print(f"  ⚠️  Router checkpoint NOT found at {potential_router}")
                        print(f"     Router will be randomly initialized (not loading exp1 weights)")
                else:
                    print(f"  ⚠️  resume_from_checkpoint is not a directory: {resume_path}")
            else:
                print(f"  ⚠️  No resume_from_checkpoint specified")
                if router_registry_init_mode == "extend":
                    print(f"     WARNING: Registry init_mode='extend' but no checkpoint to load from!")
            
            # Create router model with current registry size
            print(f"  Creating router model: num_models={len(self._router_registry)}, embedding_dim={router_embedding_dim}")
            self._router_model = RouterModel(
                num_models=len(self._router_registry),
                embedding_dim=router_embedding_dim,
                lm_hidden_size=lm_hidden_size,
                tau=router_tau,
                pooling=router_pooling,
            ).to(device=device, dtype=dtype)

            # Load router weights from checkpoint if available, handling embedding resize
            if router_checkpoint_path:
                print(f"\n  [Router Checkpoint Loading] Loading router weights from {router_checkpoint_path}")
                try:
                    import json
                    router_config_path = os.path.join(os.path.dirname(router_checkpoint_path), "router_config.json")
                    
                    # Load checkpoint state
                    router_state = torch.load(router_checkpoint_path, map_location=device)
                    print(f"  [Router Checkpoint Loading] Checkpoint loaded, keys: {list(router_state.keys())[:5]}... (total: {len(router_state)} keys)")
                    
                    # Track if embeddings will be manually handled (resized)
                    embeddings_manually_handled = False
                    
                    # Check if we need to resize embeddings
                    if "model_embeddings.weight" in router_state:
                        old_emb_shape = router_state["model_embeddings.weight"].shape
                        new_emb_shape = self._router_model.model_embeddings.weight.shape
                        M_old, D_old = old_emb_shape
                        M_new, D_new = new_emb_shape
                        print(f"  [Router Checkpoint Loading] Embedding shapes: checkpoint={old_emb_shape}, current={new_emb_shape}")
                        
                        # Auto-detect M_old for exp1-preservation if not set
                        if self._router_exp1_preservation_enable and self._router_exp1_preservation_M_old is None:
                            self._router_exp1_preservation_M_old = M_old
                            print(f"  [Exp1-Preservation] Auto-detected M_old={M_old} from checkpoint")
                        
                        # Auto-detect M_old for router anchor if not set
                        if self._router_anchor_enable and self._router_anchor_M_old is None:
                            self._router_anchor_M_old = M_old
                            print(f"  [Router Anchor] Auto-detected M_old={M_old} from checkpoint")
                        
                        if M_new != M_old or D_new != D_old:
                            print(f"  Router embedding size changed: {old_emb_shape} → {new_emb_shape}")
                            
                            if D_new != D_old:
                                raise ValueError(
                                    f"Embedding dimension mismatch: checkpoint has {D_old}, "
                                    f"but model requires {D_new}. Cannot resize."
                                )
                            
                            # Resize embedding table
                            old_emb = router_state["model_embeddings.weight"]  # [M_old, D]
                            new_emb = self._router_model.model_embeddings.weight  # [M_new, D]
                            
                            # Copy overlapping rows
                            overlap = min(M_old, M_new)
                            with torch.no_grad():
                                new_emb[:overlap].copy_(old_emb[:overlap])
                            
                            # Initialize remaining rows if registry grew
                            if M_new > M_old:
                                with torch.no_grad():
                                    # Use xavier_uniform for new rows (matching RouterModel init)
                                    nn.init.xavier_uniform_(new_emb[M_old:])
                                print(f"  ✓ Copied {overlap} embedding rows, initialized {M_new - M_old} new rows")
                            else:
                                print(f"  ✓ Copied {overlap} embedding rows (registry shrunk)")
                            
                            # Free old embedding tensor from memory immediately
                            del old_emb
                            
                            # Mark that embeddings were manually handled
                            embeddings_manually_handled = True
                            
                            # Remove embedding from state dict to avoid shape mismatch
                            # (embeddings were already manually copied above)
                            router_state = {k: v for k, v in router_state.items() if k != "model_embeddings.weight"}
                    
                    # Load remaining weights (non-strict to handle architecture changes)
                    missing, unexpected = self._router_model.load_state_dict(router_state, strict=False)
                    
                    # Verify critical parameters were loaded (before deleting router_state)
                    critical_keys = ["model_embeddings.weight", "prompt_projection.weight"]
                    # Filter out model_embeddings.weight from missing if it was manually handled
                    missing_for_diagnostic = [k for k in missing if not (k == "model_embeddings.weight" and embeddings_manually_handled)]
                    missing_critical = [k for k in critical_keys if k in missing_for_diagnostic]
                    

                    
                    
                    # Free checkpoint state dict from GPU memory after loading
                    del router_state
                    if torch.cuda.is_available():
                        torch.cuda.empty_cache()
                except Exception as e:
                    print(f"  ❌ Error loading router checkpoint: {e}")
                    import traceback
                    traceback.print_exc()
                    print(f"  ⚠️  Router will be randomly initialized (checkpoint loading failed)")
            
            # =====================================================================
            # Router Embedding Anchoring: Capture/Load Reference Snapshot
            # =====================================================================
            # CRITICAL: Only capture snapshot when starting exp2 from exp1, not when resuming exp2
            # The reference must be the exp1 embeddings, not already-drifted exp2 embeddings
            if self._router_anchor_enable:
                # Determine M_old (base registry size)
                anchor_M_old = self._router_anchor_M_old
                if anchor_M_old is None:
                    # Try to infer from exp1-preservation M_old if available
                    if self._router_exp1_preservation_M_old is not None:
                        anchor_M_old = self._router_exp1_preservation_M_old
                        print(f"  [Router Anchor] Using M_old from exp1-preservation: {anchor_M_old}")
                    elif router_registry_init_mode == "extend" and base_registry_path:
                        # Infer from base registry size
                        try:
                            # Normalize path: if it's a directory, append model_registry.json
                            normalized_path = base_registry_path
                            if os.path.isdir(normalized_path):
                                normalized_path = os.path.join(normalized_path, "model_registry.json")
                            if os.path.exists(normalized_path):
                                # ModelRegistry is already imported at top level
                                base_registry = ModelRegistry.load(normalized_path)
                                anchor_M_old = len(base_registry)
                            else:
                                anchor_M_old = None
                        except Exception as e:
                            print(f"  ⚠️  [Router Anchor] Could not infer M_old from base registry: {e}")
                            anchor_M_old = None
                
                if anchor_M_old is not None and anchor_M_old > 0:
                    M_new = len(self._router_registry)
                    if M_new > anchor_M_old:
                        # Try to load saved anchor reference from checkpoint (if resuming exp2)
                        anchor_loaded = False
                        if router_checkpoint_path:
                            anchor_ref_path = os.path.join(os.path.dirname(router_checkpoint_path), "router_anchor_ref.pt")
                            if os.path.exists(anchor_ref_path):
                                try:
                                    self._router_anchor_ref_cpu = torch.load(anchor_ref_path, map_location="cpu")
                                    if self._router_anchor_ref_cpu.shape[0] == anchor_M_old:                                      anchor_loaded = True
                                except Exception as e:
                                    print(f"  ⚠️  [Router Anchor] Could not load reference from checkpoint: {e}")
                        
                        # Only capture new snapshot if:
                        # 1. We're starting exp2 from exp1 (extend mode + no checkpoint anchor ref found)
                        # 2. We're NOT resuming exp2 (which would have a saved anchor ref)
                        if not anchor_loaded:
                            # Check if we're actually starting exp2 from exp1
                            is_starting_exp2_from_exp1 = (
                                router_registry_init_mode == "extend" and 
                                base_registry_path and 
                                os.path.exists(base_registry_path) and
                                (not router_checkpoint_path or not os.path.exists(os.path.join(os.path.dirname(router_checkpoint_path), "router_anchor_ref.pt")))
                            )
                            
                            if is_starting_exp2_from_exp1:
                                # Capture reference snapshot of old embedding rows (exp1 state)
                                with torch.no_grad():
                                    E_old = self._router_model.model_embeddings.weight[:anchor_M_old]
                                    # DIAGNOSTICS: Check data_ptr() before snapshot
                                    live_ptr_before = E_old.data_ptr() if hasattr(E_old, 'data_ptr') else None
                                    
                                    # Store as FP32 CPU for compactness
                                    self._router_anchor_ref_cpu = E_old.detach().float().cpu().clone()
                                    # Ensure it doesn't require grad
                                    self._router_anchor_ref_cpu.requires_grad_(False)
                                
                                # DIAGNOSTICS: Check data_ptr() after snapshot capture
                                ref_ptr_after = self._router_anchor_ref_cpu.data_ptr() if hasattr(self._router_anchor_ref_cpu, 'data_ptr') else None
                                with torch.no_grad():
                                    E_old_current = self._router_model.model_embeddings.weight[:anchor_M_old]
                                    live_ptr_after = E_old_current.data_ptr() if hasattr(E_old_current, 'data_ptr') else None
                                    print(f"  [Router Anchor] After snapshot capture:")
                                    print(f"    ref.data_ptr(): {ref_ptr_after}")
                                    print(f"    live_embeddings.weight[:M_old].data_ptr(): {live_ptr_after}")
                                    if ref_ptr_after is not None and live_ptr_after is not None:
                                        if ref_ptr_after == live_ptr_after:
                                            print(f"    ⚠️  WARNING: Reference and live embeddings share the same data_ptr()! ALIASING DETECTED!")
                                            print(f"    This means the snapshot is aliasing live weights. Fix: ensure .detach().clone() is used.")
                                        else:
                                            print(f"    ✓ No aliasing detected (different data_ptr)")
                                
                                # Sanity check: verify snapshot matches current state (should be 0 diff)
                                with torch.no_grad():
                                    E_old_current = self._router_model.model_embeddings.weight[:anchor_M_old].float().cpu()
                                    max_diff = (E_old_current - self._router_anchor_ref_cpu).abs().max().item()
                                    mean_diff = (E_old_current - self._router_anchor_ref_cpu).abs().mean().item()
                                    print(f"  [Router Anchor] Captured exp1 reference snapshot: M_old={anchor_M_old}, M_new={M_new}")
                                    print(f"  [Router Anchor] Sanity check: max abs diff = {max_diff:.2e}, mean abs diff = {mean_diff:.2e} (should be ~0)")
                                    if max_diff > 1e-5:
                                        print(f"  ⚠️  WARNING: Snapshot diff is non-zero! This may indicate a problem.")
                                    
                                    # Compute initial anchor loss for reference
                                    if self._router_anchor_mode == "normalized":
                                        E_old_norm = F.normalize(E_old_current, p=2, dim=-1)
                                        ref_norm = F.normalize(self._router_anchor_ref_cpu, p=2, dim=-1)
                                        initial_anchor_loss = (E_old_norm - ref_norm).pow(2).mean().item()
                                    else:
                                        initial_anchor_loss = (E_old_current - self._router_anchor_ref_cpu).pow(2).mean().item()
                                    lambda_val = float(self._router_anchor_lambda) if not isinstance(self._router_anchor_lambda, (int, float)) else self._router_anchor_lambda
                                    print(f"  [Router Anchor] Initial anchor loss (should be ~0): {initial_anchor_loss:.6f}")
                                    print(f"  [Router Anchor] Initial weighted anchor loss (λ={lambda_val:.1e}): {lambda_val * initial_anchor_loss:.6f}")
                            else:
                                print(f"  ⚠️  [Router Anchor] Not starting exp2 from exp1 (or anchor ref missing), disabling anchoring")
                                self._router_anchor_enable = False
                                anchor_M_old = None
                        
                        # Store M_old for later use
                        if anchor_M_old is not None:
                            self._router_anchor_M_old = anchor_M_old
                    else:
                        print(f"  ⚠️  [Router Anchor] M_new ({M_new}) <= M_old ({anchor_M_old}), disabling anchoring")
                        self._router_anchor_enable = False
                else:
                    print(f"  ⚠️  [Router Anchor] Could not determine M_old, disabling anchoring")
                    self._router_anchor_enable = False
            
            # Capture projection anchor snapshot if enabled
            if self._router_proj_anchor_enable and self._router_model is not None:
                with torch.no_grad():
                    # Snapshot projection weights immediately after loading/resizing (CPU FP32)
                    self._router_proj_anchor_ref_cpu = {
                        k: v.detach().clone().cpu().float() 
                        for k, v in self._router_model.prompt_projection.state_dict().items()
                    }
                    # Ensure no gradients
                    for v in self._router_proj_anchor_ref_cpu.values():
                        v.requires_grad_(False)
                    # Device cache will be materialized on-demand in _compute_router_proj_anchor_loss
                    self._router_proj_anchor_ref = None
                    print(f"  [Router Proj Anchor] Captured projection reference snapshot (CPU FP32)")
                    # Log initial projection anchor loss for sanity check
                    proj_anchor_loss_init = self._compute_router_proj_anchor_loss()
                    if proj_anchor_loss_init is not None:
                        lambda_val = float(self._router_proj_anchor_lambda) if not isinstance(self._router_proj_anchor_lambda, (int, float)) else self._router_proj_anchor_lambda
                        print(f"  [Router Proj Anchor] Initial proj anchor loss (should be ~0): {proj_anchor_loss_init.item():.8e}")
                        print(f"  [Router Proj Anchor] Initial weighted loss (λ={lambda_val:.1e}): {lambda_val * proj_anchor_loss_init.item():.8e}")
            

            
            # Initialize candidate builder
            self._router_candidate_builder = CandidateSetBuilder(
                registry=self._router_registry,
                K_total=router_K_total,
                K_semantic=router_K_semantic,
                K_far=router_K_far,
                K_hard=router_K_hard,
                semantic_pool_mode=router_semantic_pool_mode,
                semantic_pool_max_domains=router_semantic_pool_max_domains,
                semantic_pool_depth=router_semantic_pool_depth,
            )
            
            # Initialize hard negative miner
            self._router_hard_miner = HardNegativeMiner(
                registry=self._router_registry,
                K_hard_pool=router_K_hard_pool,
                semantic_pool_size=router_semantic_pool_size,
                max_pool_size=router_max_pool_size,
                semantic_pool_mode=router_semantic_pool_mode,
                semantic_pool_max_domains=router_semantic_pool_max_domains,
                semantic_pool_depth=router_semantic_pool_depth,
            )
            
            # Enable hidden states output for prompt embedding extraction
            enable_hidden_states(model)
            
            # If the LM is frozen (router-only training), force it to eval mode
            # This ensures dropout is disabled and the model behaves deterministically
            self._ensure_frozen_lm_in_eval_mode()
    
    def _ensure_frozen_lm_in_eval_mode(self):
        """
        If the base LM is frozen (all parameters have requires_grad=False),
        force it to eval mode to disable dropout and ensure deterministic behavior.
        
        This is critical for router-only training modes where the LM is just used
        as a feature extractor and should not be updated.
        
        NOTE: Only applies to "router" and "router+graph" modes. For "supervised+router"
        and "supervised+router+graph", the LM should be trained, so it stays in train mode.
        """
        # Only apply this logic for pure router modes (not supervised+router modes)
        if self._loss_mode not in ["router", "router+graph"]:
            return
        
        # Get the base model (could be model, model.model, or model.base_model)
        base_model = self.model
        if hasattr(self.model, 'model'):
            base_model = self.model.model
        elif hasattr(self.model, 'base_model'):
            base_model = self.model.base_model
        
        # Check if base model parameters are frozen
        base_params = list(base_model.parameters())
        if not base_params:
            return
        
        # Check if all base model parameters are frozen
        all_frozen = all(not p.requires_grad for p in base_params)
        
        if all_frozen:
            # Base model is frozen - set to eval mode and disable dropout
            base_model.eval()
            if not self._frozen_lm_eval_mode_logged:
                self._frozen_lm_eval_mode_logged = True
            
            # Explicitly disable dropout modules
            for module in base_model.modules():
                if isinstance(module, torch.nn.Dropout):
                    module.eval()
                elif hasattr(module, 'dropout') and isinstance(module.dropout, torch.nn.Dropout):
                    module.dropout.eval()
                elif hasattr(module, 'attention_dropout') and isinstance(module.attention_dropout, torch.nn.Dropout):
                    module.attention_dropout.eval()
            
            # Ensure router is in train mode
            if self._router_model is not None:
                self._router_model.train()
        else:
            # Some parameters are trainable - check if router is the only trainable component
            trainable_base_params = [p for p in base_params if p.requires_grad]
            router_params = list(self._router_model.parameters()) if self._router_model is not None else []
            trainable_router_params = [p for p in router_params if p.requires_grad] if router_params else []
            
            # If router has trainable params but base model doesn't, set base to eval
            if trainable_router_params and not trainable_base_params:
                base_model.eval()

                
                # Explicitly disable dropout modules
                for module in base_model.modules():
                    if isinstance(module, torch.nn.Dropout):
                        module.eval()
                    elif hasattr(module, 'dropout') and isinstance(module.dropout, torch.nn.Dropout):
                        module.dropout.eval()
                    elif hasattr(module, 'attention_dropout') and isinstance(module.attention_dropout, torch.nn.Dropout):
                        module.attention_dropout.eval()
                
                # Ensure router is in train mode
                if self._router_model is not None:
                    self._router_model.train()
    
    def create_optimizer(self):
        """
        Override to include X-CLR projection head and router parameters in the optimizer.
        """
        # Call parent to create the base optimizer
        super().create_optimizer()
        
        # Add router parameters to optimizer if router mode is enabled
        if self._loss_mode in ["router", "router+graph", "supervised+router", "supervised+router+graph"] and self._router_model is not None:
            # Split router parameters into projection head and embedding table
            proj_params = list(self._router_model.prompt_projection.parameters())
            embedding_params = list(self._router_model.model_embeddings.parameters())
            
            # Add projection head parameters with separate LR
            if proj_params:
                self.optimizer.add_param_group({
                    'params': proj_params,
                    'lr': self._router_proj_lr,
                    'weight_decay': self.args.weight_decay,
                })
                num_proj_params = sum(p.numel() for p in proj_params)
                print(f"[Router] Added projection head parameters to optimizer "
                      f"({num_proj_params} params, lr={self._router_proj_lr})")
            
            # Add embedding table parameters with separate LR
            if embedding_params:
                self.optimizer.add_param_group({
                    'params': embedding_params,
                    'lr': self._router_embedding_lr,
                    'weight_decay': self.args.weight_decay,
                })
                num_embedding_params = sum(p.numel() for p in embedding_params)
                print(f"[Router] Added embedding table parameters to optimizer "
                      f"({num_embedding_params} params, lr={self._router_embedding_lr})")
                

        
        return self.optimizer
    
    def save_router_checkpoint(self, output_dir: str):
        """
        Save router model and registry for evaluation.
        
        Args:
            output_dir: Directory to save router checkpoints
        """
        if self._router_model is None or self._router_registry is None:
            return
        
        import os
        os.makedirs(output_dir, exist_ok=True)
        
        # Save router model weights
        router_path = os.path.join(output_dir, "router_model.pt")
        torch.save(self._router_model.state_dict(), router_path)
        print(f"✓ Saved router model to {router_path}")
        
        # Save model registry
        registry_path = os.path.join(output_dir, "model_registry.json")
        self._router_registry.save(registry_path)
        
        # Save router config for reference
        router_config = {
            "num_models": len(self._router_registry),
            "embedding_dim": self._router_model.embedding_dim,
            "lm_hidden_size": self._router_model.lm_hidden_size,
            "tau": self._router_model.tau,
            "pooling": self._router_model.pooling,
            "K_total": self._router_candidate_builder.K_total,
            "K_semantic": self._router_candidate_builder.K_semantic,
            "K_far": self._router_candidate_builder.K_far,
            "K_hard": self._router_candidate_builder.K_hard,
        }
        # Include exp1 preservation info for evaluation diagnostics
        if self._router_registry_base_path:
            router_config["router_registry_base_path"] = self._router_registry_base_path
        if self._router_exp1_preservation_M_old is not None:
            router_config["router_exp1_preservation_M_old"] = self._router_exp1_preservation_M_old
        # Include router anchor info
        if self._router_anchor_enable and self._router_anchor_M_old is not None:
            router_config["router_anchor_M_old"] = self._router_anchor_M_old
            router_config["router_anchor_mode"] = self._router_anchor_mode
        import json
        config_path = os.path.join(output_dir, "router_config.json")
        with open(config_path, 'w') as f:
            json.dump(router_config, f, indent=2)

        
        # Save anchor reference if available (for resuming exp2 training)
        if self._router_anchor_enable and self._router_anchor_ref_cpu is not None:
            anchor_ref_path = os.path.join(output_dir, "router_anchor_ref.pt")
            torch.save(self._router_anchor_ref_cpu, anchor_ref_path)

    
    def get_train_dataloader(self) -> DataLoader:
        """
        Override to use custom batch samplers:

        - DomainBatchSampler when router semantic batching is enabled
        """
        # Router semantic batching takes priority
        if (self._semantic_batching and 
            self._loss_mode in ["router", "router+graph", "supervised+router", "supervised+router+graph"]):
            
            from ..data_carve.samplers import DomainBatchSampler
            
            dataset = self.train_dataset
            batch_size = self.args.per_device_train_batch_size
            
            # Determine seed for reproducibility
            seed = self.args.seed if self.args.seed is not None else 42
            if hasattr(self, 'state') and hasattr(self.state, 'epoch') and self.state.epoch is not None:
                seed = seed + int(self.state.epoch)
            
            # Create DomainBatchSampler
            batch_sampler = DomainBatchSampler(
                dataset=dataset,
                batch_size=batch_size,
                domains_per_batch=self._domains_per_batch,
                shuffle=True,
                seed=seed,
                drop_last=self.args.dataloader_drop_last,
            )
            
            sampler_length = len(batch_sampler)
            optimizer_steps_per_epoch = sampler_length // self.args.gradient_accumulation_steps
            total_optimizer_steps = optimizer_steps_per_epoch * self.args.num_train_epochs
            
            # Compute Phase 1 steps for two-phase training
            if self._router_two_phase_enable:
                self._phase1_steps = int(self._router_phase1_frac * total_optimizer_steps)
                print(f"    Phase 1 (stability warmup): steps 0-{self._phase1_steps-1} ({self._router_phase1_frac*100:.1f}% of training)")
                print(f"    Phase 2 (main training): steps {self._phase1_steps}-{total_optimizer_steps-1} ({(1-self._router_phase1_frac)*100:.1f}% of training)")
            

            # Create DataLoader with batch_sampler
            dataloader = DataLoader(
                dataset,
                batch_sampler=batch_sampler,
                collate_fn=self.data_collator,
                num_workers=self.args.dataloader_num_workers,
                pin_memory=self.args.dataloader_pin_memory,
            )
        
            
            return dataloader
        
        else:
            # Use default parent implementation
            dataloader = super().get_train_dataloader()
            
            # Compute phase1_steps once (for non-semantic batching case)
            if self._router_two_phase_enable and self._phase1_steps is None:
                steps_per_epoch = len(dataloader) // self.args.gradient_accumulation_steps
                total_optimizer_steps = int(steps_per_epoch * self.args.num_train_epochs)
                
                # Use max_steps if available (takes precedence)
                if hasattr(self.args, 'max_steps') and self.args.max_steps is not None and self.args.max_steps > 0:
                    total_optimizer_steps = self.args.max_steps
                
                self._phase1_steps = int(self._router_phase1_frac * total_optimizer_steps)
                print(f"    Phase 1 (stability warmup): {self._phase1_steps} steps ({self._router_phase1_frac*100:.1f}% of training)")
                print(f"    Phase 2 (main training): {total_optimizer_steps - self._phase1_steps} steps ({(1-self._router_phase1_frac)*100:.1f}% of training)")
            
            return dataloader
    
    def _compute_phase1_steps_from_dataloader(self):
        """Compute phase1_steps from dataloader (canonical computation, called once)."""
        if not self._router_two_phase_enable or self._phase1_steps is not None:
            return
        
        # Get dataloader to compute total steps
        try:
            dataloader = super().get_train_dataloader()
            steps_per_epoch = len(dataloader) // self.args.gradient_accumulation_steps
            total_optimizer_steps = int(steps_per_epoch * self.args.num_train_epochs)
            
            # Use max_steps if available (takes precedence)
            if hasattr(self.args, 'max_steps') and self.args.max_steps is not None and self.args.max_steps > 0:
                total_optimizer_steps = self.args.max_steps
            
            self._phase1_steps = int(self._router_phase1_frac * total_optimizer_steps)
        except Exception as e:
            print(f"  Warning: Could not compute phase1_steps from dataloader: {e}")
            # Will be computed later when dataloader is available
    
    def _freeze_lm_parameters(self, model: nn.Module):
        """Freeze LM parameters for Phase 1 by setting LR to 0 (more reliable than requires_grad=False).
        """
        # CRITICAL: Ensure router remains trainable FIRST (before freezing anything)
        if hasattr(self, '_router_model') and self._router_model is not None:
            for param in self._router_model.parameters():
                param.requires_grad = True
        
        if not hasattr(self, 'optimizer') or self.optimizer is None:
            # Fallback: use requires_grad if optimizer not available
            # CRITICAL: Only freeze base_model, not router
            if hasattr(model, 'base_model'):
                # CompositeModelWithRouter: freeze only base_model
                for param in model.base_model.parameters():
                    param.requires_grad = False
                # Ensure router remains trainable (if it's part of the composite)
                if hasattr(model, 'router'):
                    for param in model.router.parameters():
                        param.requires_grad = True
            else:
                # Regular model (PEFT wrapped): freeze the model itself
                # But router is separate in self._router_model, so this is OK
                for param in model.parameters():
                    param.requires_grad = False
            print(f"  [Phase 1] Set requires_grad=False for LM parameters (optimizer not available)")
            return
        
        # Store original LRs for restoration
        if self._lm_param_groups_original_lr is None:
            self._lm_param_groups_original_lr = {}
        
        # Identify LM parameter groups (exclude router and X-CLR projection groups)
        # CRITICAL: Router params are in separate param groups added in create_optimizer        
        lm_param_group_indices = []
        router_param_group_indices = []
        
        for i, param_group in enumerate(self.optimizer.param_groups):
            # Check parameter IDs to identify router groups
            # Router params were added in create_optimizer with specific structure
            param_ids_in_group = {id(p) for p in param_group['params']}
            
            # Check if this group contains router parameters
            is_router_group = False
            if hasattr(self, '_router_model') and self._router_model is not None:
                router_param_ids = {id(p) for p in self._router_model.parameters()}
                if param_ids_in_group & router_param_ids:  # Intersection is non-empty
                    is_router_group = True
                    router_param_group_indices.append(i)
            
            # Also check by name (fallback)
            group_name = param_group.get('name', '').lower()
            if not is_router_group and ('router' in group_name):
                is_router_group = True
                router_param_group_indices.append(i)
            
            if not is_router_group:
                # This is an LM parameter group
                lm_param_group_indices.append(i)
                # Store original LR
                if i not in self._lm_param_groups_original_lr:
                    self._lm_param_groups_original_lr[i] = param_group['lr']
        
        # Set LM param group LRs to 0
        num_lm_groups = 0
        for i in lm_param_group_indices:
            self.optimizer.param_groups[i]['lr'] = 0.0
            num_lm_groups += 1
        
        # Verify router groups are NOT frozen
        for i in router_param_group_indices:
            if self.optimizer.param_groups[i]['lr'] == 0.0:
                #print(f"  ⚠️  WARNING: Router param group {i} has LR=0! Restoring router LR.")
                # Restore router LR (use phase1 router LR if in phase1, else original)
                if hasattr(self, '_router_phase1_proj_lr'):
                    # Determine if this is proj or embedding group (simplified)
                    if 'proj' in self.optimizer.param_groups[i].get('name', '').lower():
                        self.optimizer.param_groups[i]['lr'] = self._router_phase1_proj_lr
                    else:
                        self.optimizer.param_groups[i]['lr'] = self._router_phase1_embedding_lr
                else:
                    self.optimizer.param_groups[i]['lr'] = self._router_proj_lr  # Fallback
        
        # Also set requires_grad=False as backup - BUT ONLY FOR LM, NOT ROUTER
        # CRITICAL: Only freeze base_model, not router
        if hasattr(model, 'base_model'):
            # CompositeModelWithRouter: freeze only base_model
            for param in model.base_model.parameters():
                param.requires_grad = False
            # Ensure router remains trainable (if it's part of the composite)
            if hasattr(model, 'router'):
                for param in model.router.parameters():
                    param.requires_grad = True
        else:
            # Regular model (PEFT wrapped): freeze the model itself
            # Router is separate in self._router_model, already set to trainable above
            for param in model.parameters():
                param.requires_grad = False
        
        # CRITICAL: Ensure router remains trainable (double-check)
        if hasattr(self, '_router_model') and self._router_model is not None:
            for param in self._router_model.parameters():
                param.requires_grad = True
        
        # Verify router parameters are actually trainable
        router_trainable_count = 0
        router_total_count = 0
        if hasattr(self, '_router_model') and self._router_model is not None:
            for param in self._router_model.parameters():
                router_total_count += 1
                if param.requires_grad:
                    router_trainable_count += 1
        
        if router_total_count > 0 and router_trainable_count == 0:
            raise RuntimeError("CRITICAL: All router parameters are frozen! Router must be trainable in Phase 1.")
    
    def _apply_exp1_preservation(self):
        """
        Apply exp1-preservation mode: freeze old embeddings during Phase 1.
        
        Freezes:
        - router_model.model_embeddings[:M_old] (old rows) - using gradient hook
        
        Keeps trainable:
        - router_model.model_embeddings[M_old:] (new rows)
        - router_model.prompt_projection (projection) - MUST remain trainable to adapt to new embeddings
        
        NOTE: Projection is NOT frozen because:
        1. Frozen projection + frozen old embeddings creates a mismatch
        2. New embeddings (random init) can learn to work with frozen projection
        3. Old embeddings (frozen) cannot adapt to frozen projection
        4. This biases router toward new models, causing accuracy drop on exp1
        5. Keeping projection trainable allows it to adapt to work with both old and new embeddings
        """
        if not self._router_exp1_preservation_enable:
            return
        
        if self._router_model is None:
            return
        
        M_old = self._router_exp1_preservation_M_old
        if M_old is None:
            print(f"  ⚠️  [Exp1-Preservation] M_old not set, skipping preservation mode")
            return
        
        M_new = len(self._router_registry)
        if M_new <= M_old:
            print(f"  ⚠️  [Exp1-Preservation] M_new ({M_new}) <= M_old ({M_old}), skipping preservation mode")
            return
        
        # Register gradient hook to zero gradients for old embedding rows
        old_emb_count = 0
        new_emb_count = 0
        if hasattr(self._router_model, 'model_embeddings'):
            emb_weight = self._router_model.model_embeddings.weight
            old_emb_count = M_old * emb_weight.shape[1]
            new_emb_count = (M_new - M_old) * emb_weight.shape[1]
            
            # Register hook to zero gradients for old rows
            def zero_old_grad_hook(grad):
                if grad is not None:
                    # Zero out gradients for old rows [0:M_old]
                    grad_clone = grad.clone()
                    grad_clone[:M_old] = 0.0
                    return grad_clone
                return grad
            
            # Register the hook (store handle for later removal)
            if not hasattr(self, '_exp1_preservation_hooks'):
                self._exp1_preservation_hooks = []
            hook_handle = emb_weight.register_hook(zero_old_grad_hook)
            self._exp1_preservation_hooks.append(hook_handle)
        

        
        self._exp1_preservation_applied = True
    
    def _remove_exp1_preservation(self):
        """
        Remove exp1-preservation mode: unfreeze old embeddings for Phase 2.
        """
        if not self._router_exp1_preservation_enable or not self._exp1_preservation_applied:
            return
        
        if self._router_model is None:
            return
        
        # Remove gradient hooks for old embedding rows
        if hasattr(self, '_exp1_preservation_hooks'):
            for hook_handle in self._exp1_preservation_hooks:
                hook_handle.remove()
            self._exp1_preservation_hooks = []
        
        self._exp1_preservation_applied = False
    
    def _unfreeze_lm_parameters(self, model: nn.Module):
        """Unfreeze LM parameters for Phase 2 by restoring original LRs.
        
        CRITICAL: Only unfreezes LM/base_model. Router should already be trainable.
        """
        if hasattr(self, '_router_model') and self._router_model is not None:
            for param in self._router_model.parameters():
                param.requires_grad = True
        
        if not hasattr(self, 'optimizer') or self.optimizer is None:
            # Fallback: use requires_grad if optimizer not available
            if hasattr(model, 'base_model'):
                for param in model.base_model.parameters():
                    param.requires_grad = True
                # Ensure router remains trainable (if it's part of the composite)
                if hasattr(model, 'router'):
                    for param in model.router.parameters():
                        param.requires_grad = True
            else:
                # Regular model: unfreeze all
                for param in model.parameters():
                    param.requires_grad = True
            print(f"  [Phase 2] Set requires_grad=True for LM parameters (optimizer not available)")
            return
        
        # Restore original LRs for LM groups only
        if self._lm_param_groups_original_lr is not None:
            num_restored = 0
            for i, original_lr in self._lm_param_groups_original_lr.items():
                if i < len(self.optimizer.param_groups):
                    self.optimizer.param_groups[i]['lr'] = original_lr
                    num_restored += 1
            print(f"  [Phase 2] Restored original LR for {num_restored} LM param groups")
            self._lm_param_groups_original_lr = None
        
        # Restore requires_grad - BUT ONLY FOR LM, ROUTER SHOULD ALREADY BE TRAINABLE
        # CRITICAL: Only unfreeze base_model, router should already be trainable
        if hasattr(model, 'base_model'):
            for param in model.base_model.parameters():
                param.requires_grad = True
            # Ensure router remains trainable (if it's part of the composite)
            if hasattr(model, 'router'):
                for param in model.router.parameters():
                    param.requires_grad = True
        else:
            # Regular model: unfreeze all
            for param in model.parameters():
                param.requires_grad = True
        
        # CRITICAL: Ensure router remains trainable (double-check)
        if hasattr(self, '_router_model') and self._router_model is not None:
            for param in self._router_model.parameters():
                param.requires_grad = True
    
    def _update_router_learning_rates(self, proj_lr: float, embedding_lr: float):
        """Update router learning rates in optimizer."""
        if not hasattr(self, 'optimizer') or self.optimizer is None:
            return
        
        # Find router parameter groups in optimizer
        # Router parameters are typically in groups with names containing "router" or "model_embedding"
        updated_proj = False
        updated_embedding = False
        
        for param_group in self.optimizer.param_groups:
            # Check if this is the router projection group
            if 'router' in param_group.get('name', '').lower() or 'proj' in param_group.get('name', '').lower():
                if 'embedding' not in param_group.get('name', '').lower():
                    param_group['lr'] = proj_lr
                    updated_proj = True
            # Check if this is the router embedding group
            elif 'embedding' in param_group.get('name', '').lower() or 'model_embedding' in param_group.get('name', '').lower():
                param_group['lr'] = embedding_lr
                updated_embedding = True
        
        # If parameter groups don't have names, try to identify by parameter names
        if not updated_proj or not updated_embedding:
            for param_group in self.optimizer.param_groups:
                # Check parameter names in this group
                if len(param_group['params']) > 0:
                    first_param_name = ''
                    if hasattr(param_group['params'][0], 'name'):
                        first_param_name = param_group['params'][0].name
                    elif hasattr(self, '_router_model') and self._router_model is not None:
                        # Try to match by checking if params belong to router model
                        router_param_ids = {id(p) for p in self._router_model.parameters()}
                        if any(id(p) in router_param_ids for p in param_group['params']):
                            # This is a router parameter group
                            # Check if it's projection or embedding
                            if hasattr(self._router_model, 'projection') and any(id(p) in router_param_ids for p in self._router_model.projection.parameters()):
                                param_group['lr'] = proj_lr
                                updated_proj = True
                            elif hasattr(self._router_model, 'model_embedding') and any(id(p) in router_param_ids for p in self._router_model.model_embedding.parameters()):
                                param_group['lr'] = embedding_lr
                                updated_embedding = True
        
        if updated_proj or updated_embedding:
            print(f"  [LR Update] Router proj_lr: {proj_lr:.2e}, embedding_lr: {embedding_lr:.2e}")
    
    def _log_router_param_groups(self, step: Optional[int] = None):
        """
        Log router parameter groups with their LRs for diagnostics.
        """
        if not hasattr(self, 'optimizer') or self.optimizer is None:
            return
        
        step_str = f" @ step {step}" if step is not None else ""
        print(f"\n  [Router Param Groups{step_str}]:")
        
        router_proj_lr = None
        router_embedding_lr = None
        
        for i, param_group in enumerate(self.optimizer.param_groups):
            name = param_group.get('name', 'unnamed')
            lr = param_group['lr']
            num_params = sum(p.numel() for p in param_group['params'])
            
            # Determine if this is a router group
            is_router = False
            is_proj = False
            is_embedding = False
            
            name_lower = name.lower()
            if 'router' in name_lower:
                is_router = True
                if 'proj' in name_lower and 'embedding' not in name_lower:
                    is_proj = True
                    router_proj_lr = lr
                elif 'embedding' in name_lower or 'model_embedding' in name_lower:
                    is_embedding = True
                    router_embedding_lr = lr
            
            # Also check by parameter ownership if name doesn't help
            if not is_router and hasattr(self, '_router_model') and self._router_model is not None:
                router_param_ids = {id(p) for p in self._router_model.parameters()}
                if any(id(p) in router_param_ids for p in param_group['params']):
                    is_router = True
                    if hasattr(self._router_model, 'projection') and any(id(p) in router_param_ids for p in self._router_model.projection.parameters()):
                        is_proj = True
                        router_proj_lr = lr
                    elif hasattr(self._router_model, 'model_embedding') and any(id(p) in router_param_ids for p in self._router_model.model_embedding.parameters()):
                        is_embedding = True
                        router_embedding_lr = lr
            
            group_type = "ROUTER"
            if is_proj:
                group_type = "ROUTER_PROJ"
            elif is_embedding:
                group_type = "ROUTER_EMB"
            elif not is_router:
                group_type = "LM"
            
            print(f"    Group {i}: {name} | LR={lr:.2e} | {num_params:,} params | {group_type}")
        
        # Explicitly log router LRs
        if router_proj_lr is not None or router_embedding_lr is not None:
            print(f"  [Router LRs{step_str}]: proj={router_proj_lr:.2e if router_proj_lr is not None else 'N/A'}, "
                  f"embedding={router_embedding_lr:.2e if router_embedding_lr is not None else 'N/A'}")
    
    def lr_scheduler_step(self, scheduler, metric):
        """
        Override to restore router LRs after scheduler step.
        
        The scheduler may modify all param group LRs, including router groups.
        We restore router LRs to their intended values after the scheduler step.
        """
        # Call parent implementation first (this steps the scheduler)
        super().lr_scheduler_step(scheduler, metric)
        
        # Restore router LRs after scheduler step
        if hasattr(self, 'optimizer') and self.optimizer is not None:
            # Determine current phase and target LRs
            if self._router_two_phase_enable:
                # Get current global step
                global_step = self.state.global_step if hasattr(self, 'state') and hasattr(self.state, 'global_step') else 0
                
                # Compute step within experience
                step_in_experience = global_step - (self._experience_start_global_step or 0)
                
                # Determine if we're in Phase 1
                if self._phase1_steps is not None and step_in_experience < self._phase1_steps:
                    # Phase 1: use phase1 LRs
                    self._update_router_learning_rates(
                        self._router_phase1_proj_lr,
                        self._router_phase1_embedding_lr
                    )
                else:
                    # Phase 2: use original LRs
                    self._update_router_learning_rates(
                        self._original_router_proj_lr,
                        self._original_router_embedding_lr
                    )
            else:
                # Two-phase disabled: use original LRs
                self._update_router_learning_rates(
                    self._original_router_proj_lr,
                    self._original_router_embedding_lr
                )

    
    def compute_loss(
        self,
        model: nn.Module,
        inputs: Dict[str, torch.Tensor],
        return_outputs: bool = False,
        num_items_in_batch: Optional[int] = None
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, Any]]:
        """
        Compute the training loss with optional neighbor consistency and contrastive losses.
        
        The standard supervised loss is computed first, then auxiliary losses are added:
        - Neighbor consistency: KL divergence between anchor and neighbor predictions
        - Neighbor contrastive: Ranking loss with hard negatives (different model_ids)
        """
        # ======================================================================
        # LM Sanity Check: Verify LM is not unintentionally updating in router-only runs
        # ======================================================================
        global_step = self.state.global_step if hasattr(self, 'state') and hasattr(self.state, 'global_step') else 0
        if global_step == 0:
            # Count LM parameters with requires_grad=True
            lm_params_with_grad = 0
            lm_total_params = 0
            for name, param in model.named_parameters():
                # Skip router parameters
                if 'router' in name.lower() or 'prompt_projection' in name.lower() or 'model_embeddings' in name.lower():
                    continue
                lm_total_params += 1
                if param.requires_grad:
                    lm_params_with_grad += 1
            
            print(f"\n  [LM Sanity Check @ step 0]")
            print(f"    LM params with requires_grad=True: {lm_params_with_grad}/{lm_total_params}")
            if self._loss_mode in ["router", "router+graph"]:
                if lm_params_with_grad > 0:
                    print(f"    ⚠️  WARNING: {lm_params_with_grad} LM params have requires_grad=True in router-only mode!")
                    print(f"    Consider setting router_freeze_lm=True to force requires_grad=False")
                else:
                    print(f"    ✓ All LM params have requires_grad=False (expected for router-only mode)")
            
            # Apply router_freeze_lm if enabled
            if self._router_freeze_lm:
                frozen_count = 0
                for name, param in model.named_parameters():
                    # Skip router parameters
                    if 'router' in name.lower() or 'prompt_projection' in name.lower() or 'model_embeddings' in name.lower():
                        continue
                    if param.requires_grad:
                        param.requires_grad = False
                        frozen_count += 1
                if frozen_count > 0:
                    print(f"    [router_freeze_lm] Froze {frozen_count} LM parameters (requires_grad=False)")
                    # Also ensure LM is in eval mode to disable dropout
                    if hasattr(model, 'base_model'):
                        model.base_model.eval()
                    elif hasattr(model, 'model'):
                        model.model.eval()
                    print(f"    [router_freeze_lm] Set LM to eval mode (dropout disabled)")
                    
                    # Warn if optimizer already created (should freeze before optimizer creation)
                    if hasattr(self, 'optimizer') and self.optimizer is not None:
                        print(f"    ⚠️  WARNING: Optimizer already created! LM params may still be in optimizer.")
                        print(f"    Consider setting router_freeze_lm=True before optimizer creation.")
        
        # ======================================================================
        # Two-Phase Training: Phase Switching Logic (PER-EXPERIENCE)
        # ======================================================================
        if self._router_two_phase_enable:
            # Get current global step
            global_step = self.state.global_step if hasattr(self, 'state') and hasattr(self.state, 'global_step') else 0
            
            # Initialize experience start step on first call (per-experience phase tracking)
            if self._experience_start_global_step is None:
                self._experience_start_global_step = global_step
                print(f"\n  [Two-Phase Training] Experience started at global_step={global_step}")
            
            # Compute step within this experience (critical for continual learning)
            step_in_experience = global_step - self._experience_start_global_step
            
            # Compute phase1_steps if not already computed
            if self._phase1_steps is None:
                # Try to compute from dataloader if available
                if hasattr(self, 'train_dataloader') and self.train_dataloader is not None:
                    steps_per_epoch = len(self.train_dataloader) // self.args.gradient_accumulation_steps
                    total_optimizer_steps = int(steps_per_epoch * self.args.num_train_epochs)
                    if hasattr(self.args, 'max_steps') and self.args.max_steps is not None and self.args.max_steps > 0:
                        total_optimizer_steps = self.args.max_steps
                    self._phase1_steps = int(self._router_phase1_frac * total_optimizer_steps)
                    print(f"  [Two-Phase Training] Computed phase1_steps={self._phase1_steps} from dataloader")
            
            # Determine current phase based on step WITHIN EXPERIENCE (not global step)
            is_phase1 = (self._phase1_steps is not None and step_in_experience < self._phase1_steps)
            current_phase = "Phase 1" if is_phase1 else "Phase 2"
            
            # Log phase transition
            if current_phase != self._current_phase:
                if self._current_phase is not None:
                    print(f"[PHASE TRANSITION] Switching from {self._current_phase} to {current_phase}")
                    # Log router param groups at phase transition
                    self._log_router_param_groups(step=global_step)
                else:
                    # First call - log initial phase
                    print(f"[TWO-PHASE TRAINING] Starting in {current_phase}")
                    if is_phase1:
                        # Apply exp1-preservation if enabled
                        if self._router_exp1_preservation_enable:
                            self._apply_exp1_preservation()
                        # Log router anchor settings
                        if self._router_anchor_enable:
                            lambda_val = float(self._router_anchor_lambda) if not isinstance(self._router_anchor_lambda, (int, float)) else self._router_anchor_lambda
                            print(f"      lambda: {lambda_val:.1e}")
                    else:
                        print(f"  Phase 2 settings (using original config):")
                
                # Handle phase transitions
                if self._current_phase is not None and self._current_phase != current_phase:
                    # Unfreeze LM when transitioning to Phase 2
                    if current_phase == "Phase 2":
                        self._unfreeze_lm_parameters(model)
                        # Remove exp1-preservation when transitioning to Phase 2
                        self._remove_exp1_preservation()
                
                self._current_phase = current_phase
                self._phase_transition_logged = True
            
            # Apply Phase 1 overrides
            if is_phase1:
                # Override loss mode and weights
                effective_loss_mode = self._router_phase1_loss_mode
                effective_router_loss_weight = self._router_phase1_router_loss_weight
                effective_lm_loss_weight = self._router_phase1_lm_loss_weight
                effective_use_soft_targets = self._router_phase1_use_soft_targets
                effective_soft_target_eps = self._router_phase1_soft_target_eps
                
                # Freeze LM parameters in Phase 1 (set LR to 0, not just requires_grad=False)
                if not self._lm_params_frozen:
                    self._freeze_lm_parameters(model)
                    self._lm_params_frozen = True
                
                # Update router learning rates for Phase 1
                if hasattr(self, 'optimizer') and self.optimizer is not None:
                    self._update_router_learning_rates(
                        self._router_phase1_proj_lr,
                        self._router_phase1_embedding_lr
                    )
            else:
                # Phase 2: use original config
                effective_loss_mode = self._original_loss_mode
                effective_router_loss_weight = self._original_router_loss_weight
                # CRITICAL: In router-only mode, force LM loss weight to 0 at runtime
                if effective_loss_mode in ["router", "router+graph"]:
                    effective_lm_loss_weight = 0.0
                else:
                    effective_lm_loss_weight = self._original_lm_loss_weight
                effective_use_soft_targets = self._original_router_use_soft_targets
                effective_soft_target_eps = self._original_router_soft_target_eps
                
                # Unfreeze LM parameters in Phase 2 (restore LR, not just requires_grad)
                if self._lm_params_frozen:
                    self._unfreeze_lm_parameters(model)
                    self._lm_params_frozen = False
                
                # Restore original router learning rates
                if hasattr(self, 'optimizer') and self.optimizer is not None:
                    self._update_router_learning_rates(
                        self._original_router_proj_lr,
                        self._original_router_embedding_lr
                    )
        else:
            # Two-phase disabled: use original config
            is_phase1 = False
            effective_loss_mode = self._loss_mode
            effective_router_loss_weight = self._router_loss_weight
            # CRITICAL: In router-only mode, force LM loss weight to 0 at runtime
            if effective_loss_mode in ["router", "router+graph"]:
                effective_lm_loss_weight = 0.0
            else:
                effective_lm_loss_weight = self._lm_loss_weight
            effective_use_soft_targets = self._router_use_soft_targets
            effective_soft_target_eps = self._router_soft_target_eps
        
        # Ensure frozen base model stays in eval mode (dropout disabled)
        if effective_loss_mode in ["router", "router+graph"] or (self._router_two_phase_enable and is_phase1):
            self._ensure_frozen_lm_in_eval_mode()
        
        # Get standard supervised loss and outputs
        if effective_loss_mode in ["router", "router+graph", "supervised+router", "supervised+router+graph"]:
            # MEMORY OPTIMIZATION for Phase 1: Since LM is frozen, use no_grad to avoid storing activations
            if self._router_two_phase_enable and is_phase1:
                # Phase 1: LM is frozen (LR=0), so we don't need gradients through LM
                with torch.no_grad():
                    outputs_no_grad = model(
                        input_ids=inputs.get("input_ids"),
                        attention_mask=inputs.get("attention_mask"),
                        labels=inputs.get("labels"),
                        output_hidden_states=True,
                        return_dict=True,
                    )
                
                # Extract only the last hidden state BEFORE exiting no_grad context
                last_hidden_state = outputs_no_grad.hidden_states[-1].clone()  # Clone to avoid reference issues
                loss_value = outputs_no_grad.loss
                
                # Free memory immediately by deleting the full outputs
                del outputs_no_grad
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()  
                
                # Re-enable gradients only for the last hidden state (router needs gradients)
                last_hidden_state = last_hidden_state.requires_grad_()
                
                # Create a minimal outputs object that works with get_last_hidden_statesyer
                class Phase1Outputs:
                    def __init__(self, loss_val, last_hidden):
                        self.loss = loss_val
                        # Store only last hidden state (saves memory vs storing all layers)
                        # Format as tuple so hidden_states[-1] works correctly
                        self.hidden_states = (last_hidden,)
                
                outputs = Phase1Outputs(loss_value, last_hidden_state)
                
                # Extract loss (will be 0 or ignored since LM is frozen in Phase 1)
                loss_supervised = outputs.loss
            else:
                # Phase 2 or two-phase disabled: Extract only last hidden state to save memory
                # We still need gradients for LoRA, but don't need to store all intermediate activations
                outputs_full = model(
                    input_ids=inputs.get("input_ids"),
                    attention_mask=inputs.get("attention_mask"),
                    labels=inputs.get("labels"),
                    output_hidden_states=True, 
                    return_dict=True,
                )
                
                # Extract only the last hidden state BEFORE storing full outputs
                last_hidden_state = outputs_full.hidden_states[-1].clone()  # [B, seq_len, D]
                loss_value = outputs_full.loss
                
                # CRITICAL: Delete full outputs immediately to free memory
                # This deletes all intermediate layer activations (huge memory savings!)
                del outputs_full
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
                
                # Re-enable gradients on last hidden state (needed for LoRA backprop)
                last_hidden_state = last_hidden_state.requires_grad_()
                
                # Create minimal outputs object (same pattern as Phase 1)
                class Phase2Outputs:
                    def __init__(self, loss_val, last_hidden):
                        self.loss = loss_val
                        # Store only last hidden state (saves memory vs storing all layers)
                        self.hidden_states = (last_hidden,)
                
                outputs = Phase2Outputs(loss_value, last_hidden_state)
                loss_supervised = outputs.loss
                
            
        else:
            need_outputs = return_outputs or self._use_neighbor_contrastive
            
            if need_outputs:
                loss_supervised, outputs = super().compute_loss(
                    model, inputs, return_outputs=True, num_items_in_batch=num_items_in_batch
                )
            else:
                loss_supervised = super().compute_loss(
                    model, inputs, return_outputs=False, num_items_in_batch=num_items_in_batch
                )
                outputs = None
        
        
        total_loss = loss_supervised
        
        # =================================================================
        # Router Loss Computation (Semantic Batching + Candidate Routing)
        # =================================================================
        # Use effective_loss_mode exclusively (no redundant conditions)
        if effective_loss_mode in ["router", "router+graph", "supervised+router", "supervised+router+graph"]:
            try:
                # Apply replay pressure if in Phase 1 and multiplier > 1.0
                apply_replay_multiplier = (self._router_two_phase_enable and 
                                          is_phase1 and 
                                          self._router_replay_loss_multiplier > 1.0)
                
                router_loss = self._compute_routing_loss_for_batch(
                    model, inputs, outputs,
                    apply_replay_multiplier=apply_replay_multiplier,
                    replay_loss_multiplier=self._router_replay_loss_multiplier if apply_replay_multiplier else 1.0,
                    use_soft_targets=effective_use_soft_targets,
                    soft_target_eps=effective_soft_target_eps,
                )
                
                if router_loss is not None and router_loss > 0:
                    self._router_loss_sum += router_loss.item()
                    self._router_loss_count += 1
                    
                    # Apply weighted combination based on effective loss mode
                    if effective_loss_mode in ["router", "router+graph"]:
                        # Pure router mode: use ONLY router loss (no LM loss)
                        total_loss = effective_router_loss_weight * router_loss
                        
                    else:
                        # Supervised+router modes: combine both losses
                        total_loss = (effective_lm_loss_weight * loss_supervised + 
                                     effective_router_loss_weight * router_loss)
                    
                    # Optional label-side graph regularizer
                    if effective_loss_mode in ["router+graph", "supervised+router+graph"]:
                        graph_loss = self._compute_router_graph_regularizer_for_batch(inputs)
                        if graph_loss is not None and graph_loss > 0:
                            self._router_graph_loss_sum += graph_loss.item()
                            self._router_graph_loss_count += 1
                            total_loss = total_loss + self._router_label_graph_lambda * graph_loss
                    
                    # Router embedding anchoring regularizer
                    if self._router_anchor_enable and self._router_anchor_ref_cpu is not None:
                        # Check if we should apply anchoring in this phase (use same is_phase1 logic as rest of system)
                        should_apply = False
                        if self._router_anchor_apply_phase == "both":
                            should_apply = True
                        elif self._router_anchor_apply_phase == "phase1":
                            if self._router_two_phase_enable:
                                # Two-phase enabled: apply only in phase1
                                should_apply = is_phase1
                            else:
                                # Two-phase disabled: "phase1" doesn't exist, treat as "both" (user-friendly)
                                should_apply = True
                                global_step = self.state.global_step if hasattr(self, 'state') and hasattr(self.state, 'global_step') else 0

                        elif self._router_anchor_apply_phase == "phase2" and not is_phase1:
                            should_apply = True
                        
                        if should_apply:
                            anchor_loss = self._compute_router_anchor_loss()
                            # DIAGNOSTICS: Print anchor loss values at key steps
                            global_step = self.state.global_step if hasattr(self, 'state') and hasattr(self.state, 'global_step') else 0
                            
                            if anchor_loss is not None and anchor_loss > 0:
                                if not hasattr(self, '_router_anchor_loss_sum'):
                                    self._router_anchor_loss_sum = 0.0
                                    self._router_anchor_loss_count = 0
                                    self._router_anchor_weighted_sum = 0.0
                                anchor_loss_val = anchor_loss.item()
                                weighted_anchor_loss = self._router_anchor_lambda * anchor_loss
                                self._router_anchor_loss_sum += anchor_loss_val
                                self._router_anchor_loss_count += 1
                                self._router_anchor_weighted_sum += weighted_anchor_loss.item()
                                total_loss = total_loss + weighted_anchor_loss
                            
                    # Router projection anchoring regularizer
                    if self._router_proj_anchor_enable and self._router_proj_anchor_ref_cpu is not None:
                        # Check if we should apply projection anchoring in this phase
                        proj_should_apply = False
                        if self._router_proj_anchor_apply_phase == "both":
                            proj_should_apply = True
                        elif self._router_proj_anchor_apply_phase == "phase1":
                            if self._router_two_phase_enable:
                                proj_should_apply = is_phase1
                            else:
                                # Two-phase disabled: "phase1" doesn't exist, treat as "both" (user-friendly)
                                proj_should_apply = True
                                global_step = self.state.global_step if hasattr(self, 'state') and hasattr(self.state, 'global_step') else 0
                        elif self._router_proj_anchor_apply_phase == "phase2" and not is_phase1:
                            proj_should_apply = True
                        
                        if proj_should_apply:
                            proj_anchor_loss = self._compute_router_proj_anchor_loss()
                            # DIAGNOSTICS: Print projection anchor loss and gradient flow at key steps
                            global_step = self.state.global_step if hasattr(self, 'state') and hasattr(self.state, 'global_step') else 0
                            
                            if proj_anchor_loss is not None and proj_anchor_loss > 0:
                                if not hasattr(self, '_router_proj_anchor_loss_sum'):
                                    self._router_proj_anchor_loss_sum = 0.0
                                    self._router_proj_anchor_loss_count = 0
                                    self._router_proj_anchor_weighted_sum = 0.0
                                proj_anchor_loss_val = proj_anchor_loss.item()
                                weighted_proj_anchor_loss = self._router_proj_anchor_lambda * proj_anchor_loss
                                self._router_proj_anchor_loss_sum += proj_anchor_loss_val
                                self._router_proj_anchor_loss_count += 1
                                self._router_proj_anchor_weighted_sum += weighted_proj_anchor_loss.item()
                                total_loss = total_loss + weighted_proj_anchor_loss
                                
            except Exception as e:
                if self.state.global_step % 100 == 0:
                    import traceback
                    print(f"  Warning: Router loss computation failed: {e}")
                    traceback.print_exc()
        
        if return_outputs:
            return total_loss, outputs
        return total_loss
    
    
    def _compute_routing_loss_for_batch(
        self,
        model: nn.Module,
        inputs: Dict[str, torch.Tensor],
        outputs: Any,
        apply_replay_multiplier: bool = False,
        replay_loss_multiplier: float = 1.0,
        use_soft_targets: Optional[bool] = None,
        soft_target_eps: Optional[float] = None,
    ) -> Optional[torch.Tensor]:
        """
        Compute routing loss for the current batch.
        
        Steps:
        1. Get hidden states from model outputs
        2. Extract model_name and domain from batch metadata
        3. Build candidate sets for each example
        4. Compute routing logits
        5. Compute loss (hard or soft targets)
        6. Optionally mine hard negatives
        
        Returns:
            Routing loss tensor, or None if computation fails
        """
        if self._router_model is None or self._router_registry is None:
            return None
        
        batch_size = inputs["input_ids"].shape[0]
        device = inputs["input_ids"].device
        
        # Get hidden states
        hidden_states = get_last_hidden_states(outputs)
        if hidden_states is None:
            enable_hidden_states(model)
            return None
        
        # Note: outputs.hidden_states contains all layer activations which can be large.
        # get_last_hidden_states() extracts only the last layer, but intermediate layers
        
        # Get labels and attention mask
        labels = inputs.get("labels")
        if labels is None:
            return None
        attention_mask = inputs.get("attention_mask")
        
        # Extract metadata from batch
        if "model_name" not in inputs or "domain" not in inputs:
            return None
        
        model_names = inputs["model_name"]
        domains = inputs["domain"]
        
        # Convert to lists if needed
        if not isinstance(model_names, list):
            model_names = list(model_names) if hasattr(model_names, '__iter__') else [str(model_names)] * batch_size
        if not isinstance(domains, list):
            domains = list(domains) if hasattr(domains, '__iter__') else [str(domains)] * batch_size
        
        # Convert model names to indices with canonical normalization
        # Helper function for normalized lookup (model2idx stores original names        
        def _get_model_idx(model_name: str) -> Optional[int]:
            """Get model index with normalized lookup (case-insensitive)."""
            # First try direct lookup (fast path)
            if model_name in self._router_registry.model2idx:
                return self._router_registry.model2idx[model_name]
            # Then try normalized lookup (handles case/normalization mismatches)
            normalized = normalize_model_name(model_name)
            for existing_name, idx in self._router_registry.model2idx.items():
                if normalize_model_name(existing_name) == normalized:
                    return idx
            return None
        
        # Get is_replay flags if available (for tracking replay vs non-replay unknowns)
        is_replay_list = None
        if "is_replay" in inputs:
            is_replay_raw = inputs["is_replay"]
            if isinstance(is_replay_raw, list):
                is_replay_list = is_replay_raw
            elif hasattr(is_replay_raw, '__iter__'):
                is_replay_list = list(is_replay_raw)
            else:
                is_replay_list = [bool(is_replay_raw)] * batch_size
        
        # Track stats BEFORE normalization (for diagnostics)
        unknown_before_norm = []
        unknown_replay_before = []
        unknown_nonreplay_before = []
        
        # Track stats AFTER normalization
        y_indices = []
        valid_indices = []
        unknown_models = []
        unknown_replay = []
        unknown_nonreplay = []
        
        for i, model_name in enumerate(model_names):
            # Check BEFORE normalization
            if model_name not in self._router_registry.model2idx:
                unknown_before_norm.append(model_name)
                if is_replay_list is not None and i < len(is_replay_list) and is_replay_list[i]:
                    unknown_replay_before.append(model_name)
                else:
                    unknown_nonreplay_before.append(model_name)
            
            # Try normalized lookup
            model_idx = _get_model_idx(model_name)
            if model_idx is not None:
                y_indices.append(model_idx)
                valid_indices.append(i)
            else:
                # Still unknown after normalization
                unknown_models.append(model_name)
                if is_replay_list is not None and i < len(is_replay_list) and is_replay_list[i]:
                    unknown_replay.append(model_name)
                else:
                    unknown_nonreplay.append(model_name)
        
        # Track filtering stats
        num_filtered = len(model_names) - len(y_indices)
        num_unknown = len(unknown_models)
        num_unknown_before_norm = len(unknown_before_norm)
        num_unknown_after_norm = len(unknown_models)
        num_fixed_by_norm = num_unknown_before_norm - num_unknown_after_norm
        
        # Accumulate filtering stats for periodic logging (with normalization diagnostics)
        if not hasattr(self, '_router_filter_stats'):
            self._router_filter_stats = {
                'total_examples': 0, 
                'total_filtered': 0, 
                'total_unknown': 0,
                'unknown_before_norm': 0,
                'unknown_after_norm': 0,
                'fixed_by_norm': 0,
                'unknown_replay_before': 0,
                'unknown_nonreplay_before': 0,
                'unknown_replay_after': 0,
                'unknown_nonreplay_after': 0,
            }
        self._router_filter_stats['total_examples'] += len(model_names)
        self._router_filter_stats['total_filtered'] += num_filtered
        self._router_filter_stats['total_unknown'] += num_unknown
        self._router_filter_stats['unknown_before_norm'] += num_unknown_before_norm
        self._router_filter_stats['unknown_after_norm'] += num_unknown_after_norm
        self._router_filter_stats['fixed_by_norm'] += num_fixed_by_norm
        self._router_filter_stats['unknown_replay_before'] += len(unknown_replay_before)
        self._router_filter_stats['unknown_nonreplay_before'] += len(unknown_nonreplay_before)
        self._router_filter_stats['unknown_replay_after'] += len(unknown_replay)
        self._router_filter_stats['unknown_nonreplay_after'] += len(unknown_nonreplay)
        
        # Log normalization fix impact periodically
        if self.state.global_step % 100 == 0 and self._router_filter_stats['total_examples'] > 0:
            total = self._router_filter_stats['total_examples']
            unknown_before = self._router_filter_stats['unknown_before_norm']
            unknown_after = self._router_filter_stats['unknown_after_norm']
            fixed = self._router_filter_stats['fixed_by_norm']
            replay_before = self._router_filter_stats['unknown_replay_before']
            replay_after = self._router_filter_stats['unknown_replay_after']
            nonreplay_before = self._router_filter_stats['unknown_nonreplay_before']
            nonreplay_after = self._router_filter_stats['unknown_nonreplay_after']
            
            if unknown_before > 0:
                print(f"  [Model Name Normalization] Step {self.state.global_step}:")
                print(f"    Unknown before norm: {unknown_before} ({100*unknown_before/total:.2f}%)")
                print(f"    Unknown after norm: {unknown_after} ({100*unknown_after/total:.2f}%)")
                print(f"    Fixed by normalization: {fixed} ({100*fixed/max(1,unknown_before):.1f}% of unknowns)")
                if replay_before > 0 or replay_after > 0:
                    print(f"    Replay unknowns: {replay_before} → {replay_after} (fixed: {replay_before - replay_after})")
                if nonreplay_before > 0 or nonreplay_after > 0:
                    print(f"    Non-replay unknowns: {nonreplay_before} → {nonreplay_after} (fixed: {nonreplay_before - nonreplay_after})")
                
                # Debug mode: print a few dropped replay model names
                if hasattr(self, '_router_debug_unknown_models') and self._router_debug_unknown_models:
                    if unknown_replay:
                        print(f"    [Debug] Sample dropped replay models (after norm): {unknown_replay[:3]}")
                        if len(unknown_replay) > 3:
                            print(f"    [Debug] ... and {len(unknown_replay) - 3} more")
        
        # Warn if replay examples are being dropped
        if len(unknown_replay) > 0 and self.state.global_step % 50 == 0:
            print(f"  ⚠️  [Replay Drop Warning] {len(unknown_replay)} replay examples dropped due to unknown models (step {self.state.global_step})")
        
        if not y_indices:
            return None
                
        # Normalize domains for consistent lookups
        normalized_domains = [normalize_domain(domains[i]) for i in valid_indices]
        
        # Build candidate sets
        candidate_list = self._router_candidate_builder.build_batch(
            y_indices=y_indices,
            domains=normalized_domains,
            hard_negative_cache=self._router_hard_negative_cache,
        )
        
        # =====================================================================
        # B) Candidate hygiene filtering (before building logits / loss)
        # =====================================================================
        from .router_training import filter_and_validate_candidates
        
        # Get gold model names for filtering
        gold_model_names = [self._router_registry.idx2model.get(y_idx, f"unknown_idx_{y_idx}") for y_idx in y_indices]
        
        # Filter and validate candidates
        filtered_candidates, updated_y_indices, filter_stats = filter_and_validate_candidates(
            candidates_list=candidate_list,
            gold_model_names=gold_model_names,
            gold_indices=y_indices,
            registry=self._router_registry,
            K_total=self._router_candidate_builder.K_total,
            debug=(self.state.global_step < 3),
        )
        
        # Update y_indices if gold was re-inserted (shouldn't happen, but be safe)
        y_indices = updated_y_indices
        candidate_list = filtered_candidates
        
        # Convert to tensor
        candidate_indices = torch.tensor(candidate_list, dtype=torch.long, device=device)  # [B', K]
        
        # Store candidate indices and gold indices for anchor loss computation (if enabled)
        if self._router_anchor_enable and self._router_anchor_scope == "touched":
            # Store for use in anchor loss computation
            self._router_anchor_candidate_indices = candidate_indices
            self._router_anchor_y_indices = y_indices
        
        # =====================================================================
        # A) Strict label↔candidate alignment assertion
        # =====================================================================
        from .router_training import check_label_candidate_alignment
        
        # Check alignment: candidates_i[gold_index_i] == gold_model
        # Note: gold_index_i should always be 0 (positive at index 0)
        check_label_candidate_alignment(
            candidates_list=candidate_list,
            gold_model_names=gold_model_names,
            gold_indices=[0] * len(y_indices),  # Always 0 since positive is at index 0
            registry=self._router_registry,
            debug=(self.state.global_step < 3),
        )
        
        
        # Filter inputs to valid indices
        hidden_states_valid = hidden_states[valid_indices]  # [B', seq_len, D]
        labels_valid = labels[valid_indices]  # [B', seq_len]
        attention_mask_valid = attention_mask[valid_indices] if attention_mask is not None else None
        
        # Extract prompt_len from batch (added by RouterDataCollator)
        prompt_len = inputs.get("prompt_len")
        if prompt_len is None:
            # Fallback: infer from labels (but this is unreliable if labels are wrong)
            print(f"  ⚠️ WARNING: prompt_len not found in batch! Falling back to label inference.")
            print(f"             This indicates RouterDataCollator is not being used.")
            # Count leading -100s per example as a fallback
            prompt_len = torch.zeros(batch_size, dtype=torch.long, device=device)
            for i in range(batch_size):
                for j in range(labels.shape[1]):
                    if labels[i, j] != -100:
                        prompt_len[i] = j
                        break
                else:
                    # All -100, use half as fallback
                    prompt_len[i] = max(1, labels.shape[1] // 2)
        
        prompt_len_valid = prompt_len[valid_indices]  # [B']
        
        # Get neighbor indices for soft targets if needed
        # Use effective soft target settings (from two-phase training if enabled)
        effective_use_soft_targets_local = use_soft_targets if use_soft_targets is not None else self._router_use_soft_targets
        neighbor_indices = None
        if effective_use_soft_targets_local:
            neighbor_indices = []
            for y_idx in y_indices:
                neighbors = self._router_registry.get_neighbors(
                    y_idx, 
                    k=self._router_soft_target_k_neighbors
                )
                neighbor_indices.append(neighbors)
        
    
        # Extract prompt mask for routing loss computation
        prompt_mask = extract_prompt_mask(
            prompt_len_valid,
            attention_mask_valid if attention_mask_valid is not None else torch.ones_like(labels_valid),
            labels=labels_valid if (self.state.global_step < 3) else None,
            debug=False,
            global_step=self.state.global_step,
        )
        
        # Compute routing loss (with debug for first 3 steps)
        # Use effective soft target settings (from two-phase training if enabled)
        effective_soft_target_eps_local = soft_target_eps if soft_target_eps is not None else self._router_soft_target_eps
        
        # Request per-example losses if replay multiplier is enabled
        need_per_example = apply_replay_multiplier and replay_loss_multiplier > 1.0
        
        if need_per_example:
            loss_mean, loss_per_example, accuracy_metrics = compute_routing_loss(
                router_model=self._router_model,
                hidden_states=hidden_states_valid,
                labels=labels_valid,
                attention_mask=attention_mask_valid,
                candidate_indices=candidate_indices,
                prompt_len=prompt_len_valid,
                use_soft_targets=effective_use_soft_targets_local,
                soft_target_eps=effective_soft_target_eps_local,
                neighbor_indices=neighbor_indices,
                device=device,
                return_accuracy=True,
                return_per_example=True,  # Request per-example losses
                debug=(self.state.global_step < 3),
                global_step=self.state.global_step,
                # Debug parameters
                debug_router_supervision=self._debug_router_supervision,
                debug_router_every=self._debug_router_every,
                debug_router_first_steps=self._debug_router_first_steps,
                debug_router_strict=self._debug_router_strict,
                gold_model_names=gold_model_names,
                domains=normalized_domains,
                registry=self._router_registry,
                candidate_builder=self._router_candidate_builder,
                hard_negative_cache=self._router_hard_negative_cache,
                micro_idx=0,  # TODO: track microbatch index if using gradient accumulation
            )
            
            # Apply replay loss multiplier to per-example losses
            is_replay = inputs.get("is_replay", None)
            if is_replay is not None:
                # Convert to list if needed
                if not isinstance(is_replay, list):
                    is_replay = list(is_replay) if hasattr(is_replay, '__iter__') else [bool(is_replay)] * batch_size
                
                # Filter to valid indices (replay instrumentation)
                is_replay_valid = torch.tensor([is_replay[i] for i in valid_indices], dtype=torch.float32, device=device)
                num_replay_valid = is_replay_valid.sum().item()
                num_total_valid = len(is_replay_valid)
                replay_valid_fraction = num_replay_valid / num_total_valid if num_total_valid > 0 else 0.0
                
                # Count replay in original batch (before filtering)
                num_replay_original = sum(1 for r in is_replay if r)
                num_total_original = len(is_replay)
                replay_ratio_original = num_replay_original / num_total_original if num_total_original > 0 else 0.0
                
                # Apply multiplier: w = 1 + (multiplier - 1) * is_replay
                weights = 1.0 + (replay_loss_multiplier - 1.0) * is_replay_valid
                weighted_loss_per_example = loss_per_example * weights
                loss = weighted_loss_per_example.mean()
    
                
                # Warn if replay_ratio > 0 but replay_valid_fraction is ~0 for many steps
                if replay_ratio_original > 0.1 and replay_valid_fraction < 0.01:
                    if not hasattr(self, '_replay_drop_warning_count'):
                        self._replay_drop_warning_count = 0
                    self._replay_drop_warning_count += 1
                    if self._replay_drop_warning_count <= 5 or self.state.global_step % 200 == 0:
                        print(f"  ⚠️  [Replay Drop Warning] Step {self.state.global_step}: "
                              f"Original batch has {replay_ratio_original:.1%} replay, "
                              f"but only {replay_valid_fraction:.1%} in valid examples. "
                              f"Replay may be dropped by unknown model filtering!")
            else:
                # No is_replay flag, use mean loss
                loss = loss_mean
        else:
            loss, accuracy_metrics = compute_routing_loss(
                router_model=self._router_model,
                hidden_states=hidden_states_valid,
                labels=labels_valid,
                attention_mask=attention_mask_valid,
                candidate_indices=candidate_indices,
                prompt_len=prompt_len_valid,
                use_soft_targets=effective_use_soft_targets_local,
                soft_target_eps=effective_soft_target_eps_local,
                neighbor_indices=neighbor_indices,
                device=device,
                return_accuracy=True,
                return_per_example=False,
                debug=(self.state.global_step < 3),
                global_step=self.state.global_step,
                # Debug parameters
                debug_router_supervision=self._debug_router_supervision,
                debug_router_every=self._debug_router_every,
                debug_router_first_steps=self._debug_router_first_steps,
                debug_router_strict=self._debug_router_strict,
                gold_model_names=gold_model_names,
                domains=normalized_domains,
                registry=self._router_registry,
                candidate_builder=self._router_candidate_builder,
                hard_negative_cache=self._router_hard_negative_cache,
                micro_idx=0,  # TODO: track microbatch index if using gradient accumulation
            )
        
        # =====================================================================
        # E) Debug report (behind ROUTER_DEBUG flag, first N batches)
        # =====================================================================
        if ROUTER_DEBUG and self.state.global_step < 3:
            from .router_training import print_router_debug_report
            
            # Compute logits for debug report
            with torch.no_grad():
                logits_debug = self._router_model(
                    hidden_states_valid,
                    prompt_mask,
                    candidate_indices,
                    debug=False,
                )
            
            # Map pooling mode for display
            pooling_mode_str = getattr(self._router_model, 'pooling', 'last_token')
            pooling_display = 'last' if pooling_mode_str == 'last_token' else 'mean'
            
            print_router_debug_report(
                batch_size=len(y_indices),
                domains=normalized_domains,
                K=self._router_candidate_builder.K_total,
                pooling_mode=pooling_display,
                model_names=gold_model_names,
                gold_indices=y_indices,
                candidates_list=candidate_list,
                logits=logits_debug,
                registry=self._router_registry,
                prompt_mask=prompt_mask,
                candidate_stats=filter_stats,
                num_examples=3,
            )
        
        # Accumulate compute metrics if available
        if accuracy_metrics is not None and "compute" in accuracy_metrics:
            compute_metrics = accuracy_metrics["compute"]
            # Get current experience name if available
            experience_name = getattr(self, '_current_experience_name', None)
            self._compute_tracker.accumulate(
                compute_metrics,
                phase="training",
                experience=experience_name,
            )
        

        # Track accuracy metrics (exclude "compute" which is handled separately)
        if accuracy_metrics:
            # Filter out "compute" key as it's a dict, not a float metric
            accuracy_only = {k: v for k, v in accuracy_metrics.items() if k != "compute"}
            
            if accuracy_only:
                if not hasattr(self, '_router_accuracy_sum'):
                    self._router_accuracy_sum = {k: 0.0 for k in accuracy_only.keys()}
                    self._router_accuracy_count = 0
                
                for key, value in accuracy_only.items():
                    if isinstance(value, (int, float)):
                        self._router_accuracy_sum[key] += value
                self._router_accuracy_count += 1
        
        # Compute metrics (no_grad to avoid storing gradients - metrics don't need them)
        with torch.no_grad():
            metrics = compute_router_metrics(
                logits=self._router_model(
                    hidden_states_valid, 
                    extract_prompt_mask(
                        prompt_len_valid, 
                        attention_mask_valid if attention_mask_valid is not None else torch.ones_like(labels_valid),
                        debug=False, 
                        global_step=0
                    ), 
                    candidate_indices
                ),
                candidate_indices=candidate_indices,
                candidate_builder=self._router_candidate_builder,
                y_indices=y_indices,
                domains=normalized_domains,
                hard_negative_cache=self._router_hard_negative_cache,
            )
        
        # Compute domain diversity in batch (use normalized domains)
        unique_domains = len(set(normalized_domains))
        pct_same_domain = (normalized_domains.count(normalized_domains[0]) / len(normalized_domains) * 100) if normalized_domains else 0
        
        # Log comprehensive metrics
        should_log = (self.state.global_step < 10) or (self.state.global_step % 100 == 0)
        

        # Mine hard negatives periodically
        if self.state.global_step % self._router_mine_every_steps == 0 and self.state.global_step > 0:
            # Prepare examples for mining
            batch_examples = []
            for i, idx in enumerate(valid_indices):
                # Get prompt embedding for mining
                prompt_mask = extract_prompt_mask(
                    prompt_len[idx:idx+1],
                    attention_mask[idx:idx+1] if attention_mask is not None else torch.ones(1, labels.shape[1], device=device),
                    debug=False, 
                    global_step=0
                )
                prompt_emb = self._router_model.encode_prompt(hidden_states[idx:idx+1], prompt_mask).squeeze(0)
                
                batch_examples.append({
                    'model_idx': y_indices[i],
                    'domain': domains[idx],
                    'prompt_embedding': prompt_emb.detach(),
                })
            
            # Update cache
            self._router_hard_miner.update_cache(
                batch_examples=batch_examples,
                router_model=self._router_model,
                max_examples=128,
            )
            
            # The candidate builder uses self._router_hard_negative_cache, not miner.cache
            self._router_hard_negative_cache.update(self._router_hard_miner.cache)
            
            # Log mining stats
            stats = self._router_hard_miner.get_stats()
            print(f"  [Hard Mining @ step {self.state.global_step}] "
                  f"Updates: {stats['num_updates']}, "
                  f"Examples processed: {stats['num_examples_processed']}, "
                  f"Cache size: {stats['cache_size']} (trainer cache: {len(self._router_hard_negative_cache)})")
            
            # INVARIANT CHECK: Hard negative mining correctness
            if self._router_debug_checker.should_check(self.state.global_step):
                mining_check = self._router_debug_checker.check_hard_mining_invariants(
                    hard_cache=self._router_hard_negative_cache,
                    K_hard=self._router_candidate_builder.K_hard,
                    global_step=self.state.global_step,
                    min_cache_entries=5,
                )
                if not mining_check["passed"]:
                    print(f"  ⚠️ [Hard Mining Check]: {mining_check.get('reason', 'unknown')}")
                else:
                    print(f"  ✅ [Hard Mining Check]: Cache has {mining_check['cache_size']} entries")
        
        return loss
    
    def _compute_router_anchor_loss(self) -> Optional[torch.Tensor]:
        """
        Compute router embedding anchor loss to preserve exp1 routing.
        
        This loss penalizes drift of old embedding rows (indices < M_old) away from
        a reference snapshot taken immediately after loading/resizing from exp1 checkpoint.
        
        Returns:
            Anchor loss tensor (scalar), or None if computation fails
        """
        if self._router_model is None:
            if self.state.global_step < 5:
                print(f"  [Anchor Loss Debug @ step {self.state.global_step}] router_model is None")
            return None
        
        if self._router_anchor_ref_cpu is None:
            if self.state.global_step < 5:
                print(f"  [Anchor Loss Debug @ step {self.state.global_step}] router_anchor_ref_cpu is None")
            return None
        
        M_old = self._router_anchor_M_old
        if M_old is None or M_old <= 0:
            if self.state.global_step < 5:
                print(f"  [Anchor Loss Debug @ step {self.state.global_step}] M_old is None or <= 0: {M_old}")
            return None
        
        device = self._router_model.model_embeddings.weight.device
        dtype = self._router_model.model_embeddings.weight.dtype
        
        # Materialize device copy of reference (cache it to avoid repeated CPU→GPU transfers)
        if self._router_anchor_ref is None or self._router_anchor_ref.device != device or self._router_anchor_ref.dtype != dtype:
            self._router_anchor_ref = self._router_anchor_ref_cpu.to(device=device, dtype=dtype)
        
        # Get current old embedding rows
        E_old = self._router_model.model_embeddings.weight[:M_old]  # [M_old, D]
        
     
        # Determine which rows to anchor based on scope
        if self._router_anchor_scope == "touched":
            # Anchor only rows that appear in current step's candidate set + gold
            if not hasattr(self, '_router_anchor_candidate_indices') or not hasattr(self, '_router_anchor_y_indices'):
                # Fallback to all_old if candidate indices not available
                row_indices = torch.arange(M_old, device=device)
            else:
                candidate_indices = self._router_anchor_candidate_indices  # [B, K] tensor or list-of-lists
                y_indices = self._router_anchor_y_indices  # List of gold indices
                
                # Collect all unique model indices from candidates and gold
                all_indices = set()
                # Add gold indices
                for y_idx in y_indices:
                    if isinstance(y_idx, torch.Tensor):
                        y_idx = y_idx.item()
                    if y_idx < M_old:
                        all_indices.add(int(y_idx))
                
                # Add candidate indices (handle both tensor and list-of-lists)
                if candidate_indices is not None:
                    if isinstance(candidate_indices, torch.Tensor):
                        # Tensor: [B, K]
                        candidate_flat = candidate_indices.flatten().cpu().numpy()
                    else:
                        # List-of-lists: flatten manually
                        import numpy as np
                        candidate_flat = []
                        for cand_list in candidate_indices:
                            if isinstance(cand_list, (list, tuple)):
                                candidate_flat.extend(cand_list)
                            else:
                                candidate_flat.append(cand_list)
                        candidate_flat = np.array(candidate_flat)
                    
                    for idx in candidate_flat:
                        idx_int = int(idx)
                        if idx_int < M_old:
                            all_indices.add(idx_int)
                
                if not all_indices:
                    # No old rows touched, return zero loss
                    return torch.tensor(0.0, device=device, dtype=dtype, requires_grad=True)
                
                row_indices = torch.tensor(sorted(all_indices), dtype=torch.long, device=device)
                E_old = E_old[row_indices]  # [num_touched, D]
                # Get corresponding reference rows
                ref_rows = self._router_anchor_ref[row_indices]  # [num_touched, D]
        else:
            # all_old: anchor all rows < M_old
            row_indices = None
            ref_rows = self._router_anchor_ref  # [M_old, D]
        
        # Compute anchor loss based on mode
        # Ensure float32 computation to avoid tiny-number underflow
        E_old_float = E_old.float()
        ref_rows_float = ref_rows.float()
        
        if self._router_anchor_mode == "normalized":
            # Row-level cosine distance: mean over rows of (1 - cosine_similarity)
            # This is more interpretable and avoids element-wise averaging issues
            E = F.normalize(E_old_float, p=2, dim=-1)  # [num_rows, D]
            R = F.normalize(ref_rows_float, p=2, dim=-1)  # [num_rows, D]
            # Cosine similarity per row: dot product of normalized vectors
            cos = (E * R).sum(dim=-1).clamp(-1, 1)  # [num_rows]
            # Cosine distance: 1 - cos (0 when identical, 2 when opposite)
            one_minus_cos = 1 - cos  # [num_rows]
            anchor_loss = one_minus_cos.mean()  # Mean over rows only
            
            # Store for diagnostics
            num_rows_anchored = one_minus_cos.shape[0]
            mean_one_minus_cos = one_minus_cos.mean().item()
            max_one_minus_cos = one_minus_cos.max().item()
        else:
            # Raw L2 anchor: ||E_old - E_old_ref||^2
            diff = E_old_float - ref_rows_float
            anchor_loss = diff.pow(2).mean()  # Mean over all elements
            num_rows_anchored = E_old_float.shape[0]
            mean_one_minus_cos = None
            max_one_minus_cos = None
     
        return anchor_loss
    
    def _compute_router_proj_anchor_loss(self) -> Optional[torch.Tensor]:
        """
        Compute router projection anchor loss to preserve exp1 projection.
        
        This loss penalizes drift of projection weights away from a reference
        snapshot taken immediately after loading/resizing from exp1 checkpoint.
        
        CRITICAL: Uses named_parameters() (not state_dict()) to ensure gradients
        flow through the computation graph.
        
        Returns:
            Projection anchor loss tensor (scalar), or None if computation fails
        """
        if self._router_model is None:
            return None
        
        if self._router_proj_anchor_ref_cpu is None:
            return None
        
        proj = self._router_model.prompt_projection
        device = next(proj.parameters()).device
        dtype = next(proj.parameters()).dtype
        
        # Materialize cached device copy; also handle dtype changes
        if (self._router_proj_anchor_ref is None or
            next(iter(self._router_proj_anchor_ref.values())).device != device or
            next(iter(self._router_proj_anchor_ref.values())).dtype != dtype):
            self._router_proj_anchor_ref = {
                k: v.to(device=device, dtype=dtype)
                for k, v in self._router_proj_anchor_ref_cpu.items()
            }
        
        # Compute MSE loss for each parameter tensor
        losses = []
        for name, p in proj.named_parameters():
            if name not in self._router_proj_anchor_ref:
                continue
            ref = self._router_proj_anchor_ref[name]
            # Compute MSE: mean((W - W_ref)^2) with float32 for stability
            param_loss = (p.float() - ref.float()).pow(2).mean()
            losses.append(param_loss)
        
        if not losses:
            return None
        
        # Average across all parameter tensors
        proj_anchor_loss = torch.stack(losses).mean()
        
        return proj_anchor_loss
    
    def _compute_router_graph_regularizer_for_batch(
        self,
        inputs: Dict[str, torch.Tensor],
    ) -> Optional[torch.Tensor]:
        """
        Compute label-side graph regularizer for router embeddings.
        
        Aligns learned model embeddings with taxonomy structure.
        """
        if self._router_model is None or self._router_registry is None:
            return None
        
        # Need candidate indices from last forward pass
        # For simplicity, we'll compute this separately
        # In practice, you'd cache candidate_indices from routing loss computation
        
        batch_size = inputs["input_ids"].shape[0]
        device = inputs["input_ids"].device
        
        # Extract metadata
        if "model_name" not in inputs or "domain" not in inputs:
            return None
        
        model_names = inputs["model_name"]
        domains = inputs["domain"]
        
        # Convert to lists
        if not isinstance(model_names, list):
            model_names = list(model_names) if hasattr(model_names, '__iter__') else [str(model_names)] * batch_size
        if not isinstance(domains, list):
            domains = list(domains) if hasattr(domains, '__iter__') else [str(domains)] * batch_size
        
        # Convert to indices with canonical normalization
        def _get_model_idx(model_name: str) -> Optional[int]:
            """Get model index with normalized lookup (case-insensitive)."""
            if model_name in self._router_registry.model2idx:
                return self._router_registry.model2idx[model_name]
            normalized = normalize_model_name(model_name)
            for existing_name, idx in self._router_registry.model2idx.items():
                if normalize_model_name(existing_name) == normalized:
                    return idx
            return None
        
        y_indices = []
        for model_name in model_names:
            model_idx = _get_model_idx(model_name)
            if model_idx is not None:
                y_indices.append(model_idx)
        
        if not y_indices:
            return None
        
        # Build candidates (simplified - just use model indices from batch)
        candidate_indices = torch.tensor([y_indices], dtype=torch.long, device=device)  # [1, B]
        
        loss = compute_label_graph_regularizer(
            router_model=self._router_model,
            candidate_indices=candidate_indices,
            registry=self._router_registry,
            tau=self._router_label_graph_tau,
            tau_target=self._router_label_graph_tau_target,
            alpha_domain=self._router_label_graph_alpha_domain,
            max_models=self._router_label_graph_max_models,
            device=device,
        )
        
        return loss
    
    def get_router_metrics(self) -> Dict[str, float]:
        """Get average router metrics for logging."""
        metrics = {}
        if self._router_loss_count > 0:
            avg_router = self._router_loss_sum / self._router_loss_count
            avg_supervised = self._supervised_loss_sum / max(1, self._supervised_loss_count)
            metrics["avg_router_loss"] = avg_router
            metrics["avg_supervised_loss"] = avg_supervised
            if avg_supervised > 0:
                metrics["router_to_supervised_ratio"] = avg_router / avg_supervised
        
        if self._router_graph_loss_count > 0:
            avg_graph = self._router_graph_loss_sum / self._router_graph_loss_count
            metrics["avg_graph_loss"] = avg_graph
        
        # Accuracy metrics
        if hasattr(self, '_router_accuracy_sum') and self._router_accuracy_count > 0:
            for key, value_sum in self._router_accuracy_sum.items():
                avg_value = value_sum / self._router_accuracy_count
                metrics[f"avg_{key}"] = avg_value
        
        # Hard mining stats
        if self._router_hard_miner:
            stats = self._router_hard_miner.get_stats()
            metrics.update({
                "hard_mining_updates": stats["num_updates"],
                "hard_mining_examples_processed": stats["num_examples_processed"],
                "hard_mining_cache_size": stats["cache_size"],
            })
        
        # Compute metrics (FLOPs tracking)
        if hasattr(self, '_compute_tracker'):
            compute_summary = self._compute_tracker.get_summary()
            # Add key compute metrics to router metrics
            if compute_summary["total_examples"] > 0:
                metrics["total_flops"] = compute_summary["total_flops"]
                metrics["total_flops_gflops"] = compute_summary["total_flops_gflops"]
                metrics["flops_per_example"] = compute_summary["flops_per_example"]
                metrics["total_examples_processed"] = compute_summary["total_examples"]
                metrics["total_batches_processed"] = compute_summary["total_batches"]
        
        return metrics
    
    
    def reset_consistency_metrics(self):
        """Reset tracking metrics (call at epoch start)."""
        self._consistency_loss_sum = 0.0
        self._consistency_loss_count = 0
        self._supervised_loss_sum = 0.0
        self._supervised_loss_count = 0
        
        # Reset router metrics
        if hasattr(self, '_router_loss_sum'):
            self._router_loss_sum = 0.0
            self._router_loss_count = 0
            self._router_graph_loss_sum = 0.0
            self._router_graph_loss_count = 0
        
        # Reset router accuracy metrics
        if hasattr(self, '_router_accuracy_sum'):
            for key in self._router_accuracy_sum.keys():
                self._router_accuracy_sum[key] = 0.0
            self._router_accuracy_count = 0
        self._neighbor_domain_stats = defaultdict(lambda: {"same": 0, "different": 0})
        