# ################################
# Model: SepFormer for source separation
# https://arxiv.org/abs/2010.13154
# Dataset : WSJ0-2mix and WSJ0-3mix
# ################################
#
# Basic parameters
# Seed needs to be set at top of yaml, before objects with parameters are made
#
seed: 1234
__set_seed: !apply:speechbrain.utils.seed_everything [!ref <seed>]

# Data params
load: false
# e.g. '/yourpath/wsj0-mix/2speakers'
# end with 2speakers for wsj0-2mix or 3speakers for wsj0-3mix

# the path for wsj0/si_tr_s/ folder -- only needed if dynamic mixing is used
# e.g. /yourpath/wsj0-processed/si_tr_s/
# you need to convert the original wsj0 to 8k
# you can do this conversion with the script ../meta/preprocess_dynamic_mixing.py
base_folder_dm: /yourpath/wsj0-processed/si_tr_s/

experiment_name: sepformer
output_folder: !ref results/<experiment_name>/<seed>
train_log: !ref <output_folder>/train_log.txt
save_folder: !ref <output_folder>/save
train_data: !ref <save_folder>/wsj_tr.csv
valid_data: !ref <save_folder>/wsj_cv.csv
test_data: !ref <save_folder>/wsj_tt.csv
skip_prep: False


# Experiment params
precision: fp16 # bf16, fp16 or fp32 # Set it to True for mixed precision
num_spks: 1 # set to 3 for wsj0-3mix
noprogressbar: False
save_audio: True # Save estimated sources on disk
n_audio_to_save: 20
sample_rate: 8000

####################### Training Parameters ####################################
N_epochs: 200
batch_size: 1
lr: 0.00015
clip_grad_norm: 5
loss_upper_lim: 999999  # this is the upper limit for an acceptable loss
# if True, the training sequences are cut to a specified length
limit_training_signal_len: False
# this is the length of sequences if we choose to limit
# the signal length of training sequences
training_signal_len: 32000000

# Set it to True to dynamically create mixtures at training time
dynamic_mixing: False

# Parameters for data augmentation
use_wavedrop: False
use_speedperturb: True
use_rand_shift: False
min_shift: -8000
max_shift: 8000

# Speed perturbation
speed_changes: [95, 100, 105]  # List of speed changes for time-stretching

speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
    orig_freq: !ref <sample_rate>
    speeds: !ref <speed_changes>

# Frequency drop: randomly drops a number of frequency bands to zero.
drop_freq_low: 0  # Min frequency band dropout probability
drop_freq_high: 1  # Max frequency band dropout probability
drop_freq_count_low: 1  # Min number of frequency bands to drop
drop_freq_count_high: 3  # Max number of frequency bands to drop
drop_freq_width: 0.05  # Width of frequency bands to drop

drop_freq: !new:speechbrain.augment.time_domain.DropFreq
    drop_freq_low: !ref <drop_freq_low>
    drop_freq_high: !ref <drop_freq_high>
    drop_freq_count_low: !ref <drop_freq_count_low>
    drop_freq_count_high: !ref <drop_freq_count_high>
    drop_freq_width: !ref <drop_freq_width>

# Time drop: randomly drops a number of temporal chunks.
drop_chunk_count_low: 1  # Min number of audio chunks to drop
drop_chunk_count_high: 5  # Max number of audio chunks to drop
drop_chunk_length_low: 1000  # Min length of audio chunks to drop
drop_chunk_length_high: 2000  # Max length of audio chunks to drop

drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
    drop_length_low: !ref <drop_chunk_length_low>
    drop_length_high: !ref <drop_chunk_length_high>
    drop_count_low: !ref <drop_chunk_count_low>
    drop_count_high: !ref <drop_chunk_count_high>

# loss thresholding -- this thresholds the training loss
threshold_byloss: True
threshold: -30

# Encoder parameters
N_encoder_out: 256
out_channels: 256
MaskNet_in_channels: 256
kernel_size: 16
kernel_stride: 8

# Dataloader options
# Set num_workers: 0 on MacOS due to behavior of the multiprocessing library
dataloader_opts:
    batch_size: !ref <batch_size>
    num_workers: 3


# Specifying the network
SignalEncoder: !new:speechbrain.lobes.models.dual_path.Encoder
    kernel_size: !ref <kernel_size>
    out_channels: !ref <N_encoder_out>

MaskEncoder: !new:speechbrain.lobes.models.dual_path.Encoder
    kernel_size: !ref <kernel_size>
    out_channels: !ref <N_encoder_out>

SBtfintra: !new:speechbrain.lobes.models.dual_path.SBTransformerBlock
    num_layers: 8
    d_model: !ref <out_channels>
    nhead: 8
    d_ffn: 1024
    dropout: 0
    use_positional_encoding: True
    norm_before: True

SBtfinter: !new:speechbrain.lobes.models.dual_path.SBTransformerBlock
    num_layers: 8
    d_model: !ref <out_channels>
    nhead: 8
    d_ffn: 1024
    dropout: 0
    use_positional_encoding: True
    norm_before: True

MaskNet: !new:speechbrain.lobes.models.dual_path.Dual_Path_Model
    num_spks: !ref <num_spks>
    in_channels: !ref <MaskNet_in_channels> 
    out_channels: !ref <out_channels>
    num_layers: 2
    K: 250
    intra_model: !ref <SBtfintra>
    inter_model: !ref <SBtfinter>
    norm: ln
    linear_layer_after_inter_intra: False
    skip_around_intra: True

Decoder: !new:speechbrain.lobes.models.dual_path.Decoder
    in_channels: !ref <N_encoder_out>
    out_channels: 1
    kernel_size: !ref <kernel_size>
    stride: !ref <kernel_stride>
    bias: False

optimizer: !name:torch.optim.Adam
    lr: !ref <lr>
    weight_decay: 0

loss: !name:speechbrain.nnet.losses.get_si_snr_with_pitwrapper

lr_scheduler: !new:speechbrain.nnet.schedulers.ReduceLROnPlateau
    factor: 0.5
    patience: 2
    dont_halve_until_epoch: 85

epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
    limit: !ref <N_epochs>
