# ################################
# Model: Dual-path Mamba for source separation
# Dataset : WSJ0-2mix and WSJ0-3mix
# ################################
#
# Basic parameters
# Seed needs to be set at top of yaml, before objects with parameters are made
#
seed: 1234
__set_seed: !apply:torch.manual_seed [!ref <seed>]

eval_only: false

# the path for wsj0/si_tr_s/ folder -- only needed if dynamic mixing is used
# e.g. /yourpath/wsj0-processed/si_tr_s/
# you need to convert the original wsj0 to 8k
# you can do this conversion with the script ../meta/preprocess_dynamic_mixing.py
base_folder_dm: null

project: Mamba-TasNet
experiment: dpmamba_S
output_folder: !ref results/WSJ0Mix/<experiment>/<seed>
train_log: !ref <output_folder>/train_log.txt
save_folder: !ref <output_folder>/save
train_data: !ref <save_folder>/wsj_tr.csv
valid_data: !ref <save_folder>/wsj_cv.csv
test_data: !ref <save_folder>/wsj_tt.csv
skip_prep: False

##### Data ######
batch_size: 2
data_samples: 22223
sample_length: 3
data_key: "20k_3s"


### Others ###
sef: "random"
load: False
prefix:
    name: "dpmamba_S"
    data_key: !ref <data_key>
    sef: !ref <sef>
transformer: False
smart_init: True
# cutoffs: [[1,1600],[1600,3200],[3200,4800],[4800,6400],[6400, 7999]]
# cutoffs: [[1,7999],[1,7999]]
cutoffs: [[1,4000],[4000,7999]]
# cutoffs: [[1,2667],[2667,5333],[5333,7999]]
# cutoffs: [[1,2000],[2000,4000],[4000,6000],[6000,7999]]
stft_loss: False
# Experiment params
precision: fp32 # bf16 | fp16 or fp32 # Set it to True for mixed precision
num_spks: 1 # set to 3 for wsj0-3mix
noprogressbar: False
save_audio: True # Save estimated sources on disk
n_audio_to_save: 20
sample_rate: 16000

# Training parameters
N_epochs: 200
lr: 5e-5
weight_decay: 0
clip_grad_norm: 5
loss_upper_lim: 999999  # this is the upper limit for an acceptable loss
# if True, the training sequences are cut to a specified length
limit_training_signal_len: False
# this is the length of sequences if we choose to limit
# the signal length of training sequences
training_signal_len: 32000000

n_warmup_step: 20000

# Set it to True to dynamically create mixtures at training time
dynamic_mixing: True

# Parameters for data augmentation
use_wavedrop: False
use_speedperturb: True
use_rand_shift: False
min_shift: -8000
max_shift: 8000

# Speed perturbation
speed_changes: [95, 100, 105]  # List of speed changes for time-stretching

speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
    orig_freq: !ref <sample_rate>
    speeds: !ref <speed_changes>

# Frequency drop: randomly drops a number of frequency bands to zero.
drop_freq_low: 0  # Min frequency band dropout probability
drop_freq_high: 1  # Max frequency band dropout probability
drop_freq_count_low: 1  # Min number of frequency bands to drop
drop_freq_count_high: 3  # Max number of frequency bands to drop
drop_freq_width: 0.05  # Width of frequency bands to drop

drop_freq: !new:speechbrain.augment.time_domain.DropFreq
    drop_freq_low: !ref <drop_freq_low>
    drop_freq_high: !ref <drop_freq_high>
    drop_freq_count_low: !ref <drop_freq_count_low>
    drop_freq_count_high: !ref <drop_freq_count_high>
    drop_freq_width: !ref <drop_freq_width>

# Time drop: randomly drops a number of temporal chunks.
drop_chunk_count_low: 1  # Min number of audio chunks to drop
drop_chunk_count_high: 5  # Max number of audio chunks to drop
drop_chunk_length_low: 1000  # Min length of audio chunks to drop
drop_chunk_length_high: 2000  # Max length of audio chunks to drop

drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
    drop_length_low: !ref <drop_chunk_length_low>
    drop_length_high: !ref <drop_chunk_length_high>
    drop_count_low: !ref <drop_chunk_count_low>
    drop_count_high: !ref <drop_chunk_count_high>

# loss thresholding -- this thresholds the training loss
threshold_byloss: True
threshold: -40

# Encoder parameters
N_encoder_out: 256
out_channels: 256
kernel_size: 16
kernel_stride: !ref <kernel_size> // 2

# Dual-path parameter
n_dp: 8
n_dp_encoder: 8
chunk_size: 250
skip_n_block: 0
skip_around_intra: False

# Mamba parameters
bidirectional: True
n_mamba_dp: 2
ssm_dim: 16
mamba_expand: 2
mamba_conv: 4

# Dataloader options
dataloader_opts:
    batch_size: !ref <batch_size>
    num_workers: 16
    drop_last: True

fused_add_norm: False
rms_norm: True
residual_in_fp32: False


# Specifying the network
# Specifying the network
encoder: !new:speechbrain.lobes.models.dual_path.Encoder
    kernel_size: !ref <kernel_size>
    out_channels: !ref <N_encoder_out>

Mambaintra: !new:modules.mamba_blocks.MambaBlocksSequential
    n_mamba: !ref <n_mamba_dp> // 2
    bidirectional: !ref <bidirectional>
    d_model: !ref <out_channels>
    d_state: !ref <ssm_dim>
    expand: !ref <mamba_expand>
    d_conv: !ref <mamba_conv>
    fused_add_norm: !ref <fused_add_norm>
    rms_norm: !ref <rms_norm>
    residual_in_fp32: !ref <residual_in_fp32>

Mambainter: !new:modules.mamba_blocks.MambaBlocksSequential
    n_mamba: !ref <n_mamba_dp> // 2
    bidirectional: !ref <bidirectional>
    d_model: !ref <out_channels>
    d_state: !ref <ssm_dim>
    expand: !ref <mamba_expand>
    d_conv: !ref <mamba_conv>
    fused_add_norm: !ref <fused_add_norm>
    rms_norm: !ref <rms_norm>
    residual_in_fp32: !ref <residual_in_fp32>

mamba: !new:speechbrain.lobes.models.dual_path.Dual_Path_Model
    num_spks: !ref <num_spks>
    in_channels: !ref <out_channels>
    out_channels: !ref <out_channels>
    num_layers: !ref <n_dp>
    K: !ref <chunk_size>
    intra_model: !ref <Mambaintra>
    inter_model: !ref <Mambainter>
    norm: ln
    linear_layer_after_inter_intra: False
    skip_around_intra: !ref <skip_around_intra>

mamba_encoder: !new:speechbrain.lobes.models.dual_path.Dual_Path_Model
    num_spks: !ref <num_spks>
    in_channels: !ref <N_encoder_out>
    out_channels: !ref <out_channels>
    num_layers: !ref <n_dp>
    K: !ref <chunk_size>
    intra_model: !ref <Mambaintra>
    inter_model: !ref <Mambainter>
    norm: ln
    linear_layer_after_inter_intra: False
    skip_around_intra: !ref <skip_around_intra>
    
decoder: !new:speechbrain.lobes.models.dual_path.Decoder
    in_channels: !ref <out_channels>
    out_channels: 1
    kernel_size: !ref <kernel_size>
    stride: !ref <kernel_stride>
    bias: False


optimizer: !name:torch.optim.Adam
    lr: !ref <lr>
    weight_decay: !ref <weight_decay>

loss: !name:speechbrain.nnet.losses.get_si_snr_with_pitwrapper

use_cosine_schedule: True
percent_lr: 0.01

lr_scheduler: !new:speechbrain.nnet.schedulers.ReduceLROnPlateau
    factor: 0.5
    patience: 2
    dont_halve_until_epoch: 0

epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
    limit: !ref <N_epochs>

modules:
    encoder: !ref <encoder>
    decoder: !ref <decoder>
    mamba_encoder: !ref <mamba_encoder>
    mamba: !ref <mamba>
