meta_data:
  script_path: run_scripts/adv_irl_lfo_exp_script.py
  exp_name: gail_lfo_halfcheetah_union
  description: Train an adversarial IRL model
  num_workers: 6 # 64
  num_gpu_per_worker: 1 # 0
  num_cpu_per_worker: 32 # 2
  mem_per_worker: 4gb
  partitions: cpu
  node_exclusions: gpu048,gpu024,gpu025,gpu012,gpu027
# -----------------------------------------------------------------------------
variables:
  adv_irl_params:
    grad_pen_weight: [0.5]
    state_predictor_alpha: [0.38]
    inverse_dynamic_beta: [0.25]
  sac_params:
    reward_scale: [2.0]
  # seed: [723894, 23789]
  # decay_ratio: [0.99]
  seed: [0,2,3,4,5]

# --------------- --------------------------------------------------------------
constants:
  expert_name: 'halfcheetah_sac'
  expert_idx: 0
  scale_env_with_demo_stats: false
  traj_num: 4

  # decay_ratio: 0.99

  disc_num_blocks: 2
  disc_hid_dim: 128
  disc_hid_act: tanh
  disc_use_bn: false
  disc_clamp_magnitude: 10.0

  policy_net_size: 256
  policy_num_hidden_layers: 2

  adv_irl_params:
    mode: 'gail2'
    inverse_mode: 'MSE'
    state_only: true
    state_diff: false
    union: true
    update_weight: false
    sas: false
    qss: false

    num_epochs: 302
    num_steps_per_epoch: 100000
    num_steps_between_train_calls: 1000
    max_path_length: 1000
    min_steps_before_training: 5000

    eval_deterministic: true
    num_steps_per_eval: 20000
    
    replay_buffer_size: 20000
    no_terminal: true
    eval_no_terminal: false
    wrap_absorbing: false

    num_update_loops_per_train_call: 100
    num_disc_updates_per_loop_iter: 1
    num_policy_updates_per_loop_iter: 1
    num_state_predictor_updates_per_loop_iter: 1
    num_inverse_dynamic_updates_per_loop_iter: 1
    num_pretrain_updates: 10
    pretrain_steps_per_epoch: 5000

    disc_lr: 0.0003
    disc_momentum: 0.9
    use_grad_pen: true
    use_wgan: false
    # grad_pen_weight: 10.0
    disc_optim_batch_size: 256
    policy_optim_batch_size: 256
    policy_optim_batch_size_from_expert: 0

    state_predictor_lr: 0.01
    state_predictor_momentum: 0.9
    inverse_dynamic_lr: 0.0001
    inverse_dynamic_momentum: 0.9

    save_best: true
    freq_saving: 20
    save_replay_buffer: false
    save_environment: false
    save_algorithm: false

  sac_params:
    # reward_scale: 8.0
    discount: 0.99
    soft_target_tau: 0.005
    beta_1: 0.25
    policy_lr: 0.0003
    qf_lr: 0.0003
    vf_lr: 0.0003
    policy_mean_reg_weight: 0.001
    policy_std_reg_weight: 0.001

  env_specs:
    env_name: 'halfcheetah'
    env_kwargs: {}
    eval_env_seed: 78236
    training_env_seed: 24495
