wandb:
  project: Large-Physics-Foundation-Model
  entity: xxxxxxxx
  id: final-training
  tags:
    - GPT-XL
    - AllDatasets
    - MainRun
  log_model: gradients # log :all", "parameters" or "gradients" or null
  notes: "Final training run with GPT-XL on all datasets"

logging:
  log_dir: /results
  log_file: null #log.txt
  log_level: INFO
  subdir_name: val_ # subdir name each cycle, should end with an underscore

model:
  architecture: gphyt # gphyt or unet or fno
  transformer:
    input_channels: 5 # number of input fields (usually 5 (pressure, density, temperature, vel_x, vel_y))
    model_size: GPT_XL # model config (GPT_S, GPT_M, GPT_L, GPT_XL)
    att_mode: full # attention mode (full, full_causal)
    dropout: 0.0 # dropout rate
    pos_enc_mode: absolute # positional encoding mode (usually rope or absolute)
    patch_size: [1, 16, 16] # patch size (usually 1x16x16)
    stochastic_depth_rate: 0.0 # stochastic depth rate
    use_derivatives: true # whether to use derivatives in the model
    integrator: Euler # integrator to use (Euler, RK4, Heun, null)
  tokenizer:
    tokenizer_mode: linear # tokenizer mode (usually linear or conv_net)
    detokenizer_mode: linear # detokenizer mode (usually linear or conv_net)
    tokenizer_overlap: 0 # pixels overlap between patches, only used for linear tokenizer
    detokenizer_overlap: 0 # pixels overlap between patches, only used for linear detokenizer

  # architecture: unet # gphyt or unet or fno
  # model_size: UNet_M # UNet_S or UNet_M
  # n_time_steps: 4 # number of input time steps
  # integrate: false # whether to integrate the output with the last input time step



training:
  compile: true
  tf32: true # use tf32 for faster matrix multiplications
  amp: true # use automatic mixed precision training
  mem_budget: 1 # memory budget in percent of total memory. # if below 1, use gradient checkpointing
  seed: 42
  batch_size: 64 # batch size per GPU
  batches: 1000e3 # number of batches to use for training

  checkpoint_every_batches: 1e3 # number of batches between checkpoints
  ################################################################
  ########### Validation parameters ############################
  ################################################################
  val_every_batches: 20e3 # number of batches between validation runs

  # This is per dataset, also per dt, if separate_dt is true
  val_frac_samples: 1 # fraction of samples to validate on (1 = all val samples)


  num_workers: 16 # number of workers for dataloader per GPU
  prefetch_factor: 2 # how many batches to prefetch for each worker
  grad_clip: 1.0 # gradient clipping, use null for no gradient clipping
  optimizer:
    name: AdamW # optimizer name (Adam, AdamW)
    learning_rate: 1e-4
    weight_decay: 0.01
    betas: [0.9, 0.999]
  criterion: MSE # MSE or MAE

  lr_scheduler:
    first_stage:
      name: LinearLR # linear warmup scheduler
      start_factor: 0.001
      end_factor: 1.0
      num_updates: 5000 # number of updates for linear warmup
    second_stage:
      name: CosineAnnealingLR # cosine annealing scheduler
      num_updates: -1 # num batches for cosine annealing, -1 means use all remaining batches
      end_factor: 0.01 # percentage of initial learning rate to use as minimum learning rate
data:
  use_normalization: true # whether to normalize the data
  max_samples_per_ds: null # limit the number of samples per epoch, set to null for no limit
  dt_stride: [1,8] # take every dt_stride-th timestep, if list, randomize between the values
  n_steps_input: 4 # number of input timesteps
  n_steps_output: 1 # number of target timesteps
  out_shape: [256, 128] # output shape (usually 256 x 128
  flip_x: 0.5 # probability to flip the x-axis of the data
  flip_y: 0.5 # probability to flip the y-axis of the data

  data_dir: data/datasets
  datasets:
    - cylinder_sym_flow_water
    - cylinder_pipe_flow_water
    - object_periodic_flow_water
    - object_sym_flow_water
    - object_sym_flow_air

    - heated_object_pipe_flow_air
    - cooled_object_pipe_flow_air
    - rayleigh_benard_obstacle

    - twophase_flow

    - rayleigh_benard
    - shear_flow
    - euler_multi_quadrants_periodicBC
    - acoustic_scattering_inclusions
