# langdt.yml

planner:
  training:
    total_timesteps: 200_000
    log_interval: 1_000
  params:
    policy: !!python/name:diffgro.langdt.policies.LangDTPlannerPolicy ''
    learning_rate: !!float 3e-5
    batch_size: 8 # per task
  policy_kwargs: 
    activation_fn: "gelu_new"
    max_length: 8 # horizon
    max_ep_length: 500
    skill_dim: 512
    hidden_size: 128
    n_layer: 2
    n_head: 4
    n_inner: 512
    resid_pdrop: 0.1
    attn_pdrop: 0.1
    normalization_class: !!python/name:sb3_jax.common.norm_layers.RunningNormLayer ''

# ----- overrides ----- #

metaworld_complex:
  params:
    learning_rate: !!float 1e-5
  policy_kwargs:
    domain: 'long'
    max_length: 16
    max_ep_length: 2000
