
basic_config: &basic_config
  # Run settings
  log_to_wandb: !!bool False # Use wandb integration
  log_to_screen: !!bool True # Log progress to screen.
  save_checkpoint: !!bool True # Save checkpoints
  checkpoint_save_interval: 100 # Save every # epochs - also saves "best" according to val loss
  debug_grad: !!bool True # Compute gradient/step_sizes/ect for debugging
  true_time: !!bool False # Debugging setting - sets num workers to zero and activates syncs
  num_data_workers: 0 # TODO: Generally pulling 8 cpu per process, so using 6 for DL - not sure if best ratio
  enable_amp: !!bool False # Use automatic mixed precision - blows up with low variance fields right now
  compile: !!bool False # Compile model - Does not currently work
  gradient_checkpointing: !!bool False # Whether to use gradient checkpointing - Slow, but lower memory
  exp_dir: './logs/exp' # Output path
  log_interval: 1 # How often to log - Don't think this is actually implemented
  pretrained: False # Whether to load a pretrained model
  # vmae_pretrained: False # Whether to load a pretrained model
  checkpoint_path: '' # Whether to load a pretrained model
  # wandb settings
  project: 'hyfluid' 
  group: 'transformer'
  entity: ' '
  # Training settings ################################
  drop_path: 0.1
  batch_size: 1
  accum_grad: 4 # Real batch size is accum * batch_size, real steps/"epoch" is epoch_size / accum
  scheduler_epochs: -1
  train_subsample: 1. # TODO:
  # max_epochs: 100 # TODO:
  epoch_size: 20 # TODO: Artificial epoch size
  rescale_gradients: !!bool False # Activate hook that scales block gradients to norm 1
  optimizer: 'sgd' # adam, adan, whatever else i end up adding - adan did better on HP sweep
  scheduler: 'cosine' # Only cosine implemented
  warmup_steps: 50 # Warmup when not using DAdapt. Consine scheduler does not use warmup
  ######################################################
  learning_rate: 0.01 # -1 means use DAdapt
  weight_decay: 1e-3
  # state_names: ['Pressure', 'Vx', 'Vy', 'Density',  'Vx', 'Vy', 'Density', 'Pressure'] # TODO: Should be sorted
  state_names: ['Pressure', 'Vx', 'Vy'] # TODO: Should be sorted
  dt: 1 # Striding of data - Not currently implemented > 1
  n_steps: 10 # TODO: Length of history to include in input
  enforce_max_steps: !!bool False # If false and n_steps > dataset steps, use dataset steps. Otherwise, raise Exception.
  rollout_train: 5 # 
  rollout_test: 100 # TODO: total_num_frame - scalarflow_frame_num
  # Model settings ####################################
  model_type: 'vmae' # vit_small_patch16_224
  tubelet_size: 2
  encoder_embed_dim: 192 # Dimension of internal representation - 192/384/768/1024 for Ti/S/B/L
  decoder_embed_dim: 96
  encoder_num_heads: 12 # Number of heads for attention - 3/6/12/16 for Ti/S/B/L
  # processor_blocks: 12 # Number of transformer blocks in the backbone - 12/12/12/24 for Ti/S/B/L
  decoder_num_heads: 6
  encoder_depth: 12
  decoder_depth: 6
  in_chans: 3 # number of input channels, i.e. number of physical variables to learn and predict

  patch_size: 4 # Actually currently hardcoded at 16     
  # decoder_num_classes: params.in_chans* params.tubelet_size * params.patch_size ** 2 # added to build vmae 3*2*4*4
  window_size: [8,7,7]   # swin block window size
  input_size: 224  # TODO: will resize input to vmae to this shape
  ######################################################
  drop_path_rate: 0.1
  init_scale: 0.001
  # --num_frames 16 \
  # --opt adamw \
  # --lr 5e-4 \
  # --opt_betas 0.9 0.999 \
  # --weight_decay 0.05 \
  # --dist_eval \
  # --test_num_segment 2 \
  # --test_num_crop 3 \
  # block_type: 'axial' # Which type of block to use - if axial, next two fields must be set to define axial ops
  # time_type: 'attention' # Conditional on block type
  # space_type: 'axial_attention' # Conditional on block type
  tie_fields: !!bool False # Whether to use 1 embedding per field per data
  bias_type: 'rel'  # Options rel, continuous, none
  # Data settings
  train_val_test: [0.8, 0.1, 0.1] # TODO: use all data
  augmentation: !!bool False # Augmentation not implemented
  use_all_fields: !!bool True # Prepopulate the field metadata dictionary from dictionary in datasets
  tie_batches: !!bool False # Force everything in batch to come from one dset
  extended_names: !!bool False # Whether to use extended names - not currently implemented
  embedding_offset: 0  # Use when adding extra finetuning fields
  # paths, dataset_type, include_string, pde_param
  train_data_paths: [ ]
  valid_data_paths: [ ]
  ood_train_data_paths: [ ]
  ood_valid_data_paths: [ ]
  # scalarflow_frame_num: 20 # TODO:
  temporal_cutoff: -1 # TODO: -1: no cutoff; >0: cutoff over temporal steps
  temporal_cutoff_test: -1 # TODO: -1: no cutoff; >0: cutoff over temporal steps
  ######################################################
  ssl: 'gt' # 'none', 'interp' (temporal interpolation of inp to fine-tune), 'gt' (use target ground truth to fine-tune, for comparison purpose only)
  mask_ratio: 0.
  dropout_p: 0.005
  st_repeat: 10


finetune: &finetune
  <<: *basic_config
  max_epochs: 500
  train_val_test: [.8, .1, .1]
  accum_grad: 1
  pretrained: !!bool True
  group: 'debugging'
  pretrained_ckpt_path: 'training_checkpoints/ckpt.tar'
  train_data_paths: [ ]
  valid_data_paths: [ ]
  embedding_offset: 0 # Number of fields in original model - FT fields start after this
  freeze_middle: !!bool False # Whether to freeze the middle layers of the model
  freeze_processor: !!bool False
  append_datasets: [] # List of datasets to append to the input/output projections for finetuning
  

frozen: &frozen
  <<: *finetune
  freeze_middle: !!bool True # Whether to freeze the middle layers of the model
  freeze_processor: !!bool False

less_frozen: &less_frozen
  <<: *finetune
  freeze_middle: !!bool True # Whether to freeze the middle layers of the model
  freeze_processor: !!bool True

