master_factory_base_path: vislstm

name: vil2-b16
stage_name: robustness
vars:
  resolution: 224

datasets:
  adversarial:
    kind: imagenet_adversarial
    sample_wrappers:
      - kind: x_transform_wrapper
        transform:
          - kind: resize
            size: ${vars.resolution}
            interpolation: bicubic
          - kind: center_crop
            size: ${vars.resolution}
          - kind: imagenet1k_norm
  rendition:
    kind: imagenet_rendition
    sample_wrappers:
      - kind: x_transform_wrapper
        transform:
          - kind: resize
            size: ${vars.resolution}
            interpolation: bicubic
          - kind: center_crop
            size: ${vars.resolution}
          - kind: imagenet1k_norm
  sketch:
    kind: imagenet_sketch
    sample_wrappers:
      - kind: x_transform_wrapper
        transform:
          - kind: resize
            size: ${vars.resolution}
            interpolation: bicubic
          - kind: center_crop
            size: ${vars.resolution}
          - kind: imagenet1k_norm
  val:
    kind: imagenet1k
    split: val
    sample_wrappers:
      - kind: x_transform_wrapper
        transform:
          - kind: resize
            size: ${vars.resolution}
            interpolation: bicubic
          - kind: center_crop
            size: ${vars.resolution}
          - kind: imagenet1k_norm

model:
  is_frozen: true
  initializers:
    - kind: previous_run_initializer
      stage_id: 6x5zxnr9
      stage_name: in1k
      model_name: vislstm
      checkpoint: last
      use_checkpoint_kwargs: true



trainer:
  kind: classification_trainer
  precision: bfloat16
  backup_precision: float16
  max_epochs: 0
  effective_batch_size: 32
  log_every_n_epochs: 1
  callbacks:
    - kind: offline_classsubset_accuracy_callback
      every_n_epochs: 1
      dataset_key: adversarial
    - kind: offline_classsubset_accuracy_callback
      every_n_epochs: 1
      dataset_key: rendition
    - kind: offline_accuracy_callback
      every_n_epochs: 1
      dataset_key: sketch
    - kind: offline_accuracy_callback
      every_n_epochs: 1
      dataset_key: val
    - kind: offline_imagenet_corruption_callback
      every_n_epochs: 1