# @package _global_

defaults:
  - /experiments/base.yaml
  - _self_

launcher:
  num_nodes: 4
  gpus_per_node: 8

trainer:
  max_epochs: 1
  mode: val

  optim:
    optimizer: NULL
    options: NULL

  data:
    val:
      _target_: omnivision.data.concat_dataset.ConcatDataset
      max_steps: sum
      datasets:
      - _target_: omnivision.data.torch_dataset.TorchDataset
        dataset:
          _target_: omnivision.data.path_dataset.ImagePathDataset
          path_file_list:
            - ${in1k_val_imgs_path}
          label_file_list:
            - ${in1k_val_labels_path}
          new_prefix: ${in1k_prefix}
          transforms:
            - _target_: omnivision.data.transforms.transform_wrappers.VisionTransform
              base_transform:
                _target_: torchvision.transforms.Compose
                transforms:
                  - _target_: torchvision.transforms.Resize
                    size: 224
                    interpolation: 3
                  - _target_: torchvision.transforms.CenterCrop
                    size: 224
                  - _target_: torchvision.transforms.ToTensor
                  - _target_: torchvision.transforms.Normalize
                    mean: [0.485, 0.456, 0.406]
                    std: [0.229, 0.224, 0.225]
                  # the omnivore public model expects a 5 dimensional image
                  - _target_: omnivision.data.transforms.image_video.ImageToSingleFrameVideo
        shuffle: False
        batch_size: 25
        num_workers: 8
        pin_memory: True
        drop_last: False
        collate_fn:
          _target_: omnivision.data.api.DefaultOmnivoreCollator
          output_key: in1k
          batch_kwargs:
            # the omnivore public model expects an input_type key during forward
            model_fwd_kwargs:
              input_type: image
        worker_init_fn: NULL
      - _target_: omnivision.data.torch_dataset.TorchDataset
        dataset:
          _target_: omnivision.data.path_dataset.VideoPathDataset
          path_file_list:
            - ${k400_val_vids_path}
          label_file_list:
            - ${k400_val_labels_path}
          new_prefix: ${k400_prefix}
          clip_sampler:
            _target_: pytorchvideo.data.clip_sampling.ConstantClipsPerVideoSampler
            clip_duration: 10
            clips_per_video: 1
          frame_sampler:
            _target_: pytorchvideo.transforms.UniformTemporalSubsample
            num_samples: 160
          decoder: pyav
          normalize_to_0_1: True
          transforms:
            - _target_: omnivision.data.transforms.transform_wrappers.VisionTransform
              base_transform:
                _target_: torchvision.transforms.Compose
                transforms:
                - _target_: pytorchvideo.transforms.ShortSideScale
                  size: 224
                - _target_: torchvision.transforms._transforms_video.NormalizeVideo
                  mean: [0.485, 0.456, 0.406]
                  std: [0.229, 0.224, 0.225]
                - _target_: omnivision.data.transforms.pytorchvideo.TemporalCrop
                  frames_per_clip: 32
                  stride: 40
                - _target_: omnivision.data.transforms.pytorchvideo.SpatialCrop
                  crop_size: 224
                  num_crops: 3
        shuffle: False
        batch_size: 4
        num_workers: 8
        pin_memory: True
        drop_last: True
        collate_fn:
          _target_: omnivision.data.api.DefaultOmnivoreCollator
          output_key: k400
          batch_kwargs:
            # the omnivore public model expects an input_type key during forward
            model_fwd_kwargs:
              input_type: video
        worker_init_fn: NULL

  model:
    _target_: omnivision.model.model_wrappers.MIMOHeadWrapper
    handle_list_inputs: True
    trunk:
      _target_: omnivore.models.omnivore_model.omnivore_swinT
    heads:
      - head:
          _target_: torch.nn.Identity
        fork_module: ""
        input_key: NULL
        output_key: NULL
    trunk_fields:
      - input_key: NULL
        args: ["vision"]

  metrics:
    val:
      in1k:
        accuracy_top1:
          _target_: omnivision.metrics.accuracy.Accuracy
          top_k: 1
        accuracy_top5:
          _target_: omnivision.metrics.accuracy.Accuracy
          top_k: 5
      k400:
        accuracy_top1:
          _target_: omnivision.metrics.avg_pooled_accuracy_list_meter.AvgPooledAccuracyListMeter
          top_k: 1
        accuracy_top5:
          _target_: omnivision.metrics.avg_pooled_accuracy_list_meter.AvgPooledAccuracyListMeter
          top_k: 5
  loss:
    in1k:
      _target_: torch.nn.CrossEntropyLoss
    k400:
      _target_: torch.nn.CrossEntropyLoss

  logging:
    tensorboard_writer:
      _target_: omnivision.logger.make_tensorboard_logger
      log_dir:  ${launcher.experiment_log_dir}/tensorboard
      flush_secs: 120
    log_dir: ${launcher.experiment_log_dir}/logs
    log_freq: 10

  checkpoint:
    save_dir: ${launcher.experiment_log_dir}/checkpoints
    save_freq: 0 # 0 only last checkpoint is saved.
