defaults:
  - basic
  - model: flow_matching_small
  - data@data_dict: train_val
  - loss@loss_fn: weighted_sum
  - _self_

exp_name: wave_dit_ada_sola_bias_diff
exp_dir: experiments/${exp_name}/
logging_file: ${exp_dir}/train.log

train_dataloader:
  _target_: torchdata.stateful_dataloader.StatefulDataLoader
  dataset:
    _target_: data_module.dataset.TaskGroupedAudioGenConcatDataset
    datasets: ${data_dict.train_data_list}
  batch_sampler:
    _target_: data_module.sampler.TaskGroupedIteratingBatchSampler
    shuffle: true
    batch_size: 24 # per device batch size
  num_workers: 4
  collate_fn:
    _target_: data_module.collate_function.PaddingCollateWithAnyContent
    pad_keys: ["waveform", "duration", "instruction"]
    torchify_keys: ["is_time_aligned"]
    content_pad_keys: ["phoneme", "phoneme_duration", "midi", "midi_duration", "is_slur", "frames"]
    content_torchify_keys: ["spk"]

val_dataloader:
  _target_: torch.utils.data.DataLoader
  dataset:
    _target_: data_module.dataset.TaskGroupedAudioGenConcatDataset
    datasets: ${data_dict.val_data_list}
  batch_sampler:
    _target_: data_module.sampler.TaskGroupedSequentialBatchSampler
    batch_size: 24 # per device batch size
    shuffle: false
    drop_last: true
  num_workers: 4
  collate_fn:
    _target_: data_module.collate_function.PaddingCollateWithAnyContent
    pad_keys: ["waveform",  "duration", "instruction"]
    torchify_keys: ["is_time_aligned"]
    content_pad_keys: ["phoneme", "phoneme_duration", "midi", "midi_duration", "is_slur", "frames"]
    content_torchify_keys: ["spk"]

model:
  content_encoder:
    _target_: models.content_encoder.content_encoder.BatchedContentEncoder

warmup_params:
  warmup_steps: 10000
  warmup_epochs: Null
  epoch_length: ${epoch_length}

gradient_accumulation_steps: 1

optimizer:
  _target_: torch.optim.AdamW
  lr: !!float 5e-5
  weight_decay: 0.01

lr_scheduler:
  _target_: "transformers.get_scheduler"
  name: "linear"

epochs: 600
epoch_length: 2000

trainer:
  _target_: audio_generation_trainer.MultiTaskAudioGenerationTrainer
  project_dir: ${exp_dir}
  logging_file: ${logging_file}
  gradient_accumulation_steps: ${gradient_accumulation_steps}
  max_grad_norm: 1.0
  epochs: ${epochs}
  epoch_length: ${epoch_length}
  save_last_k: 2
  early_stop: Null
  logging_config:
    _target_: trainer.LoggingConfig
    report_to: swanlab
    # project: x_to_audio_generation
    project: runs
    save_dir: ${exp_dir}
    name: ${exp_name}
    resume_id: Null
  metric_monitor:
    _target_: trainer.MetricMonitor
    metric_name: loss
    mode: min
