model:
    model_name: AE
    lr: 1.0e-4
    encoder:
        in_channels: 4
        latent_channels: 4
        width_list: [64, 128, 256, 512] # 8x downsample
        block_type: ["ResBlock", "ResBlock", "ResBlock", "ResBlock"]
        depth_list: [4, 4, 4, 4]
        cond_dim: 2 # prandtl number, rayleigh number
    decoder:
        in_channels: 4
        latent_channels: 4
        width_list: [64, 128, 256, 512]
        block_type: ["ResBlock", "ResBlock", "ResBlock", "ResBlock"]
        depth_list: [4, 4, 4, 4]
        cond_dim: 2 

data:
    dataset:
        base_path: /path/to/data/the_well/
    normalizer:
        stat_path: /path/to/data/the_well/datasets/rayleigh_benard/stats.yaml
    pde: rayleigh_benard
    num_workers: 1
    batch_size: 10
    ae: True

training:
    seed: 42
    devices: 1
    accelerator: gpu
    strategy: auto
    check_val_every_n_epoch: 1
    log_every_n_steps: 100
    max_epochs: 100
    log_dir: logs/
    project: stochastic_interpolants
    wandb_mode: online
    checkpoint: null