basic_config: &basic_config
  # Run settings
  log_to_wandb: !!bool True # Use wandb integration
  log_to_screen: !!bool True # Log progress to screen.
  save_checkpoint: !!bool True # Save checkpoints
  checkpoint_save_interval: 100 # Save every # epochs - also saves "best" according to val loss
  debug_grad: !!bool True # Compute gradient/step_sizes/ect for debugging
  true_time: !!bool False # Debugging setting - sets num workers to zero and activates syncs
  num_data_workers: 6 # Generally pulling 8 cpu per process, so using 6 for DL - not sure if best ratio
  enable_amp: !!bool False # Use automatic mixed precision - blows up with low variance fields right now
  compile: !!bool False # Compile model - Does not currently work
  gradient_checkpointing: !!bool False # Whether to use gradient checkpointing - Slow, but lower memory
  exp_dir: 'exp_dir' # Output path
  log_interval: 1 # How often to log - Don't think this is actually implemented
  pretrained: False # Whether to load a pretrained model
  vmae_pretrained: False # Whether to load a pretrained model
  # wandb settings
  project: 'foundationmodels' 
  group: 'multiple_physics_pretraining'
  entity: 'entity'
  # Training settings ################################
  drop_path: 0.1
  batch_size: 4
  accum_grad: 2 # Real batch size is accum * batch_size, real steps/"epoch" is epoch_size / accum
  scheduler_epochs: -1
  pretrain_train: [.9, .1] # TODO:
  train_subsample: 1. # TODO:
  max_epochs: 500
  epoch_size: 15 # TODO: Artificial epoch size
  rescale_gradients: !!bool False # Activate hook that scales block gradients to norm 1
  optimizer: 'adan' # adam, adan, whatever else i end up adding - adan did better on HP sweep
  scheduler: 'cosine' # Only cosine implemented
  warmup_steps: 50 # TODO: Warmup when not using DAdapt
  ######################################################
  learning_rate: -1 # -1 means use DAdapt
  weight_decay: 1e-3 
  n_states: 12  # TODO: Number of state variables across the datasets - Can be larger than real number and things will just go unused
  state_names: ['Pressure', 'Vx', 'Vy', 'Density',  'Vx', 'Vy', 'Density', 'Pressure'] # TODO: Should be sorted
  dt: 1 # TODO: Striding of data - Not currently implemented > 1
  n_steps: 16 # TODO: Length of history to include in input
  enforce_max_steps: !!bool False # If false and n_steps > dataset steps, use dataset steps. Otherwise, raise Exception.
  # Model settings ####################################
  model_type: 'vmae' # vit_small_patch16_224
  encoder_embed_dim: 384 # Dimension of internal representation - 192/384/768/1024 for Ti/S/B/L
  decoder_embed_dim: 192
  encoder_num_heads: 6 # Number of heads for attention - 3/6/12/16 for Ti/S/B/L
  decoder_num_heads: 3
  decoder_depth: 4
  decoder_num_classes: 1536
  tubelet_size: 2
  ######################################################
  input_size: 512
  drop_path_rate: 0.1
  init_scale: 0.001
  # --num_frames 16 \
  # --opt adamw \
  # --lr 5e-4 \
  # --opt_betas 0.9 0.999 \
  # --weight_decay 0.05 \
  # --dist_eval \
  # --test_num_segment 2 \
  # --test_num_crop 3 \
  # block_type: 'axial' # Which type of block to use - if axial, next two fields must be set to define axial ops
  # time_type: 'attention' # Conditional on block type
  # space_type: 'axial_attention' # Conditional on block type
  tie_fields: !!bool False # Whether to use 1 embedding per field per data
  processor_blocks: 12 # Number of transformer blocks in the backbone - 12/12/12/24 for Ti/S/B/L
  patch_size: 16 # Actually currently hardcoded at 16
  bias_type: 'rel'  # Options rel, continuous, none
  # Data settings
  train_val_test: [.8, .1, .1]
  augmentation: !!bool False # Augmentation not implemented
  use_all_fields: !!bool True # Prepopulate the field metadata dictionary from dictionary in datasets
  tie_batches: !!bool False # Force everything in batch to come from one dset
  extended_names: !!bool False # Whether to use extended names - not currently implemented
  embedding_offset: 0  # Use when adding extra finetuning fields
  train_data_paths: [
              ['../data/PDEBench/2D/NS_incom', 'incompNS', '']
              ]
  valid_data_paths: [
              ['../data/PDEBench/2D/NS_incom', 'incompNS', '']
              ]

  mask_ratio: 0.1 # TODO:

finetune: &finetune
  <<: *basic_config
  max_epochs: 500
  train_val_test: [.8, .1, .1]
  accum_grad: 1
  pretrained: !!bool True
  group: 'debugging'
  pretrained_ckpt_path: '/B16-noNS/training_checkpoints/ckpt.tar'
  train_data_paths: [
              ['/PDEBench/2D/CFD/2D_Train_Turb', 'compNS', 'M1.0'],
              ]
  valid_data_paths: [  # These are the same for all configs - uses split according to train_val_test
               ['/PDEBench/2D/CFD/2D_Train_Turb', 'compNS', 'M1.0'],
              ]
  embedding_offset: 0 # Number of fields in original model - FT fields start after this
  freeze_middle: !!bool False # Whether to freeze the middle layers of the model
  freeze_processor: !!bool False
  append_datasets: [] # List of datasets to append to the input/output projections for finetuning
  

frozen: &frozen
  <<: *finetune
  freeze_middle: !!bool True # Whether to freeze the middle layers of the model
  freeze_processor: !!bool False

less_frozen: &less_frozen
  <<: *finetune
  freeze_middle: !!bool True # Whether to freeze the middle layers of the model
  freeze_processor: !!bool True
