_target_: agents.cvae_vision_agent.CVAEAgent
_recursive_: false

model:
  _target_: agents.cvae_vision_agent.CVAEPolicy
  _recursive_: false

  visual_input: True
  device: ${device}

  model:
    _target_: agents.models.vae.cvae.VariationalAE
    _recursive_: false
    device: ${device}

    encoder:
      _target_: agents.models.vae.cvae.VariationalEncoder
      _recursive_: false
      latent_dim: 32
      device: ${device}
      model_config:
        _target_: agents.models.common.mlp.ResidualMLPNetwork
        input_dim: ${add:${action_dim}, 128} #${add:${action_dim}, ${obs_dim}}
        hidden_dim: ${hidden_dim}
        num_hidden_layers: ${num_hidden_layers}
        #      output_dim: 128
        dropout: 0
        activation: "Mish"
        use_spectral_norm: false
        device: ${device}

    decoder:
      _target_: agents.models.common.mlp.ResidualMLPNetwork
      _recursive_: false
      input_dim: 128 #${obs_dim}
      #    hidden_dim: ${hidden_dim}
      #    num_hidden_layers: ${num_hidden_layers}
      output_dim: ${action_dim}
      dropout: 0
      activation: "Mish"
      use_spectral_norm: false
      use_norm: false
      norm_style: "BatchNorm"
      device: ${device}

  obs_encoder:
    _target_: agents.models.vision.multi_image_obs_encoder.MultiImageObsEncoder
    shape_meta: &shape_meta
      # acceptable types: rgb, low_dim
      obs:
        agentview_image:
          shape: [3, 96, 96]
          type: rgb
        in_hand_image:
          shape: [3, 96, 96]
          type: rgb
#        robot_ee_pos:
#          shape: [3]
#          type default: low_dim
    #      robot0_eef_quat:
    #        shape: [4]
    #      robot0_gripper_qpos:
    #        shape: [2]
    #    action:
    #      shape: [10]

    rgb_model:
      _target_: agents.models.vision.model_getter.get_resnet
      input_shape: [3, 96, 96]
      output_size: 64
#      name: resnet18
#      weights: null
    resize_shape: null
#    crop_shape: [76, 76]
    # constant center crop
    random_crop: False
    use_group_norm: True
    share_rgb_model: False
    imagenet_norm: True

kl_loss_factor: 1.6412506376100464

trainset: ${trainset}
valset: ${valset}

optimization:
  _target_: torch.optim.AdamW
  lr: 1.0e-4
  betas: [0.95, 0.999]
  eps: 1.0e-8
  weight_decay: 1.0e-6

#optimization:
#  _target_: torch.optim.Adam
#  lr: 1e-4
#  weight_decay: 0

train_batch_size: ${train_batch_size}
val_batch_size: ${val_batch_size}
num_workers: ${num_workers}
epoch: ${epoch}
device: ${device}
scale_data: ${scale_data}
eval_every_n_epochs: ${eval_every_n_epochs}