from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Optional

# Get the project root directory by searching for pyproject.toml
def _find_project_root() -> Path:
    """Find project root by looking for pyproject.toml marker file."""
    current = Path(__file__).resolve()
    for parent in [current] + list(current.parents):
        if (parent / "pyproject.toml").exists():
            return parent
    # Fallback to assuming standard structure
    return Path(__file__).parent.parent.parent

PROJECT_ROOT = _find_project_root()


@dataclass
class ApibenchDataConfig:
    train_set: str = str(PROJECT_ROOT / "data" / "processed" / "cleaned-apibench-hf-train.json")
    val_set: str = str(PROJECT_ROOT / "data" / "processed" / "cleaned-apibench-hf-val.json")
    test_set: str = str(PROJECT_ROOT / "data" / "processed" / "cleaned-apibench-hf-eval.json")
    model_date_cutoff: Optional[str] = "Jun 2023"  # Date cutoff for model selection (e.g., "Jun 2023")

@dataclass
class MLLMDataConfig:
    train_set: str = str(PROJECT_ROOT / "data" / "processed" / "cleaned-mllm-train.json")
    val_set: str = str(PROJECT_ROOT / "data" / "processed" / "cleaned-mllm-val.json")
    test_set: str = str(PROJECT_ROOT / "data" / "processed" / "cleaned-mllm-eval.json")
    model_date_cutoff: Optional[str] = "Oct 2024"  # Date cutoff for model selection (e.g., "Dec 2023")


@dataclass
class HuggingBench1DataConfig:
    train_set: str = str(PROJECT_ROOT / "data" / "processed" / "cleaned-hugging-bench-1-train.json")
    val_set: str = str(PROJECT_ROOT / "data" / "processed" / "cleaned-hugging-bench-1-val.json")
    test_set: str = str(PROJECT_ROOT / "data" / "processed" / "cleaned-hugging-bench-1-eval.json")
    model_date_cutoff: Optional[str] = None  # Date cutoff for model selection

@dataclass
class HuggingBench2DataConfig:
    train_set: str = str(PROJECT_ROOT / "data" / "processed" / "cleaned-hugging-bench-2-train.json")
    val_set: str = str(PROJECT_ROOT / "data" / "processed" / "cleaned-hugging-bench-2-val.json")
    test_set: str = str(PROJECT_ROOT / "data" / "processed" / "cleaned-hugging-bench-2-eval.json")
    model_date_cutoff: Optional[str] = None  # Date cutoff for model selection

@dataclass
class TrainConfig:
    # Model and experiment identifiers
    experience_name: str  # name of the experience, e.g., "apibench" "mllm"
    output_root: Path
    variant_name: str = "" # variant name for the experiment
    extra_info: str = ""  # any extra info to append to the output directory name
    

    repo_id: str = "huggyllama/llama-7b" # base model to use
    retriever: Optional[str] = None  # specify retriever if needed, e.g., "bm25", "sentence_transformer", "splade", "flagembedding"
    
    # System prompt configuration
    system_prompt: Optional[str] = ""  # custom system prompt, if None uses default gorilla_prompt
    system_prompt_format: Optional[str] = None  # specify system prompt format if needed, e.g., "gorilla_prompt", "gorilla_prompt_explanation", "gorilla_prompt_explanation_json"
    
    # Reproducibility
    seed: int = None  # random seed for reproducibility
    
    # Training hyperparameters
    epochs: int = 15
    batch_size: int = 4
    grad_accum: int = 2
    lr: float = 0.0005
    max_length: int = 1024
    max_grad_norm: float = 1.0
    packing: bool = False
    group_by_length: bool = True
    completion_only_loss: bool = True
    label_smoothing: float = 0.0

    # LoRA parameters
    lora_r: int = 32  # try 64
    lora_alpha: int = 32 # typically 2 * lora_r (128)
    lora_dropout: float = 0.1  # try 0.05
    target_modules: List[str] = field(
        default_factory=lambda: ['q_proj', 'k_proj', 'v_proj', 'o_proj'])

    # Checkpoint and evaluation options
    resume_from: Optional[str] = None
    lora_adapters: Optional[List[str]] = None  # list of LoRA adapters to use
    early_stopping_patience: int = 3  # check we are not overfitting
    early_stopping_threshold: float = 0.0
    no_validation: bool = False
    hyperparameters_search: bool = False  # whether to perform hyperparameter search
    eval_at_step0: bool = False  # Run evaluation immediately after loading checkpoint but before training (global_step==0)
    

    # Optimizer and scheduler
    weight_decay: float = 0.001
    warmup_steps: int = 10
    lr_scheduler_type: str = "linear"  # "warmup_stable_decay"
    optim: str = "adamw_torch"
    logging_steps: int = 1
    save_strategy: str = "epoch"
    save_total_limit: int = 3
    metric_for_best_model: str = "eval_loss"
    greater_is_better: bool = False


    activation_checkpointing: bool = True
    
    # Memory optimization options
    low_memory_mode: bool = False  # Enable memory optimizations: Flash Attention 2 (if available) and activation offloading
    use_quantization: bool = False
    
    # Preference optimization options (for DPO/RLHF/ORPO)
    preference_mode: bool = False  # If True, convert dataset to preference format (chosen/rejected pairs)
    negative_sampling_strategy: str = "same_domain"  # Strategy for sampling rejected responses: "same_domain", "cross_domain", or "random"
    num_rejections_per_example: int = 1  # Number of rejected responses per example (creates multiple preference pairs)
    
    # ORPO-specific parameters
    orpo_beta: float = 0.1  # ORPO beta parameter (λ in paper, controls relative ratio loss weight)
    max_prompt_length: int = 512  # Maximum length of the prompt for ORPO
    disable_dropout: bool = True  # Whether to disable dropout during ORPO training
    
    # Replay configuration for continual learning
    experiences_sequence: Optional[List[str]] = None  # List of experience names to train on sequentially, e.g., ["apibench", "mllm"]
    joint_training: bool = False  # If True, combine all experiences into one dataset for joint training (upper bound baseline)
    replay_percentage: Optional[float] = None  # Percentage of previous experience samples to replay (e.g., 0.1 for 10%)
    replay_num_samples: Optional[int] = None  # Fixed number of samples to replay from each previous experience
    # Note: If both replay_percentage and replay_num_samples are None, no replay is performed
    # If both are provided, replay_num_samples takes precedence
    # Note: If joint_training is True, replay is disabled (all data is available)
    
    # Few-shot retrieval replay configuration
    fewshot_top_k: int = 3  # Number of retrieved neighbors for few-shot augmentation
    fewshot_replay_ratio: float = 0.1  # Replay buffer ratio (default 10%)
    fewshot_max_card_tokens: int = 200  # Max tokens for model card snippets
    fewshot_dropout_prob: float = 0.0  # Probability of using original prompt without examples (0.0 = always use examples, 1.0 = never use examples)
    
    # ==========================================================================
    # Domain + Model-Aware Coreset Replay Configuration
    # ==========================================================================
    # These options control how replay examples are selected from previous 
    # experiences during continual learning. The coreset strategy aims to 
    # preserve diversity across domains and models, rather than random sampling.
    # 
    # CL Motivation: Random replay can undersample long-tail domains/models,
    # leading to catastrophic forgetting on those subpopulations. Coreset 
    # selection ensures balanced coverage and embedding-space diversity.
    # ==========================================================================
    replay_strategy: str = "random"  # "random" (baseline) or "domain_model_coreset" (new)
    replay_min_per_domain: int = 5  # Minimum examples to include per domain (floor)
    replay_max_per_domain: Optional[int] = None  # Optional cap per domain (None = no cap)
    replay_max_per_model: int = 3  # Max examples per model within each domain
    replay_embedding_source: str = "sentence_transformer"  # Embedding model for coreset: "sentence_transformer", "flagembedding", etc.
    replay_embedding_cache_dir: Optional[str] = None  # Directory to cache embeddings (None = cco/cache/embeddings)
    
    # ==========================================================================
    # Neighbour-Consistency Regularisation Configuration  
    # ==========================================================================
    # Adds an auxiliary loss term that encourages locally smooth predictions:
    # for each training example, we retrieve similar prompts from APIBench and
    # regularize the model's output distribution to be consistent between the
    # anchor and its retrieved neighbours.
    #
    # CL Motivation: This enforces local smoothness in the output space using
    # retrieved neighbours as anchors, helping preserve knowledge from previous
    # experiences without requiring a separate teacher model.
    # ==========================================================================
    use_neighbor_consistency: bool = False  # Enable neighbour-consistency loss (default off for baseline)
    neighbor_k: int = 3  # Number of neighbours to retrieve for consistency
    neighbor_consistency_weight: float = 0.1  # Weight of consistency loss relative to supervised loss
    neighbor_consistency_temperature: float = 1.0  # Temperature for softmax in KL divergence
    neighbor_source: str = "apibench"  # Source for neighbours: "apibench" or "current_train"
    
    # NEW: How many batch examples participate in consistency loss
    neighbor_max_consistency_samples: int = 4  # Default = current hard-coded behaviour
    
    # NEW: How many last tokens to pool over for KL loss
    neighbor_consistency_num_tokens: int = 1  # Default = current behaviour (last token only)
    
    # NEW: Domain-aware neighbour behaviour
    neighbor_domain_filter_mode: str = "none"  # Choices: "none", "soft", "hard", "strict"
    # - "none": ignore domains entirely
    # - "soft": add domain_bias to similarity when domains match
    # - "hard": prefer same-domain neighbors, fall back to cross-domain if needed
    # - "strict": only use same-domain neighbors, skip consistency if too few exist
    neighbor_domain_bias: float = 0.2  # Similarity bias for same-domain neighbors when filter_mode == "soft"
    neighbor_min_same_domain: int = 1  # Minimum same-domain neighbors required (for "strict" mode)
    
    # NEW: Replay-only consistency mode
    # When True, only applies consistency loss when the anchor prompt is also from the
    # replay buffer (not a new experience prompt). This avoids the problem of pushing
    # new model prompts toward old model outputs.
    neighbor_replay_only: bool = False  # Only use replay examples as anchors
    neighbor_replay_similarity_threshold: float = 0.95  # Similarity threshold to detect replay examples
    
    # NEW: Embedding model for neighbor retrieval
    neighbor_embedding_model: str = "all-mpnet-base-v2"  # Sentence transformer model for neighbor similarity
    
    # ======================================================================
    # Neighbor-Contrastive Regularisation (Hard Negative Mining)
    # ======================================================================
    # Instead of enforcing neighbor-consistency (which can be mis-specified when
    # similar prompts route to different models), this loss learns to discriminate
    # between semantically similar prompts that map to different models.
    #
    # For each anchor (preferably replay examples), we retrieve top-k similar prompts,
    # find neighbors with different model_ids as hard negatives, and add a ranking
    # loss that increases the score of the anchor's gold model vs negative models.
    # ======================================================================
    use_neighbor_contrastive: bool = False  # Enable contrastive ranking loss (default off)
    neighbor_contrastive_weight: float = 0.1  # Weight of contrastive loss vs supervised loss
    
    neighbor_contrastive_k: int = 3  # How many neighbours to retrieve per anchor
    neighbor_contrastive_num_negatives: int = 3  # How many unique negative model_ids to use (<= k)
    neighbor_contrastive_max_anchors_per_batch: int = 4  # Max anchors to apply contrastive loss per batch
    
    neighbor_contrastive_loss_type: str = "softplus"  # Loss type: "softplus" or "hinge"
    neighbor_contrastive_margin: float = 0.0  # Margin for hinge loss (only used when loss_type="hinge")
    
    neighbor_contrastive_apply_to: str = "replay_only"  # Apply to: "replay_only" or "all" examples
    
    # ======================================================================
    # X-Sample Contrastive Loss (X-CLR) Configuration
    # ======================================================================
    # X-CLR (ICLR 2025) replaces hard "one positive" targets with soft target
    # distributions derived from a similarity graph. This is particularly useful
    # for self-instruct data where near-duplicate prompts may map to different models.
    #
    # The loss is: L_xclr = cross_entropy(soft_target_distribution, learned_similarity)
    # Where soft targets come from either taxonomy (same model/domain) or model-card
    # embeddings.
    # ======================================================================
    use_xclr: bool = False  # Enable X-CLR loss
    xclr_weight: float = 0.1  # Weight λ for L_total = L_supervised + λ * L_xclr
    xclr_tau: float = 0.1  # Temperature for learned similarity distribution
    xclr_tau_target: float = 0.1  # Temperature for target soft distribution
    xclr_graph_mode: str = "taxonomy"  # Graph mode: "taxonomy" or "modelcard"
    xclr_domain_alpha: float = 0.3  # Similarity weight for same-domain, different-model pairs
    xclr_apply_to: str = "all"  # Apply to: "all" or "replay_only" examples
    xclr_proj_dim: int = 128  # Projection dimension for prompt embeddings
    xclr_max_anchors_per_batch: int = 8  # Max anchors per batch (for efficiency)
    xclr_log_every: int = 100  # Log X-CLR metrics every N steps
    xclr_modelcard_encoder: str = "all-mpnet-base-v2"  # SBERT model for modelcard embeddings
    xclr_modelcard_blend: float = 0.0  # Blend factor with taxonomy (0 = pure modelcard)
    
    # X-CLR Debug/Diagnostics Configuration
    xclr_debug: bool = False  # Enable detailed X-CLR diagnostics (target/pred distributions, gradients)
    xclr_debug_log_every: Optional[int] = None  # Log diagnostics every N steps (None = use xclr_log_every)
    xclr_debug_max_print: int = 1  # Max anchors to show detailed breakdown for
    xclr_stability_checks: bool = False  # Enable stability checks and assertions (K consistency, self-exclusion, positive ranking)
    
    # X-CLR Replay Candidate Pool Configuration (for xclr_apply_to="replay_only")
    # When xclr_apply_to="replay_only", many replay anchors may have no in-batch positives
    # under taxonomy mode. This candidate pool samples from the replay buffer to guarantee
    # same-domain positives for each anchor where possible.
    xclr_num_candidates: int = 31  # Number of candidates to sample per anchor (31 + anchor = 32)
    xclr_min_pos: int = 1  # Minimum same-domain candidates to include per anchor
    
    # X-CLR Prompt Similarity Configuration
    # Enables prompt similarity-based candidate sampling and target graph construction
    xclr_candidate_sampling: str = "domain"  # Candidate sampling strategy: "domain" or "prompt_similarity"
    xclr_target_graph: str = "taxonomy"  # Target graph mode: "taxonomy" or "prompt_similarity"
    xclr_promptsim_k: Optional[int] = None  # Number of candidates for prompt_similarity (defaults to xclr_num_candidates)
    xclr_promptsim_retriever_type: str = "sentence_transformer"  # Retriever type: "bm25", "sentence_transformer", "splade", "flagembedding"
    xclr_promptsim_exclude_self: bool = True  # Exclude anchor from candidates (default: True)
    xclr_promptsim_retrieve_over_all: bool = False  # If True, retrieve over all prompts (replay + current exp), else just replay pool
    xclr_promptsim_taxonomy_blend: float = 0.3  # Blend prompt similarity with taxonomy (0 = pure prompt similarity, 1 = pure taxonomy)
                                                # Recommended: 0.2-0.4 to ensure same-domain examples get higher scores
    
    # X-CLR Candidate Queue (Memory Bank) Configuration
    xclr_queue_size: int = 2048  # Maximum number of candidates to store in queue (FIFO)
    xclr_queue_use: bool = False  # Enable candidate queue (memory bank)
    xclr_queue_device: str = "cpu"  # Device to store queue embeddings ("cpu" to save VRAM, "cuda" for speed)
    xclr_num_queue_candidates: int = 256  # Number of candidates to sample from queue per batch
    xclr_force_positive_from_queue: bool = False  # If anchor has no positives in-batch, try to draw at least 1 from queue with same domain
    
    # X-CLR Stability Fixes Configuration
    # These parameters improve training stability by ensuring constant batch composition,
    # stable loss scaling, and optional projection-only early training.
    xclr_use_two_stream_sampler: bool = False  # If true, use TwoStreamBatchSampler for constant replay per batch
                                               # Reduces gradient noise by making replay anchors reproducible
                                               # Only works when xclr_apply_to="replay_only"
    xclr_replay_per_batch: Optional[int] = None  # Number of replay samples per batch (None = auto: min(max_anchors, batch_size//4))
                                                   # Only used when xclr_use_two_stream_sampler=true
    xclr_proj_lr: Optional[float] = None  # Learning rate for projection head (None = use args.learning_rate)
                                         # Allows fine-tuning projection separately from backbone
    xclr_stopgrad_base: bool = False  # If true, detach prompt_embeddings before projection (projection-only training)
                                       # Makes X-CLR updates only affect projection head, not base model
    xclr_stopgrad_warmup_steps: int = 0  # If > 0, use stop-grad for first N steps (projection-only early training)
                                         # After warmup, gradients flow to backbone
                                         # Useful for stabilizing early training: projection learns first, then backbone
    
    # NOTE on neighbor consistency design:
    # The consistency loss enforces similar outputs for similar prompts. However, in a
    # multi-model routing scenario, similar prompts may intentionally route to DIFFERENT
    # models. Therefore, consistency is most meaningful when:
    # 1. Neighbors are from the SAME domain (knowledge transfer within domain)
    # 2. We DON'T expect exact output matching (different models are okay)
    # 3. There are sufficient same-domain neighbors in the sparse replay buffer
    # Use "strict" mode to only apply consistency when same-domain neighbors exist.
    
    # ======================================================================
    # Semantic Batching + Routing Loss Configuration
    # ======================================================================
    # This section configures the semantic batching + candidate-set routing loss
    # system that replaces X-CLR as the primary training objective.
    #
    # The router learns to select the correct model from a candidate set containing:
    # - 1 positive model (the generator model)
    # - K_semantic semantic negatives (same domain)
    # - K_far far negatives (other domains)
    # - K_hard hard negatives (mined confusable models)
    #
    # Key features:
    # - Domain-homogeneous batching for semantic negative overlap
    # - Hard negative mining from confusable models
    # - Optional soft targets with graph neighbors
    # - Optional label-side graph regularization (X-CLR style on model embeddings)
    # ======================================================================
    
    # Loss mode selection
    loss_mode: str = "supervised"  # "supervised", "router", "router+graph", "xclr"
    # - "supervised": Only SFT loss (baseline)
    # - "router": SFT + routing loss
    # - "router+graph": SFT + routing + label-graph regularizer
    # - "xclr": SFT + X-CLR (backwards compatibility)
    
    # Loss weights
    router_loss_weight: float = 1.0  # Weight for routing loss
    lm_loss_weight: float = 1.0  # Weight for LM supervised loss
    # NOTE: Setting lm_loss_weight < 1.0 or = 0.0 can be useful when routing is primary objective
    
    # Semantic batching
    semantic_batching: bool = False  # Enable domain-homogeneous batching
    domains_per_batch: int = 1  # Number of domains per batch (1 = pure, >1 = mixed)
    mix_replay_in_semantic_batches: bool = True  # Mix replay examples into semantic batches
    
    # Router architecture
    router_embedding_dim: Optional[int] = None  # Defaults to hidden_size of base model
    router_tau: float = 0.07  # Temperature for scaling logits
    router_pooling: str = "last_token"  # Pooling strategy: "last_token" or "mean"
    # Router learning rates (split by parameter group for stability)
    router_proj_lr: Optional[float] = None  # Learning rate for projection head (None = use args.learning_rate)
    router_embedding_lr: Optional[float] = None  # Learning rate for embedding table (None = use args.learning_rate)
    # Recommended: router_proj_lr=3e-4 to 5e-4, router_embedding_lr=1e-4 to 3e-4
    # This helps stabilize training by reducing spiky updates from the embedding table
    
    # Candidate sampling
    router_K_total: int = 64  # Total candidates per example (including positive)
    router_K_semantic: int = 48  # Target semantic negatives (same domain)
    router_K_far: int = 8  # Target far negatives (other domains)
    router_K_hard: int = 7  # Target hard negatives from cache (7 + 1 positive = 64 - 8)
    
    # Hard negative mining
    router_mine_every_steps: int = 200  # Mine hard negatives every N steps
    router_K_hard_pool: int = 20  # Store top K confusable models in cache
    router_semantic_pool_size: int = 512  # Semantic pool size for mining (per domain)
    router_max_pool_size: int = 1024  # Maximum pool size (cap for large domains)
    
    # Semantic Pool Expansion (Option B)
    # Controls how semantic negatives are sampled: from exact domain only, or expanded to related domains
    router_semantic_pool_mode: str = "parent_group"  # "domain_only", "parent_group", "taxonomy_graph"
    # - "domain_only": Sample only from exact domain (original behavior, causes sparse domain issues)
    # - "parent_group": Expand to related domains via parent group (e.g., all "computer vision *" domains)
    # - "taxonomy_graph": Use explicit taxonomy graph (future extension)
    router_semantic_pool_max_domains: Optional[int] = None  # Max related domains to include (None = all)
    router_semantic_pool_depth: int = 1  # Graph traversal depth (for taxonomy_graph mode)
    
    # Soft targets (graph-smoothed supervision)
    router_use_soft_targets: bool = False  # Distribute small mass to graph neighbors
    router_soft_target_eps: float = 0.1  # Mass to distribute to neighbors (1-eps on positive)
    router_soft_target_k_neighbors: int = 5  # Number of neighbors to consider
    
    # Label-side graph regularizer (optional, X-CLR style on model embeddings)
    router_use_label_graph_reg: bool = False  # Enable label-side graph alignment
    router_label_graph_lambda: float = 0.1  # Weight for graph regularizer
    router_label_graph_tau: float = 0.07  # Temperature for predicted similarities
    router_label_graph_tau_target: float = 0.1  # Temperature for target similarities
    router_label_graph_max_models: int = 256  # Max models for graph regularizer (subsample if exceeds)
    router_label_graph_alpha_domain: float = 0.3  # Similarity for same-domain pairs in taxonomy
    
    # Model registry persistence
    router_registry_path: Optional[str] = None  # Path to save/load registry (for continual learning)
    router_registry_init_mode: str = "extend"  # "fresh" or "extend" - how to initialize registry
    router_registry_base_path: Optional[str] = None  # Path to previous registry JSON (for extend mode)
    # IMPORTANT: Registry should be built once and persisted across experiences to maintain stable IDs
    # In "extend" mode, loads base registry and appends new models. In "fresh" mode, builds from scratch.
    
    # Router debug parameters
    debug_router_supervision: bool = False  # Enable comprehensive router supervision debug checks
    debug_router_every: int = 100  # Run router debug every N steps
    debug_router_first_steps: int = 50  # Run router debug for first N steps
    debug_router_strict: bool = False  # Raise AssertionError on router debug mismatches (default: False, only warns)
    
    # Two-phase training schedule (for Experience 2+ to reduce forgetting)
    router_two_phase_enable: bool = False  # Enable two-phase schedule (Phase 1: stability warmup, Phase 2: main training)
    router_phase1_frac: float = 0.2  # Fraction of total steps for Phase 1 (stability warmup)
    router_phase1_loss_mode: str = "router"  # Loss mode for Phase 1 (typically "router" for router-only)
    router_phase1_replay_ratio: Optional[float] = None  # NOTE: Not implemented - replay ratio must be set in sampler, not here. Use router_replay_loss_multiplier for replay pressure.
    router_phase1_router_loss_weight: float = 1.0  # Router loss weight for Phase 1
    router_phase1_lm_loss_weight: float = 0.0  # LM loss weight for Phase 1 (0.0 ensures LM is frozen)
    router_phase1_proj_lr: Optional[float] = None  # Router projection LR for Phase 1 (None = use router_proj_lr)
    router_phase1_embedding_lr: Optional[float] = None  # Router embedding LR for Phase 1 (None = use router_embedding_lr)
    router_phase1_use_soft_targets: bool = False  # Soft targets for Phase 1
    router_phase1_soft_target_eps: float = 0.02  # Soft target epsilon for Phase 1
    router_replay_loss_multiplier: float = 1.0  # Multiplier for router loss on replay examples (applied in Phase 1 if >1.0)
    
    # Exp1-preservation training mode (for exp2 to reduce catastrophic forgetting)
    router_exp1_preservation_enable: bool = False  # Enable exp1-preservation mode (freeze old embeddings during Phase 1, keep projection trainable to prevent accuracy drop)
    router_exp1_preservation_M_old: Optional[int] = None  # Base registry size from exp1 (auto-detected from checkpoint if None)
    
    # Router embedding anchoring regularizer (for exp2+ to reduce forgetting)
    router_anchor_enable: bool = False  # Enable embedding anchoring to preserve exp1 routing when registry is extended
    router_anchor_lambda: float = 1e-3  # Weight for anchor loss (lambda in total_loss += lambda * anchor_loss)
    router_anchor_mode: str = "normalized"  # Anchor mode: "raw" (L2) or "normalized" (cosine, preferred for router scoring)
    router_anchor_apply_phase: str = "phase1"  # When to apply anchoring: "phase1", "phase2", or "both"
    router_anchor_scope: str = "all_old"  # Which rows to anchor: "all_old" (all rows < M_old) or "touched" (only rows in current step's candidates + gold)
    router_anchor_M_old: Optional[int] = None  # Base registry size M_old (auto-detected from checkpoint if None)
    
    # Router projection anchoring regularizer (for exp2+ to reduce projection drift)
    router_proj_anchor_enable: bool = False  # Enable projection anchoring to preserve exp1 projection when registry is extended
    router_proj_anchor_lambda: float = 1e-2  # Weight for projection anchor loss (lambda in total_loss += lambda * proj_anchor_loss)
    router_proj_anchor_apply_phase: str = "phase1"  # When to apply projection anchoring: "phase1", "phase2", or "both"
    
    # Router freeze LM option (for router-only runs to prevent unintentional LM updates)
    router_freeze_lm: bool = False  # If True, force LM requires_grad=False in router-only mode (router still trainable)
    
    
@dataclass
class EvalConfig:
    # Model and experiment identifiers   
    experience_name: str = "apibench"  # name of the experience, e.g., "apibench" "mllm"
    lora_adapters: List[str] = field(default_factory=list)  # list of LoRA adapters to use
    repo_id: str = "huggyllama/llama-7b"  # base model to use
    eval_on_train: bool = False  # optionally evaluate on train set as well (default: off)
    
    # Input/Output settings
    input_max_length: int = 1024
    max_new_tokens: int = 64
    temperature: float = 0.4
    do_sample: bool = True
    # Decoding controls (optional; only applied if provided)
    top_p: float = 1.0
    top_k: Optional[int] = None
    penalty_alpha: Optional[float] = None
    # Not currently used in generation, but accepted for experimentation/logging
    random_prefix_len: Optional[int] = None
    sample_num: Optional[int] = None
    
    output_name: Optional[str] = None  # Name of the directory to save the evaluation results
    
    # Evaluation settings
    eval_batch_size: int = 4
    
    # LoRA merging strategy settings
    lora_merging_strategy: Optional[str] = None  # ties, dare_linear, arithmetic_mean, or null
    ties_or_dare_weights: List[float] = field(default_factory=lambda: [1.0, 1.0])  # use only when lora_merging_strategy is "ties" or "dare_linear"
    ties_or_dare_density: float = 0.3
    
    retriever: Optional[str] = None  # specify retriever if needed, e.g., "bm25", "sentence_transformer", "splade", "flagembedding"
    system_prompt_format: Optional[str] = None  # specify system prompt format if needed, e.g., "gorilla_prompt", "gorilla_prompt_explanation", "gorilla_prompt_explanation_json"
    use_router: bool = False  # use router evaluation instead of text generation
    debug_router_eval: bool = False  # enable detailed debugging output for router evaluation
    strict_router_load: bool = False  # use strict=True when loading router weights
    eval_on_train_samples: bool = False  # load 50 examples from training split and run router evaluation on them
    known_domain_mode: bool = False  # only compare against models within the same domain (for router evaluation)
    hierarchical_eval: bool = False  # enable hierarchical (two-stage) evaluation: predict group then model within group
    hierarchy_level: str = "domain"  # hierarchy level for hierarchical evaluation: "domain" or "parent_group"
    hierarchical_topk: int = 1  # number of top groups to consider in hierarchical evaluation
    hier_domain_score_mode: str = "logsumexp"  # domain scoring strategy: "logsumexp", "max", "topk_logsumexp", "hybrid"
    hier_domain_topk: int = 10  # number of top models for topk_logsumexp/hybrid domain scoring modes
    hier_domain_hybrid_alpha: float = 0.5  # weight for max in hybrid domain scoring mode (0.0 = pure logsumexp, 1.0 = pure max)
    
    # Few-shot retrieval configuration
    # NOTE: These should match the training config values for consistency between train/test distributions
    # If you trained with fewshot_top_k=1, set this to 1 for fair evaluation
    fewshot_top_k: int = 3  # Number of retrieved neighbors for few-shot augmentation (should match TrainConfig.fewshot_top_k)
    fewshot_max_card_tokens: int = 200  # Max tokens for model card snippets (should match TrainConfig.fewshot_max_card_tokens)
    fewshot_replay_seed: Optional[int] = 42  # Seed for replay buffer sampling (should match TrainConfig.seed used during training)
    fewshot_replay_ratio: float = 0.1  # Replay buffer ratio (should match TrainConfig.fewshot_replay_ratio, default 10%)
    fewshot_dropout_prob: float = 0.0  # Probability of using original prompt without examples during evaluation (0.0 = always use examples, typically 0.0 for eval)
    
    # Commented out parameters (uncomment and adjust as needed)
    # penalty_alpha: float = 0.6
    # top_k: int = 10
    # top_p: float = 0.7
    # random_prefix_len: int = 5
    # sample_num: int = 2
    # decoding_method: str = "sampling"
