# @package _global_
general:
    name: 'planar'
    gpus: 1
    wandb: 'online'
    test_only: 'checkpoints/planar.ckpt'
    evaluate_all_checkpoints: False
    num_final_sampling: 5
    final_seeds: [0,1,2,3,4]
    final_model_samples_to_generate: 40
    final_model_samples_to_save: 30
    final_model_chains_to_save: 20
train:
    n_epochs: 150000
    batch_size: 64
    save_model: True
    ema_decay: 0.999
model:
    n_layers: 10
    lambda_train: [1, 5]
    corrector_entry_time: 0.
    corrector_num_steps: 0
    corrector_tau_multiplier: 0.1
    transition: "marginal"
    rate_constant: [5., 5., 1.]
    diffusion_steps: 500

    hidden_mlp_dims: { 'X': 128, 'E': 64, 'y': 64 }
    hidden_dims: { 'dx': 256, 'de': 64, 'dy': 64, 'n_head': 8, 'dim_ffX': 256, 'dim_ffE': 64, 'dim_ffy': 128 }
dataset:
    pin_memory: True