_target_: agents.ibc_vision_agent.IBCAgent
_recursive_: false

model:
  _target_: agents.ibc_vision_agent.IBCPolicy
  _recursive_: false

  visual_input: True
  device: ${device}

  model:
    _target_: agents.models.ibc.ebms.EBMMLP
    _recursive_: false

    device: ${device}

    mlp:
      _target_: agents.models.common.mlp.ResidualMLPNetwork
      input_dim: ${add:${action_dim}, 128}
      hidden_dim: ${hidden_dim}
      num_hidden_layers: ${num_hidden_layers}
      output_dim: 1
      dropout: 0
      activation: 'Mish'
      use_spectral_norm: false
      use_norm: false
      norm_style: 'BatchNorm'
      device: ${device}

  obs_encoder:
    _target_: agents.models.vision.multi_image_obs_encoder.MultiImageObsEncoder
    shape_meta: &shape_meta
      # acceptable types: rgb, low_dim
      obs:
        agentview_image:
          shape: [ 3, 96, 96 ]
          type: rgb
        in_hand_image:
          shape: [ 3, 96, 96 ]
          type: rgb
    #        robot_ee_pos:
    #          shape: [3]
    #          type default: low_dim
    #      robot0_eef_quat:
    #        shape: [4]
    #      robot0_gripper_qpos:
    #        shape: [2]
    #    action:
    #      shape: [10]

    rgb_model:
      _target_: agents.models.vision.model_getter.get_resnet
      input_shape: [ 3, 96, 96 ]
      output_size: 64
    resize_shape: null
    #    crop_shape: [ 84, 84 ]
    # constant center crop
    random_crop: False
    use_group_norm: True
    share_rgb_model: False
    imagenet_norm: True

sampler:
  _target_: agents.models.ibc.samplers.langevin_mcmc.LangevinMCMCSampler
  _recursive_: false
  noise_scale: 0.1
  noise_scale_infer: 0.5
  noise_shrink: 0.894
  train_samples: 8
  inference_samples: 64
  train_iterations: 40
  inference_iterations: 10
  sampler_stepsize_init: 0.0493 #0.894
  sampler_stepsize_init_infer: 0.5
  second_inference_stepsize_init: 0.00001
  sampler_stepsize_decay: 0.8 #0.8655
  sampler_stepsize_final: 0.00001
  sampler_stepsize_power: 2.0
  use_polynomial_rate: True
  second_inference_iteration: True
  delta_action_clip: 0.1
  device: ${device}

optimization:
  _target_: torch.optim.AdamW
  lr: 1.0e-4
  betas: [0.95, 0.999]
  eps: 1.0e-8
  weight_decay: 1.0e-6

lr_scheduler:
  _target_: torch.optim.lr_scheduler.StepLR
  step_size: 200
  gamma: 0.99

loss_type: 'info_nce' # possible loss types: 'info_nce'  'cd' 'cd_kl'
avrg_e_regularization: 0
kl_loss_factor: 0.3
grad_norm_factor: 1
use_ema: False
decay: 0.999
update_ema_every_n_steps: 1
stop_value: 1
goal_conditioning: False

trainset: ${trainset}
valset: ${valset}

train_batch_size: ${train_batch_size}
val_batch_size: ${val_batch_size}
num_workers: ${num_workers}
epoch: ${epoch}
device: ${device}
scale_data: ${scale_data}
eval_every_n_epochs: ${eval_every_n_epochs}