tokenizer:
  model_name: gpt2
  padding_side: left
  truncation_side: left
  pad_token_as_eos_token: True

reward_fn:
  id: "intent_accuracy"
  args:
    intent_coeff: 1.0
    auto_coeff: 1.0
    constraint_name: "intent"


datapool:
  id: "daily_dialog"
  args:
    context_size: 5

env:
  n_envs: 10
  constrained: True
  args:
    max_prompt_length: 128
    max_episode_length: 20
    terminate_on_eos: True

alg:
  id: constrained_ppo
  args: 
    n_steps: 256
    batch_size: 64
    verbose: 1
    learning_rate: 0.000001
    n_epochs: 5
    squash_lagrange: True
    lagrange_lr: 0.1
    fixed_lagrange: False
    # ========================== mu/xi-PPO - KL ========================== 
    # constraint thresholds := intent threshold
    constraint_threshold: 0.48
    vf_coef: 0.5
    constraint_vf_coef: 0.5
    kl_vf_coef: 0.2
    maximizing_reward: "kl" 
    # task threshold := meteor threshold
    task_threshold: 0.238
    equality_constraints: True
    # we'll current consider "task" == METEOR, "constraint" == Intent

    
  kl_div:
    coeff: 0.2 #0.1
    # target_kl: 0.5  # to use a fixed coefficient, set to None; target is really
    # more like an upper-bound; if kl is below target, then decrease coeff, else
    # increase it (it's essentially like a Lagrange multiplier)
    # default value is 0.5
 
  policy:
    id: causal_lm_actor_critic_policy
    args:
      model_name: gpt2
      apply_model_parallel: True
      num_value_heads: 3
      debug_mode: False
      generation_kwargs:
        do_sample: True
        top_k: 20
        min_length: 2
        max_new_tokens: 20

nelder_mead:
  active: True
  args:
    max_iters: 10
    alpha: 1.0
    gamma: 2.0
    rho: 0.5
    sigma: 0.5
    tol: 0.01
      
train_evaluation:
  eval_batch_size: 32
  n_iters: 100
  eval_every: 5
  save_every: 10
  metrics:
    - id: intent_accuracy
    - id: causal_perplexity
      args:
        tokenizer_id: gpt2
        stride: 128
        model_type: causal
    - id: diversity
      args: {}
    - id: meteor
      args: {}
    - id: rouge
    - id: bleu
      args: {}
    - id: bert_score
      args:
        language: en
    - id: sacre_bleu
      args:
        tokenize: "intl"
    - id: crlhf_eval
      args:
        tokenize: "intl"
  generation_kwargs:
    do_sample: True
    top_k: 20
    min_length: 2
    max_new_tokens: 20
