# @package __global__

defaults:
  - /solver/default
  - /conditioner: none
  - _self_
  - /solver/musicgen/evaluation: none
  - override /dset: audio/default

autocast: true
autocast_dtype: float16

solver: musicgen
sample_rate: ???
channels: ???
compression_model_checkpoint: ???
# The following will set the num codebooks on the underlying
# model, this might be different from the actual value for n_q
# given to the transformer, when the model output is postprocessed, for instance
# for stereo channels. If not provided, default value for the compression model
# will be used.
compression_model_n_q: null

tokens:
  padding_with_special_token: false

interleave_stereo_codebooks:
  use: false
  per_timestep: false

cache:
  path:
  write: false
  write_shard: 0
  write_num_shards: 1


dataset:
  batch_size: 128
  num_workers: 10
  segment_duration: 30
  min_segment_ratio: 0.8  # lower values such as 0.5 result in generations with a lot of silence.
  return_info: true
  train:
    num_samples: 1000000 # need a randomly large number here for AudioDataset
  valid:
    num_samples: 10000
  generate:
    num_samples: 50

metrics:
  fad:
    use_gt: false
    model: tf
    tf:
      bin: null  # path to local frechet_audio_distance code
      model_path: //reference/fad/vggish_model.ckpt
  kld:
    use_gt: false
    model: passt
    passt:
      pretrained_length: 20
  text_consistency:
    use_gt: false
    model: clap
    clap:
      model_path: //reference/clap/music_audioset_epoch_15_esc_90.14.pt
      model_arch: 'HTSAT-base'
      enable_fusion: false
  chroma_cosine:
    use_gt: false
    model: chroma_base
    chroma_base:
      sample_rate: ${sample_rate}
      n_chroma: 12
      radix2_exp: 14
      argmax: true

generate:
  every: 25
  num_workers: 5
  path: samples
  audio:
    format: wav
    strategy: loudness
    sample_rate: ${sample_rate}
    loudness_headroom_db: 14
  lm:
    prompted_samples: true
    unprompted_samples: true
    gen_gt_samples: false
    prompt_duration: null   # if not set, will use dataset.generate.segment_duration / 4
    gen_duration: null      # if not set, will use dataset.generate.segment_duration
    remove_prompts: false
    # generation params
    use_sampling: false
    temp: 1.0
    top_k: 0
    top_p: 0.0
evaluate:
  every: 25
  num_workers: 5
  metrics:
    base: false
    fad: false
    kld: false
    text_consistency: false
    chroma_cosine: false

checkpoint:
  save_last: true
  save_every: 50
  keep_last: 10
  keep_every_states: null

optim:
  epochs: 200
  updates_per_epoch: 2000
  lr: 1e-4
  optimizer: adamw
  max_norm: 1.0
  eager_sync: true
  adam:
    betas: [0.9, 0.95]
    weight_decay: 0.1
    eps: 1e-8

schedule:
  lr_scheduler: null
