# @package _global_

sampling_rate: 22050
dataset_path: ${work_dir}/data/slakh2100
length: 262144
channels: 1
min_duration: 12.0
max_duration: 640.0
aug_shift: True
stems: ['bass', 'drums', 'guitar', 'piano']

model:
  _target_: main.module_base.Model
  learning_rate: 1e-4
  beta1: 0.9
  beta2: 0.99
  in_channels: 4
  channels: 256
  patch_factor: 16
  patch_blocks: 1
  resnet_groups: 8
  kernel_multiplier_downsample: 2
  kernel_sizes_init: [1, 3, 7]
  multipliers: [1, 2, 4, 4, 4, 4, 4]
  factors: [4, 4, 4, 2, 2, 2]
  num_blocks: [2, 2, 2, 2, 2, 2]
  attentions: [False, False, False, True, True, True]
  attention_heads: 8
  attention_features: 128
  attention_multiplier: 2
  use_nearest_upsample: False
  use_skip_scale: True
  use_attention_bottleneck: True
  diffusion_sigma_distribution:
    _target_: audio_diffusion_pytorch.LogNormalDistribution
    mean: -3.0
    std: 1.0
  diffusion_sigma_data: 0.2
  diffusion_dynamic_threshold: 0.0

datamodule:
  _target_: main.module_base.DatamoduleWithValidation
  train_dataset:
    _target_: main.data.MultiSourceDataset
    audio_files_dir: ${dataset_path}/train
    sr: ${sampling_rate}
    channels: ${channels}
    min_duration: ${min_duration}
    max_duration: ${max_duration}
    aug_shift: ${aug_shift}
    sample_length: ${length}
    stems: ${stems}

  val_dataset:
    _target_: main.data.MultiSourceDataset
    audio_files_dir: ${dataset_path}/validation
    sr: ${sampling_rate}
    channels: ${channels}
    min_duration: ${min_duration}
    max_duration: ${max_duration}
    aug_shift: ${aug_shift}
    sample_length: ${length}
    stems: ${stems}

  batch_size: 8
  num_workers: 4
  pin_memory: True

callbacks:
  rich_progress_bar:
    _target_: pytorch_lightning.callbacks.RichProgressBar

  model_checkpoint:
    _target_: pytorch_lightning.callbacks.ModelCheckpoint
    monitor: "valid_loss"   # name of the logged metric which determines when model is improving
    save_top_k: 1           # save k best models (determined by above metric)
    save_last: False        # additionaly always save model from last epoch
    mode: "min"             # can be "max" or "min"
    verbose: False
    dirpath: ${work_dir}/ckpts/${oc.env:RUN_NAME}
    filename: '{epoch:02d}-{valid_loss:.3f}'
    save_weights_only: True

  model_summary:
    _target_: pytorch_lightning.callbacks.RichModelSummary
    max_depth: 2

  audio_samples_logger:
    _target_: main.module_base.MultiSourceSampleLogger
    num_items: 1
    channels: 4
    sampling_rate: ${sampling_rate}
    length: ${length}
    sampling_steps: [100]
    stems: ${stems}
    diffusion_sampler:
      _target_: audio_diffusion_pytorch.ADPM2Sampler
      rho: 1.0
    diffusion_schedule:
      _target_: audio_diffusion_pytorch.KarrasSchedule
      sigma_min: 0.0001
      sigma_max: 3.0
      rho: 9.0

loggers:
  wandb:
    _target_: pytorch_lightning.loggers.wandb.WandbLogger
    name: ${oc.env:RUN_NAME}
    project: ${oc.env:WANDB_PROJECT}
    entity: ${oc.env:WANDB_ENTITY}
    # offline: False  # set True to store all logs only locally
    job_type: "train"
    group: ""
    save_dir: ${logs_dir}

trainer:
  _target_: pytorch_lightning.Trainer
  gpus: 1 # Set `1` to train on GPU, `0` to train on CPU only, and `-1` to train on all GPUs, default `0`
  precision: "bf16" # Precision used for tensors, default `32`
  accelerator: null # `ddp` GPUs train individually and sync gradients, default `None`
  min_epochs: 0
  max_epochs: -1
  enable_model_summary: False
  log_every_n_steps: 10 # Logs metrics every N batches
  check_val_every_n_epoch: null
  val_check_interval: 4000
