_target_: agents.act_agent.ActAgent # agents.joint_ibc_agent.PlanarRobotJointIBCAgent  # agents.joint_distribution_ebm_agent.PlanarBotJointEBMAgent
_recursive_: false

optimization:
  _target_: torch.optim.AdamW
  lr: 1e-4
  betas: [0.9, 0.95]
  weight_decay: 1e-6

lr_scheduler:
  _target_: torch.optim.lr_scheduler.CosineAnnealingLR
  #step_size: 100
  # gamma: 0.99
  T_max: 100000
  eta_min: 1e-6

model:
  _target_: agents.models.act.act_vae.ActVAE
  _recursive_: false

  action_encoder:
    _target_: agents.models.act.act_vae.TransformerEncoder
    embed_dim: 64
    n_heads: 4
    n_layers: 2
    attn_pdrop: 0.1
    resid_pdrop: 0.1
    bias: False
    block_size: ${add:${window_size}, 1}

  encoder:
    _target_: agents.models.act.act_vae.TransformerEncoder
    embed_dim: 64 #${hidden_dim}
    n_heads: 4
    n_layers: 2
    attn_pdrop: 0.1
    resid_pdrop: 0.1
    bias: False
    block_size: ${window_size}

  decoder:
    _target_: agents.models.act.act_vae.TransformerDecoder
    embed_dim: 64 #${hidden_dim}
    cross_embed: 64 #${hidden_dim} # allow cross embedding to have different dimension
    n_heads: 4
    n_layers: 4
    attn_pdrop: 0.1
    resid_pdrop: 0.1
    bias: False
    block_size: ${window_size}

  hidden_dim: 64 #${hidden_dim}
  action_dim: ${action_dim}
  state_dim: ${obs_dim}
#  goal_dim: ${goal_dim}
  act_seq_size: ${window_size} # number of action chuking
  latent_dim: 32 #${latent_dim}

kl_loss_factor: 1.0

trainset: ${trainset}
valset: ${valset}
train_batch_size: ${train_batch_size}
val_batch_size: ${val_batch_size}
num_workers: ${num_workers}
epoch: ${epoch}
device: ${device}
scale_data: ${scale_data}
eval_every_n_epochs: ${eval_every_n_epochs}

action_seq_size: ${window_size}
obs_size: 1

goal_conditioned: false
decay: 0
goal_window_size: 1
window_size: ${window_size}
patience: 80 # interval for early stopping during epoch training