 # Training Configuration
# This file contains the default parameters for the TrainConfig class

# Model and experiment identifiers
experience_name: "apibench"  # name of the experience, e.g., "apibench" "mllm" "hugging-bench-1" "hugging-bench-2"
variant_name: "sequential-finetuning"  # variant name for the experiment
extra_info: ""  # any extra info to append to the output directory name 
output_root: "cco/experiments"  # root directory for output

repo_id: "huggyllama/llama-7b"  # base model to use e.g., "huggyllama/llama-7b", "deepseek-ai/deepseek-coder-7b-instruct-v1.5"
retriever: null  # specify retriever if needed, e.g., "bm25", "sentence_transformer", "splade", "flagembedding"

# System prompt configuration
#system_prompt:"" "Return only the repository identifier in the form owner/repo. Do not include any extra words or punctuation. Use only HuggingFace models."  # custom system prompt, if None uses default gorilla_prompt
system_prompt: ""
### system prompt format:
# gorilla_prompt: standard gorilla prompt with no explanation, predict only the model_name
# gorilla_prompt_explanation: gorilla prompt with explanation using gorilla format <<<model_name>>>:my_model <<<explanation>>>:my_explanation
# gorilla_prompt_explanation_json: gorilla prompt with explanation in json format {"model_name": "my_model", "explanation": "my_explanation"}
system_prompt_format: gorilla_prompt  # specify system prompt format if needed, e.g., "gorilla_prompt", "gorilla_prompt_explanation", "gorilla_prompt_explanation_json"

# Training hyperparameters
epochs: 5
batch_size: 32  # Reduced from 4 to reduce memory usage
grad_accum: 4  # Increased from 32 to maintain similar effective batch size (2*64=128 vs 4*32=128)
lr: 0.0005
max_length: 550 # DO NOT DELETE THIS COMMENT max tokenized sequence length considering {model_name:..., explanation:...} for apibench is about 920 tokens. if we truncate explanations to 1000 chars at max, max tokenized sequence length is about 350 tokens
max_grad_norm: 1.0
packing: false
group_by_length: true
completion_only_loss: true
label_smoothing: 0.05

# LoRA parameters
lora_r: 32  # try 64
lora_alpha: 64  # typically 2 * lora_r (128)
lora_dropout: 0.05  # try 0.05
target_modules: ['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'down_proj', 'up_proj']

# Checkpoint and evaluation options
resume_from: null  # path to checkpoint to resume from a previous training run
lora_adapters: []  # list of LoRA adapters to use
early_stopping_patience: 3  # check we are not overfitting
early_stopping_threshold: 0.01
no_validation: true  # set to true to merge training and validation set during training
hyperparameters_search: false  # set to true when executing bayes_hyperparameter_search.py, this will evaluate the model after training

# Optimizer and scheduler
weight_decay: 0.001
warmup_steps: 10
lr_scheduler_type: "linear"  # choices: linear, cosine, cosine_with_restarts, polynomial, constant, constant_with_warmup
optim: "adamw_torch"
logging_steps: 5
save_strategy: "epoch"
save_total_limit: 1
metric_for_best_model: "eval_loss"
greater_is_better: false

activation_checkpointing: true

# Memory optimization options
# When true, enables Flash Attention 2 (if available), activation offloading to CPU, and bfloat16 (bf16) precision to reduce GPU memory usage.
low_memory_mode: false  # Enable flash-attn + activation offload + bf16
use_quantization: false # use 4bit quantization


preference_mode: false
negative_sampling_strategy: "mixed"
num_rejections_per_example: 2
orpo_beta: 0.1
max_prompt_length: 512  # Reduced from 512 to reduce memory usage

# =============================================================================
# Replay Configuration (Continual Learning)
# =============================================================================
# replay_percentage and replay_num_samples control how much data to replay
# replay_strategy controls HOW the replay data is selected:
#   - "random": baseline random sampling (default)
#   - "domain_model_coreset": diversity-preserving sampling with domain/model awareness

replay_strategy: "random"  # options: "random", "domain_model_coreset"

# Coreset replay hyperparameters (only used when replay_strategy="domain_model_coreset")
# These preserve long-tail domains/models and maximize embedding-space diversity
replay_min_per_domain: 5      # Minimum examples to include per domain (floor)
replay_max_per_domain: null   # Optional cap per domain (null = no cap)
replay_max_per_model: 3       # Maximum examples per model within each domain
replay_embedding_source: "sentence_transformer"  # Embedding model for coreset selection

# =============================================================================
# Neighbour-Consistency Regularisation (Continual Learning)
# =============================================================================
# Adds an auxiliary loss term encouraging locally smooth predictions.
# For each training example, retrieves similar prompts from previous experiences
# and regularizes the model's output to be consistent between anchor and neighbours.
# This enforces knowledge preservation without a separate teacher model.
#
# IMPORTANT: In multi-model routing scenarios, similar prompts may intentionally
# route to DIFFERENT models. Use "strict" domain filtering to only apply consistency
# when same-domain neighbours exist in the replay buffer. This preserves domain
# knowledge without forcing identical model routing.

use_neighbor_consistency: false   # Enable consistency loss (default off for baseline)
neighbor_k: 3                     # Number of neighbours to retrieve per example
neighbor_consistency_weight: 0.1  # Weight of consistency loss vs supervised loss
neighbor_consistency_temperature: 1.0  # Temperature for KL divergence
neighbor_source: "apibench"       # Source for neighbours: "apibench" or "current_train"

# Domain-aware neighbour filtering
# - "none": ignore domains (use semantic similarity only)
# - "soft": add domain_bias to similarity for same-domain neighbours
# - "hard": prefer same-domain, fall back to cross-domain if needed
# - "strict": only use same-domain neighbours, skip consistency if too few exist
neighbor_domain_filter_mode: "none"
neighbor_domain_bias: 0.2         # Similarity bonus for same-domain (soft mode)
neighbor_min_same_domain: 1       # Minimum same-domain neighbours required (strict mode)

# Replay-only consistency (RECOMMENDED for multi-model routing)
# When true, only applies consistency loss when the anchor prompt is ALSO from
# the replay buffer (not a new experience prompt). This ensures we're maintaining
# consistency between old prompts, not pushing new prompts toward old outputs.
neighbor_replay_only: false       # Only use replay examples as anchors
neighbor_replay_similarity_threshold: 0.95  # Similarity to detect replay examples


replay_strategy: "random"  # options: "random", "domain_model_coreset"

# Coreset replay hyperparameters (only used when replay_strategy="domain_model_coreset")
# These preserve long-tail domains/models and maximize embedding-space diversity
replay_min_per_domain: 5      # Minimum examples to include per domain (floor)
replay_max_per_domain: null   # Optional cap per domain (null = no cap)
replay_max_per_model: 3       # Maximum examples per model within each domain
replay_embedding_source: "flagembedding"  # Embedding model for coreset selection

# =============================================================================
# Neighbour-Consistency Regularisation (Continual Learning)
# =============================================================================
# Adds an auxiliary loss term encouraging locally smooth predictions.
# For each training example, retrieves similar prompts from previous experiences
# and regularizes the model's output to be consistent between anchor and neighbours.
# This enforces knowledge preservation without a separate teacher model.

use_neighbor_consistency: false   # Enable consistency loss (default off for baseline)
neighbor_k: 3                     # Number of neighbours to retrieve per example
neighbor_consistency_weight: 0.1  # Weight of consistency loss vs supervised loss
neighbor_consistency_temperature: 1.0  # Temperature for KL divergence
neighbor_source: "apibench"       # Source for neighbours: "apibench" or "current_train"

# NEW: Neighbor consistency refinement options
neighbor_max_consistency_samples: 4   # Max batch examples for consistency loss (default: 4)
neighbor_consistency_num_tokens: 5    # Last tokens to pool for KL loss (default: 1 = last token only)
neighbor_domain_filter_mode: "strict"   # Domain filtering: "none", "soft", "hard"
neighbor_domain_bias: 0.2             # Similarity bonus for same-domain neighbors (soft mode)
neighbor_embedding_model: "flagembedding"  # Options: "flagembedding" (BGE-M3), or any sentence-transformer model name


# =============================================================================
# Neighbour-Contrastive Regularisation (Hard Negative Mining)
# =============================================================================
# Alternative to neighbour-consistency for multi-model routing scenarios.
# Instead of enforcing similar outputs for similar prompts (which is mis-specified
# when similar prompts intentionally route to different models), this loss
# DISCRIMINATES between semantically similar prompts with different target models.
#
# For each anchor (preferably replay examples), retrieves top-k similar prompts,
# finds neighbors with different model_ids as hard negatives, and adds a ranking
# loss that increases the score of the anchor's gold model vs negative models.
#
# score(prompt, model_id) = log-probability the router assigns to emitting
#                           the target response with that model_id
#
# Ranking loss per anchor:
#   softplus: L = mean_j softplus(score_neg_j - score_pos)
#   hinge:    L = mean_j relu(margin + score_neg_j - score_pos)
#
# Total loss: L_total = L_supervised + weight * L_contrastive

use_neighbor_contrastive: false        # Enable contrastive ranking loss (default off)
neighbor_contrastive_weight: 0.1       # Weight of contrastive loss vs supervised loss
neighbor_contrastive_k: 3              # Number of neighbours to retrieve per anchor
neighbor_contrastive_num_negatives: 3  # Number of unique negative model_ids to use (<= k)
neighbor_contrastive_max_anchors_per_batch: 4  # Max anchors to apply loss per batch
neighbor_contrastive_loss_type: "softplus"     # Loss type: "softplus" or "hinge"
neighbor_contrastive_margin: 0              # Margin for hinge loss (only when loss_type="hinge")
neighbor_contrastive_apply_to: "replay_only"   # Apply to: "replay_only" or "all" examples

# =============================================================================
# X-Sample Contrastive Loss (X-CLR) - ICLR 2025
# =============================================================================
# X-CLR replaces the hard "one positive" target with a soft target distribution
# derived from a similarity graph, and minimizes cross-entropy between the target
# distribution and the learned similarity distribution.
#
# Key advantages over neighbor-consistency:
# - Compatible with self-instruct data where near-duplicate prompts may map to
#   different models (taxonomy graph handles this explicitly)
# - Uses prompt embeddings from router model's hidden states (no separate model)
# - Soft targets avoid pushing semantically similar but label-different prompts
#
# Graph modes:
# - "taxonomy": G[i,j]=1.0 if same model_id, alpha if same domain, 0 otherwise
# - "modelcard": Use frozen SBERT embeddings of model cards for soft similarity
#
# Loss: L_xclr = mean_i [ - sum_j s_i[j] * log p_i[j] ] (cross-entropy)
# where s_i = softmax(G_soft / tau_target) and p_i = softmax(z @ z.T / tau)

use_xclr: false                    # Enable X-CLR loss (default off)
xclr_weight: 0.2                   # Weight λ for L_total = L_supervised + λ * L_xclr
xclr_tau: 0.1                      # Temperature for learned similarity distribution
xclr_tau_target: 0.1               # Temperature for soft target distribution
xclr_graph_mode: "taxonomy"        # Graph mode: "taxonomy" or "modelcard"
xclr_domain_alpha: 0.5             # Weight for same-domain, different-model pairs
xclr_apply_to: "replay_only"               # Apply to: "all" or "replay_only" examples
xclr_proj_dim: 256                 # Projection dimension for prompt embeddings
xclr_max_anchors_per_batch: 8      # Max anchors per batch (for efficiency)
xclr_log_every: 100                # Log X-CLR metrics every N steps
xclr_modelcard_encoder: "all-mpnet-base-v2"  # SBERT model for modelcard mode
xclr_modelcard_blend: 0.0          # Blend modelcard with taxonomy (0 = pure modelcard)

# X-CLR Debugging / Diagnostics (for diagnosing flat loss issues)
xclr_debug: true                  # Enable detailed diagnostics (target/pred distributions, gradients)
xclr_debug_log_every: 10         # Log diagnostics every N steps (null = use xclr_log_every)
xclr_debug_max_print: 1            # Max anchors to show detailed breakdown for

# X-CLR Replay Candidate Pool (for xclr_apply_to="replay_only")
# When xclr_apply_to="replay_only", many replay anchors may have no in-batch positives
# under taxonomy mode. The candidate pool samples from the replay buffer to guarantee
# same-domain positives for each anchor where possible.
xclr_num_candidates: 31           # Number of candidates per anchor (31 + anchor = 32)
xclr_min_pos: 1                   # Minimum same-domain candidates to include per anchor

# X-CLR Prompt Similarity Configuration
# Enables prompt similarity-based candidate sampling and target graph construction.
# Uses the existing few-shot retrieval infrastructure to drive X-CLR neighborhoods.
xclr_candidate_sampling: "prompt_similarity"  # Candidate sampling: "domain" (default) or "prompt_similarity"
xclr_target_graph: "prompt_similarity"       # Target graph: "taxonomy" (default) or "prompt_similarity"
xclr_promptsim_k: 3              # Number of candidates for prompt_similarity (null = use xclr_num_candidates)
xclr_promptsim_retriever_type: "flagembedding"  # Retriever: "bm25", "sentence_transformer", "splade", "flagembedding"
xclr_promptsim_exclude_self: true   # Exclude anchor from candidates (default: true)
xclr_promptsim_retrieve_over_all: true  # If true, retrieve over all prompts (replay + current exp), else just replay pool

# X-CLR Candidate Queue (Memory Bank) - for addressing near-uniform predicted distributions
# The queue stores embeddings from previous batches to augment the candidate set,
# increasing the number of negatives/positives and helping prevent p from becoming uniform.
xclr_queue_size: 2048             # Maximum number of candidates to store in queue (FIFO)
xclr_queue_use: false             # Enable candidate queue (memory bank)
xclr_queue_device: "cpu"          # Device to store queue embeddings ("cpu" to save VRAM, "cuda" for speed)
xclr_num_queue_candidates: 256   # Number of candidates to sample from queue per batch
xclr_force_positive_from_queue: false  # If anchor has no positives in-batch, try to draw at least 1 from queue with same domain


# X-CLR Stability Fixes (NEW - added for training stability)
xclr_use_two_stream_sampler: false
xclr_replay_per_batch: 4
xclr_proj_lr: null
xclr_stopgrad_base: false
xclr_stopgrad_warmup_steps: 20

# =============================================================================
# Router Training (Semantic Batching + Candidate-Set Routing Loss)
# =============================================================================
# Neural model selector that learns to predict which model should handle each prompt.
# Uses semantic batching, candidate-set routing loss, and hard negative mining.
#
# Key features:
# - Semantic batching: Groups examples by domain for better negative sampling
# - Candidate-set routing loss: Supervised ranking over sampled candidate models
# - Hard negative mining: Identifies confusable models as hard negatives
# - Optional soft targets: Distributes supervisory signal to graph neighbors
# - Optional label-graph regularization: Aligns model embeddings based on taxonomy
#
# Loss modes:
# - "supervised": Standard LM training only (no router)
# - "router": Router training only (recommended for pure routing)
# - "router+graph": Router + label-side graph regularization
# - "supervised+router": Combined LM + router training
# - "supervised+router+graph": Combined with graph regularization

loss_mode: "supervised+router"              # Training mode (see options above)
router_loss_weight: 1.0              # Weight for routing loss
lm_loss_weight: 1.0                  # Weight for LM supervised loss (if combined mode)

# Router Architecture
router_embedding_dim: 2048           # Dimension of model embeddings (null = use hidden_size = 4096)
router_tau: 0.15                     # Temperature for scaling logits
router_pooling: "mean"         # Pooling strategy: "last_token" or "mean"

# Router Learning Rates (Split by Parameter Group)
# Split LR helps stabilize training by reducing spiky updates from the embedding table
# If null, uses the global learning_rate from training args
router_proj_lr: 3e-4                 # Learning rate for projection head (null = use global lr)
                                     # Recommended: 3e-4 to 5e-4 for faster convergence
router_embedding_lr: 5e-5            # Learning rate for embedding table (null = use global lr)
                                     # Recommended: 1e-4 to 3e-4 for stability (lower than proj_lr)

# Candidate Sampling Strategy
# The router samples a candidate set per example containing:
# - 1 positive (the ground truth model)
# - K_semantic semantic negatives (same domain, different model)
# - K_far far negatives (different domain)
# - K_hard hard negatives (from mining confusable models)
router_K_total: 64                   # Total candidate set size
router_K_semantic: 40                # Number of semantic (in-domain) negatives
router_K_far: 8                      # Number of far (out-of-domain) negatives
router_K_hard: 15                     # Number of hard negatives from mining

# Hard Negative Mining
# Periodically mines "confusable" models (high router scores for wrong models)
# Runs under torch.no_grad() for efficiency
router_mine_every_steps: 200         # Mine hard negatives every N training steps
router_K_hard_pool: 20               # Number of top confusable models to cache per example
router_semantic_pool_size: 512       # Size of semantic pool to score during mining
router_max_pool_size: 1024           # Maximum pool size for mining

# Semantic Pool Expansion (Option B)
# Controls how semantic negatives are sampled: from exact domain only vs expanded to related domains
# Solves the sparse domain problem where most domains have < K_semantic models
router_semantic_pool_mode: "parent_group"  # Mode: "domain_only", "parent_group", "taxonomy_graph"
# - "domain_only": Sample only from exact domain (original behavior, causes sparse domain warnings)
# - "parent_group": Expand to related domains via parent group (e.g., all "computer vision *" domains)
#                   This substantially increases semantic pool size for sparse domains
# - "taxonomy_graph": Use explicit taxonomy graph (future extension, requires graph definition)
router_semantic_pool_max_domains: 8     # Max related domains to include (null = all in parent group)
router_semantic_pool_depth: 1              # Graph traversal depth (only for taxonomy_graph mode)

# Soft Targets (Optional)
# Distributes supervisory signal to graph neighbors of the positive model
# Uses taxonomy/domain graph to identify related models
router_use_soft_targets: true       # Enable soft target distribution
router_soft_target_eps: 0.1          # Mass to distribute to neighbors (1-eps to positive)
router_soft_target_k_neighbors: 5    # Number of neighbors to distribute to

# Label-Side Graph Regularization (Optional)
# Adds auxiliary loss to align model embeddings based on taxonomy
# Constructs target similarity distribution among models and matches with learned embeddings
router_use_label_graph_reg: false    # Enable graph regularization
router_label_graph_lambda: 0.1       # Weight for graph regularizer loss
router_label_graph_tau: 0.07         # Temperature for learned model similarities
router_label_graph_tau_target: 0.1   # Temperature for target graph similarities
router_label_graph_max_models: 256   # Max models for graph computation (efficiency cap)
router_label_graph_alpha_domain: 0.5 # Weight for same-domain different-model pairs in graph

# Semantic Batching (Recommended)
# Groups training examples by domain to improve negative sampling
# Creates batches where most examples share the same domain
semantic_batching: true      # Enable domain-based semantic batching (for router training)
domains_per_batch: 2          # Number of domains per batch

# Model Registry
# The registry maps model names to stable integer IDs and stores metadata
# Can be pre-built and reused across experiences in continual learning
router_registry_path: null           # Path to pre-built registry (null = build from data)

