# diffgro.yaml

planner:
  training:
    total_timesteps: 200_000
    log_interval: 1_000
  inference:
    history: 4
  params:
    policy: !!python/name:diffgro.diffgro.policies.DiffGroPlannerPolicy ''
    learning_rate: !!float 3e-5
    batch_size: 8 # per task
    beta: !!float 5e-5
  policy_kwargs: 
    net_arch: 
      act: !!python/tuple
        - 1
        - 4
        - 8
      pri:
        - 128
        - 128
        - 128
    activation_fn: "mish"
    horizon: 8
    skill_dim: 512
    emb_dim: 128
    hid_dim: 128
    ctx_dim: 64
    n_denoise: 20
    cf_weight: !!float 1.0
    predict_epsilon: False
    beta_scheduler: "cosine"
    normalization_class: !!python/name:sb3_jax.common.norm_layers.RunningNormLayer ''

# ----- overrides ----- #

metaworld_complex:
  inference:
    history: 8
  params:
    learning_rate: !!float 1e-5
  policy_kwargs:
    domain: 'long'
    horizon: 16 
    ctx_dim: 128
