tokenizer:
  model_name: t5-base
  padding_side: left
  truncation_side: left
  pad_token_as_eos_token: False

reward_fn:
  id: meteor
  args:
    shaping_fn: "common_gen_repeat_penalty"
    # values:
    # - id: rouge_combined
    #   args:
    #     shaping_fn: "common_gen_repeat_penalty"
    # - id: meteor
    #   args:
    #     shaping_fn: "common_gen_repeat_penalty"
    # - id: rouge
    #   args:
    #     rouge_type: "rouge1"
    #     shaping_fn: "common_gen_repeat_penalty"
    # - id: spider
    #   args:
    #     spice_coeff: 0.0
    #     cider_coeff: 1.0
    #     shaping_fn: "common_gen_repeat_penalty_batched"
    # - id: spider
    #   args:
    #     spice_coeff: 1.0
    #     cider_coeff: 0.0
    #     shaping_fn: "common_gen_repeat_penalty_batched"
    # - id: spider
    #   args:
    #     spice_coeff: 0.5
    #     cider_coeff: 0.5
    #     shaping_fn: "common_gen_repeat_penalty_batched"


datapool:
  id: commongen
  args:
    concept_end_token: '.'
    concept_separator_token: ' '
    prefix: "generate a sentence with: "


env:
  n_envs: 10
  args:
    max_prompt_length: 20
    max_episode_length: 20
    terminate_on_eos: True
    context_start_token: 0
    prompt_truncation_side: "right"


alg:
  id: ppo
  args:
    n_steps: 256
    batch_size: 64
    verbose: 1
    learning_rate: 0.000002
    n_epochs: 5
    ent_coef: 0.01
  kl_div:
    coeff: 0.001
    target_kl: 2.0
  policy:
    id: seq2seq_lm_actor_critic_policy
    args:
      model_name: t5-base
      apply_model_parallel: True
      prompt_truncation_side: "right"
      generation_kwargs:
        do_sample: True
        top_k: 0
        min_length: 5
        max_new_tokens: 20
    
train_evaluation:
  eval_batch_size: 20
  n_iters: 200
  eval_every: 20
  save_every: 1
  metrics:
    - id: meteor
      args: {}
    - id: rouge
    - id: bleu
      args: {}
    - id: bert_score
      args:
        language: en
    - id: cider
    - id: spice
    - id: diversity
      args: {}
  generation_kwargs:
    num_beams: 5
    min_length: 5
    max_new_tokens: 20

