seed_everything: 42

# ---------------------------- TRAINER -------------------------------------------
trainer:
  default_root_dir: /path/to/root_dir

  precision: 16

  devices: null
  num_nodes: 1
  accelerator: gpu
  strategy: ddp_find_unused_parameters_true

  min_epochs: 1
  max_epochs: 50
  enable_progress_bar: true

  sync_batchnorm: True
  enable_checkpointing: True
  num_sanity_val_steps: 1

  # debugging
  fast_dev_run: false

  logger:
    class_path: lightning.pytorch.loggers.wandb.WandbLogger
    init_args:
      project: 'atmos-arena'
      save_dir: ${trainer.default_root_dir}/air_comp_forecasting_unet_7days_surface_vars_finetune_backbone_128bs_5e-4_lr
      name: air_comp_forecasting_unet_7days_surface_vars_finetune_backbone_128bs_5e-4_lr

  callbacks:
    - class_path: lightning.pytorch.callbacks.LearningRateMonitor
      init_args:
        logging_interval: "step"

    - class_path: lightning.pytorch.callbacks.ModelCheckpoint
      init_args:
        dirpath: "${trainer.default_root_dir}/air_comp_forecasting_unet_7days_surface_vars_finetune_backbone_128bs_5e-4_lr/checkpoints"
        monitor: "val/w_mse" # name of the logged metric which determines when model is improving
        mode: "min" # "max" means higher metric value is better, can be also "min"
        save_top_k: 1 # save k best models (determined by above metric)
        save_last: True # additionaly always save model from last epoch
        verbose: False
        filename: "epoch_{epoch:03d}"
        auto_insert_metric_name: False

    - class_path: lightning.pytorch.callbacks.EarlyStopping
      init_args:
        monitor: "val/w_mse" # name of the logged metric which determines when model is improving
        mode: "min" # "max" means higher metric value is better, can be also "min"
        patience: 10 # how many validation epochs of not improving until training stops
        min_delta: 0. # minimum change in the monitored metric needed to qualify as an improvement

    - class_path: lightning.pytorch.callbacks.RichModelSummary
      init_args:
        max_depth: -1

    - class_path: lightning.pytorch.callbacks.TQDMProgressBar

# ---------------------------- MODEL -------------------------------------------
model:
  lr: 5e-4
  beta_1: 0.9
  beta_2: 0.95
  weight_decay: 1e-5
  warmup_epochs: 5
  max_epochs: 50
  warmup_start_lr: 1e-8
  eta_min: 1e-8
  pretrained_path: ""

  net:
    class_path: unet_arch.Unet
    init_args:
      in_channels: 77
      out_channels: 8
      hidden_channels: 128
      activation: "leaky"
      norm: True
      dropout: 0.1
      ch_mults: [1, 2, 2, 4]
      is_attn: [False, False, False, False]
      mid_attn: False
      n_blocks: 2

# ---------------------------- DATA -------------------------------------------
data:
  root_dir: /path/to/data
  in_variables: [
    "2m_temperature",
    "10m_u_component_of_wind",
    "10m_v_component_of_wind",
    "mean_sea_level_pressure",
    "particulate_matter_10um",
    "particulate_matter_1um",
    "particulate_matter_2.5um",
    "total_column_carbon_monoxide",
    "total_column_nitrogen_dioxide",
    "total_column_nitrogen_monoxide",
    "total_column_ozone",
    "total_column_sulphur_dioxide",
    "ozone_50",
    "ozone_100",
    "ozone_150",
    "ozone_200",
    "ozone_250",
    "ozone_300",
    "ozone_400",
    "ozone_500",
    "ozone_600",
    "ozone_700",
    "ozone_850",
    "ozone_925",
    "ozone_1000",
    "specific_humidity_50",
    "specific_humidity_100",
    "specific_humidity_150",
    "specific_humidity_200",
    "specific_humidity_250",
    "specific_humidity_300",
    "specific_humidity_400",
    "specific_humidity_500",
    "specific_humidity_600",
    "specific_humidity_700",
    "specific_humidity_850",
    "specific_humidity_925",
    "specific_humidity_1000",
    "sulphur_dioxide_50",
    "sulphur_dioxide_100",
    "sulphur_dioxide_150",
    "sulphur_dioxide_200",
    "sulphur_dioxide_250",
    "sulphur_dioxide_300",
    "sulphur_dioxide_400",
    "sulphur_dioxide_500",
    "sulphur_dioxide_600",
    "sulphur_dioxide_700",
    "sulphur_dioxide_850",
    "sulphur_dioxide_925",
    "sulphur_dioxide_1000",
    "u_component_of_wind_50",
    "u_component_of_wind_100",
    "u_component_of_wind_150",
    "u_component_of_wind_200",
    "u_component_of_wind_250",
    "u_component_of_wind_300",
    "u_component_of_wind_400",
    "u_component_of_wind_500",
    "u_component_of_wind_600",
    "u_component_of_wind_700",
    "u_component_of_wind_850",
    "u_component_of_wind_925",
    "u_component_of_wind_1000",
    "v_component_of_wind_50",
    "v_component_of_wind_100",
    "v_component_of_wind_150",
    "v_component_of_wind_200",
    "v_component_of_wind_250",
    "v_component_of_wind_300",
    "v_component_of_wind_400",
    "v_component_of_wind_500",
    "v_component_of_wind_600",
    "v_component_of_wind_700",
    "v_component_of_wind_850",
    "v_component_of_wind_925",
    "v_component_of_wind_1000",
  ]
  out_variables: [
    "particulate_matter_10um",
    "particulate_matter_1um",
    "particulate_matter_2.5um",
    "total_column_carbon_monoxide",
    "total_column_nitrogen_dioxide",
    "total_column_nitrogen_monoxide",
    "total_column_ozone",
    "total_column_sulphur_dioxide",
  ]
  lead_time: 168
  data_freq: 12
  batch_size: 4
  num_workers: 8
  pin_memory: False