# @package _global_

defaults:
  - /experiments/base.yaml
  - _self_

launcher:
  gpus_per_node: NULL
  num_nodes: 1

trainer:
  max_epochs: 5
  # accelerator: "cpu"

  data:
    train:
      _target_: omnivision.data.torch_dataset.TorchDataset
      dataset:
        _target_: omnivision.data.synthetic_dataset.SyntheticDataset
        tensor_shape: [3, 224, 224]
        length: 10
      shuffle: True
      batch_size: 2
      num_workers: 2
      pin_memory: True
      drop_last: True
      collate_fn:
        _target_: omnivision.data.api.DefaultOmnivoreCollator
        output_key: synth_image
        batch_transforms:
        - _target_: omnivision.data.transforms.cutmixup.CutMixUp
          mixup_alpha: 0.8 # mixup alpha value, mixup is active if > 0.
          cutmix_alpha: 1.0 # cutmix alpha value, cutmix is active if > 0.
          prob: 1.0 # probability of applying mixup or cutmix per batch or element
          switch_prob: 0.5 # probability of switching to cutmix instead of mixup when both are active
          mode: batch # how to apply mixup/cutmix params (per 'batch', 'pair' (pair of elements), 'elem' (element)
          correct_lam: True # apply lambda correction when cutmix bbox clipped by image borders
          label_smoothing: 0.1 # apply label smoothing to the mixed target tensor
          num_classes: 1000 # number of classes for target
      worker_init_fn: NULL
    val:
      _target_: omnivision.data.torch_dataset.TorchDataset
      dataset:
        _target_: omnivision.data.synthetic_dataset.SyntheticDataset
        tensor_shape: [3, 224, 224]
        length: 10
      shuffle: False
      batch_size: 2
      num_workers: 2
      pin_memory: True
      drop_last: True
      collate_fn:
        _target_: omnivision.data.api.DefaultOmnivoreCollator
        output_key: synth_image
      worker_init_fn: NULL

  model:
    _target_: omnivision.model.model_wrappers.MIMOHeadWrapper
    trunk:
      _target_: omnivision.models.swin_transformer.SwinTransformer3D
      pretrained: NULL
      pretrained2d: False
      patch_size: [1, 4, 4]
      embed_dim: 96
      depths: [2, 2, 6, 2]
      num_heads: [3, 6, 12, 24]
      window_size: [1, 7, 7]
      mlp_ratio: 4.
      qkv_bias: True
      qk_scale: NULL
      drop_rate: 0.
      attn_drop_rate: 0.
      drop_path_rate: 0.2
      patch_norm: True
    heads:
      - head:
          _target_: omnivision.model.model_init_utils.init_parameters
          model:
            _target_: torch.nn.Linear
            in_features: 768  # 8 * 96
            out_features: 1000
          init_fns:
            weight:
              _target_: torch.nn.init.normal_
              _partial_: True
              mean: 0
              std: 0.01
            bias:
              _target_: torch.nn.init.zeros_
              _partial_: True
        fork_module: ""
        input_key: NULL
        output_key: NULL
    trunk_fields:
      - input_key: NULL
        args: ["vision"]

  optim:
    optimizer:
      _target_: torch.optim.AdamW
    options:
      lr:
        - scheduler:
            _target_: fvcore.common.param_scheduler.CompositeParamScheduler
            schedulers:
              - _target_: fvcore.common.param_scheduler.LinearParamScheduler
                start_value: 10e-7
                end_value: 10e-4
              - _target_: fvcore.common.param_scheduler.CosineParamScheduler
                start_value: 10e-4
                end_value: 10e-6
            lengths: [0.07, 0.93]
            interval_scaling: ['rescaled', 'rescaled']
      weight_decay:
        - scheduler:
            _target_: fvcore.common.param_scheduler.ConstantParamScheduler
            value: 0.05
        - scheduler:
            _target_: fvcore.common.param_scheduler.ConstantParamScheduler
            value: 0.0
          param_names:
          # TODO should we allow unused param names or classes?
             - '*.bias'
             #- '*.pos_embed'
             #- '*.cls_token'
             #- '*.absolute_pos_embed'
             - '*.relative_position_bias_table'
             #- '*.norm'
          # TODO: allow other forms for class names
          module_cls_names: ['torch.nn.LayerNorm']
  
  metrics:
    train:
      synth_image:
        accuracy_top1:
          _target_: omnivision.metrics.accuracy.Accuracy
          top_k: 1
        accuracy_top5:
          _target_: omnivision.metrics.accuracy.Accuracy
          top_k: 5
    val:
      synth_image:
        accuracy_top1:
          _target_: omnivision.metrics.accuracy.Accuracy
          top_k: 1
        accuracy_top5:
          _target_: omnivision.metrics.accuracy.Accuracy
          top_k: 5
  loss:
    synth_image:
      _target_: torch.nn.CrossEntropyLoss
