seed_everything: 42

# ---------------------------- TRAINER -------------------------------------------
trainer:
  default_root_dir: /path/to/root_dir

  precision: 16

  devices: null
  num_nodes: 1
  accelerator: gpu
  strategy: ddp_find_unused_parameters_true

  min_epochs: 1
  max_epochs: 50
  enable_progress_bar: true

  sync_batchnorm: True
  enable_checkpointing: True
  num_sanity_val_steps: 1

  # debugging
  fast_dev_run: false

  logger:
    class_path: lightning.pytorch.loggers.wandb.WandbLogger
    init_args:
      project: 'atmos-arena'
      save_dir: ${trainer.default_root_dir}/extreme_detection_climax_deeper_cnn_downsample_finetune_backbone_weighted_jaccard_8bs_1e-4_lr
      name: extreme_detection_climax_deeper_cnn_downsample_finetune_backbone_weighted_jaccard_8bs_1e-4_lr

  callbacks:
    - class_path: lightning.pytorch.callbacks.LearningRateMonitor
      init_args:
        logging_interval: "step"

    - class_path: lightning.pytorch.callbacks.ModelCheckpoint
      init_args:
        dirpath: "${trainer.default_root_dir}/extreme_detection_climax_deeper_cnn_downsample_finetune_backbone_weighted_jaccard_8bs_1e-4_lr/checkpoints"
        monitor: "val/loss" # name of the logged metric which determines when model is improving
        mode: "min" # "max" means higher metric value is better, can be also "min"
        save_top_k: 1 # save k best models (determined by above metric)
        save_last: True # additionaly always save model from last epoch
        verbose: False
        filename: "epoch_{epoch:03d}"
        auto_insert_metric_name: False

    - class_path: lightning.pytorch.callbacks.EarlyStopping
      init_args:
        monitor: "val/loss" # name of the logged metric which determines when model is improving
        mode: "min" # "max" means higher metric value is better, can be also "min"
        patience: 10 # how many validation epochs of not improving until training stops
        min_delta: 0. # minimum change in the monitored metric needed to qualify as an improvement

    - class_path: lightning.pytorch.callbacks.RichModelSummary
      init_args:
        max_depth: -1

    - class_path: lightning.pytorch.callbacks.TQDMProgressBar

# ---------------------------- MODEL -------------------------------------------
model:
  loss_type: 'weighted_jaccard'
  lr: 1e-4
  beta_1: 0.9
  beta_2: 0.999
  weight_decay: 1e-5
  warmup_epochs: 5
  max_epochs: 50
  warmup_start_lr: 1e-8
  eta_min: 1e-8
  pretrained_path: https://huggingface.co/tungnd/climax/resolve/main/1.40625deg.ckpt

  net:
    class_path: extreme_detection.climax_climatenet_arch.ClimaXClimateNet
    init_args:
      img_size: [768, 1152]
      default_vars: [
        'TMQ',
        'U850',
        'V850',
        'UBOT',
        'VBOT',
        'QREFHT',
        'PS',
        'PSL',
        'T200',
        'T500',
        'PRECT',
        'TS',
        'TREFHT',
        'Z1000',
        'Z200',
        'ZBOT',
      ]
      out_vars: [
        "Background",
        "Tropical Cyclone",
        "Atmospheric River"
      ]
      patch_size: 8
      embed_dim: 1024
      depth: 8
      decoder_depth: 2
      num_heads: 16
      mlp_ratio: 4
      drop_path: 0.1
      drop_rate: 0.1
      freeze_encoder: False

# ---------------------------- DATA -------------------------------------------
data:
  root_dir: /path/to/data
  in_variables: [
    'TMQ',
    'U850',
    'V850',
    'UBOT',
    'VBOT',
    'QREFHT',
    'PS',
    'PSL',
    'T200',
    'T500',
    'PRECT',
    'TS',
    'TREFHT',
    'Z1000',
    'Z200',
    'ZBOT',
  ]
  val_ratio: 0.1
  batch_size: 2
  num_workers: 4
  pin_memory: False