### CONFIGURATION FOR INSPIRATION RETRIEVAL SFT - FULL PARAMETER ###
### Key fixes from HC: ZeRO-3 + lower weight_decay to stabilize loss ###
### R1-Distill Native Format: No system prompt, reasoning starts directly ###
### Dataset: 150,218 train + 2,377 eval samples (alphabetical label format) ###

model_name_or_path: /pfs/training-data/hf/models/deepseek-ai/DeepSeek-R1-Distill-Qwen-7B
template: deepseekr1  # Correct template for R1-Distill models (adds <think> to prompt)

# MODIFY THIS: Point to your dataset directory
# This should contain your IR SFT data
# Data flow: IR_SFT_DATA_DIR from main.sh
dataset_dir: <YOUR_IR_SFT_DATA_DIR>/
dataset: train  # Mapped to actual file in dataset_info.json
# eval_dataset: eval  # Disabled - no eval during training

### RESUME TRAINING CONFIGURATION ###
# To resume from checkpoint: uncomment the line below
# To start fresh training: keep the line commented out
# resume_from_checkpoint: <YOUR_CHECKPOINT_DIR>/full_training_inspiration_retrieval/checkpoint-XXX

# Full Parameter Training (no freezing)
finetuning_type: full

# Training Setup
stage: sft
do_train: true
mask_history: true
disable_shuffling: false
seed: 42

# Sequence and Batching
cutoff_len: 16384  # Sufficient for all samples
packing: false
per_device_train_batch_size: 1  # Conservative for full param + ZeRO-3
gradient_accumulation_steps: 1  # For 128 GPUs: 1 * 1 * 128 = 128 effective batch

# Preprocessing (Parallel Tokenization)
preprocessing_num_workers: 32
# MODIFY THIS: Point to your tokenized cache directory (optional, speeds up repeated runs)
tokenized_path: <YOUR_TOKENIZED_CACHE_DIR>/inspiration_retrieval

# Optimization - Settings aligned with working HC config
learning_rate: 1.0e-5              # Aligned with HC (was 5e-6)
lr_scheduler_type: cosine
warmup_ratio: 0.05                 # Aligned with HC (was 0.1)
weight_decay: 0.01                 # KEY FIX: Aligned with HC (was 0.1)
num_train_epochs: 1
max_grad_norm: 0.5                 # Aligned with HC (was 1.0)

# Precision Settings
bf16: true
fp16: false
gradient_checkpointing: true
flash_attn: sdpa  # Use PyTorch built-in SDPA (no GLIBC dependency)

# DeepSpeed ZeRO-3 - KEY FIX: More stable than ZeRO-2 for full param training
deepspeed: SFT/deepspeed_zero3.json

# Evaluation Settings (disabled)
do_eval: false
# eval_strategy: steps
# eval_steps: 132
# per_device_eval_batch_size: 2
# eval_accumulation_steps: 4
# load_best_model_at_end: true
# metric_for_best_model: eval_loss
# greater_is_better: false

# Note: With 150,218 training samples and effective batch size of 128 (1*1*128 GPUs):
# Total steps = 150,218 / 128 = ~1,174 steps per epoch

# Saving and Logging
save_strategy: steps
save_steps: 117  # Save ~10 times per epoch (1,174 steps / 10)
save_total_limit: 3  # Keep more checkpoints for safety
logging_steps: 1  # Log every step for debugging (can increase after verified stable)
report_to: tensorboard

# MODIFY THIS: Point to your checkpoint output directory
output_dir: <YOUR_CHECKPOINT_DIR>/full_training_inspiration_retrieval
