action: ft_avsync_model_for_syncability

model:
  target: model.sync_model.Synchformer
  params:
    afeat_extractor:
      is_trainable: False
      target: model.modules.feat_extractors.audio.ast.AST
      params:
        ckpt_path: null
        extract_features: True
        # feat_type: "last_hidden_state"  # 'pooler_output' 'CLS' 'last_hidden_state' 'last_hidden_state_no_AUX':
        max_spec_t: 66  # time dimension of the input spectrogram
        factorize_freq_time: True  # used if extract_features is True; if True, feat_type will be 'last_hidden_state'
        agg_freq_module: 'TransformerEncoderLayer'  # 'AveragePooling' or 'TransformerEncoderLayer'
        agg_time_module: torch.nn.Identity
        add_global_repr: False
    vfeat_extractor:
      is_trainable: False
      target: model.modules.feat_extractors.visual.motionformer.MotionFormer
      params:
        ckpt_path: null
        extract_features: True
        factorize_space_time: True  # used if extract_features is True
        agg_space_module: 'TransformerEncoderLayer'  # 'AveragePooling' or 'TransformerEncoderLayer'
        agg_time_module: torch.nn.Identity
        add_global_repr: False
    aproj:  # audio projection head (from D of feat_extractor to D of the transformer)
      # target: model.modules.bridges.DoNothingBridge
      target: torch.nn.Linear
      params:
        in_features: 768
        out_features: ${model.params.transformer.params.n_embd}
    vproj:  # visual projection head (from D of feat_extractor to D of the transformer)
      # target: model.modules.bridges.DoNothingBridge
      target: torch.nn.Linear
      params:
        in_features: 768
        out_features: ${model.params.transformer.params.n_embd}
    transformer:
      # target: model.sync_model.GlobalTransformer
      target: model.sync_model.GlobalTransformerWithSyncabilityHead
      params:
        n_layer: 3
        n_head: 8
        n_embd: 768
        tok_pdrop: 0.0
        embd_pdrop: 0.1
        resid_pdrop: 0.1
        attn_pdrop: 0.1
        pos_emb_cfg:  # or null
          target: model.modules.transformer.RandInitPositionalEncoding
          params:
            # block_shape: [198,]
            block_shape: [184,]
            n_embd: ${model.params.transformer.params.n_embd}
        off_head_cfg: # or null
          target: torch.nn.Linear
          params:
            in_features: ${model.params.transformer.params.n_embd}
            out_features: ${data.num_off_cls}

training:
  base_learning_rate: 2e-6
  base_batch_size: 16
  num_workers: 7
  num_epochs: 10000  # just a large number (early stopper with the `patience` will stop it anyway)
  patience: 20
  to_max_metric: True
  metric_name: 'accuracy_1'
  early_stop_phase: 'valid'  # care about which phase when deciding to early stop
  use_half_precision: True
  seed: 1337
  compile: false
  skip_test: False
  run_test_only: False
  resume: False
  finetune: False
  dist_backend: 'nccl'
  max_clip_norm: 1
  lr_scheduler:
    name: 'constant_with_warmup'  # 'constant' 'constant_with_warmup'
    warmup: 1000 # iterations to recover from base_learning_rate / 100
  optimizer:
    name: adam # adamw, adam or sgd
    betas: [0.9, 0.999]
    momentum: 0.9
    weight_decay: 0

data:
  offset_type: 'grid'  # 'grid' 'uniform' 'uniform_binary' (for this one, prob_oos must be non-null)
  num_off_cls: 21  # or e.g. 21 if offset_type is 'grid'
  prob_oos: null # float or null
  max_off_sec: 2
  crop_len_sec: 5
  step_size_seg: 0.5  # 1 for no overlap, 0.5 for 50% overlap
  vids_path: 'PLACEHOLDER' # something that ends with 'CODEC_video_XXfps_YYYside_ZZZZZhz' or '..._25fps_...'
  size_before_crop: 256  # video resolution -> size_before_crop resolution -> input_size (crop resolution)
  input_size: 224
  segment_size_vframes: 16
  vfps: 25
  afps: 16000
  # n_segments: 14
  n_segments: 13
  do_offset: True
  p_color_jitter: 0.0  # ignored if 0 # ignored if 0
  p_gray_scale: 0.0  # ignored if 0
  sometimes_upscale_p: 0.0  # how often to apply the smaller crop and upscale? if 0.0 or null, works as RGBSpatialCrop
  is_spatial_crop_random: True  # if the crop transform should be random or just center crop should be used
  is_temporal_crop_random: True  # if True, the starting position of the 1st clip will be random but respecting n_segments
  audio_jitter_sec: 0.05
  p_horizontal_flip: 0.5
  p_audio_aug: 0.0
  # changing `dataset` arguments here won't affect the init call. See train_utils.get_datasets
  dataset:
    target: 'PLACEHOLDER'
    params:
      load_fixed_offsets_on: []
      vis_load_backend: 'read_video'
      size_ratio: null  # null or 1.0: full dataset; a ratio will use a proportion of it

# sequentially defined
transform_sequence_train:
  - target: dataset.transforms.EqualifyFromRight
    params:
      clip_max_len_sec: 10 # for LRS3 this can be increased to allow more training data as clips may be >10s
  - target: dataset.transforms.RGBSpatialCropSometimesUpscale
    params:
      sometimes_p: ${data.sometimes_upscale_p}
      smaller_input_size: 192 # the size of the smaller crop. null 192 112
      target_input_size: ${data.input_size}
      is_random: ${data.is_spatial_crop_random}
  - target: dataset.transforms.TemporalCropAndOffsetForSyncabilityTraining
    params:
      max_off_sec: ${data.max_off_sec}  # for the synchronizable videos
      max_wiggle_sec: ${data.audio_jitter_sec}
      do_offset: ${data.do_offset}
      grid_size: ${data.num_off_cls}  # synchronizable + 2 (for the non-synchronizable, audio ealier or later)
      segment_size_vframes: ${data.segment_size_vframes}
      n_segments: ${data.n_segments}
      step_size_seg: ${data.step_size_seg}
      vfps: ${data.vfps}
  - target: dataset.transforms.RandomApplyColorDistortion
    params:
      p_color_jitter: ${data.p_color_jitter}
      s: 1.0 # strength of the color jitter if applied
      p_gray_scale: ${data.p_gray_scale}
  - target: dataset.transforms.RandomHorizontalFlip
    params:
      p: ${data.p_horizontal_flip}
  - target: dataset.transforms.AudioRandomReverb
    params:
      p: ${data.p_audio_aug}
  - target: dataset.transforms.AudioRandomVolume
    params:
      p: ${data.p_audio_aug}
      gain: 2.0
      gain_type: 'amplitude'
  - target: dataset.transforms.AudioRandomPitchShift
    params:
      p: ${data.p_audio_aug}
      shift: 1000
  - target: dataset.transforms.AudioRandomLowpassFilter
    params:
      p: ${data.p_audio_aug}
      cutoff_freq: 100
  - target: dataset.transforms.AudioRandomGaussNoise
    params:
      p: ${data.p_audio_aug}
      amplitude: 0.01
  - target: dataset.transforms.GenerateMultipleSegments
    params:
      segment_size_vframes: ${data.segment_size_vframes}
      n_segments: ${data.n_segments}
      is_start_random: ${data.is_temporal_crop_random}
      step_size_seg: ${data.step_size_seg}
  # - target: dataset.transforms.RGBToFloatToZeroOne
  - target: dataset.transforms.RGBToHalfToZeroOne
  - target: dataset.transforms.RGBNormalize
    params:
      mean: [0.5, 0.5, 0.5] # motionformer normalization
      std: [0.5, 0.5, 0.5]
  - target: dataset.transforms.AudioMelSpectrogram
    params:
      sample_rate: ${data.afps}
      win_length: 400  # 25 ms * 16 kHz
      hop_length: 160  # 10 ms * 16 kHz
      n_fft: 1024  # 2^(ceil(log2(window_size * sampling_rate)))
      n_mels: 128  # as in AST
  - target: dataset.transforms.AudioLog
  - target: dataset.transforms.PadOrTruncate
    params:
      max_spec_t: ${model.params.afeat_extractor.params.max_spec_t}
  - target: dataset.transforms.AudioNormalizeAST
    params:
      mean: -4.2677393  # AST, pre-trained on AudioSet
      std: 4.5689974
  - target: dataset.transforms.PermuteStreams
    params:
      einops_order_audio: "S F T -> S 1 F T"
      einops_order_rgb: "S T C H W -> S T C H W"  # same

transform_sequence_test:
  - target: dataset.transforms.EqualifyFromRight
  - target: dataset.transforms.RGBSpatialCrop
    params:
      input_size: ${data.input_size}
      is_random: False
  - target: dataset.transforms.TemporalCropAndOffsetForSyncabilityTraining
    params:
      max_off_sec: ${data.max_off_sec}  # for the synchronizable videos
      do_offset: ${data.do_offset}
      grid_size: ${data.num_off_cls}
      segment_size_vframes: ${data.segment_size_vframes}
      n_segments: ${data.n_segments}
      step_size_seg: ${data.step_size_seg}
      vfps: ${data.vfps}
  - target: dataset.transforms.GenerateMultipleSegments
    params:
      segment_size_vframes: ${data.segment_size_vframes}
      n_segments: ${data.n_segments}
      is_start_random: False
      step_size_seg: ${data.step_size_seg}
  # - target: dataset.transforms.RGBToFloatToZeroOne
  - target: dataset.transforms.RGBToHalfToZeroOne
  - target: dataset.transforms.RGBNormalize
    params:
      mean: [0.5, 0.5, 0.5] # motionformer normalization
      std: [0.5, 0.5, 0.5]
  - target: dataset.transforms.AudioMelSpectrogram
    params:
      sample_rate: ${data.afps}
      win_length: 400  # 25 ms * 16 kHz
      hop_length: 160  # 10 ms * 16 kHz
      n_fft: 1024  # 2^(ceil(log2(window_size * sampling_rate)))
      n_mels: 128  # as in AST
  - target: dataset.transforms.AudioLog
  - target: dataset.transforms.PadOrTruncate
    params:
      max_spec_t: ${model.params.afeat_extractor.params.max_spec_t}
  - target: dataset.transforms.AudioNormalizeAST
    params:
      mean: -4.2677393  # AST, pre-trained on AudioSet
      std: 4.5689974
  - target: dataset.transforms.PermuteStreams
    params:
      einops_order_audio: "S F T -> S 1 F T"
      einops_order_rgb: "S T C H W -> S T C H W"  # same

logging:
  logdir: './logs/sync_models'
  log_code_state: True
  log_frequency: 20
  # patterns to ignore when backing up the code folder
  patterns_to_ignore: ['logs', '.git', '__pycache__', 'data', '*.pt', 'sbatch_logs', '*.mp4', '*.wav', '*.jpg', '*.gif', 'misc*']
  vis_segment_sim: True
  log_max_items: 200000  # max number of items to keep in running_results: lowering it helps with OOM; the higher the closer to the real metric
  use_wandb: False
