# @package _group_

common:
  tpu: true
  fp16: false
  log_format: json
  log_interval: 10

checkpoint:
  save_interval_updates: 25000
  keep_interval_updates: 1
  no_epoch_checkpoints: true

task:
  _name: audio_pretraining
  data: ???
  max_sample_size: 250000
  min_sample_size: 32000
  normalize: true
  num_batch_buckets: 3
  precompute_mask_indices: true
  enable_padding: true
  inferred_w2v_config:
      mask_prob: 0.65
      mask_selection: 'static'
      mask_other: 0
      mask_channel_prob: 0.1

dataset:
  num_workers: 6
  max_tokens: 1200000
  skip_invalid_size_inputs_valid_test: true

distributed_training:
  distributed_world_size: 8
  ddp_backend: legacy_ddp

criterion:
  _name: wav2vec
  infonce: true
  log_keys: ["prob_perplexity","code_perplexity","temp"]
  loss_weights: [0.1, 0]

optimization:
  max_update: 1000000
  lr: [0.005]

optimizer:
  _name: adam
  adam_betas: (0.9,0.98)
  adam_eps: 1e-06
  weight_decay: 0.01

lr_scheduler:
  _name: polynomial_decay
  warmup_updates: 32000

model:
  _name: wav2vec2
  quantize_targets: true
  extractor_mode: layer_norm
  layer_norm_first: true
  final_dim: 768
  latent_temp: [2.0,0.1,0.999995]
  encoder_layerdrop: 0.00
  dropout_input: 0.0
  dropout_features: 0.0
  dropout: 0.0
  attention_dropout: 0.0
  conv_bias: true

  encoder_layers: 24
  encoder_embed_dim: 1024
  encoder_ffn_embed_dim: 4096
  encoder_attention_heads: 16

  feature_grad_mult: 1.0
