data:
  train:
    _target_: omnivision.data.concat_dataset.ConcatDataset
    max_steps: sum
    repeat_factors:
      - ${divide:1,${trainer.data.train.datasets.0.dataset.transforms.0.transforms.0.base_transform.num_times}}
    datasets:
      - _target_: omnivision.data.torch_dataset.TorchDataset
        dataset:
          _target_: omnivision.data.path_dataset.VideoPathDataset
          path_file_list:
            - ${k150_train_vids_path}
          label_file_list:
            - ${k150_train_labels_path}
          new_prefix: ${k150_prefix}
          clip_sampler:
            _target_: pytorchvideo.data.clip_sampling.RandomClipSampler
            clip_duration: 2.7
          frame_sampler:
            _target_: pytorchvideo.transforms.UniformTemporalSubsample
            num_samples: ${num_frames}
          decoder: pyav
          normalize_to_0_1: True
          transforms:
            - _target_: torchvision.transforms.Compose
              transforms:
                - _target_: omnivision.data.transforms.transform_wrappers.VisionTransform
                  base_transform:
                    _target_: omnivision.data.transforms.pytorchvideo.Replicate
                    num_times: 4
                - _target_: omnivision.data.transforms.transform_wrappers.VisionTransform
                  base_transform:
                    _target_: omnivision.data.transforms.transform_wrappers.ListTransform
                    base_transform:
                      _target_: torchvision.transforms.Compose
                      transforms:
                        - _target_: pytorchvideo.transforms.ShortSideScale
                          size: 256
                        - _target_: torchvision.transforms.RandomResizedCrop
                          size: 224
                        - _target_: torchvision.transforms._transforms_video.NormalizeVideo
                          mean: [0.485, 0.456, 0.406]
                          std: [0.229, 0.224, 0.225]
            - _target_: omnivision.data.transforms.transform_wrappers.SingleFieldListToSampleList
              field: vision
            - _target_: omnivision.data.transforms.transform_wrappers.ListTransform
              base_transform:
                _target_: omnivision.data.transforms.transform_wrappers.MaskingTransform
                masking_object:
                  _target_: omnivision.data.transforms.mask_image_modeling.MaskImageModeling
                  pred_ratio: 0.9
                  pred_ratio_var: 0.0
                  pred_shape:
                    _target_: omnivision.data.transforms.mask_image_modeling.RandMasking
                  patch_size: ${trainer.model.trunk.patch_size}
        shuffle: True
        batch_size: ${int:${divide:${batch_size},${trainer.data.train.datasets.0.dataset.transforms.0.transforms.0.base_transform.num_times}}}
        num_workers: ${num_workers}
        pin_memory: True
        drop_last: False
        collate_fn:
          _target_: omnivision.data.api.SampleListOmnivoreCollator
          output_key: k150
          batch_kwargs:
            model_fwd_kwargs:
              use_checkpoint: True
        worker_init_fn: NULL
  val: NULL

metrics: NULL

loss:
  k150:
    _target_: omnivision.losses.mae_loss.MAELoss
    norm_pix_loss: True
    norm_pix_per_channel: True
    patch_size: ${trainer.model.trunk.patch_size}
    unnormalize_img:
      - ${trainer.data.train.datasets.0.dataset.transforms.0.transforms.1.base_transform.base_transform.transforms.2.mean}
      - ${trainer.data.train.datasets.0.dataset.transforms.0.transforms.1.base_transform.base_transform.transforms.2.std}
    pad_object: ${trainer.model.trunk.patch_embed_params_list.0}