# YT MEGA
# 32768 folds, each one has roughly 2362 segments of 16
# which is 77410304 ex
# bsize of 1024 so 75596 is one epoch

data:
  train_fns: "gs://${YOURPATHHERE}/train{:05d}of32800.tfrecord"
  num_train_files: 32768

  use_audio_token_prob: 0.5

  # vision
  random_scale_max: 1.1
  random_scale_min: 1.05

  # Audio
  fft_hop_length: 588
  fft_window_size: 1536
  num_mels: 64
  sample_rate: 22050
  spec_size: 188

  # Masking
  mask_rate: 0.25

  # how many sequences to use
  num_audio2text_seqs: 1
  num_text2audio_seqs: 1
  num_text_seqs: 1
  num_text_seqs_in_record: 1

  # Shapes
  num_segments: 16
  num_segment_groups: 2
  num_audio_subsegments: 3

  # you should set this to
  # (output_grid[0] * output_grid[1]) // (vit_pooling_ratio ** 2) *
  #    num_segments // num_segment_groups  + lang_seq_len
  seq_len: 640
  lang_seq_len: 160

  num_text_spans_to_include: 48
  text_span_budget: 38

model:
  # Joint
  hidden_size: 1024
  joint_num_layers: 24
  use_bfloat16: true

  # Audio
  audio_num_layers: 12
  audio_patch_size: 2
  audio_seq_length: 60
  audio_token_length: 6

  # Vision
  output_grid: [12, 20]
  vit_patch_size: 16
  vit_pooling_ratio: 2
  vit_num_layers: 24

  # text span -- note that this length is absent the CLS token we add
  span_num_layers: 4
  text_span_length: 15

device:
  use_tpu: True
  num_tpu_cores: 512
  output_dir: "gs://${YOURPATHHERE}/flagship~size=large~lr=3e-4/"
  batch_size: 1024
  iterations_per_loop: 7500
  commit_every_nsteps: 50
  n_fns_per_cycle: 128
  num_parallel_reads: 128
  shuffle_buffer_size: 4096
  wandb_project: merlotreserve

optimizer:
  learning_rate: 0.0003
  num_train_steps: 750000 # 10 epochs
  num_warmup_steps: 3750 # first 1/20th epoch
  weight_decay_rate: 0.1
  beta_2: 0.98
  adafactor: False
  use_bfloat16_adam: True
  eps: 0.000001
  use_bfloat16_weights: False