# @package _global_
# Reproducing OmniMAE w/ ViT-Base trunk

pretrained_omnimae_checkpoint_path: ???

defaults:
  - /experiments/base.yaml
  - _self_

launcher:
  gpus_per_node: 8
  num_nodes: 4

trainer:
  max_epochs: 100

  checkpoint:
    save_dir: ${launcher.experiment_log_dir}/checkpoints
    save_freq: 0 # 0 only last checkpoint is saved.
    model_weight_initializer:
      _partial_: True
      _target_: omnivision.model.checkpoint_utils.load_state_dict_into_model
      strict: False # heads aren't loaded
      state_dict:
        _target_: omnivision.model.checkpoint_utils.load_checkpoint_and_apply_kernels
        checkpoint_path: ${pretrained_omnimae_checkpoint_path}
        ckpt_state_dict_key: model
        checkpoint_kernels:
        - _target_: omnivision.model.checkpoint_utils.CkptExcludeKernel
          key_pattern:
          - "trunk.decoder.*"
          - "trunk.norm.*"
          - "trunk.mask_token"
          - "heads.*"

  data:
    train:
      _target_: omnivision.data.torch_dataset.TorchDataset
      dataset:
        _target_: omnivision.data.path_dataset.VideoPathDataset
        path_file_list:
          - ${ssv2_train_vids_path}
        label_file_list:
          - ${ssv2_train_labels_path}
        new_prefix: ${ssv2_prefix}
        clip_sampler:
          _target_: pytorchvideo.data.clip_sampling.RandomClipSampler
          clip_duration: 2.7
        frame_sampler:
          _target_: pytorchvideo.transforms.UniformTemporalSubsample
          num_samples: 16
        decoder: pyav
        normalize_to_0_1: True
        transforms:
          - _target_: omnivision.data.transforms.transform_wrappers.VisionTransform
            base_transform:
              _target_: torchvision.transforms.Compose
              transforms:
              - _target_: pytorchvideo.transforms.ShortSideScale
                size: 256
              - _target_: torchvision.transforms.RandomResizedCrop
                size: 224
              # - _target_: torchvision.transforms.RandomHorizontalFlip
              #   p: 0.5
              - _target_: omnivision.data.transforms.basic.Permute
                ordering: [1, 0, 2, 3]  # C,T,H,W -> T,C,H,W for RandAug
              - _target_: pytorchvideo.transforms.RandAugment
                magnitude: 7
                num_layers: 4
                prob: 0.5
              - _target_: omnivision.data.transforms.basic.Permute
                ordering: [1, 0, 2, 3]  # T,C,H,W -> C,T,H,W after RandAug
              - _target_: torchvision.transforms._transforms_video.NormalizeVideo
                mean: [0.485, 0.456, 0.406]
                std: [0.229, 0.224, 0.225]
              - _target_: omnivision.data.transforms.basic.Permute
                ordering: [1, 0, 2, 3]  # C,T,H,W -> T,C,H,W for cube RandErase over batch(time) dim
              - _target_: omnivision.data.transforms.video_random_erasing.RandomErasing
                probability: 0.25
                mode: pixel
                max_count: 1
                num_splits: 1
                cube: True
                device: cpu
              - _target_: omnivision.data.transforms.basic.Permute
                ordering: [1, 0, 2, 3]  # C,T,H,W -> T,C,H,W after RandErase
      shuffle: True
      batch_size: 8
      num_workers: 8
      pin_memory: True
      drop_last: True
      collate_fn:
        _target_: omnivision.data.api.DefaultOmnivoreCollator
        output_key: ssv2
        batch_transforms:
          - _target_: omnivision.data.transforms.cutmixup.CutMixUp
            mixup_alpha: 0.8 # mixup alpha value, mixup is active if > 0.
            cutmix_alpha: 1.0 # cutmix alpha value, cutmix is active if > 0.
            prob: 1.0 # probability of applying mixup or cutmix per batch or element
            switch_prob: 0.5 # probability of switching to cutmix instead of mixup when both are active
            mode: batch # how to apply mixup/cutmix params (per 'batch', 'pair' (pair of elements), 'elem' (element)
            correct_lam: True # apply lambda correction when cutmix bbox clipped by image borders
            label_smoothing: 0.1 # apply label smoothing to the mixed target tensor
            num_classes: 174 # number of classes for target
      worker_init_fn: NULL
    val:
      _target_: omnivision.data.torch_dataset.TorchDataset
      dataset:
        _target_: omnivision.data.path_dataset.VideoPathDataset
        path_file_list:
          - ${ssv2_val_vids_path}
        label_file_list:
          - ${ssv2_val_labels_path}
        new_prefix: ${ssv2_prefix}
        clip_sampler:
          _target_: pytorchvideo.data.clip_sampling.ConstantClipsPerVideoSampler
          clip_duration: ${trainer.data.train.dataset.clip_sampler.clip_duration}
          clips_per_video: 5
        frame_sampler:
          _target_: pytorchvideo.transforms.UniformTemporalSubsample
          num_samples: ${trainer.data.train.dataset.frame_sampler.num_samples}
        decoder: ${trainer.data.train.dataset.decoder}
        normalize_to_0_1: ${trainer.data.train.dataset.normalize_to_0_1}
        transforms:
          - _target_: torchvision.transforms.Compose
            transforms:
              # Not splitting into multiple sample objects to keep the crops of one video
              # in 1 sigle Sample object
              - _target_: omnivision.data.transforms.transform_wrappers.VisionTransform
                base_transform:
                  _target_: omnivision.data.transforms.transform_wrappers.ListTransform
                  base_transform:
                    _target_: pytorchvideo.transforms.ShortSideScale
                    size: 224
              - _target_: omnivision.data.transforms.transform_wrappers.VisionTransform
                base_transform:
                  _target_: omnivision.data.transforms.transform_wrappers.ListTransform
                  base_transform:
                    _target_: torchvision.transforms._transforms_video.NormalizeVideo
                    mean: ${trainer.data.train.dataset.transforms.0.base_transform.transforms.5.mean}
                    std: ${trainer.data.train.dataset.transforms.0.base_transform.transforms.5.std}
              - _target_: omnivision.data.transforms.transform_wrappers.VisionTransform
                base_transform:
                  _target_: omnivision.data.transforms.transform_wrappers.ListTransform
                  base_transform:
                    _target_: omnivision.data.transforms.basic.SpatialCrop
                    crop_size: 224
                    num_crops: 3
              # Each sample now has a list in the .data object, so convert into sublists
              - _target_: omnivision.data.transforms.transform_wrappers.VisionTransform
                base_transform:
                  _target_: omnivision.data.transforms.transform_wrappers.FlattenListOfList
      shuffle: False
      batch_size: 1
      num_workers: 8
      pin_memory: True
      drop_last: True
      collate_fn:
        _target_: omnivision.data.api.DefaultOmnivoreCollator
        output_key: ssv2
      worker_init_fn: NULL


  model:
    _target_: omnivision.model.model_wrappers.MIMOHeadWrapper
    handle_list_inputs: True
    trunk:
      _target_: omnivision.models.vision_transformer.VisionTransformer
      img_size:
        - 3
        - ${trainer.data.train.dataset.frame_sampler.num_samples}
        - 224
        - 224
      embed_dim: 768
      depth: 12
      patch_size: [2, 16, 16]
      classifier_feature: global_pool
      drop_path_rate: 0.1
      use_cls_token: False
      patch_embed_type: generic
      patch_embed_params_list:
      - _target_: omnivision.models.PadIm2Video
        pad_type: repeat
        ntimes: 2
      - _target_: omnivision.models.make_conv_or_linear
        layer:
          _target_: torch.nn.Conv3d
          in_channels: 3
          out_channels: ${....embed_dim}
          kernel_size: ${....patch_size}
          stride: ${.kernel_size}
        init_weight:
          _target_: omnivision.models.reshape_and_init_as_mlp
        _recursive_: False
      attn_target:
        _target_: omnivision.models.vision_transformer.Attention
        _partial_: True
        num_heads: 12
        proj_drop: 0
        qk_scale: NULL
        qkv_bias: True
        attn_drop: 0
      learnable_pos_embed: False  # Use sinusoidal positional encoding
    heads:
      - head:
          _target_: torch.nn.Sequential
          _args_:
            # - _target_: torch.nn.Dropout
            #   p: 0.0
            - _target_: omnivision.model.model_init_utils.init_parameters
              model:
                _target_: torch.nn.Linear
                in_features: 768  # 8 * 96
                out_features: 174
              init_fns:
                weight:
                  _target_: torch.nn.init.normal_
                  _partial_: True
                  mean: 0
                  std: 0.01
                bias:
                  _target_: torch.nn.init.zeros_
                  _partial_: True
        fork_module: ""
        input_key: NULL
        output_key: NULL
    trunk_fields:
      - input_key: NULL
        args: ["vision"]
  optim:
    gradient_clip: NULL
    amp:
      enabled: False
      amp_dtype: float16 # bfloat16 or float16
    optimizer:
      _target_: torch.optim.AdamW
    param_group_modifiers:
      - _target_: omnivision.optim.layer_decay_param_modifier.layer_decay_param_modifier
        _partial_: True
        layer_decay_value: 0.75
    options:
      lr:
        - scheduler:
            _target_: fvcore.common.param_scheduler.CompositeParamScheduler
            schedulers:
              - _target_: fvcore.common.param_scheduler.LinearParamScheduler
                start_value: 1e-6
                end_value: 1e-3
              - _target_: fvcore.common.param_scheduler.CosineParamScheduler
                start_value: ${..0.end_value}
                end_value: 1e-6
            lengths: [0.125, 0.875]  # 5 epoch warmup
            interval_scaling: ['rescaled', 'rescaled']
      weight_decay:
        - scheduler:
            _target_: fvcore.common.param_scheduler.ConstantParamScheduler
            value: 0.05
        - scheduler:
            _target_: fvcore.common.param_scheduler.ConstantParamScheduler
            value: 0.0
          param_names:
          # TODO should we allow unused param names or classes?
             - '*.bias'
             # - '*.pos_embed'  # using sinusoidal
             #  - '*.cls_token'  # no cls token
          # TODO: allow other forms for class names
          module_cls_names: ['torch.nn.LayerNorm']
  
  metrics:
    train:
      ssv2:
        accuracy_top1:
          _target_: omnivision.metrics.avg_pooled_accuracy_list_meter.AvgPooledAccuracyListMeter
          top_k: 1
        accuracy_top5:
          _target_: omnivision.metrics.avg_pooled_accuracy_list_meter.AvgPooledAccuracyListMeter
          top_k: 5
    val:
      ssv2:
        accuracy_top1:
          _target_: omnivision.metrics.avg_pooled_accuracy_list_meter.AvgPooledAccuracyListMeter
          top_k: 1
        accuracy_top5:
          _target_: omnivision.metrics.avg_pooled_accuracy_list_meter.AvgPooledAccuracyListMeter
          top_k: 5
  loss:
    ssv2:
      _target_: omnivision.losses.cross_entropy_multiple_output_single_target.CrossEntropyMultipleOutputSingleTargetLoss
      ignore_index: -1
      update_output_apply_activation: True