project: ppl_xcodec
seed: 114514
ppl_ckpt_dir: ???
tensorboard_save_dir: ???
codec_name: xcodec
codec_ckpt_dir: ???
accumulate_grad_batches: 1 # 1 means no accumulate grad batches

trainer:
  _target_: pytorch_lightning.Trainer
  accelerator: gpu
  devices: ???
  precision: bf16-mixed
  accumulate_grad_batches: ${accumulate_grad_batches}
  max_steps: 1000000
  val_check_interval: 400
  limit_val_batches: 100
  log_every_n_steps: 50
  max_epochs: 1000
  strategy: ddp_find_unused_parameters_true

data:
  # Emilia-en
  _target_: codec_evaluation.perplexity.dataset.emilia_en_dataset.Emilia_ppl_Module
  base_audio_dir: ???
  dataset_path: ???
  train_batch_size: ???
  valid_batch_size: ???
  train_num_workers: 4
  valid_num_workers: 1

  # MTG-Jamendo
  #_target_: codec_evaluation.perplexity.dataset.mtg_jamendo_dataset.Jamendo_ppl_Module
  #base_audio_dir: ???
  #dataset_path: ???
  #train_batch_size: ???
  #valid_batch_size: ???
  #train_num_workers: 4
  #valid_num_workers: 1
  #sample_rate: 44100
  #target_sec: 15

model:
  _target_: codec_evaluation.perplexity.model.lit_modules.PPL_lit_modules
  accumulate_grad_batches: ${accumulate_grad_batches}
  ppl_model_config:
    _target_: transformers.models.qwen2.configuration_qwen2.Qwen2Config.from_pretrained
    pretrained_model_name_or_path: perplexity/config/ppl_model_config.json
    
  sample_rate: 24000    # audio sample rate
  codec_name: ${codec_name}
  codec_ckpt_dir: ${codec_ckpt_dir}
  lm_head_nums: 8

  optimizer_builder:
    _target_: torch.optim.AdamW
    _partial_: true
    lr: 1e-4
    betas:
    - 0.8
    - 0.99
    eps: 1e-05
    weight_decay: 0.01

  lr_scheduler_builder:
    _target_: torch.optim.lr_scheduler.LambdaLR
    _partial_: true
    lr_lambda:
      _target_: codec_evaluation.utils.schedule.get_cosine_schedule_with_warmup_lr_lambda
      _partial_: true
      num_warmup_steps: 50
      num_training_steps: ${trainer.max_steps}
      final_lr_ratio: 0.99

callbacks:
  rich_progress_bar:
    _target_: pytorch_lightning.callbacks.RichProgressBar

  model_summary:
    _target_: pytorch_lightning.callbacks.ModelSummary
    max_depth: 1

  model_checkpoint:
    _target_: pytorch_lightning.callbacks.ModelCheckpoint
    monitor: val_loss_mean
    mode: min
    every_n_train_steps: 2000
    dirpath: ${ppl_ckpt_dir}
    filename: '{epoch:03d}-{step:06d}_speech_xcodec_ppl'
    save_top_k: 1
    verbose: true
    save_last: true # add this parameter to save the last epoch

tensorboard_logger:
  _target_: pytorch_lightning.loggers.TensorBoardLogger
  save_dir: ${tensorboard_save_dir}
  name: ${project}
  log_graph: true
  