defaults:
  - model: per_token_iql
  - dataset@train_dataset: list_train
  - dataset@eval_dataset: list_test
  # - score_evaluators@evaluator: iql_eval
  - evaluator: iql_evaluator
  - _self_

train_dataset:
  cache_id: d_train
  data:
    cache_id: train_raw_data
    cache_path: null
    reward_shift: 5.0
    reward_scale: 5.0
    reward_f:
      # name: toxicity_noised_reward
      name: score_human_reward
      reddit_path: data/reddit_comments/
      index_path: data/reddit_comments/train_idxs.json
    # reward_f:
    #   name: model_reward
    #   cache_id: reward_model
    #   model:
    #     name: roberta_binary_reward_model
    #     dataset:
    #       name: toxicity_list_dataset
    #       data:
    #         name: reddit_comments
    #         path: data/reddit_comments/
    #         cache_path: null
    #         reward_shift: 0.0
    #         reward_scale: 1.0
    #         reward_f: null
    #         index_path: data/reddit_comments/test_idxs.json
    #       token_reward:
    #         name: constant_token_reward
    #         c: 0.0
    #       max_len: 256
    #       cuttoff: null
    #       resample_timeout: 0.0
    #       include_parent: true
    #     roberta_kind: roberta-base
    #     freeze_roberta: false
    #     reward_cuttoff: 0.0
    #     load:
    #       name: roberta_binary_reward_model
    #       checkpoint_path: outputs/toxicity/toxicity_upvote_roberta_binary_reward_f2/model.pkl
    #       strict_load: true
  max_len: 512
  # cuttoff: 10.0
  resample_timeout: 0.0
  include_parent: true

eval_dataset:
  cache_id: d_test
  data:
    cache_id: test_raw_data
    cache_path: null
    reward_shift: 5.0
    reward_scale: 5.0
    # reward_f:
      # name: toxicity_noised_reward
    reward_f:
      name: score_human_reward
      reddit_path: data/reddit_comments/
      index_path: data/reddit_comments/test_idxs.json
    # reward_f:
    #   name: model_reward
    #   cache_id: reward_model
  max_len: 512
  # cuttoff: 10.0
  resample_timeout: 0.0
  include_parent: true

model:
  alpha: 0.005
  gamma: 0.99
  beta: 0.0
  transition_weight: 0.0
  clip_weight: null
  value_max: null
  value_min: null
  detach_v: false
  detach_q: false
  detach_pi: false
  double_q: true
  seperate_policy: true
  seperate_target: true
  tau: 0.5
  exp_weights: true
  dm_margin: 0.0
  advanced_mlp: false
  cql_temp: 1.0
  gpt2:
    lm_head: true
    from_pretrained: true
  dataset:
    name: toxicity_list_dataset
    cache_id: d_train
  load:
    # checkpoint_path: null
    checkpoint_path: outputs/toxicity/conditional_toxicity_bc_test1/model_converted.pkl
    # checkpoint_path: outputs/toxicity/upvotes_iql_reward_from_model_test1/model.pkl
    strict_load: false

# evaluator:
#   env:
#     reward_shift: 0.0
#     reward_scale: 10.0
#     # reward_f:
#     #   name: model_reward
#     #   model:
#     #     name: roberta_binary_reward_model
#     #     dataset:
#     #       name: toxicity_list_dataset
#     #       cache_id: d_train
#     #     roberta_kind: roberta-base
#     #     freeze_roberta: false
#     #     reward_cuttoff: 0.0
#     #     load:
#     #       name: roberta_binary_reward_model
#     #       checkpoint_path: outputs/toxicity/toxicity_upvote_roberta_binary_reward_f2/model.pkl
#     #       strict_load: true
#     reward_f:
#       name: model_reward
#       cache_id: reward_model
#   kwargs_main:
#     beta: 4.0
#     exp_weights: true
#     clip_weight: null
#     logit_temp: 1.0
#     logit_top_k: null
#     logit_top_p: null
#     include_logits: true
#     include_advantage: true
#   verbose: true
#   num_generations: 1
#   max_generation_len: 256

evaluator:
  env:
    reward_shift: 0.0
    reward_scale: 10.0
    data:
      name: reddit_comments
      cache_id: test_raw_data
    reward_f:
      # name: toxicity_noised_reward
      name: model_reward
      model:
        name: roberta_binary_reward_model
        dataset:
          name: toxicity_list_dataset
          cache_id: d_train
        roberta_kind: roberta-base
        freeze_roberta: false
        reward_cuttoff: 0.0
        load:
          name: roberta_binary_reward_model
          checkpoint_path: outputs/toxicity/toxicity_upvote_roberta_binary_reward_f2/model.pkl
          strict_load: true
      # name: model_reward
      # cache_id: reward_model
    include_parent: true
  verbose: true
  kind: sample
  generation_kwargs:
    max_generation_len: 256
    # beam_width: 16
    temp: 1.0
    top_k: null
    top_p: null
    exp_adv: true
    adv_weight: 8.0
    adv_clip: null
    include_logits: true
    include_adv: true
    num_generations: 1
    rerank_log_prob_weight: 0.0
    rerank_advantage_weight: 1.0

train:
  save_checkpoint_dir: outputs/toxicity/conditional_gt_upvotes_official_iql_test1_2/
  optim_state_path: null
  epochs: 10000000
  dataloader_workers: 0
  bsize: 1
  grad_accum_steps: 64
  log_every: 256
  eval_every: 4096
  save_every: 32768
  max_checkpoints: 1
  eval_bsize: 1
  eval_batches: 32
  lr: 1e-4
  weight_decay: 0.00
  hard_update_every: null
  max_steps: null
  loss:
    v_loss_weight: 1.0
    q_loss_weight: 1.0
    awac_weight: 0.0
    cql_loss_weight: 0.25
    dm_loss_weight: 0.0
    mc_returns: false

wandb:
  use_wandb: true
  wandb_project: toxicity_iql
