# @package _global_
# Reproducing OmniMAE w/ ViT-Base trunk

pretrained_omnimae_checkpoint_path: ???

defaults:
  - /experiments/base.yaml
  - _self_

launcher:
  gpus_per_node: 8
  num_nodes: 4

trainer:
  max_epochs: 100

  checkpoint:
    save_dir: ${launcher.experiment_log_dir}/checkpoints
    save_freq: 0 # 0 only last checkpoint is saved.
    model_weight_initializer:
      _partial_: True
      _target_: omnivision.model.checkpoint_utils.load_state_dict_into_model
      strict: False # heads aren't loaded
      state_dict:
        _target_: omnivision.model.checkpoint_utils.load_checkpoint_and_apply_kernels
        checkpoint_path: ${pretrained_omnimae_checkpoint_path}
        ckpt_state_dict_key: NULL
        checkpoint_kernels:
        - _target_: omnivision.model.checkpoint_utils.CkptExcludeKernel
          key_pattern:
          - "trunk.decoder.*"
          - "trunk.norm.*"
          - "trunk.mask_token"
          - "head.*"

  data:
    train:
      _target_: omnivision.data.torch_dataset.TorchDataset
      dataset:
        _target_: omnivision.data.path_dataset.ImagePathDataset
        path_file_list:
          - ${in1k_train_imgs_path}
        label_file_list:
          - ${in1k_train_labels_path}
        new_prefix: ${in1k_prefix}
        transforms:
          - _target_: omnivision.data.transforms.transform_wrappers.VisionTransform
            base_transform:
              _target_: torchvision.transforms.Compose
              transforms:
                - _target_: torchvision.transforms.RandomResizedCrop
                  size: 224
                  interpolation: 3
                - _target_: torchvision.transforms.RandomHorizontalFlip
                - _target_: omnivision.data.transforms.rand_auto_aug.RandAugment  # Essentially autoagument rand-m9-mstd0.5-inc1
                  magnitude: 9
                  magnitude_std: 0.5
                  increasing_severity: True
                - _target_: torchvision.transforms.ToTensor
                - _target_: torchvision.transforms.RandomErasing
                  p: .25
                - _target_: torchvision.transforms.Normalize
                  mean: [0.485, 0.456, 0.406]
                  std: [0.229, 0.224, 0.225]
      shuffle: True
      batch_size: 32
      num_workers: 10
      pin_memory: True
      drop_last: True
      collate_fn:
        _target_: omnivision.data.api.DefaultOmnivoreCollator
        output_key: in1k
        batch_transforms:
        - _target_: omnivision.data.transforms.cutmixup.CutMixUp
          mixup_alpha: 0.8 # mixup alpha value, mixup is active if > 0.
          cutmix_alpha: 1.0 # cutmix alpha value, cutmix is active if > 0.
          prob: 1.0 # probability of applying mixup or cutmix per batch or element
          switch_prob: 0.5 # probability of switching to cutmix instead of mixup when both are active
          mode: batch # how to apply mixup/cutmix params (per 'batch', 'pair' (pair of elements), 'elem' (element)
          correct_lam: True # apply lambda correction when cutmix bbox clipped by image borders
          label_smoothing: 0.1 # apply label smoothing to the mixed target tensor
          num_classes: 1000 # number of classes for target
      worker_init_fn: NULL
    val:
      _target_: omnivision.data.torch_dataset.TorchDataset
      dataset:
        _target_: omnivision.data.path_dataset.ImagePathDataset
        path_file_list:
          - ${in1k_val_imgs_path}
        label_file_list:
          - ${in1k_val_labels_path}
        new_prefix: ${in1k_prefix}
        transforms:
          - _target_: omnivision.data.transforms.transform_wrappers.VisionTransform
            base_transform:
              _target_: torchvision.transforms.Compose
              transforms:
                - _target_: torchvision.transforms.Resize
                  size: 224
                  interpolation: 3
                - _target_: torchvision.transforms.CenterCrop
                  size: 224
                - _target_: torchvision.transforms.ToTensor
                - _target_: torchvision.transforms.Normalize
                  mean: [0.485, 0.456, 0.406]
                  std: [0.229, 0.224, 0.225]
      shuffle: False
      batch_size: 32
      num_workers: 8
      pin_memory: True
      drop_last: True
      collate_fn:
        _target_: omnivision.data.api.DefaultOmnivoreCollator
        output_key: in1k
      worker_init_fn: NULL

  model:
    _target_: omnivision.model.model_wrappers.MIMOHeadWrapper
    handle_list_inputs: True
    trunk:
      _target_: omnivision.models.vision_transformer.VisionTransformer
      img_size:
        - 3
        - 16
        - 224
        - 224
      embed_dim: 768
      depth: 12
      patch_size: [2, 16, 16]
      classifier_feature: global_pool
      drop_path_rate: 0.1
      use_cls_token: False
      patch_embed_type: generic
      patch_embed_params_list:
      - _target_: omnivision.models.PadIm2Video
        pad_type: repeat
        ntimes: 2
      - _target_: omnivision.models.make_conv_or_linear
        layer:
          _target_: torch.nn.Conv3d
          in_channels: 3
          out_channels: ${....embed_dim}
          kernel_size: ${....patch_size}
          stride: ${.kernel_size}
        init_weight:
          _target_: omnivision.models.reshape_and_init_as_mlp
        _recursive_: False
      attn_target:
        _target_: omnivision.models.vision_transformer.Attention
        _partial_: True
        num_heads: 12
        proj_drop: 0
        qk_scale: NULL
        qkv_bias: True
        attn_drop: 0
      learnable_pos_embed: False  # Use sinusoidal positional encoding
    heads:
      - head:
          _target_: omnivision.model.model_init_utils.init_parameters
          model:
            _target_: torch.nn.Linear
            in_features: 768
            out_features: 1000
          init_fns:
            weight:
              _target_: timm.models.layers.trunc_normal_
              _partial_: True
              mean: 0
              std: 2e-5
            bias:
              _target_: torch.nn.init.zeros_
              _partial_: True
        fork_module: ""
        input_key: in1k
        output_key: in1k
    trunk_fields:
      - input_key: NULL
        args: ["vision"]
  
  optim:
    gradient_clip: NULL
    amp:
      enabled: False
      amp_dtype: float16 # bfloat16 or float16
    optimizer:
      _target_: torch.optim.AdamW
    param_group_modifiers:
      - _target_: omnivision.optim.layer_decay_param_modifier.layer_decay_param_modifier
        _partial_: True
        layer_decay_value: 0.65
    options:
      lr:
        - scheduler:
            _target_: fvcore.common.param_scheduler.CompositeParamScheduler
            schedulers:
              - _target_: fvcore.common.param_scheduler.LinearParamScheduler
                start_value: 1e-6
                end_value: 4e-3
              - _target_: fvcore.common.param_scheduler.CosineParamScheduler
                start_value: ${..0.end_value}
                end_value: 1e-6
            lengths: [0.05, 0.95]  # warm for 5 epochs
            interval_scaling: ['rescaled', 'rescaled']
      weight_decay:
        - scheduler:
            _target_: fvcore.common.param_scheduler.ConstantParamScheduler
            value: 1e-4  # 0.05??
        - scheduler:
            _target_: fvcore.common.param_scheduler.ConstantParamScheduler
            value: 0.0
          param_names:
             - '*.bias'
            #  - '*.pos_embed'
            #  - '*.cls_token'
          module_cls_names: ['torch.nn.LayerNorm']
  metrics:
    train:
      in1k:
        accuracy_top1:
          _target_: omnivision.metrics.accuracy.Accuracy
          top_k: 1
        accuracy_top5:
          _target_: omnivision.metrics.accuracy.Accuracy
          top_k: 5
    val:
      in1k:
        accuracy_top1:
          _target_: omnivision.metrics.accuracy.Accuracy
          top_k: 1
        accuracy_top5:
          _target_: omnivision.metrics.accuracy.Accuracy
          top_k: 5
  loss:
    in1k:
      _target_: torch.nn.CrossEntropyLoss
