model:
  _target_: warpspeed.models.flower.Flower
  # dim_in and dim_out will be set automatically based on dataset metadata
  # spatial_resolution will be set automatically based on dataset metadata
  # n_spatial_dims will be set automatically based on dataset metadata
  lifting_dim: 160
  n_levels: 4
  num_heads: 40
  groups: 40
  dropout_rate: 0.0

# Optimizer configuration
optimizer:
  _target_: torch.optim.AdamW
  lr: 0.001
  weight_decay: 1.0e-2

# Learning rate scheduler
lr_scheduler:
  _target_: the_well.benchmark.optim.schedulers.LinearWarmupCosineAnnealingLR
  warmup_epochs: 5  # 5 epochs of warmup
  # optimizer will be passed automatically by train script
  # max_epochs will be passed automatically by train script
  # warmup_start_lr and eta_min will be set to optimizer.lr * 0.1 by train script

# Batch size mapping based on dataset name
# Adjust based on available GPU memory
batch_size_map:
  # 2D datasets
  "acoustic_scattering_maze": 111          # 256x256
  "acoustic_scattering_discontinuous": 111 # 256x256
  "acoustic_scattering_inclusions": 111    # 256x256
  "active_matter": 111                     # 256x256
  "euler_multi_quadrants_periodicBC": 28   # 512x512
  "gray_scott_reaction_diffusion": 445     # 128x128
  "helmholtz_staircase": 27                # 1024x256
  "pdebench-2D_DarcyFlow": 445             # 128x128
  "pdebench-diffusion_reaction": 445       # 128x128
  "pdebench-shallow_water": 445            # 128x128
  "planetswe": 54                          # 256x512
  "rayleigh_benard": 104                   # 512x128
  "rayleigh_benard_uniform": 104           # 512x128
  "shear_flow": 52                         # 256x512
  "turbulent_radiative_layer_2D": 150      # 128x384
  "viscoelastic_instability": 28           # 512x512
  "viscoelastic_instability_fixed": 28    # 512x512 (duplicates removed)
  "wavebench-helmholtz_anisotropic": 445   # 128x128

  # 3D datasets
  "convective_envelope_rsg": 1  # 256x128x256
  "MHD_64": 32  # 64x64x64
  "post_neutron_star_merger": 5 # 192x128x66
  "supernova_explosion_64": 32  # 64x64x64
  "supernova_explosion_128": 4  # batch size 4 gives 86% usage on H200
  "turbulence_gravity_cooling": 32  # 64x64x64
  "rayleigh_taylor_instability": 4  # 128x128x128
  "turbulent_radiative_layer_3D": 2  # 128x128x256
