# @package _global_
num_frames: 16
batch_size: 16
num_workers: 10
epochs: 50
video_labels: 400
accum_step: 1
lr_scale_factor: 1 #assume bsize=4096
pretrained_checkpoint_path: ???
ensemble_num_models: ???
ensemble_combine_mode: ???
ensemble_use_beta: False
mixing_lr_factor: 1

defaults:
  - /experiments/base.yaml
  - /trainer/model: vitbase_ft.yaml
  - /trainer/checkpoint: default.yaml
#  - trainer: k200_config.yaml MISSING ARG
  - _self_

launcher:
  gpus_per_node: ???
  num_nodes: ???

trainer:
  max_epochs: ${epochs}

  distributed:
    comms_dtype: float16 # NULL, float16, bfloat16


  optim:
    gradient_clip: NULL
    amp:
      enabled: False
      amp_dtype: float16 # bfloat16 or float16
    optimizer:
      _target_: torch.optim.AdamW
    param_group_modifiers:
      - _target_: omnivision.optim.layer_decay_param_modifier.layer_decay_param_modifier
        _partial_: True
        layer_decay_value: 0.75
    options:
      lr:
        - scheduler:
            _target_: fvcore.common.param_scheduler.CompositeParamScheduler
            schedulers:
              - _target_: fvcore.common.param_scheduler.CosineParamScheduler
                start_value: 16e-3/${lr_scale_factor}*${mixing_lr_factor}
                end_value: 4e-6/${lr_scale_factor}*${mixing_lr_factor}
            lengths: [ 1 ]  # 5 epoch warmup [0.125, 0.875]
            interval_scaling: [ 'rescaled' ]
        - scheduler:
            _target_: fvcore.common.param_scheduler.CompositeParamScheduler
            schedulers:
              - _target_: fvcore.common.param_scheduler.LinearParamScheduler
                start_value: 4e-6/${lr_scale_factor}
                end_value: 16e-3/${lr_scale_factor}
              - _target_: fvcore.common.param_scheduler.CosineParamScheduler
                start_value: ${..0.end_value}/${lr_scale_factor}
                end_value: 4e-6/${lr_scale_factor}
            lengths: [0, 1]  # 5 epoch warmup [0.125, 0.875]
            interval_scaling: ['rescaled', 'rescaled']
          param_names:
            - 'model.*'

      weight_decay:
        - scheduler:
            _target_: fvcore.common.param_scheduler.ConstantParamScheduler
            value: 1e-4
        - scheduler:
            _target_: fvcore.common.param_scheduler.ConstantParamScheduler
            value: 0.0
          param_names:
            # TODO should we allow unused param names or classes?
            - '*.bias'
#          # TODO: allow other forms for class names
##          module_cls_names: ['torch.nn.LayerNorm']
  logging:
    tensorboard_writer:
      _target_: omnivision.logger.make_tensorboard_logger
      log_dir: ${launcher.experiment_log_dir}/tensorboard
      flush_secs: 120
    log_dir: ${launcher.experiment_log_dir}/logs
    log_freq: 100
