model:
    correlation: 0.8
    model_name: fno2d
    lr: 1.0e-4
    latent: True
    fno2d:
        in_channels: 4
        out_channels: 4
        modes1: 24
        modes2: 24
        hidden_channels: 64
        cond_channels: 64
        cond_dim: 2
        num_layers: 5
    unet2d:
        in_channels: 4
        out_channels: 4
        hidden_channels: 64
        cond_channels: 64
        cond_dim: 2
        ch_mults: [1, 2, 4]
        norm: True
        use_scale_shift_norm: False
    lns:
        in_dim: 4
        out_dim: 4
        dim: 1024
        num_heads: 16
        num_layers: 16
        patch_size: [2, 2]
        input_size: [64, 16]
        cond_dim: 1024
        num_cond: 2 # rayleigh number, prandtl number
        scale_by_sigma: False
    ddpm:
        noise_steps: 100
        scale: 400
        skip_percent: 0
        beta_start: 1.0e-4
        beta_end: 0.02
        schedule: linear
        grid_size: [64, 16] # latent grid size (160x160) downsampled 8 times
    ddim:
        noise_steps: 100
        scale: 400
        beta_start: 1.0e-4
        beta_end: 0.02
        schedule: linear
        grid_size: [64, 16] # latent grid size (160x160) downsampled 8 times
        num_ddim_steps: 10
    tsm:
        noise_steps: 100
        scale: 400
        beta_start: 1.0e-4
        beta_end: 0.02
        schedule: linear
        grid_size: [64, 16] # latent grid size (160x160) downsampled 8 times
        skip_percent: 0.9
    edm:
        grid_size: [64, 16] # latent grid size (160x160) downsampled 8 times
        num_steps: 10
    flow_matching:
        num_refinement_steps: 10
        num_train_steps: 11
    interpolant:
        num_refinement_steps: 10
        num_train_steps: 11
        sigma_coef: 0.25
        integrator: euler
    dit:
        in_dim: 8
        out_dim: 4
        dim: 1024
        num_heads: 16
        num_layers: 12
        patch_size: [2, 2]
        input_size: [64, 16]
        cond_dim: 1024
        num_cond: 2 # rayleigh number, prandtl number
    autoencoder:
        model:
            model_name: AE
            lr: 1.0e-4
            encoder:
                in_channels: 4
                latent_channels: 4
                width_list: [64, 128, 256, 512] # 8x downsample
                block_type: ["ResBlock", "ResBlock", "ResBlock", "ResBlock"]
                depth_list: [4, 4, 4, 4]
                cond_dim: 2 # rayleigh number, prandtl number
            decoder:
                in_channels: 4
                latent_channels: 4
                width_list: [64, 128, 256, 512] # 8x downsample
                block_type: ["ResBlock", "ResBlock", "ResBlock", "ResBlock"]
                depth_list: [4, 4, 4, 4]
                cond_dim: 2 
        training:
            log_dir: logs/
            strategy: auto
        data:
            pde: rayleigh_benard
        scale_factor: 1.0
        checkpoint: /path/to/pretrained/autoencoder.ckpt

data:
    dataset:
        base_path: /path/to/data/the_well/
    normalizer:
        stat_path: /path/to/data/the_well/datasets/rayleigh_benard/stats.yaml
    pde: rayleigh_benard
    num_workers: 1
    batch_size: 8

training:
    seed: 42
    devices: [0]
    accelerator: gpu
    strategy: auto
    check_val_every_n_epoch: 2
    log_every_n_steps: 100
    max_epochs: 40
    accumulate_grad_batches: 4
    log_dir: logs/
    project: stochastic_interpolants
    wandb_mode: online
    checkpoint: null
