################################
# Recipe for Training K-Means Clustering on LJSpeech Data
# Using Self-Supervised Model-Based Representations
#
# It is used for creating discrete audio representations from LJSpeech data.
#
################################
# Seed needs to be set at top of yaml, before objects with parameters are made
seed: 1986
__set_seed: !apply:torch.manual_seed [!ref <seed>]
output_folder: !ref results/LJSpeech/clustering/hubert/<seed>
save_folder: !ref <output_folder>/save

# Data files
data_folder: !PLACEHOLDER # e,g./path/to/LJSpeech-1.1

train_json: !ref <save_folder>/train.json

splits: ["train"]
split_ratio: [80]
skip_prep: False
sample_rate: 16000

# ssl_model_type: hubert, wavlm, wav2vec2
# ssl_hub: facebook/hubert-large-ll60k, microsoft/wavlm-large,  facebook/wav2vec2-large
ssl_model_type: hubert # hubert, wavml or wav2vec2
ssl_hub: facebook/hubert-large-ll60k
ssl_folder: !ref <save_folder>/ssl_checkpoint
freeze_feature_extractor: True
freeze_ssl: True
ssl_layer_num: 7
batch_size: 128 # batch_size for loading and extracting features. It is different from kmeans_batch_size.
checkpoint_interval: 100


# Dataloader options
train_dataloader_opts:
    batch_size: !ref <batch_size>
    drop_last: True

ssl_model: !apply:speechbrain.utils.hparams.choice
    value: !ref <ssl_model_type>
    choices:
        wavlm: !new:speechbrain.lobes.models.huggingface_transformers.wavlm.WavLM
            source: !ref <ssl_hub>
            output_norm: False
            freeze: !ref <freeze_ssl>
            freeze_feature_extractor: !ref <freeze_feature_extractor>
            output_all_hiddens: True
            save_path: !ref <ssl_folder>
        hubert: !new:speechbrain.lobes.models.huggingface_transformers.hubert.HuBERT
            source: !ref <ssl_hub>
            output_norm: False
            freeze: !ref <freeze_ssl>
            freeze_feature_extractor: !ref <freeze_feature_extractor>
            output_all_hiddens: True
            save_path: !ref <ssl_folder>
        wav2vec2: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2
            source: !ref <ssl_hub>
            output_norm: False
            freeze: !ref <freeze_ssl>
            freeze_feature_extractor: !ref <freeze_feature_extractor>
            output_all_hiddens: True
            save_path: !ref <ssl_folder>

####################
# Model Parameters #
####################
num_clusters: 128
init: k-means++
max_iter: 100
kmeans_batch_size: 1000 # should be >= num_clusters
tol: 0.0
max_no_improvement: 100
n_init: 20
reassignment_ratio: 0.0
