# ###########################################################################################
# Model: Conformer with continuous WavLM audio representations
# ###########################################################################################

experiment_name: continuous_wavlm

# Seed needs to be set at top of YAML
seed: 0
__set_seed: !apply:torch.manual_seed [!ref <seed>]

# Data preparation
data_folder: !PLACEHOLDER
train_csv: !ref <save_folder>/trainset_28spk_wav.csv
valid_csv: !ref <save_folder>/validset_wav.csv
test_csv: !ref <save_folder>/testset_wav.csv
splits: [trainset_28spk_wav, validset_wav, testset_wav]
num_valid_speakers: 2

# Output folders
output_folder: !ref results/<experiment_name>/<seed>
save_folder: !ref <output_folder>/save
cache_folder: !name:huggingface_hub.constants.HUGGINGFACE_HUB_CACHE

# Save options
compute_metrics: True
save_audios: False

# Preprocessing parameters
train_remove_if_longer: 60.0  # Seconds
valid_remove_if_longer: 60.0  # Seconds
test_remove_if_longer: 60.0  # Seconds
sorting: random
use_cache: False

# Training parameters
num_epochs: 50
grad_accumulation_factor: 16
train_batch_size: 1
valid_batch_size: 1
test_batch_size: 1
dataloader_workers: 4
nonfinite_patience: 10
max_grad_norm: 5.0
precision: fp32
ckpt_interval_minutes: 6000
keep_checkpoints: 1
augment: False
augment_prob: 0.75

# Optimizer parameters
lr: 0.0005
weight_decay: 0.01
improvement_threshold: 0.0025
annealing_factor: 0.9
patient: 1

# SSL parameters
sample_rate: 16000
SSL_layers: [1, 3, 7, 12, 18, 23]
num_codebooks: !apply:len [!ref <SSL_layers>]
ssl_hub: microsoft/wavlm-large
vocoder_hub: chaanks/hifigan-wavlm-l1-3-7-12-18-23-LibriTTS  # Must be consistent with ssl_hub/SSL_layers

# Encoder parameters
embedding_dim: 1024
dropout: 0.1
activation: !name:torch.nn.GELU
d_model: 256
nhead: 4
num_layers: 6
d_ffn: 2048
max_length: 2000
causal: False

# Augmentation
drop_freq: !new:speechbrain.augment.time_domain.DropFreq
    drop_freq_low: 0  # Min frequency band dropout probability
    drop_freq_high: 1  # Max frequency band dropout probability
    drop_freq_count_low: 1  # Min number of frequency bands to drop
    drop_freq_count_high: 3  # Max number of frequency bands to drop
    drop_freq_width: 0.05  # Width of frequency bands to drop

drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
    drop_length_low: 1  # Min number of audio chunks to drop
    drop_length_high: 5  # Max number of audio chunks to drop
    drop_count_low: 1000  # Min length of audio chunks to drop
    drop_count_high: 2000  # Max length of audio chunks to drop

augmentation: !new:speechbrain.augment.augmenter.Augmenter
    parallel_augment: False
    concat_original: False
    repeat_augment: 1
    shuffle_augmentations: False
    min_augmentations: 2
    max_augmentations: 2
    augment_prob: !ref <augment_prob>
    augmentations: [!ref <drop_freq>, !ref <drop_chunk>]

# Modules
ssl_model: !new:utils.SBWav2Vec2ForwardWrapper
    wav2vec2: !new:speechbrain.lobes.models.huggingface_transformers.wavlm.WavLM
        source: !ref <ssl_hub>
        output_norm: False
        freeze: True
        freeze_feature_extractor: True
        output_all_hiddens: True
        save_path: !ref <cache_folder>
    layer_ids: !ref <SSL_layers>

ssl_vocoder: !apply:speechbrain.inference.vocoders.UnitHIFIGAN.from_hparams
    source: !ref <vocoder_hub>
    savedir: !apply:os.path.join [!ref <cache_folder>, !ref <vocoder_hub>]

attention_mlp: !new:custom_model.AttentionMLP
    input_dim: !ref <embedding_dim>
    hidden_dim: !ref <embedding_dim>

encoder: !new:speechbrain.lobes.models.transformer.TransformerASR.TransformerASR
    input_size: !ref <embedding_dim>
    tgt_vocab: -1
    d_model: !ref <d_model>
    nhead: !ref <nhead>
    num_encoder_layers: !ref <num_layers>
    num_decoder_layers: 0
    d_ffn: !ref <d_ffn>
    dropout: !ref <dropout>
    activation: !ref <activation>
    max_length: !ref <max_length>
    encoder_module: conformer
    normalize_before: True
    causal: !ref <causal>

head: !new:torch.nn.Linear
    in_features: !ref <d_model>
    # Workaround to bypass HyperPyYAML lack of flexibility
    out_features: !apply:train_continuous_ssl.len_ [!ref <SSL_layers>, !ref <embedding_dim>]

modules:
    attention_mlp: !ref <attention_mlp>
    encoder: !ref <encoder>
    head: !ref <head>

model: !new:torch.nn.ModuleList
    [[!ref <attention_mlp>,
      !ref <encoder>,
      !ref <head>]]

# Loss functions
rec_loss: !name:speechbrain.nnet.losses.mse_loss
    allowed_len_diff: 0
    reduction: mean

# Optimizers
opt_class: !name:torch.optim.AdamW
    lr: !ref <lr>
    betas: (0.9, 0.98)
    eps: 1.e-8
    weight_decay: !ref <weight_decay>

# Schedulers
scheduler: !new:speechbrain.nnet.schedulers.NewBobScheduler
    initial_value: !ref <lr>
    improvement_threshold: !ref <improvement_threshold>
    annealing_factor: !ref <annealing_factor>
    patient: !ref <patient>

# Dataloaders
train_dataloader_kwargs:
    batch_size: !ref <train_batch_size>
    num_workers: !ref <dataloader_workers>
    pin_memory: True
    shuffle: !apply:str.__eq__ [!ref <sorting>, random]

valid_dataloader_kwargs:
    batch_size: !ref <valid_batch_size>
    num_workers: !ref <dataloader_workers>
    pin_memory: True

test_dataloader_kwargs:
    batch_size: !ref <test_batch_size>
    num_workers: !ref <dataloader_workers>
    pin_memory: True

# Performance metrics
dnsmos_computer: !name:metrics.dnsmos.DNSMOS
    sample_rate: !ref <sample_rate>

dwer_computer: !name:metrics.dwer.DWER
    model_hub: openai/whisper-small
    save_path: !ref <cache_folder>
    sample_rate: !ref <sample_rate>

wavlm_sim_computer: !name:metrics.spk_sim.SpkSimWavLM
    model_hub: microsoft/wavlm-base-sv
    save_path: !ref <cache_folder>
    sample_rate: !ref <sample_rate>

ecapatdnn_sim_computer: !name:metrics.spk_sim.SpkSimECAPATDNN
    model_hub: speechbrain/spkrec-ecapa-voxceleb
    save_path: !apply:os.path.join [!ref <cache_folder>, models--speechbrain--spkrec-ecapa-voxceleb]
    sample_rate: !ref <sample_rate>

# Counters, checkpointers, loggers, etc.
epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
    limit: !ref <num_epochs>

checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
    checkpoints_dir: !ref <save_folder>
    recoverables:
        model: !ref <model>
        scheduler: !ref <scheduler>
        counter: !ref <epoch_counter>

train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
    save_file: !ref <output_folder>/train_log.txt
