# @package _global_
# Reproducing OmniMAE w/ ViT-Base trunk

defaults:
  - /experiments/base.yaml
  - _self_

base_batchsize_per_replica: 32

launcher:
  gpus_per_node: 8
  num_nodes: 8

trainer:
  max_epochs: 800

  distributed:
    comms_dtype: float16 # NULL, float16, bfloat16

  data:
    train:
      _target_: omnivision.data.concat_dataset.ConcatDataset
      max_steps: sum
      repeat_factors:
        - 1.0
        - ${divide:1,${trainer.data.train.datasets.1.dataset.transforms.0.transforms.0.base_transform.num_times}}
      datasets:
      - _target_: omnivision.data.torch_dataset.TorchDataset
        dataset:
          _target_: omnivision.data.path_dataset.ImagePathDataset
          path_file_list:
            - ${in1k_train_imgs_path}
          label_file_list:
            - ${in1k_train_labels_path}
          new_prefix: ${in1k_prefix}
          transforms:
            - _target_: torchvision.transforms.Compose
              transforms:
                - _target_: omnivision.data.transforms.transform_wrappers.VisionTransform
                  base_transform:
                    _target_: torchvision.transforms.RandomResizedCrop
                    size: 224
                    scale: [0.2, 1.0]
                    interpolation: 3
                - _target_: omnivision.data.transforms.transform_wrappers.VisionTransform
                  base_transform:
                    _target_: torchvision.transforms.RandomHorizontalFlip
                    p: 0.5
                - _target_: omnivision.data.transforms.transform_wrappers.VisionTransform
                  base_transform:
                    _target_: torchvision.transforms.ToTensor
                - _target_: omnivision.data.transforms.transform_wrappers.VisionTransform
                  base_transform:
                    _target_: torchvision.transforms.Normalize
                    mean: [0.485, 0.456, 0.406]
                    std: [0.229, 0.224, 0.225]
            - _target_: omnivision.data.transforms.transform_wrappers.MaskingTransform
              masking_object:
                _target_: omnivision.data.transforms.mask_image_modeling.MaskImageModeling
                pred_ratio: 0.75
                pred_ratio_var: 0.0
                pred_shape:
                  _target_: omnivision.data.transforms.mask_image_modeling.TubeMasking
                  frame_masking:
                    _target_: omnivision.data.transforms.mask_image_modeling.RandMasking
                patch_size: ${trainer.model.trunk.patch_size}
        shuffle: True
        batch_size: ${base_batchsize_per_replica}
        num_workers: 10
        pin_memory: True
        drop_last: False
        collate_fn:
          _target_: omnivision.data.api.DefaultOmnivoreCollator
          output_key: in1k
        worker_init_fn: NULL
      - _target_: omnivision.data.torch_dataset.TorchDataset
        dataset:
          _target_: omnivision.data.path_dataset.VideoPathDataset
          path_file_list:
            - ${ssv2_train_vids_path}
          label_file_list:
            - ${ssv2_train_labels_path}
          new_prefix: ${ssv2_prefix}
          clip_sampler:
            _target_: pytorchvideo.data.clip_sampling.RandomClipSampler
            clip_duration: 2.7
          frame_sampler:
            _target_: pytorchvideo.transforms.UniformTemporalSubsample
            num_samples: 16
          decoder: pyav
          normalize_to_0_1: True
          transforms:
            - _target_: torchvision.transforms.Compose
              transforms:
                - _target_: omnivision.data.transforms.transform_wrappers.VisionTransform
                  base_transform:
                    _target_: omnivision.data.transforms.pytorchvideo.Replicate
                    num_times: 4
                - _target_: omnivision.data.transforms.transform_wrappers.VisionTransform
                  base_transform:
                    _target_: omnivision.data.transforms.transform_wrappers.ListTransform
                    base_transform:
                      _target_: torchvision.transforms.Compose
                      transforms:
                        - _target_: pytorchvideo.transforms.ShortSideScale
                          size: 256
                        - _target_: torchvision.transforms.RandomResizedCrop
                          size: 224
                        - _target_:  torchvision.transforms._transforms_video.NormalizeVideo
                          mean: ${trainer.data.train.datasets.0.dataset.transforms.0.transforms.3.base_transform.mean}
                          std: ${trainer.data.train.datasets.0.dataset.transforms.0.transforms.3.base_transform.std}
            - _target_: omnivision.data.transforms.transform_wrappers.SingleFieldListToSampleList
              field: vision
            - _target_: omnivision.data.transforms.transform_wrappers.ListTransform
              base_transform:
                _target_: omnivision.data.transforms.transform_wrappers.MaskingTransform
                masking_object:
                  _target_: omnivision.data.transforms.mask_image_modeling.MaskImageModeling
                  pred_ratio: 0.9
                  pred_ratio_var: 0.0
                  pred_shape:
                    _target_: omnivision.data.transforms.mask_image_modeling.RandMasking
                  patch_size: ${trainer.model.trunk.patch_size}
        shuffle: True
        batch_size: ${int:${divide:${base_batchsize_per_replica},${trainer.data.train.datasets.1.dataset.transforms.0.transforms.0.base_transform.num_times}}}
        num_workers: 6
        pin_memory: True
        drop_last: False
        collate_fn:
          _target_: omnivision.data.api.SampleListOmnivoreCollator
          output_key: ssv2
          batch_kwargs:
            model_fwd_kwargs:
              use_checkpoint: True
        worker_init_fn: NULL
    val: NULL

  model:
    _target_: omnivision.model.model_wrappers.MIMOHeadWrapper
    trunk:
      _target_: omnivision.models.vision_transformer.VisionTransformer
      img_size:
        - 3
        - ${trainer.data.train.datasets.1.dataset.frame_sampler.num_samples}
        - 224
        - 224
      embed_dim: 768
      depth: 12
      patch_size: [2, 16, 16]
      classifier_feature: global_pool
      drop_path_rate: 0.0
      use_cls_token: False
      patch_embed_type: generic
      patch_embed_params_list:
        - _target_: omnivision.models.PadIm2Video
          pad_type: repeat
          ntimes: 2
        - _target_: omnivision.models.make_conv_or_linear
          layer:
            _target_: torch.nn.Conv3d
            in_channels: 3
            out_channels: ${....embed_dim}
            kernel_size: ${....patch_size}
            stride: ${.kernel_size}
          init_weight:
            _target_: omnivision.models.reshape_and_init_as_mlp
          _recursive_: False
      attn_target:
        _target_: omnivision.models.vision_transformer.Attention
        _partial_: True
        num_heads: 12
        proj_drop: 0
        qk_scale: NULL
        qkv_bias: True
        attn_drop: 0
      learnable_pos_embed: False  # Use sinusoidal positional encoding
      masked_image_modeling: True
      patch_dropping: True
      decoder:
        _target_: omnivision.models.vision_transformer.Decoder
        _partial_: True
        embed_dim: ${trainer.model.trunk.embed_dim}
        decoder_depth: 4
        decoder_embed_dim: 384
        learnable_pos_embed: False  # Use sinusoidal positional encoding
        attn_target:
          _target_: omnivision.models.vision_transformer.Attention
          _partial_: True
          num_heads: 16
          proj_drop: 0
          qk_scale: NULL
          qkv_bias: True
          attn_drop: 0
    heads:
      - head:
          _target_: omnivision.models.heads.mae_head.MAEHead
          in_features: ${trainer.model.trunk.decoder.decoder_embed_dim}
          # 3 x 2 x 16 x 16
          out_features: ${times:${times:${times:${trainer.model.trunk.img_size.0},${trainer.model.trunk.patch_size.0}},${trainer.model.trunk.patch_size.1}},${trainer.model.trunk.patch_size.2}}
        input_key: NULL
        output_key: NULL
        fork_module: ""
    trunk_fields:
      - input_key: NULL
        args: ["vision"]
        kwargs: {"mask": "mask"}

  optim:
    gradient_clip: NULL
    amp:
      enabled: True
      amp_dtype: float16 # bfloat16 or float16

    optimizer:
      _target_: torch.optim.AdamW
      betas: [0.9, 0.95]
    options:
      lr:
        - scheduler:
            _target_: fvcore.common.param_scheduler.CompositeParamScheduler
            schedulers:
              - _target_: fvcore.common.param_scheduler.LinearParamScheduler
                start_value: 1e-6
                end_value: 1.6e-3  # 8e-4 in orig config
              - _target_: fvcore.common.param_scheduler.CosineParamScheduler
                start_value: ${..0.end_value}
                end_value: 0.0
            lengths: [0.05, 0.95]  # warm for 40 epochs
            interval_scaling: ['rescaled', 'fixed']
      weight_decay:
        - scheduler:
            _target_: fvcore.common.param_scheduler.ConstantParamScheduler
            value: 0.05
        - scheduler:
            _target_: fvcore.common.param_scheduler.ConstantParamScheduler
            value: 0.0
          param_names:
              - '*.bias'
          module_cls_names: ['torch.nn.LayerNorm']

  metrics: NULL

  loss:
    ssv2:
      _target_: omnivision.losses.mae_loss.MAELoss
      norm_pix_loss: True
      norm_pix_per_channel: True
      patch_size: ${trainer.model.trunk.patch_size}
      unnormalize_img:
        - ${trainer.data.train.datasets.0.dataset.transforms.0.transforms.3.base_transform.mean}
        - ${trainer.data.train.datasets.0.dataset.transforms.0.transforms.3.base_transform.std}
      pad_object: ${trainer.model.trunk.patch_embed_params_list.0}
    in1k: ${.ssv2}
