# # Taken and modified from Coste's(tlc4418) configs/config_rl.yaml

defaults_rlhf:
  rng_seed: 0xa1221f97
  datasets: []
  datasets_extra: []
  cache_dir: .cache
  output_dir: runs/ppo
  eval_size: 64
  num_eval_prompts: 64
  rank_config:
  sft_config:
  debug: false
  dtype: bf16


# Most paremeters are based off OpenAssistant's defaults for pythia_rlhf
pythia_rlhf_individual:
  output_dir: runs/ppo_individual_kl_0.1_64_a2_4
  datasets:
    - alpaca_farm

  gold_config:
    model_name: alpaca_models/reward-model-human
    is_alpacafarm_rm: true
    batch_size: 32

  rank_config:
    is_reward_model: true
    model_names: 
      - models/rm-pythia-1_3b_seed3
    cache_dir: .cache
    pooling: last
    residual_dropout: 0.01
    use_flash_attention: false
    dtype: bf16
    batch_size: 128

  sft_config:
    is_reward_model: false
    model_name: tlc4418/pythia_1.4b_sft_policy
    cache_dir: .cache
    quantization: false
    seq2seqmodel: false
    freeze_layer:
    num_layers_unfrozen: 2
    residual_dropout: 0.2
    use_flash_attention: false
    dtype: bf16
    batch_size: 32


pythia_rlhf_ensemble:
  output_dir: runs/ppo_ensemble_kl_0.1_64
  datasets:
    - alpaca_farm

  gold_config:
    model_name: alpaca_models/reward-model-human
    is_alpacafarm_rm: false
    batch_size: 32

  rank_config:
    is_reward_model: true
    model_names: 
      - models_hf/rm-pythia-44m_seed1
      - models_hf/rm-pythia-44m_seed2
      - models_hf/rm-pythia-44m_seed3


    objective_name: UWO # Change objetive (mean, WCO, or UWO)
    uwo_weight: 0.1 # Change UWO weight (only for UWO)
    cache_dir: .cache
    pooling: last
    residual_dropout: 0.01
    use_flash_attention: false
    dtype: bf16
    batch_size: 128

  sft_config:
    is_reward_model: false
    model_name: tlc4418/pythia_1.4b_sft_policy
    cache_dir: .cache
    quantization: false
    seq2seqmodel: false
    freeze_layer:
    num_layers_unfrozen: 2 
    residual_dropout: 0.2
    use_flash_attention: false
    dtype: bf16
    batch_size: 32