# @package _group_
name: "mbpo"

normalize: true
target_is_delta: true
learned_rewards: true
freq_train_model: ${overrides.freq_train_model}

sac_samples_action: true
initial_exploration_steps: 10000
random_initial_explore: false
num_eval_episodes: 1

# --------------------------------------------
#          SAC Agent configuration
# --------------------------------------------
agent:
  _target_: mbrl.third_party.pytorch_sac.agent.sac.SACAgent
  obs_dim: ??? # to be specified later
  action_dim: ??? # to be specified later
  action_range: ??? # to be specified later
  device: ${device}
  critic_cfg: ${algorithm.double_q_critic}
  actor_cfg: ${algorithm.diag_gaussian_actor}
  discount: 0.99
  init_temperature: 0.1
  alpha_lr: ${overrides.sac_alpha_lr}
  alpha_betas: [0.9, 0.999]
  actor_lr: ${overrides.sac_actor_lr}
  actor_betas: [0.9, 0.999]
  actor_update_frequency: ${overrides.sac_actor_update_frequency}
  critic_lr: ${overrides.sac_critic_lr}
  critic_betas: [0.9, 0.999]
  critic_tau: 0.005
  critic_target_update_frequency: ${overrides.sac_critic_target_update_frequency}
  batch_size: 256
  learnable_temperature: true
  target_entropy: ${overrides.sac_target_entropy}

double_q_critic:
  _target_: mbrl.third_party.pytorch_sac.agent.critic.DoubleQCritic
  obs_dim: ${algorithm.agent.obs_dim}
  action_dim: ${algorithm.agent.action_dim}
  hidden_dim: 1024
  hidden_depth: ${overrides.sac_hidden_depth}

diag_gaussian_actor:
  _target_: mbrl.third_party.pytorch_sac.agent.actor.DiagGaussianActor
  obs_dim: ${algorithm.agent.obs_dim}
  action_dim: ${algorithm.agent.action_dim}
  hidden_depth: ${overrides.sac_hidden_depth}
  hidden_dim: 1024
  log_std_bounds: [-5, 2]
