defaults:
  - ../base/ppo_trainer

gsm8k:
  developer_prompt: DeepSeekZero
  datasets:
    file_path: /nlp/scr/qinanyu/data/gsm8k_train.parquet
    size: 0
    seed: 42
  validation_dataset:
    file_path: /nlp/scr/qinanyu/data/gsm8k_test.parquet
    size: 0
    seed: 41
  val_path: trainers/val_gsm8k

data:
  task: gsm8k
  train_batch_size: 512 
  val_batch_size: 100
  max_prompt_length: 1024
  max_response_length: 2048
  filter_overlong_prompts: True 
  truncation: 'error' 
  shuffle: False 
  chat: True

actor_rollout_ref:
  hybrid_engine: True
  model:
    path: /nlp/scr/qinanyu/models/qwen2.5-3b
    external_lib: null
    override_config: { }
    enable_gradient_checkpointing: True
    use_remove_padding: False
    use_shm: false
    #lora_rank: 64
  actor:
    loss_agg_mode: "token-mean"
    strategy: fsdp  # This is for backward-compatibility
    ppo_mini_batch_size: 128
    ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu
    ppo_micro_batch_size_per_gpu: 16
    use_dynamic_bsz: False
    ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length}
    grad_clip: 1.0
    clip_ratio: 0.2 # default value if clip_ratio_low and clip_ratio_high are not specified
    clip_ratio_low: 0.2
    clip_ratio_high: 0.2
    clip_ratio_c: 3.0 # lower bound of the value for Dual-clip PPO from https://arxiv.org/pdf/1912.09729
    entropy_coeff: 0
    use_kl_loss: True # True for GRPO
    use_torch_compile: True # False to disable torch compile
    kl_loss_coef: 0.001 # for grpo
    kl_loss_type: low_var_kl # for grpo
    ppo_epochs: 1
    shuffle: False
    ulysses_sequence_parallel_size: 1 # sp size
    optim:
      lr: 1e-6
      lr_warmup_steps_ratio: 0.  # the total steps will be injected during runtime
      min_lr_ratio: null   # only useful for warmup with cosine
      warmup_style: constant  # select from constant/cosine
      total_training_steps: -1  # must be override by program
    fsdp_config:
      wrap_policy:
        # transformer_layer_cls_to_wrap: None
        min_num_params: 0
      param_offload: True
      optimizer_offload: True
      fsdp_size: -1
      forward_prefetch: False
    checkpoint:
      contents: ['model', 'optimizer', 'extra']
  ref:
    fsdp_config:
      param_offload: True
      optimizer_offload: True
      wrap_policy:
        # transformer_layer_cls_to_wrap: None
        min_num_params: 0
    log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu
    log_prob_micro_batch_size_per_gpu: 16
    log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
    log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
    ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size
  rollout:
    name: vllm
    mode: sync
    temperature: 1.0
    max_model_len: 2048
    top_k: -1 # 0 for hf rollout, -1 for vllm rollout
    top_p: 1
    prompt_length: ${data.max_prompt_length}  # not use for opensource
    response_length: ${data.max_response_length}
    # for vllm rollout
    dtype: bfloat16 # should align with FSDP
    gpu_memory_utilization: 0.8
    ignore_eos: False
    enforce_eager: True
    free_cache_engine: True
    load_format: safetensors
    tensor_model_parallel_size: 2
    max_num_batched_tokens: 8192
    max_num_seqs: 1024
    log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu
    log_prob_micro_batch_size_per_gpu: 8
    log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
    log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
    disable_log_stats: True
    enable_chunked_prefill: True # could get higher throughput
    tensor_parallel_size: 2
    # for hf rollout
    do_sample: True
    use_fire_sampling: False
    # number of responses (i.e. num sample times)
    n: 8 # > 1 for grpo
    val_kwargs:
      do_sample: True
      n: 4
    multi_turn:
      enable: False  # set to True for multi-turn tool interaction tasks; should set rollout.name to sglang as well
      max_turns: null  # null for no limit (default max_length // 3)
      tool_config_path: null  # null for no tool
      format: chatml

critic:
  strategy: fsdp
  optim:
    lr: 1e-5
    lr_warmup_steps_ratio: 0.  # the total steps will be injected during runtime
    min_lr_ratio: null   # only useful for warmup with cosine
    warmup_style: constant  # select from constant/cosine
    total_training_steps: -1  # must be override by program
  model:
    path: ~/models/deepseek-llm-7b-chat
    tokenizer_path: ${actor_rollout_ref.model.path}
    override_config: { }
    external_lib: ${actor_rollout_ref.model.external_lib}
    enable_gradient_checkpointing: True
    use_remove_padding: False
    fsdp_config:
      param_offload: False
      optimizer_offload: False
      wrap_policy:
        # transformer_layer_cls_to_wrap: None
        min_num_params: 0
      fsdp_size: -1
  ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size}
  ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu
  ppo_micro_batch_size_per_gpu: null
  forward_micro_batch_size: ${critic.ppo_micro_batch_size}
  forward_micro_batch_size_per_gpu: ${critic.ppo_micro_batch_size_per_gpu}
  use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
  ppo_max_token_len_per_gpu: 16768 # (${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}) * 2
  forward_max_token_len_per_gpu: ${critic.ppo_max_token_len_per_gpu}
  ulysses_sequence_parallel_size: 1 # sp size
  ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs}
  shuffle: ${actor_rollout_ref.actor.shuffle}
  grad_clip: 1.0
  cliprange_value: 0.5

reward_model:
  enable: False
  strategy: fsdp
  model:
    input_tokenizer: ${actor_rollout_ref.model.path}  # set this to null if the chat template is identical
    path: ~/models/FsfairX-LLaMA3-RM-v0.1
    external_lib: ${actor_rollout_ref.model.external_lib}
    use_remove_padding: False
    fsdp_config:
      min_num_params: 0
      param_offload: False
      fsdp_size: -1
  micro_batch_size: null # will be deprecated, use micro_batch_size_per_gpu
  micro_batch_size_per_gpu: null # set a number
  max_length: null
  ulysses_sequence_parallel_size: 1 # sp size
  use_dynamic_bsz: ${critic.use_dynamic_bsz}
  forward_max_token_len_per_gpu: ${critic.forward_max_token_len_per_gpu}
  launch_reward_fn_async: False

algorithm:
  use_kl_in_reward: False
  gamma: 1.0
  lam: 1.0
  adv_estimator: grpo
  kl_penalty: kl  # how to estimate kl divergence
  kl_ctrl:
    type: fixed
    kl_coef: 0.001
  use_pf_ppo: False
  pf_ppo:
    reweight_method: pow  # ["pow", "max_min", "max_random"]
    weight_pow: 2.0

trainer:
  balance_batch: True
  total_epochs: 15
  total_training_steps: null
  project_name: gsm8k_ft
  experiment_name: gsm8k
  logger: [ 'console', 'wandb' ]
  val_generations_to_log_to_wandb: 0
  nnodes: 1
  n_gpus_per_node: 4
  save_freq: 50
  # auto: find the last ckpt to resume. If can't find, start from scratch
  resume_mode: auto # or auto or resume_path if
  resume_from_path: False
  test_freq: 5
  critic_warmup: 0
  default_hdfs_dir: null
  remove_previous_ckpt_in_save: False
  del_local_ckpt_after_load: False
  default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name}
  val_before_train: True