# @package _global_

defaults:
  - /experiments/base.yaml
  - _self_

launcher:
  num_nodes: 1
  gpus_per_node: 2

trainer:
  max_epochs: 1
  mode: val

  optim:
    optimizer: NULL
    options: NULL

  data:
    val:
      _target_: omnivision.data.torch_dataset.TorchDataset
      dataset:
        _target_: omnivision.data.path_dataset.ImagePathDataset
        path_file_list:
          - ${in1k_val_imgs_path}
        label_file_list:
          - ${in1k_val_labels_path}
        new_prefix: ${in1k_prefix}
        transforms:
          - _target_: omnivision.data.transforms.transform_wrappers.VisionTransform
            base_transform:
              _target_: torchvision.transforms.Compose
              transforms:
                - _target_: torchvision.transforms.Resize
                  size: 224
                  interpolation: 3
                - _target_: torchvision.transforms.CenterCrop
                  size: 224
                - _target_: torchvision.transforms.ToTensor
                - _target_: torchvision.transforms.Normalize
                  mean: [0.485, 0.456, 0.406]
                  std: [0.229, 0.224, 0.225]
                # the omnivore public model expects a 5 dimensional image
                - _target_: omnivision.data.transforms.image_video.ImageToSingleFrameVideo
      shuffle: False
      batch_size: 25
      num_workers: 8
      pin_memory: True
      drop_last: False
      collate_fn:
        _target_: omnivision.data.api.DefaultOmnivoreCollator
        output_key: in1k
        batch_kwargs:
          # the omnivore public model expects an input_type key during forward
          model_fwd_kwargs:
            input_type: image
      worker_init_fn: NULL

  model:
    _target_: omnivision.model.model_wrappers.MIMOHeadWrapper
    handle_list_inputs: True
    trunk:
      _target_: omnivore.models.omnivore_model.omnivore_swinT
    heads:
      - head:
          _target_: torch.nn.Identity
        fork_module: ""
        input_key: NULL
        output_key: NULL
    trunk_fields:
      - input_key: NULL
        args: ["vision"]

  metrics:
    val:
      in1k:
        accuracy_top1:
          _target_: omnivision.metrics.accuracy.Accuracy
          top_k: 1
        accuracy_top5:
          _target_: omnivision.metrics.accuracy.Accuracy
          top_k: 5
  
  loss:
    in1k:
      _target_: torch.nn.CrossEntropyLoss