# Model arguments
model_name_or_path: /fast/XXXX-3/forecasting/sft/llama3.1-8b/full # deepseek-ai/DeepSeek-R1-Distill-Llama-8B  # Qwen/Qwen2.5-7B-Instruct # deepseek-ai/DeepSeek-R1-Distill-Qwen-7B # Qwen/Qwen2.5-7B-Instruct # deepseek-ai/DeepSeek-R1-Distill-Llama-8B # Qwen/Qwen2.5-7B-Instruct # Qwen/Qwen2.5-7B-Instruct # deepseek-ai/DeepSeek-R1-Distill-Qwen-7B #   Qwen/Qwen2.5-7B-Instruct #  deepseek-ai/DeepSeek-R1-Distill-Llama-8B ##  deepseek-ai/DeepSeek-R1-Distill-Qwen-7B
model_revision: main
torch_dtype: bfloat16
attn_implementation: flash_attention_2
bf16: true
tf32: true
output_dir: /fast/XXXX-3/forecasting/training/SFT-Llama-Full-DataMix-Brier # DeepSeek-R1-Distill-Llama-8B-Brier-DataMix-Raw # Qwen2.5-7B-Instruct-Raw-DataMix # R1-Distill-Qwen-7B-Raw-Halawi # DataMix # Qwen-2.5-7B-Instruct-Brier-DataMix # /fast/XXXX-3/forecasting/training/DeepSeek-R1-Distill-Llama-8B-Brier-Raw # Qwen-2.5-7B-Instruct-LOG-Raw-102 # /fast/XXXX-3/forecasting/training/DeepSeek-R1-Distill-Llama-8B-Brier-256 # R1-Distill-Llama-8B-Brier #runs/qwen-2.5-3b-r1-countdown # Llama-3.1-8B-MATH-IDK #

# Dataset arguments
dataset_id_or_path: YuehHanChen/forecasting
dataset_splits: train  # Must match the training split used

# Lora Arguments
# No LoRA is used here

# Training arguments
# max_steps: 500
per_device_train_batch_size: 1 # 2
gradient_accumulation_steps: 256 # 256 # 32
gradient_checkpointing: true
gradient_checkpointing_kwargs:
  use_reentrant: false
learning_rate: 1.0e-5 # 1.0e-6 as in the deepseek math paper 5-e7 from XXXX
lr_scheduler_type: cosine
warmup_ratio: 0.1
# GRPO specific parameters
beta: 0.04 # 0.04 as in the deepseek math paper 0.001 from XXXX
max_prompt_length: 1024
max_completion_length: 2048 # 3072 # 3584 # 3072 # 3584 # 2048
num_generations: 7 # 14
use_vllm: true
# vllm_device: "cuda:7"
vllm_gpu_memory_utilization: 0.8
num_train_epochs: 15
temperature: 0.9

# Logging arguments
logging_strategy: steps
logging_steps: 1
log_completions: False # True
save_strategy: "steps"
save_steps: 100 # 100
seed: 42
report_to:
- wandb

# Eval arguments 
eval_strategy: "steps"
eval_steps: 30
per_device_eval_batch_size: 1
# eval_on_start: True 
# do_eval: true 
# batch_eval_metrics: true 