_target_: omnivision.model.model_wrappers.MIMOHeadWrapper
trunk:
  _target_: omnivision.models.vision_transformer.VisionTransformer
  img_size:
    - 3
    - ${num_frames}
    - 224
    - 224
  embed_dim: 768
  depth: 12
  patch_size: [2, 16, 16]
  classifier_feature: global_pool
  drop_path_rate: 0.0
  use_cls_token: False
  patch_embed_type: generic
  patch_embed_params_list:
    - _target_: omnivision.models.PadIm2Video
      pad_type: repeat
      ntimes: 2
    - _target_: omnivision.models.make_conv_or_linear
      layer:
        _target_: torch.nn.Conv3d
        in_channels: 3
        out_channels: ${....embed_dim}
        kernel_size: ${....patch_size}
        stride: ${.kernel_size}
      init_weight:
        _target_: omnivision.models.reshape_and_init_as_mlp
      _recursive_: False
  attn_target:
    _target_: omnivision.models.vision_transformer.Attention
    _partial_: True
    num_heads: 12
    proj_drop: 0
    qk_scale: NULL
    qkv_bias: True
    attn_drop: 0
  learnable_pos_embed: False  # Use sinusoidal positional encoding
  masked_image_modeling: True
  patch_dropping: True
  decoder:
    _target_: omnivision.models.vision_transformer.Decoder
    _partial_: True
    embed_dim: ${trainer.model.trunk.embed_dim}
    decoder_depth: 4
    decoder_embed_dim: 384
    learnable_pos_embed: False  # Use sinusoidal positional encoding
    attn_target:
      _target_: omnivision.models.vision_transformer.Attention
      _partial_: True
      num_heads: 16
      proj_drop: 0
      qk_scale: NULL
      qkv_bias: True
      attn_drop: 0
heads:
  - head:
      _target_: omnivision.models.heads.mae_head.MAEHead
      in_features: ${trainer.model.trunk.decoder.decoder_embed_dim}
      # 3 x 2 x 16 x 16
      out_features: ${times:${times:${times:${trainer.model.trunk.img_size.0},${trainer.model.trunk.patch_size.0}},${trainer.model.trunk.patch_size.1}},${trainer.model.trunk.patch_size.2}}
    input_key: NULL
    output_key: NULL
    fork_module: ""
trunk_fields:
  - input_key: NULL
    args: ["vision"]
    kwargs: {"mask": "mask"}