_target_: agents.act_vision_agent.ActAgent # agents.joint_ibc_agent.PlanarRobotJointIBCAgent  # agents.joint_distribution_ebm_agent.PlanarBotJointEBMAgent
_recursive_: false

optimization:
  _target_: torch.optim.AdamW
  lr: 1.0e-4
  betas: [0.95, 0.999]
  eps: 1.0e-8
  weight_decay: 1.0e-6

lr_scheduler:
  _target_: torch.optim.lr_scheduler.CosineAnnealingLR
  #step_size: 100
  # gamma: 0.99
  T_max: 100000
  eta_min: 1e-6

model:
  _target_: agents.act_vision_agent.ActPolicy
  _recursive_: false

  visual_input: True
  device: ${device}

  model:
    _target_: agents.models.act.act_vae.ActVAE
    _recursive_: false

    action_encoder:
      _target_: agents.models.act.act_vae.TransformerEncoder
      embed_dim: 64
      n_heads: 4
      n_layers: 2
      attn_pdrop: 0.1
      resid_pdrop: 0.1
      bias: False
      block_size: ${add:${window_size}, 1}

    encoder:
      _target_: agents.models.act.act_vae.TransformerEncoder
      embed_dim: 64 #${hidden_dim}
      n_heads: 4
      n_layers: 2
      attn_pdrop: 0.1
      resid_pdrop: 0.1
      bias: False
      block_size: ${window_size}

    decoder:
      _target_: agents.models.act.act_vae.TransformerDecoder
      embed_dim: 64 #${hidden_dim}
      cross_embed: 64 #${hidden_dim} # allow cross embedding to have different dimension
      n_heads: 4
      n_layers: 4
      attn_pdrop: 0.1
      resid_pdrop: 0.1
      bias: False
      block_size: ${window_size}

    hidden_dim: 64 #${hidden_dim}
    action_dim: ${action_dim}
    state_dim: 128 #${obs_dim}
  #  goal_dim: ${goal_dim}
    act_seq_size: ${window_size} # number of action chuking
    latent_dim: 32 #${latent_dim}

  obs_encoder:
    _target_: agents.models.vision.multi_image_obs_encoder.MultiImageObsEncoder
    shape_meta: &shape_meta
      # acceptable types: rgb, low_dim
      obs:
        agentview_image:
          shape: [ 3, 96, 96 ]
          type: rgb
        in_hand_image:
          shape: [ 3, 96, 96 ]
          type: rgb
    #        robot_ee_pos:
    #          shape: [3]
    #          type default: low_dim
    #      robot0_eef_quat:
    #        shape: [4]
    #      robot0_gripper_qpos:
    #        shape: [2]
    #    action:
    #      shape: [10]

    rgb_model:
      _target_: agents.models.vision.model_getter.get_resnet
      input_shape: [ 3, 96, 96 ]
      output_size: 64
    resize_shape: null
    #    crop_shape: [ 84, 84 ]
    # constant center crop
    random_crop: False
    use_group_norm: True
    share_rgb_model: False
    imagenet_norm: True

trainset: ${trainset}
valset: ${valset}
train_batch_size: ${train_batch_size}
val_batch_size: ${val_batch_size}
num_workers: ${num_workers}
epoch: ${epoch}
device: ${device}
scale_data: ${scale_data}
eval_every_n_epochs: ${eval_every_n_epochs}

action_seq_size: ${window_size}
obs_size: 1

goal_conditioned: false
decay: 0
goal_window_size: 1
window_size: ${window_size}
patience: 80 # interval for early stopping during epoch training