# ############################################################################
# Model: Tokenized TTS (WhisperSpeech-inspired)
# ############################################################################

experiment_name: tokotron/discrete_ssl

# Seed needs to be set at top of yaml, before objects with parameters are made
seed: 74443
__set_seed: !apply:torch.manual_seed [!ref <seed>]
output_folder: !ref results/<experiment_name>/<seed>
save_folder: !ref <output_folder>/save
train_log: !ref <output_folder>/train_log.txt

token_model_src: "fnlp/SpeechTokenizer"
g2p_src: flexthink/soundchoice-g2p
vocoder_type: encodec
vocoder_src: "charactr/vocos-encodec-24khz"

# Data files
data_folder: !PLACEHOLDER # e.g., /path/to/LibriSpeech
prepare_save_folder: !ref <data_folder>/prepared/st
pretrained_model_save_folder: !ref <prepare_save_folder>
prepare_archive_path: null
prepare_skip_ignore_folders: False
train_json: !ref <prepare_save_folder>/train.json
valid_json: !ref <prepare_save_folder>/valid.json
test_json: !ref <prepare_save_folder>/test.json
frozen_split_path: null
sample_path: null
progress_folder: !ref <output_folder>/progress
progress_archive: !ref <progress_folder>/progress.tar
progress_current: !ref <progress_folder>/current
progress_meta: !ref <progress_folder>/meta.yaml
num_audio_samples: 32
samples_interval: 5

splits: ["train", "valid", "test"]
split_ratio: [90, 5, 5]


ckpt_interval_minutes: 30 # save checkpoint every N min

# Training parameters
input: text
number_of_epochs: 150
batch_size: 16
grad_accumulation_factor: 1
max_grad_norm: 5.0
sorting: random
num_workers: 4
skip_prep: False
overfit_test: False
overfit_test_sample_count: !ref <batch_size>
overfit_test_epoch_data_count: 1000


# index
pad_index: 0
bos_index: 0
bos_width: 1

# stages related parameters
lr: 0.001
lr_warmup_steps: 10000
lr_annealing_mode: step
guided_attention_weight: 50.0
guided_attention_sigma: 0.5
gate_loss_weight: 1.0
gate_threshold: 0.5
gate_loss_beta: 0.2
gate_loss_gamma: 0.01
gate_loss_max_weight: 1.

# Feature parameters
sample_rate: 22050
model_sample_rate: 16000
max_audio_length: 1000
infer_max_audio_length: !ref <max_audio_length>
debug_infer_max_audio_length: 10

# Label encoder
label_encoder: !new:speechbrain.dataio.encoder.TextEncoder
token_list_file_text: ./hparams/char_en.txt
token_list_file_phn: ./hparams/arpabet.txt
token_list_file: !apply:speechbrain.utils.hparams.choice
    value: !ref <input>
    choices:
        text: !ref <token_list_file_text>
        phonemes: !ref <token_list_file_phn>

# Gate offset
gate_offset: !apply:Tokotron.distance_diff_loss_ramp
    beta: !ref <gate_loss_beta>
    gamma: !ref <gate_loss_gamma>
    max_weight: !ref <gate_loss_max_weight>

silence_padding: !ref <gate_offset>

# Token model (pretrained)
speech_tokenizer: !new:speechbrain.lobes.models.discrete.speechtokenizer_interface.SpeechTokenizer_interface
    source: !ref <token_model_src>
    save_path: !ref <pretrained_model_save_folder>

token_model: !new:Tokotron.SpeechTokenizerFeatureExtractor
    speech_tokenizer: !ref <speech_tokenizer>
    codebooks: !ref <audio_tokens_per_step>

# Dataloader options
train_dataloader_opts:
    batch_size: !ref <batch_size>
    shuffle: True
    num_workers: !ref <num_workers>
    collate_fn: !name:speechbrain.dataio.batch.PaddedBatch
        padding_kwargs:
            value: !ref <pad_index>

valid_dataloader_opts:
    batch_size: !ref <batch_size>
    num_workers: !ref <num_workers>
    collate_fn: !name:speechbrain.dataio.batch.PaddedBatch
        padding_kwargs:
            value: !ref <pad_index>

test_dataloader_opts:
    batch_size: 1
    num_workers: !ref <num_workers>
    collate_fn: !name:speechbrain.dataio.batch.PaddedBatch
        padding_kwargs:
            value: !ref <pad_index>

sample_dataloader_opts:
    batch_size: !ref <batch_size>
    num_workers: !ref <num_workers>
    collate_fn: !name:speechbrain.dataio.batch.PaddedBatch
        padding_kwargs:
            value: !ref <pad_index>

extract_features_opts:
    dataloader_opts:
        batch_size: !ref <batch_size>
    token_model: !ref <token_model>
    sample_rate: !ref <sample_rate>
    model_sample_rate: !ref <model_sample_rate>


####################### Model parameters ###########################
# Transformer
d_model: 512
nhead: 4
enc_num_layers: 6
dec_num_layers: 12
d_ffn: 2048
transformer_dropout: 0.2
target_dropout: 0.2
activation: !name:torch.nn.GELU
audio_num_tokens: 1024
audio_emb_size: 1024
audio_emb_freeze: False
audio_emb_pretrained: True
text_num_tokens: 39
phn_num_tokens: 52
input_num_tokens: !apply:speechbrain.utils.hparams.choice
    value: !ref <input>
    choices:
        text: !ref <text_num_tokens>
        phonemes: !ref <phn_num_tokens>
audio_tokens_per_step: 2
bandwidth: 1.5
attention_type: regularMHA

############################## models ################################

model: !new:Tokotron.TokotronTransformerModel  # yamllint disable-line rule:line-length
    input_num_tokens: !ref <input_num_tokens>
    audio_num_tokens: !ref <audio_num_tokens>
    audio_tokens_per_step: !ref <audio_tokens_per_step>
    d_model: !ref <d_model>
    d_ffn: !ref <d_ffn>
    nhead: !ref <nhead>
    enc_num_layers: !ref <enc_num_layers>
    dec_num_layers: !ref <dec_num_layers>
    dropout: !ref <transformer_dropout>
    target_dropout: !ref <target_dropout>
    activation: !ref <activation>
    attention_type: !ref <attention_type>
    gate_threshold: !ref <gate_threshold>
    gate_offset: !ref <gate_offset>
    audio_emb_size: !ref <audio_emb_size>
    audio_emb_freeze: !ref <audio_emb_freeze>
    max_audio_length: !ref <max_audio_length>
    infer_max_audio_length: !ref <infer_max_audio_length>

modules:
    model: !ref <model>
    token_model: !ref <token_model>

# define two optimizers here for two-stage training
opt_class: !name:torch.optim.Adam
    lr: !ref <lr>

compute_cost: !new:Tokotron.TokotronLoss
    guided_attention_weight: !ref <guided_attention_weight>
    guided_attention_sigma: !ref <guided_attention_sigma>
    gate_weight: !ref <gate_loss_weight>
    gate_beta: !ref <gate_loss_beta>
    gate_gamma: !ref <gate_loss_gamma>
    gate_max_weight: !ref <gate_loss_max_weight>
    silence_padding: !ref <silence_padding>

lr_annealing: !new:speechbrain.nnet.schedulers.NoamScheduler
    lr_initial: !ref <lr>
    n_warmup_steps: !ref <lr_warmup_steps>

checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
    checkpoints_dir: !ref <save_folder>
    recoverables:
        model: !ref <model>
        lr_scheduler: !ref <lr_annealing>
        counter: !ref <epoch_counter>

freezer: !new:preparation.Freezer
    save_path: !ref <prepare_save_folder>
    archive_path: !ref <prepare_archive_path>

epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
    limit: !ref <number_of_epochs>

train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
    save_file: !ref <train_log>
