alg_id: ppo

countdown:
  id: ppo
  build_reward: true

  args:
    seed: 0
    verbose: 0
    n_iters: 200
    batch_size: 32
    grad_accumulation: 4
    trajectories_per_update: 128
    n_epochs: 4
    gamma: 1.0
    gae_lambda: 0.95
    use_whitening: false
    vf_coef: 1.0
    ent_coef: 0.0
    target_coef: 0.0
    target_regularization: false
    clip_range: 0.2
    clip_range_vf: 0.2
    max_grad_norm: 1.0
    target_kl: null
    eval_batch_size: 5
    eval_every: 10
    save_every: 50
    eval_zero_shot: true
    save_checkpoints: true
    eval_splits:
      - val
      - test
    max_prompt_len: ${sampling.max_prompt_len}
    max_gen_len: ${sampling.max_gen_len}
    subgoal_reward: false

  kl_div:
    kl_type: fixedklcontroller
    kl_lr: .01
    coeff: 0.01
    target_kl: 0.1

  optimizer:
    id: adamw
    args:
      lr: 1e-7
      weight_decay: 0.01

  scheduler:
    id: constant

  tokenizer:
    model_name: /home/user/train-countdown/stream-of-search/outputs/star3-final-rand-s0-gpt2/checkpoint-20000
    padding_side: right
    truncation_side: right
    pad_token_as_eos_token: true

  policy:
    id: actor_critic
    args:
      model_type: causal
      model_name: /home/user/train-countdown/stream-of-search/outputs/star3-final-rand-s0-gpt2/checkpoint-20000
      max_prompt_len: ${sampling.max_prompt_len}
      max_gen_len: ${sampling.max_gen_len}
      create_reference: true
      mlp_head: false
      quantize_model: false
      gen_kwargs: ${sampling.train_generation_kwargs}
      prompt_truncation_side: ${sampling.prompt_truncation_side}
