### base config ###
full_field: &FULL_FIELD
  loss: 'l2'
  lr: 1E-3
  scheduler: 'ReduceLROnPlateau'
  num_data_workers: 4
  dt: 1 # how many timesteps ahead the model will predict
  n_history: 0 #how many previous timesteps to consider
  prediction_type: 'iterative'
  prediction_length: 41 #applicable only if prediction_type == 'iterative'
  n_initial_conditions: 5 #applicable only if prediction_type == 'iterative'
  ics_type: "default"
  save_raw_forecasts: !!bool True
  save_channel: !!bool False
  masked_acc: !!bool False
  maskpath: None
  perturb: !!bool False
  add_grid: !!bool False
  N_grid_channels: 0
  gridtype: 'sinusoidal' #options 'sinusoidal' or 'linear'
  roll: !!bool False
  max_epochs: 50
  batch_size: 64

  #afno hyperparams
  num_blocks: 8
  nettype: 'afno'
  patch_size: 8
  width: 56
  modes: 32
  #options default, residual
  target: 'default' 
  in_channels: [0,1]
  out_channels: [0,1] #must be same as in_channels if prediction_type == 'iterative'
  normalization: 'zscore' #options zscore (minmax not supported) 
  train_data_path: '/pscratch/sd/j/jpathak/wind/train'
  valid_data_path: '/pscratch/sd/j/jpathak/wind/test'
  inf_data_path: '/pscratch/sd/j/jpathak/wind/out_of_sample' # test set path for inference
  exp_dir: '/pscratch/sd/j/jpathak/ERA5_expts_gtc/wind'
  time_means_path:   '/pscratch/sd/j/jpathak/wind/time_means.npy'
  global_means_path: '/pscratch/sd/j/jpathak/wind/global_means.npy'
  global_stds_path:  '/pscratch/sd/j/jpathak/wind/global_stds.npy'

  orography: !!bool False
  orography_path: None

  log_to_screen: !!bool True
  log_to_wandb: !!bool True
  save_checkpoint: !!bool True

  enable_nhwc: !!bool False
  optimizer_type: 'FusedAdam'
  crop_size_x: None
  crop_size_y: None

  two_step_training: !!bool False
  plot_animations: !!bool False

  add_noise: !!bool False
  noise_std: 0

afno_backbone: &backbone
  <<: *FULL_FIELD
  log_to_wandb: !!bool True
  lr: 5E-4
  batch_size: 64
  max_epochs: 150
  scheduler: 'CosineAnnealingLR'
  in_channels: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71]
  # out_channels: [0, 1 ,2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
  out_channels: [0, 1 ,2]
  orography: !!bool False
  orography_path: None 
  exp_dir: '/pscratch/sd/s/shas1693/results/era5_wind'
  train_data_path: '/pscratch/sd/s/shas1693/data/era5/train'
  valid_data_path: '/pscratch/sd/s/shas1693/data/era5/test'
  inf_data_path:   '/pscratch/sd/s/shas1693/data/era5/out_of_sample'
  time_means_path:   '/pscratch/sd/s/shas1693/data/era5/time_means.npy'
  global_means_path: '/pscratch/sd/s/shas1693/data/era5/global_means.npy'
  global_stds_path:  '/pscratch/sd/s/shas1693/data/era5/global_stds.npy'

afno_backbone_orography: &backbone_orography 
  <<: *backbone
  orography: !!bool True
  orography_path: '/pscratch/sd/s/shas1693/data/era5/static/orography.h5'

afno_backbone_finetune: 
  <<: *backbone
  lr: 1E-4
  batch_size: 64
  log_to_wandb: !!bool True
  max_epochs: 50
  pretrained: !!bool True
  two_step_training: !!bool True
  pretrained_ckpt_path: '/pscratch/sd/s/shas1693/results/era5_wind/afno_backbone/0/training_checkpoints/best_ckpt.tar'

perturbations:
  <<: *backbone
  lr: 1E-4
  batch_size: 64
  max_epochs: 50
  pretrained: !!bool True
  two_step_training: !!bool True
  pretrained_ckpt_path: '/pscratch/sd/j/jpathak/ERA5_expts_gtc/wind/afno_20ch_bs_64_lr5em4_blk_8_patch_8_cosine_sched/1/training_checkpoints/best_ckpt.tar'
  prediction_length: 24
  ics_type: "datetime"
  n_perturbations: 100 
  save_channel: !bool True
  save_idx: 4
  save_raw_forecasts: !!bool False
  date_strings: ["2018-01-01 00:00:00"] 
  inference_file_tag: " "
  valid_data_path: "/pscratch/sd/j/jpathak/ "
  perturb: !!bool True
  n_level: 0.3

### PRECIP ###
precip: &precip
  <<: *backbone
  in_channels: [0, 1 ,2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
  out_channels: [0]
  nettype: 'afno'
  nettype_wind: 'afno'
  log_to_wandb: !!bool True
  lr: 2.5E-4
  batch_size: 64
  max_epochs: 25
  precip: '/pscratch/sd/p/pharring/ERA5/precip/total_precipitation'
  time_means_path_tp: '/pscratch/sd/p/pharring/ERA5/precip/total_precipitation/time_means.npy'
  model_wind_path: '/pscratch/sd/s/shas1693/results/era5_wind/afno_backbone_finetune/0/training_checkpoints/best_ckpt.tar'
  precip_eps: !!float 1e-5

