alg_id: ppo

imdb:
  id: ppo
  build_reward: True

  args:
    seed: 0
    verbose: 0
    n_iters: 60
    batch_size: 28
    grad_accumulation: 1
    trajectories_per_update: 112
    n_epochs: 5
    gamma: 0.99
    gae_lambda: 0.95
    vf_coef: 0.5
    target_coef: 0.1
    target_regularization: True
    ent_coef: 0.0
    clip_range: 0.2
    clip_range_vf: 0.2
    max_grad_norm: 1.0
    target_kl: null
    eval_batch_size: 100 
    eval_every: 10
    save_every: 60
    eval_zero_shot: true
    save_checkpoints: false
    eval_splits: ['val']
    max_prompt_len: ${sampling.max_prompt_len}
    max_gen_len: ${sampling.max_gen_len}
    robust_alpha: 1.0
    robust_beta: 0.0
    rce_Z: -1.0
    

  kl_div:
    kl_type: 'fixedklcontroller'
    kl_lr: .01
    coeff: 0.001
    target_kl: 0.1 

  # optimizer:
  #   id: adamw
  #   args:
  #     lr: 1e-5
  #     weight_decay: 1e-6
  #     eps: 1e-5

  # scheduler:
  #   id: linear
  #   args:
  #     total_iters: 50
  optimizer:
    id: adamw
    args:
      lr: 5e-6
      weight_decay: 1e-6
      eps: 1e-5

  scheduler:
    id: constant


  tokenizer:
    model_name: lvwerra/gpt2-imdb
    padding_side: left 
    truncation_side: left 
    pad_token_as_eos_token: True 

  policy:
    id: actor_critic
    args:
      model_type: causal
      model_name: rajkumarrrk/gpt2-fine-tuned-on-imdb-positive-reviews
      max_prompt_len: ${sampling.max_prompt_len}
      max_gen_len: ${sampling.max_gen_len}
      create_reference: True
      mlp_head: False
      quantize_model: False
      gen_kwargs: ${sampling.train_generation_kwargs}
      prompt_truncation_side: ${sampling.prompt_truncation_side}


commongen:
  id: ppo
  build_reward: True

  args:
    seed: 0
    verbose: 0
    n_iters: 100
    batch_size: 32
    grad_accumulation: 1
    trajectories_per_update: 128
    n_epochs: 4
    gamma: 0.99
    gae_lambda: 0.95
    vf_coef: 30.0
    ent_coef: 0.0
    target_coef: 0.1
    target_regularization: True
    clip_range: 0.2
    clip_range_vf: 0.2
    max_grad_norm: 1.0
    target_kl: null
    eval_batch_size: 256 
    eval_every: 10
    save_every: 100
    eval_zero_shot: true
    save_checkpoints: false
    eval_splits: ['val']
    max_prompt_len: ${sampling.max_prompt_len}
    max_gen_len: ${sampling.max_gen_len}
    robust_alpha: 1.0
    robust_beta: 0.0
    rce_Z: -1.0

  kl_div:
    kl_type: 'fixedklcontroller'
    kl_lr: .01
    coeff: 0.0
    target_kl: 0.1 

  optimizer:
    id: adamw
    args:
      lr: 5e-6
      weight_decay: 1e-6
      eps: 1e-5

  scheduler:
    id: linear
    args:
      total_iters: 500

  tokenizer:
    model_name: t5-base
    padding_side: left
    truncation_side: left
    pad_token_as_eos_token: False

  policy:
    id: actor_critic
    args:
      model_type: seq2seq
      model_name: rajkumarrrk/t5-common-gen
      max_prompt_len: ${sampling.max_prompt_len}
      max_gen_len: ${sampling.max_gen_len}
      create_reference: True
      mlp_head: False
      quantize_model: False
      gen_kwargs: ${sampling.train_generation_kwargs}
      prompt_truncation_side: ${sampling.prompt_truncation_side}


tldr:
  id: ppo
  build_reward: True

  args:
    seed: 0
    verbose: 0
    n_iters: 200
    batch_size: 32
    grad_accumulation: 4
    trajectories_per_update: 128
    n_epochs: 4
    gamma: 1.0
    gae_lambda: 0.95
    vf_coef: 0.1
    ent_coef: 0.0
    target_coef: 0.1
    target_regularization: True
    clip_range: 0.2
    clip_range_vf: 0.2
    max_grad_norm: 1.0
    target_kl: null
    eval_batch_size: 5
    eval_every: 10
    save_every: 100
    eval_zero_shot: True
    save_checkpoints: True
    eval_splits: ['val']
    max_prompt_len: ${sampling.max_prompt_len}
    max_gen_len: ${sampling.max_gen_len}
    robust_alpha: 1.0
    robust_beta: 0.0
    rce_Z: -1.0

  kl_div:
    kl_type: 'fixedklcontroller'
    kl_lr: .01
    coeff: 0.002
    target_kl: 0.1 

  optimizer:
    id: adamw
    args:
      lr: 5e-6
      weight_decay: 0.01
      eps: 1e-5

  scheduler:
    id: constant

  tokenizer:
    model_name: gpt2
    padding_side: left
    truncation_side: right 
    pad_token_as_eos_token: True 

  policy:
    id: actor_critic
    args:
      model_type: causal
      model_name: CarperAI/openai_summarize_tldr_sft
      max_prompt_len: ${sampling.max_prompt_len}
      max_gen_len: ${sampling.max_gen_len}
      create_reference: True
      mlp_head: False
      quantize_model: True
      gen_kwargs: ${sampling.train_generation_kwargs}
      prompt_truncation_side: ${sampling.prompt_truncation_side}

  lora:
    peft_config:
      r: 8
      lora_alpha: 64
      lora_dropout: 0.1
      task_type: CAUSAL_LM
