# ########################################
# Recipe for training an emotion recognition system from speech data
# only using IEMOCAP and an SSL feature extractor
# The system classifies 4 emotions ( anger, happiness, sadness, neutrality)
# with an ECAPA-TDNN model.
# ########################################

# Seed needs to be set at top of yaml, before objects with parameters are made
seed: 1986
__set_seed: !apply:torch.manual_seed [!ref <seed>]

# Dataset will be downloaded to the `data_original`
data_folder: !PLACEHOLDER # e.g., /path/to/IEMOCAP_full_release
output_folder: !ref results/IEMOCAP/IEMOCAP_full_release/weighted_ssl/<seed>
save_folder: !ref <output_folder>/save
train_log: !ref <output_folder>/train_log.txt

# URL for the ssl encoder  model, you can change to benchmark diffrenet models
# Important: we use ssl encoder base and not the fine-tuned one with ASR task
# This allow you to have ~4% improvment

ssl_hub: microsoft/wavlm-large
ssl_folder: !ref <output_folder>/ssl_checkpoints
encoder_dim: 1024

# different speakers for train, valid and test sets
different_speakers: True
# which speaker is used for test set, value from 1 to 10
# Change this value and run this value 10 times and take the mean of it
test_spk_id: 1

# Path where data manifest files will be stored
train_annotation: !ref <output_folder>/train.json
valid_annotation: !ref <output_folder>/valid.json
test_annotation: !ref <output_folder>/test.json
skip_prep: False

# The train logger writes training statistics to a file, as well as stdout.
train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
    save_file: !ref <train_log>

ckpt_interval_minutes: 15 # save checkpoint every N min

# Training parameters
precision: fp32
number_of_epochs: 30
batch_size: 2
test_batch_size: 1

lr: 0.0002
lr_weights: 0.01

# Number of emotions
out_n_neurons: 4 # (anger, happiness, sadness, neutral)

# Dataloader options
train_dataloader_opts:
    batch_size: !ref <batch_size>
    shuffle: True
    num_workers: 2  # 2 on linux but 0 works on windows
    drop_last: False

valid_dataloader_opts:
    batch_size: !ref <batch_size>

test_dataloader_opts:
    batch_size: !ref <test_batch_size>

weighted_ssl_model: !new:speechbrain.lobes.models.huggingface_transformers.WeightedSSLModel # yamllint disable-line rule:line-length
    hub: !ref <ssl_hub>
    save_path: !ref <ssl_folder>

embedding_model: !new:speechbrain.lobes.models.ECAPA_TDNN.ECAPA_TDNN
    input_size: !ref <encoder_dim>
    channels: [1024, 1024, 1024, 1024, 3072]
    kernel_sizes: [5, 3, 3, 3, 1]
    dilations: [1, 2, 3, 4, 1]
    attention_channels: 64
    lin_neurons: 192

classifier: !new:speechbrain.lobes.models.ECAPA_TDNN.Classifier
    input_size: 192
    out_neurons: !ref <out_n_neurons>


epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
    limit: !ref <number_of_epochs>

modules:
    classifier: !ref <classifier>
    embedding_model: !ref <embedding_model>
    weighted_ssl_model: !ref <weighted_ssl_model>

model: !new:torch.nn.ModuleList
    - [!ref <embedding_model>, !ref <classifier>]

log_softmax: !new:speechbrain.nnet.activations.Softmax
    apply_log: True

compute_cost: !name:speechbrain.nnet.losses.nll_loss

error_stats: !name:speechbrain.utils.metric_stats.MetricStats
    metric: !name:speechbrain.nnet.losses.classification_error
        reduction: batch

model_opt_class: !name:torch.optim.Adam
    lr: !ref <lr>

weights_opt_class: !name:torch.optim.Adam
    lr: !ref <lr_weights>

lr_annealing_model: !new:speechbrain.nnet.schedulers.NewBobScheduler
    initial_value: !ref <lr>
    improvement_threshold: 0.0025
    annealing_factor: 0.9
    patient: 0

lr_annealing_weights: !new:speechbrain.nnet.schedulers.NewBobScheduler
    initial_value: !ref <lr_weights>
    improvement_threshold: 0.0025
    annealing_factor: 0.9


checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
    checkpoints_dir: !ref <save_folder>
    recoverables:
        model: !ref <model>
        ssl_model: !ref <weighted_ssl_model>
        scheduler_model: !ref <lr_annealing_model>
        scheduler_encoder: !ref <lr_annealing_weights>
        counter: !ref <epoch_counter>
