master_factory_base_path: vislstm
name: in1k-lstm-6m16-e400res192-bialter-midpatch-lr1e3
stage_name: in1k
vars:
  lr: 1.0e-3
  max_epochs: 400
  batch_size: 1024
  resolution: 192

datasets:
  train:
    kind: imagenet1k
    split: train
    sample_wrappers:
      - kind: x_transform_wrapper
        transform:
          - kind: random_resized_crop
            size: ${vars.resolution}
            scale:
              - 0.08
              - 1.0
            interpolation: bicubic
          - kind: random_horizontal_flip
          - kind: transforms.three_augment
            blur_sigma:
              - 0.1
              - 2.0
          - kind: color_jitter
            brightness: 0.3
            contrast: 0.3
            saturation: 0.3
            hue: 0.0
          - kind: imagenet1k_norm
      - kind: one_hot_wrapper
    collators:
      - kind: mix_collator
        mixup_alpha: 0.8
        cutmix_alpha: 1.0
        mixup_p: 0.5
        cutmix_p: 0.5
        apply_mode: batch
        lamb_mode: batch
        shuffle_mode: flip
  val:
    kind: imagenet1k
    split: val
    sample_wrappers:
      - kind: x_transform_wrapper
        transform:
          - kind: resize
            size: ${vars.resolution}
            interpolation: bicubic
          - kind: center_crop
            size: ${vars.resolution}
          - kind: imagenet1k_norm

model:
  kind: models.single.vislstm
  patch_size: 16
  dim: 192
  depth: 24
  bidirectional: false
  alternation: bidirectional
  pos_embed_mode: learnable
  mode: classifier
  pooling:
    kind: middle_tokens
    num_tokens: 1
  optim:
    kind: adamw
    lr: ${vars.lr}
    betas: [ 0.9, 0.999 ]
    weight_decay: 0.05
    clip_grad_norm: 1.0
    schedule:
      kind: linear_warmup_cosine_decay_schedule
      warmup_epochs: 5
      end_value: 1.0e-6
    lr_scaler:
      kind: linear_lr_scaler
      divisor: 1024


trainer:
  kind: classification_trainer
  precision: bfloat16
  backup_precision: float16
  max_epochs: ${vars.max_epochs}
  effective_batch_size: ${vars.batch_size}
  log_every_n_epochs: 1
  use_torch_compile: true
  callbacks:
    # save last checkpoint
    - kind: checkpoint_callback
    # save latest checkpoint
    - kind: checkpoint_callback
      every_n_epochs: 10
      save_weights: false
      save_latest_weights: true
      save_latest_optim: true
    # metrics
    - kind: offline_accuracy_callback
      every_n_epochs: 1
      dataset_key: val