# the prime config will override default ppo_trainer.yaml

hydra:
  searchpath:
    - XXXX

defaults:
  - ppo_trainer
  - _self_

data:
  filter_accuracy: True
  accuracy_lower_bound: 0.2
  accuracy_upper_bound: 0.8
  oversample_factor: 4.0 # Sample more responses than the batch size. prompts satisfying the filter will be prioritized.
  filter_truncate: True
  truncation: right

actor_rollout_ref:
  hybrid_engine: True
  model:
    use_remove_padding: True
  rollout:
    # number of responses (i.e. num sample times)
    n: 4
  actor:
    entropy_coeff: 0.001

reward_model:
  enable: True
  strategy: fsdp
  model:
    ref_path: ${reward_model.model.path}
    use_remove_padding:  True
    use_fused_kernels: ${actor_rollout_ref.model.use_fused_kernels}
    fused_kernel_options:
      impl_backend: torch # triton, torch
    tokenizer_path: ${actor_rollout_ref.model.path}
    enable_gradient_checkpointing: ${actor_rollout_ref.model.enable_gradient_checkpointing}
    ref_type: freeze
    fsdp_config:
      min_num_params: 0
      param_offload: ${actor_rollout_ref.actor.fsdp_config.param_offload}
      optimizer_offload: ${actor_rollout_ref.actor.fsdp_config.optimizer_offload}
    update: before # ``before`` for double-forward, ``after`` for single-forward
    optim:
      lr: 1e-6
      lr_warmup_steps: -1 # Prioritized. Negative values mean delegating to lr_warmup_steps_ratio.
      lr_warmup_steps_ratio: 0.  # the total steps will be injected during runtime
      min_lr_ratio: null
      warmup_style: constant
      total_training_steps: -1  # must be overridden by program
      weight_decay: 0.
      grad_clip: 10.0
    beta_train: 0.05
    loss_type: ce # currently only supports ce loss
  prime_granularity: token
  prime_norm: batch_norm # batch_norm or none. if set to none, the normalizer is beta_train
  mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size}
  reward_manager: prime

algorithm:
  adv_estimator: rloo
  # now supports rloo. it treats different source of reward separately.
  kl_ctrl:
    type: fixed
    kl_coef: 0.000
  reward_gt_coef: 5
  reward_dpo_coef: 5

trainer:
  project_name: prime
  experiment_name: examples
  val_before_train: False
  balance_batch: False
