defaults:
  - dataset: rayleigh_benard
  - denoiser: vit_large
  - denoiser/loss: uniform
  - denoiser/schedule: log_logit
  - optim: adamw
  - server: local

ae_run: ???
load_ae: True
load_surrogate: True
surrogate_run: ???

staged_training:
  threshold_lr: 1e-4     # LR for COMMON/pretrained parameters
  
  # Scheduler settings for common (pretrained) parameters
  common_scheduler: "cosine"
  common_warmup: 10
  
  # Scheduler settings for new parameters  
  new_scheduler: "cosine"
  new_warmup: 10
  
  # Optional: Override optimizer settings for common parameters
  common_params_kwargs:
    weight_decay: 0.0
    betas: [0.9, 0.99]
  
  # Optional: Override optimizer settings for new parameters
  new_params_kwargs:
    weight_decay: 0.0
    betas: [0.9, 0.99]

trajectory:
  length: 5
  stride: 1
  context:
    lmbda: 1.0
    rho: 0.66
    atleast: 1

val_eval:
  enabled: true
  num_val_indices: 20
  num_test_indices: 1000
  start: 0
  context: 1
  overlap: 1
  samples: 16
  interval: 50
  record: 3 # How many samples to plot video for
  ensemble_sizes_to_save: [1, 2, 4, 8, 16]
  sampling:
    algorithm: "ab"
    steps: 16
  seed: 0

train:
  epochs: 100
  epoch_size: 65536
  batch_size: 128 # 128 for RB, 64 for shear, 32 for euler
  accumulation: 1 # 1 for RB, 2 for shear, 4 for euler

valid:
  epoch_size: 4096
  batch_size: 256

fork:
  run: null
  target: "state"
  strict: true

compute:
  nodes: 1
  cpus_per_gpu: 8
  gpus: 1
  ram: "256GB"
  time: "7-00:00:00"

wandb:
  entity: null
