defaults:
  - dataset: rayleigh_benard
  - surrogate: vit_large_with_noise
  - optim: adamw
  - server: local

ae_run: ???
surrogate_run: ???
ensemble_size: 4 # 16
noise_emb_features: 32
load_surrogate: True
load_ae: True
finetune: False

staged_training:
  threshold_lr: 1e-4     # LR for COMMON/pretrained parameters
  
  # Scheduler settings for common (pretrained) parameters
  common_scheduler: "cosine"
  common_warmup: 10
  
  # Scheduler settings for new parameters  
  new_scheduler: "cosine"
  new_warmup: 10
  
  # Optional: Override optimizer settings for common parameters
  common_params_kwargs:
    weight_decay: 0.0
    betas: [0.9, 0.99]
  
  # Optional: Override optimizer settings for new parameters
  new_params_kwargs:
    weight_decay: 0.0
    betas: [0.9, 0.99]

trajectory:
  length: 5
  stride: 1
  context:
    lmbda: 1.0
    rho: 0.66
    atleast: 1

val_eval:
  enabled: true
  num_val_indices: 20
  num_test_indices: 1000
  start: 0
  context: 1
  overlap: 1
  samples: 16
  interval: 10
  record: 3 # How many samples to plot video for
  ensemble_sizes_to_save: [1, 2, 4, 8, 16]
  sampling:
    algorithm: "ab"
    steps: 16
  seed: 0

train:
  epochs: 100
  epoch_size: 16384
  batch_size: 32 # 32 for 4 members, 64 for 2 members
  accumulation: 1 # 1 for RB, 2 for shear, 4 for euler

valid:
  epoch_size: 4096
  batch_size: 64

fork:
  run: null
  target: "state"
  strict: true

compute:
  nodes: 1
  cpus_per_gpu: 8
  gpus: 1
  ram: "256GB"
  time: "7-00:00:00"

wandb:
  entity: null
