tokenizer:
  model_name: lvwerra/gpt2-imdb
  padding_side: left
  truncation_side: left
  pad_token_as_eos_token: True

reward_fn:
  id: learned_reward
  args:
    model_name: lvwerra/distilbert-imdb
    label_ix: 1
    include_prompt_for_eval: True

datapool:
  id: imdb
  args: {}

env:
  n_envs: 10
  args:
    max_prompt_length: 64
    max_episode_length: 48
    terminate_on_eos: True

alg:
  id: nlpo
  args:
    n_steps: 128
    batch_size: 64
    verbose: 1
    learning_rate: 0.000001
    n_epochs: 5

  kl_div:
    coeff: 1.0
    target_kl: 1.0
  policy:
    id: maskable_causal_lm_actor_critic_policy
    args:
      model_name: lvwerra/gpt2-imdb
      apply_model_parallel: False
      top_mask: 0.9
      min_tokens_to_keep: 100
      mask_type: 'learned_top_p'
      target_update_iterations: 50
      generation_kwargs:
        do_sample: True
        min_length: 48
        max_new_tokens: 48

train_evaluation:
  eval_batch_size: 256
  n_iters: 50
  eval_every: 10
  save_every: 1
  metrics:
    - id: learned_reward
      args:
        model_name: lvwerra/distilbert-imdb
        label_ix: 1
        batch_size: 100
    - id: causal_perplexity
      args:
        tokenizer_id: gpt2
        stride: 512
        model_type: causal
    - id: diversity
      args: {}