_target_: agents.ibc_agent.IBCAgent
_recursive_: false

model:
  _target_: agents.models.ibc.ebms.EBMMLP
  _recursive_: false

  device: ${device}

  mlp:
    _target_: agents.models.common.mlp.ResidualMLPNetwork
    input_dim: ${add:${action_dim}, ${obs_dim}}
    hidden_dim: ${hidden_dim}
    num_hidden_layers: ${num_hidden_layers}
    output_dim: 1
    dropout: 0
    activation: 'Mish'
    use_spectral_norm: false
    use_norm: false
    norm_style: 'BatchNorm'
    device: ${device}

sampler:
  _target_: agents.models.ibc.samplers.langevin_mcmc.LangevinMCMCSampler
  _recursive_: false
  noise_scale: 0.1
  noise_scale_infer: 0.5
  noise_shrink: 0.894
  train_samples: 8
  inference_samples: 64
  train_iterations: 40
  inference_iterations: 10
  sampler_stepsize_init: 0.0493 #0.894
  sampler_stepsize_init_infer: 0.5
  second_inference_stepsize_init: 0.00001
  sampler_stepsize_decay: 0.8 #0.8655
  sampler_stepsize_final: 0.00001
  sampler_stepsize_power: 2.0
  use_polynomial_rate: True
  second_inference_iteration: True
  delta_action_clip: 0.1
  device: ${device}

optimization:
  _target_: torch.optim.AdamW
  lr: 1.0e-3
  betas: [0.95, 0.999]
  eps: 1.0e-8
  weight_decay: 1.0e-6

lr_scheduler:
  _target_: torch.optim.lr_scheduler.StepLR
  step_size: 200
  gamma: 0.99

loss_type: 'info_nce' # possible loss types: 'info_nce'  'cd' 'cd_kl'
avrg_e_regularization: 0
kl_loss_factor: 0.3
grad_norm_factor: 1
use_ema: False
decay: 0.999
update_ema_every_n_steps: 1
stop_value: 1
goal_conditioning: False

trainset: ${trainset}
valset: ${valset}

train_batch_size: ${train_batch_size}
val_batch_size: ${val_batch_size}
num_workers: ${num_workers}
epoch: ${epoch}
device: ${device}
scale_data: ${scale_data}
eval_every_n_epochs: ${eval_every_n_epochs}