# @package _global_
agent:
  name: FOIL
  _target_: agent.foil.FOIL
  obs_dim: ??? # to be specified later
  action_dim: ??? # to be specified later

  critic_cfg: ${q_net}
  actor_cfg: ${diag_gaussian_actor}
  disc_cfg: ${disc_net}
  init_temp: 1e-2 # use a low temp for IL

  alpha_lr: 3e-4
  alpha_betas: [0.9, 0.999]

  actor_lr: 3e-5
  actor_betas: [0.9, 0.999]
  actor_update_frequency: 2

  disc_lr: 1e-5
  disc_betas: [0.9, 0.999]
  disc_update_frequency: 20

  critic_lr: 3e-4
  critic_betas: [0.9, 0.999]
  critic_tau: 0.005
  critic_target_update_frequency: 1

  # learn temperature coefficient (disabled by default)
  learn_temp: false
  update_per_step: 5
  disc_reg: 1

  # Use either value_dice actor or normal SAC actor loss
  vdice_actor: false

  bc_transit: false

q_net:
  _target_: module.critic.DoubleQCritic
  obs_dim: ${agent.obs_dim}
  action_dim: ${agent.action_dim}
  hidden_dim: 256
  hidden_depth: 2

#disc_net:
#  _target_: module.discriminator.ReshapedDiscriminator
#  obs_dim: ${agent.obs_dim}
#  action_dim: ${agent.action_dim}
#  hidden_dim: 256
#  hidden_depth: 2
#  log_std_bounds: [-5, 2]

disc_net:
  _target_: module.discriminator.Discriminator
  obs_dim: ${agent.obs_dim}
  action_dim: ${agent.action_dim}
  hidden_dim: 256
  hidden_depth: 2

diag_gaussian_actor:
  _target_: module.actor.DiagGaussianActor
  obs_dim: ${agent.obs_dim}
  action_dim: ${agent.action_dim}
  hidden_dim: 256
  hidden_depth: 2
  log_std_bounds: [-5, 2]