# task name for logging
task_name: flow/qm9-refiner
pretrained_ckpt: CKPT/DIR/FOR/ET-FLOW

# unique seed for experiment reproducibility
seed: 42

# data config
datamodule: BaseDataModule
datamodule_args:
    partition: qm9
    dataloader_args:
        batch_size: 128
        num_workers: 8
        pin_memory: false
        persistent_workers: false
data_root: DATA_DIR
cascade: GlobalRefiner

# Sigma schedule
sigma_args:
    sigma_scale: 1

# model config
model: GlobalRefiner
model_args:
    # network args
    network_type: TorchMDDynamics
    hidden_channels: 160
    num_layers: 20
    num_rbf: 64
    rbf_type: expnorm
    trainable_rbf: true
    activation: silu
    neighbor_embedding: true
    cutoff_lower: 0.0
    cutoff_upper: 10.0
    max_z: 100
    node_attr_dim: 10
    edge_attr_dim: 1
    attn_activation: silu
    num_heads: 8
    distance_influence: both
    reduce_op: sum
    qk_norm: true
    clip_during_norm: true
    so3_equivariant: false
    output_layer_norm: true
    parity_switch: post_hoc

    # flow matching specific
    sigma: 0.1
    prior_type: gaussian

    # optimizer args
    optimizer_type: AdamW
    lr: 1.e-4
    weight_decay: 1.e-8

    # lr scheduler args
    lr_scheduler_type: CosineAnnealingWarmupRestarts
    first_cycle_steps: 10000
    cycle_mult: 1.0
    max_lr: 1.e-4
    min_lr: 1.e-6
    warmup_steps: 500
    gamma: 0.95
    last_epoch: -1
    lr_scheduler_monitor: val/loss
    lr_scheduler_interval: step
    lr_scheduler_frequency: 1

# callbacks
callbacks:
    -   callback: ModelCheckpoint
        callback_args:
            dirpath: './checkpoint'
            monitor: val/loss
            mode: min
            save_last: true
            every_n_epochs: 1
            save_top_k: 3

    -   callback: LearningRateMonitor
        callback_args:
            log_momentum: false
            logging_interval: null

# trainer
trainer: Trainer
trainer_args:
    max_epochs: 250
    devices: 1
    enable_progress_bar: true
    strategy: auto
    accelerator: auto

# eval config
eval_args:
    batch_size: 32
    sampler_args:
        sampler_type: sde
        n_timesteps: 20
        s_churn: 1.0
        t_min: 0.0001
        t_max: 0.9999
        std: 1.0
