defaults:
  - model: per_token_iql
  - dataset@train_dataset: list_train
  - dataset@eval_dataset: list_val
  - evaluator: iql_evaluator
  # - score_evaluators@evaluator: bc_iql_eval
  # - score_evaluators@evaluator: iql_eval
  - _self_

train_dataset:
  cache_id: d_train
  data:
    reward_cache: data/vis_dialogue/processed/visdial_0.5/train_rank_reward_cache1.json
    # reward_cache: data/vis_dialogue/processed/visdial_0.5/train_reward_cache2.json
    # additional_scenes: data/vis_dialogue/processed/visdial_0.5/is_it_sunny_events.pkl
    # reward_shift: 30.0
    # reward_scale: 1e6
    mode: env_stops
    # mode: 10_stop
    cutoff_rule:
      name: percentile_cutoff_rule
      goal_value: 1.0
      percentile: 0.5
    yn_reward: -2.0
    yn_reward_kind: hard
  # token_reward:
  #   name: specified_token_reward
  #   token_file: data/wikitext/wikitext-103-train_gpt2_token_freq.json
  #   scale: 20.0
  #   shift: -1.0
  # top_p: 0.1

eval_dataset:
  cache_id: d_eval
  data:
    reward_cache: data/vis_dialogue/processed/visdial_0.5/val_rank_reward_cache1.json
    # reward_cache: data/vis_dialogue/processed/visdial_0.5/train_reward_cache2.json
    # reward_shift: 30.0
    # reward_scale: 1e6
    mode: env_stops
    # mode: 10_stop
    cutoff_rule:
      name: percentile_cutoff_rule
      goal_value: 1.0
      percentile: 0.5
    yn_reward: -2.0
    yn_reward_kind: hard
  # token_reward:
  #   name: specified_token_reward
  #   token_file: data/wikitext/wikitext-103-train_gpt2_token_freq.json
  #   scale: 20.0
  #   shift: -1.0
  # top_p: 0.1

model:
  alpha: 0.005
  gamma: 0.99
  beta: 0.0
  transition_weight: 0.0
  clip_weight: null
  value_max: null
  value_min: null
  detach_v: false
  detach_q: false
  detach_pi: false
  double_q: true
  seperate_policy: true
  seperate_target: true
  tau: 0.8
  exp_weights: true
  dm_margin: 0.0
  advanced_mlp: false
  cql_temp: 1.0
  gpt2:
    lm_head: true
    from_pretrained: true
  dataset:
    name: vis_dial_list_dataset
    cache_id: d_train
  load:
    # checkpoint_path: outputs/visual_dialogue/visdial_bc_iql_test1/model.pkl
    # checkpoint_path: outputs/visual_dialogue/visdial_mc_iql_test1/model_360447.pkl
    # checkpoint_path: outputs/visual_dialogue/visdial_iql_standard_new_extraction_test1/model_393215.pkl
    checkpoint_path: outputs/visual_dialogue/visdial_bc_official_test1/model_converted.pkl
    # checkpoint_path: outputs/visual_dialogue/visdial_hard_yn_iql_official_test2/awac_model.pkl
    strict_load: false

# evaluator:
#   env:
#     url: http://localhost:5001/step_rank
#     dataset:
#       name: vis_dial_list_dataset
#       cache_id: d_eval
#     # reward_shift: 30.0
#     # reward_scale: 1e6
#   verbose: true
#   kind: sample
#   generation_kwargs:
#     max_generation_len: 40
#     # beam_width: 1
#     # temp: 1.0
#     num_generations: 1
#     # max_generation_len: null 
#     # top_k: null
#     # top_p: null


evaluator:
  env:
    url: http://localhost:5001/step_rank
    actor_stop: false
    dataset:
      name: vis_dial_list_dataset
      cache_id: d_eval
    # reward_shift: 30.0
    # reward_scale: 1e6
    yn_reward: -2.0
    yn_reward_kind: hard
  verbose: true
  kind: sample
  generation_kwargs:
    max_generation_len: 40
    # beam_width: 1
    temp: 1.0
    top_k: null
    top_p: null
    exp_adv: true
    adv_weight: 16.0
    adv_clip: null
    include_logits: true
    include_adv: true
    num_generations: 1
    rerank_log_prob_weight: 0.0
    rerank_advantage_weight: 1.0

train:
  save_checkpoint_dir: outputs/visual_dialogue/visdial_hard_yn_iql_no_cql_test2/
  optim_state_path: null
  epochs: 10000000
  dataloader_workers: 1
  bsize: 1
  grad_accum_steps: 64
  log_every: 256
  eval_every: 4096
  save_every: 32768
  max_checkpoints: 1
  eval_bsize: 1
  eval_batches: 32
  lr: 1e-5
  weight_decay: 0.00
  hard_update_every: null
  max_steps: null
  loss:
    v_loss_weight: 1.0
    q_loss_weight: 1.0
    awac_weight: 0.0
    cql_loss_weight: 0.0
    dm_loss_weight: 0.0
    mc_returns: false

wandb:
  use_wandb: true
  wandb_project: visdial_iql
