action: train_avclip
debug: False

## Segment AVCLIP config
model:
  target: model.modules.feat_extractors.train_clip_src.open_clip.model.AVCLIP
  params:
    init_scale: ${training.init_scale}
    clamp_scale_min: ${training.clamp_scale_min}
    clamp_scale_max: ${training.clamp_scale_max}
    n_embd: 768
    gather_for_loss: False
    afeat_extractor:
      is_trainable: True
      target: model.modules.feat_extractors.audio.ast.AST
      params:
        # ckpt_path: null   # == from scratch
        ckpt_path: "MIT/ast-finetuned-audioset-10-10-0.4593"
        extract_features: True
        # feat_type: "last_hidden_state"  # 'pooler_output' 'CLS' 'last_hidden_state' 'last_hidden_state_no_AUX':
        max_spec_t: 66  # time dimension of the input spectrogram
        factorize_freq_time: True  # used if extract_features is True; if True, feat_type will be 'last_hidden_state'
        agg_freq_module: 'TransformerEncoderLayer'  # 'AveragePooling' or 'TransformerEncoderLayer'
        agg_time_module: 'AveragePooling'  # 'AveragePooling' or 'TransformerEncoderLayer'
        add_global_repr: False  # used if extract_features is True
        agg_segments_module: 'AveragePooling'  # 'AveragePooling' or 'TransformerEncoderLayer'
         # if None, will use the default value in the model (16=10sec//0.64sec+1)
        max_segments: ${data.n_segments_train}  # TODO: if agg_segments_module is with PE, what to do during eval?
    vfeat_extractor:
      is_trainable: True
      target: model.modules.feat_extractors.visual.motionformer.MotionFormer
      params:
        # ckpt_path: null  # from scratch
        # ckpt_path: "./model/modules/feat_extractors/visual/motionformer_src/ssv2_motionformer_224_16x4.pyth"
        # ckpt_path: "./model/modules/feat_extractors/visual/motionformer_src/ssv2_joint_224_16x4.pyth"
        ckpt_path: "./model/modules/feat_extractors/visual/motionformer_src/ssv2_divided_224_16x4.pyth"
        extract_features: True
        factorize_space_time: True  # used if extract_features is True
        agg_space_module: 'TransformerEncoderLayer'  # 'AveragePooling' or 'TransformerEncoderLayer'
        agg_time_module: 'AveragePooling'  # 'AveragePooling' or 'TransformerEncoderLayer'
        add_global_repr: False  # used if extract_features is True
        agg_segments_module: 'AveragePooling'  # 'AveragePooling' or 'TransformerEncoderLayer'
         # if None, will use the default value in the model (16=10sec//0.64sec+1)
        max_segments: ${data.n_segments_train}  # TODO: what if during eval it'll be different?
    aproj:
      target: model.modules.bridges.DoNothingBridge
      params:
        in_features: ${model.params.n_embd}
        out_features: ${model.params.n_embd}
    vproj:
      target: model.modules.bridges.DoNothingBridge
      params:
        in_features: ${model.params.n_embd}
        out_features: ${model.params.n_embd}

training:
  resume: null
  learning_rate: 0.0001
  lr_cooldown_end: 0.0
  lr_cooldown_power: 1.0
  base_batch_size: 2
  queue_size: 0
  for_loop_segment_fwd: False  # if True, the forward pass will be done in a for loop over the segments, otherwise treated as a batch dim (B*S)
  grad_checkpointing: False
  momentum: 0.995
  num_workers: 8
  num_epochs: 100
  patience: 20
  epochs_cooldown: null
  val_frequency: 1
  compile: False  # does not work with DDP (Exception: Please convert all Tensors to FakeTensors first)
  to_max_metric: True
  metric_name: 'precision'
  early_stop_phase: 'valid'  # care about which phase when deciding to early stop
  precision: 'amp'
  alpha: 0.0  # weight the pseudo-targets in the targets (interpolated linearly with this param alpha).
  seed: 1337
  run_test_only: False
  dist_backend: 'nccl'
  dist_url: env://
  ddp_static_graph: False
  no_set_device_rank: False
  remote_sync: null
  remote_sync_frequency: 300
  remote_sync_protocol: s3
  distill_model: null
  distill_pretrained: null
  force_image_size: null
  lock_rgb: False
  lock_audio: False
  lock_rgb_unlocked_groups: 0
  lock_audio_unlocked_layers: 0
  lock_rgb_freeze_bn_stats: False
  lock_audio_freeze_layer_norm: False
  trace: False
  use_bn_sync: False
  max_clip_norm: 1.0
  init_scale: 0.07
  clamp_scale_min: 0.001
  clamp_scale_max: 0.5
  run_shifted_win_val: True
  run_shifted_win_val_winsize_valid: 8  # (in segments) 5sec/0.64sec~8  where 5sec is the expected len of the downstream task dataset
  run_shifted_win_val_winsize_train: 8  # should be < than data.n_segments_train
  segment_loss_weight: 1.0
  global_loss_weight: 1.0
  skip_scheduler: False
  lr_scheduler:
    name: 'cosine'  # 'cosine' 'const' 'const-cooldown'
    warmup: 1000
  optimizer:
    name: adamw # adamw, adam or sgd
    betas: [0.9, 0.999]
    momentum: 0.9
    weight_decay: 0.0

data:
  dataset_type: sparsesync
  vids_path: 'PLACEHOLDER' # something that ends with 'CODEC_video_XXfps_YYYside_ZZZZZhz' or '..._25fps_...'
  size_before_crop: 256  # video resolution -> size_before_crop resolution -> input_size (crop resolution)
  input_size: 224
  segment_size_vframes: 16
  is_spatial_crop_random: True  # if the crop transform should be random or just center crop should be used
  is_temporal_crop_random: True  # if True, the starting position of the 1st clip will be random but respecting n_segments
  sometimes_upscale_p: 0.2  # how often to apply the smaller crop and upscale? if 0.0 or null, works as RGBSpatialCrop
  p_horizontal_flip: 0.5
  p_audio_aug: 0.2
  p_color_jitter: 0.2  # ignored if 0 # ignored if 0
  p_gray_scale: 0.2  # ignored if 0
  # if null, the max number of segments will be made (i.e. no gap between segments), otherwise segments will be uniformly ditributed
  n_segments_train: 14
  n_segments_valid: 14
  audio_jitter_sec: 0.05  # offset the audio by small amount U ~ [-jitter_s, jitter_s]
  step_size_seg: 1.0  # step size between segments in segments (1 = no overlap between segments, 0.5 = 50% overlap)
  # changing `dataset` arguments here won't affect the init call. See train_utils.get_datasets
  dataset:
    target: 'PLACEHOLDER'
    params:
      load_fixed_offsets_on: []
      vis_load_backend: 'read_video'
      size_ratio: null  # null or 1.0: full dataset; a ratio will use a proportion of it

# sequentially defined
transform_sequence_train:
  - target: dataset.transforms.EqualifyFromRight
    params:
      clip_max_len_sec: 10 # for LRS3 this can be increased to allow more training data as clips may be >10s
  - target: dataset.transforms.RGBSpatialCropSometimesUpscale
    params:
      sometimes_p: ${data.sometimes_upscale_p}
      smaller_input_size: 192 # the size of the smaller crop. null 192 112
      target_input_size: ${data.input_size}
      is_random: ${data.is_spatial_crop_random}
  - target: dataset.transforms.GenerateMultipleSegments
    params:
      segment_size_vframes: ${data.segment_size_vframes}
      n_segments: ${data.n_segments_train}
      is_start_random: ${data.is_temporal_crop_random}
      audio_jitter_sec: ${data.audio_jitter_sec}
      step_size_seg: ${data.step_size_seg}
  - target: dataset.transforms.RandomApplyColorDistortion
    params:
      p_color_jitter: ${data.p_color_jitter}
      s: 1.0 # strength of the color jitter if applied
      p_gray_scale: ${data.p_gray_scale}
  - target: dataset.transforms.RandomHorizontalFlip
    params:
      p: ${data.p_horizontal_flip}
  - target: dataset.transforms.RGBToHalfToZeroOne # RGBToFloatToZeroOne
  - target: dataset.transforms.RGBNormalize
    params:
      mean: [0.5, 0.5, 0.5] # motionformer normalization
      std: [0.5, 0.5, 0.5]
  - target: dataset.transforms.AudioRandomReverb
    params:
      p: ${data.p_audio_aug}
  - target: dataset.transforms.AudioRandomVolume
    params:
      p: ${data.p_audio_aug}
      gain: 2.0
      gain_type: 'amplitude'
  - target: dataset.transforms.AudioRandomPitchShift
    params:
      p: ${data.p_audio_aug}
      shift: 1000
  - target: dataset.transforms.AudioRandomLowpassFilter
    params:
      p: ${data.p_audio_aug}
      cutoff_freq: 100
  - target: dataset.transforms.AudioRandomGaussNoise
    params:
      p: ${data.p_audio_aug}
      amplitude: 0.01
  - target: dataset.transforms.AudioMelSpectrogram
    params:
      sample_rate: 16000
      win_length: 400  # 25 ms * 16 kHz
      hop_length: 160  # 10 ms * 16 kHz
      n_fft: 1024  # 2^(ceil(log2(window_size * sampling_rate)))
      n_mels: 128  # as in AST
  - target: dataset.transforms.AudioLog
  - target: dataset.transforms.PadOrTruncate
    params:
      max_spec_t: ${model.params.afeat_extractor.params.max_spec_t}
  - target: dataset.transforms.AudioNormalizeAST
    params:
      mean: -4.2677393  # AST, pre-trained on AudioSet
      std: 4.5689974
  - target: dataset.transforms.PermuteStreams
    params:
      einops_order_audio: "S F T -> S T F"
      einops_order_rgb: "S T C H W -> S C T H W"

transform_sequence_test:
  - target: dataset.transforms.EqualifyFromRight
  - target: dataset.transforms.RGBSpatialCrop
    params:
      input_size: ${data.input_size}
      is_random: False
  - target: dataset.transforms.GenerateMultipleSegments
    params:
      segment_size_vframes: ${data.segment_size_vframes}
      n_segments: ${data.n_segments_valid}
      is_start_random: False
      step_size_seg: 1.0
  - target: dataset.transforms.RGBToFloatToZeroOne # RGBToHalfToZeroOne
  - target: dataset.transforms.RGBNormalize
    params:
      mean: [0.5, 0.5, 0.5]
      std: [0.5, 0.5, 0.5]
  - target: dataset.transforms.AudioMelSpectrogram
    params:
      sample_rate: 16000
      win_length: 400  # 25 ms * 16 kHz
      hop_length: 160  # 10 ms * 16 kHz
      n_fft: 1024  # 2^(ceil(log2(window_size * sampling_rate)))
      n_mels: 128  # as in AST
  - target: dataset.transforms.AudioLog
  - target: dataset.transforms.PadOrTruncate
    params:
      max_spec_t: ${model.params.afeat_extractor.params.max_spec_t}
  - target: dataset.transforms.AudioNormalizeAST
    params:
      mean: -4.2677393  # AST, pre-trained on AudioSet
      std: 4.5689974
  - target: dataset.transforms.PermuteStreams
    params:
      einops_order_audio: "S F T -> S T F"
      einops_order_rgb: "S T C H W -> S C T H W"

logging:
  logdir: './logs/avclip_models'
  log_code_state: True
  log_frequency: 100
  log_local: False
  delete_previous_checkpoint: False
  save_most_recent: True
  save_frequency: 0
  # patterns to ignore when backing up the code folder
  patterns_to_ignore: ['logs', '.git', '__pycache__', 'data', '*.pt', '*.pyth', 'sbatch_logs', '*.mp4', '*.wav', '*.jpg', '*.gif', 'misc*']
  use_tboard: False
  use_wandb: False
  wandb_notes: null
