# @package _global_

defaults:
  - /experiments/base.yaml
  - _self_

launcher:
  gpus_per_node: 2
  num_nodes: 1

trainer:
  max_epochs: 5

  distributed:
   comms_dtype: float16 # NULL, float16, bfloat16

  limit_train_batches: 15
  limit_val_batches: 15

  data:
    train:
      _target_: omnivision.data.torch_dataset.TorchDataset
      dataset:
        _target_: omnivision.data.path_dataset.ImagePathDataset
        path_file_list:
          - ${in1k_train_imgs_path}
        label_file_list:
          - ${in1k_train_labels_path}
        new_prefix: ${in1k_prefix}
        transforms:
          - _target_: omnivision.data.transforms.transform_wrappers.VisionTransform
            base_transform:
              _target_: torchvision.transforms.Compose
              transforms:
                - _target_: torchvision.transforms.RandomResizedCrop
                  size: 224
                  interpolation: 3
                - _target_: torchvision.transforms.RandomHorizontalFlip
                - _target_: omnivision.data.transforms.rand_auto_aug.RandAugment  # Essentially autoagument rand-m9-mstd0.5-inc1
                  magnitude: 9
                  magnitude_std: 0.5
                  increasing_severity: True
                - _target_: torchvision.transforms.ColorJitter
                  brightness: 0.4
                  contrast: 0.4
                  saturation: 0.4
                  hue: 0.4
                - _target_: torchvision.transforms.ToTensor
                - _target_: torchvision.transforms.RandomErasing
                  p: .25
                - _target_: torchvision.transforms.Normalize
                  mean: [0.485, 0.456, 0.406]
                  std: [0.229, 0.224, 0.225]
      shuffle: True
      batch_size: 64
      num_workers: 8
      pin_memory: True
      drop_last: True
      collate_fn:
        _target_: omnivision.data.api.DefaultOmnivoreCollator
        output_key: in1k
        batch_transforms:
        - _target_: omnivision.data.transforms.cutmixup.CutMixUp
          mixup_alpha: 0.8 # mixup alpha value, mixup is active if > 0.
          cutmix_alpha: 1.0 # cutmix alpha value, cutmix is active if > 0.
          prob: 1.0 # probability of applying mixup or cutmix per batch or element
          switch_prob: 0.5 # probability of switching to cutmix instead of mixup when both are active
          mode: batch # how to apply mixup/cutmix params (per 'batch', 'pair' (pair of elements), 'elem' (element)
          correct_lam: True # apply lambda correction when cutmix bbox clipped by image borders
          label_smoothing: 0.1 # apply label smoothing to the mixed target tensor
          num_classes: 1000 # number of classes for target
      worker_init_fn: NULL
    val:
      _target_: omnivision.data.torch_dataset.TorchDataset
      dataset:
        _target_: omnivision.data.path_dataset.ImagePathDataset
        path_file_list:
          - ${in1k_val_imgs_path}
        label_file_list:
          - ${in1k_val_labels_path}
        new_prefix: ${in1k_prefix}
        transforms:
          - _target_: omnivision.data.transforms.transform_wrappers.VisionTransform
            base_transform:
              _target_: torchvision.transforms.Compose
              transforms:
                - _target_: torchvision.transforms.Resize
                  size: 224
                  interpolation: 3
                - _target_: torchvision.transforms.CenterCrop
                  size: 224
                - _target_: torchvision.transforms.ToTensor
                - _target_: torchvision.transforms.Normalize
                  mean: [0.485, 0.456, 0.406]
                  std: [0.229, 0.224, 0.225]
      shuffle: False
      batch_size: 64
      num_workers: 8
      pin_memory: True
      drop_last: True
      collate_fn:
        _target_: omnivision.data.api.DefaultOmnivoreCollator
        output_key: in1k
      worker_init_fn: NULL


  model:
    _target_: omnivision.model.model_wrappers.MIMOHeadWrapper
    trunk:
      _target_: omnivision.models.swin_transformer.SwinTransformer3D
      pretrained: NULL
      pretrained2d: False
      patch_size: [1, 4, 4]
      embed_dim: 96
      depths: [2, 2, 6, 2]
      num_heads: [3, 6, 12, 24]
      window_size: [1, 7, 7]
      mlp_ratio: 4.
      qkv_bias: True
      qk_scale: NULL
      drop_rate: 0.
      attn_drop_rate: 0.
      drop_path_rate: 0.2
      patch_norm: True
    heads:
    - head:
        _target_: omnivision.model.model_init_utils.init_parameters
        model:
          _target_: torch.nn.Linear
          in_features: 768  # 8 * 96
          out_features: 1000
        init_fns:
          weight:
            _target_: torch.nn.init.normal_
            _partial_: True
            mean: 0
            std: 0.01
          bias:
            _target_: torch.nn.init.zeros_
            _partial_: True
      fork_module: ""
      input_key: NULL
      output_key: NULL
    trunk_fields:
      - input_key: NULL
        args: ["vision"]
  optim:
    gradient_clip:
      _partial_: True
      _target_: torch.nn.utils.clip_grad_norm_
      max_norm: 1.0
      norm_type: 2

    amp:
      enabled: True
      amp_dtype: float16 # bfloat16 or float16

    optimizer:
      _target_: torch.optim.AdamW

    options:
      lr:
        - scheduler:
            _target_: fvcore.common.param_scheduler.CompositeParamScheduler
            schedulers:
              - _target_: fvcore.common.param_scheduler.LinearParamScheduler
                start_value: 10e-7
                end_value: 10e-4
              - _target_: fvcore.common.param_scheduler.CosineParamScheduler
                start_value: 10e-4
                end_value: 10e-6
            lengths: [0.07, 0.93]
            interval_scaling: ['rescaled', 'rescaled']
      weight_decay:
        - scheduler:
            _target_: fvcore.common.param_scheduler.ConstantParamScheduler
            value: 0.05
        - scheduler:
            _target_: fvcore.common.param_scheduler.ConstantParamScheduler
            value: 0.0
          param_names:
            - '*.bias'
            #- '*.pos_embed'
            #- '*.cls_token'
            #- '*.absolute_pos_embed'
            - '*.relative_position_bias_table'
            #- '*.norm'
          module_cls_names: ['torch.nn.LayerNorm']

  metrics:
    train:
      in1k:
        accuracy_top1:
          _target_: omnivision.metrics.accuracy.Accuracy
          top_k: 1
        accuracy_top5:
          _target_: omnivision.metrics.accuracy.Accuracy
          top_k: 5
    val:
      in1k:
        accuracy_top1:
          _target_: omnivision.metrics.accuracy.Accuracy
          top_k: 1
        accuracy_top5:
          _target_: omnivision.metrics.accuracy.Accuracy
          top_k: 5

  loss:
    in1k:
      _target_: torch.nn.CrossEntropyLoss
