# ############################################################################
# Model: Tokenized TTS (WhisperSpeech-inspired)
# ############################################################################

experiment_name: tokotron/discrete_ssl

# Seed needs to be set at top of yaml, before objects with parameters are made
seed: 74443
__set_seed: !apply:torch.manual_seed [!ref <seed>]

# Model Type
ssl_model_type: wavlm

output_folder: !ref results/tokotron/<experiment_name>/<ssl_model_type>/<seed>
save_folder: !ref <output_folder>/save
train_log: !ref <output_folder>/train_log.txt


# Data files
data_folder: !PLACEHOLDER # e.g., /path/to/LibriSpeech
prepare_save_folder: !ref <data_folder>/prepared/discrete-<ssl_model_type>
pretrained_model_save_folder: !ref <prepare_save_folder>
vocoder_model_name: !ref unithifigan-dasb-<ssl_model_type>-discrete
vocoder_model_path: !ref <pretrained_model_save_folder>/<vocoder_model_name>
prepare_archive_path: null
prepare_skip_ignore_folders: False
train_json: !ref <prepare_save_folder>/train.json
valid_json: !ref <prepare_save_folder>/valid.json
test_json: !ref <prepare_save_folder>/test.json
frozen_split_path: null
sample_path: null
progress_folder: !ref <output_folder>/progress
progress_archive: !ref <progress_folder>/progress.tar
progress_current: !ref <progress_folder>/current
progress_meta: !ref <progress_folder>/meta.yaml
num_audio_samples: 32
samples_interval: 5

freeze_token_model: True
token_model_src: !apply:speechbrain.utils.hparams.choice
    value: !ref <ssl_model_type>
    choices:
        wavlm: microsoft/wavlm-large
        hubert: facebook/hubert-large-ll60k
        wav2vec2: facebook/wav2vec2-large-960h-lv60-self

g2p_src: speechbrain/soundchoice-g2p
token_model_kmeans_src: poonehmousavi/SSL_Quantization
token_model_kmeans_dataset: LibriSpeech-100-360-500
ssl_model_layers: [1, 3, 7, 12, 18, 23]
token_model_layers: !ref <ssl_model_layers>
token_offset: 1
vocoder_src: !apply:speechbrain.utils.hparams.choice
    value: !ref <ssl_model_type>
    choices:
        wavlm: chaanks/hifigan-wavlm-l1-3-7-12-18-23-k1000-LibriTTS
        hubert: chaanks/hifigan-hubert-l1-3-7-12-18-23-k1000-LibriTTS
        wav2vec2: chaanks/hifigan-wav2vec-l1-3-7-12-18-23-k1000-LibriTTS
spk_emb_src: speechbrain/spkrec-ecapa-voxceleb-mel-spec
use_spk_emb: False

vocoder_available_layers: [1, 3, 7, 12, 18, 23]

splits: ["train", "valid", "test"]
split_ratio: [90, 5, 5]


ckpt_interval_minutes: 30 # save checkpoint every N min

# Training parameters
input: text
number_of_epochs: 50
batch_size: 16
grad_accumulation_factor: 1
max_grad_norm: 0.01
sorting: random
num_workers: 4
skip_prep: False
overfit_test: False
overfit_test_sample_count: !ref <batch_size>
overfit_test_epoch_data_count: 1000


# index
pad_index: 0
bos_index: 0
bos_width: 1
eos_index: 0
eos_width: 1
audio_token_shift: 0

# stages related parameters
lr: 0.0005
lr_warmup_steps: 10000
lr_annealing_mode: step
guided_attention_weight: 50.0
guided_attention_sigma: 0.5
gate_loss_weight: 1.0
gate_threshold: 0.5
gate_loss_beta: 0.2
gate_loss_gamma: 0.01
gate_loss_max_weight: 1.

# Inference parameters
eos_mode: gate
decoder_mode: autoregressive
scale_factor: 4

# Beam Search-specific parameters
min_decode_ratio: 1.0
max_decode_ratio: 10.0
beam_size: 5


# Feature parameters
sample_rate: 22050
model_sample_rate: 16000
max_audio_length: 1000
infer_max_audio_length: !ref <max_audio_length>
debug_infer_max_audio_length: 10

# Label encoder
label_encoder: !new:speechbrain.dataio.encoder.TextEncoder
token_list_file_text: ./hparams/char_en.txt
token_list_file_phn: ./hparams/arpabet.txt
token_list_file: !apply:speechbrain.utils.hparams.choice
    value: !ref <input>
    choices:
        text: !ref <token_list_file_text>
        phonemes: !ref <token_list_file_phn>

# Gate offset
gate_offset: !apply:Tokotron.distance_diff_loss_ramp
    beta: !ref <gate_loss_beta>
    gamma: !ref <gate_loss_gamma>
    max_weight: !ref <gate_loss_max_weight>

silence_padding: !ref <gate_offset>
use_silence_padding: True


# Token model (pretrained)
ssl_model: !apply:speechbrain.utils.hparams.choice
    value: !ref <ssl_model_type>
    choices:
        wavlm: !new:speechbrain.lobes.models.huggingface_transformers.wavlm.WavLM
            source: !ref <token_model_src>
            save_path: !ref <pretrained_model_save_folder>
            freeze: !ref <freeze_token_model>
            output_all_hiddens: True
        hubert: !new:speechbrain.lobes.models.huggingface_transformers.hubert.HuBERT
            source: !ref <token_model_src>
            save_path: !ref <pretrained_model_save_folder>
            freeze: !ref <freeze_token_model>
            output_all_hiddens: True
        wav2vec2: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2
            source: !ref <token_model_src>
            save_path: !ref <pretrained_model_save_folder>
            freeze: !ref <freeze_token_model>
            output_all_hiddens: True


token_model: !new:speechbrain.lobes.models.huggingface_transformers.discrete_ssl.DiscreteSSL
    ssl_model: !ref <ssl_model>
    kmeans_repo_id: !ref <token_model_kmeans_src>
    kmeans_dataset: !ref <token_model_kmeans_dataset>
    num_clusters: !ref <audio_num_tokens>
    save_path: !ref <pretrained_model_save_folder>
    layers_num: !ref <token_model_layers>

spk_emb_model: !name:speechbrain.inference.encoders.MelSpectrogramEncoder.from_hparams
    source: !ref <spk_emb_src>
    savedir: !ref <pretrained_model_save_folder>/ecapa

# Dataloader options
train_dataloader_opts:
    batch_size: !ref <batch_size>
    shuffle: True
    num_workers: !ref <num_workers>
    collate_fn: !name:speechbrain.dataio.batch.PaddedBatch
        padding_kwargs:
            value: !ref <pad_index>

valid_dataloader_opts:
    batch_size: !ref <batch_size>
    num_workers: !ref <num_workers>
    collate_fn: !name:speechbrain.dataio.batch.PaddedBatch
        padding_kwargs:
            value: !ref <pad_index>

test_dataloader_opts:
    batch_size: 1
    num_workers: !ref <num_workers>
    collate_fn: !name:speechbrain.dataio.batch.PaddedBatch
        padding_kwargs:
            value: !ref <pad_index>

sample_dataloader_opts:
    batch_size: !ref <batch_size>
    num_workers: !ref <num_workers>
    collate_fn: !name:speechbrain.dataio.batch.PaddedBatch
        padding_kwargs:
            value: !ref <pad_index>

token_model_kwargs:
    SSL_layers: !ref <token_model_layers>

extract_features_opts:
    dataloader_opts:
        batch_size: !ref <batch_size>
    token_model: !ref <token_model>
    token_model_kwargs: !ref <token_model_kwargs>
    ssl_model: !ref <ssl_model>
    ssl_model_layers: !ref <ssl_model_layers>
    token_model_layers: !ref <token_model_layers>
    sample_rate: !ref <sample_rate>
    model_sample_rate: !ref <model_sample_rate>
    spk_emb_model: !ref <spk_emb_model>


####################### Model parameters ###########################
# Transformer
d_model: 512
nhead: 4
enc_num_layers: 6
dec_num_layers: 12
d_ffn: 2048
transformer_dropout: 0.2
target_dropout: 0.2
activation: !name:torch.nn.GELU
audio_num_tokens: 1000
audio_emb_size: 1024
audio_emb_freeze: False
audio_emb_pretrained: False
audio_emb_lr: 0.00001
audio_emb_weight_decay: 0.001
text_num_tokens: 39
phn_num_tokens: 52
input_num_tokens: !apply:speechbrain.utils.hparams.choice
    value: !ref <input>
    choices:
        text: !ref <text_num_tokens>
        phonemes: !ref <phn_num_tokens>
audio_tokens_per_step: 6
attention_type: regularMHA

############################## models ################################

vocoder: !apply:speechbrain.inference.vocoders.UnitHIFIGAN.from_hparams
    source: !ref <vocoder_src>
    savedir: !ref <vocoder_model_path>


model: !new:Tokotron.TokotronTransformerModel # yamllint disable-line rule:line-length
    input_num_tokens: !ref <input_num_tokens>
    audio_num_tokens: !ref <audio_num_tokens>
    audio_tokens_per_step: !ref <audio_tokens_per_step>
    d_model: !ref <d_model>
    d_ffn: !ref <d_ffn>
    nhead: !ref <nhead>
    enc_num_layers: !ref <enc_num_layers>
    dec_num_layers: !ref <dec_num_layers>
    dropout: !ref <transformer_dropout>
    target_dropout: !ref <target_dropout>
    activation: !ref <activation>
    attention_type: !ref <attention_type>
    gate_threshold: !ref <gate_threshold>
    gate_offset: !ref <gate_offset>
    audio_emb_size: !ref <audio_emb_size>
    audio_emb_freeze: !ref <audio_emb_freeze>
    max_audio_length: !ref <max_audio_length>
    eos_mode: !ref <eos_mode>
    infer_max_audio_length: !ref <infer_max_audio_length>
    audio_token_shift: !ref <audio_token_shift>
    decoder_mode: !ref <decoder_mode>
    scale_factor: !ref <scale_factor>
    representation_mode: discrete

modules:
    model: !ref <model>
    vocoder: !ref <vocoder>
    compute_cost: !ref <compute_cost>

# define two optimizers here for two-stage training
opt_class: !name:torch.optim.Adam
    lr: !ref <lr>

compute_cost: !new:Tokotron.TokotronLoss
    guided_attention_weight: !ref <guided_attention_weight>
    guided_attention_sigma: !ref <guided_attention_sigma>
    gate_weight: !ref <gate_loss_weight>
    gate_beta: !ref <gate_loss_beta>
    gate_gamma: !ref <gate_loss_gamma>
    gate_max_weight: !ref <gate_loss_max_weight>
    silence_padding: !ref <silence_padding>
    eos_mode: !ref <eos_mode>
    eos_index: !ref <eos_index>
    eos_width: !ref <eos_width>
    audio_tokens_per_step: !ref <audio_tokens_per_step>
    audio_token_shift: !ref <audio_token_shift>
    representation_mode: discrete


lr_annealing: !new:Tokotron.TargetedNoamScheduler
    lr_initial: [!ref <lr>, !ref <audio_emb_lr>]
    n_warmup_steps: !ref <lr_warmup_steps>
    param_group: 0

checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
    checkpoints_dir: !ref <save_folder>
    recoverables:
        model: !ref <model>
        lr_scheduler: !ref <lr_annealing>
        counter: !ref <epoch_counter>

freezer: !new:preparation.Freezer
    save_path: !ref <prepare_save_folder>
    archive_path: !ref <prepare_archive_path>

epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
    limit: !ref <number_of_epochs>

train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
    save_file: !ref <train_log>
