_target_: omnivision.model.model_wrappers.MIMOHeadWrapper
handle_list_inputs: True
trunk:
  _target_: omnivision.models.vision_transformer.VisionTransformer
  img_size:
    - 3
    - ${num_frames}
    - 224
    - 224
  embed_dim: 768
  depth: 12
  patch_size: [2, 16, 16]
  classifier_feature: global_pool
  drop_path_rate: 0.1
  use_cls_token: False
  patch_embed_type: generic
  patch_embed_params_list:
    - _target_: omnivision.models.PadIm2Video
      pad_type: repeat
      ntimes: 2
    - _target_: omnivision.models.make_conv_or_linear
      layer:
        _target_: torch.nn.Conv3d
        in_channels: 3
        out_channels: ${....embed_dim}
        kernel_size: ${....patch_size}
        stride: ${.kernel_size}
      init_weight:
        _target_: omnivision.models.reshape_and_init_as_mlp
      _recursive_: False
  attn_target:
    _target_: omnivision.models.vision_transformer.Attention
    _partial_: True
    num_heads: 12
    proj_drop: 0
    qk_scale: NULL
    qkv_bias: True
    attn_drop: 0
  learnable_pos_embed: False  # Use sinusoidal positional encoding
heads:
  - head:
      _target_: torch.nn.Sequential
      _args_:
        # - _target_: torch.nn.Dropout
        #   p: 0.0
        - _target_: omnivision.model.model_init_utils.init_parameters
          model:
            _target_: torch.nn.Linear
            in_features: 768  # 8 * 96
            out_features: ${video_labels}
          init_fns:
            weight:
              _target_: torch.nn.init.normal_
              _partial_: True
              mean: 0
              std: 0.01
            bias:
              _target_: torch.nn.init.zeros_
              _partial_: True
    fork_module: ""
    input_key: NULL
    output_key: NULL
trunk_fields:
  - input_key: NULL
    args: ["vision"]