# ########################################
# Recipe for training an emotion recognition system from speech data
# only using IEMOCAP and an SSL feature extractor
# The system classifies 4 emotions ( anger, happiness, sadness, neutrality)
# with an ECAPA-TDNN model.
# ########################################

# Seed needs to be set at top of yaml, before objects with parameters are made
seed: 1986
__set_seed: !apply:torch.manual_seed [!ref <seed>]

# Dataset will be downloaded to the `data_original`
data_folder: !PLACEHOLDER # e.g., /path/to/IEMOCAP_full_release
output_folder: !ref results/IEMOCAP/IEMOCAP_full_release/linear/discrete_ssl/<seed>
save_folder: !ref <output_folder>/save
train_log: !ref <output_folder>/train_log.txt

# URL for the ssl encoder  model, you can change to benchmark diffrenet models
# Important: we use ssl encoder base and not the fine-tuned one with ASR task
# This allow you to have ~4% improvment

### Configuration for  discrete SSL model
# ssl_model_type: hubert, wavlm, wav2vec2
# ssl_hub: facebook/hubert-large-ll60k, microsoft/wavlm-large,  facebook/wav2vec2-large
ssl_model_type: hubert # hubert, wavml or wav2vec2
ssl_hub: facebook/hubert-large-ll60k
ssl_folder: !ref <save_folder>/ssl_checkpoint
kmeans_repo_id: speechbrain/SSL_Quantization
kmeans_cache_dir: !ref <save_folder>/kmeans_checkpoint
kmeans_dataset: LibriSpeech-100-360-500
freeze_ssl: True
freeze_feature_extractor: True
num_clusters: 1000


### Config for Tokenizer
# Layer number should be among the supported layers for discrete SSL models(kmenas  model should be available for that layer)
# ssl_layer_num: [3, 7, 12, 23]
# deduplicate: [False, False, False, False]
# bpe_tokenizer_path: [null , null,  null, null]
ssl_layer_num: [1, 3, 7, 12, 18, 23]
num_codebooks: 6
deduplicate: [False, False, False, False, False, False]
bpe_tokenizer_path: [null, null, null, null, null, null]
sample_rate: 16000

# Feature parameters

encoder_dim: 1024

# different speakers for train, valid and test sets
different_speakers: True
# which speaker is used for test set, value from 1 to 10
# Change this value and run this value 10 times and take the mean of it
test_spk_id: 1

# Path where data manifest files will be stored
train_annotation: !ref <output_folder>/train.json
valid_annotation: !ref <output_folder>/valid.json
test_annotation: !ref <output_folder>/test.json
skip_prep: False

# The train logger writes training statistics to a file, as well as stdout.
train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
    save_file: !ref <train_log>

ckpt_interval_minutes: 15 # save checkpoint every N min

# Training parameters
precision: fp32
number_of_epochs: 30
batch_size: 2
test_batch_size: 1

lr: 0.0002

# Number of emotions
out_n_neurons: 4 # (anger, happiness, sadness, neutral)

# Dataloader options
train_dataloader_opts:
    batch_size: !ref <batch_size>
    shuffle: True
    num_workers: 2  # 2 on linux but 0 works on windows
    drop_last: False

valid_dataloader_opts:
    batch_size: !ref <batch_size>

test_dataloader_opts:
    batch_size: !ref <test_batch_size>

# Modules
tokenizer_config:
    SSL_layers: !ref <ssl_layer_num>
    deduplicates: !ref <deduplicate>
    bpe_tokenizers: !ref <bpe_tokenizer_path>

ssl_model: !apply:speechbrain.utils.hparams.choice
    value: !ref <ssl_model_type>
    choices:
        wavlm: !new:speechbrain.lobes.models.huggingface_transformers.wavlm.WavLM
            source: !ref <ssl_hub>
            output_norm: False
            freeze: !ref <freeze_ssl>
            freeze_feature_extractor: !ref <freeze_feature_extractor>
            output_all_hiddens: True
            save_path: !ref <ssl_folder>
        hubert: !new:speechbrain.lobes.models.huggingface_transformers.hubert.HuBERT
            source: !ref <ssl_hub>
            output_norm: False
            freeze: !ref <freeze_ssl>
            freeze_feature_extractor: !ref <freeze_feature_extractor>
            output_all_hiddens: True
            save_path: !ref <ssl_folder>
        wav2vec2: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2
            source: !ref <ssl_hub>
            output_norm: False
            freeze: !ref <freeze_ssl>
            freeze_feature_extractor: !ref <freeze_feature_extractor>
            output_all_hiddens: True
            save_path: !ref <ssl_folder>

codec: !new:speechbrain.lobes.models.huggingface_transformers.discrete_ssl.DiscreteSSL
    save_path: !ref <kmeans_cache_dir>
    ssl_model: !ref <ssl_model>
    kmeans_dataset: !ref <kmeans_dataset>
    kmeans_repo_id: !ref <kmeans_repo_id>
    num_clusters: !ref <num_clusters>

discrete_embedding_layer: !new:custom_model.Discrete_EmbeddingLayer
    num_codebooks: !ref <num_codebooks>
    vocab_size: !ref <num_clusters>
    emb_dim: !ref <encoder_dim>

attention_mlp: !new:custom_model.AttentionMLP
    input_dim: !ref <encoder_dim>
    hidden_dim: !ref <encoder_dim>

avg_pool: !new:speechbrain.nnet.pooling.StatisticsPooling
    return_std: False

output_mlp: !new:speechbrain.nnet.linear.Linear
    input_size: !ref <encoder_dim>
    n_neurons: !ref <out_n_neurons>
    bias: False

epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
    limit: !ref <number_of_epochs>

modules:
    output_mlp: !ref <output_mlp>
    attention_mlp: !ref <attention_mlp>
    codec: !ref <codec>
    discrete_embedding_layer: !ref <discrete_embedding_layer>

model: !new:torch.nn.ModuleList
    - [!ref <output_mlp>, !ref <discrete_embedding_layer>, !ref <attention_mlp>]

log_softmax: !new:speechbrain.nnet.activations.Softmax
    apply_log: True

compute_cost: !name:speechbrain.nnet.losses.nll_loss

error_stats: !name:speechbrain.utils.metric_stats.MetricStats
    metric: !name:speechbrain.nnet.losses.classification_error
        reduction: batch

model_opt_class: !name:torch.optim.Adam
    lr: !ref <lr>


lr_annealing_model: !new:speechbrain.nnet.schedulers.NewBobScheduler
    initial_value: !ref <lr>
    improvement_threshold: 0.0025
    annealing_factor: 0.9
    patient: 0

checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
    checkpoints_dir: !ref <save_folder>
    recoverables:
        model: !ref <model>
        scheduler_model: !ref <lr_annealing_model>
        attention_mlp: !ref <attention_mlp>
        codec: !ref <codec>
        discrete_embedding_layer: !ref <discrete_embedding_layer>
        counter: !ref <epoch_counter>
