# @package _global_

defaults:
  - /experiments/base.yaml
  - _self_

launcher:
  gpus_per_node: 8
  num_nodes: 8

trainer:
  max_epochs: 300

  distributed:
    find_unused_parameters: True

  data:
    train:
      _target_: omnivision.data.concat_dataset.ConcatDataset
      max_steps: sum
      repeat_factors: [1.0, 50.0, 1.0]
      datasets:
      - _target_: omnivision.data.torch_dataset.TorchDataset
        dataset:
          _target_: omnivision.data.path_dataset.ImagePathDataset
          path_file_list:
            - ${in1k_train_imgs_path}
          label_file_list:
            - ${in1k_train_labels_path}
          new_prefix: ${in1k_prefix}
          transforms:
            - _target_: omnivision.data.transforms.transform_wrappers.VisionTransform
              base_transform:
                _target_: torchvision.transforms.Compose
                transforms:
                  - _target_: torchvision.transforms.RandomResizedCrop
                    size: 224
                    interpolation: 3
                  - _target_: torchvision.transforms.RandomHorizontalFlip
                  - _target_: omnivision.data.transforms.rand_auto_aug.RandAugment  # Essentially autoagument rand-m9-mstd0.5-inc1
                    magnitude: 9
                    magnitude_std: 0.5
                    increasing_severity: True
                  - _target_: torchvision.transforms.ColorJitter
                    brightness: 0.4
                    contrast: 0.4
                    saturation: 0.4
                    hue: 0.4
                  - _target_: torchvision.transforms.ToTensor
                  - _target_: torchvision.transforms.RandomErasing
                    p: .25
                  - _target_: torchvision.transforms.Normalize
                    mean: [0.485, 0.456, 0.406]
                    std: [0.229, 0.224, 0.225]
        shuffle: True
        batch_size: 32
        num_workers: 5
        pin_memory: True
        drop_last: True
        collate_fn:
          _target_: omnivision.data.api.DefaultOmnivoreCollator
          output_key: in1k
          batch_transforms:
          - _target_: omnivision.data.transforms.cutmixup.CutMixUp
            mixup_alpha: 0.8 # mixup alpha value, mixup is active if > 0.
            cutmix_alpha: 1.0 # cutmix alpha value, cutmix is active if > 0.
            prob: 1.0 # probability of applying mixup or cutmix per batch or element
            switch_prob: 0.5 # probability of switching to cutmix instead of mixup when both are active
            mode: batch # how to apply mixup/cutmix params (per 'batch', 'pair' (pair of elements), 'elem' (element)
            correct_lam: True # apply lambda correction when cutmix bbox clipped by image borders
            label_smoothing: 0.1 # apply label smoothing to the mixed target tensor
            num_classes: 1000 # number of classes for target
        worker_init_fn: NULL
      - _target_: omnivision.data.torch_dataset.TorchDataset
        dataset:
          _target_: omnivision.data.path_dataset.ImageWithDepthPathDataset
          path_file_list:
            - ${sunrgbd_train_imgs_path}
          label_file_list:
            - ${sunrgbd_train_labels_path}
          depth_path_file_list:
            - ${sunrgbd_train_depth_imgs_path}
          new_prefix: ${sunrgbd_prefix}
          new_depth_prefix: ${sunrgbd_prefix}
          transforms:
            - _target_: omnivision.data.transforms.transform_wrappers.VisionTransform
              base_transform:
                _target_: torchvision.transforms.Compose
                transforms:
                  - _target_: omnivision.data.transforms.image_rgbd.DepthNorm
                    max_depth: 75
                    clamp_max_before_scale: True
                  - _target_: torchvision.transforms.RandomResizedCrop
                    size: 224
                    interpolation: 2
                  - _target_: torchvision.transforms.RandomHorizontalFlip
                  - _target_: omnivision.data.transforms.image_rgbd.RandAugment3d  # Essentially autoagument rand-m9-mstd0.5-inc1
                    num_ops: 2
                    magnitude: 9
                    interpolation: 2
                  - _target_: omnivision.data.transforms.image_rgbd.ColorJitter3d
                    brightness: 0.4
                    contrast: 0.4
                    saturation: 0.4
                    hue: 0.4
                  - _target_: torchvision.transforms.RandomErasing
                    p: .25
                  - _target_: torchvision.transforms.Normalize
                    mean: [0.485, 0.456, 0.406, 0.0]
                    std: [0.229, 0.224, 0.225, 1.0]
                  - _target_: omnivision.data.transforms.image_rgbd.DropChannels
                    channel_probs: [0.5, 0.5, 0.5, 0]
                    tie_channels: [0, 1, 2]
                    fill_values: [0, 0, 0, 0]
        shuffle: True
        batch_size: 32
        num_workers: 5
        pin_memory: True
        drop_last: True
        collate_fn:
          _target_: omnivision.data.api.DefaultOmnivoreCollator
          output_key: sunrgbd
          batch_transforms:
          - _target_: omnivision.data.transforms.cutmixup.CutMixUp
            mixup_alpha: 0.8 # mixup alpha value, mixup is active if > 0.
            cutmix_alpha: 1.0 # cutmix alpha value, cutmix is active if > 0.
            prob: 1.0 # probability of applying mixup or cutmix per batch or element
            switch_prob: 0.5 # probability of switching to cutmix instead of mixup when both are active
            mode: batch # how to apply mixup/cutmix params (per 'batch', 'pair' (pair of elements), 'elem' (element)
            correct_lam: True # apply lambda correction when cutmix bbox clipped by image borders
            label_smoothing: 0.1 # apply label smoothing to the mixed target tensor
            num_classes: 19 # number of classes for target
        worker_init_fn: NULL
      - _target_: omnivision.data.torch_dataset.TorchDataset
        dataset:
          _target_: omnivision.data.path_dataset.VideoPathDataset
          path_file_list:
            - ${k400_train_vids_path}
          label_file_list:
            - ${k400_train_labels_path}
          new_prefix: ${k400_prefix}
          clip_sampler:
            _target_: pytorchvideo.data.clip_sampling.RandomClipSampler
            clip_duration: 2
          frame_sampler:
            _target_: pytorchvideo.transforms.UniformTemporalSubsample
            num_samples: 32
          decoder: pyav
          normalize_to_0_1: True
          transforms:
            - _target_: omnivision.data.transforms.transform_wrappers.VisionTransform
              base_transform:
                _target_: torchvision.transforms.Compose
                transforms:
                - _target_: pytorchvideo.transforms.ShortSideScale
                  size: 256
                - _target_: torchvision.transforms.RandomResizedCrop
                  size: 224
                - _target_: torchvision.transforms.RandomHorizontalFlip
                  p: 0.5
                - _target_: torchvision.transforms._transforms_video.NormalizeVideo
                  mean: [0.485, 0.456, 0.406]
                  std: [0.229, 0.224, 0.225]
        shuffle: True
        batch_size: 32
        num_workers: 3
        pin_memory: True
        drop_last: True
        collate_fn:
          _target_: omnivision.data.api.DefaultOmnivoreCollator
          output_key: k400
          batch_kwargs:
            model_fwd_kwargs:
              use_checkpoint: True
            accum_steps: 4
          batch_transforms:
          - _target_: omnivision.data.transforms.cutmixup.CutMixUp
            mixup_alpha: 0.8 # mixup alpha value, mixup is active if > 0.
            cutmix_alpha: 1.0 # cutmix alpha value, cutmix is active if > 0.
            prob: 1.0 # probability of applying mixup or cutmix per batch or element
            switch_prob: 0.5 # probability of switching to cutmix instead of mixup when both are active
            mode: batch # how to apply mixup/cutmix params (per 'batch', 'pair' (pair of elements), 'elem' (element)
            correct_lam: True # apply lambda correction when cutmix bbox clipped by image borders
            label_smoothing: 0.1 # apply label smoothing to the mixed target tensor
            num_classes: 400 # number of classes for target
        worker_init_fn: NULL
    val:
      _target_: omnivision.data.concat_dataset.ConcatDataset
      max_steps: sum
      datasets:
      - _target_: omnivision.data.torch_dataset.TorchDataset
        dataset:
          _target_: omnivision.data.path_dataset.ImagePathDataset
          path_file_list:
            - ${in1k_val_imgs_path}
          label_file_list:
            - ${in1k_val_labels_path}
          new_prefix: ${in1k_prefix}
          transforms:
            - _target_: omnivision.data.transforms.transform_wrappers.VisionTransform
              base_transform:
                _target_: torchvision.transforms.Compose
                transforms:
                  - _target_: torchvision.transforms.Resize
                    size: 224
                    interpolation: 3
                  - _target_: torchvision.transforms.CenterCrop
                    size: 224
                  - _target_: torchvision.transforms.ToTensor
                  - _target_: torchvision.transforms.Normalize
                    mean: [0.485, 0.456, 0.406]
                    std: [0.229, 0.224, 0.225]
        shuffle: False
        batch_size: 64
        num_workers: 4
        pin_memory: True
        drop_last: True
        collate_fn:
          _target_: omnivision.data.api.DefaultOmnivoreCollator
          output_key: in1k
        worker_init_fn: NULL
      - _target_: omnivision.data.torch_dataset.TorchDataset
        dataset:
          _target_: omnivision.data.path_dataset.ImageWithDepthPathDataset
          path_file_list:
            - ${sunrgbd_val_imgs_path}
          label_file_list:
            - ${sunrgbd_val_labels_path}
          depth_path_file_list:
            - ${sunrgbd_train_depth_imgs_path}
          new_prefix: ${sunrgbd_prefix}
          new_depth_prefix: ${sunrgbd_prefix}
          transforms:
            - _target_: omnivision.data.transforms.transform_wrappers.VisionTransform
              base_transform:
                _target_: torchvision.transforms.Compose
                transforms:
                  - _target_: omnivision.data.transforms.image_rgbd.DepthNorm
                    max_depth: 75
                    clamp_max_before_scale: True
                  - _target_: torchvision.transforms.Resize
                    size: 224
                    interpolation: 3
                  - _target_: torchvision.transforms.CenterCrop
                    size: 224
                  - _target_: torchvision.transforms.Normalize
                    mean: [0.485, 0.456, 0.406, 0.0]
                    std: [0.229, 0.224, 0.225, 1.0]
        shuffle: True
        batch_size: 64
        num_workers: 4
        pin_memory: True
        drop_last: True
        collate_fn:
          _target_: omnivision.data.api.DefaultOmnivoreCollator
          output_key: sunrgbd
        worker_init_fn: NULL
      - _target_: omnivision.data.torch_dataset.TorchDataset
        dataset:
          _target_: omnivision.data.path_dataset.VideoPathDataset
          path_file_list:
            - ${k400_val_vids_path}
          label_file_list:
            - ${k400_val_labels_path}
          new_prefix: ${k400_prefix}
          clip_sampler:
            _target_: pytorchvideo.data.clip_sampling.ConstantClipsPerVideoSampler
            clip_duration: 10
            clips_per_video: 1
          frame_sampler:
            _target_: pytorchvideo.transforms.UniformTemporalSubsample
            num_samples: 160
          decoder: pyav
          normalize_to_0_1: True
          transforms:
            - _target_: omnivision.data.transforms.transform_wrappers.VisionTransform
              base_transform:
                _target_: torchvision.transforms.Compose
                transforms:
                - _target_: pytorchvideo.transforms.ShortSideScale
                  size: 224
                - _target_: torchvision.transforms._transforms_video.NormalizeVideo
                  mean: [0.485, 0.456, 0.406]
                  std: [0.229, 0.224, 0.225]
                - _target_: omnivision.data.transforms.pytorchvideo.TemporalCrop
                  frames_per_clip: 32
                  stride: 40
                - _target_: omnivision.data.transforms.pytorchvideo.SpatialCrop
                  crop_size: 224
                  num_crops: 3
        shuffle: False
        batch_size: 1
        num_workers: 4
        pin_memory: True
        drop_last: True
        collate_fn:
          _target_: omnivision.data.api.DefaultOmnivoreCollator
          output_key: k400
        worker_init_fn: NULL


  model:
    _target_: omnivision.model.model_wrappers.MIMOHeadWrapper
    handle_list_inputs: True
    trunk:
      _target_: omnivision.models.swin_transformer.SwinTransformer3D
      pretrained2d: False
      patch_size: [2, 4, 4]
      embed_dim: 128
      depths: [2, 2, 18, 2]
      num_heads: [4, 8, 16, 32]
      window_size: [16, 7, 7]
      mlp_ratio: 4.
      qkv_bias: True
      qk_scale: NULL
      drop_rate: 0.
      attn_drop_rate: 0.
      drop_path_rate: 0.3
      patch_norm: True
      depth_mode: summed_rgb_d_tokens
    heads:
      - head:
          _target_: omnivision.model.model_init_utils.init_parameters
          model:
            _target_: torch.nn.Linear
            in_features: 1024
            out_features: 1000
          init_fns:
            weight:
              _target_: torch.nn.init.normal_
              _partial_: True
              mean: 0
              std: 0.01
            bias:
              _target_: torch.nn.init.zeros_
              _partial_: True
        fork_module: ""
        input_key: in1k
        output_key: in1k
      - head:
          _target_: omnivision.model.model_init_utils.init_parameters
          model:
            _target_: torch.nn.Linear
            in_features: 1024
            out_features: 19
          init_fns:
            weight:
              _target_: torch.nn.init.normal_
              _partial_: True
              mean: 0
              std: 0.01
            bias:
              _target_: torch.nn.init.zeros_
              _partial_: True
        fork_module: ""
        input_key: sunrgbd
        output_key: sunrgbd
      - head:
          _target_: torch.nn.Sequential
          _args_:
            - _target_: torch.nn.Dropout
              p: 0.5
            - _target_: omnivision.model.model_init_utils.init_parameters
              model:
                _target_: torch.nn.Linear
                in_features: 1024
                out_features: 400
              init_fns:
                weight:
                  _target_: torch.nn.init.normal_
                  _partial_: True
                  mean: 0
                  std: 0.01
                bias:
                  _target_: torch.nn.init.zeros_
                  _partial_: True
        fork_module: ""
        input_key: k400
        output_key: k400
    trunk_fields:
      - input_key: NULL
        args: ["vision"]
  optim:
    gradient_clip: NULL

    amp:
      enabled: False
      amp_dtype: float16 # bfloat16 or float16

    optimizer:
      _target_: torch.optim.AdamW
    options:
      lr:
        - scheduler:
            _target_: fvcore.common.param_scheduler.CompositeParamScheduler
            schedulers:
              - _target_: fvcore.common.param_scheduler.LinearParamScheduler
                start_value: 20e-7
                end_value: 20e-4
              - _target_: fvcore.common.param_scheduler.CosineParamScheduler
                start_value: 20e-4
                end_value: 20e-5
              - _target_: fvcore.common.param_scheduler.LinearParamScheduler
                start_value: 20e-5
                end_value: 1e-6
            lengths: [0.1, 0.8, 0.1]
            interval_scaling: ['rescaled', 'rescaled', 'rescaled']
      weight_decay:
        - scheduler:
            _target_: fvcore.common.param_scheduler.ConstantParamScheduler
            value: 5e-2
        - scheduler:
            _target_: fvcore.common.param_scheduler.ConstantParamScheduler
            value: 0.0
          param_names:
              - '*.bias'
              #- '*.pos_embed'
              #- '*.cls_token'
              #- '*.absolute_pos_embed'
              - '*.relative_position_bias_table'
              # - '*.norm'
          module_cls_names: ['torch.nn.LayerNorm']
  metrics:
    train:
      in1k:
        accuracy_top1:
          _target_: omnivision.metrics.accuracy.Accuracy
          top_k: 1
        accuracy_top5:
          _target_: omnivision.metrics.accuracy.Accuracy
          top_k: 5
      sunrgbd:
        accuracy_top1:
          _target_: omnivision.metrics.accuracy.Accuracy
          top_k: 1
        accuracy_top5:
          _target_: omnivision.metrics.accuracy.Accuracy
          top_k: 5
      k400:
        accuracy_top1:
          _target_: omnivision.metrics.avg_pooled_accuracy_list_meter.AvgPooledAccuracyListMeter
          top_k: 1
        accuracy_top5:
          _target_: omnivision.metrics.avg_pooled_accuracy_list_meter.AvgPooledAccuracyListMeter
          top_k: 5
    val:
      in1k:
        accuracy_top1:
          _target_: omnivision.metrics.accuracy.Accuracy
          top_k: 1
        accuracy_top5:
          _target_: omnivision.metrics.accuracy.Accuracy
          top_k: 5
      sunrgbd:
        accuracy_top1:
          _target_: omnivision.metrics.accuracy.Accuracy
          top_k: 1
        accuracy_top5:
          _target_: omnivision.metrics.accuracy.Accuracy
          top_k: 5
      k400:
        accuracy_top1:
          _target_: omnivision.metrics.avg_pooled_accuracy_list_meter.AvgPooledAccuracyListMeter
          top_k: 1
        accuracy_top5:
          _target_: omnivision.metrics.avg_pooled_accuracy_list_meter.AvgPooledAccuracyListMeter
          top_k: 5
  loss:
    in1k:
      _target_: torch.nn.CrossEntropyLoss
    sunrgbd:
      _target_: torch.nn.CrossEntropyLoss
    k400:
      _target_: torch.nn.CrossEntropyLoss
