# @package _group_
name: "mbail"

is_model_based: true
total_env_steps: ${overrides.total_env_steps}
bc_init_steps: ${overrides.bc_init_steps}

normalize: true
normalize_double_precision: true
target_is_delta: true
learned_rewards: true
log_frequency: 5000
freq_train_model: ${overrides.freq_train_model}

sac_samples_action: true
random_initial_explore: false

num_eval_episodes: 5

# --------------------------------------------
#          Dynamics Model configuration
# --------------------------------------------
dynamics_model:
  _target_: mbrl.models.GaussianMLP
  device: ${device}
  num_layers: 4
  in_size: ???
  out_size: ???
  ensemble_size: 7
  hid_size: ${overrides.model_hid_size}
  deterministic: false
  # propagation_method: random_model
  learn_logvar_bounds: false  # so far this works better
  clip: ${overrides.model_clip_output}
  activation_fn_cfg:
    _target_: torch.nn.SiLU

# --------------------------------------------
#          SAC Agent configuration
# --------------------------------------------
agent:
  _target_: garage.mbrl.third_party.pytorch_sac_pranz24.sac.SAC
  num_inputs: ???
  relabel_samples: True
  action_space:
    _target_: gym.env.Box
    low: ???
    high: ???
    shape: ???
  args:
    gamma: ${overrides.sac_gamma}
    tau: ${overrides.sac_tau}
    alpha: ${overrides.sac_alpha}
    policy: ${overrides.sac_policy}
    target_update_interval: ${overrides.sac_target_update_interval}
    automatic_entropy_tuning: ${overrides.sac_automatic_entropy_tuning}
    target_entropy: ${overrides.sac_target_entropy}
    hidden_size: ${overrides.sac_hidden_size}
    device: ${device}
    lr: ${overrides.sac_lr}
    decay_horizon: ${overrides.sac_decay_horizon}

# --------------------------------------------
#          TD3-BC Agent configuration
# --------------------------------------------
td3_agent:
  _target_: garage.models.td3_bc.TD3_BC
  env: ???
  expert_buffer: ???
  learner_buffer: ???
  discriminator: ???
  cfg: ???
  actor_wd: ${overrides.actor.actor_wd}
  critic_wd: ${overrides.actor.critic_wd}
  discount: ${overrides.actor.discount}
  tau: ${overrides.actor.tau}
  policy_noise_scalar: ${overrides.actor.policy_noise_scalar}
  noise_clip_scalar: ${overrides.actor.noise_clip_scalar}
  policy_freq: ${overrides.actor.policy_freq}
  alpha: ${overrides.actor.alpha}
  decay_lr: ${overrides.actor.decay_lr}
  hybrid_sampling: True
  device: ${device}