alg_id: bc

imdb:
  id: bc
  build_reward: False

  args:
    seed: 0
    grad_accumulation: 1
    batch_size_per_process: 15
    n_epochs: 10
    eval_batch_size: 100 
    eval_every: 500
    save_every: 500 #TODO
    eval_zero_shot: false
    save_checkpoints: false #TODO
    eval_splits: ['val']

  optimizer_kwargs:
    lr: 1e-5
    weight_decay: 1e-6
    eps: 1e-5

  tokenizer:
    model_name: lvwerra/gpt2-imdb
    padding_side: left 
    truncation_side: left 
    pad_token_as_eos_token: True 

  policy:
    id: actor
    args:
      model_type: causal
      model_name: lvwerra/gpt2-imdb
      max_prompt_len: ${sampling.max_prompt_len}
      max_gen_len: ${sampling.max_gen_len}
      create_reference: False
      gen_kwargs: ${sampling.train_generation_kwargs}
      prompt_truncation_side: ${sampling.prompt_truncation_side}

commongen:
  id: bc
  build_reward: False

  args:
    seed: 0
    grad_accumulation: 1
    batch_size_per_process: 30
    n_epochs: 10
    eval_batch_size: 100 
    eval_every: 500
    save_every: 500 #TODO
    eval_zero_shot: false
    save_checkpoints: false #TODO
    eval_splits: ['val']

  optimizer_kwargs:
    lr: 1e-5
    weight_decay: 1e-6
    eps: 1e-5

  tokenizer:
    model_name: t5-base
    padding_side: left 
    truncation_side: left 
    pad_token_as_eos_token: False 

  policy:
    id: actor
    args:
      model_type: seq2seq
      model_name: t5-base
      max_prompt_len: ${sampling.max_prompt_len}
      max_gen_len: ${sampling.max_gen_len}
      create_reference: False
      gen_kwargs: ${sampling.train_generation_kwargs}
      prompt_truncation_side: ${sampling.prompt_truncation_side}

tldr:
  id: bc
  build_reward: False

  args:
    seed: 0
    grad_accumulation: 4
    batch_size_per_process: 2
    n_epochs: 10
    eval_batch_size: 100 
    eval_every: 500
    save_every: 500 #TODO
    eval_zero_shot: false
    save_checkpoints: false #TODO
    eval_splits: ['val']

  optimizer_kwargs:
    lr: 1e-5
    weight_decay: 1e-6
    eps: 1e-5

  tokenizer:
    model_name: gpt2
    padding_side: left
    truncation_side: right 
    pad_token_as_eos_token: True 

  policy:
    id: actor
    args:
      model_type: causal
      model_name: CarperAI/openai_summarize_tldr_sft
      max_prompt_len: ${sampling.max_prompt_len}
      max_gen_len: ${sampling.max_gen_len}
      create_reference: False
      mlp_head: False
      quantize_model: True
      gen_kwargs: ${sampling.train_generation_kwargs}
      prompt_truncation_side: ${sampling.prompt_truncation_side}

  lora:
    peft_config:
      r: 8
      lora_alpha: 64
      lora_dropout: 0.1
      task_type: CAUSAL_LM
