# UNetConvNext model configuration from The Well

model:
  _target_: warpspeed.models.unet_convnext.UNetConvNext
  # input_dim and output_dim will be set automatically based on dataset metadata
  # resolution will be set automatically based on dataset metadata
  # n_spatial_dims will be set automatically based on dataset metadata
  init_features: 42  # Initial dimension
  blocks_per_stage: 2  # Blocks per stage
  stages: 4  # Number of up/down blocks
  blocks_at_neck: 1  # Bottleneck blocks
  gradient_checkpointing: false
  # Note: Spatial filter size (7) is hardcoded in UNetConvNext Block class

# Optimizer configuration
optimizer:
  _target_: torch.optim.AdamW
  lr: 0.001
  weight_decay: 1.0e-2

# Learning rate scheduler
lr_scheduler:
  _target_: the_well.benchmark.optim.schedulers.LinearWarmupCosineAnnealingLR
  warmup_epochs: 5  # 5 epochs of warmup
  # optimizer will be passed automatically by train script
  # max_epochs will be passed automatically by train script
  # warmup_start_lr and eta_min will be set to optimizer.lr * 0.1 by train script

# Batch size mapping based on dataset name
# Adjust based on available GPU memory
batch_size_map:
  # 2D datasets
  "acoustic_scattering_maze": 128           # 256x256
  "acoustic_scattering_discontinuous": 128  # 256x256
  "acoustic_scattering_inclusions": 128     # 256x256
  "active_matter": 128                      # 256x256
  "euler_multi_quadrants_periodic": 32      # 512x512
  "gray_scott_reaction_diffusion": 580      # 128x128
  "helmholtz_staircase": 32                 # 1024x256
  "pdebench-2D_DarcyFlow": 580                 # 128x128
  "pdebench-diffusion_reaction": 580        # 128x128
  "pdebench-shallow_water": 580             # 128x128
  "planetswe": 64                           # 256x512
  "rayleigh_benard": 128                    # 512x128
  "rayleigh_benard_uniform": 128            # 512x128
  "shear_flow": 64                          # 256x512
  "turbulent_radiative_layer_2D": 180       # 128x384
  "viscoelastic_instability": 32            # 512x512
  "viscoelastic_instability_fixed": 32     # 512x512 (duplicates removed)
  "wavebench-helmholtz_anisotropic": 580    # 128x128

  # 3D datasets
  "convective_envelope_rsg": 1  # 256x128x256
  "MHD_64": 52  # 64x64x64
  "post_neutron_star_merger": 9 # 192x128x66
  "supernova_explosion_64": 52  # 64x64x64
  "supernova_explosion_128": 6  # 64x64x64 6 gives 91% on H200
  "turbulence_gravity_cooling": 52  # 64x64x64
  "rayleigh_taylor_instability": 6  # 128x128x128
  "turbulent_radiative_layer_3D": 3  # 128x128x256
