tokenizer:
  model_name: google/flan-t5-large
  padding_side: "right"
  truncation_side: "right"
  pad_token_as_eos_token: False


reward_fn:
  id: summarization_with_hint
  args:
    gpt3_model: 'gpt-3.5-turbo'
    interval: 0.5
    timeout: 20.0
    exp: 2.0
    patience: 10
    temperature: 0.7
    max_tokens: 160
    num_seqs: 4
    selection_strategy: "choose_all"
    top_p: 1.0
    stop_words: ["Article:", "Q:",  "A:", "<|im_end|>"]
    prompt_prefix: "Extract the keywords: "
    prompt_path: "./prompts/cnndm_fs.txt"
    hint_prompt_path: "./prompts/cnndm_hint_fs.txt"
    gpt3_metric: "rouge-avg"
    gpt3_coef: 10.
    use_baseline: False
    t5_pos_coef: 1.0
    t5_neg_coef: -0.2 # -0.1 
    step_reward_coef: 1.
    t5_metric: "hint_hit"
    
    
env:
  n_envs: 10
  args:
    max_prompt_length: 512
    max_episode_length: 100
    terminate_on_eos: True
    prompt_truncation_side: "right"
    context_start_token: 0

datapool:
  id: cnn_daily_mail_with_hint

  args:
    prompt_prefix: "Extract the keywords: "
    dataset: "cnndm"
    n_train: 1000
    n_val: 500
    n_test: 500
    extraction_mode: "textrank"
    extraction_source: "all"

alg:
  id: nlpo
  args: 
    n_steps: 512
    batch_size: 24
    verbose: 1
    learning_rate: 0.000002
    n_epochs: 5 
    ent_coef: 0.0
    vf_coef: 0.5
  kl_div:
    coeff: 0.005
    target_kl: 0.5
  policy:
    id: maskable_seq2seq_lm_actor_critic_policy
    args:
      model_name: ./sft4lms/ckpt/cnndm_1000/textrank-all/flan-t5-large-ep5/
      apply_model_parallel: True
      prompt_truncation_side: "right"
      min_tokens_to_keep: 100
      top_mask: 0.9
      mask_type: "learned_top_p"
      target_update_iterations: 20
      generation_kwargs:
        min_length: 10
        max_new_tokens: 80
        do_sample: True
        top_k: 100
        
 
train_evaluation:
  eval_batch_size: 10
  n_iters: 20
  eval_every: 2
  save_every: 2
  metrics:
    - id: summarization_with_hint
      args: 
        gpt3_model: 'gpt-3.5-turbo'
        interval: 0.5
        timeout: 20.0
        exp: 2
        patience: 10
        split_token: ";"
        split_token_id: 117 # token id of t5 for ";"
        temperature: 0.7
        max_tokens: 160
        num_seqs: 3
        selection_strategy: "choose_all"
        top_p: 1.0
        stop_words: ["Article:", "Q:",  "A:", "<|im_end|>"]
        prompt_prefix: "Extract the keywords: "
        prompt_path: "./prompts/cnndm_fs.txt"
        hint_prompt_path: "./prompts/cnndm_hint_fs.txt"
        evaluate_policy_model: True
        use_lower_baseline: False
        use_upper_baseline: False
        gpt3_metrics: 
          - id: meteor
            args: {}
          - id: rouge
            args: 
              use_single_ref: False
          - id: bleu
            args: {}
          - id: bert_score
            args:
              language: en
        t5_metrics: 
          - id: "hint_hit"
            args: 
              split: ";"
  generation_kwargs:
    min_length: 10
    max_new_tokens: 80
    do_sample: True
    top_k: 0
    temperature: 0.7

