_target_: agents.ddpm_encdec_agent.DiffusionAgent
_recursive_: false

action_seq_size: 4
obs_seq_len: 5

model:
  _target_: agents.models.diffusion.diffusion_policy.Diffusion
  _recursive_: false

  state_dim: ${obs_dim}
  action_dim: ${action_dim}
  beta_schedule: 'cosine'
  n_timesteps: 16
  loss_type: 'l2'
  clip_denoised: true
  predict_epsilon: true
  device: ${device}
  diffusion_x: False
  diffusion_x_M: 10

  model:
    _target_: agents.models.diffusion.diffusion_models.DiffusionEncDec
    _recursive_: false
    state_dim: ${obs_dim}
    action_dim: ${action_dim}
    goal_conditioned: False
    goal_seq_len: 10
    obs_seq_len: 5
    action_seq_len: 4
    embed_pdrob: 0
    # Architecture details
    embed_dim: 64 #${n_embd}
    device: ${device}
    linear_output: true

    encoder:
      _target_: agents.models.act.act_vae.TransformerEncoder
      embed_dim: 64 #${n_embd}
      n_heads: 4
      n_layers: 2
      attn_pdrop: 0.1
      resid_pdrop: 0.1
      bias: False
      block_size: ${add:${window_size}, 1}

    decoder:
      _target_: agents.models.act.act_vae.TransformerDecoder
      embed_dim: 64 #${n_embd}
      cross_embed: 64 #${n_embd} # allow cross embedding to have different dimension
      n_heads: 4
      n_layers: 4
      attn_pdrop: 0.1
      resid_pdrop: 0.1
      bias: False
      block_size: ${add:${window_size}, 1}

#image_encoder:
#  _target_: agents.models.vision.multi_image_obs_encoder.MultiImageObsEncoder
#  shape_meta: &shape_meta
#    # acceptable types: rgb, low_dim
#    obs:
#      agentview_image:
#        shape: [3, 96, 96]
#        type: rgb
#      in_hand_image:
#        shape: [3, 96, 96]
#        type: rgb
##      robot_ee_pos:
##        shape: [2]
#        # type default: low_dim
##      robot0_eef_quat:
##        shape: [4]
##      robot0_gripper_qpos:
##        shape: [2]
#    action:
#      shape: [10]
#
#  rgb_model:
#    _target_: agents.models.vision.model_getter.get_resnet
#    name: resnet18
#    weights: null
#  resize_shape: null
#  crop_shape: [76, 76]
#  # constant center crop
#  random_crop: True
#  use_group_norm: True
#  share_rgb_model: True
#  imagenet_norm: True
#
#visual_input: True

optimization:
  _target_: torch.optim.Adam
  lr: 5e-4 # for transformer
#  lr: 5e-4 # for MLP
  weight_decay: 0

#optimization:
#  _target_: torch.optim.AdamW
#  lr: 1.0e-4
#  betas: [0.9, 0.995]
#  eps: 1.0e-8
#  weight_decay: 0.1

trainset: ${trainset}
valset: ${valset}
train_batch_size: ${train_batch_size}
val_batch_size: ${val_batch_size}
num_workers: ${num_workers}
epoch: ${epoch}
device: ${device}
scale_data: ${scale_data}
eval_every_n_epochs: ${eval_every_n_epochs}

discount: 0.99
use_ema: True
decay: 0.995
update_ema_every_n_steps: 1
goal_window_size: 1
window_size: ${window_size}
diffusion_kde: false
diffusion_kde_samples: 100
goal_conditioned: False