master_factory_base_path: vislstm
name: deit3-b16-u160k-upernet-relpos
stage_name: ade20k


vars:
  batch_size: 16
  updates: 160000
  drop_path_rate: 0.1
  optim:
    kind: adamw
    lr: 5.0e-4
    betas: [ 0.9, 0.999 ]
    weight_decay: 0.05
    schedule:
      kind: linear_warmup_cosine_decay_schedule
      warmup_updates: 1500
      end_value: 1.0e-6
    lr_scaler:
      kind: linear_lr_scaler
      divisor: 16


datasets:
  train:
    kind: datasets.ade20k
    split: training
    sample_wrappers:
      - kind: segmentation_transform_wrapper
        transform:
          - kind: segmentation_random_resize
            ratio_resolution: [ 2048, 512 ]
            ratio_range: [ 0.5, 2.0 ]
            interpolation: bicubic
          - kind: segmentation_random_crop
            size: 512
            max_category_ratio: 0.75
            ignore_index: -1
          - kind: segmentation_random_horizontal_flip
      - kind: x_transform_wrapper
        transform:
          - kind: color_jitter
            brightness: 0.5
            contrast: 0.5
            saturation: 0.5
            hue: 0.25
          - kind: imagenet1k_norm
      - kind: segmentation_transform_wrapper
        transform:
          - kind: segmentation_pad
            size: 512
            fill: -1
  val:
    kind: datasets.ade20k
    split: validation
    sample_wrappers:
      - kind: x_transform_wrapper
        transform:
          - kind: imagenet1k_norm

model:
  kind: models.upernet.upernet_model
  auxiliary_feature_idx: 2
  encoder:
    kind: models.single.vit
    patch_size: 16
    dim: 768
    num_attn_heads: 12
    depth: 12
    drop_path_rate: ${vars.drop_path_rate}
    drop_path_decay: true
    mode: features
    use_relpos_bias: true
    optim:
      template: ${vars.optim}
      template.param_group_modifiers:
        - kind: optim.layerwise_lr_decay
          layerwise_lr_decay: 0.65
    initializers:
      - kind: deit3_pretrained_initializer
        model: base_res224_in1k
        use_checkpoint_kwargs: true
  extractor:
    kind: block_extractor
    block_indices: [ 3, 5, 7, 11 ]
    pooling:
      kind: to_image
  postprocessor:
    kind: models.upernet.upernet_postprocessor
    optim: ${vars.optim}
  decoder:
    kind: models.upernet.upernet_decoder
    dim: 384
    optim: ${vars.optim}
  auxiliary:
    kind: models.upernet.upernet_auxiliary
    dim: 384
    optim: ${vars.optim}

trainer:
  kind: segmentation_trainer
  precision: float16
  backup_precision: float32
  effective_batch_size: ${vars.batch_size}
  max_updates: ${vars.updates}
  ignore_index: -1
  log_every_n_updates: 1000
  skip_nan_loss: true
  loss_weights:
    auxiliary: 0.4
  callbacks:
    # resume
    - kind: checkpoint_callback
      every_n_updates: 10000
      save_weights: false
      save_latest_weights: true
      save_latest_optim: true
    # miou
    - kind: offline_segmentation_callback
      every_n_updates: 10000
      batch_size: 1
      dataset_key: val
      ignore_index: -1
      mode: slide
      mode_kwargs:
        stride: [ 341, 341 ]
    # best checkpoint
    - kind: best_checkpoint_callback
      every_n_updates: 10000
      metric_key: iou/val/main