_target_: agents.cvae_agent.CVAEAgent
_recursive_: false

model:
  _target_: agents.models.vae.cvae.VariationalAE
  _recursive_: false
  device: ${device}

  encoder:
    _target_: agents.models.vae.cvae.VariationalEncoder
    _recursive_: false
    latent_dim: 32
    device: ${device}
    model_config:
      _target_: agents.models.common.mlp.ResidualMLPNetwork
      input_dim: ${add:${action_dim}, ${obs_dim}}
      hidden_dim: ${hidden_dim}
      num_hidden_layers: ${num_hidden_layers}
#      output_dim: 128
      dropout: 0
      activation: 'Mish'
      use_spectral_norm: false
      device: ${device}

  decoder:
    _target_: agents.models.common.mlp.ResidualMLPNetwork
    _recursive_: false
    input_dim: ${obs_dim}
#    hidden_dim: ${hidden_dim}
#    num_hidden_layers: ${num_hidden_layers}
    output_dim: ${action_dim}
    dropout: 0
    activation: 'Mish'
    use_spectral_norm: false
    use_norm: false
    norm_style: 'BatchNorm'
    device: ${device}

kl_loss_factor: 1.6412506376100464

trainset: ${trainset}
valset: ${valset}

#optimization:
#  _target_: torch.optim.AdamW
#  lr: 1.0e-3
#  betas: [0.9, 0.995]
#  eps: 1.0e-8
#  weight_decay: 0.1

optimization:
  _target_: torch.optim.Adam
  lr: 1e-4
  weight_decay: 0

train_batch_size: ${train_batch_size}
val_batch_size: ${val_batch_size}
num_workers: ${num_workers}
epoch: ${epoch}
device: ${device}
scale_data: ${scale_data}
eval_every_n_epochs: ${eval_every_n_epochs}