basic_config: &basic_config
  # Run settings
  log_to_wandb: !!bool True # Use wandb integration
  log_to_screen: !!bool True # Log progress to screen.
  save_checkpoint: !!bool True # Save checkpoints
  checkpoint_save_interval: 10 # Save every # epochs - also saves "best" according to val loss
  debug_grad: !!bool True # Compute gradient/step_sizes/ect for debugging
  true_time: !!bool False # Debugging setting - sets num workers to zero and activates syncs
  num_data_workers: 6 # Generally pulling 8 cpu per process, so using 6 for DL - not sure if best ratio
  enable_amp: !!bool False # Use automatic mixed precision - blows up with low variance fields right now
  compile: !!bool False # Compile model - Does not currently work
  gradient_checkpointing: !!bool False # Whether to use gradient checkpointing - Slow, but lower memory
  exp_dir: '~/MPP' # Output path
  log_interval: 1 # How often to log - Don't think this is actually implemented
  pretrained: !!bool False # Whether to load a pretrained model
  # wandb settings
  project: 'mpp' 
  group: 'debugging'
  entity: 'kim-anonymous'
  # Training settings
  drop_path: 0.1
  batch_size: 4
  max_epochs: 500
  scheduler_epochs: -1
  epoch_size: 2000 # Artificial epoch size
  rescale_gradients: !!bool False # Activate hook that scales block gradients to norm 1
  optimizer: 'adan' # adam, adan, whatever else i end up adding - adan did better on HP sweep
  scheduler: 'cosine' # Only cosine implemented
  warmup_steps: 1000 # Warmup when not using DAdapt
  learning_rate: -1 # -1 means use DAdapt
  weight_decay: 1e-3 
  n_states: 12   # Number of state variables across the datasets - Can be larger than real number and things will just go unused
  state_names: ['Pressure', 'Vx', 'Vy', 'Density',  'Vx', 'Vy', 'Density', 'Pressure'] # Should be sorted
  dt: 1 # Striding of data - Not currently implemented > 1
  n_steps: 16 # Length of history to include in input
  enforce_max_steps: !!bool False # If false and n_steps > dataset steps, use dataset steps. Otherwise, raise Exception.
  accum_grad: 5  # Real batch size is accum * batch_size, real steps/"epoch" is epoch_size / accum
  # Model settings
  model_type: 'xavit' # Only option so far
  block_type: 'axial' # Which type of block to use - if axial, next two fields must be set to define axial ops
  time_type: 'attention' # Conditional on block type
  space_type: 'axial_attention' # Conditional on block type
  tie_fields: !!bool False # Whether to use 1 embedding per field per data
  embed_dim: 192 # Dimension of internal representation - 192/384/768/1024 for Ti/S/B/L
  num_heads: 3 # Number of heads for attention - 3/6/12/16 for Ti/S/B/L
  processor_blocks: 12 # Number of transformer blocks in the backbone - 12/12/12/24 for Ti/S/B/L
  patch_size: [16, 16] # Actually currently hardcoded at 16
  bias_type: 'rel'  # Options rel, continuous, none
  # Data settings
  train_val_test: [.8, .1, .1]
  augmentation: !!bool False # Augmentation not implemented
  use_all_fields: !!bool True # Prepopulate the field metadata dictionary from dictionary in datasets
  tie_batches: !!bool False # Force everything in batch to come from one dset
  extended_names: !!bool False # Whether to use extended names - not currently implemented
  embedding_offset: 0  # Use when adding extra finetuning fields
  train_data_paths: [
              ['/mnt/home/anonymous/cvit/PDEDatasets/2D/shallow-water', 'swe', ''],
              # ['/mnt/home/anonymous/cvit/PDEDatasets/2D/NS_incom', 'incompNS', ''],
              ['/mnt/home/anonymous/cvit/PDEDatasets/2D/CFD/2D_Train_Rand', compNS, '128'],
              ['/mnt/home/anonymous/cvit/PDEDatasets/2D/CFD/2D_Train_Rand', compNS, '512'],
              ['/mnt/home/anonymous/cvit/PDEDatasets/2D/CFD/2D_Train_Turb', compNS, ''],
              ['/mnt/home/anonymous/cvit/PDEDatasets/2D/diffusion-reaction', 'diffre2d', ''],
              ]
  valid_data_paths: [
              ['/mnt/home/anonymous/cvit/PDEDatasets/2D/shallow-water', 'swe', ''],
              # ['/mnt/home/anonymous/cvit/PDEDatasets/2D/NS_incom', 'incompNS', ''],
              ['/mnt/home/anonymous/cvit/PDEDatasets/2D/CFD/2D_Train_Rand', compNS, '128'],
              ['/mnt/home/anonymous/cvit/PDEDatasets/2D/CFD/2D_Train_Rand', compNS, '512'],
              ['/mnt/home/anonymous/cvit/PDEDatasets/2D/CFD/2D_Train_Turb', compNS, ''],
              ['/mnt/home/anonymous/cvit/PDEDatasets/2D/diffusion-reaction', 'diffre2d', ''],
              ]
  append_datasets: [] # List of datasets to append to the input/output projections for finetuning


finetune: &finetune
  <<: *basic_config
  max_epochs: 500
  train_val_test: [.8, .1, .1]
  accum_grad: 1
  pretrained: !!bool True
  group: 'debugging'
  pretrained_ckpt_path: '/B16-noNS/training_checkpoints/ckpt.tar'
  train_data_paths: [
              ['/mnt/home/anonymous/cvit/PDEDatasets/2D/CFD/2D_Train_Turb', 'compNS', 'M1.0'],
              ]
  valid_data_paths: [  # These are the same for all configs - uses split according to train_val_test
               ['/mnt/home/anonymous/cvit/PDEDatasets/2D/CFD/2D_Train_Turb', 'compNS', 'M1.0'],
              ]
  embedding_offset: 0 # Number of fields in original model - FT fields start after this
  freeze_middle: !!bool False # Whether to freeze the middle layers of the model
  freeze_processor: !!bool False
  append_datasets: [] # List of datasets to append to the input/output projections for finetuning
  

frozen: &frozen
  <<: *finetune
  freeze_middle: !!bool True # Whether to freeze the middle layers of the model
  freeze_processor: !!bool False

less_frozen: &less_frozen
  <<: *finetune
  freeze_middle: !!bool True # Whether to freeze the middle layers of the model
  freeze_processor: !!bool True